diff --git a/graph/ascendc_ir/generator/ascir_register.cc b/graph/ascendc_ir/generator/ascir_register.cc index 0cb3f07a07935ca0cae7ca59eab06f62556ee8d9..d7b7f53e01ac47b2b4113605362616ec5dc22d51 100644 --- a/graph/ascendc_ir/generator/ascir_register.cc +++ b/graph/ascendc_ir/generator/ascir_register.cc @@ -11,55 +11,51 @@ #include "graph/types.h" namespace ge { namespace ascir { -AscirRegister::AscirRegister(const char *type, const char *def_file_path, int64_t line) - : ir_def_{} { - ir_def_.type = type; - ir_def_.file_path = def_file_path; - ir_def_.line = line; - ir_def_.start_node = false; +AscirRegister::AscirRegister(const char *type, const char *def_file_path, int64_t line) : ir_def_{} { + ir_def_.Init(type, def_file_path, line); } AscirRegister &AscirRegister::Inputs(std::vector &&input_names) { - for (const auto &input_name: input_names) { - ir_def_.input_defs.emplace_back(input_name.GetString(), ge::IrInputType::kIrInputRequired); + for (const auto &input_name : input_names) { + ir_def_.AppendInput(input_name.GetString(), ge::IrInputType::kIrInputRequired); } return *this; } AscirRegister &AscirRegister::DynamicInput(const std::string &input_name) { - ir_def_.input_defs.emplace_back(input_name, ge::IrInputType::kIrInputDynamic); + ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputDynamic); return *this; } AscirRegister &AscirRegister::OptionalInput(const std::string &input_name) { - ir_def_.input_defs.emplace_back(input_name, ge::IrInputType::kIrInputOptional); + ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputOptional); return *this; } AscirRegister &AscirRegister::Outputs(std::vector &&output_names) { - for (const auto &output_name: output_names) { - ir_def_.output_defs.emplace_back(output_name.GetString(), ge::IrOutputType::kIrOutputRequired); + for (const auto &output_name : output_names) { + ir_def_.AppendOutput(output_name.GetString(), ge::IrOutputType::kIrOutputRequired); } return *this; } AscirRegister &AscirRegister::DynamicOutput(const std::string &output_name) { - ir_def_.output_defs.emplace_back(output_name, ge::IrOutputType::kIrOutputDynamic); + ir_def_.AppendOutput(output_name, ge::IrOutputType::kIrOutputDynamic); return *this; } AscirRegister::AscirRegister(const AscirRegister &other) { - AscirRegistry::GetInstance().RegisterAscIr(other.ir_def_.type, other.ir_def_); + AscirRegistry::GetInstance().RegisterAscIr(other.ir_def_.GetType(), other.ir_def_); } AscirRegister &AscirRegister::Attr(std::string name, std::string asc_type, std::string ge_type) { if (ir_def_.IsAttrExisted(name)) { return *this; } - ir_def_.attr_defs.emplace_back(AscIrAttrDef{std::move(name), std::move(asc_type), std::move(ge_type)}); + ir_def_.SetAttr(name, asc_type, ge_type); return *this; } AscirRegister &AscirRegister::StartNode() { - ir_def_.start_node = true; + ir_def_.StartNode(); return *this; } AscirRegister &AscirRegister::InferDataType(AscIrDef::CodeGenerator infer_data_type_generator) { @@ -72,65 +68,62 @@ AscirRegister &AscirRegister::InferView(AscIrDef::CodeGenerator infer_view_gener } AscirRegister &AscirRegister::Views(const std::vector &views_policy) { - ir_def_.output_views_policy = views_policy; + ir_def_.SetViewPolicy(views_policy); return InferView(InferViewByPolicy); } AscirRegister &AscirRegister::DataTypes(const std::vector &data_types_policy) { - ir_def_.output_dtypes_policy = data_types_policy; + ir_def_.SetDtypePolicy(data_types_policy); return InferDataType(InferDtypeByPolicy); } AscirRegister &AscirRegister::Input(const char_t *input_name, const char_t *datatype_symbol) { - ir_def_.input_defs.emplace_back(input_name, ge::IrInputType::kIrInputRequired); - ir_def_.dtype_symbol_store.SetInputSymbol(input_name, ge::kIrInputRequired, datatype_symbol); + ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputRequired); + ir_def_.MutableDataTypeSymbolStore().SetInputSymbol(input_name, ge::kIrInputRequired, datatype_symbol); return *this; } AscirRegister &AscirRegister::Output(const char_t *output_name, const char_t *datatype_symbol) { - ir_def_.output_defs.emplace_back(output_name, ge::IrOutputType::kIrOutputRequired); - ir_def_.dtype_symbol_store.SetOutputSymbol(output_name, ge::kIrOutputRequired, datatype_symbol); + ir_def_.AppendOutput(output_name, ge::IrOutputType::kIrOutputRequired); + ir_def_.MutableDataTypeSymbolStore().SetOutputSymbol(output_name, ge::kIrOutputRequired, datatype_symbol); return *this; } AscirRegister &AscirRegister::DataType(const char_t *datatype_symbol, const TensorType &type_range) { - ir_def_.dtype_symbol_store.DeclareSymbol(datatype_symbol, type_range); + ir_def_.MutableDataTypeSymbolStore().DeclareSymbol(datatype_symbol, type_range); return *this; } AscirRegister &AscirRegister::DynamicInput(const char_t *input_name, const char_t *datatype_symbol) { - ir_def_.input_defs.emplace_back(input_name, ge::IrInputType::kIrInputDynamic); - ir_def_.dtype_symbol_store.SetInputSymbol(input_name, ge::kIrInputDynamic, datatype_symbol); + ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputDynamic); + ir_def_.MutableDataTypeSymbolStore().SetInputSymbol(input_name, ge::kIrInputDynamic, datatype_symbol); return *this; } AscirRegister &AscirRegister::DataType(const char_t *datatype_symbol, const OrderedTensorTypeList &type_range) { - ir_def_.dtype_symbol_store.DeclareSymbol(datatype_symbol, type_range); + ir_def_.MutableDataTypeSymbolStore().DeclareSymbol(datatype_symbol, type_range); return *this; } -AscirRegister &AscirRegister::CalcTmpBufSize(const std::string calc_tmp_buf_size_func) { - if (!ir_def_.calc_tmp_buf_size_func.func_name.empty()) { - GELOGE(ge::FAILED, "has registered calc_tmp_buf_size_func: %s", ir_def_.calc_tmp_buf_size_func.func_name.c_str()); - return *this; - } - ir_def_.calc_tmp_buf_size_func = CalcTmpBufSizeFunc{std::move(calc_tmp_buf_size_func), CalcTmpBufSizeFuncType::CustomizeType}; +AscirRegister &AscirRegister::CalcTmpBufSize(const std::string &calc_tmp_buf_size_func) { + ir_def_.SetCalcTmpBufSizeFunc(calc_tmp_buf_size_func, CalcTmpBufSizeFuncType::CustomizeType); return *this; } AscirRegister &AscirRegister::SameTmpBufSizeFromFirstInput() { - if (!ir_def_.calc_tmp_buf_size_func.func_name.empty()) { - GELOGE(ge::FAILED, "has registered calc_tmp_buf_size_func: %s", ir_def_.calc_tmp_buf_size_func.func_name.c_str()); - return *this; - } - ir_def_.calc_tmp_buf_size_func = CalcTmpBufSizeFunc{"SameTmpBufSizeWithFirstInput", CalcTmpBufSizeFuncType::CommonType}; + ir_def_.SetCalcTmpBufSizeFunc("SameTmpBufSizeWithFirstInput", CalcTmpBufSizeFuncType::CommonType); return *this; } AscirRegister &AscirRegister::ApiTilingDataType(const std::string &tiling_data_name) { - if (!ir_def_.tiling_data_name.empty()) { - GELOGE(ge::FAILED, "%s has registered tiling data: %s", ir_def_.type.c_str(), ir_def_.tiling_data_name.c_str()); - return *this; - } - ir_def_.tiling_data_name = std::move(tiling_data_name); + ir_def_.SetApiTilingDataName(tiling_data_name); return *this; } +AscirRegister &AscirRegister::Impl(const std::vector &soc_version, const AscIrDef::AscIrImpl &impl) { + ir_def_.AddAscIrImpl(soc_version, impl); + return *this; +} + +size_t AscirRegister::GetSocImplSize() const { + return ir_def_.GetSocImplSize(); +} + template<> AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { return Attr(name.GetString(), "float", "Float"); @@ -166,8 +159,8 @@ AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { } AscirRegister &AscirRegister::Comment(const string &comment) { - ir_def_.comment = comment; + ir_def_.SetComment(comment); return *this; } } // namespace ascir -} +} // namespace ge diff --git a/graph/ascendc_ir/generator/ascir_registry.cc b/graph/ascendc_ir/generator/ascir_registry.cc index a13ccfb8d5fb454df497a532c3ca9de1e62f73ab..ecfe6c6dfef98067ab0959a8a80a5c2d591611db 100644 --- a/graph/ascendc_ir/generator/ascir_registry.cc +++ b/graph/ascendc_ir/generator/ascir_registry.cc @@ -10,6 +10,153 @@ #include "graph/ascendc_ir/ascir_registry.h" namespace ge { namespace ascir { + +struct AscIrDefImpl { + std::string file_path; + int64_t line{}; + std::string type; + std::vector> input_defs; + std::vector> output_defs; + std::vector attr_defs; + + std::vector output_views_policy; + std::vector output_dtypes_policy; + + bool start_node{false}; + IRDataTypeSymbolStore dtype_symbol_store; + std::string comment; + CalcTmpBufSizeFunc calc_tmp_buf_size_func; + std::string tiling_data_name; + std::map soc_2_impl_; +}; + +AscIrDef::AscIrDef() { + impl_ = std::make_shared(); +} + +bool AscIrDef::IsAttrExisted(const std::string &attr_name) const { + return std::find_if(impl_->attr_defs.begin(), impl_->attr_defs.end(), + [&attr_name](const AscIrAttrDef &asc_ir_attr_def) { + return asc_ir_attr_def.name == attr_name; + }) != impl_->attr_defs.end(); +} + +void AscIrDef::Init(const char *type, const char *def_file_path, int64_t line) const { + impl_->type = type; + impl_->file_path = def_file_path; + impl_->line = line; + impl_->start_node = false; +} + +const std::vector> &AscIrDef::GetInputDefs() const { + return impl_->input_defs; +} +const std::vector> &AscIrDef::GetOutputDefs() const { + return impl_->output_defs; +} + +void AscIrDef::AppendInput(const std::string &name, ge::IrInputType type) const { + impl_->input_defs.emplace_back(name, type); +} +void AscIrDef::AppendOutput(const std::string &name, ge::IrOutputType type) const { + impl_->output_defs.emplace_back(name, type); +} +const std::string &AscIrDef::GetType() const { + return impl_->type; +} +void AscIrDef::StartNode() const { + impl_->start_node = true; +} + +bool AscIrDef::IsStartNode() const { + return impl_->start_node; +} + +void AscIrDef::SetAttr(const std::string &name, const std::string &asc_type, const std::string &ge_type) const { + impl_->attr_defs.emplace_back(AscIrAttrDef{name, asc_type, ge_type}); +} + +void AscIrDef::SetDtypePolicy(const std::vector &output_dtypes_policy) const { + impl_->output_dtypes_policy = output_dtypes_policy; +} + +const std::vector &AscIrDef::GetOutputDtypePolicy() const { + return impl_->output_dtypes_policy; +} + +void AscIrDef::SetViewPolicy(const std::vector &view_policy) const { + impl_->output_views_policy = view_policy; +} + +const std::vector &AscIrDef::GetViewPolicy() const { + return impl_->output_views_policy; +} + +void AscIrDef::SetApiTilingDataName(const std::string &tiling_data_name) const { + if (!impl_->tiling_data_name.empty()) { + GELOGE(ge::FAILED, "%s has registered tiling data: %s", impl_->type.c_str(), impl_->tiling_data_name.c_str()); + return; + } + impl_->tiling_data_name = tiling_data_name; +} + +const string &AscIrDef::GetApiTilingDataName() const { + return impl_->tiling_data_name; +} + +void AscIrDef::SetCalcTmpBufSizeFunc(const std::string &calc_tmp_buf_size_func, CalcTmpBufSizeFuncType type) const { + if (!impl_->calc_tmp_buf_size_func.func_name.empty()) { + GELOGE(ge::FAILED, "has registered calc_tmp_buf_size_func: %s", impl_->calc_tmp_buf_size_func.func_name.c_str()); + return; + } + impl_->calc_tmp_buf_size_func = CalcTmpBufSizeFunc{calc_tmp_buf_size_func, type}; +} + +const CalcTmpBufSizeFunc &AscIrDef::GetCalcTmpBufSizeFunc() const { + return impl_->calc_tmp_buf_size_func; +} + +const std::vector &AscIrDef::GetAttrDefs() const { + return impl_->attr_defs; +} + +std::vector &AscIrDef::MutableAttrDefs() const { + return impl_->attr_defs; +} + +void AscIrDef::SetComment(const string &comment) const { + impl_->comment = comment; +} + +const std::string &AscIrDef::GetComment() const { + return impl_->comment; +} +const std::string &AscIrDef::GetFilePath() const { + return impl_->file_path; +} + +int64_t AscIrDef::GetLine() const { + return impl_->line; +} + +IRDataTypeSymbolStore &AscIrDef::MutableDataTypeSymbolStore() const { + return impl_->dtype_symbol_store; +} + +const IRDataTypeSymbolStore &AscIrDef::GetDataTypeSymbolStore() const { + return impl_->dtype_symbol_store; +} + +void AscIrDef::AddAscIrImpl(const std::vector &soc_versions, const AscIrImpl &impl) const { + for (auto &soc : soc_versions) { + impl_->soc_2_impl_[soc] = impl; + } +} + +size_t AscIrDef::GetSocImplSize() const { + return impl_->soc_2_impl_.size(); +} + AscirRegistry &AscirRegistry::GetInstance() { static AscirRegistry registry; return registry; @@ -21,4 +168,4 @@ const std::unordered_map &AscirRegistry::GetAll() const { return types_to_ascir_; } } // namespace ascir -} +} // namespace ge diff --git a/graph/ascendc_ir/generator/generator.cc b/graph/ascendc_ir/generator/generator.cc index 9cb16741022845acaf4a46323fd05c5711fc961d..c322c8c8c64a78d2e34dda45d63a35fad7b51b1b 100644 --- a/graph/ascendc_ir/generator/generator.cc +++ b/graph/ascendc_ir/generator/generator.cc @@ -42,8 +42,7 @@ std::string CapitalizeFirstLetter(const std::string &input) { return result; } -void GenIrAttrMemberFuncs(const std::vector &attr_defs, - std::stringstream &ss) { +void GenIrAttrMemberFuncs(const std::vector &attr_defs, std::stringstream &ss) { if (attr_defs.empty()) { return; } @@ -55,8 +54,8 @@ void GenIrAttrMemberFuncs(const std::vector &attr_defs, ss << " GE_WARN_ASSERT(attr_value != nullptr);" << std::endl; ss << " return attr_value->GetValue(" << attr_def.name << ");" << std::endl << " }" << std::endl; - ss << " graphStatus Set" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.asc_ir_type - << " " << attr_def.name << ") {" << std::endl; + ss << " graphStatus Set" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.asc_ir_type << " " + << attr_def.name << ") {" << std::endl; ss << " auto attr_value = attr_store_.GetOrCreateAnyValue(\"" << attr_def.name << "\");" << std::endl; ss << " GE_ASSERT_NOTNULL(attr_value);" << std::endl; ss << " return attr_value->SetValue(" << attr_def.name << ");" << std::endl << " }" << std::endl; @@ -64,13 +63,14 @@ void GenIrAttrMemberFuncs(const std::vector &attr_defs, } std::string TryGenIrAttrClass(const AscIrDef &def, std::stringstream &ss) { - const auto &attr_defs = def.attr_defs; + const auto &attr_defs = def.GetAttrDefs(); if (attr_defs.empty()) { return (""); } - std::string derived_class_name = std::string("Asc").append(def.type).append("IrAttrDef"); + const std::string &ir_type = def.GetType(); + std::string derived_class_name = std::string("Asc").append(ir_type).append("IrAttrDef"); // 暂时没啥好的办法,data的类需要先定义好,gen出来的话有点晚了 - if (def.type == ge::DATA) { + if (ir_type == ge::DATA) { ss << " using " << derived_class_name << " = ge::" << derived_class_name << ";" << std::endl; ss << " " << derived_class_name << " &ir_attr;" << std::endl; return derived_class_name; @@ -85,7 +85,7 @@ std::string TryGenIrAttrClass(const AscIrDef &def, std::stringstream &ss) { return derived_class_name; }; void GenIrInputAndOutputDef(const AscIrDef &def, std::stringstream &ss) { - const auto &input_defs = def.input_defs; + const auto &input_defs = def.GetInputDefs(); for (const auto &input_def : input_defs) { if (input_def.second == ge::IrInputType::kIrInputDynamic) { ss << " this->DynamicInputRegister(\"" << input_def.first << "\", 0U, true);" << std::endl; @@ -96,7 +96,7 @@ void GenIrInputAndOutputDef(const AscIrDef &def, std::stringstream &ss) { } } - const auto &output_defs = def.output_defs; + const auto &output_defs = def.GetOutputDefs(); for (const auto &output_def : output_defs) { if (output_def.second == ge::IrOutputType::kIrOutputDynamic) { ss << " this->DynamicOutputRegister(\"" << output_def.first << "\", 0U, true);" << std::endl; @@ -107,16 +107,18 @@ void GenIrInputAndOutputDef(const AscIrDef &def, std::stringstream &ss) { } // 初始化列表中初始化的成员依赖构造函数中的ir信息确定之后再次进行赋值 void GenOutTensorInitDef(const AscIrDef &def, std::stringstream &ss) { - for (const auto &output_def : def.output_defs) { + const auto &output_defs = def.GetOutputDefs(); + for (const auto &output_def : output_defs) { ss << " " << output_def.first << ".TryInitTensorAttr();\n"; } } void GenConstructorDef(const AscIrDef &def, const std::string &attr_class, std::stringstream &ss, const bool need_graph = false) { + const std::string &ir_type = def.GetType(); if (need_graph) { - ss << " inline " << def.type << "(const char *name, AscGraph &graph) : ge::Operator"; + ss << " inline " << ir_type << "(const char *name, AscGraph &graph) : ge::Operator"; } else { - ss << " inline " << def.type << "(const char *name) : ge::Operator"; + ss << " inline " << ir_type << "(const char *name) : ge::Operator"; } if (attr_class.empty()) { ss << "(name, Type), attr(" << "*AscNodeAttr::Create(*this))"; @@ -124,11 +126,13 @@ void GenConstructorDef(const AscIrDef &def, const std::string &attr_class, std:: ss << "(name, Type), attr(" << "*AscNodeAttr::Create<" << attr_class << ">(*this))"; ss << ", ir_attr(dynamic_cast<" << attr_class << "&>(*(attr.ir_attr)))"; } - for (const auto &input_def : def.input_defs) { + const auto &input_defs = def.GetInputDefs(); + for (const auto &input_def : input_defs) { ss << "," << std::endl << " " << input_def.first << "(this)"; } - for (size_t i = 0UL; i < def.output_defs.size(); ++i) { - ss << "," << std::endl << " " << def.output_defs[i].first << "(this, " << i << ")"; + const auto &output_defs = def.GetOutputDefs(); + for (size_t i = 0UL; i < output_defs.size(); ++i) { + ss << "," << std::endl << " " << output_defs[i].first << "(this, " << i << ")"; } ss << " {" << std::endl; GenIrInputAndOutputDef(def, ss); @@ -142,7 +146,7 @@ void GenConstructorDef(const AscIrDef &def, const std::string &attr_class, std:: std::string DataTypeToSerialString(const DataType data_type) { auto res = TypeUtils::DataTypeToSerialString(data_type); - if (res == "DT_BFLOAT16") { // 历史原因,DT_BF16的string表达是DT_BFLOAT16,所以我们需要特殊处理一下 + if (res == "DT_BFLOAT16") { // 历史原因,DT_BF16的string表达是DT_BFLOAT16,所以我们需要特殊处理一下 return "DT_BF16"; } return res; @@ -164,8 +168,7 @@ std::string TensorTypeToCode(const TensorType &tensor_type) { class SymbolProcessor { public: explicit SymbolProcessor(const AscIrDef &def) : def_(def) {} - Status ProcessSymbol(const std::pair &sym, - std::stringstream &ss) { + Status ProcessSymbol(const std::pair &sym, std::stringstream &ss) { // 外部保证非空,不允许注册一个不带dtype的sym GE_ASSERT_TRUE(!(sym.second->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet().empty())); // 只有输出持有的sym不在这里处理 @@ -178,8 +181,7 @@ class SymbolProcessor { return SUCCESS; } - Status ProcessSymbolWithNoCheck(const std::pair &sym, - std::stringstream &ss) { + Status ProcessSymbolWithNoCheck(const std::pair &sym, std::stringstream &ss) { // 外部保证非空,不允许注册一个不带dtype的sym GE_ASSERT_TRUE(!(sym.second->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet().empty())); // 只有输出持有的sym不在这里处理 @@ -189,6 +191,7 @@ class SymbolProcessor { GenerateInputDtypeUniquenessCheck(sym.second, ss); return SUCCESS; } + protected: static void GenerateTypeDefinition(const std::pair &sym, std::stringstream &ss) { const std::string tensor_type_obj = "support_dtypes_of_sym_" + sym.first; @@ -199,9 +202,8 @@ class SymbolProcessor { // 共sym的已经都校验过一致了,所以这里可以用第一个来校验是否在支持范围内 auto check_index = sym.second->GetIrInputIndexes().front(); ss << " GE_WARN_ASSERT(" - << "support_dtypes_of_sym_" << sym.first - << ".find(input_dtypes[" - << check_index << "]) != support_dtypes_of_sym_" << sym.first << ".end());" << std::endl; + << "support_dtypes_of_sym_" << sym.first << ".find(input_dtypes[" << check_index + << "]) != support_dtypes_of_sym_" << sym.first << ".end());" << std::endl; } void GenerateInputDtypeUniquenessCheck(const SymDtype *sym, std::stringstream &ss) { @@ -227,7 +229,7 @@ class OutputHandler { bool could_infer = true; std::stringstream warning_code; size_t out_index = 0U; - for (const auto out_sym : def_.dtype_symbol_store.GetOutSymbols()) { + for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { sym_2_ir_indexs_[out_sym].push_back(out_index); if (!(out_sym->GetIrInputIndexes().empty())) { // 有输入对应,一定是可以推导的 @@ -248,7 +250,7 @@ class OutputHandler { ss << std::string(space_count, ' ') << "return FAILED;\n"; return; } - for (const auto out_sym : def_.dtype_symbol_store.GetOutSymbols()) { + for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { if (syms_only_of_output_.find(out_sym) == syms_only_of_output_.end()) { // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype作为输出的dtype" ss << std::string(space_count, ' ') << "expect_output_dtypes.push_back(input_dtypes[" @@ -262,11 +264,10 @@ class OutputHandler { void GenerateOutputValidation(std::stringstream &ss) { size_t out_index = 0; - for (const auto out_sym : def_.dtype_symbol_store.GetOutSymbols()) { + for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { if (syms_only_of_output_.find(out_sym) == syms_only_of_output_.end()) { // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype跟out的dtype比对 - ss << " GE_WARN_ASSERT(input_dtypes[" - << out_sym->GetIrInputIndexes().front() << "] == " + ss << " GE_WARN_ASSERT(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "] == " << "expect_output_dtypes[" << out_index << "]);" << std::endl; } else { GenerateCustomTypeValidation(out_sym, ss); @@ -274,16 +275,16 @@ class OutputHandler { ++out_index; } } + protected: static void GenerateCustomTypeInference(const SymDtype *sym, std::stringstream &ss, int space_count) { // 走到这里说明out使用了跟所有input不一样的sym, 并且此输出只有一个支持的类型 auto support_types = sym->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet(); - ss << std::string(space_count, ' ') << "expect_output_dtypes.push_back(" << DataTypeToSerialString(*support_types.begin()) - << ");\n"; + ss << std::string(space_count, ' ') << "expect_output_dtypes.push_back(" + << DataTypeToSerialString(*support_types.begin()) << ");\n"; } - Status GenerateCustomTypeValidation(SymDtype *sym, - std::stringstream &ss) { + Status GenerateCustomTypeValidation(SymDtype *sym, std::stringstream &ss) { if (!syms_checked_.insert(sym).second) { return SUCCESS; } @@ -298,30 +299,29 @@ class OutputHandler { } } const std::string tensor_type_obj = "support_dtypes_of_sym_" + sym->Id(); - ss << " static std::set " << tensor_type_obj << " = " - << TensorTypeToCode(sym->GetTensorType()) << ";\n"; + ss << " static std::set " << tensor_type_obj << " = " << TensorTypeToCode(sym->GetTensorType()) + << ";\n"; // 共sym的已经都校验过一致了,所以这里可以用第一个来校验是否在支持范围内 auto check_index = indexes_of_this_sym.front(); ss << " GE_WARN_ASSERT(" - << "support_dtypes_of_sym_" << sym->Id() - << ".find(expect_output_dtypes[" - << check_index << "]) != support_dtypes_of_sym_" << sym->Id() << ".end());" << std::endl; + << "support_dtypes_of_sym_" << sym->Id() << ".find(expect_output_dtypes[" << check_index + << "]) != support_dtypes_of_sym_" << sym->Id() << ".end());" << std::endl; return SUCCESS; } + private: const AscIrDef &def_; std::set syms_only_of_output_{}; std::map> sym_2_ir_indexs_{}; - std::set syms_checked_{}; // 多个输出可能sym一样 + std::set syms_checked_{}; // 多个输出可能sym一样 }; class OrderedSymbolProcessor : public SymbolProcessor { public: - explicit OrderedSymbolProcessor(const AscIrDef &def) - : SymbolProcessor(def), valid_dtype_nums_of_sym_(0U) {} + explicit OrderedSymbolProcessor(const AscIrDef &def) : SymbolProcessor(def), valid_dtype_nums_of_sym_(0U) {} Status PreProcessSymbol(std::stringstream &ss) { - const auto &symbols = def_.dtype_symbol_store.GetSymbols(); + const auto &symbols = def_.GetDataTypeSymbolStore().GetSymbols(); GE_ASSERT_TRUE(!symbols.empty()); GE_ASSERT_SUCCESS(InitializeSymbolAttributes(symbols)); GE_ASSERT_SUCCESS(ClassifySymbols(symbols, ss)); @@ -339,7 +339,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { if (input_syms_.size() > 1U) { ss << " auto iter = results.find(std::vector{"; size_t index{0U}; - for (auto input_sym: input_syms_) { + for (auto input_sym : input_syms_) { ss << "input_dtypes[" << input_sym->GetIrInputIndexes().front() << "]"; if (index++ < input_syms_.size() - 1U) { ss << ", "; @@ -365,11 +365,11 @@ class OrderedSymbolProcessor : public SymbolProcessor { private: void GenerateOutputInference(std::stringstream &ss) { size_t only_output_index{0U}; - for (const auto out_sym : def_.dtype_symbol_store.GetOutSymbols()) { + for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { if (std::find(only_out_syms_.begin(), only_out_syms_.end(), out_sym) == only_out_syms_.end()) { // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype作为输出的dtype" - ss << " expect_output_dtypes.push_back(input_dtypes[" - << out_sym->GetIrInputIndexes().front() << "]);" << std::endl; + ss << " expect_output_dtypes.push_back(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "]);" + << std::endl; } else { // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 if (container_meta_.output_count == 1U) { @@ -389,8 +389,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { << std::endl; // std::vector> } else { - ss << " expect_output_dtypes.push_back(iter->second[" << only_output_index << "]));" - << std::endl; + ss << " expect_output_dtypes.push_back(iter->second[" << only_output_index << "]));" << std::endl; } only_output_index++; } @@ -402,11 +401,10 @@ class OrderedSymbolProcessor : public SymbolProcessor { void GenerateOutputValidation(std::stringstream &ss) { size_t only_output_index{0U}; size_t output_index{0U}; - for (const auto out_sym : def_.dtype_symbol_store.GetOutSymbols()) { + for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { if (std::find(only_out_syms_.begin(), only_out_syms_.end(), out_sym) == only_out_syms_.end()) { // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype跟输出的dtype做校验" - ss << " GE_WARN_ASSERT(input_dtypes[" - << out_sym->GetIrInputIndexes().front() << "] == " + ss << " GE_WARN_ASSERT(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "] == " << "expect_output_dtypes[" << output_index << "]);" << std::endl; } else { // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 @@ -424,8 +422,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { // std::vector> if (container_meta_.has_multiple_solutions) { ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "].find(expect_output_dtypes[" - << output_index - << "]) != iter->second[" << only_output_index << "].end());" << std::endl; + << output_index << "]) != iter->second[" << only_output_index << "].end());" << std::endl; // std::vector> } else { ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "] == " @@ -445,8 +442,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { return SUCCESS; } - Status ClassifySymbols(const std::list> &symbols, - std::stringstream &ss) { + Status ClassifySymbols(const std::list> &symbols, std::stringstream &ss) { for (const auto &sym : symbols) { GE_ASSERT_NOTNULL(sym); GE_ASSERT_TRUE(sym->IsOrderedList()); @@ -516,17 +512,14 @@ class OrderedSymbolProcessor : public SymbolProcessor { const auto solution_map = BuildSolutionMap(); const bool has_multiple = CheckMultipleSolutions(solution_map); - container_meta_ = { - .input_count = input_syms_.size(), - .output_count = only_out_syms_.size(), - .has_multiple_solutions = has_multiple - }; + container_meta_ = {.input_count = input_syms_.size(), + .output_count = only_out_syms_.size(), + .has_multiple_solutions = has_multiple}; return BuildContainerString(solution_map, container_meta_); } - using SolutionMap = std::map, - std::set>>; + using SolutionMap = std::map, std::set>>; SolutionMap BuildSolutionMap() { SolutionMap mapping; @@ -538,8 +531,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { static bool CheckMultipleSolutions(const SolutionMap &mapping) { return std::any_of(mapping.begin(), mapping.end(), - [](const std::pair, - std::set>> &pair) { + [](const std::pair, std::set>> &pair) { return pair.second.size() > 1; }); } @@ -550,8 +542,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { bool has_multiple_solutions; }; - std::string BuildContainerString(const SolutionMap &mapping, - const ContainerMeta &meta) { + std::string BuildContainerString(const SolutionMap &mapping, const ContainerMeta &meta) { std::ostringstream oss; container_type_ = GetContainerType(meta); oss << " const static " << container_type_ << " results = {\n"; @@ -562,25 +553,19 @@ class OrderedSymbolProcessor : public SymbolProcessor { } static std::string GetContainerType(const ContainerMeta &meta) { std::ostringstream oss; - oss << "std::map<" - << (meta.input_count > 1 ? "std::vector" : "ge::DataType") - << ", "; + oss << "std::map<" << (meta.input_count > 1 ? "std::vector" : "ge::DataType") << ", "; if (meta.output_count > 1) { - oss << (meta.has_multiple_solutions ? - "std::vector>" : "std::vector"); + oss << (meta.has_multiple_solutions ? "std::vector>" : "std::vector"); } else { - oss << (meta.has_multiple_solutions ? - "std::set" : "ge::DataType"); + oss << (meta.has_multiple_solutions ? "std::set" : "ge::DataType"); } oss << ">"; return oss.str(); } - static void AppendContainerEntries(std::ostream &os, - const SolutionMap &mapping, - const ContainerMeta &meta) { + static void AppendContainerEntries(std::ostream &os, const SolutionMap &mapping, const ContainerMeta &meta) { std::vector entries; entries.reserve(mapping.size()); @@ -591,10 +576,8 @@ class OrderedSymbolProcessor : public SymbolProcessor { } static std::string BuildEntryString(const std::vector &input, - const std::set> &outputs, - const ContainerMeta &meta) { - return "{" + SerializeVector(input) + ", " + - SerializeOutputs(outputs, meta) + "}"; + const std::set> &outputs, const ContainerMeta &meta) { + return "{" + SerializeVector(input) + ", " + SerializeOutputs(outputs, meta) + "}"; } static std::string SerializeVector(const std::vector &vec) { @@ -614,16 +597,14 @@ class OrderedSymbolProcessor : public SymbolProcessor { return oss.str(); } - static std::string SerializeOutputs(const std::set> &outputs, - const ContainerMeta &meta) { + static std::string SerializeOutputs(const std::set> &outputs, const ContainerMeta &meta) { if (meta.output_count == 1U) { return SerializeSingleOutput(outputs, meta.has_multiple_solutions); } return SerializeMultiOutputs(outputs, meta); } - static std::string SerializeSingleOutput(const std::set> &outputs, - bool multiple) { + static std::string SerializeSingleOutput(const std::set> &outputs, bool multiple) { if (!multiple) { return DataTypeToSerialString(outputs.begin()->front()); } @@ -636,7 +617,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { } static std::string SerializeMultiOutputs(const std::set> &outputs, - const ContainerMeta &meta) { + const ContainerMeta &meta) { if (!meta.has_multiple_solutions) { return SerializeVector(*outputs.begin()); } @@ -677,8 +658,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { return oss.str(); } - static std::string JoinEntries(const std::vector &entries, - const std::string &delimiter) { + static std::string JoinEntries(const std::vector &entries, const std::string &delimiter) { std::ostringstream oss; for (size_t i = 0U; i < entries.size(); ++i) { oss << entries[i]; @@ -702,7 +682,7 @@ class OrderedSymbolProcessor : public SymbolProcessor { class InferDtypeCodeGenerator { public: explicit InferDtypeCodeGenerator(const AscIrDef &def) : def_(def) { - is_ordered_dtype_infer_ = def_.dtype_symbol_store.IsSupportOrderedSymbolicInferDtype(); + is_ordered_dtype_infer_ = def_.GetDataTypeSymbolStore().IsSupportOrderedSymbolicInferDtype(); } Status Generate(std::stringstream &ss) { GenerateFunctionSignature(ss); @@ -712,6 +692,7 @@ class InferDtypeCodeGenerator { GenerateReturnStatement(ss); return SUCCESS; } + private: static void GenerateFunctionSignature(std::stringstream &ss) { ss << R"( inline static Status InferDataType(const std::vector& input_dtypes, @@ -721,10 +702,9 @@ class InferDtypeCodeGenerator { void GenerateArgsSizeAssertion(std::stringstream &ss) { ss << " // 校验入参容器的元素个数是否合法" << std::endl; - ss << " GE_ASSERT_EQ(input_dtypes.size(), " - << def_.input_defs.size() << "U);" << std::endl; + ss << " GE_ASSERT_EQ(input_dtypes.size(), " << def_.GetInputDefs().size() << "U);" << std::endl; ss << " GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == " - << def_.output_defs.size() << "U);" << std::endl; + << def_.GetOutputDefs().size() << "U);" << std::endl; ss << std::endl; } @@ -741,7 +721,7 @@ class InferDtypeCodeGenerator { } ss << " // 校验同sym的输入的dtype是否在注册范围内并且一致" << std::endl; SymbolProcessor symbol_processor(def_); - for (const auto &sym : def_.dtype_symbol_store.GetNamedSymbols()) { + for (const auto &sym : def_.GetDataTypeSymbolStore().GetNamedSymbols()) { symbol_processor.ProcessSymbol(sym, ss); } ss << std::endl; @@ -770,9 +750,9 @@ class InferDtypeCodeGenerator { }; class InferDtypeWithNoCheckCodeGenerator { -public: + public: explicit InferDtypeWithNoCheckCodeGenerator(const AscIrDef &def) : def_(def) { - is_ordered_dtype_infer_ = def_.dtype_symbol_store.IsSupportOrderedSymbolicInferDtype(); + is_ordered_dtype_infer_ = def_.GetDataTypeSymbolStore().IsSupportOrderedSymbolicInferDtype(); } Status Generate(std::stringstream &ss) const { GenerateFunctionSignature(ss); @@ -789,8 +769,8 @@ public: GenerateReturnStatement(ss); return SUCCESS; } -private: + private: static void GenerateFunctionSignature(std::stringstream &ss) { ss << R"( inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, std::vector& expect_output_dtypes) {)" @@ -799,8 +779,7 @@ private: void GenerateArgsSizeAssertion(std::stringstream &ss) const { ss << " // 校验入参容器的元素个数是否合法" << std::endl; - ss << " GE_ASSERT_EQ(input_dtypes.size(), " - << def_.input_defs.size() << "U);" << std::endl; + ss << " GE_ASSERT_EQ(input_dtypes.size(), " << def_.GetInputDefs().size() << "U);" << std::endl; ss << " GE_ASSERT_TRUE(expect_output_dtypes.empty());" << std::endl; ss << std::endl; @@ -809,7 +788,7 @@ private: Status GenerateSymbolProcessing(std::stringstream &ss) const { ss << " // 校验同sym的输入的dtype是否一致" << std::endl; SymbolProcessor symbol_processor(def_); - for (const auto &sym : def_.dtype_symbol_store.GetNamedSymbols()) { + for (const auto &sym : def_.GetDataTypeSymbolStore().GetNamedSymbols()) { symbol_processor.ProcessSymbolWithNoCheck(sym, ss); } ss << std::endl; @@ -841,35 +820,39 @@ Status GenInferDtypeWithNoCheckFuncDef(const AscIrDef &def, std::stringstream &s } void GenCopyConstructor(const AscIrDef &def, std::stringstream &ss) { - ss << " inline " << def.type << "& operator=(const " << def.type << "&) = delete;" << std::endl; - ss << " inline " << def.type << "(" << def.type << " &&) = delete;" << std::endl; - ss << " inline " << def.type << "(const " << def.type << " &other)"; + const std::string &ir_type = def.GetType(); + ss << " inline " << ir_type << "& operator=(const " << ir_type << "&) = delete;" << std::endl; + ss << " inline " << ir_type << "(" << ir_type << " &&) = delete;" << std::endl; + ss << " inline " << ir_type << "(const " << ir_type << " &other)"; ss << " : ge::Operator(other),"; ss << " attr(other.attr)"; - if (!def.attr_defs.empty()) { - ss << ", ir_attr(dynamic_cast(*(attr.ir_attr)))"; + if (!def.GetAttrDefs().empty()) { + ss << ", ir_attr(dynamic_cast(*(attr.ir_attr)))"; } - for (const auto &input_def : def.input_defs) { + const auto &input_defs = def.GetInputDefs(); + for (const auto &input_def : input_defs) { ss << "," << std::endl << " " << input_def.first << "(this)"; } - for (size_t i = 0UL; i < def.output_defs.size(); ++i) { - ss << "," << std::endl << " " << def.output_defs[i].first << "(this, " << i << ")"; + const auto &output_defs = def.GetOutputDefs(); + for (size_t i = 0UL; i < output_defs.size(); ++i) { + ss << "," << std::endl << " " << output_defs[i].first << "(this, " << i << ")"; } ss << " {" << std::endl; - for (const auto &output_def : def.output_defs) { + for (const auto &output_def : output_defs) { ss << " " << output_def.first << ".TryInitTensorAttr();" << std::endl; } ss << " }" << std::endl; } void GenAscIr(const AscIrDef &def, std::stringstream &ss) { + const std::string &ir_type = def.GetType(); ss << "namespace ascir_op {" << std::endl; - ss << "struct " << def.type << " : public ge::Operator {" << std::endl; - ss << " static constexpr const char *Type = \"" << def.type << "\";" << std::endl; + ss << "struct " << ir_type << " : public ge::Operator {" << std::endl; + ss << " static constexpr const char *Type = \"" << ir_type << "\";" << std::endl; ss << " AscNodeAttr &attr;" << std::endl; const auto &ir_attr_class_name = TryGenIrAttrClass(def, ss); // generate input output definitions - const auto &input_defs = def.input_defs; + const auto &input_defs = def.GetInputDefs(); for (size_t i = 0UL; i < input_defs.size(); ++i) { const auto &input_def = input_defs[i]; if (input_def.second == ge::IrInputType::kIrInputDynamic) { @@ -879,13 +862,13 @@ void GenAscIr(const AscIrDef &def, std::stringstream &ss) { } } - const auto &output_defs = def.output_defs; + const auto &output_defs = def.GetOutputDefs(); for (const auto &output_def : output_defs) { ss << " AscOpOutput " << output_def.first << ";" << std::endl; } // generate constructor func definitions - if (def.start_node) { + if (def.IsStartNode()) { GenConstructorDef(def, ir_attr_class_name, ss, true); } GenConstructorDef(def, ir_attr_class_name, ss); @@ -897,9 +880,10 @@ void GenAscIr(const AscIrDef &def, std::stringstream &ss) { } void GenIrComment(const AscIrDef &def, std::stringstream &ss) { - if (!def.comment.empty()) { + const auto &comment = def.GetComment(); + if (!comment.empty()) { ss << "/* \n"; - ss << def.comment << "\n"; + ss << comment << "\n"; ss << "*/ \n"; } } @@ -949,13 +933,15 @@ class FunctionGenerator { ss << " SET_SCHED_AXIS_IF_IN_CONTEXT(op);" << std::endl; } virtual void TryGenOutputsVectorizedAxis(std::stringstream &ss) const { - for (const auto &name : def_.output_defs) { + const auto &output_defs = def_.GetOutputDefs(); + for (const auto &name : output_defs) { ss << " *op." << name.first << ".vectorized_axis = AxisUtils::GetDefaultVectorizedAxis(*op." << name.first << ".axis, op.attr.sched.loop_axis);" << std::endl; } } virtual void GenOutputMemInfo(std::stringstream &ss) const { - for (const auto &name : def_.output_defs) { + const auto &output_defs = def_.GetOutputDefs(); + for (const auto &name : output_defs) { ss << " op." << name.first << ".mem->tensor_id = " << "CodeGenUtils::GenNextTensorId(op);" << std::endl; } @@ -975,22 +961,24 @@ class FunctionGenerator { return generated; } virtual void GenPaddingAxis(std::stringstream &ss) const { - for (const auto &name : def_.output_defs) { + const auto &output_defs = def_.GetOutputDefs(); + for (const auto &name : output_defs) { ss << " THROW(PadOutputViewToSched(op." << name.first << "));" << std::endl; } } virtual void GenReturn(std::stringstream &ss) const { - if (def_.output_defs.empty()) { + const auto &output_defs = def_.GetOutputDefs(); + if (output_defs.empty()) { ss << " return op;" << std::endl; - } else if (def_.output_defs.size() == 1U) { - ss << " return op." << def_.output_defs[0U].first << ";" << std::endl; + } else if (output_defs.size() == 1U) { + ss << " return op." << output_defs[0U].first << ";" << std::endl; } else { ss << " return std::make_tuple("; - for (size_t i = 0; i < def_.output_defs.size(); ++i) { + for (size_t i = 0; i < output_defs.size(); ++i) { if (i == 0) { - ss << "op." << def_.output_defs[i].first; + ss << "op." << output_defs[i].first; } else { - ss << " ,op." << def_.output_defs[i].first; + ss << " ,op." << output_defs[i].first; } } ss << ");" << std::endl; @@ -1000,11 +988,12 @@ class FunctionGenerator { protected: const AscIrDef &def_; + private: static bool NeedConnectByInputArgs(const bool has_optional_input, const std::pair &input_def) { return ((has_optional_input || (input_def.second != ge::IrInputType::kIrInputOptional)) && - (input_def.second != ge::IrInputType::kIrInputDynamic)); + (input_def.second != ge::IrInputType::kIrInputDynamic)); } }; @@ -1012,13 +1001,13 @@ void ascir::FunctionGenerator::GenDefinition(std::stringstream &ss, const bool h const std::vector> *input_defs; std::vector> empty_input_defs; - if (def_.start_node) { - // TTODO 由于历史原因,start_node(例如Data)仍然带有输入定义,但是这种输入实际是不连边的。 + if (def_.IsStartNode()) { + // TTODO 由于历史原因,IsStartNode()(例如Data)仍然带有输入定义,但是这种输入实际是不连边的。 // 但是为了最小化修改,当前先不修改Data的定义,后续需要做调整,对与StartNode类型,不定义输入, // 或者认为没有输入的op就是start node,在定义IR时不需要再显式指定start node标记 input_defs = &empty_input_defs; } else { - input_defs = &def_.input_defs; + input_defs = &def_.GetInputDefs(); } auto append_output_types = [&ss](size_t count) { for (size_t i = 0; i < count; ++i) { @@ -1028,14 +1017,14 @@ void ascir::FunctionGenerator::GenDefinition(std::stringstream &ss, const bool h ss << "AscOpOutput"; } }; - + const auto &output_defs = def_.GetOutputDefs(); ss << "inline "; - if (def_.output_defs.size() > 1U) { + if (output_defs.size() > 1U) { ss << "std::tuple<"; - append_output_types(def_.output_defs.size()); - ss << "> " << def_.type << "(const char* name"; + append_output_types(output_defs.size()); + ss << "> " << def_.GetType() << "(const char* name"; } else { - ss << "AscOpOutput " << def_.type << "(const char* name"; + ss << "AscOpOutput " << def_.GetType() << "(const char* name"; } if (!input_defs->empty()) { for (const auto &input_def : *input_defs) { @@ -1046,17 +1035,17 @@ void ascir::FunctionGenerator::GenDefinition(std::stringstream &ss, const bool h } else { ss << ", ge::AscGraph &graph"; } - - for (const auto &attr_def : def_.attr_defs) { + const auto &attr_defs = def_.GetAttrDefs(); + for (const auto &attr_def : attr_defs) { ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; } ss << ") {" << std::endl; } void ascir::FunctionGenerator::GenInstantiation(std::stringstream &ss) const { - if (def_.start_node) { - ss << " const auto &op_ptr = std::make_shared(name, graph);" << std::endl; + if (def_.IsStartNode()) { + ss << " const auto &op_ptr = std::make_shared(name, graph);" << std::endl; } else { - ss << " const auto &op_ptr = std::make_shared(name);" << std::endl; + ss << " const auto &op_ptr = std::make_shared(name);" << std::endl; } ss << " auto &op = *op_ptr;" << std::endl; ss << " const auto &desc = ge::OpDescUtils::GetOpDescFromOperator(op);" << std::endl; @@ -1064,33 +1053,35 @@ void ascir::FunctionGenerator::GenInstantiation(std::stringstream &ss) const { } bool ascir::FunctionGenerator::GenConnectInputs(std::stringstream &ss, const bool has_optional_input) const { // TTODO 这里与GenFunctionDefinition同理,后续删除 - if (def_.start_node) { + if (def_.IsStartNode()) { return false; } - if (!def_.input_defs.empty()) { - for (const auto &input_def : def_.input_defs) { + const auto &input_defs = def_.GetInputDefs(); + if (!input_defs.empty()) { + for (const auto &input_def : input_defs) { if (NeedConnectByInputArgs(has_optional_input, input_def)) { ss << " op." << input_def.first << " = " << input_def.first << "_in;" << std::endl; } } } - return !def_.input_defs.empty(); + return !input_defs.empty(); } bool ascir::FunctionGenerator::GenAttrAssignment(std::stringstream &ss) const { - if (!def_.attr_defs.empty()) { - for (const auto &attr_def : def_.attr_defs) { + const auto &attr_defs = def_.GetAttrDefs(); + if (!attr_defs.empty()) { + for (const auto &attr_def : attr_defs) { // 函数命名大驼峰,所以把属性名第一个字符转换成大写字母 ss << " op.ir_attr.Set" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.name << ");" << std::endl; } } - return !def_.attr_defs.empty(); + return !attr_defs.empty(); } class StartNodeFuncGenerator : public FunctionGenerator { public: explicit StartNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} void Gen(std::stringstream &ss, const bool has_optional_input) const override { - if (!def_.start_node || def_.output_defs.size() != 1UL) { + if (!def_.IsStartNode() || def_.GetOutputDefs().size() != 1UL) { return; } FunctionGenerator::Gen(ss, has_optional_input); @@ -1099,18 +1090,20 @@ class StartNodeFuncGenerator : public FunctionGenerator { (void) has_optional_input; // inline ascir::ops::OpType OpType ss << "inline " - << "AscOpOutput " << ' ' << def_.type << "(const char *name, ge::AscGraph &graph, ge::DataType dt" + << "AscOpOutput " << ' ' << def_.GetType() << "(const char *name, ge::AscGraph &graph, ge::DataType dt" << ", const std::vector &axis_ids" << ", const std::vector &repeats" << ", const std::vector &strides"; - for (const auto &attr_def : def_.attr_defs) { + const auto &attr_defs = def_.GetAttrDefs(); + for (const auto &attr_def : attr_defs) { ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; } ss << ") {" << std::endl; } bool GenOutputsAssignment(std::stringstream &ss) const override { - const auto &output_name = def_.output_defs[0].first; + const auto &output_defs = def_.GetOutputDefs(); + const auto &output_name = output_defs[0].first; ss << " op." << output_name << ".dtype = dt;" << std::endl; ss << " *op." << output_name << ".axis = axis_ids;" << std::endl; ss << " *op." << output_name << ".repeats = repeats;" << std::endl; @@ -1124,19 +1117,20 @@ class StoreNodeFuncGenerator : public FunctionGenerator { explicit StoreNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} void Gen(std::stringstream &ss, const bool has_optional_input) const override { (void) has_optional_input; - if (def_.type != "Store") { + if (def_.GetType() != "Store") { return; } ss << "inline " - << "void" << ' ' << def_.type << "(const char *name"; + << "void" << ' ' << def_.GetType() << "(const char *name"; ss << ", const ge::AscOpOutput &" << "ub_in"; ss << ", ge::AscOpOutput &gm_output"; - for (const auto &attr_def : def_.attr_defs) { + const auto &attr_defs = def_.GetAttrDefs(); + for (const auto &attr_def : attr_defs) { ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; } ss << ") {" << std::endl; ss << " auto store_out = Store(name, ub_in"; - for (const auto &attr_def : def_.attr_defs) { + for (const auto &attr_def : attr_defs) { ss << ", " << attr_def.name; } ss << ");" << std::endl; @@ -1154,7 +1148,7 @@ class ContiguousStartNodeFuncGenerator : FunctionGenerator { public: explicit ContiguousStartNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} void Gen(std::stringstream &ss, const bool has_optional_input) const override { - if (!def_.start_node || def_.output_defs.size() != 1UL) { + if (!def_.IsStartNode() || def_.GetOutputDefs().size() != 1UL) { return; } FunctionGenerator::Gen(ss, has_optional_input); @@ -1162,17 +1156,16 @@ class ContiguousStartNodeFuncGenerator : FunctionGenerator { void GenDefinition(std::stringstream &ss, const bool has_optional_input) const override { (void) has_optional_input; ss << "inline " - << "AscOpOutput" << " Contiguous" << def_.type - << "(const char *name, ge::AscGraph &graph, ge::DataType dt" + << "AscOpOutput" << " Contiguous" << def_.GetType() << "(const char *name, ge::AscGraph &graph, ge::DataType dt" << ", const std::vector &axes"; - - for (const auto &attr_def : def_.attr_defs) { + const auto &attr_defs = def_.GetAttrDefs(); + for (const auto &attr_def : attr_defs) { ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; } ss << ") {" << std::endl; } bool GenOutputsAssignment(std::stringstream &ss) const override { - const auto &output_name = def_.output_defs[0].first; + const auto &output_name = def_.GetOutputDefs()[0].first; ss << " op." << output_name << ".dtype = dt;" << std::endl; ss << " op." << output_name << ".SetContiguousView(axes);" << std::endl; return true; @@ -1203,7 +1196,8 @@ void GenFunc(const AscIrDef &def, std::stringstream &ss) { ContiguousStartNodeFuncGenerator(def).Gen(ss, false); StoreNodeFuncGenerator(def).Gen(ss, false); bool has_optional_input = false; - for (const auto &input_def : def.input_defs) { + const auto &input_defs = def.GetInputDefs(); + for (const auto &input_def : input_defs) { if (input_def.second == ge::IrInputType::kIrInputOptional) { has_optional_input = true; break; @@ -1217,26 +1211,29 @@ void GenFunc(const AscIrDef &def, std::stringstream &ss) { } } -void GenCalcBufFunc(std::stringstream &ss, const std::map, AscIrDef>& ordered_keys_to_def) { +void GenCalcBufFunc(std::stringstream &ss, + const std::map, AscIrDef> &ordered_keys_to_def) { std::stringstream ss_calc_tmp_buff_map; std::stringstream ss_calc_tmp_buff; for (auto &key_and_def : ordered_keys_to_def) { - if (key_and_def.second.calc_tmp_buf_size_func.func_name.empty()) { + const auto &calc_func = key_and_def.second.GetCalcTmpBufSizeFunc(); + if (calc_func.func_name.empty()) { continue; } - if (key_and_def.second.calc_tmp_buf_size_func.func_type == CalcTmpBufSizeFuncType::CustomizeType) { + if (calc_func.func_type == CalcTmpBufSizeFuncType::CustomizeType) { ss_calc_tmp_buff << "extern std::vector> "; - ss_calc_tmp_buff << key_and_def.second.calc_tmp_buf_size_func.func_name << "(const ge::AscNode &Node);" << std::endl; + ss_calc_tmp_buff << calc_func.func_name << "(const ge::AscNode &Node);" << std::endl; } - ss_calc_tmp_buff_map << " {\"" << key_and_def.second.type << "\", &"; - ss_calc_tmp_buff_map << key_and_def.second.calc_tmp_buf_size_func.func_name << "}," << std::endl; + ss_calc_tmp_buff_map << " {\"" << key_and_def.second.GetType() << "\", &"; + ss_calc_tmp_buff_map << calc_func.func_name << "}," << std::endl; } // 没有API注册时不生成CalcBuf函数 if (ss_calc_tmp_buff_map.str().empty()) { return; } ss << ss_calc_tmp_buff.str(); - ss << "inline std::vector> CalcAscNodeTmpSize(const ge::AscNode &node) {" << std::endl; + ss << "inline std::vector> CalcAscNodeTmpSize(const ge::AscNode &node) {" + << std::endl; ss << " typedef std::vector> (*calc_func_ptr) (const AscNode &node);" << std::endl; ss << " static const std::unordered_map node_calc_tmp_buff_map = {" << std::endl; ss << ss_calc_tmp_buff_map.str(); @@ -1254,18 +1251,20 @@ void GenCommonInferDtypeBaseFunc(std::stringstream &ss, const string &extra_str = "") { std::stringstream func_table; for (const auto &key_and_def : ordered_keys_to_def) { - const auto &node_type = key_and_def.second.type; + const auto &node_type = key_and_def.second.GetType(); func_table << " {\"" << node_type << "\", "; func_table << "::ge::ascir_op::" << node_type << "::InferDataType" << extra_str << "}," << std::endl; } if (func_table.str().empty()) { return; } - ss << "inline ge::Status CommonInferDtype" << extra_str << "(const std::string &type, const std::vector &input_dtypes,\n" + ss << "inline ge::Status CommonInferDtype" << extra_str + << "(const std::string &type, const std::vector &input_dtypes,\n" " std::vector &expect_output_dtypes) {" << std::endl; ss << " using func = ge::Status (*)(const std::vector &input_dtypes, \n" - " std::vector &expect_output_dtypes);" << std::endl; + " std::vector &expect_output_dtypes);" + << std::endl; ss << " static const std::unordered_map func_table = {" << std::endl; ss << func_table.str(); ss << " };" << std::endl; @@ -1304,12 +1303,13 @@ void GenAll(std::stringstream &ss) { std::map, AscIrDef> ordered_keys_to_def; for (const auto &type_and_def : AscirRegistry::GetInstance().GetAll()) { - ordered_keys_to_def[std::make_pair(type_and_def.second.file_path, type_and_def.second.line)] = type_and_def.second; + ordered_keys_to_def[std::make_pair(type_and_def.second.GetFilePath(), type_and_def.second.GetLine())] = + type_and_def.second; } for (const auto &key_and_def : ordered_keys_to_def) { - ss << "// Defined at " << GetPureFileName(key_and_def.second.file_path.c_str()) << ':' << key_and_def.second.line - << std::endl; + ss << "// Defined at " << GetPureFileName(key_and_def.second.GetFilePath().c_str()) << ':' + << key_and_def.second.GetLine() << std::endl; GenIrComment(key_and_def.second, ss); ss << "namespace ge {" << std::endl; GenAscIr(key_and_def.second, ss); @@ -1322,13 +1322,15 @@ void GenAll(std::stringstream &ss) { for (auto &key_and_def : ordered_keys_to_def) { GenFunc(key_and_def.second, ss); // 如果有node属性配置,重载一个不设置属性的构造函数,把属性变成可选 - if (!key_and_def.second.attr_defs.empty()) { - key_and_def.second.attr_defs.clear(); - GenFunc(key_and_def.second, ss); + + const auto &ascir_def = key_and_def.second; + if (!ascir_def.GetAttrDefs().empty()) { + ascir_def.MutableAttrDefs().clear(); + GenFunc(ascir_def, ss); } } - ss << "}" << std::endl; // namespace cg + ss << "}" << std::endl; // namespace cg GenCalcBufFunc(ss, ordered_keys_to_def); GenCommonInferDtypeFunc(ss, ordered_keys_to_def); GenCommonInferDtypeWithNoCheckFunc(ss, ordered_keys_to_def); @@ -1365,5 +1367,4 @@ int GenHeaderFile(const char *path) { return 0; } } // namespace ascir -} - +} // namespace ge diff --git a/inc/graph/ascendc_ir/ascir_register.h b/inc/graph/ascendc_ir/ascir_register.h index 2c6dc5c51e8b17e0b930b39694cdd9b9b200291f..c3a21ed7e03cf900a9c8d2028db7c176f19263a6 100644 --- a/inc/graph/ascendc_ir/ascir_register.h +++ b/inc/graph/ascendc_ir/ascir_register.h @@ -11,6 +11,7 @@ #define AUTOFUSE_ASCIR_REGISTER_H #include #include +#include #include "graph/ascendc_ir/ascir_registry.h" #include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" #include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h" @@ -21,9 +22,9 @@ namespace ge { namespace ascir { class AscIrCodegen { -public: + public: virtual std::vector> CalcTmpBufSize(const ge::AscNode &node) { - (void)node; + (void) node; return std::vector>(); } virtual std::string GetApiTilingTypeName() const { @@ -36,14 +37,29 @@ public: return 0U; } - // 创建api call对象, 返回对象指针 - virtual void* CreateApiCall() const = 0; - // 创建mciro api call对象, 返回对象指针 - virtual void* CreateMicroApiCall() const = 0; + // 返回api call类的名称 + virtual std::string GetApiCallName() const { + return ""; + } + + // 返回api的名称 + virtual std::string GetApiName() const { + return ""; + } + + // 返回api call类的名称 + virtual std::string GetMicroApiCallName() const { + return ""; + } + + // 返回api的名称 + virtual std::string GetMicroApiName() const { + return ""; + } }; class AscIrAtt { -public: + public: // 最内轴建议对齐值(默认32B对齐) virtual uint32_t GetInnerDimPromptAlignSize() { return 32U; @@ -60,6 +76,13 @@ public: virtual void *GetAscendCApiPerfTable() = 0; }; +template +std::function()> AscIrImplCreator() { + return []() { + return std::make_unique(); + }; +} + class AscirRegister { public: AscirRegister() = default; @@ -81,15 +104,18 @@ class AscirRegister { AscirRegister &InferDataType(AscIrDef::CodeGenerator infer_data_type_generator); AscirRegister &UseFirstInputDataType() { - return DataTypes(std::vector(ir_def_.output_defs.size(), DtypePolicy(0U))); + const auto &output_defs = ir_def_.GetOutputDefs(); + return DataTypes(std::vector(output_defs.size(), DtypePolicy(0U))); } AscirRegister &UseSecondInputDataType() { - return DataTypes(std::vector(ir_def_.output_defs.size(), DtypePolicy(1U))); + const auto &output_defs = ir_def_.GetOutputDefs(); + return DataTypes(std::vector(output_defs.size(), DtypePolicy(1U))); } AscirRegister &InferView(AscIrDef::CodeGenerator infer_view_generator); AscirRegister &UseFirstInputView() { - return Views(std::vector(ir_def_.output_defs.size(), ViewPolicy(0))); + const auto &output_defs = ir_def_.GetOutputDefs(); + return Views(std::vector(output_defs.size(), ViewPolicy(0))); } AscirRegister &StartNode(); @@ -101,11 +127,15 @@ class AscirRegister { AscirRegister(AscirRegister &&) noexcept = delete; AscirRegister &operator=(AscirRegister &&) noexcept = delete; - AscirRegister &CalcTmpBufSize(const std::string calc_tmp_buf_size_func); + AscirRegister &CalcTmpBufSize(const std::string &calc_tmp_buf_size_func); AscirRegister &SameTmpBufSizeFromFirstInput(); AscirRegister &ApiTilingDataType(const std::string &tiling_data_name); + AscirRegister &Impl(const std::vector &soc_version, const AscIrDef::AscIrImpl &impl); + + size_t GetSocImplSize() const; + private: AscirRegister &Attr(std::string name, std::string asc_type, std::string ge_type); @@ -115,20 +145,19 @@ class AscirRegister { #define REG_ASC_IR(type) static auto g_register_##type = AscirRegister(#type, __FILE__, __LINE__) #define REG_ASC_IR_START_NODE(type) REG_ASC_IR(type).Inputs({}).Outputs({"y"}).StartNode() -#define REG_ASC_IR_START_NODE_WITH_ATTR(type) REG_ASC_IR(type).Inputs({}).Outputs({"y"}).Attr("index").StartNode() +#define REG_ASC_IR_START_NODE_WITH_ATTR(type) \ + REG_ASC_IR(type).Inputs({}).Outputs({"y"}).Attr("index").StartNode() #define REG_ASC_IR_1IO(type) REG_ASC_IR(type).Input("x", "T").Output("y", "T").DataType("T", TensorType::ALL()) -#define REG_ASC_IR_2I1O(type) REG_ASC_IR(type).Input("x1", "T").Input("x2", "T").Output("y", "T").DataType("T", TensorType::ALL()) +#define REG_ASC_IR_2I1O(type) \ + REG_ASC_IR(type).Input("x1", "T").Input("x2", "T").Output("y", "T").DataType("T", TensorType::ALL()) #define EXPAND_CHAIN_CALL(...) #__VA_ARGS__ -#define REG_ASC_IR_WITH_COMMENT(type, ...) \ - constexpr const char* comment_##type = \ - R"COMMENT(REG_ASC_IR()COMMENT" #type \ - ")\n" EXPAND_CHAIN_CALL(__VA_ARGS__) ";"; \ - static auto g_register_##type = AscirRegister(#type, __FILE__, __LINE__) \ - __VA_ARGS__ \ - .Comment(comment_##type) +#define REG_ASC_IR_WITH_COMMENT(type, ...) \ + constexpr const char *comment_##type = \ + R"COMMENT(REG_ASC_IR()COMMENT" #type ")\n" EXPAND_CHAIN_CALL(__VA_ARGS__) ";"; \ + static auto g_register_##type = AscirRegister(#type, __FILE__, __LINE__) __VA_ARGS__.Comment(comment_##type) #define EXPORT_GENERATOR() -} -} +} // namespace ascir +} // namespace ge #endif // AUTOFUSE_ASCIR_REGISTER_H diff --git a/inc/graph/ascendc_ir/ascir_registry.h b/inc/graph/ascendc_ir/ascir_registry.h index 0c280af0f584f8ee949487771cc7642bcde9c851..c8359fea54a576c822af26789b7cb6c882566156 100644 --- a/inc/graph/ascendc_ir/ascir_registry.h +++ b/inc/graph/ascendc_ir/ascir_registry.h @@ -21,6 +21,7 @@ #include "external/graph/types.h" #include "op_desc.h" #include "ir/ir_data_type_symbol_store.h" +#include "graph/ascendc_ir/ascir_register.h" namespace ge { namespace ascir { using ApplyOutputView = std::function; @@ -35,14 +36,13 @@ struct ViewPolicy { ViewPolicy(uint32_t element_wise_input_index) : use_input_index(element_wise_input_index) { view_type = kElementWise; } - ViewPolicy(uint32_t reduce_input_index, std::string reduce_axis_name) : use_input_index(reduce_input_index), - reduce_axis_attr_name(std::move( - reduce_axis_name)) { + ViewPolicy(uint32_t reduce_input_index, std::string reduce_axis_name) + : use_input_index(reduce_input_index), reduce_axis_attr_name(std::move(reduce_axis_name)) { view_type = kReduce; } - explicit ViewPolicy(std::vector broad_cast_in_indexs) : broad_cast_input_indexs(std::move( - broad_cast_in_indexs)) { + explicit ViewPolicy(std::vector broad_cast_in_indexs) + : broad_cast_input_indexs(std::move(broad_cast_in_indexs)) { view_type = kBroadCast; } @@ -66,6 +66,7 @@ struct DtypePolicy { kUseDtype, kInvalid, }; + public: DtypePolicy(uint32_t use_in_index) : use_input_index(use_in_index) { policy_type = kUseInput; @@ -97,67 +98,91 @@ struct CalcTmpBufSizeFunc { std::string func_name; CalcTmpBufSizeFuncType func_type = CalcTmpBufSizeFuncType::CommonType; CalcTmpBufSizeFunc() = default; - CalcTmpBufSizeFunc(std::string name, const CalcTmpBufSizeFuncType type) : func_name(std::move(name)), func_type(type) {} + CalcTmpBufSizeFunc(std::string name, const CalcTmpBufSizeFuncType type) + : func_name(std::move(name)), func_type(type) {} }; -struct AscIrDef { + +struct AscIrDefImpl; +class AscIrDef { + public: + AscIrDef(); using CodeGenerator = void (*)(const AscIrDef &def, std::stringstream &ss); - bool IsAttrExisted(const std::string &attr_name) const { - return std::find_if(attr_defs.begin(), attr_defs.end(), [&attr_name](const AscIrAttrDef &asc_ir_attr_def) { - return asc_ir_attr_def.name == attr_name; - }) != attr_defs.end(); - } - std::string file_path; - int64_t line; - std::string type; + bool IsAttrExisted(const std::string &attr_name) const; + + void Init(const char *type, const char *def_file_path, int64_t line) const; + + const std::vector> &GetInputDefs() const; + const std::vector> &GetOutputDefs() const; - // 当前只有必选输入一种,没有其他类型,因此暂时简单处理,后续有复杂的optional后,defs的类型就不是string了 - std::vector> input_defs; - std::vector> output_defs; - std::vector attr_defs; + void AppendInput(const string &name, ge::IrInputType type) const; + void AppendOutput(const string &name, ge::IrOutputType type) const; + const std::string &GetType() const; + void StartNode() const; + bool IsStartNode() const; + void SetAttr(const std::string &name, const std::string &asc_type, const std::string &ge_type) const; + void SetDtypePolicy(const std::vector &output_dtypes_policy) const; + const std::vector &GetOutputDtypePolicy() const; + void SetViewPolicy(const std::vector &view_policy) const; + const std::vector &GetViewPolicy() const; + void SetApiTilingDataName(const std::string &tiling_data_name) const; + const string &GetApiTilingDataName() const; + void SetCalcTmpBufSizeFunc(const std::string &calc_tmp_buf_size_func, CalcTmpBufSizeFuncType type) const; + const CalcTmpBufSizeFunc &GetCalcTmpBufSizeFunc() const; + const std::vector &GetAttrDefs() const; + std::vector &MutableAttrDefs() const; + void SetComment(const string &comment) const; + const string &GetComment() const; + const std::string &GetFilePath() const; + int64_t GetLine() const; + IRDataTypeSymbolStore &MutableDataTypeSymbolStore() const; + const IRDataTypeSymbolStore &GetDataTypeSymbolStore() const; - std::vector output_views_policy; - std::vector output_dtypes_policy; + using AscIrAttCreator = std::function()>; + using AscIrCodegenCreator = std::function()>; - bool start_node{false}; - CodeGenerator infer_data_type_generator; - CodeGenerator infer_view_generator; - IRDataTypeSymbolStore dtype_symbol_store; - std::string comment; - CalcTmpBufSizeFunc calc_tmp_buf_size_func; - std::string tiling_data_name; + struct AscIrImpl { + AscIrAttCreator att; + AscIrCodegenCreator codegen; + std::vector> support_dtypes; + }; + void AddAscIrImpl(const std::vector &soc_versions, const AscIrImpl &impl) const; + + size_t GetSocImplSize() const; + CodeGenerator infer_data_type_generator{nullptr}; + CodeGenerator infer_view_generator{nullptr}; + + private: + friend class AscirRegister; + std::shared_ptr impl_; }; -inline std::string UpdateViewIfCrossLoop(const std::string &trans_infos, - const std::string &input_api_sched_axis, - const std::string &op_attr_sched_axis, - const std::string &tie_expression) { - return "AxisUtils::UpdateViewIfCrossLoop(" + trans_infos + ", " + input_api_sched_axis + ", " + op_attr_sched_axis - + ", " + "std::move(" + tie_expression + "))"; +inline std::string UpdateViewIfCrossLoop(const std::string &trans_infos, const std::string &input_api_sched_axis, + const std::string &op_attr_sched_axis, const std::string &tie_expression) { + return "AxisUtils::UpdateViewIfCrossLoop(" + trans_infos + ", " + input_api_sched_axis + ", " + op_attr_sched_axis + + ", " + "std::move(" + tie_expression + "))"; } inline void GenChosenInputView(const AscIrDef &def, const uint32_t chosen_input_index, std::stringstream &ss) { - ss << def.input_defs[chosen_input_index].first + "_tmp = " << "{*" - << def.input_defs[chosen_input_index].first << "_in.axis, *" - << def.input_defs[chosen_input_index].first << "_in.repeats, *" - << def.input_defs[chosen_input_index].first << "_in.strides};" - << std::endl; + const auto &input_defs = def.GetInputDefs(); + ss << input_defs[chosen_input_index].first + "_tmp = " << "{*" << input_defs[chosen_input_index].first + << "_in.axis, *" << input_defs[chosen_input_index].first << "_in.repeats, *" + << input_defs[chosen_input_index].first << "_in.strides};" << std::endl; } template void GenErrorIfPolicyInvalid(Policy policy, size_t range, std::stringstream &ss) { if (policy.use_input_index < range) { return; } - ss << "Policy is invalid as use_input_index :" << policy.use_input_index << " should be less than input size:" - << range << std::endl; + ss << "Policy is invalid as use_input_index :" << policy.use_input_index + << " should be less than input size:" << range << std::endl; } -inline void DefineChosenInputView(const AscIrDef &def, const ViewPolicy &policy, - uint32_t &chosen_input_index, - std::unordered_set &chosen_input_index_set, - std::stringstream &ss) { +inline void DefineChosenInputView(const AscIrDef &def, const ViewPolicy &policy, uint32_t &chosen_input_index, + std::unordered_set &chosen_input_index_set, std::stringstream &ss) { + const auto &input_defs = def.GetInputDefs(); ss << " // set tmp view to store input view and apply view transform" << std::endl; const std::string view_type("View "); - GenErrorIfPolicyInvalid(policy, def.input_defs.size(), ss); + GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); chosen_input_index = policy.use_input_index; ss << " "; if (chosen_input_index_set.insert(chosen_input_index).second) { @@ -167,31 +192,31 @@ inline void DefineChosenInputView(const AscIrDef &def, const ViewPolicy &policy, } inline void SameDataTypeFromInput(const AscIrDef &def, std::stringstream &ss, const char *input_name) { - for (const auto &output_def : def.output_defs) { + const auto &output_defs = def.GetOutputDefs(); + for (const auto &output_def : output_defs) { ss << " op." << output_def.first << ".dtype = static_cast(" << input_name << "_in.dtype);" << std::endl; } } -inline void GenerateViewUpdateCode(const AscIrDef &def, - const std::pair out_to_chosen_input, - const ApplyOutputView &apply_output_view, - std::stringstream &ss, +inline void GenerateViewUpdateCode(const AscIrDef &def, const std::pair out_to_chosen_input, + const ApplyOutputView &apply_output_view, std::stringstream &ss, bool &gen_trans_infos_instance) { + const auto &input_defs = def.GetInputDefs(); + const auto &output_defs = def.GetOutputDefs(); const size_t output_index = out_to_chosen_input.first; const size_t chosen_input_index = out_to_chosen_input.second; if (!gen_trans_infos_instance) { - ss << " auto trans_infos = CodeGenUtils::GetOwnerGraphAscAttr(op." << def.output_defs[output_index].first - << ".GetOwnerOp())" << "->trans_info_road;" - << std::endl; + ss << " auto trans_infos = CodeGenUtils::GetOwnerGraphAscAttr(op." << output_defs[output_index].first + << ".GetOwnerOp())" << "->trans_info_road;" << std::endl; gen_trans_infos_instance = true; } - const std::string which_input_api_sched_axis = def.output_defs[output_index].first + "_in_api_sched_axis"; + const std::string which_input_api_sched_axis = output_defs[output_index].first + "_in_api_sched_axis"; ss << " auto " << which_input_api_sched_axis << " = CodeGenUtils::GetOwnerOpAscAttr(" - << def.input_defs[chosen_input_index].first << "_in.GetOwnerOp())" + << input_defs[chosen_input_index].first << "_in.GetOwnerOp())" << "->sched.axis;" << std::endl; - std::string view = def.input_defs[chosen_input_index].first + "_tmp"; + std::string view = input_defs[chosen_input_index].first + "_tmp"; ss << " {" << std::endl << " const auto &[axes, repeats, strides] = "; std::string val = UpdateViewIfCrossLoop("trans_infos", which_input_api_sched_axis, "op.attr.sched.axis", view).append(".second"); @@ -201,17 +226,17 @@ inline void GenerateViewUpdateCode(const AscIrDef &def, } else { ss << val << ";" << std::endl; } - ss << " std::tie(*op." << def.output_defs[output_index].first << ".axis, *op." - << def.output_defs[output_index].first << ".repeats, *op." << def.output_defs[output_index].first - << ".strides) = std::make_tuple(axes, repeats, strides);" << std::endl + ss << " std::tie(*op." << output_defs[output_index].first << ".axis, *op." << output_defs[output_index].first + << ".repeats, *op." << output_defs[output_index].first << ".strides) = std::make_tuple(axes, repeats, strides);" + << std::endl << " }" << std::endl; } inline ApplyOutputView GenApplyOutputViewFunc(const AscIrDef &def, const size_t output_index, - uint32_t &chosen_input_index, - std::stringstream &ss) { + uint32_t &chosen_input_index, std::stringstream &ss) { (void) chosen_input_index; - const auto &policy = def.output_views_policy[output_index]; + const auto &output_views_policy = def.GetViewPolicy(); + const auto &policy = output_views_policy[output_index]; ApplyOutputView apply_output_view; switch (policy.view_type) { case ViewPolicy::kElementWise: @@ -220,12 +245,11 @@ inline ApplyOutputView GenApplyOutputViewFunc(const AscIrDef &def, const size_t if (!def.IsAttrExisted(policy.reduce_axis_attr_name)) { return apply_output_view; } - apply_output_view = [&def, output_index](const std::string &var) -> std::string { - return "AxisUtils::ReduceView(" + var + ", " + def.output_views_policy[output_index].reduce_axis_attr_name + - ")"; + apply_output_view = [&output_views_policy, output_index](const std::string &var) -> std::string { + return "AxisUtils::ReduceView(" + var + ", " + output_views_policy[output_index].reduce_axis_attr_name + ")"; }; break; - case ViewPolicy::kBroadCast: // TTODO 广播代码后续支持 + case ViewPolicy::kBroadCast: // TTODO 广播代码后续支持 case ViewPolicy::kInvalid: default: ss << "unsupported policy type: " << policy.view_type << std::endl; @@ -235,18 +259,21 @@ inline ApplyOutputView GenApplyOutputViewFunc(const AscIrDef &def, const size_t } inline void InferViewByPolicy(const AscIrDef &def, std::stringstream &ss) { - if (def.output_defs.size() != def.output_views_policy.size()) { - std::string error_info = - std::string("view_policy's size ").append(std::to_string(def.output_views_policy.size())).append( - " should be equal with output_defs's size ").append(std::to_string(def.output_defs.size())); + const auto &output_defs = def.GetOutputDefs(); + const auto &output_views_policy = def.GetViewPolicy(); + if (output_defs.size() != output_views_policy.size()) { + std::string error_info = std::string("view_policy's size ") + .append(std::to_string(output_views_policy.size())) + .append(" should be equal with output_defs's size ") + .append(std::to_string(output_defs.size())); ss << error_info; return; } bool gen_trans_infos_instance = false; std::unordered_set chosen_input_index_set; - for (size_t output_index = 0U; output_index < def.output_views_policy.size(); ++output_index) { + for (size_t output_index = 0U; output_index < output_views_policy.size(); ++output_index) { uint32_t chosen_input_index = 0U; - DefineChosenInputView(def, def.output_views_policy[output_index], chosen_input_index, chosen_input_index_set, ss); + DefineChosenInputView(def, output_views_policy[output_index], chosen_input_index, chosen_input_index_set, ss); GenerateViewUpdateCode(def, std::make_pair(output_index, chosen_input_index), GenApplyOutputViewFunc(def, output_index, chosen_input_index, ss), ss, gen_trans_infos_instance); @@ -254,55 +281,66 @@ inline void InferViewByPolicy(const AscIrDef &def, std::stringstream &ss) { } inline void InferDtypeByPolicy(const AscIrDef &def, std::stringstream &ss) { - if (def.output_defs.size() != def.output_dtypes_policy.size()) { - std::string error_info = - std::string("dtype_policy's size ").append(std::to_string(def.output_dtypes_policy.size())).append( - "should be equal with output_defs's size ").append(std::to_string(def.output_defs.size())); + const auto &output_defs = def.GetOutputDefs(); + const auto &output_dtypes_policy = def.GetOutputDtypePolicy(); + if (output_defs.size() != output_dtypes_policy.size()) { + std::string error_info = std::string("dtype_policy's size ") + .append(std::to_string(output_dtypes_policy.size())) + .append("should be equal with output_defs's size ") + .append(std::to_string(output_defs.size())); ss << error_info; return; } - for (size_t output_index = 0U; output_index < def.output_dtypes_policy.size(); ++output_index) { - const auto &policy = def.output_dtypes_policy[output_index]; + const auto &input_defs = def.GetInputDefs(); + for (size_t output_index = 0U; output_index < output_dtypes_policy.size(); ++output_index) { + const auto &policy = output_dtypes_policy[output_index]; switch (policy.policy_type) { - case DtypePolicy::kUseInput:GenErrorIfPolicyInvalid(policy, def.input_defs.size(), ss); - ss << " op." << def.output_defs[output_index].first << ".dtype = static_cast(" - << def.input_defs[policy.use_input_index].first << "_in.dtype);" << std::endl; + case DtypePolicy::kUseInput: + GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); + ss << " op." << output_defs[output_index].first << ".dtype = static_cast(" + << input_defs[policy.use_input_index].first << "_in.dtype);" << std::endl; break; - case DtypePolicy::kPromptInput:GenErrorIfPolicyInvalid(policy, def.input_defs.size(), ss); - ss << " op." << def.output_defs[output_index].first + case DtypePolicy::kPromptInput: + GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); + ss << " op." << output_defs[output_index].first << ".dtype = DtypeTransformUtils::Prompt(static_cast(" - << def.input_defs[policy.use_input_index].first << "_in.dtype));" << std::endl; + << input_defs[policy.use_input_index].first << "_in.dtype));" << std::endl; break; case DtypePolicy::kUseDtype: - ss << " op." << def.output_defs[output_index].first << ".dtype = static_cast(" << policy.data_type + ss << " op." << output_defs[output_index].first << ".dtype = static_cast(" << policy.data_type << ");" << std::endl; break; case DtypePolicy::kInvalid: - default:ss << "unsupported policy type: " << policy.policy_type << std::endl; + default: + ss << "unsupported policy type: " << policy.policy_type << std::endl; } } } inline void SameDataTypeFromFirstInput(const AscIrDef &def, std::stringstream &ss) { - if (!def.input_defs.empty()) { - SameDataTypeFromInput(def, ss, def.input_defs[0].first.c_str()); + const auto &input_defs = def.GetInputDefs(); + if (!input_defs.empty()) { + SameDataTypeFromInput(def, ss, input_defs[0].first.c_str()); } } inline void SameDataTypeFromSecondInput(const AscIrDef &def, std::stringstream &ss) { - if (def.input_defs.size() > 1U) { - SameDataTypeFromInput(def, ss, def.input_defs[1].first.c_str()); + const auto &input_defs = def.GetInputDefs(); + if (input_defs.size() > 1U) { + SameDataTypeFromInput(def, ss, input_defs[1].first.c_str()); } } inline void SameViewFromInput(const AscIrDef &def, std::stringstream &ss, const char *input_name) { - for (const auto &output_def : def.output_defs) { + const auto &output_defs = def.GetOutputDefs(); + for (const auto &output_def : output_defs) { ss << " op." << output_def.first << ".axis = " << input_name << "_in.axis;" << std::endl; ss << " op." << output_def.first << ".repeats = " << input_name << "_in.repeats;" << std::endl; ss << " op." << output_def.first << ".strides = " << input_name << "_in.strides;" << std::endl; } } inline void SameViewFromFirstInput(const AscIrDef &def, std::stringstream &ss) { - if (!def.input_defs.empty()) { - SameViewFromInput(def, ss, def.input_defs[0].first.c_str()); + const auto &input_defs = def.GetInputDefs(); + if (!input_defs.empty()) { + SameViewFromInput(def, ss, input_defs[0].first.c_str()); } } diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc b/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc index 4dc0fa3f405c2a23e2397c8707a9aaba500a05e9..0bbb8497c59579efc3c66fc859b687a3af373036 100644 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc +++ b/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc @@ -189,7 +189,7 @@ TEST_F(UtestAscendCIR, BindBlock) { auto merge_axis = graph.MergeAxis({outer_axis.id, s1_axis.id}); EXPECT_TRUE(graph.BindBlock(merge_axis->id, inner_axis.id)); EXPECT_EQ(merge_axis->type, Axis::kAxisTypeBlockOuter); - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeBlockInner); + EXPECT_EQ(inner_axis.type, Axis::kAxisTypeBlockInner); } TEST_F(UtestAscendCIR, GetAllAxisTransInfo) { @@ -3323,5 +3323,29 @@ TEST_F(UtestAscendCIR, AscOpDynamicInputVectorConstructor) { TEST_F(UtestAscendCIR, RegisterTilingData) { auto ir_defs = ge::ascir::AscirRegistry::GetInstance().GetAll(); EXPECT_NE(ir_defs.find("StubTilingData"), ir_defs.end()); - EXPECT_EQ(ir_defs["StubTilingData"].tiling_data_name, "StubTilingData"); + EXPECT_EQ(ir_defs["StubTilingData"].GetApiTilingDataName(), "StubTilingData"); +} + +TEST_F(UtestAscendCIR, AscirRegisterImpTest) { + class AscIrAttSub : public ge::ascir::AscIrAtt { + virtual void *GetApiPerf() { + return nullptr; + } + + virtual void *GetMicroApiPerf() { + return nullptr; + } + virtual void *GetAscendCApiPerfTable() { + return nullptr; + } + }; + + ge::AscirRegister reg_test; + reg_test.Impl( + {"v1", "v2", "v3"}, + {ge::ascir::AscIrImplCreator(), + ge::ascir::AscIrImplCreator(), + {{"T1", OrderedTensorTypeList{DT_INT8, DT_INT16}}, {"T2", OrderedTensorTypeList{DT_UINT8, DT_INT16}}}}); + + EXPECT_EQ(reg_test.GetSocImplSize(), 3); } \ No newline at end of file