From a66b521ba96d4954ebb36f1bb39e5fa4512fcb29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AE=81?= Date: Mon, 29 Sep 2025 12:01:08 +0800 Subject: [PATCH] delete unused parallel --- graph/CMakeLists.txt | 1 - graph/attr/ge_attr_define.cc | 3 - graph/parallelism/comm_task_builder.h | 43 - graph/parallelism/tensor_parallel_attrs.cc | 890 ---------------- inc/graph/debug/ge_attr_define.h | 3 - inc/graph/parallelism/graph_parallel_option.h | 82 -- inc/graph/parallelism/tensor_parallel_attrs.h | 396 -------- .../tensor_parallel_attrs_unittest.cc | 953 ------------------ 8 files changed, 2371 deletions(-) delete mode 100644 graph/parallelism/comm_task_builder.h delete mode 100644 graph/parallelism/tensor_parallel_attrs.cc delete mode 100644 inc/graph/parallelism/graph_parallel_option.h delete mode 100644 inc/graph/parallelism/tensor_parallel_attrs.h delete mode 100644 tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc diff --git a/graph/CMakeLists.txt b/graph/CMakeLists.txt index 11caf72ba2..c27f96a2f7 100644 --- a/graph/CMakeLists.txt +++ b/graph/CMakeLists.txt @@ -120,7 +120,6 @@ SET(GRAPH_SOURCE_LIST "utils/graph_thread_pool.cc" "utils/multi_thread_graph_builder.cc" "utils/type_utils_ex.cc" - "parallelism/tensor_parallel_attrs.cc" "utils/screen_printer.cc" "${METADEF_DIR}/third_party/transformer/src/axis_util.cc" "${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc" diff --git a/graph/attr/ge_attr_define.cc b/graph/attr/ge_attr_define.cc index 335e51f40b..1c5a0f6333 100644 --- a/graph/attr/ge_attr_define.cc +++ b/graph/attr/ge_attr_define.cc @@ -1453,9 +1453,6 @@ const std::string ATTR_NAME_SHARD_GRAPH_EXT_ATTRS = "_shard_graph_ext_attrs"; const std::string ATTR_NAME_IS_SHARD_GRAPH_FOR_LOAD = "_is_shard_graph_for_load"; const std::string ATTR_NAME_GRAPH_MODEL_DEPLOY_MODE = "_graphModelDeployMode"; -// for tensor parallelism -const std::string ATTR_NAME_TP_RESHARD_ATTR = "_reshard_attr"; - // for lowering const std::string ATTR_NAME_GRAPH_FLATTEN_OFFSET = "graph_flatten_offset"; diff --git a/graph/parallelism/comm_task_builder.h b/graph/parallelism/comm_task_builder.h deleted file mode 100644 index ac1088424f..0000000000 --- a/graph/parallelism/comm_task_builder.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ -#define METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ - -#include "graph/parallelism/tensor_parallel_attrs.h" -#include "nlohmann/json.hpp" - -namespace ge { -namespace tp { -class CommTaskBuilder { - public: - static CommTaskBuilder &GetInstance() { - static CommTaskBuilder instance; - return instance; - } - - void BuildCommTask(const nlohmann::json &j, CommTask &comm_task); - Status ConvertToJson(const CommTask &comm_task, nlohmann::json &j); - - private: - CommTaskBuilder(); - ~CommTaskBuilder() = default; - - void InitCommTaskBuilders(); - void InitJsonConverters(); - template - static Status ConvertToJson(const T *reshard_task, nlohmann::json &j); - - std::map> builders_; - std::map> json_converters_; -}; -} // namespace tp -} // namespace ge - -#endif // METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ diff --git a/graph/parallelism/tensor_parallel_attrs.cc b/graph/parallelism/tensor_parallel_attrs.cc deleted file mode 100644 index c5ea39aa6b..0000000000 --- a/graph/parallelism/tensor_parallel_attrs.cc +++ /dev/null @@ -1,890 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "parallelism/tensor_parallel_attrs.h" -#include "common/ge_common/util.h" -#include "graph/debug/ge_util.h" -#include "nlohmann/json.hpp" -#include "parallelism/comm_task_builder.h" - -#define USED_BY_JSON __attribute__((unused)) static - -namespace ge { -namespace tp { -namespace { -using Json = nlohmann::json; - -constexpr size_t kValidDimSliceItemNum = 2U; -constexpr size_t kIndexStepId = 0U; -constexpr size_t kIndexOutputIndex = 1U; - -Status StringToJson(const std::string &json_str, Json &json) { - std::stringstream ss; - ss << json_str; - try { - ss >> json; - } catch (const nlohmann::json::exception &e) { - GELOGE(PARAM_INVALID, "Failed to init json object, err = %s, json_str = %s", e.what(), json_str.c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -template -Status ParseFromJson(const std::string &type, const std::string &json_str, T &value) { - Json json; - GE_CHK_STATUS_RET_NOLOG(StringToJson(json_str, json)); - try { - value = json.get(); - } catch (const nlohmann::json::exception &e) { - GELOGE(PARAM_INVALID, - "Failed to parse json object, type = %s, err = %s, json_str = %s", - type.c_str(), - e.what(), - json_str.c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -template -std::shared_ptr CreateReshardTaskInfo(const Json &j) { - return ComGraphMakeShared(j.get()); -} - -template -std::string ToJsonString(const T &obj) { - try { - const Json j = obj; - return j.dump(); - } catch (const nlohmann::json::exception &e) { - GELOGE(FAILED, "Failed to dump object, err = %s", e.what()); - return ""; - } -} - -template -void GetValue(const Json &j, const std::string &key, T &value) { - value = j.at(key).template get(); -} -template -void TryGetValue(const Json &j, const std::string &key, T &value) { - if (j.contains(key)) { - value = j.at(key).template get(); - } -} -} // namespace - -void CommTaskBuilder::BuildCommTask(const Json &j, CommTask &comm_task) { - auto task_type = j.at("task_type").get(); - const decltype(builders_)::const_iterator it = builders_.find(task_type); - if (it == builders_.cend()) { - GELOGE(PARAM_INVALID, "unsupported op type %s", comm_task.task_type.c_str()); - return; - } - it->second(j, comm_task); // exception caught by caller - comm_task.task_type = std::move(task_type); -} - -Status CommTaskBuilder::ConvertToJson(const CommTask &comm_task, nlohmann::json &j) { - const decltype(json_converters_)::const_iterator it = json_converters_.find(comm_task.task_type); - GE_CHK_BOOL_RET_STATUS(it != json_converters_.cend(), - PARAM_INVALID, - "unsupported op type %s", - comm_task.task_type.c_str()); - return it->second(comm_task, j); // exception caught by caller -} - -std::string DeviceIndex::DebugString() const { - return engine_type + ToString(indices); -} - -USED_BY_JSON void to_json(Json &j, const DimSlice &dim_slice) { - j = std::vector{dim_slice.begin, dim_slice.end}; -} - -USED_BY_JSON void from_json(const Json &j, DimSlice &dim_slice) { - const auto &range = j.get>(); - if (range.size() == kValidDimSliceItemNum) { - dim_slice.begin = range.front(); - dim_slice.end = range.back(); - } else { - dim_slice.begin = -1; - dim_slice.end = -1; - GELOGE(PARAM_INVALID, "invalid DimSlice: %s", j.dump().c_str()); - } -} - -USED_BY_JSON void to_json(Json &j, const DeviceIndex &device_index) { - j = Json(); - j["engine_type"] = device_index.engine_type; - j["index"] = device_index.indices; -} - -USED_BY_JSON void from_json(const Json &j, DeviceIndex &device_index) { - GetValue(j, "engine_type", device_index.engine_type); - GetValue(j, "index", device_index.indices); -} - -USED_BY_JSON void to_json(Json &j, const ModelIndex &model_index) { - j = Json(); - j["device_index"] = model_index.device_index; - j["virtual_stage_id"] = model_index.virtual_stage_id; - j["stage_id"] = model_index.stage_id; -} - -USED_BY_JSON void from_json(const Json &j, ModelIndex &model_index) { - GetValue(j, "device_index", model_index.device_index); - GetValue(j, "virtual_stage_id", model_index.virtual_stage_id); - GetValue(j, "stage_id", model_index.stage_id); -} - -USED_BY_JSON void to_json(Json &j, const PipelineConfig &pipeline_config) { - j = Json(); - j["micro_batch"] = pipeline_config.micro_batch; - j["stage_id"] = pipeline_config.stage_id; - j["virtual_stage_id"] = pipeline_config.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, PipelineConfig &pipeline_config) { - GetValue(j, "micro_batch", pipeline_config.micro_batch); - GetValue(j, "stage_id", pipeline_config.stage_id); - GetValue(j, "virtual_stage_id", pipeline_config.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const TensorSliceDeployment &tensor_slice_deployment) { - j = Json(); - j["device_indices_each_slice"] = tensor_slice_deployment.device_indices_each_slice; - j["axis_slices"] = tensor_slice_deployment.axis_slices; -} - -USED_BY_JSON void from_json(const Json &j, TensorSliceDeployment &tensor_slice_deployment) { - GetValue(j, "device_indices_each_slice", tensor_slice_deployment.device_indices_each_slice); - GetValue(j, "axis_slices", tensor_slice_deployment.axis_slices); -} - -USED_BY_JSON void to_json(Json &j, const TensorDeployment &tensor_deployment) { - j = Json(); - j["shard_deployment"] = tensor_deployment.shard_deployment; - if (!tensor_deployment.verbose.empty()) { - j["verbose"] = tensor_deployment.verbose; - } -} - -USED_BY_JSON void from_json(const Json &j, TensorDeployment &tensor_deployment) { - GetValue(j, "shard_deployment", tensor_deployment.shard_deployment); - TryGetValue(j, "verbose", tensor_deployment.verbose); -} - -USED_BY_JSON void to_json(Json &j, const TensorDeployments &tensor_deployments) { - j = Json(); - j["deployments"] = tensor_deployments.deployments; -} - -USED_BY_JSON void from_json(const Json &j, NodeDeployments &node_deployments) { - GetValue(j, "deployments", node_deployments.deployments); -} - -USED_BY_JSON void from_json(const Json &j, TensorDeployments &tensor_deployments) { - GetValue(j, "deployments", tensor_deployments.deployments); -} - -USED_BY_JSON void to_json(Json &j, const NodeDeployment &node_deployment) { - j = Json(); - j["devices"] = node_deployment.devices; - j["pipeline_config"] = node_deployment.pipeline_config; -} - -USED_BY_JSON void from_json(const Json &j, NodeDeployment &node_deployment) { - GetValue(j, "devices", node_deployment.devices); - TryGetValue(j, "pipeline_config", node_deployment.pipeline_config); -} - -USED_BY_JSON void to_json(Json &j, const NodeDeployments &node_deployments) { - j = Json(); - j["deployments"] = node_deployments.deployments; -} - - -USED_BY_JSON void to_json(Json &j, const CommPair &comm_pair) { - j = Json(); - j["src_device_index"] = comm_pair.src_device_index; - j["dst_device_index"] = comm_pair.dst_device_index; - j["src_virtual_stage_id"] = comm_pair.src_virtual_stage_id; - j["dst_virtual_stage_id"] = comm_pair.dst_virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, CommPair &comm_pair) { - GetValue(j, "src_device_index", comm_pair.src_device_index); - GetValue(j, "dst_device_index", comm_pair.dst_device_index); - TryGetValue(j, "src_virtual_stage_id", comm_pair.src_virtual_stage_id); - TryGetValue(j, "dst_virtual_stage_id", comm_pair.dst_virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const FlowAttr &comm_group) { - j = Json(); - j["depth"] = comm_group.depth; - j["enqueue_policy"] = comm_group.enqueue_policy; -} - -USED_BY_JSON void from_json(const Json &j, FlowAttr &comm_group) { - GetValue(j, "depth", comm_group.depth); - GetValue(j, "enqueue_policy", comm_group.enqueue_policy); -} - -USED_BY_JSON void to_json(Json &j, const CommGroup &comm_group) { - j = comm_group.device_indices; -} - -USED_BY_JSON void from_json(const Json &j, CommGroup &comm_group) { - comm_group.device_indices = j.get>(); -} - -USED_BY_JSON void to_json(Json &j, const SendRecvReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSendReceive; - j["comm_pairs"] = task_info.comm_pairs; - j["parallel_group"] = task_info.parallel_group; - j["comm_type"] = task_info.comm_type; - j["flow_attr"] = task_info.flow_attr; -} - -USED_BY_JSON void from_json(const Json &j, SendRecvReshardTask &task_info) { - GetValue(j, "comm_pairs", task_info.comm_pairs); - TryGetValue(j, "comm_type", task_info.comm_type); - TryGetValue(j, "parallel_group", task_info.parallel_group); - TryGetValue(j, "flow_attr", task_info.flow_attr); -} - -USED_BY_JSON void to_json(Json &j, const AllGatherReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllGather; - j["axis"] = task_info.axis; - j["comm_groups"] = task_info.comm_groups; - j["parallel_group"] = task_info.parallel_group; - j["output_allocator"] = task_info.output_allocator; -} - -USED_BY_JSON void from_json(const Json &j, AllGatherReshardTask &all_gather_task_info) { - GetValue(j, "comm_groups", all_gather_task_info.comm_groups); - GetValue(j, "axis", all_gather_task_info.axis); - TryGetValue(j, "parallel_group", all_gather_task_info.parallel_group); - TryGetValue(j, "output_allocator", all_gather_task_info.output_allocator); -} - -USED_BY_JSON void to_json(Json &j, const AllToAllReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllToAll; - j["comm_groups"] = task_info.comm_groups; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllToAllReshardTask &all_to_all_task_info) { - GetValue(j, "comm_groups", all_to_all_task_info.comm_groups); - TryGetValue(j, "parallel_group", all_to_all_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const AllReduceReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllReduce; - j["comm_groups"] = task_info.comm_groups; - j["reduction"] = task_info.reduction; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllReduceReshardTask &all_reduce_task_info) { - GetValue(j, "reduction", all_reduce_task_info.reduction); - GetValue(j, "comm_groups", all_reduce_task_info.comm_groups); - TryGetValue(j, "parallel_group", all_reduce_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const AllReduceMeanReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllReduceMean; - j["comm_groups"] = task_info.comm_groups; - j["axis"] = task_info.axis; - j["value"] = task_info.value; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllReduceMeanReshardTask &task_info) { - GetValue(j, "comm_groups", task_info.comm_groups); - GetValue(j, "axis", task_info.axis); - GetValue(j, "value", task_info.value); - TryGetValue(j, "parallel_group", task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const ReduceScatterReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomReduceScatter; - j["comm_groups"] = task_info.comm_groups; - j["reduction"] = task_info.reduction; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, ReduceScatterReshardTask &reduce_scatter_task_info) { - GetValue(j, "reduction", reduce_scatter_task_info.reduction); - GetValue(j, "comm_groups", reduce_scatter_task_info.comm_groups); - TryGetValue(j, "parallel_group", reduce_scatter_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const BroadcastReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomBroadcast; - j["comm_groups"] = task_info.comm_groups; - j["roots"] = task_info.root_device_indices; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, BroadcastReshardTask &broadcast_task_info) { - GetValue(j, "roots", broadcast_task_info.root_device_indices); - GetValue(j, "comm_groups", broadcast_task_info.comm_groups); - TryGetValue(j, "parallel_group", broadcast_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const SliceReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSlice; - j["axes"] = task_info.axes; - j["offsets"] = task_info.offsets; - j["size"] = task_info.sizes; - j["device_index"] = task_info.device_index; -} - -USED_BY_JSON void from_json(const Json &j, SliceReshardTask &task_info) { - TryGetValue(j, "axes", task_info.axes); - GetValue(j, "offsets", task_info.offsets); - GetValue(j, "size", task_info.sizes); - TryGetValue(j, "device_index", task_info.device_index); -} - -USED_BY_JSON void to_json(Json &j, const SliceByAxisReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSliceByAxis; - j["axis_to_slice_deployments"] = task_info.axis_to_slice_deployments; -} - -USED_BY_JSON void from_json(const Json &j, SliceByAxisReshardTask &task_info) { - GetValue(j, "axis_to_slice_deployments", task_info.axis_to_slice_deployments); -} - -USED_BY_JSON void to_json(Json &j, const ConcatReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeConcat; - j["concat_dim"] = task_info.concat_dim; -} - -USED_BY_JSON void from_json(const Json &j, ConcatReshardTask &task_info) { - GetValue(j, "concat_dim", task_info.concat_dim); -} - -USED_BY_JSON void to_json(Json &j, const UniqueConcatReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeUniqueConcat; - j["unique_id"] = task_info.unique_id; - j["concat_dim"] = task_info.concat_dim; - j["src_device_indices"] = task_info.src_device_indices; - j["dst_device_index"] = task_info.dst_device_index; -} - -USED_BY_JSON void from_json(const Json &j, UniqueConcatReshardTask &task_info) { - TryGetValue(j, "unique_id", task_info.unique_id); - GetValue(j, "concat_dim", task_info.concat_dim); - GetValue(j, "src_device_indices", task_info.src_device_indices); - GetValue(j, "dst_device_index", task_info.dst_device_index); -} - -USED_BY_JSON void to_json(Json &j, const SplitReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSplit; - j["num_split"] = task_info.num_split; - j["split_dim"] = task_info.split_dim; -} - -USED_BY_JSON void from_json(const Json &j, SplitReshardTask &task_info) { - GetValue(j, "num_split", task_info.num_split); - GetValue(j, "split_dim", task_info.split_dim); -} - -USED_BY_JSON void to_json(Json &j, const TransposeReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeTranspose; - j["perm"] = task_info.perm; -} - -USED_BY_JSON void from_json(const Json &j, TransposeReshardTask &task_info) { - GetValue(j, "perm", task_info.perm); -} - -USED_BY_JSON void to_json(Json &j, const ReshapeReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeReshape; - j["shape"] = task_info.shape; -} - -USED_BY_JSON void from_json(const Json &j, ReshapeReshardTask &task_info) { - GetValue(j, "shape", task_info.shape); -} - -USED_BY_JSON void to_json(Json &j, const ModifyValueReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeModifyValue; - j["op_type"] = task_info.op_type; - j["value"] = task_info.value; -} - -USED_BY_JSON void from_json(const Json &j, ModifyValueReshardTask &task_info) { - GetValue(j, "op_type", task_info.op_type); - GetValue(j, "value", task_info.value); -} - -USED_BY_JSON void to_json(Json &j, const CastReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeCast; - j["dst_type"] = static_cast(task_info.dst_type); -} - -USED_BY_JSON void from_json(const Json &j, CastReshardTask &task_info) { - int32_t dst_type = -1; - GetValue(j, "dst_type", dst_type); - task_info.dst_type = static_cast(dst_type); -} - -USED_BY_JSON void to_json(Json &j, const CommTask &comm_task) { - GE_CHK_STATUS(CommTaskBuilder::GetInstance().ConvertToJson(comm_task, j)); -} - -USED_BY_JSON void from_json(const Json &j, CommTask &comm_task) { - CommTaskBuilder::GetInstance().BuildCommTask(j, comm_task); -} - -USED_BY_JSON void to_json(Json &j, const CommStepInput &step_input) { - j = std::vector{step_input.step_id, step_input.output_index}; -} - -USED_BY_JSON void from_json(const Json &j, CommStepInput &step_input) { - const auto step_id_and_out_index = j.get>(); - const size_t num_items = step_id_and_out_index.size(); - if (num_items > kIndexStepId) { - step_input.step_id = step_id_and_out_index[kIndexStepId]; - } - if (num_items > kIndexOutputIndex) { - step_input.output_index = step_id_and_out_index[kIndexOutputIndex]; - } -} - -USED_BY_JSON void to_json(Json &j, const CommStep &comm_step) { - j = Json(); - j["id"] = comm_step.id; - if (!comm_step.inputs.empty()) { - j["input_ids"] = comm_step.inputs; - } - j["comm_task"] = comm_step.comm_task; -} - -USED_BY_JSON void from_json(const Json &j, CommStep &comm_step) { - comm_step.id = j.at("id").get(); - if (j.contains("input_ids")) { - comm_step.inputs = j.at("input_ids").get>(); - } - comm_step.comm_task = j.at("comm_task").get(); -} - -USED_BY_JSON void to_json(Json &j, const PeerInput &peer_input) { - j = Json(); - j["step_id"] = peer_input.step_id; - j["node_name"] = peer_input.node_name; - j["input_index"] = peer_input.input_index; - j["stage_id"] = peer_input.stage_id; - j["virtual_stage_id"] = peer_input.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, PeerInput &peer_input) { - GetValue(j, "step_id", peer_input.step_id); - GetValue(j, "node_name", peer_input.node_name); - GetValue(j, "input_index", peer_input.input_index); - TryGetValue(j, "stage_id", peer_input.stage_id); - TryGetValue(j, "virtual_stage_id", peer_input.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const OutputReshardRes &reshard_res) { - j = Json(); - j["comm_steps"] = reshard_res.comm_steps; - j["peer_inputs"] = reshard_res.peer_inputs; - j["device_list"] = reshard_res.device_indices; - j["stage_id"] = reshard_res.stage_id; - j["virtual_stage_id"] = reshard_res.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, OutputReshardRes &reshard_res) { - GetValue(j, "comm_steps", reshard_res.comm_steps); - GetValue(j, "peer_inputs", reshard_res.peer_inputs); - GetValue(j, "device_list", reshard_res.device_indices); - TryGetValue(j, "stage_id", reshard_res.stage_id); - TryGetValue(j, "virtual_stage_id", reshard_res.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const ReshardAttr &reshard_attr) { - j = reshard_attr.reshard_infos; -} - -USED_BY_JSON void to_json(Json &j, const ShardGraphExtAttrs &shard_graph_ext_attrs) { - j = Json(); - j["dev_index_to_logic_dev_id"] = shard_graph_ext_attrs.dev_index_to_logic_dev_id; - j["graph_name_to_endpoints"] = shard_graph_ext_attrs.graph_name_to_endpoints; - j["group_name_to_dev_ids"] = shard_graph_ext_attrs.group_name_to_dev_ids; -} - -USED_BY_JSON void from_json(const Json &j, ShardGraphExtAttrs &shard_graph_ext_attrs) { - shard_graph_ext_attrs.dev_index_to_logic_dev_id = - j.at("dev_index_to_logic_dev_id").get>>(); - shard_graph_ext_attrs.graph_name_to_endpoints = - j.at("graph_name_to_endpoints").get>>>(); - shard_graph_ext_attrs.group_name_to_dev_ids = - j.at("group_name_to_dev_ids").get>>(); -} - -USED_BY_JSON void from_json(const Json &j, ReshardAttr &reshard_attr) { - reshard_attr.reshard_infos = j.get>>(); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, ShardGraphExtAttrs &shard_graph_ext_attrs) { - return ParseFromJson("ShardGraphExtAttrs", json_str, shard_graph_ext_attrs); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, DeviceIndex &device_index) { - return ParseFromJson("DeviceIndex", json_str, device_index); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, ModelIndex &model_index) { - return ParseFromJson("ModelIndex", json_str, model_index); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, PipelineConfig &pipeline_config) { - return ParseFromJson("PipelineConfig", json_str, pipeline_config); -} - - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - ReshardAttr &reshard_attr) { - return ParseFromJson("ReshardRes", json_str, reshard_attr); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - TensorDeployment &tensor_deployment) { - return ParseFromJson("TensorDeployment", json_str, tensor_deployment); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - TensorDeployments &tensor_deployments) { - return ParseFromJson("TensorDeployments", json_str, tensor_deployments); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - NodeDeployments &node_deployments) { - return ParseFromJson("NodeDeployments", json_str, node_deployments); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, CommTask &comm_task) { - return ParseFromJson("CommTask", json_str, comm_task); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, CommStep &comm_step) { - return ParseFromJson("CommStep", json_str, comm_step); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - OutputReshardRes &output_reshard_res) { - return ParseFromJson("TensorReshardInfo", json_str, output_reshard_res); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, NodeDeployment &node_deployment) { - return ParseFromJson("NodeDeployment", json_str, node_deployment); -} - -std::string TensorParallelAttrs::ToJson(const ShardGraphExtAttrs &shard_graph_ext_attrs) { - return ToJsonString(shard_graph_ext_attrs); -} - -std::string TensorParallelAttrs::ToJson(const DeviceIndex &device_index) { - return ToJsonString(device_index); -} - -std::string TensorParallelAttrs::ToJson(const ModelIndex &model_index) { - return ToJsonString(model_index); -} - -std::string TensorParallelAttrs::ToJson(const PipelineConfig &pipeline_config) { - return ToJsonString(pipeline_config); -} - -std::string TensorParallelAttrs::ToJson(const NodeDeployment &node_deployment) { - return ToJsonString(node_deployment); -} - -std::string TensorParallelAttrs::ToJson(const TensorDeployment &tensor_deployment) { - return ToJsonString(tensor_deployment); -} - -std::string TensorParallelAttrs::ToJson(const ReshardAttr &reshard_attr) { - return ToJsonString(reshard_attr); -} - -std::string TensorParallelAttrs::ToJson(const TensorDeployments &tensor_deployments) { - return ToJsonString(tensor_deployments); -} - -std::string TensorParallelAttrs::ToJson(const NodeDeployments &node_deployments) { - return ToJsonString(node_deployments); -} - -CommTaskBuilder::CommTaskBuilder() { - InitCommTaskBuilders(); - InitJsonConverters(); -} - -void CommTaskBuilder::InitCommTaskBuilders() { - builders_[kCommTaskTypeSlice] = [](const Json &j, CommTask &comm_task) { - comm_task.slice_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSliceByAxis] = [](const Json &j, CommTask &comm_task) { - comm_task.slice_by_axis_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSplit] = [](const Json &j, CommTask &comm_task) { - comm_task.split_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeConcat] = [](const Json &j, CommTask &comm_task) { - comm_task.concat_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeUniqueConcat] = [](const Json &j, CommTask &comm_task) { - comm_task.unique_concat_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeTranspose] = [](const Json &j, CommTask &comm_task) { - comm_task.transpose_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllGather] = [](const Json &j, CommTask &comm_task) { - comm_task.all_gather_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllReduce] = [](const Json &j, CommTask &comm_task) { - comm_task.all_reduce_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllReduceMean] = [](const Json &j, CommTask &comm_task) { - comm_task.all_reduce_mean_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomReduceScatter] = [](const Json &j, CommTask &comm_task) { - comm_task.reduce_scatter_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomBroadcast] = [](const Json &j, CommTask &comm_task) { - comm_task.broadcast_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllToAll] = [](const Json &j, CommTask &comm_task) { - comm_task.all_to_all_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSendReceive] = [](const Json &j, CommTask &comm_task) { - comm_task.send_recv_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeModifyValue] = [](const Json &j, CommTask &comm_task) { - comm_task.modify_value_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeReshape] = [](const Json &j, CommTask &comm_task) { - comm_task.reshape_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeCast] = [](const Json &j, CommTask &comm_task) { - comm_task.cast_reshard_task = CreateReshardTaskInfo(j); - }; -} - -template -Status CommTaskBuilder::ConvertToJson(const T *reshard_task, nlohmann::json &j) { - GE_CHECK_NOTNULL(reshard_task); - j = *reshard_task; - return SUCCESS; -} - -void CommTaskBuilder::InitJsonConverters() { - json_converters_[kCommTaskTypeSlice] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.slice_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSliceByAxis] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.slice_by_axis_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSplit] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.split_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeConcat] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.concat_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeUniqueConcat] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.unique_concat_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeTranspose] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.transpose_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllGather] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_gather_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllReduce] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_reduce_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllReduceMean] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_reduce_mean_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomReduceScatter] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.reduce_scatter_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomBroadcast] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.broadcast_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllToAll] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_to_all_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSendReceive] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.send_recv_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeModifyValue] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.modify_value_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeReshape] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.reshape_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeCast] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.cast_reshard_task.get(), j); - }; -} - -bool operator==(const DeviceIndex &lhs, const DeviceIndex &rhs) { - return lhs.engine_type == rhs.engine_type && - lhs.indices == rhs.indices; -} - -bool operator!=(const DeviceIndex &lhs, const DeviceIndex &rhs) { - return !(rhs == lhs); -} - -bool operator<(const DeviceIndex &lhs, const DeviceIndex &rhs) { - if (lhs.engine_type < rhs.engine_type) { - return true; - } - if (rhs.engine_type < lhs.engine_type) { - return false; - } - return lhs.indices < rhs.indices; -} - -bool operator==(const ModelIndex &lhs, const ModelIndex &rhs) { - return (lhs.device_index == rhs.device_index) && (lhs.virtual_stage_id == rhs.virtual_stage_id); -} - -bool operator!=(const ModelIndex &lhs, const ModelIndex &rhs) { - return !(rhs == lhs); -} - -bool operator<(const ModelIndex &lhs, const ModelIndex &rhs) { - if (lhs.virtual_stage_id < rhs.virtual_stage_id) { - return true; - } - if (rhs.virtual_stage_id < lhs.virtual_stage_id) { - return false; - } - return lhs.device_index < rhs.device_index; -} - -bool operator==(const CommStepInput &lhs, const CommStepInput &rhs) { - return (lhs.step_id == rhs.step_id) && (lhs.output_index == rhs.output_index); -} - -bool operator<(const CommStepInput &lhs, const CommStepInput &rhs) { - if (lhs.step_id < rhs.step_id) { - return true; - } - if (rhs.step_id < lhs.step_id) { - return false; - } - return lhs.output_index < rhs.output_index; -} - -bool operator==(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs) { - return (lhs.inserted_node_id == rhs.inserted_node_id) && (lhs.output_index == rhs.output_index); -} -bool operator<(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs) { - if (lhs.inserted_node_id < rhs.inserted_node_id) { - return true; - } - if (rhs.inserted_node_id < lhs.inserted_node_id) { - return false; - } - return lhs.output_index < rhs.output_index; -} - -bool operator==(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs) { - return (lhs.node_name == rhs.node_name) && (lhs.sliced_id == rhs.sliced_id); -} - -bool operator<(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs) { - if (lhs.node_name < rhs.node_name) { - return true; - } - if (rhs.node_name < lhs.node_name) { - return false; - } - return lhs.sliced_id < rhs.sliced_id; -} - -bool operator==(const DstNodeInfo &lhs, const DstNodeInfo &rhs) { - return (lhs.orig_node_info == rhs.orig_node_info) && (lhs.input_indexes == rhs.input_indexes); -} - -bool operator<(const DstNodeInfo &lhs, const DstNodeInfo &rhs) { - if (lhs.orig_node_info < rhs.orig_node_info) { - return true; - } - if (rhs.orig_node_info < lhs.orig_node_info) { - return false; - } - return lhs.InputIndexesToString() < rhs.InputIndexesToString(); -} - -bool operator==(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs) { - if ((lhs.input_info.inserted_node_id >= 0) && (rhs.input_info.inserted_node_id >= 0)) { - return (lhs.input_info == rhs.input_info); - } - if ((lhs.input_info.inserted_node_id < 0) && (rhs.input_info.inserted_node_id < 0)) { - return (lhs.input_info == rhs.input_info) && (lhs.orig_node_info == rhs.orig_node_info); - } - return false; -} -bool operator<(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs) { - if (lhs.input_info < rhs.input_info) { - return true; - } - if (rhs.input_info < lhs.input_info) { - return false; - } - return lhs.orig_node_info < rhs.orig_node_info; -} - -bool operator==(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs) { - return (lhs.input_info == rhs.input_info) && (lhs.node_info == rhs.node_info); -} - -bool operator<(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs) { - if (lhs.input_info < rhs.input_info) { - return true; - } - if (rhs.input_info < lhs.input_info) { - return false; - } - return lhs.node_info < rhs.node_info; -} - -std::string ModelIndex::DebugString() const { - return device_index.DebugString() + "[S" + std::to_string(stage_id) + ", V" + std::to_string(virtual_stage_id) + "]"; -} -} // namespace tp -} // namespace ge diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 6b04da83f7..ef42a723f2 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -1454,9 +1454,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_SHARD_GRAPH_FOR_LOAD; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_MODEL_DEPLOY_MODE; -// for tensor parallelism -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TP_RESHARD_ATTR; - // for lowering GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_FLATTEN_OFFSET; diff --git a/inc/graph/parallelism/graph_parallel_option.h b/inc/graph/parallelism/graph_parallel_option.h deleted file mode 100644 index 74196f19c8..0000000000 --- a/inc/graph/parallelism/graph_parallel_option.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ -#define METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ - -#include -#include - -namespace ge { -struct PipelineParallelOption { - bool is_enabled = false; - bool is_auto = false; - std::string pipeline_strategy; - int32_t pipe_stage_num = -1; - int32_t schedule_opt_virtual_stage_num = -1; -}; - -struct TensorParallelOption { - bool is_enabled = false; - bool is_auto = false; - int32_t tensor_parallel_size = -1; - int32_t inter_batch_flow_num = 1; -}; - -struct DataParallelOption { - bool is_enabled = false; - bool is_auto = false; - // to be deleted below - bool optimizer_state_sharding = false; - bool gradient_sharding = false; - bool model_weight_sharding = false; - bool model_weight_prefetch = true; - int32_t data_parallel_size = -1; - // model weight prefetch buffer size(MB) - uint32_t model_weight_prefetch_buffer_size = 0U; -}; - -struct TensorShardingOption { - bool is_enabled = false; - bool optimizer_state_sharding = false; - bool gradient_sharding = false; - bool model_weight_sharding = false; - bool model_weight_prefetch = true; - // model weight prefetch buffer size(MB) - uint32_t model_weight_prefetch_buffer_size = 0U; -}; - -struct OptimizerOffloadGraphOption { - bool is_enabled = false; - std::string offload; // cpu or NVME, NVME is reserved - std::string offload_path; // NVME path, reserved -}; - -struct EngineParallelOption { - bool is_enabled = false; - bool is_auto = false; - std::string config_path; // used if is_auto == true -}; - -struct GraphParallelOption { - bool auto_deploy = false; - std::string mode; // AOE mode, search_strategy/search_and_shard_graph/load_strategy/load_and_eval_strategy - std::string work_dir; // AOE dump/load path for strategies - std::string opt_level; - int32_t global_batch_size = -1; - DataParallelOption data_parallel_option; - TensorParallelOption tensor_parallel_option; - TensorShardingOption tensor_sharding_option; - PipelineParallelOption pipeline_parallel_option; - OptimizerOffloadGraphOption optimizer_offload_option; - EngineParallelOption engine_parallel_option; -}; -} // namespace ge - -#endif // METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ diff --git a/inc/graph/parallelism/tensor_parallel_attrs.h b/inc/graph/parallelism/tensor_parallel_attrs.h deleted file mode 100644 index ea56fe1544..0000000000 --- a/inc/graph/parallelism/tensor_parallel_attrs.h +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ -#define METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ - -#include -#include -#include -#include -#include -#include "external/ge_common/ge_api_types.h" - -namespace ge { -namespace tp { -constexpr const char_t *kCommTaskTypeConcat = "Concat"; -constexpr const char_t *kCommTaskTypeUniqueConcat = "UniqueConcat"; -constexpr const char_t *kCommTaskTypeModifyValue = "ModifyValue"; -constexpr const char_t *kCommTaskTypeSlice = "Slice"; -constexpr const char_t *kCommTaskTypeSliceByAxis = "SliceByAxis"; -constexpr const char_t *kCommTaskTypeSplit = "Split"; -constexpr const char_t *kCommTaskTypeTranspose = "Transpose"; -constexpr const char_t *kCommTaskTypeReshape = "Reshape"; -constexpr const char_t *kCommTaskTypeCast = "Cast"; -constexpr const char_t *kCommTaskTypeHcomAllGather = "HcomAllGather"; -constexpr const char_t *kCommTaskTypeHcomAllReduce = "HcomAllReduce"; -constexpr const char_t *kCommTaskTypeHcomAllReduceMean = "HcomAllReduceMean"; -constexpr const char_t *kCommTaskTypeHcomReduceScatter = "HcomReduceScatter"; -constexpr const char_t *kCommTaskTypeHcomBroadcast = "HcomBroadcast"; -constexpr const char_t *kCommTaskTypeHcomAllToAll = "HcomAllToAll"; -constexpr const char_t *kCommTaskTypeSendReceive = "SendReceive"; -constexpr const char_t *kCommTaskTypeLocalReduce = "LocalReduce"; -constexpr const char_t *kGraphSlicingSuffix = "_by_graph_slice_"; -constexpr const char_t *kFlowAttrEnqueuePolicyFifo = "FIFO"; -constexpr const char_t *kFlowAttrEnqueuePolicyOverwrite = "OVERWRITE"; -constexpr const char_t *kSendRecvCommTypeQueue = "Queue"; -constexpr const char_t *kSendRecvCommTypeP2p = "P2pComm"; - -// tensor deployment attrs -struct DimSlice { - int64_t begin; - int64_t end; -}; - -struct DeviceIndex { - std::string engine_type; - std::vector indices; - std::string DebugString() const; -}; - -struct ModelIndex { - // use this construct when need use stage id - ModelIndex() = default; - ModelIndex(const DeviceIndex &device_index, const int64_t stage_id, const int64_t virtual_stage_id) - : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(stage_id) {} - // use this construct when do not need use stage id - ModelIndex(const DeviceIndex &device_index, const int64_t virtual_stage_id) - : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(0L) {} - ~ModelIndex() = default; - DeviceIndex device_index; - int64_t virtual_stage_id = 0L; - int64_t stage_id = 0L; - std::string DebugString() const; -}; - -struct PipelineConfig { - int64_t micro_batch = 1L; - int64_t stage_id = 0L; - std::vector virtual_stage_id {0L}; -}; - -bool operator==(const DeviceIndex &lhs, const DeviceIndex &rhs); -bool operator!=(const DeviceIndex &lhs, const DeviceIndex &rhs); -bool operator<(const DeviceIndex &lhs, const DeviceIndex &rhs); - -bool operator==(const ModelIndex &lhs, const ModelIndex &rhs); -bool operator!=(const ModelIndex &lhs, const ModelIndex &rhs); -bool operator<(const ModelIndex &lhs, const ModelIndex &rhs); - -struct TensorSliceDeployment { - std::vector> axis_slices; - std::vector> device_indices_each_slice; - std::string reduce_type; -}; - -struct TensorDeployment { - TensorSliceDeployment shard_deployment; - std::string verbose; -}; - -struct NodeDeployment { - std::vector devices; - PipelineConfig pipeline_config; -}; - -struct NodeDeployments { - std::map deployments; -}; - -struct TensorDeployments { - std::map deployments; -}; - -// P2P communications -struct CommPair { - DeviceIndex src_device_index; - int64_t src_virtual_stage_id = 0L; - DeviceIndex dst_device_index; - int64_t dst_virtual_stage_id = 0L; -}; - -struct FlowAttr { - int32_t depth = 1; - std::string enqueue_policy = kFlowAttrEnqueuePolicyFifo; -}; - -struct SendRecvReshardTask { - std::vector comm_pairs; - std::string parallel_group; - std::string comm_type = kSendRecvCommTypeQueue; - FlowAttr flow_attr; // used when comm_type is Queue -}; - -struct CastReshardTask { - DataType dst_type = DT_MAX; -}; - -// group communications -struct CommGroup { - std::vector device_indices; -}; - -struct AllToAllReshardTask { - std::vector comm_groups; - std::string parallel_group; -}; - -struct AllGatherReshardTask { - std::vector comm_groups; - int32_t axis; // axis to concat - std::string parallel_group; - std::string output_allocator; -}; - -struct AllReduceReshardTask { - std::string reduction; - std::vector comm_groups; - std::string parallel_group; -}; - -struct AllReduceMeanReshardTask { - std::vector comm_groups; - int32_t axis; - int32_t value; - std::string parallel_group; -}; - -struct ReduceScatterReshardTask { - std::string reduction; - std::vector comm_groups; - std::string parallel_group; -}; - -struct BroadcastReshardTask { - std::vector root_device_indices; // size == num_groups - std::vector comm_groups; - std::string parallel_group; -}; - -// local reshardings -struct SliceReshardTask { - std::vector axes; - std::vector offsets; - std::vector sizes; - DeviceIndex device_index; -}; - -struct SliceByAxisReshardTask { - // key: axis to split - // value: index: slice index - // element: devices to deploy - std::map>> axis_to_slice_deployments; -}; - -struct SplitReshardTask { - int32_t split_dim = 0; - int32_t num_split = 0; -}; - -struct ConcatReshardTask { - int32_t concat_dim = 0; -}; - -struct UniqueConcatReshardTask { - std::string unique_id; - int32_t concat_dim = 0; - std::vector src_device_indices; - DeviceIndex dst_device_index; -}; - -struct TransposeReshardTask { - std::vector perm; -}; - -struct ReshapeReshardTask { - std::vector shape; -}; - -struct ModifyValueReshardTask { - std::string op_type; // mul, div - std::vector value; -}; - -struct LocalReduceReshardTask { - std::string op_type; -}; - -struct CommTask { - std::string task_type; - std::shared_ptr send_recv_reshard_task; - std::shared_ptr all_gather_reshard_task; - std::shared_ptr all_to_all_reshard_task; - std::shared_ptr all_reduce_reshard_task; - std::shared_ptr all_reduce_mean_reshard_task; - std::shared_ptr reduce_scatter_reshard_task; - std::shared_ptr broadcast_reshard_task; - std::shared_ptr split_reshard_task; - std::shared_ptr concat_reshard_task; - std::shared_ptr unique_concat_reshard_task; - std::shared_ptr slice_reshard_task; - std::shared_ptr slice_by_axis_reshard_task; - std::shared_ptr transpose_reshard_task; - std::shared_ptr modify_value_reshard_task; - std::shared_ptr local_reduce_reshard_task; - std::shared_ptr reshape_reshard_task; - std::shared_ptr cast_reshard_task; -}; - -struct CommStepInput { - int32_t step_id = -1; - int32_t output_index = -1; -}; - -bool operator==(const CommStepInput &lhs, const CommStepInput &rhs); -bool operator<(const CommStepInput &lhs, const CommStepInput &rhs); - -struct CommStep { - int32_t id; - std::vector inputs; - CommTask comm_task; -}; - -struct PeerInput { - int32_t step_id = -1; - std::string node_name; - uint32_t input_index; - int64_t stage_id = 0L; - int64_t virtual_stage_id = 0L; -}; - -// reshard ops for one output tensor -struct OutputReshardRes { - std::vector comm_steps; - std::vector peer_inputs; - std::vector device_indices; - int64_t stage_id = 0L; - int64_t virtual_stage_id = 0L; -}; - -struct ReshardAttr { - std::vector> reshard_infos; // indexed by output index -}; - -struct SrcNodeInfo { - int32_t inserted_node_id = -1; - int32_t output_index = -1; -}; -bool operator==(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); -bool operator<(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); - -struct OrigNodeInfo { - std::string node_name; - int32_t sliced_id = -1; - - std::string Name() const { - return (sliced_id == -1) ? node_name : (node_name + kGraphSlicingSuffix + std::to_string(sliced_id)); - } -}; - -bool operator==(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); -bool operator<(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); - -struct DstNodeInfo { - OrigNodeInfo orig_node_info; - std::vector input_indexes; - - std::string InputIndexesToString() const { - std::string res; - for (const uint32_t input_index : input_indexes) { - res += std::to_string(input_index) + " "; - } - return res; - } -}; - -bool operator==(const DstNodeInfo &lhs, const DstNodeInfo &rhs); -bool operator<(const DstNodeInfo &lhs, const DstNodeInfo &rhs); - -struct InsertedNodeInput { - SrcNodeInfo input_info; - OrigNodeInfo orig_node_info; -}; - -bool operator==(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); -bool operator<(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); - -struct PeerOutNodeInfo { - SrcNodeInfo input_info; - DstNodeInfo node_info; -}; - -bool operator==(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); -bool operator<(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); - -struct InsertedNodeInfo { - uint32_t id; - CommTask task; - std::vector inputs; -}; - -struct OutputSlicedRes { - std::vector inserted_nodes_info; - std::vector peer_out_nodes; -}; - -struct SlicedEdgeInfo { - std::vector steps_sliced; -}; - -struct TensorShapeSlicedInfo { - std::vector> axis_slices; -}; - -struct NodeSliceStrategy { - std::map input_shape_sliced_info; - std::map output_shape_sliced_info; - - std::map outputs_sliced_edge_info; - std::vector>> dependencies; - size_t size = 1U; -}; - -struct ShardGraphExtAttrs { - // ExtAttr _device_index_to_logic_device_id, key is DeviceIndex, value is logic device id - std::map> dev_index_to_logic_dev_id; - // ExtAttr _model_events, key1 is graph name, key2 is endpoint name, value is serialized endpoints - std::map>> graph_name_to_endpoints; - // ExtAttr _hcomgroups, key is group name, value is device ids - std::map> group_name_to_dev_ids; -}; - -class TensorParallelAttrs { - public: - static Status FromJson(const std::string &json_str, DeviceIndex &device_index); - static Status FromJson(const std::string &json_str, ModelIndex &model_index); - static Status FromJson(const std::string &json_str, PipelineConfig &pipeline_config); - static Status FromJson(const std::string &json_str, NodeDeployment &node_deployment); - static Status FromJson(const std::string &json_str, TensorDeployment &tensor_deployment); - static Status FromJson(const std::string &json_str, TensorDeployments &tensor_deployments); - static Status FromJson(const std::string &json_str, NodeDeployments &node_deployments); - static Status FromJson(const std::string &json_str, CommTask &comm_task); - static Status FromJson(const std::string &json_str, CommStep &comm_step); - static Status FromJson(const std::string &json_str, OutputReshardRes &output_reshard_res); - static Status FromJson(const std::string &json_str, ReshardAttr &reshard_attr); - static Status FromJson(const std::string &json_str, ShardGraphExtAttrs &shard_graph_ext_attrs); - - static std::string ToJson(const NodeDeployment &node_deployment); - static std::string ToJson(const DeviceIndex &device_index); - static std::string ToJson(const ModelIndex &model_index); - static std::string ToJson(const PipelineConfig &pipeline_config); - static std::string ToJson(const TensorDeployment &tensor_deployment); - static std::string ToJson(const NodeDeployments &node_deployments); - static std::string ToJson(const ReshardAttr &reshard_attr); - static std::string ToJson(const TensorDeployments &tensor_deployments); - static std::string ToJson(const ShardGraphExtAttrs &shard_graph_ext_attrs); -}; -} // namespace tp -} // namespace ge - -#endif // METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ diff --git a/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc b/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc deleted file mode 100644 index e4399914c8..0000000000 --- a/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc +++ /dev/null @@ -1,953 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "nlohmann/json.hpp" - -#include "graph/parallelism/tensor_parallel_attrs.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "common/ge_common/ge_inner_error_codes.h" - -using namespace testing; -namespace ge { -namespace tp { -using Json = nlohmann::json; - -class TensorParallelAttrsTest : public testing::Test { - protected: - static void TestToAndFromJson(const CommTask &comm_task, CommTask &out_comm_task) { - ReshardAttr reshard_attr; - OutputReshardRes output_reshard_res; - CommStep comm_step; - comm_step.id = 0; - comm_step.comm_task = comm_task; - output_reshard_res.comm_steps.emplace_back(comm_step); - reshard_attr.reshard_infos.emplace_back(std::vector{output_reshard_res}); - const auto &json_str = TensorParallelAttrs::ToJson(reshard_attr); - ASSERT_TRUE(!json_str.empty()); - std::cout << json_str << std::endl; - ReshardAttr reshard_attr_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, reshard_attr_from_json), SUCCESS); - out_comm_task = reshard_attr_from_json.reshard_infos[0][0].comm_steps[0].comm_task; - } -}; - -TEST_F(TensorParallelAttrsTest, ParseFailed_InvalidJsonStr) { - std::string json_str = "invalid"; - DeviceIndex device_index; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, device_index), PARAM_INVALID); -} - -TEST_F(TensorParallelAttrsTest, ParseFailed_FieldMismatches) { - std::string json_str = R"( - {"engine_type": "NPU", "index": 0} -)"; - DeviceIndex device_index; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, device_index), PARAM_INVALID); -} - -TEST_F(TensorParallelAttrsTest, DeviceIndex_ToAndFromJsonStr) { - DeviceIndex device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(device_index); - DeviceIndex another_device_index; - TensorParallelAttrs::FromJson(json_str, another_device_index); - EXPECT_EQ(device_index, another_device_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_ToAndFromJsonStr) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 0; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index; - TensorParallelAttrs::FromJson(json_str, another_model_index); - EXPECT_TRUE(model_index.DebugString() == "MyEngine[0, 1, 2][S0, V0]"); - EXPECT_EQ(model_index, another_model_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_NotEqual) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 0; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index(model_index); - another_model_index.virtual_stage_id = 1; - EXPECT_TRUE(model_index != another_model_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_LessBigger) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 1; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index(model_index); - another_model_index.virtual_stage_id = 2; - EXPECT_TRUE(model_index < another_model_index); - EXPECT_FALSE(another_model_index < model_index); - another_model_index.virtual_stage_id = 1; - EXPECT_FALSE(another_model_index < model_index); -} - -TEST_F(TensorParallelAttrsTest, PipelineConfig_ToAndFromJsonStr) { - PipelineConfig pipeline_config; - pipeline_config.micro_batch = 1; - pipeline_config.stage_id = 0; - pipeline_config.virtual_stage_id = {0, 1}; - std::string json_str = TensorParallelAttrs::ToJson(pipeline_config); - PipelineConfig another_pipeline_config; - TensorParallelAttrs::FromJson(json_str, another_pipeline_config); - EXPECT_EQ(pipeline_config.micro_batch, another_pipeline_config.micro_batch); - EXPECT_EQ(pipeline_config.stage_id, another_pipeline_config.stage_id); - EXPECT_EQ(pipeline_config.virtual_stage_id, another_pipeline_config.virtual_stage_id); -} - -TEST_F(TensorParallelAttrsTest, DeviceIndex_operators) { - std::map device_index_to_value; - DeviceIndex device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "MyEngine"; - - DeviceIndex another_device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "CPU"; - ASSERT_TRUE(device_index_to_value.emplace(device_index, 0).second); - ASSERT_FALSE(device_index_to_value.emplace(device_index, 0).second); - ASSERT_FALSE(device_index.DebugString().empty()); - ASSERT_EQ(device_index_to_value.count(another_device_index), 0U); - ASSERT_NE(device_index, another_device_index); - ASSERT_EQ(device_index, device_index); -} - -TEST_F(TensorParallelAttrsTest, NodeDeployment_ToAndFromJsonStr) { - DeviceIndex device_index_0; - device_index_0.indices = {0, 0, 1}; - device_index_0.engine_type = "CPU"; - - DeviceIndex device_index_1; - device_index_1.indices = {0, 0, 2}; - device_index_1.engine_type = "NPU"; - - NodeDeployment node_deployment; - node_deployment.devices = {device_index_0, device_index_1}; - std::string json_str = TensorParallelAttrs::ToJson(node_deployment); - NodeDeployment another_node_deployment; - TensorParallelAttrs::FromJson(json_str, another_node_deployment); - EXPECT_EQ(node_deployment.devices, another_node_deployment.devices); -} - -TEST_F(TensorParallelAttrsTest, ParseSendRecvTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "SendReceive", - "comm_pairs": [ - { - "src_device_index": {"engine_type": "NPU", "index": [0, 0, 1]}, - "src_virtual_stage_id": 0, - "dst_device_index": {"engine_type": "NPU", "index": [0, 0, 2]}, - "dst_virtual_stage_id": 0 - } - ], - "comm_type": "Queue", - "flow_attr": { - "depth":128, - "enqueue_policy":"FIFO" - } -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.send_recv_reshard_task != nullptr); - ASSERT_EQ(comm_task.send_recv_reshard_task->comm_pairs.size(), 1U); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_pairs[0].src_device_index.indices, (std::vector{0, 0, 1})); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_pairs[0].dst_device_index.indices, (std::vector{0, 0, 2})); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_type, kSendRecvCommTypeQueue); - EXPECT_EQ(comm_task.send_recv_reshard_task->flow_attr.depth, 128); - EXPECT_EQ(comm_task.send_recv_reshard_task->flow_attr.enqueue_policy, kFlowAttrEnqueuePolicyFifo); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); -} - -TEST_F(TensorParallelAttrsTest, ParseAllGatherCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllGather", - "parallel_group": "-1", - "output_allocator": "BufferPool", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ], - "axis": 0 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.all_gather_reshard_task != nullptr); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->comm_groups.size(), 4); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->parallel_group, "-1"); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->output_allocator, "BufferPool"); -} - -TEST_F(TensorParallelAttrsTest, ParseAllReduceCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllReduce", - "reduction": "sum", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.all_reduce_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.all_reduce_reshard_task->reduction, "sum"); - EXPECT_EQ(out_comm_task.all_reduce_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseAllReduceMeanCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllReduceMean", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ], - "axis": 0, - "value": 2 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.all_reduce_mean_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.all_reduce_mean_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseReduceScatterCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomReduceScatter", - "reduction": "sum", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.reduce_scatter_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.reduce_scatter_reshard_task->reduction, "sum"); - EXPECT_EQ(out_comm_task.reduce_scatter_reshard_task->comm_groups.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, ParseAllToAllCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllToAll", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.all_to_all_reshard_task != nullptr); - EXPECT_EQ(out_comm_task.all_to_all_reshard_task->comm_groups.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, ParseSliceCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Slice", - "offsets": [2, 4], - "size": [4, 8], - "device_index":{"engine_type": "NPU", "index": [0]} -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.slice_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.slice_reshard_task->offsets, (std::vector{2, 4})); - ASSERT_EQ(out_comm_task.slice_reshard_task->sizes, (std::vector{4, 8})); - ASSERT_EQ(out_comm_task.slice_reshard_task->device_index.engine_type, "NPU"); - ASSERT_EQ(out_comm_task.slice_reshard_task->device_index.indices, (std::vector{0})); -} - -TEST_F(TensorParallelAttrsTest, ParseSliceByAxisCommTask) { - CommTask comm_task; - comm_task.task_type = "SliceByAxis"; -// ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - comm_task.slice_by_axis_reshard_task = std::make_shared(); - auto &axis_to_slice_deployments = comm_task.slice_by_axis_reshard_task->axis_to_slice_deployments; - std::vector dim_0_slice_0_deployments{DeviceIndex{"NPU", {0, 0, 0}}, DeviceIndex{"NPU", {0, 0, 1}}}; - std::vector dim_0_slice_1_deployments{DeviceIndex{"NPU", {0, 0, 2}}, DeviceIndex{"NPU", {0, 0, 3}}}; - std::vector dim_1_slice_0_deployments{DeviceIndex{"NPU", {0, 1, 0}}, DeviceIndex{"NPU", {0, 1, 1}}}; - std::vector dim_1_slice_1_deployments{DeviceIndex{"NPU", {0, 1, 2}}, DeviceIndex{"NPU", {0, 1, 3}}}; - - axis_to_slice_deployments[0].emplace_back(dim_0_slice_0_deployments); - axis_to_slice_deployments[0].emplace_back(dim_0_slice_1_deployments); - axis_to_slice_deployments[1].emplace_back(dim_1_slice_0_deployments); - axis_to_slice_deployments[2].emplace_back(dim_1_slice_1_deployments); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.slice_by_axis_reshard_task != nullptr); -} - -TEST_F(TensorParallelAttrsTest, ParseSplitCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Split", - "split_dim": 1, - "num_split": 2 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.split_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.split_reshard_task->split_dim, 1); - ASSERT_EQ(out_comm_task.split_reshard_task->num_split, 2); -} - -TEST_F(TensorParallelAttrsTest, ParseConcatCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Concat", - "concat_dim": 1 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.concat_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.concat_reshard_task->concat_dim, 1); -} - -TEST_F(TensorParallelAttrsTest, ParseUniqueConcatCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "UniqueConcat", - "unique_id": "0:1", - "concat_dim": 1, - "src_device_indices": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - "dst_device_index": {"engine_type": "HOST_CPU", "index": [0, 0, 1]} -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.unique_concat_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->concat_dim, 1); - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->src_device_indices.size(), 4); - DeviceIndex device_index{"HOST_CPU", {0, 0, 1}}; - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->dst_device_index, device_index); -} - -TEST_F(TensorParallelAttrsTest, ParseTransposeTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Transpose", - "perm": [1, 0, 2, 3] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.transpose_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.transpose_reshard_task->perm, (std::vector{1, 0, 2, 3})); -} - -TEST_F(TensorParallelAttrsTest, ParseReshapeTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Reshape", - "shape": [1, 1, 2, 3] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.reshape_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.reshape_reshard_task->shape, (std::vector{1, 1, 2, 3})); -} - -TEST_F(TensorParallelAttrsTest, ParseCastTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Cast", - "dst_type": 1 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.cast_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.cast_reshard_task->dst_type, DT_FLOAT16); -} - -TEST_F(TensorParallelAttrsTest, ParseModifyValueCommTask) { - const std::string &json_str = R"( -{ - "task_type": "ModifyValue", - "op_type": "Mul", - "value": [1, 2] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.modify_value_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.modify_value_reshard_task->op_type, "Mul"); - ASSERT_EQ(out_comm_task.modify_value_reshard_task->value, (std::vector{1, 2})); -} - -TEST_F(TensorParallelAttrsTest, ParseBroadcastCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomBroadcast", - "roots": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 6]} - ], - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(comm_task.broadcast_reshard_task != nullptr); - std::vector root_device_index; - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 0}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 2}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 4}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 6}}); - EXPECT_EQ(out_comm_task.broadcast_reshard_task->root_device_indices, root_device_index); - EXPECT_EQ(out_comm_task.broadcast_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseParseCommStep) { - const std::string &json_str = - R"( -{ - "id": 2, - "input_ids": [[0, 0], [1, 0]], - "comm_task": { - "task_type": "Split", - "num_split": 2, - "split_dim": 1 - } -} - )"; - CommStep comm_step; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_step), SUCCESS); - EXPECT_TRUE(comm_step.comm_task.split_reshard_task != nullptr); - EXPECT_EQ(comm_step.id, 2); -} - -TEST_F(TensorParallelAttrsTest, ParseTensorReshardInfo) { - const std::string &json_str = - R"( -{ - "output_index": 1, - "device_list": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - "comm_steps": [ - { - "id": 1, - "input_ids": [], - "comm_task": { - "task_type": "SplitVD", - "size_splits": [2, 4], - "split_dim": 1 - } - }, - { - "id": 2, - "input_ids": [[1, 0]], - "comm_task": { - "task_type": "SplitVD", - "size_splits": [2, 4], - "split_dim": 1 - } - } - ], - "peer_inputs": [ - {"step_id": 1, "node_name": "dst_node", "input_index": 0, "stage_id": 0, "virtual_stage_id": 0}, - {"step_id": 1, "node_name": "dst_node_1", "input_index": 1, "stage_id": 0, "virtual_stage_id": 0} - ], - "stage_id":1, - "virtual_stage_id":2 -} -)"; - OutputReshardRes tensor_reshard_info; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_reshard_info), SUCCESS); - EXPECT_EQ(tensor_reshard_info.comm_steps.size(), 2); - ASSERT_EQ(tensor_reshard_info.peer_inputs.size(), 2); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].step_id, 1); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].node_name, "dst_node"); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].input_index, 0); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].node_name, "dst_node_1"); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].step_id, 1); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].input_index, 1); - EXPECT_EQ(tensor_reshard_info.stage_id, 1); - EXPECT_EQ(tensor_reshard_info.virtual_stage_id, 2); -} - -TEST_F(TensorParallelAttrsTest, ReshardAttrToAndFromJson) { - const std::string &json_str = - R"( -[ - [ - { - "device_list": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - "comm_steps": [ - { - "id": 1, - "input_ids": [], - "comm_task": { - "task_type": "Split", - "num_split": 3, - "split_dim": 1 - } - }, - { - "id": 2, - "input_ids": [[1, 0]], - "comm_task": { - "task_type": "Split", - "num_split": 4, - "split_dim": 1 - } - } - ], - "peer_inputs": [ - {"step_id": 1, "node_name": "dst_node", "input_index": 0, "stage_id": 0, "virtual_stage_id": 0}, - {"step_id": 1, "node_name": "dst_node_1", "input_index": 1, "stage_id": 0, "virtual_stage_id": 0} - ] - } - ] -] -)"; - ReshardAttr reshard_attr; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, reshard_attr), SUCCESS); - ASSERT_EQ(reshard_attr.reshard_infos.size(), 1); - auto str = TensorParallelAttrs::ToJson(reshard_attr); - ASSERT_EQ(TensorParallelAttrs::FromJson(str, reshard_attr), SUCCESS); -} - -TEST_F(TensorParallelAttrsTest, TensorDeploymentToAndFromJson) { - const std::string &json_str = - R"( -{ - "shard_deployment": { - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 0, 0]}], - [{"engine_type": "NPU", "index": [0, 0, 1]}], - [{"engine_type": "NPU", "index": [0, 0, 2]}], - [{"engine_type": "NPU", "index": [0, 0, 3]}] - ], - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ] - }, - "verbose" : "verbose_val" -} -)"; - TensorDeployment tensor_deployment; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_deployment), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(tensor_deployment); - - TensorDeployment tensor_deployment_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, tensor_deployment_from_json), SUCCESS); - const auto &tensor_slice_deployment = tensor_deployment_from_json.shard_deployment; - EXPECT_EQ(tensor_slice_deployment.device_indices_each_slice.size(), 4); - EXPECT_EQ(tensor_slice_deployment.axis_slices.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, TensorDeploymentsToAndFromJson) { - const std::string &json_str = - R"( -{ - "deployments": [ - [1, { - "shard_deployment": { - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ], - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 0, 0]}], - [{"engine_type": "NPU", "index": [0, 0, 1]}], - [{"engine_type": "NPU", "index": [0, 0, 2]}], - [{"engine_type": "NPU", "index": [0, 0, 3]}] - ] - }, - "verbose": "verbose_val" - }], - [2, { - "shard_deployment": { - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ], - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 1, 0]}], - [{"engine_type": "NPU", "index": [0, 1, 1]}], - [{"engine_type": "NPU", "index": [0, 1, 2]}], - [{"engine_type": "NPU", "index": [0, 1, 3]}] - ] - }, - "verbose": "verbose_val" - }] - ] -} -)"; - TensorDeployments tensor_deployments; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_deployments), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(tensor_deployments); - ASSERT_EQ(tensor_deployments.deployments.size(), 2); - - TensorDeployments tensor_deployments_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, tensor_deployments_from_json), SUCCESS); - const auto &tensor_slice_deployment = tensor_deployments_from_json.deployments[1].shard_deployment; - EXPECT_EQ(tensor_slice_deployment.device_indices_each_slice.size(), 4); - EXPECT_EQ(tensor_slice_deployment.axis_slices.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, NodeDeploymentsToAndFromJson) { - const std::string &json_str = - R"( -{ - "deployments": [ - [1, { - "devices": [{ - "engine_type": "CPU", - "index": [0, 0, 1] - }, { - "engine_type": "NPU", - "index": [0, 0, 3] - }], - "pipeline_config": { - "micro_batch": 1, - "stage_id": 0, - "virtual_stage_id": [] - } - }], - [2, { - "devices": [{ - "engine_type": "", - "index": [] - }], - "pipeline_config": { - "micro_batch": 1, - "stage_id": 0, - "virtual_stage_id": [] - } - }] - ] -} -)"; - NodeDeployments node_deployments; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, node_deployments), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(node_deployments); - ASSERT_EQ(node_deployments.deployments.size(), 2); - - NodeDeployments node_deployments_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, node_deployments_from_json), SUCCESS); - const auto &devices = node_deployments_from_json.deployments[1].devices; - EXPECT_EQ(devices.size(), 2); - EXPECT_TRUE(devices[0].engine_type == "CPU"); - EXPECT_TRUE(devices[1].engine_type == "NPU"); -} - -TEST_F(TensorParallelAttrsTest, ShardGraphExtAttrsToAndFromJson) { - const std::string &json_str = - R"( -{ - "dev_index_to_logic_dev_id": [ - [{ - "engine_type": "NPU", - "index": [0, 0, 0] - }, - [1, 0, 0] - ], - [{ - "engine_type": "NPU", - "index": [0, 0, 1] - }, - [1, 0, 1] - ] - ], - "graph_name_to_endpoints": { - "test_graph1": { - "endpoint1": ["SerializedString1"], - "endpoint2": ["SerializedString2"] - }, - "test_graph2": { - "endpoint1": ["SerializedString3"], - "endpoint2": ["SerializedString4"] - } - }, - "group_name_to_dev_ids": { - "group1": ["0:0:0:0", "0:0:1:0"], - "group2": ["0:0:0:1", "0:0:1:1"] - } -} -)"; - ShardGraphExtAttrs shard_graph_ext_attrs; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, shard_graph_ext_attrs), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(shard_graph_ext_attrs); - ShardGraphExtAttrs shard_graph_ext_attrs_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, shard_graph_ext_attrs_from_json), SUCCESS); - EXPECT_EQ(shard_graph_ext_attrs_from_json.graph_name_to_endpoints, shard_graph_ext_attrs.graph_name_to_endpoints); - EXPECT_EQ(shard_graph_ext_attrs_from_json.dev_index_to_logic_dev_id, - shard_graph_ext_attrs.dev_index_to_logic_dev_id); - EXPECT_EQ(shard_graph_ext_attrs_from_json.group_name_to_dev_ids, shard_graph_ext_attrs.group_name_to_dev_ids); -} - -TEST_F(TensorParallelAttrsTest, StructCmp) { - SrcNodeInfo src_node_info; - src_node_info.inserted_node_id = 0; - src_node_info.output_index = 0; - SrcNodeInfo src_node_info1; - src_node_info1.inserted_node_id = 0; - src_node_info1.output_index = 0; - EXPECT_EQ(src_node_info == src_node_info1, true); - SrcNodeInfo src_node_info2; - src_node_info2.inserted_node_id = 0; - src_node_info2.output_index = 1; - EXPECT_EQ(src_node_info < src_node_info2, true); - SrcNodeInfo src_node_info3; - src_node_info3.inserted_node_id = 1; - src_node_info3.output_index = 1; - EXPECT_EQ(src_node_info < src_node_info3, true); - EXPECT_EQ(src_node_info3 < src_node_info1, false); - - OrigNodeInfo orig_node_info; - orig_node_info.node_name = "node"; - orig_node_info.sliced_id = 0; - DstNodeInfo dst_node_info; - dst_node_info.orig_node_info = orig_node_info; - dst_node_info.input_indexes = {0}; - InsertedNodeInput inserted_node_input; - inserted_node_input.orig_node_info = orig_node_info; - inserted_node_input.input_info = src_node_info; - PeerOutNodeInfo peer_out_node_info; - peer_out_node_info.input_info = src_node_info; - peer_out_node_info.node_info = dst_node_info; - - OrigNodeInfo orig_node_info1; - orig_node_info1.node_name = "node"; - orig_node_info1.sliced_id = 0; - DstNodeInfo dst_node_info1; - dst_node_info1.orig_node_info = orig_node_info1; - dst_node_info1.input_indexes = {0}; - InsertedNodeInput inserted_node_input1; - inserted_node_input1.orig_node_info = orig_node_info1; - inserted_node_input1.input_info = src_node_info1; - PeerOutNodeInfo peer_out_node_info1; - peer_out_node_info1.input_info = src_node_info1; - peer_out_node_info1.node_info = dst_node_info1; - - EXPECT_EQ(orig_node_info1 == orig_node_info, true); - EXPECT_EQ(dst_node_info == dst_node_info1, true); - EXPECT_EQ(inserted_node_input == inserted_node_input1, true); - EXPECT_EQ(peer_out_node_info == peer_out_node_info1, true); - - OrigNodeInfo orig_node_info2; - orig_node_info2.node_name = "node1"; - orig_node_info2.sliced_id = 0; - DstNodeInfo dst_node_info2; - dst_node_info2.orig_node_info = orig_node_info2; - dst_node_info2.input_indexes = {0}; - InsertedNodeInput inserted_node_input2; - inserted_node_input2.orig_node_info = orig_node_info2; - inserted_node_input2.input_info = src_node_info2; - PeerOutNodeInfo peer_out_node_info2; - peer_out_node_info2.input_info = src_node_info2; - peer_out_node_info2.node_info = dst_node_info2; - - EXPECT_EQ(orig_node_info < orig_node_info2, true); - EXPECT_EQ(orig_node_info2 < orig_node_info, false); - EXPECT_EQ(dst_node_info < dst_node_info2, true); - EXPECT_EQ(dst_node_info2 < dst_node_info, false); - EXPECT_EQ(inserted_node_input < inserted_node_input2, true); - EXPECT_EQ(inserted_node_input2 < inserted_node_input, false); - EXPECT_EQ(peer_out_node_info < peer_out_node_info2, true); - EXPECT_EQ(peer_out_node_info2 < peer_out_node_info, false); - - OrigNodeInfo orig_node_info3; - orig_node_info3.node_name = "node"; - orig_node_info3.sliced_id = 1; - DstNodeInfo dst_node_info3; - dst_node_info3.orig_node_info = orig_node_info3; - dst_node_info3.input_indexes = {0, 1}; - InsertedNodeInput inserted_node_input3; - inserted_node_input3.orig_node_info = orig_node_info3; - inserted_node_input3.input_info = src_node_info3; - PeerOutNodeInfo peer_out_node_info3; - peer_out_node_info3.input_info = src_node_info3; - peer_out_node_info3.node_info = dst_node_info3; - - EXPECT_EQ(orig_node_info < orig_node_info3, true); - EXPECT_EQ(dst_node_info < dst_node_info3, true); - EXPECT_EQ(inserted_node_input < inserted_node_input3, true); - EXPECT_EQ(peer_out_node_info < peer_out_node_info3, true); -} -} // namespace tp -} // namespace ge -- Gitee