diff --git a/graph/compute_graph.cc b/graph/compute_graph.cc index f19f2a29e78d52b7a4da2a4050f9e221c2ceb201..cf4a32635fcb33dcbf4aedcc3b1ca998d14345ee 100644 --- a/graph/compute_graph.cc +++ b/graph/compute_graph.cc @@ -731,11 +731,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::vector &stack) { + std::vector &stack, bool reverse) { GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); // Record the number of non data nodes but no input nodes GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - + std::vector out_nodes; + auto stack_push = [&reverse, &stack](std::vector& out_nodes) { + if (reverse) { + std::reverse(out_nodes.begin(), out_nodes.end()); + } + stack.insert(stack.end(), out_nodes.begin(), out_nodes.end()); + out_nodes.clear(); + }; // Only data nodes here while (!stack.empty()) { NodePtr node = stack.back(); @@ -749,16 +756,18 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, GE_CHECK_NOTNULL(peer_in_anchor); auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); } } + stack_push(out_nodes); for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { GE_CHECK_NOTNULL(peer_in_anchor); auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); } } + stack_push(out_nodes); } GE_IF_BOOL_EXEC( node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor @@ -766,9 +775,10 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, GE_CHECK_NOTNULL(peer_in_anchor); auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); } - }) + } + stack_push(out_nodes);) } return GRAPH_SUCCESS; diff --git a/graph/utils/tuning_utils.cc b/graph/utils/tuning_utils.cc index 6d9ce08a2083fe966e73318b23baff6a2255e826..4a2543dd05d066af342ff445e56447c8bf2989d8 100644 --- a/graph/utils/tuning_utils.cc +++ b/graph/utils/tuning_utils.cc @@ -119,7 +119,11 @@ graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_ graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo& help_info) { GE_CHECK_NOTNULL(exe_graph); - + graphStatus ret = exe_graph->TopologicalSorting(true); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); + return ret; + } // clear graph id GELOGI("TUU:clear [%s] session_graph_id %s", exe_graph->GetName().c_str(), (AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "") ? "success" : "not success")); @@ -148,7 +152,7 @@ graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, } } } - graphStatus ret = exe_graph->TopologicalSorting(); + ret = exe_graph->TopologicalSorting(true); if (ret != SUCCESS) { GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); return ret; diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index 90a4c743abdeb147e4e1f4f497590825e9b67f6c..60de380e47c0919dd30ef8161a9b74280b6be870 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -137,7 +137,11 @@ class ComputeGraph : public std::enable_shared_from_this, public A /// graphStatus UpdateOutputMapping(const std::map &output_mapping); - graphStatus TopologicalSorting(); + /// nodes like : (a) <--- (c) ---> (b) + /// node a and b have only one parent node c, and a is connected to c firstly + /// topo order of DFS is `c, b, a` with `dfs_reverse=false` as default + /// in same case, user could get `c, a, b` with `dfs_reverse=true` + graphStatus TopologicalSorting(bool dfs_reverse = false); bool IsValid() const; void InValid() { is_valid_flag_ = false; } void Dump() const; @@ -249,12 +253,12 @@ class ComputeGraph : public std::enable_shared_from_this, public A private: graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::vector &stack); + std::vector &stack, bool reverse); graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack); graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, std::map &breadth_node_map); - graphStatus TopologicalSortingGraph(); + graphStatus TopologicalSortingGraph(bool dfs_reverse); graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); Vistor AllGraphNodes(std::vector> &subgraphs) const; size_t GetInEdgeSize(const NodePtr &node);