From 6f6c2c80e851f915426c07d142f2ec4f9cec9d9b Mon Sep 17 00:00:00 2001 From: l00800044 Date: Tue, 23 Sep 2025 16:20:57 +0800 Subject: [PATCH] fix gefunc initialize --- .../_ge_concrete_graph/fx2ge_converter.py | 4 ++- tests/st/torchair_aclgraph_st.py | 11 ++++++ torchair/concrete_graph/session.cpp | 35 +++++++++++++------ torchair/core/torchair.cpp | 6 ++++ torchair/include/session.h | 9 ++++- 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/python/torchair/_ge_concrete_graph/fx2ge_converter.py b/python/torchair/_ge_concrete_graph/fx2ge_converter.py index 917640a57..648729339 100644 --- a/python/torchair/_ge_concrete_graph/fx2ge_converter.py +++ b/python/torchair/_ge_concrete_graph/fx2ge_converter.py @@ -637,6 +637,8 @@ class GeConcreteGraph(ConcreteGraphBase): self._all_meta_tensor_input = {} self._fx_graph = None self._has_empty_tensor = False + _, global_compile_options = self._normalize_ge_option() + initialize_graph_engine(global_compile_options) def __call__(self, *args: Any, **kwargs: Any) -> Any: """ @@ -1087,7 +1089,7 @@ class GeConcreteGraph(ConcreteGraphBase): Returns: Any: Processed result of the parsed node. - """ + """ if str(target) in ['air.scope_enter.default', 'air.scope_exit.default']: return target(*args, **kwargs, need_excute=True) all_zero_and_nosym = all([is_zero_element_tensor(t) and not get_used_syms_in_meta(t) diff --git a/tests/st/torchair_aclgraph_st.py b/tests/st/torchair_aclgraph_st.py index 3ee96d65e..b11a1467e 100644 --- a/tests/st/torchair_aclgraph_st.py +++ b/tests/st/torchair_aclgraph_st.py @@ -39,6 +39,15 @@ stub_fa = StubNpuFA() stub_fa.default = stub_npu_fa_func stub_fa.out = stub_npu_fa_func + +class StubConf: + def __init__(self): + self.allow_hf32 = 0 + pass + + +stub_conf = StubConf() + _GLOBAL_POOL_ID = 0 @@ -282,6 +291,8 @@ class StubNpu: self.synchronize = stub_synchronize self.empty_cache = stub_empty_cache self.memory_snapshot = memory_snapshot + self.matmul = stub_conf + self.conv = stub_conf self._C = Stub_C diff --git a/torchair/concrete_graph/session.cpp b/torchair/concrete_graph/session.cpp index ea7266a8a..f48756bd0 100644 --- a/torchair/concrete_graph/session.cpp +++ b/torchair/concrete_graph/session.cpp @@ -58,7 +58,7 @@ Status Session::Initialize(const std::map &options) { } if (option.first == "ge_dump_with_acl_config") { TNG_RETURN_IF_ERROR(AclDumpConfigInit(option.second)); - aclmd_initialzed_ = true; + aclmd_initialized_ = true; continue; } ge_options[option.first.c_str()] = option.second.c_str(); @@ -84,26 +84,37 @@ Status Session::Initialize(const std::map &options) { auto ret = aclrtSetDevice(device_index_); TNG_ASSERT(ret == ACL_ERROR_NONE, "ACL set device id failed, return %d", ret); + if (!get_ge_func_) { + TNG_RETURN_IF_ERROR(GetGeFunc()); + } + + initialized_ = true; + return status_; +} + +Status Session::EnsureInitialized() { + if (initialized_) { + return status_; + } + return Status::Error("Session is not initialized"); +} + +Status Session::GetGeFunc() { libge_runner_handle = dlopen("libge_runner.so", RTLD_NOW); TNG_ASSERT_NOTNULL(libge_runner_handle, "libge_runner.so dlopen failed, %s", dlerror()); + fast_load_graph_ = reinterpret_cast(dlsym(libge_runner_handle, "GeSessionLoadGraph")); fast_execute_graph_async_ = reinterpret_cast(dlsym(libge_runner_handle, "GeSessionExecuteGraphWithStreamAsync")); get_registered_ir_def_ = reinterpret_cast(dlsym(libge_runner_handle, "GetRegisteredIrDef")); + TNG_LOG(DEBUG) << "In current cann version" << ", FastLoadGraph api is " << (IsFastLoadGraphSupported() ? "supported" : "unsupported") << ", FastExecuteGraph api is " << (IsFastExecuteGraphSupported() ? "supported" : "unsupported") << ", GetRegisteredIr api is " << (IsGetRegisteredIrDefSupported() ? "supported" : "unsupported"); - initialized_ = true; - return status_; -} - -Status Session::EnsureInitialized() { - if (initialized_) { - return status_; - } - return Status::Error("Session is not initialized"); + get_ge_func_ = true; + return Status::Success(); } Status Session::Finalize() { @@ -119,7 +130,7 @@ Status Session::Finalize() { TNG_LOG(DEBUG) << "ACL synchronize device success in Finalize."; } - if (aclmd_initialzed_) { + if (aclmd_initialized_) { (void)AclDumpConfigFinalize(); } @@ -148,6 +159,8 @@ Status Session::Finalize() { auto ctx_ptr = (ctx_ret == ACL_ERROR_NONE) ? detect_context : nullptr; initialized_ = false; + get_ge_func_ = false; + aclmd_initialized_ = false; TNG_LOG(DEBUG) << "After torchair finalize, got context pointer: " << ctx_ptr << ", and the initialized flag is set to " << initialized_; diff --git a/torchair/core/torchair.cpp b/torchair/core/torchair.cpp index c58e8c8c2..78e9ef613 100644 --- a/torchair/core/torchair.cpp +++ b/torchair/core/torchair.cpp @@ -134,6 +134,12 @@ std::tuple GetGeIrD using GeOutType = std::vector>; GeOutType inputs_ge, outputs_ge, attrs_ge; IrOutType inputs, outputs, attrs; + if (!Session::GetInstance().IsGetGeFunc()) { + TNG_RAISE_IF_ERROR(Session::GetInstance().GetGeFunc()); + } + if (!Session::GetInstance().IsInitialized()) { + TNG_LOG(WARNING) << "GE has not initialized, GetRegisteredIrDef func can not get correct geir def."; + } static bool enable_get_registered_ir_def = Session::GetInstance().IsGetRegisteredIrDefSupported(); if (enable_get_registered_ir_def) { tng::Status status = Session::GetInstance().GeGetRegisteredIrDef(op_type, inputs_ge, outputs_ge, attrs_ge); diff --git a/torchair/include/session.h b/torchair/include/session.h index 7ef54b636..0851baa67 100644 --- a/torchair/include/session.h +++ b/torchair/include/session.h @@ -24,6 +24,8 @@ class Session { Status EnsureInitialized(); + Status GetGeFunc(); + Status Finalize(); Status AddGraph(uint32_t id, const ge::Graph &graph, const std::map &options); @@ -73,6 +75,10 @@ class Session { return initialized_; } + bool IsGetGeFunc() const { + return get_ge_func_; + } + Status AclDumpConfigInit(const std::string &dump_path); Status AclDumpConfigFinalize(); @@ -81,7 +87,8 @@ class Session { Session() : initialized_(false), status_(Status::Success()){}; std::mutex mu_; std::atomic_bool initialized_; - std::atomic_bool aclmd_initialzed_; + std::atomic_bool aclmd_initialized_ = false; + std::atomic_bool get_ge_func_ = false; std::atomic_bool run_with_torch_npu_ = false; Status status_; int32_t device_index_ = -1; -- Gitee