diff --git a/graph/CMakeLists.txt b/graph/CMakeLists.txt index 7dbb2ca49b1fdfc6aa8e9b0dbd5d9bdad8c71b77..16758a4dcc1b7c4e223d99a0bba8eef353f80480 100755 --- a/graph/CMakeLists.txt +++ b/graph/CMakeLists.txt @@ -34,6 +34,7 @@ set(GRAPH_SOURCE_LIST "utils/type_utils.cc" "utils/tensor_utils.cc" "utils/constant_utils.cc" + "utils/branch_exec_cond_calculator.cc" "tensor.cc" "debug/graph_debug.cc" "opsproto/opsproto_manager.cc" diff --git a/graph/debug/ge_op_types.h b/graph/debug/ge_op_types.h index 6913b7093cc545bcdbbb8d982ab105b31f7e346f..22efc2a96cb0a704cd6b9455040ce4f696a2c2f1 100644 --- a/graph/debug/ge_op_types.h +++ b/graph/debug/ge_op_types.h @@ -48,6 +48,8 @@ GE_REGISTER_OPTYPE(ENTER, "Enter"); GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); +GE_REGISTER_OPTYPE(LOOPCOND, "LoopCond"); +GE_REGISTER_OPTYPE(IDENTITY, "Identity"); GE_REGISTER_OPTYPE(CONSTANT, "Const"); GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); GE_REGISTER_OPTYPE(END, "End"); diff --git a/graph/utils/branch_exec_cond_calculator.cc b/graph/utils/branch_exec_cond_calculator.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c9772b2c04c8fb35019544fd04a7d19ed2dac8e --- /dev/null +++ b/graph/utils/branch_exec_cond_calculator.cc @@ -0,0 +1,702 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/branch_exec_cond_calculator.h" +#include +#include +#include +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_op_types.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +graphStatus LogicOperatorUnit::Init(size_t length, size_t index, bool value) { + if (length == 0) { + REPORT_INNER_ERROR("E19999", "Init failed, length=%zu", length); + GELOGE(GRAPH_FAILED, "[Check][Param] Init failed, length=%zu", length); + return GRAPH_FAILED; + } + mask_ = std::move(std::vector(length, false)); + value_ = std::move(std::vector(length, false)); + valid_flag_ = true; + + if (index >= mask_.size()) { + REPORT_INNER_ERROR("E19999", "SetValue failed, length=%zu, index=%zu", mask_.size(), index); + GELOGE(GRAPH_FAILED, "[Check][Param] SetValue failed, length=%zu, index=%zu", mask_.size(), index); + return GRAPH_FAILED; + } + mask_[index] = true; + value_[index] = value; + + return GRAPH_SUCCESS; +} + +LogicOperatorUnit LogicOperatorUnit::And(const LogicOperatorUnit &left, const LogicOperatorUnit &right) { + if (!(left.IsValid() && right.IsValid())) { + GELOGW("[Check][Param] Input is invalid"); + return LogicOperatorUnit(false); + } else if (!CheckSupport(left, right)) { + REPORT_CALL_ERROR("E19999", "CheckSupport failed"); + GELOGE(GRAPH_FAILED, "[And][Check] CheckSupport failed"); + return LogicOperatorUnit(false); + } + size_t length = left.mask_.size(); + std::vector mask(length, false); + std::vector value(length, false); + for (size_t i = 0; i < length; i++) { + if (left.mask_[i] && right.mask_[i]) { + if (left.value_[i] == right.value_[i]) { + mask[i] = true; + value[i] = left.value_[i]; + } else { + return LogicOperatorUnit(false); + } + } else if (left.mask_[i]) { + mask[i] = true; + value[i] = left.value_[i]; + } else if (right.mask_[i]) { + mask[i] = true; + value[i] = right.value_[i]; + } + } + return LogicOperatorUnit(mask, value); +} + +LogicOperatorUnit LogicOperatorUnit::Or(const LogicOperatorUnit &left, const LogicOperatorUnit &right) { + if (!(left.IsValid() && right.IsValid())) { + GELOGW("[Check][Param] Input is invalid"); + return LogicOperatorUnit(false); + } else if (!left.IsValid() || !right.IsValid()) { + return left.IsValid() ? left : right; + } else if (!CheckSupport(left, right)) { + REPORT_CALL_ERROR("E19999", "CheckSupport failed"); + GELOGE(GRAPH_FAILED, "[Or][Check] CheckSupport failed"); + return LogicOperatorUnit(false); + } + size_t length = left.mask_.size(); + std::vector mask(length, false); + std::vector value(length, false); + for (size_t i = 0; i < length; i++) { + if (left.mask_[i] && right.mask_[i]) { + if (left.value_[i] == right.value_[i]) { + mask[i] = true; + value[i] = left.value_[i]; + } + } + } + return LogicOperatorUnit(mask, value); +} + +bool LogicOperatorUnit::CheckSupport(const LogicOperatorUnit &left, const LogicOperatorUnit &right) { + return ((left.mask_.size() == left.value_.size()) && (right.mask_.size() == right.value_.size()) && + (left.mask_.size() == right.mask_.size())); +} + +bool LogicOperatorUnit::IsOrthogonal(const LogicOperatorUnit &left, const LogicOperatorUnit &right) { + if (!CheckSupport(left, right)) { + return false; + } + size_t num = left.mask_.size(); + for (size_t i = 0; i < num; i++) { + if (left.mask_[i] && right.mask_[i]) { + return false; + } + } + return true; +} + +bool LogicOperatorUnit::operator==(const LogicOperatorUnit &unit) const { + return (this->mask_ == unit.mask_) && (this->value_ == unit.value_); +} + +bool LogicOperatorUnit::IsEmpty() const { + if (!IsValid()) { + return false; + } + return std::none_of(mask_.begin(), mask_.end(), [](bool i) { return i; }); +} + +bool LogicOperatorUnit::IsSubUnit(const LogicOperatorUnit &unit) const { + if (IsEmpty()) { + return unit.IsValid(); + } + if (!CheckSupport(*this, unit)) { + return false; + } + size_t num = mask_.size(); + for (size_t i = 0; i < num; i++) { + if (mask_[i]) { + if (!unit.mask_[i]) { + return false; + } + if (value_[i] != unit.value_[i]) { + return false; + } + } + } + return true; +} + +std::string LogicOperatorUnit::String() const { + if (!valid_flag_) { + return "Invalid"; + } + std::stringstream ss; + for (size_t i = 0; i < mask_.size(); i++) { + if (mask_[i]) { + if (value_[i]) { + ss << "T"; + } else { + ss << "F"; + } + } else { + ss << "0"; + } + } + return ss.str(); +} + +LogicOperatorItem LogicOperatorItem::And(const LogicOperatorItem &left, const LogicOperatorItem &right) { + if (!(left.IsValid() && right.IsValid())) { + GELOGW("[Check][Param] Input is invalid"); + return LogicOperatorItem(false); + } + if (left.IsEmpty() || right.IsEmpty()) { + return left.IsEmpty() ? right : left; + } + if (!CheckSupport(left, right)) { + REPORT_CALL_ERROR("E19999", "CheckSupport failed"); + GELOGE(GRAPH_FAILED, "[And][Check] CheckSupport failed"); + return LogicOperatorItem(false); + } + std::set units; + for (const auto &left_unit : left.units_) { + for (const auto &right_unit : right.units_) { + const auto &tmp = LogicOperatorUnit::And(left_unit, right_unit); + if (!(tmp.IsValid() && tmp.IsEmpty())) { + units.insert(tmp); + } + } + } + + return LogicOperatorItem(units).Simplify(); +} + +LogicOperatorItem LogicOperatorItem::Or(const LogicOperatorItem &left, const LogicOperatorItem &right) { + if (!(left.IsValid() || right.IsValid())) { + GELOGW("[Check][Param] Input is invalid"); + return LogicOperatorItem(false); + } else if (!left.IsValid()) { + return right; + } else if (!right.IsValid()) { + return left; + } + if (left.IsEmpty() || right.IsEmpty()) { + return LogicOperatorItem(); + } + if (!CheckSupport(left, right)) { + REPORT_CALL_ERROR("E19999", "CheckSupport failed"); + GELOGE(GRAPH_FAILED, "[Or][Check] CheckSupport failed"); + return LogicOperatorItem(false); + } + std::set units; + for (const auto &left_unit : left.units_) { + for (const auto &right_unit : right.units_) { + if (LogicOperatorUnit::IsOrthogonal(left_unit, right_unit)) { + units.insert(left_unit); + units.insert(right_unit); + continue; + } + const auto &tmp = LogicOperatorUnit::Or(left_unit, right_unit); + if (!(tmp.IsValid() && tmp.IsEmpty())) { + units.insert(tmp); + } + } + } + + return LogicOperatorItem(units).Simplify(); +} + +bool LogicOperatorItem::CheckSupport(const LogicOperatorItem &left, const LogicOperatorItem &right) { + if (left.units_.empty() || right.units_.empty()) { + REPORT_INNER_ERROR("E19999", "Inputs is empty, left.size=%zu, right.size=%zu", left.units_.size(), + right.units_.size()); + GELOGE(GRAPH_FAILED, "[Check][Param] Inputs is empty, left.size=%zu, right.size=%zu", left.units_.size(), right.units_.size()); + return false; + } + const auto &left_unit = left.units_.begin(); + for (const auto &unit : left.units_) { + if (!LogicOperatorUnit::CheckSupport(*left_unit, unit)) { + REPORT_CALL_ERROR("E19999", "Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + left_unit->mask_.size(), left_unit->value_.size(), unit.mask_.size(), unit.value_.size()); + GELOGE(GRAPH_FAILED, + "[Check][Support] Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + left_unit->mask_.size(), left_unit->value_.size(), unit.mask_.size(), unit.value_.size()); + return false; + } + } + + const auto &right_unit = right.units_.begin(); + for (const auto &unit : right.units_) { + if (!LogicOperatorUnit::CheckSupport(*right_unit, unit)) { + REPORT_CALL_ERROR("E19999", "Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + right_unit->mask_.size(), right_unit->value_.size(), unit.mask_.size(), unit.value_.size()); + GELOGE(GRAPH_FAILED, + "[Check][Support] Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + right_unit->mask_.size(), right_unit->value_.size(), unit.mask_.size(), unit.value_.size()); + return false; + } + } + + if (!LogicOperatorUnit::CheckSupport(*left_unit, *right_unit)) { + REPORT_CALL_ERROR("E19999", "Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + left_unit->mask_.size(), left_unit->value_.size(), right_unit->mask_.size(), + right_unit->value_.size()); + GELOGE(GRAPH_FAILED, + "[Check][Support] Length mismatch, mask1_size=%zu, value1_size=%zu, mask2_size=%zu, value2_size=%zu", + left_unit->mask_.size(), left_unit->value_.size(), right_unit->mask_.size(), right_unit->value_.size()); + return false; + } + + return true; +} + +const LogicOperatorItem &LogicOperatorItem::Simplify() { + if (!valid_flag_) { + units_.clear(); + } + if (units_.empty()) { + return *this; + } + + std::vector tmp_units; + for (const auto &unit : units_) { + if (unit.IsValid()) { + tmp_units.emplace_back(unit); + } + } + + if (tmp_units.empty()) { + valid_flag_ = false; + } else { + for (const auto &unit : tmp_units) { + if (unit.IsEmpty()) { + tmp_units.clear(); + break; + } + } + } + + units_.clear(); + for (const auto &left_unit : tmp_units) { + bool redundant_flag = false; + for (const auto &right_unit : tmp_units) { + if (left_unit == right_unit) { + continue; + } + if (right_unit.IsSubUnit(left_unit)) { + redundant_flag = true; + break; + } + } + if (!redundant_flag) { + units_.insert(left_unit); + } + } + + return *this; +} + +std::string LogicOperatorItem::String() const { + if (!valid_flag_) { + return "Invalid"; + } + std::stringstream ss; + for (const auto &unit : units_) { + ss << unit.String() << ","; + } + return ss.str(); +} + +LogicOperatorItem ExecCondLogicOperator::Calculate(const std::vector &items, + const OperationType &operation_type) { + if (items.empty()) { + GELOGW("[Check][Param] Input compute items is empty"); + return LogicOperatorItem(true); + } + LogicOperatorItem res = items[0]; + for (size_t i = 1; i < items.size(); i++) { + if (operation_type == AND) { + res = LogicOperatorItem::And(res, items[i]); + } else if (operation_type == OR) { + res = LogicOperatorItem::Or(res, items[i]); + } else { + REPORT_INNER_ERROR("E19999", "operation type %u not support, only support AND, OR", operation_type); + GELOGE(GRAPH_FAILED, "[Check][Param] operation type %u not support, only support AND, OR", operation_type); + return LogicOperatorItem(false); + } + } + return res; +} + +graphStatus BranchExecCondCalculator::Calculate(){ + if (graph_ == nullptr) { + GELOGW("[Check][Param] Input graph is nullptr."); + return GRAPH_SUCCESS; + } + + std::map switch_to_cond_input; + std::map cond_to_index; + bool loop_branch_flag = false; + if (FindSwitchCondInput(switch_to_cond_input, cond_to_index, loop_branch_flag) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "GetSwitchCondInput failed, graph=%s", graph_->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Calculate] GetSwitchCondInput failed, graph=%s", + graph_->GetName().c_str()); + return GRAPH_FAILED; + } + if (cond_to_index.empty()) { + GELOGI("graph %s does not have v1 cond/loop branch", graph_->GetName().c_str()); + return GRAPH_SUCCESS; + } + std::map next_exec_cond; + if (loop_branch_flag) { + if (FindLoopExecCond(cond_to_index, next_exec_cond) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "FindNextIterationExecCond failed, graph=%s", graph_->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Calculate] FindNextIterationExecCond failed, graph=%s", + graph_->GetName().c_str()); + return GRAPH_FAILED; + } + } + + if (FindAllExecCond(switch_to_cond_input, cond_to_index, next_exec_cond) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "FindAllNodeExecCond failed, graph=%s", graph_->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Calculate] FindAllNodeExecCond failed, graph=%s", + graph_->GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +void BranchExecCondCalculator::GetBranchExecCondLabel(std::map &node_exec_cond_label) const { + std::map item_visited; + for (const auto &pair : node_exec_cond_) { + const auto &node = pair.first; + const auto &item = pair.second; + if (item_visited.count(item) > 0) { + node_exec_cond_label[node] = item_visited[item]; + } else { + size_t visited_num = item_visited.size(); + node_exec_cond_label[node] = visited_num; + item_visited[item] = visited_num; + } + } + + for (const auto &item : item_visited) { + GELOGD("[Debug print for GetBranchExecCondLabel] label=%zu, cond info=%s", item.second, item.first.String().c_str()); + } + for (const auto &item : node_exec_cond_label) { + GELOGD("[Debug print for GetBranchExecCondLabel] node=%s, label=%zu", item.first->GetName().c_str(), item.second); + } +} + +graphStatus BranchExecCondCalculator::FindSwitchCondInput(std::map &switch_to_cond_input, + std::map &cond_to_index, + bool &loop_branch_flag) { + for (const auto &node : graph_->GetDirectNode()) { + const std::string &type = NodeUtils::GetNodeType(node); + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + const auto &switch_cond_input = NodeUtils::GetOriginalSwitchCondInput(node); + if (switch_cond_input == nullptr) { + REPORT_INNER_ERROR("E19999", "Get original cond input for switch node %s failed", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Get original cond input for switch node %s failed", + node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!loop_branch_flag && switch_cond_input->GetOwnerNode()->GetType() == LOOPCOND) { + loop_branch_flag = true; + } + switch_to_cond_input[node] = switch_cond_input; + if (cond_to_index.count(switch_cond_input) == 0) { + size_t cond_num = cond_to_index.size(); + GELOGD("[Debug print for GetBranchExecCond] index=%zu, cond_node=%s, out_index=%d", + cond_num, switch_cond_input->GetOwnerNode()->GetName().c_str(), switch_cond_input->GetIdx()); + cond_to_index[switch_cond_input] = cond_num; + } + } + return GRAPH_SUCCESS; +} + +graphStatus BranchExecCondCalculator::FindLoopExecCond(const std::map &cond_to_index, + std::map &next_exec_cond) { + std::map> groups_to_enter_nodes; + if (GroupEnterNodes(groups_to_enter_nodes) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Group enter nodes failed"); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Group enter nodes failed"); + return GRAPH_FAILED; + } + + std::list loop_groups; + if (FindLoopGroup(groups_to_enter_nodes, loop_groups) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Find loop group failed"); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find loop group failed"); + return GRAPH_FAILED; + } + + for (const auto &loop_group : loop_groups) { + // LoopCond node has and only has one output + const auto &loop_cond = loop_group.loop_cond->GetOutDataAnchor(0); + const auto &iter = cond_to_index.find(loop_cond); + if (iter == cond_to_index.end()) { + REPORT_INNER_ERROR("E19999", "Find loop cond failed, LoopCond node=%s", loop_group.loop_cond->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find loop cond failed, LoopCond node=%s", + loop_group.loop_cond->GetName().c_str()); + return GRAPH_FAILED; + } + size_t index = iter->second; + LogicOperatorUnit unit; + size_t branch_cond_num = cond_to_index.size(); + GELOGD("Init LogicOperatorUnit for LoopCond %s, cond_num=%zu, index=%zu", + loop_group.loop_cond->GetName().c_str(), branch_cond_num, index); + if (unit.Init(branch_cond_num, index, true) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "LogicOperatorUnit Init failed, LoopCond node=%s, cond_num=%zu, index=%zu", + loop_group.loop_cond->GetName().c_str(), branch_cond_num, index); + GELOGE(GRAPH_FAILED, + "[GetBranchExecCond][Check] LogicOperatorUnit Init failed, LoopCond node=%s, cond_num=%zu, index=%zu", + loop_group.loop_cond->GetName().c_str(), branch_cond_num, index); + return GRAPH_FAILED; + } + + for (const auto &next_node : loop_group.next_iteration_nodes) { + next_exec_cond[next_node] = unit; + } + } + + for (const auto &item : next_exec_cond) { + GELOGD("[Debug print for GetBranchExecCond] node=%s, label=%s", item.first->GetName().c_str(), + item.second.String().c_str()); + } + return GRAPH_SUCCESS; +} + +graphStatus BranchExecCondCalculator::FindAllExecCond(const std::map &switch_to_cond_input, + const std::map &cond_to_index, + const std::map &next_exec_cond) { + for (const auto &node : graph_->GetDirectNode()) { + std::vector items; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + const auto &in_node = peer_out_anchor->GetOwnerNode(); + if (in_node == nullptr) { + REPORT_INNER_ERROR("E19999", "in node is null, cur node=%s", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] in node is null, cur node=%s", node->GetName().c_str()); + return GRAPH_FAILED; + } + const std::string &in_type = NodeUtils::GetNodeType(in_node); + if ((in_type == NEXTITERATION) || (in_type == NEXTITERATION)) { + const auto &iter0 = next_exec_cond.find(in_node); + if (iter0 == next_exec_cond.end()) { + REPORT_INNER_ERROR("E19999", "Find NextIteration node %s exec_cond failed", in_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find NextIteration node %s exec_cond failed", + in_node->GetName().c_str()); + return GRAPH_FAILED; + } + items.emplace_back(LogicOperatorItem({ iter0->second })); + } else { + const auto &iter1 = node_exec_cond_.find(in_node); + if (iter1 == node_exec_cond_.end()) { + REPORT_INNER_ERROR("E19999", "Find NextIteration node %s exec_cond failed", in_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find NextIteration node %s exec_cond failed", + in_node->GetName().c_str()); + return GRAPH_FAILED; + } + if ((in_type == SWITCH) || (in_type == REFSWITCH)) { + const auto &switch_iter = switch_to_cond_input.find(in_node); + if (switch_iter == switch_to_cond_input.end()) { + REPORT_INNER_ERROR("E19999", "Find Switch node %s cond_input failed", in_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find Switch node %s cond_input failed", + in_node->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &cond_iter = cond_to_index.find(switch_iter->second); + if (cond_iter == cond_to_index.end()) { + REPORT_INNER_ERROR("E19999", "Find Switch node %s failed", in_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find Switch node %s failed", in_node->GetName().c_str()); + return GRAPH_FAILED; + } + LogicOperatorUnit unit; + size_t branch_cond_num = cond_to_index.size(); + size_t index = cond_iter->second; + // switch output: 0 for false, 1 for true + bool value = (peer_out_anchor->GetIdx() == 1); + GELOGD("Init LogicOperatorUnit for Switch output node %s, cond_num=%zu, index=%zu, value=%u", + node->GetName().c_str(), branch_cond_num, index, value); + if (unit.Init(branch_cond_num, index, value) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "LogicOperatorUnit Init failed, switch node=%s, cond_num=%zu, index=%zu, value=%u", + in_node->GetName().c_str(), branch_cond_num, index, value); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] LogicOperatorUnit Init failed, switch node=%s, " + "cond_num=%zu, index=%zu, value=%u", + in_node->GetName().c_str(), branch_cond_num, index, value); + return GRAPH_FAILED; + } + const auto &input_item = ExecCondLogicOperator::Calculate({ iter1->second, LogicOperatorItem({unit}) }, AND); + if (!input_item.IsValid()) { + GELOGW("[GetBranchExecCond][Check] exec_cond of node %s is invalid", node->GetName().c_str()); + } + items.emplace_back(input_item); + } else { + items.emplace_back(iter1->second); + } + } + } + + for (const auto &in_ctrl_node : node->GetInControlNodes()) { + const auto &iter2 = node_exec_cond_.find(in_ctrl_node); + if (iter2 == node_exec_cond_.end()) { + REPORT_INNER_ERROR("E19999", "Find in ctrl node %s exec_cond failed", in_ctrl_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Find in ctrl node %s exec_cond failed", + in_ctrl_node->GetName().c_str()); + return GRAPH_FAILED; + } + items.emplace_back(iter2->second); + } + + const std::string &type = NodeUtils::GetNodeType(node); + OperationType operation_type = AND; + if ((type == MERGE) || (type == REFMERGE)) { + operation_type = OR; + } + + const auto &res = ExecCondLogicOperator::Calculate(items, operation_type); + if (!res.IsValid()) { + GELOGW("[GetBranchExecCond][Check] exec_cond of node %s is invalid", node->GetName().c_str()); + } + + GELOGD("[Debug print for GetBranchExecCond] node=%s, exec_cond=%s", node->GetName().c_str(), res.String().c_str()); + node_exec_cond_[node] = res; + } + return GRAPH_SUCCESS; +} + +graphStatus BranchExecCondCalculator::GroupEnterNodes(std::map> &groups_to_enter_nodes) { + for (const auto &node : graph_->GetDirectNode()) { + if ((node->GetType() != ENTER) && (node->GetType() != REFENTER)) { + continue; + } + // op_desc of node should not be null + OpDescPtr enter_desc = node->GetOpDesc(); + std::string frame_name; + if (!AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { + REPORT_CALL_ERROR("E19999", "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", + enter_desc->GetName().c_str()); + return GRAPH_FAILED; + } + + string batch_label; + if (AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { + frame_name.append("_").append(batch_label); + } + groups_to_enter_nodes[frame_name].emplace_back(node); + } + + return GRAPH_SUCCESS; +} + +graphStatus BranchExecCondCalculator::FindLoopGroup( + const std::map> &groups_to_enter_nodes, std::list &loop_groups) { + for (const auto &group_to_enter_nodes : groups_to_enter_nodes) { + std::list loop_switch_nodes; + LoopGroup group; + FindSwitchAndNextIteration(group_to_enter_nodes.second, loop_switch_nodes, group.next_iteration_nodes); + NodePtr loop_cond_node = nullptr; + for (const auto &loop_switch : loop_switch_nodes) { + // pred input index of switch node is 1 + const auto &pred_input_node = NodeUtils::GetInDataNodeByIndex(*loop_switch, 1); + if (pred_input_node == nullptr) { + REPORT_INNER_ERROR("E19999", "peer input node of switch node %s is null", loop_switch->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] peer input node of switch node %s is null", + loop_switch->GetName().c_str()); + return GRAPH_FAILED; + } + if (loop_cond_node == nullptr) { + loop_cond_node = pred_input_node; + } else if (loop_cond_node != pred_input_node) { + REPORT_INNER_ERROR("E19999", + "peer input node of switch nodes in the same loop group is different, pred1=%s, pred2=%s", + loop_cond_node->GetName().c_str(), pred_input_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] peer input node of switch nodes in the same loop group is " + "different, pred1=%s, pred2=%s", + loop_cond_node->GetName().c_str(), pred_input_node->GetName().c_str()); + return GRAPH_FAILED; + } + } + if (loop_cond_node == nullptr) { + REPORT_INNER_ERROR("E19999", "LoopCond node for loop group %s is null", group_to_enter_nodes.first.c_str()); + GELOGE(GRAPH_FAILED, "[GetBranchExecCond][Check] LoopCond node for loop group %s is null", + group_to_enter_nodes.first.c_str()); + return GRAPH_FAILED; + } + group.loop_cond = loop_cond_node; + loop_groups.emplace_back(group); + } + + return GRAPH_SUCCESS; +} + +void BranchExecCondCalculator::FindSwitchAndNextIteration(const std::list &enter_nodes, + std::list &loop_switch_nodes, + std::list &next_iteration_nodes) { + /// + /// find v1 loop group structure like: + /// Exit Node + /// \F T/ | + /// Switch | + /// | | + /// Merge | + /// / \ | + /// Enter NextIteration + /// + for (const auto &enter_node : enter_nodes) { + for (const auto &enter_out_node : enter_node->GetOutDataNodes()) { + const std::string &enter_out_type = NodeUtils::GetNodeType(enter_out_node); + if ((enter_out_type != MERGE) && (enter_out_type != REFMERGE)) { + continue; + } + for (const auto &out_node : enter_out_node->GetOutDataNodes()) { + const std::string &out_type = NodeUtils::GetNodeType(out_node); + if ((out_type != SWITCH) && (out_type != REFSWITCH)) { + continue; + } + loop_switch_nodes.emplace_back(out_node); + } + for (const auto &in_node : enter_out_node->GetInDataNodes()) { + const std::string &in_type = NodeUtils::GetNodeType(in_node); + if ((in_type != NEXTITERATION) && (in_type != REFNEXTITERATION)) { + continue; + } + next_iteration_nodes.emplace_back(in_node); + } + } + } +} +} // namespace ge diff --git a/graph/utils/graph_utils.cc b/graph/utils/graph_utils.cc index d95aab5af6aaad35f8445a238f28530ce398f9ef..7c0007e30cc36c95f7d95a31c937e465818e7533 100644 --- a/graph/utils/graph_utils.cc +++ b/graph/utils/graph_utils.cc @@ -2160,7 +2160,7 @@ ComputeGraphPtr GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std:: /// Copy tensor attribute to new node. /// @param [in] dst_node: cloned node. /// @param [in] src_node: original node. -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { if (dst_desc == nullptr) { @@ -2204,7 +2204,7 @@ graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr /// @param [in] node: original node. /// @param [in] prefix: node name prefix of new node. /// @param [in] all_nodes: all nodes in new graph. -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, const std::unordered_map &all_nodes) { @@ -2282,7 +2282,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref /// @param [in] graph /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, std::map> &symbol_to_anchors, @@ -2335,7 +2335,7 @@ NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::stri /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::HandleInAnchorMapping(const ComputeGraphPtr &graph, const NodePtr &node, std::map> &symbol_to_anchors, @@ -2382,7 +2382,7 @@ graphStatus GraphUtils::HandleInAnchorMapping(const ComputeGraphPtr &graph, cons /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, std::map> &symbol_to_anchors, @@ -2422,7 +2422,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -2460,7 +2460,7 @@ graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -2533,7 +2533,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -2583,7 +2583,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol /// @param [out] symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, std::map> &symbol_to_anchors, @@ -2635,7 +2635,7 @@ graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, /// @param [in] exist_node_info /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS +/// @return success: GRAPH_SUCCESS /// graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, std::map> &symbol_to_anchors, @@ -3289,6 +3289,18 @@ graphStatus GraphUtils::MergeNetOutputNode(const ComputeGraphPtr &graph) { return GRAPH_SUCCESS; } +graphStatus GraphUtils::GetBranchExecCondLabel(const ComputeGraphPtr &graph, + std::map &node_exec_cond_label) { + BranchExecCondCalculator calculator(graph); + if (calculator.Calculate() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Get branch exec cond label failed"); + GELOGE(GRAPH_FAILED, "[GetBranchExecCondLabel][Check] Get branch exec cond label failed"); + return GRAPH_FAILED; + } + calculator.GetBranchExecCondLabel(node_exec_cond_label); + return GRAPH_SUCCESS; +} + /// /// @brief Add node to graph /// @param [in] op_desc diff --git a/graph/utils/node_utils.cc b/graph/utils/node_utils.cc index 8235f0f4a50ede929fb489aa6206a17889ae4232..478f89504f86777aad34968ab6310ea183a1a1c5 100644 --- a/graph/utils/node_utils.cc +++ b/graph/utils/node_utils.cc @@ -1369,4 +1369,53 @@ graphStatus NodeUtils::UpdateOutputOriginalShapeAndShape(const Node &node, uint3 output_desc->SetOriginShape(shape); return GRAPH_SUCCESS; } + +OutDataAnchorPtr NodeUtils::GetOriginalSwitchCondInput(const NodePtr &switch_node) { + std::string type = NodeUtils::GetNodeType(switch_node); + if ((type != SWITCH) && (type != REFSWITCH)) { + REPORT_INNER_ERROR("E19999", "Input node is not switch, node=%s, type=%s", switch_node->GetName().c_str(), + type.c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] Input node is not switch, node=%s, type=%s", + switch_node->GetName().c_str(), type.c_str()); + return nullptr; + } + // pred input index of switch node is 1 + InDataAnchorPtr in_cond_anchor = switch_node->GetInDataAnchor(1); + OutDataAnchorPtr peer_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); + if (peer_cond_anchor == nullptr) { + REPORT_INNER_ERROR("E19999", "Get peer out anchor for switch pred input failed, switch_node=%s", + switch_node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] Get peer out anchor for switch pred input failed, switch_node=%s", + switch_node->GetName().c_str()); + return nullptr; + } + NodePtr node = peer_cond_anchor->GetOwnerNode(); + while (true) { + if (node == nullptr) { + REPORT_INNER_ERROR("E19999", "Node is null"); + GELOGE(GRAPH_FAILED, "[Check][Param] Node is null"); + return nullptr; + } + type = NodeUtils::GetNodeType(node); + if ((type == IDENTITY) || (type == SWITCH) || (type == REFSWITCH)) { + // Identity node has and only has one output, data input index of switch node is 0 + in_cond_anchor = node->GetInDataAnchor(0); + } else { + break; + } + if (in_cond_anchor == nullptr) { + REPORT_INNER_ERROR("E19999", "in cond anchor of node %s is null", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] in cond anchor of node %s is null", node->GetName().c_str()); + return nullptr; + } + peer_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); + if (peer_cond_anchor == nullptr) { + REPORT_INNER_ERROR("E19999", "peer cond anchor of node %s is null", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "[Check][Param] peer cond anchor of node %s is null", node->GetName().c_str()); + return nullptr; + } + node = peer_cond_anchor->GetOwnerNode(); + } + return peer_cond_anchor; +} } // namespace ge diff --git a/inc/graph/utils/branch_exec_cond_calculator.h b/inc/graph/utils/branch_exec_cond_calculator.h new file mode 100644 index 0000000000000000000000000000000000000000..46011767bbc0ec9080cb731a9b8b31e488aed74e --- /dev/null +++ b/inc/graph/utils/branch_exec_cond_calculator.h @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_BRANCH_EXEC_COND_LOGIC_CALCULATOR_H_ +#define INC_GRAPH_UTILS_BRANCH_EXEC_COND_LOGIC_CALCULATOR_H_ + +#include "graph/compute_graph.h" + +/************************************************************************************************** + * + * 1. Node1 ---> Node2 + * {A} {A} + * + * 2. Node1 \ + * {A} Node3 + * Node2 / {A&&B} + * {B} + * + * 3. F/ FalseBranch + * Pred ---> Switch {A&&} + * {A} {A} T\ TrueBranch + * {A&&} + * + * 4. Node1 \ + * {A} Merge + * Node2 / {A||B} + * {B} + * + * ============================== V1 Cond Branch ================================================== + * + * F/ FalseBranch \ + * Pred ---> Switch {A&&} Merge + * {A} {A} T\ TrueBranch / {A} + * {A&&} + * + * ============================== V1 Loop Branch ================================================== + * + * LoopCond Exit + * {A||} \ F/ {} + * Enter -------- Switch ------ + * {A} \ / {A||} T\ + * -------------- Merge BodyNode + * / {A||} / {} + * NextIteration <------------------------------------------------------| + * {} + * +**************************************************************************************************/ + +namespace ge { +enum OperationType { + AND, + OR +}; + +class LogicOperatorUnit { + public: + explicit LogicOperatorUnit(bool valid_flag = true) : valid_flag_(valid_flag) {} + LogicOperatorUnit(std::vector mask, std::vector value) + : mask_(std::move(mask)), value_(std::move(value)), valid_flag_(true) {} + ~LogicOperatorUnit() = default; + graphStatus Init(size_t length, size_t index, bool value); + + static LogicOperatorUnit And(const LogicOperatorUnit &left, const LogicOperatorUnit &right); + static LogicOperatorUnit Or(const LogicOperatorUnit &left, const LogicOperatorUnit &right); + static bool CheckSupport(const LogicOperatorUnit &left, const LogicOperatorUnit &right); + static bool IsOrthogonal(const LogicOperatorUnit &left, const LogicOperatorUnit &right); + bool operator==(const LogicOperatorUnit &unit) const; + bool operator<(const LogicOperatorUnit &unit) const { + return String() < unit.String(); + } + bool IsEmpty() const; + bool IsValid() const { return valid_flag_; } + bool IsSubUnit(const LogicOperatorUnit &unit) const; + std::string String() const; + + private: + friend class LogicOperatorItem; + std::vector mask_; + std::vector value_; + bool valid_flag_; +}; + +class LogicOperatorItem { + public: + explicit LogicOperatorItem(bool valid_flag = true) : valid_flag_(valid_flag) {} + explicit LogicOperatorItem(std::set units) + : units_(std::move(units)), valid_flag_(true) {} + ~LogicOperatorItem() = default; + static LogicOperatorItem And(const LogicOperatorItem &left, const LogicOperatorItem &right); + static LogicOperatorItem Or(const LogicOperatorItem &left, const LogicOperatorItem &right); + static bool CheckSupport(const LogicOperatorItem &left, const LogicOperatorItem &right); + bool operator<(const LogicOperatorItem &item) const { + return String() < item.String(); + } + const LogicOperatorItem &Simplify(); + bool IsEmpty() const { return valid_flag_ && units_.empty(); } + bool IsValid() const { return valid_flag_; } + std::string String() const; + + private: + std::set units_; + bool valid_flag_; +}; + +class ExecCondLogicOperator { + public: + static LogicOperatorItem Calculate(const std::vector &items, + const OperationType &operation_type); +}; + +class BranchExecCondCalculator { + struct LoopGroup { + NodePtr loop_cond; // LoopCond node + std::list next_iteration_nodes; // NextIteration nodes + }; + + public: + explicit BranchExecCondCalculator(ComputeGraphPtr graph) : graph_(std::move(graph)) {} + graphStatus Calculate(); + void GetBranchExecCondLabel(std::map &node_exec_cond_label) const; + const std::map &GetBranchExecCond() const { return node_exec_cond_; } + + private: + graphStatus FindSwitchCondInput(std::map &switch_to_cond_input, + std::map &cond_to_index, + bool &loop_branch_flag); + + graphStatus FindLoopExecCond(const std::map &cond_to_index, + std::map &next_exec_cond); + + graphStatus FindAllExecCond(const std::map &switch_to_cond_input, + const std::map &cond_to_index, + const std::map &next_exec_cond); + + graphStatus GroupEnterNodes(std::map> &groups_to_enter_nodes); + + static graphStatus FindLoopGroup(const std::map> &groups_to_enter_nodes, + std::list &loop_groups); + + static void FindSwitchAndNextIteration(const std::list &enter_nodes, + std::list &loop_switch_nodes, + std::list &next_iteration_nodes); + + ComputeGraphPtr graph_; + std::map node_exec_cond_; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_BRANCH_EXEC_COND_LOGIC_CALCULATOR_H_ diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 7a6672dac88b8cc233e82e988527414b84007bff..afed675664831ff45ff8bba958c56602edad07b8 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -31,6 +31,7 @@ #include "graph/model.h" #include "graph/node.h" #include "graph/utils/anchor_utils.h" +#include "graph/utils/branch_exec_cond_calculator.h" #define GE_DUMP(compute_graph, name) \ do { \ @@ -382,13 +383,16 @@ class GraphUtils { static graphStatus UnfoldSubgraph(const ComputeGraphPtr &graph, const std::function &filter); + static graphStatus GetBranchExecCondLabel(const ComputeGraphPtr &graph, + std::map &node_exec_cond_label); + private: /// /// Get reference-mapping for in_data_anchors of node /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus HandleInAnchorMapping(const ComputeGraphPtr &graph, const NodePtr &node, std::map> &symbol_to_anchors, @@ -399,7 +403,7 @@ class GraphUtils { /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus HandleOutAnchorMapping(const NodePtr &node, std::map> &symbol_to_anchors, @@ -410,7 +414,7 @@ class GraphUtils { /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus HandleSubgraphInput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -421,7 +425,7 @@ class GraphUtils { /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus HandleMergeInput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -432,7 +436,7 @@ class GraphUtils { /// @param [in] node /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus HandleSubgraphOutput(const NodePtr &node, std::map> &symbol_to_anchors, @@ -443,7 +447,7 @@ class GraphUtils { /// @param [in] node: original node. /// @param [in] prefix: node name prefix of new node. /// @param [in] all_nodes: all nodes in new graph. - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, const std::unordered_map &all_nodes); @@ -455,7 +459,7 @@ class GraphUtils { /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol /// @param [out] symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, std::map> &symbol_to_anchors, @@ -467,7 +471,7 @@ class GraphUtils { /// @param [in] exist_node_info /// @param [out] symbol_to_anchors /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS + /// @return success: GRAPH_SUCCESS /// static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, std::map> &symbol_to_anchors, diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index 714cf3584433edb0b9194724200f9ad945b1a759..d809c56261847e54df7dc7220eac53c7084de435 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -213,6 +213,7 @@ class NodeUtils { static graphStatus UpdateInputOriginalShapeAndShape(const Node &node, uint32_t index, const GeShape &shape); static graphStatus UpdateOutputOriginalShapeAndShape(const Node &node, uint32_t index, const GeShape &shape); + static OutDataAnchorPtr GetOriginalSwitchCondInput(const NodePtr &switch_node); private: static std::map> map_send_info_; diff --git a/tests/ut/graph/CMakeLists.txt b/tests/ut/graph/CMakeLists.txt index 370a4fc9419865aa9120ec89a423ebbd4f472763..5cc6ec6ef0c54037417ac681a323c175a22ff12e 100644 --- a/tests/ut/graph/CMakeLists.txt +++ b/tests/ut/graph/CMakeLists.txt @@ -139,6 +139,7 @@ set(GRAPH_SRC_FILES "${METADEF_DIR}/graph/utils/tuning_utils.cc" "${METADEF_DIR}/graph/utils/type_utils.cc" "${METADEF_DIR}/graph/utils/constant_utils.cc" + "${METADEF_DIR}/graph/utils/branch_exec_cond_calculator.cc" "${METADEF_DIR}/ops/op_imp.cpp" "${METADEF_DIR}/third_party/transformer/src/axis_util.cc" "${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc" diff --git a/tests/ut/graph/testcase/graph_utils_unittest.cc b/tests/ut/graph/testcase/graph_utils_unittest.cc index ea18c90b1b26597dfbfebe316496a3edaa47810f..896455c8d871a3de1624a33e13495cdcd9e4e5ac 100644 --- a/tests/ut/graph/testcase/graph_utils_unittest.cc +++ b/tests/ut/graph/testcase/graph_utils_unittest.cc @@ -15,10 +15,6 @@ */ #include - -#define protected public -#define private public - #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" #include "graph/op_desc_impl.h" @@ -166,22 +162,21 @@ void BuildGraphForUnfold(ComputeGraphPtr &graph, ComputeGraphPtr &subgraph) { class UtestGraphUtils : public testing::Test { protected: - void SetUp() {} - - void TearDown() {} + void SetUp() override {} + void TearDown() override {} }; -/* -* var var -* atomicclean | \ | \ -* \\ | assign | assign -* \\ | // =======> | // -* allreduce identity atomicclean -* | | // -* netoutput allreduce -* | -* netoutput - */ +/// +/// var var +/// atomic_clean | \ | \ +/// \\ | assign | assign +/// \\ | // =======> | // +/// all_reduce identity atomic_clean +/// | | // +/// net_output all_reduce +/// | +/// net_output +/// TEST_F(UtestGraphUtils, InsertNodeBefore_DoNotMoveCtrlEdgeFromAtomicClean) { // build test graph auto builder = ut::GraphBuilder("test"); @@ -193,20 +188,781 @@ TEST_F(UtestGraphUtils, InsertNodeBefore_DoNotMoveCtrlEdgeFromAtomicClean) { const auto &identity = builder.AddNode("identity", "Identity", 1, 1); builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var,0,allreduce,0); + builder.AddDataEdge(var, 0, allreduce, 0); builder.AddControlEdge(assign, allreduce); builder.AddControlEdge(atomic_clean, allreduce); auto graph = builder.GetGraph(); - // insert identity before allreduce + // insert identity before all_reduce GraphUtils::InsertNodeBefore(allreduce->GetInDataAnchor(0), identity, 0, 0); // check assign control-in on identity ASSERT_EQ(identity->GetInControlNodes().at(0)->GetName(), "assign"); - // check atomicclean control-in still on allreuce + // check atomic_clean control-in still on all_reduce ASSERT_EQ(allreduce->GetInControlNodes().at(0)->GetName(), "atomic_clean"); } +/// +/// add +/// / \ +/// const0 const1 +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_no_v1_branch) { + auto builder = ut::GraphBuilder("no_v1_branch"); + auto const0 = builder.AddNode("const0", "Const", 0, 1); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto add = builder.AddNode("add", "Add", 2, 1); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(const0, 0, add, 0); + builder.AddDataEdge(const1, 0, add, 1); + builder.AddDataEdge(add, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_TRUE(node_exec_labels.empty()); +} + +/// +/// net_output +/// | +/// merge +/// / \ +/// square add +/// F| T/ T\ +/// switch1 switch2 +/// / \ / \ +/// var1 var2 var3 +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_cond_branch) { + auto builder = ut::GraphBuilder("v1_cond_branch"); + auto var1 = builder.AddNode("var1", "VariableV2", 0, 1); + auto var2 = builder.AddNode("var2", "VariableV2", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto var3 = builder.AddNode("var3", "VariableV2", 0, 1); + auto switch1 = builder.AddNode("switch1", "RefSwitch", 2, 2); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto add = builder.AddNode("add", "Add", 2, 1); + auto square = builder.AddNode("square", "Square", 1, 1); + auto merge = builder.AddNode("merge", "Merge", 2, 2); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(var1, 0, switch1, 0); + builder.AddDataEdge(var2, 0, switch1, 1); + builder.AddDataEdge(var3, 0, switch2, 0); + builder.AddDataEdge(var2, 0, switch2, 1); + builder.AddDataEdge(switch1, 0, square, 0); + builder.AddDataEdge(switch1, 1, add, 0); + builder.AddDataEdge(switch2, 1, add, 1); + builder.AddDataEdge(square, 0, merge, 0); + builder.AddDataEdge(add, 0, merge, 1); + builder.AddDataEdge(merge, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[merge]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[net_output]); + + ASSERT_NE(node_exec_labels[var1], node_exec_labels[add]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[square]); + + ASSERT_NE(node_exec_labels[add], node_exec_labels[square]); +} + +/// +/// net_output +/// | +/// exit next_iteration +/// \ | | +/// \ add | +/// F\ T/ \ | +/// switch1 enter1 | +/// / | | | +/// loop_cond | const1 | +/// | | | +/// less | | +/// / \ | | +/// enter2 merge ---------| +/// | | +/// const2 enter3 +/// | +/// var +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_loop_branch) { + auto builder = ut::GraphBuilder("v1_loop_branch"); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto enter1 = builder.AddNode("enter1", "Enter", 1, 1); + AttrUtils::SetStr(enter1->GetOpDesc(), "frame_name", "frame_name"); + auto const2 = builder.AddNode("const2", "Const", 0, 1); + auto enter2 = builder.AddNode("enter2", "Enter", 1, 1); + AttrUtils::SetStr(enter2->GetOpDesc(), "frame_name", "frame_name"); + auto var = builder.AddNode("var", "VariableV2", 0, 1); + auto enter3 = builder.AddNode("enter3", "Enter", 1, 1); + AttrUtils::SetStr(enter3->GetOpDesc(), "frame_name", "frame_name"); + auto merge = builder.AddNode("merge", "Merge", 2, 2); + auto less = builder.AddNode("less", "Less", 2, 1); + auto loop_cond = builder.AddNode("loop_cond", "LoopCond", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 2); + auto add = builder.AddNode("add", "add", 2, 1); + auto next_iteration = builder.AddNode("next_iteration", "NextIteration", 1, 1); + auto exit = builder.AddNode("exit", "Exit", 1, 1); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(const1, 0, enter1, 0); + builder.AddDataEdge(const2, 0, enter2, 0); + builder.AddDataEdge(var, 0, enter3, 0); + builder.AddDataEdge(enter3, 0, merge, 0); + builder.AddDataEdge(enter2, 0, less, 0); + builder.AddDataEdge(merge, 0, less, 1); + builder.AddDataEdge(merge, 0, switch1, 0); + builder.AddDataEdge(less, 0, loop_cond, 0); + builder.AddDataEdge(loop_cond, 0, switch1, 1); + builder.AddDataEdge(switch1, 1, add, 0); + builder.AddDataEdge(enter1, 0, add, 1); + builder.AddDataEdge(add, 0, next_iteration, 0); + builder.AddDataEdge(next_iteration, 0, merge, 1); + builder.AddDataEdge(switch1, 0, exit, 0); + builder.AddDataEdge(exit, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter1]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[const2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[var]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter3]); + + ASSERT_EQ(node_exec_labels[merge], node_exec_labels[less]); + ASSERT_EQ(node_exec_labels[merge], node_exec_labels[loop_cond]); + ASSERT_EQ(node_exec_labels[merge], node_exec_labels[switch1]); + + ASSERT_EQ(node_exec_labels[add], node_exec_labels[next_iteration]); + + ASSERT_EQ(node_exec_labels[exit], node_exec_labels[net_output]); + + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[merge]); + + ASSERT_NE(node_exec_labels[const1], node_exec_labels[add]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[exit]); + + ASSERT_NE(node_exec_labels[add], node_exec_labels[exit]); +} + +/// +/// net_output +/// | +/// merge3 +/// / \ +/// merge1 merge2 +/// / \ / \ +/// I3 I4 I5 I6 +/// F\ T/ F\ T/ +/// switch4 switch5 +/// | T\ F/ | +/// I1 switch3 I2 +/// T| | \ F| +/// switch1 |var4 switch2 +/// / \ | / \ +/// variable1 variable2 variable3 +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_nest_cond_branch) { + auto builder = ut::GraphBuilder("v1_nest_cond_branch"); + auto var1 = builder.AddNode("var1", "VariableV2", 0, 1); + auto var2 = builder.AddNode("var2", "VariableV2", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto var3 = builder.AddNode("var3", "VariableV2", 0, 1); + auto var4 = builder.AddNode("var4", "VariableV2", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "RefSwitch", 2, 2); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto switch3 = builder.AddNode("switch3", "Switch", 2, 2); + auto switch4 = builder.AddNode("switch4", "Switch", 2, 2); + auto switch5 = builder.AddNode("switch5", "Switch", 2, 2); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto identity2 = builder.AddNode("identity2", "Identity", 1, 1); + auto identity3 = builder.AddNode("identity3", "Identity", 1, 1); + auto identity4 = builder.AddNode("identity4", "Identity", 1, 1); + auto identity5 = builder.AddNode("identity5", "Identity", 1, 1); + auto identity6 = builder.AddNode("identity6", "Identity", 1, 1); + auto merge1 = builder.AddNode("merge1", "Merge", 2, 2); + auto merge2 = builder.AddNode("merge2", "Merge", 2, 2); + auto merge3 = builder.AddNode("merge3", "Merge", 2, 2); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(var1, 0, switch1, 0); + builder.AddDataEdge(var2, 0, switch1, 1); + builder.AddDataEdge(var3, 0, switch2, 0); + builder.AddDataEdge(var2, 0, switch2, 1); + builder.AddDataEdge(var4, 0, switch3, 0); + builder.AddDataEdge(var2, 0, switch3, 1); + builder.AddDataEdge(switch1, 1, identity1, 0); + builder.AddDataEdge(switch2, 0, identity2, 0); + builder.AddDataEdge(identity1, 0, switch4, 0); + builder.AddDataEdge(switch3, 1, switch4, 1); + builder.AddDataEdge(switch4, 0, identity3, 0); + builder.AddDataEdge(switch4, 1, identity4, 0); + builder.AddDataEdge(identity3, 0, merge1, 0); + builder.AddDataEdge(identity4, 0, merge1, 1); + builder.AddDataEdge(identity2, 0, switch5, 0); + builder.AddDataEdge(switch3, 0, switch5, 1); + builder.AddDataEdge(switch5, 0, identity5, 0); + builder.AddDataEdge(switch5, 1, identity6, 0); + builder.AddDataEdge(identity5, 0, merge2, 0); + builder.AddDataEdge(identity6, 0, merge2, 1); + builder.AddDataEdge(merge1, 0, merge3, 0); + builder.AddDataEdge(merge2, 0, merge3, 1); + builder.AddDataEdge(merge3, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var4]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[merge3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[net_output]); + + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[switch4]); + + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[switch5]); + + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity3]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity4]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity5]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity3]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity4]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity5]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[identity3]); + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[identity4]); + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[identity5]); + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity3], node_exec_labels[identity4]); + ASSERT_NE(node_exec_labels[identity3], node_exec_labels[identity5]); + ASSERT_NE(node_exec_labels[identity3], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity4], node_exec_labels[identity5]); + ASSERT_NE(node_exec_labels[identity4], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity5], node_exec_labels[identity6]); +} + +/// +/// next_iteration2 +/// | | +/// | I2 exit2 +/// | T\ F/ | +/// | switch2 | +/// | / | | +/// | loop_cond2 | | +/// | | | | +/// | less2 | | +/// | / \ | | +/// |-----/---> merge2 | +/// / / | +/// enter5 enter3 | +/// | / | +/// enter4 | | +/// | | | +/// const2 | | +/// | | +/// net_output | next_iteration1 +/// | / | +/// exit1 I1 | +/// F\ T/ | +/// switch1 | +/// / | | +/// loop_cond1 | | +/// | | | +/// less1 | | +/// / \ | | +/// enter2 merge1 <--------| +/// | | +/// const1 enter1 +/// | +/// var1 + +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_nest_loop_branch) { + auto builder = ut::GraphBuilder("v1_nest_loop_branch"); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto enter1 = builder.AddNode("enter1", "Enter", 1, 1); + AttrUtils::SetStr(enter1->GetOpDesc(), "frame_name", "frame_name_1"); + auto var1 = builder.AddNode("var1", "VariableV2", 0, 1); + auto enter2 = builder.AddNode("enter2", "Enter", 1, 1); + AttrUtils::SetStr(enter2->GetOpDesc(), "frame_name", "frame_name_1"); + auto merge1 = builder.AddNode("merge1", "Merge", 2, 2); + auto less1 = builder.AddNode("less1", "Less", 2, 1); + auto loop_cond1 = builder.AddNode("loop_cond1", "LoopCond", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 2); + auto exit1 = builder.AddNode("exit1", "Exit", 1, 1); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto enter3 = builder.AddNode("enter3", "Enter", 1, 1); + AttrUtils::SetStr(enter3->GetOpDesc(), "frame_name", "frame_name_2"); + auto const2 = builder.AddNode("const2", "Const", 0, 1); + auto enter4 = builder.AddNode("enter4", "Enter", 1, 1); + AttrUtils::SetStr(enter4->GetOpDesc(), "frame_name", "frame_name_1"); + auto enter5 = builder.AddNode("enter5", "Enter", 1, 1); + AttrUtils::SetStr(enter5->GetOpDesc(), "frame_name", "frame_name_2"); + auto merge2 = builder.AddNode("merge2", "Merge", 2, 2); + auto less2 = builder.AddNode("less2", "Less", 2, 1); + auto loop_cond2 = builder.AddNode("loop_cond2", "LoopCond", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto exit2 = builder.AddNode("exit2", "Exit", 1, 1); + auto identity2 = builder.AddNode("identity2", "Identity", 1, 1); + auto next_iteration2 = builder.AddNode("next_iteration2", "NextIteration", 1, 1); + auto next_iteration1 = builder.AddNode("next_iteration1", "NextIteration", 1, 1); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(var1, 0, enter1, 0); + builder.AddDataEdge(enter1, 0, merge1, 0); + builder.AddDataEdge(merge1, 0, less1, 0); + builder.AddDataEdge(merge1, 0, switch1, 0); + builder.AddDataEdge(const1, 0, enter2, 0); + builder.AddDataEdge(enter2, 0, less1, 1); + builder.AddDataEdge(less1, 0, loop_cond1, 0); + builder.AddDataEdge(loop_cond1, 0, switch1, 1); + builder.AddDataEdge(switch1, 0, exit1, 0); + builder.AddDataEdge(exit1, 0, net_output, 0); + builder.AddDataEdge(switch1, 1, identity1, 0); + builder.AddDataEdge(identity1, 0, enter3, 0); + builder.AddDataEdge(enter3, 0, merge2, 0); + builder.AddDataEdge(merge2, 0, less2, 0); + builder.AddDataEdge(merge2, 0, switch2, 0); + builder.AddDataEdge(const2, 0, enter4, 0); + builder.AddDataEdge(enter4, 0, enter5, 0); + builder.AddDataEdge(enter5, 0, less2, 1); + builder.AddDataEdge(less2, 0, loop_cond2, 0); + builder.AddDataEdge(loop_cond2, 0, switch2, 1); + builder.AddDataEdge(switch2, 1, identity2, 0); + builder.AddDataEdge(identity2, 0, next_iteration2, 0); + builder.AddDataEdge(next_iteration2, 0, merge2, 1); + builder.AddDataEdge(switch2, 0, exit2, 0); + builder.AddDataEdge(exit2, 0, next_iteration1, 0); + builder.AddDataEdge(next_iteration1, 0, merge1, 1); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[enter1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[merge1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[const1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[enter2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[less1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[loop_cond1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[const2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[enter4]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[enter5]); + + ASSERT_EQ(node_exec_labels[exit1], node_exec_labels[net_output]); + + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[enter3]); + + ASSERT_EQ(node_exec_labels[merge2], node_exec_labels[less2]); + ASSERT_EQ(node_exec_labels[merge2], node_exec_labels[loop_cond2]); + ASSERT_EQ(node_exec_labels[merge2], node_exec_labels[switch2]); + + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[next_iteration2]); + + ASSERT_EQ(node_exec_labels[exit2], node_exec_labels[next_iteration1]); + + ASSERT_NE(node_exec_labels[var1], node_exec_labels[exit1]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[merge2]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[exit2]); + + ASSERT_NE(node_exec_labels[exit1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[exit1], node_exec_labels[merge2]); + ASSERT_NE(node_exec_labels[exit1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[exit1], node_exec_labels[exit2]); + + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[merge2]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[exit2]); + + ASSERT_NE(node_exec_labels[merge2], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[merge2], node_exec_labels[exit2]); + + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[exit2]); +} + +/// +/// next_iteration ---| +/// | | +/// net_output merge2 | +/// | / \ | +/// exit I1 I2 | +/// \ T\ /F | +/// \ switch2 | +/// F\ T/ \ | +/// switch1 enter1 | +/// / | | | +/// loop_cond | const1 | +/// | | | +/// less | | +/// / \ | | +/// enter2 merge1 --------| +/// | | +/// const2 enter3 +/// | +/// var +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_loop_cond_branch) { + auto builder = ut::GraphBuilder("v1_loop_cond_branch"); + auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto enter1 = builder.AddNode("enter1", "Enter", 1, 1); + AttrUtils::SetStr(enter1->GetOpDesc(), "frame_name", "frame_name"); + auto const2 = builder.AddNode("const2", "Const", 0, 1); + auto enter2 = builder.AddNode("enter2", "Enter", 1, 1); + AttrUtils::SetStr(enter2->GetOpDesc(), "frame_name", "frame_name"); + auto var = builder.AddNode("var", "VariableV2", 0, 1); + auto enter3 = builder.AddNode("enter3", "Enter", 1, 1); + AttrUtils::SetStr(enter3->GetOpDesc(), "frame_name", "frame_name"); + auto merge1 = builder.AddNode("merge1", "Merge", 2, 2); + auto less = builder.AddNode("less", "Less", 2, 1); + auto loop_cond = builder.AddNode("loop_cond", "LoopCond", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 2); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto identity2 = builder.AddNode("identity2", "Identity", 1, 1); + auto merge2 = builder.AddNode("merge2", "Merge", 2, 2); + auto next_iteration = builder.AddNode("next_iteration", "NextIteration", 1, 1); + auto exit = builder.AddNode("exit", "Exit", 1, 1); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(const1, 0, enter1, 0); + builder.AddDataEdge(const2, 0, enter2, 0); + builder.AddDataEdge(var, 0, enter3, 0); + builder.AddDataEdge(enter3, 0, merge1, 0); + builder.AddDataEdge(enter2, 0, less, 0); + builder.AddDataEdge(merge1, 0, less, 1); + builder.AddDataEdge(merge1, 0, switch1, 0); + builder.AddDataEdge(less, 0, loop_cond, 0); + builder.AddDataEdge(loop_cond, 0, switch1, 1); + builder.AddDataEdge(switch1, 1, switch2, 0); + builder.AddDataEdge(enter1, 0, switch2, 1); + builder.AddDataEdge(switch2, 1, identity1, 0); + builder.AddDataEdge(switch2, 0, identity2, 0); + builder.AddDataEdge(identity1, 0, merge2, 0); + builder.AddDataEdge(identity2, 0, merge2, 1); + builder.AddDataEdge(merge2, 0, next_iteration, 0); + builder.AddDataEdge(next_iteration, 0, merge1, 1); + builder.AddDataEdge(switch1, 0, exit, 0); + builder.AddDataEdge(exit, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter1]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[const2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[var]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[enter3]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[merge1]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[less]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[loop_cond]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[switch1]); + + ASSERT_EQ(node_exec_labels[switch2], node_exec_labels[merge2]); + ASSERT_EQ(node_exec_labels[switch2], node_exec_labels[next_iteration]); + + ASSERT_EQ(node_exec_labels[exit], node_exec_labels[net_output]); + + ASSERT_NE(node_exec_labels[const1], node_exec_labels[switch2]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[exit]); + + ASSERT_NE(node_exec_labels[switch2], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[switch2], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[switch2], node_exec_labels[exit]); + + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[exit]); + + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[exit]); +} + +/// +/// net_output +/// | +/// merge2 next_iteration +/// / \ | | +/// / exit I1 | +/// / F\ T| | +/// / switch3 | +/// / / | | +/// I2 loop_cond | | +/// | | | | +/// | less | | +/// | / \ | | +/// | / merge1 ----| +/// | enter1 | +/// | / enter2 +/// T| F/ F/ +/// switch1 switch2 +/// / \ / \ +/// const1 const2 var +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_cond_loop_branch) { + auto builder = ut::GraphBuilder("v1_cond_loop_branch"); + auto const1 = builder.AddNode("const1", "Const", 0, 1); + auto const2 = builder.AddNode("const2", "Const", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto var = builder.AddNode("var", "VariableV2", 0, 1); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 2); + auto identity2 = builder.AddNode("identity2", "Identity", 1, 1); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto enter1 = builder.AddNode("enter1", "Enter", 1, 1); + AttrUtils::SetStr(enter1->GetOpDesc(), "frame_name", "frame_name"); + auto enter2 = builder.AddNode("enter2", "Enter", 1, 1); + AttrUtils::SetStr(enter2->GetOpDesc(), "frame_name", "frame_name"); + auto merge1 = builder.AddNode("merge1", "Merge", 2, 2); + auto less = builder.AddNode("less", "Less", 2, 1); + auto loop_cond = builder.AddNode("loop_cond", "LoopCond", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch3 = builder.AddNode("switch3", "Switch", 2, 2); + auto exit = builder.AddNode("exit", "Exit", 1, 1); + auto merge2 = builder.AddNode("merge2", "Merge", 2, 2); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto next_iteration = builder.AddNode("next_iteration", "NextIteration", 1, 1); + + builder.AddDataEdge(const1, 0, switch1, 0); + builder.AddDataEdge(const2, 0, switch1, 1); + builder.AddDataEdge(var, 0, switch2, 0); + builder.AddDataEdge(const2, 0, switch2, 1); + builder.AddDataEdge(switch1, 1, identity2, 0); + builder.AddDataEdge(identity2, 0, merge2, 0); + builder.AddDataEdge(switch1, 0, enter1, 0); + builder.AddDataEdge(enter1, 0, less, 0); + builder.AddDataEdge(switch2, 0, enter2, 0); + builder.AddDataEdge(enter2, 0, merge1, 0); + builder.AddDataEdge(merge1, 0, less, 1); + builder.AddDataEdge(merge1, 0, switch3, 0); + builder.AddDataEdge(less, 0, loop_cond, 0); + builder.AddDataEdge(loop_cond, 0, switch3, 1); + builder.AddDataEdge(switch3, 0, exit, 0); + builder.AddDataEdge(exit, 0, merge2, 1); + builder.AddDataEdge(merge2, 0, net_output, 0); + builder.AddDataEdge(switch3, 1, identity1, 0); + builder.AddDataEdge(identity1, 0, next_iteration, 0); + builder.AddDataEdge(next_iteration, 0, merge1, 1); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[const2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[var]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[switch2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[merge2]); + ASSERT_EQ(node_exec_labels[const1], node_exec_labels[net_output]); + + ASSERT_EQ(node_exec_labels[enter1], node_exec_labels[enter2]); + + ASSERT_EQ(node_exec_labels[less], node_exec_labels[loop_cond]); + ASSERT_EQ(node_exec_labels[less], node_exec_labels[switch3]); + + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[next_iteration]); + + ASSERT_EQ(node_exec_labels[merge2], node_exec_labels[net_output]); + + ASSERT_NE(node_exec_labels[const1], node_exec_labels[enter1]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[merge1]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[exit]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[const1], node_exec_labels[identity2]); + + ASSERT_NE(node_exec_labels[enter1], node_exec_labels[merge1]); + ASSERT_NE(node_exec_labels[enter1], node_exec_labels[exit]); + ASSERT_NE(node_exec_labels[enter1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[enter1], node_exec_labels[identity2]); + + ASSERT_NE(node_exec_labels[merge1], node_exec_labels[exit]); + ASSERT_NE(node_exec_labels[merge1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[merge1], node_exec_labels[identity2]); + + ASSERT_NE(node_exec_labels[exit], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[exit], node_exec_labels[identity2]); + + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity2]); +} + +/// +/// net_output +/// | +/// merge3 +/// / \ +/// merge1 merge2 +/// / \ / \ +/// I5 I6 I7 I8 +/// T\ F/ T\ F/ +/// switch3 switch4 +/// \ \ | \ +/// \ -|-- \ +/// \ | \ \ +/// I1 I2 I3 I4 +/// T| F| T| F| +/// switch1 switch2 +/// | \ | +/// var1 var2 +/// +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_redundant_cond_branch) { + auto builder = ut::GraphBuilder("v1_redundant_cond_branch"); + auto var1 = builder.AddNode("var1", "VariableV2", 0, 1); + auto var2 = builder.AddNode("var2", "VariableV2", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "Switch", 2, 2); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2, FORMAT_ND, DT_BOOL, {}); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto identity2 = builder.AddNode("identity2", "Identity", 1, 1); + auto identity3 = builder.AddNode("identity3", "Identity", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto identity4 = builder.AddNode("identity4", "Identity", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch3 = builder.AddNode("switch3", "Switch", 2, 2); + auto switch4 = builder.AddNode("switch4", "Switch", 2, 2); + auto identity5 = builder.AddNode("identity5", "Identity", 1, 1); + auto identity6 = builder.AddNode("identity6", "Identity", 1, 1); + auto identity7 = builder.AddNode("identity7", "Identity", 1, 1); + auto identity8 = builder.AddNode("identity8", "Identity", 1, 1); + auto merge1 = builder.AddNode("merge1", "Merge", 2, 2); + auto merge2 = builder.AddNode("merge2", "Merge", 2, 2); + auto merge3 = builder.AddNode("merge3", "Merge", 2, 2); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(var1, 0, switch1, 0); + builder.AddDataEdge(var2, 0, switch1, 1); + builder.AddDataEdge(var2, 0, switch2, 0); + builder.AddDataEdge(var2, 0, switch2, 1); + builder.AddDataEdge(switch1, 1, identity1, 0); + builder.AddDataEdge(switch1, 0, identity2, 0); + builder.AddDataEdge(switch2, 1, identity3, 0); + builder.AddDataEdge(switch2, 0, identity4, 0); + builder.AddDataEdge(identity1, 0, switch3, 0); + builder.AddDataEdge(identity2, 0, switch4, 0); + builder.AddDataEdge(identity3, 0, switch3, 1); + builder.AddDataEdge(identity4, 0, switch4, 1); + builder.AddDataEdge(switch3, 1, identity5, 0); + builder.AddDataEdge(switch3, 0, identity6, 0); + builder.AddDataEdge(switch4, 1, identity7, 0); + builder.AddDataEdge(switch4, 0, identity8, 0); + builder.AddDataEdge(identity5, 0, merge1, 0); + builder.AddDataEdge(identity6, 0, merge1, 1); + builder.AddDataEdge(identity7, 0, merge2, 0); + builder.AddDataEdge(identity8, 0, merge2, 1); + builder.AddDataEdge(merge1, 0, merge3, 0); + builder.AddDataEdge(merge2, 0, merge3, 1); + builder.AddDataEdge(merge3, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[merge3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[net_output]); + + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[identity3]); + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[switch3]); + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[identity5]); + ASSERT_EQ(node_exec_labels[identity1], node_exec_labels[merge1]); + + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[identity4]); + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[switch4]); + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[identity8]); + ASSERT_EQ(node_exec_labels[identity2], node_exec_labels[merge2]); + + ASSERT_EQ(node_exec_labels[identity6], node_exec_labels[identity7]); + + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity1]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity2]); + ASSERT_NE(node_exec_labels[identity1], node_exec_labels[identity6]); + + ASSERT_NE(node_exec_labels[identity2], node_exec_labels[identity6]); +} + + +/// +/// net_output +/// | +/// merge +/// / | +/// add1 <--|--------| +/// | add2 | +/// T| F/ T\ F| +/// switch1 switch2 +/// | \ | \ +/// | \ not var3 +/// | \ / +/// var1 var2 +/// +TEST_F(UtestGraphUtils, GetBranchExecCond_v1_diff_cond_branch) { + auto builder = ut::GraphBuilder("v1_diff_cond_branch"); + auto var1 = builder.AddNode("var1", "VariableV2", 0, 1); + auto var2 = builder.AddNode("var2", "VariableV2", 0, 1, FORMAT_ND, DT_BOOL, {}); + auto var3 = builder.AddNode("var3", "VariableV2", 0, 1); + auto logical_not = builder.AddNode("logical_not", "LogicalNot", 1, 1, FORMAT_ND, DT_BOOL, {}); + auto switch1 = builder.AddNode("switch1", "RefSwitch", 2, 2); + auto switch2 = builder.AddNode("switch2", "Switch", 2, 2); + auto add1 = builder.AddNode("add1", "Add", 2, 1); + auto add2 = builder.AddNode("add2", "Add", 2, 1); + auto merge = builder.AddNode("merge", "Merge", 2, 2); + auto net_output = builder.AddNode("net_output", "NetOutput", 1, 0); + + builder.AddDataEdge(var1, 0, switch1, 0); + builder.AddDataEdge(var2, 0, switch1, 1); + builder.AddDataEdge(var3, 0, switch2, 0); + builder.AddDataEdge(var2, 0, logical_not, 0); + builder.AddDataEdge(logical_not, 0, switch2, 1); + builder.AddDataEdge(switch1, 1, add1, 0); + builder.AddDataEdge(switch1, 0, add2, 0); + builder.AddDataEdge(switch2, 1, add2, 1); + builder.AddDataEdge(switch2, 0, add1, 1); + builder.AddDataEdge(add1, 0, merge, 0); + builder.AddDataEdge(add2, 0, merge, 1); + builder.AddDataEdge(merge, 0, net_output, 0); + + std::map node_exec_labels; + ASSERT_EQ(GraphUtils::GetBranchExecCondLabel(builder.GetGraph(), node_exec_labels), GRAPH_SUCCESS); + ASSERT_FALSE(node_exec_labels.empty()); + + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[var3]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch1]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[switch2]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[merge]); + ASSERT_EQ(node_exec_labels[var1], node_exec_labels[net_output]); + + ASSERT_NE(node_exec_labels[var1], node_exec_labels[add1]); + ASSERT_NE(node_exec_labels[var1], node_exec_labels[add2]); + + ASSERT_NE(node_exec_labels[add1], node_exec_labels[add2]); +} + TEST_F(UtestGraphUtils, GetSubgraphs) { auto root_builder = ut::GraphBuilder("root"); const auto &case0 = root_builder.AddNode("case0", "Case", 0, 0);