From 3b65a41bc899f036ac068f7c0d01643be88c2751 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 21 Oct 2020 09:42:01 +0800 Subject: [PATCH 1/3] for clear_atomic --- inc/register/op_tiling.h | 1 + 1 file changed, 1 insertion(+) diff --git a/inc/register/op_tiling.h b/inc/register/op_tiling.h index bcd4cd5e6c..383708199a 100644 --- a/inc/register/op_tiling.h +++ b/inc/register/op_tiling.h @@ -70,6 +70,7 @@ struct OpRunInfo { uint32_t block_dim; std::vector workspaces; ByteBuffer tiling_data; + bool clear_atomic; }; -- Gitee From d3349be8849b52b63bf6bddc33cc50f147c47f91 Mon Sep 17 00:00:00 2001 From: wxl Date: Thu, 22 Oct 2020 10:46:33 +0800 Subject: [PATCH 2/3] Bugfix:origin format infer bug fix --- graph/format_refiner.cc | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/graph/format_refiner.cc b/graph/format_refiner.cc index 8dc13a9bfe..dbce70478e 100644 --- a/graph/format_refiner.cc +++ b/graph/format_refiner.cc @@ -41,6 +41,7 @@ using namespace ge; using namespace std; namespace ge { namespace { +const size_t kDimSize4d = 4; const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; const string kIsGraphInferred = "_is_graph_inferred"; thread_local RefRelations reflection_builder; @@ -414,28 +415,26 @@ graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, s GE_CHECK_NOTNULL(data_node); auto op_desc = data_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); - auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); + + auto input_desc = op_desc->MutableInputDesc(0); + auto output_desc = op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(input_desc); + GE_CHECK_NOTNULL(output_desc); + + auto curr_format = output_desc->GetOriginFormat(); if (curr_format != FORMAT_ND) { // Data format has been infered , continue continue; } - // Set format for un-infered data node - auto input_descs = op_desc->GetAllInputsDescPtr(); - auto output_descs = op_desc->GetAllOutputsDescPtr(); - - for (const auto &input_desc : input_descs) { - if (input_desc != nullptr) { - input_desc->SetOriginFormat(data_format); - input_desc->SetFormat(data_format); - } - } - for (const auto &output_desc : output_descs) { - if (output_desc != nullptr) { - output_desc->SetOriginFormat(data_format); - output_desc->SetFormat(data_format); - } + // keep data format be ND because lacking of defination when input shape num is smaller than 4 + if (input_desc->MutableShape().GetDimNum() < kDimSize4d) { + continue; } + // Set format for un-infered data node + input_desc->SetOriginFormat(data_format); + input_desc->SetFormat(data_format); + output_desc->SetOriginFormat(data_format); + output_desc->SetFormat(data_format); uninfered_data_nodes.push_back(data_node); } // Reinfer format from uninfered data nodes -- Gitee From 675b74a0b8b7e81177d6e4c61e76ee087f275496 Mon Sep 17 00:00:00 2001 From: wxl Date: Fri, 23 Oct 2020 11:02:27 +0800 Subject: [PATCH 3/3] Feature:refresh options --- graph/option/ge_local_context.cc | 20 ++++++++++++++++++++ inc/graph/ge_local_context.h | 5 +++++ 2 files changed, 25 insertions(+) diff --git a/graph/option/ge_local_context.cc b/graph/option/ge_local_context.cc index 792abb605e..3a441eb67e 100644 --- a/graph/option/ge_local_context.cc +++ b/graph/option/ge_local_context.cc @@ -57,4 +57,24 @@ void GEThreadLocalContext::SetGraphOption(map options_map) graph_options_.clear(); graph_options_ = std::move(options_map); } + +map GEThreadLocalContext::GetAllGraphOptions() const { + return graph_options_; +} + +map GEThreadLocalContext::GetAllSessionOptions() const { + return session_options_; +} + +map GEThreadLocalContext::GetAllGlobalOptions() const { + return global_options_; +} + +map GEThreadLocalContext::GetAllOptions() const { + map options_all; + options_all.insert(graph_options_.begin(), graph_options_.end()); + options_all.insert(session_options_.begin(), session_options_.end()); + options_all.insert(global_options_.begin(), global_options_.end()); + return options_all; +} } // namespace ge diff --git a/inc/graph/ge_local_context.h b/inc/graph/ge_local_context.h index 36beaa7989..a691ebde9a 100644 --- a/inc/graph/ge_local_context.h +++ b/inc/graph/ge_local_context.h @@ -32,6 +32,11 @@ class GEThreadLocalContext { void SetSessionOption(map options_map); void SetGlobalOption(map options_map); + map GetAllGraphOptions() const; + map GetAllSessionOptions() const; + map GetAllGlobalOptions() const; + map GetAllOptions() const; + private: map graph_options_; map session_options_; -- Gitee