From accf78bddd35aaabc8674bc5d778c58945054ca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BE=B7=E9=B9=8F?= Date: Mon, 8 Sep 2025 13:33:05 +0000 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!6876=20?= =?UTF-8?q?:=200902=E6=85=A2=E8=BD=A6'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graph/attr/ge_attr_define.cc | 3 - inc/CMakeLists.txt | 1 + inc/graph/debug/ge_attr_define.h | 2 - register/CMakeLists.txt | 1 - register/inference_rule.cc | 908 --------------- register/inference_rule.h | 120 -- register/shape_inference.cc | 164 +-- tests/ut/register/CMakeLists.txt | 4 +- .../testcase/inference_rule_unittest.cc | 1004 ----------------- .../testcase/shape_inference_unittest.cc | 2 +- 10 files changed, 55 insertions(+), 2154 deletions(-) delete mode 100644 register/inference_rule.cc delete mode 100644 register/inference_rule.h delete mode 100644 tests/ut/register/testcase/inference_rule_unittest.cc diff --git a/graph/attr/ge_attr_define.cc b/graph/attr/ge_attr_define.cc index 335e51f40b..e15172321c 100644 --- a/graph/attr/ge_attr_define.cc +++ b/graph/attr/ge_attr_define.cc @@ -1565,7 +1565,4 @@ const std::string ATTR_NAME_DO_NOT_CONSTANT_FOLDING = "_do_not_constant_folding" // for super kernel const std::string ATTR_NAME_SUPER_KERNEL_SCOPE = "_super_kernel_scope"; const std::string ATTR_NAME_SUPER_KERNEL_OPTIONS = "_super_kernel_options"; - -// inference rule for torch or other framework with symbols -const std::string ATTR_NAME_INFER_RULE = "_inference_rule"; } // namespace ge diff --git a/inc/CMakeLists.txt b/inc/CMakeLists.txt index f8be6daf9d..d252e479f4 100644 --- a/inc/CMakeLists.txt +++ b/inc/CMakeLists.txt @@ -30,6 +30,7 @@ target_include_directories(metadef_headers INTERFACE $ $ $ + $ $ $ $ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 6b04da83f7..ccc7fc8bc3 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -1556,8 +1556,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // for super kernel GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUPER_KERNEL_SCOPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUPER_KERNEL_OPTIONS; -// inference rule for torch or other framework with symbols -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFER_RULE; } // namespace ge /*lint +e618*/ diff --git a/register/CMakeLists.txt b/register/CMakeLists.txt index a813ea5265..791e97b6e6 100644 --- a/register/CMakeLists.txt +++ b/register/CMakeLists.txt @@ -59,7 +59,6 @@ set(SRC_LIST "scope/scope_util.cc" "scope/scope_pass_registry.cc" "shape_inference.cc" - "inference_rule.cc" "ascendc/ascendc_py.cc" "ascendc/op_check.cc" "ascendc/tilingdata_base.cc" diff --git a/register/inference_rule.cc b/register/inference_rule.cc deleted file mode 100644 index 7d807377b8..0000000000 --- a/register/inference_rule.cc +++ /dev/null @@ -1,908 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd.|Hisilicon Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "inference_rule.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/checker.h" -#include "external/graph/ge_error_codes.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" - -using Json = nlohmann::json; -namespace { -/** - * @brief 表达一个符号的来源 - * - * 用于描述某个符号源自输入的某个维度或某个值。并支持生成对应的C++定义代码片段。 - */ -class SymbolDef { - public: - explicit SymbolDef(const std::string &name) : name_(name), is_value_(name[0] == 'v') {} - - void RecordSource(size_t input_index, size_t offset) { - sources_.emplace_back(input_index, offset); - } - - [[nodiscard]] std::string Codegen() const { - std::stringstream ss; - if (!sources_.empty()) { - const size_t input = sources_.front().first; - const size_t offset = sources_.front().second; - if (is_value_) { - ss << " GET_SYMBOL_VALUE(" << name_ << ", " << input << ", " << offset << ");"; - } else { - ss << " GET_SYMBOL_DIM(" << name_ << ", " << input << ", " << offset << ");"; - } - } - return ss.str(); - } - - private: - std::string name_; - std::vector> sources_; - bool is_value_; -}; - -/** - * @brief 表达一个Shape维度由符号表达的输出Tensor - * - * 用于描述输出Shape每个维度的计算表达式,表达式是支持受限的表达式(+,-,*,Div,Floor,Ceil,Mod,Pow),也可以是常量表达式。 - */ -class SymbolTensor { - public: - explicit SymbolTensor(const size_t output_index) : output_index_(output_index) {} - - void AppendDim(const std::string &dim) { - dims_.push_back(dim); - } - - // 生成执行时的Shape设置代码片段 - [[nodiscard]] std::string Codegen() const { - std::stringstream ss; - ss << " SET_OUTPUT_RANK(" << output_index_ << ", " << dims_.size() << ");" << std::endl; - for (size_t i = 0; i < dims_.size(); i++) { - ss << " SET_OUTPUT_DIM(" << output_index_ << ", " << i << ", static_cast(" << dims_[i] << "));" - << std::endl; - } - return ss.str(); - } - - // 生成编译时的Shape设置代码片段 - [[nodiscard]] std::string CodegenCompileTime() const { - std::stringstream ss; - ss << " SET_OUTPUT_RANK(" << output_index_ << ", " << dims_.size() << ");" << std::endl; - for (size_t i = 0; i < dims_.size(); i++) { - const bool has_symbol = dims_[i].find('s') != std::string::npos || dims_[i].find('v') != std::string::npos; - ss << " SET_OUTPUT_DIM(" << output_index_ << ", " << i << ", " << (has_symbol ? "-1" : dims_[i]) << ");" - << std::endl; - } - return ss.str(); - } - - private: - size_t output_index_; - std::vector dims_; -}; - -/** - * @brief Shape推导规则的JSON解析器 - * - * 完成推导规则JSON的解析、合法性校验以及到InferShape代码的生成。 - */ -class RuleJsonParser { - public: - std::string ParseJson(const std::string &json_str) { - std::stringstream ss; - Json rule_json; - try { - rule_json = Json::parse(json_str); - } catch (const std::exception &e) { - ss << "Error parsing json: " << e.what(); - return ss.str(); - } - - if (!rule_json.contains("shape")) { - ss << "Missing 'shape' field in rule json."; - return ss.str(); - } - - auto shape_json = rule_json["shape"]; - std::vector> inputs; - std::vector> outputs; - - std::string error_msg = ParseJsonToVecVecString(shape_json["inputs"], inputs); - if (!error_msg.empty()) { - ss << "Invalid 'shape.inputs' field: " << shape_json["inputs"] << " " << error_msg; - return ss.str(); - } - error_msg = ParseJsonToVecVecString(shape_json["outputs"], outputs); - if (!error_msg.empty()) { - ss << "Invalid 'shape.outputs' field: " << shape_json["outputs"] << " " << error_msg; - return ss.str(); - } - std::map symbol_defs; - error_msg = GetInputSymbolDefs(inputs, symbol_defs); - if (!error_msg.empty()) { - ss << "Error parsing input symbols: " << error_msg; - return ss.str(); - } - error_msg = GetOutputSymbolTensors(outputs, symbol_defs, symbols_, symbol_tensors_); - if (!error_msg.empty()) { - ss << "Error parsing output tensors: " << error_msg; - return ss.str(); - } - return ss.str(); - } - - void CodegenInferShape(std::stringstream &code_ss) const { - code_ss << R"(extern "C" {)"; - code_ss << R"(bool infer_shape(Ctx *ctx) {)" << std::endl; - - for (const auto &symbol : symbols_) { - code_ss << symbol.Codegen() << std::endl; - } - - code_ss << std::endl; - - for (const auto &tensor : symbol_tensors_) { - code_ss << tensor.Codegen() << std::endl; - } - - code_ss << " return true;\n}" << std::endl; - - code_ss << R"(bool infer_shape_on_compile(Ctx *ctx) {)" << std::endl; - for (const auto &tensor : symbol_tensors_) { - code_ss << tensor.CodegenCompileTime() << std::endl; - } - code_ss << " return true;\n}"; - - code_ss << "}"; - } - - private: - std::vector symbols_; - std::vector symbol_tensors_; - - static std::string GetInputSymbolDefs(const std::vector> &inputs, - std::map &symbol_defs) { - for (size_t i = 0; i < inputs.size(); i++) { - const auto &dims = inputs[i]; - for (size_t j = 0; j < dims.size(); j++) { - const auto &dim = dims[j]; - if (dim.empty() || IsNumber(dim)) { - continue; - } - if (!IsSymbol(dim)) { - std::stringstream ss; - ss << "Invalid input[" << i << "].size(" << j << "): " << dim - << ", symbol dimension must start with 's' or 'v' and follow with a number"; - return ss.str(); - } - auto it = symbol_defs.find(dim); - if (it != symbol_defs.end()) { - // 已经存在,记录来源 - it->second.RecordSource(i, j); - } else { - // 新建符号定义 - SymbolDef symbol(dim); - symbol.RecordSource(i, j); - symbol_defs.emplace(dim, std::move(symbol)); - } - } - } - return ""; - } - - static std::string GetOutputSymbolTensors(const std::vector> &outputs, - const std::map &symbol_defs, - std::vector &used_symbol_defs, - std::vector &symbol_tensors) { - std::set used_symbols; - std::stringstream ss; - for (size_t i = 0; i < outputs.size(); i++) { - symbol_tensors.emplace_back(i); - const auto &dims = outputs[i]; - - for (size_t j = 0; j < dims.size(); j++) { - auto &dim = dims[j]; - if (dim.empty()) { - ss << "Invalid output[" << i << "].size(" << j << "): empty dimension"; - return ss.str(); - } - std::string error_msg = ValidateDimExpr(dim, used_symbols); - if (!error_msg.empty()) { - ss << "Invalid dim expr '" << dim << "': " << error_msg; - return ss.str(); - } - symbol_tensors.back().AppendDim(dim); - } - } - - for (const auto &symbol : used_symbols) { - auto it = symbol_defs.find(symbol); - if (it == symbol_defs.end()) { - ss << "Symbol '" << symbol << "' used in output but not defined in inputs"; - return ss.str(); - } - used_symbol_defs.emplace_back(it->second); - } - - return ""; - } - - static std::string ValidateDimExpr(std::string expr, std::set &used_symbols) { - expr.erase(remove_if(expr.begin(), expr.end(), isspace), expr.end()); - - // 2. 定义 token 正则 - // - 函数/变量名: [A-Za-z0-9_]* - // - 运算符: [+*()-,] - const std::regex token_regex(R"([A-Za-z0-9_]*|\+|\-|\*|\(|\)|,)"); - const auto begin = std::sregex_iterator(expr.begin(), expr.end(), token_regex); - const auto end = std::sregex_iterator(); - - std::vector tokens; // 存储匹配到的 token,应当为操作符、操作数、函数名、括号之一 - for (auto it = begin; it != end; ++it) { - if (!it->str().empty()) { - tokens.push_back(it->str()); - } - } - - // 检查是否所有字符都被匹配(防止非法字符) - size_t totalLen = 0U; - for (auto &t : tokens) totalLen += t.size(); - if (totalLen != expr.size()) { - return "Expression contains invalid characters"; - } - - // 3. 遍历 tokens 检查合法性 - std::stack func_stack; - for (size_t i = 0U; i < tokens.size(); i++) { - const std::string &token = tokens[i]; - - if (std::isalpha(token[0])) { - if (i + 1U < tokens.size() && tokens[i + 1U] == "(") { - if (!IsSupportedFunc(token)) { - return "Invalid function: " + token + ", supported [Div, Floor, Ceil, Pow, Mod]"; - } - } else { - used_symbols.insert(token); - } - } else if (token == "(") { - func_stack.emplace("("); - } else if (token == ")") { - if (func_stack.empty()) { - return "Unmatched ')'"; - } - func_stack.pop(); - } else if (IsSupportedOperator(token) || IsNumber(token)) { - // 运算符不做额外语法检查,由C++编译器处理 - } else { - return "Invalid identifier: '" + token + "', expected start with 's' or 'v' and follow with a number"; - } - } - - if (!func_stack.empty()) { - return "Unmatched '('"; - } - - return ""; - } - - static std::string ParseJsonToVecVecString(const Json &json, std::vector> &result) { - if (json.is_null()) { - return ""; - } - if (!json.is_array()) { - return "field must be an array or null."; - } - - for (const auto &dims : json) { - if (dims.is_null()) { - result.emplace_back(); - continue; - } - if (!dims.is_array()) { - return "element must be an array of dimension expressions."; - } - result.emplace_back(); - for (const auto &dim : dims) { - if (dim.is_null()) { - result.back().emplace_back(); - continue; - } - if (!dim.is_string() && !dim.is_number_integer()) { - return "dimension expression must be a string or integer."; - } - result.back().push_back(dim.is_string() ? dim.get() : std::to_string(dim.get())); - } - } - return ""; - } - - static bool IsSymbol(const std::string &token) { - // 符号必须以 's' 或 'v' 开头,后跟数字 - return token.size() > 1 && (token[0] == 's' || token[0] == 'v') && IsNumber(&token[1]); - } - - static bool IsSupportedFunc(const std::string &func) { - static const std::unordered_set kAllowedFuncs = {"Div", "Floor", "Ceil", "Pow", "Mod"}; - return kAllowedFuncs.find(func) != kAllowedFuncs.end(); - } - - static bool IsSupportedOperator(const std::string &op) { - // 支持的运算符 - return op == "+" || op == "-" || op == "*" || op == ","; - } - - static bool IsNumber(const std::string &s) { - try { - size_t idx; - std::stod(s, &idx); - return idx == s.size(); // 必须整个字符串都被解析 - } catch (...) { - return false; - } - } -}; - -/** - * @brief Cpp JIT编译器 - * - * 用于将生成的C++代码编译为内存中的.so,并加载以供调用。 - */ -class CppJitCompiler { - public: - std::string Error() const { - return err_.str(); - } - - std::vector Compile(const std::string &source_code) { - std::vector so_data; - - const int32_t cpp_fd = CreateMemFd("source.cpp"); - const int32_t so_fd = CreateMemFd("output.so"); - if (cpp_fd == -1 || so_fd == -1) { - err_ << "mem fd create failed: " << strerror(errno); - return {}; - } - - ClearCloexec(cpp_fd); - ClearCloexec(so_fd); - - if (!WriteToFd(cpp_fd, source_code)) { - err_ << "write source code to mem fd failed: " << strerror(errno); - return {}; - } - - lseek(cpp_fd, 0, SEEK_SET); - lseek(so_fd, 0, SEEK_SET); - - if (!CompileToSo(cpp_fd, so_fd)) { - return {}; - } - - lseek(so_fd, 0, SEEK_SET); - - char buf[4096]; - ssize_t n; - while ((n = read(so_fd, buf, sizeof(buf))) > 0) { - so_data.insert(so_data.end(), buf, buf + n); - } - - close(cpp_fd); - close(so_fd); - return so_data; - } - - void *Load(const std::vector &so_binary) { - static std::atomic loaded{0}; - - char tmp_filename[256] = {}; - // make sure the filename is unique for disable cache for dlopen - const std::string filename = "/tmp/temp_so" + std::to_string(loaded++) + "XXXXXX"; - if (snprintf_s(tmp_filename, sizeof(tmp_filename), filename.size(), "%s", filename.c_str()) < 0) { - err_ << "snprintf file name failed: " << strerror(errno); - return nullptr; - } - - const int32_t fd = mkstemp(tmp_filename); - if (fd == -1) { - err_ << "mkstemp failed: " << strerror(errno); - return nullptr; - } - - const ssize_t written = write(fd, so_binary.data(), so_binary.size()); - if (written != static_cast(so_binary.size())) { - err_ << "write so binary to temp file failed: " << strerror(errno); - close(fd); - unlink(tmp_filename); - return nullptr; - } - - close(fd); - - void *handle = dlopen(tmp_filename, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - err_ << "dlopen failed: " << dlerror(); - unlink(tmp_filename); - return nullptr; - } - - unlink(tmp_filename); - return handle; - } - - private: - std::stringstream err_; - static std::string GetSystemCompiler() { - if (system("g++ --version > /dev/null 2>&1") == 0) { - return "g++"; - } - if (system("gcc --version > /dev/null 2>&1") == 0) { - return "gcc"; - } - return ""; - } - - static int32_t CreateMemFd(const std::string &name) { - return syscall(__NR_memfd_create, name.c_str(), MFD_CLOEXEC); - } - - static void ClearCloexec(const int32_t fd) { - const int32_t flags = fcntl(fd, F_GETFD); - if (flags != -1) { - fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC); - } - } - - static bool WriteToFd(const int32_t fd, const std::string &data) { - size_t written = 0; - while (written < data.size()) { - const ssize_t n = write(fd, data.data() + written, data.size() - written); - if (n <= 0) { - return false; - } - written += n; - } - return true; - } - - static bool WriteToFd(const int32_t fd, const std::vector &data) { - size_t written = 0; - while (written < data.size()) { - const ssize_t n = write(fd, data.data() + written, data.size() - written); - if (n <= 0) { - return false; - } - written += n; - } - return true; - } - - bool CompileToSo(const int32_t input_fd, const int32_t output_fd) { - const std::string input_path = "/proc/self/fd/" + std::to_string(input_fd); - const std::string output_path = "/proc/self/fd/" + std::to_string(output_fd); - - const std::string compiler = GetSystemCompiler(); - if (compiler.empty()) { - err_ << "No C++ compiler found (g++ or gcc) for jit compiling symbol infer"; - return false; - } - - const std::vector args = { - compiler.c_str(), "-x", "c++", "-shared", "-fPIC", "-o", output_path.c_str(), - input_path.c_str(), "-lstdc++", nullptr}; - - const pid_t pid = fork(); - if (pid == 0) { - execvp(compiler.c_str(), const_cast(args.data())); - _exit(1); - } - - int32_t status = 0; - waitpid(pid, &status, 0); - const bool succeed = WIFEXITED(status) && WEXITSTATUS(status) == 0; - if (!succeed) { - err_ << "syntax error"; - } - return succeed; - } -}; - -const std::string kHeader = R"( -#include -#include - -inline double Pow(const double base, const double exp) { return std::pow(base, exp); } -inline double Floor(const double x) { return std::floor(x); } -inline double Div(const double x, const double y) { return x / y; } -inline double Ceil(const double x) { return std::ceil(x); } -inline double Mod(const double a, const double b) { - double r = std::fmod(a, b); - if ((r != 0) && ((b < 0 && r > 0) || (b > 0 && r < 0))) { - r += b; - } - return r; -} - -extern "C" { -int64_t version() { return 1; } -} - -class Ctx { - public: - virtual ~Ctx() = default; - virtual bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) = 0; - virtual bool GetInputValue(int64_t input, int64_t offset, int64_t &value) = 0; - virtual bool SetOutputDimNum(int64_t output, int64_t dim_num) = 0; - virtual bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) = 0; - virtual void SetError(const char *) = 0; -}; - -#define GET_SYMBOL_DIM(S, INPUT, DIM) \ -int64_t S##_int; \ -if (!ctx->GetInputDim(INPUT, DIM, S##_int)) { \ - ctx->SetError("Failed to get dim sym '" #S "' from input[" #INPUT "], dim: " #DIM); \ - return false; \ -} \ -const double S = static_cast(S##_int); - -#define GET_SYMBOL_VALUE(S, INPUT, DIM) \ -int64_t S##_int; \ -if (!ctx->GetInputValue(INPUT, DIM, S##_int)) { \ - ctx->SetError("Failed to get value sym '" #S "' from input[" #INPUT "], offset: " #DIM); \ - return false; \ -} \ -const double S = static_cast(S##_int); - -#define SET_OUTPUT_RANK(OUTPUT, RANK) \ -if (!ctx->SetOutputDimNum(OUTPUT, RANK)) { \ - ctx->SetError("Failed to set rank " #RANK " for output[" #OUTPUT "]"); \ - return false; \ -} - -#define SET_OUTPUT_DIM(OUTPUT, INDEX, DIM) \ -if (!ctx->SetOutputDim(OUTPUT, INDEX, DIM)) { \ - ctx->SetError("Failed to set dim " #DIM " for output[" #OUTPUT "], dim: " #INDEX); \ - return false; \ -} -)"; - -/** - * @brief 适用于GertCtx的包装器 - * - * Jit生成InferShape代码时,设计时保证不使用任何本地头文件参与编译,通过运行时的Ctx封装,隔离本地文件依赖。 - */ -class GertContextWrapper final : public ShapeInferenceRule::Ctx { - public: - explicit GertContextWrapper(gert::InferShapeContext *ctx) : ctx_(ctx) {} - - bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) override { - const auto shape = ctx_->GetInputShape(input); - if (shape == nullptr) { - return false; - } - dim = shape->GetDim(dim_index); - return true; - } - - bool GetInputValue(int64_t input, int64_t offset, int64_t &value) override { - auto *tensor = ctx_->GetInputTensor(input); - if (tensor == nullptr || tensor->GetAddr() == nullptr) { - return false; - } - if (offset < 0 || offset >= tensor->GetShapeSize()) { - return false; - } - if (tensor->GetDataType() == ge::DT_INT64) { - value = tensor->GetData()[offset]; - } else if (tensor->GetDataType() == ge::DT_INT32) { - value = tensor->GetData()[offset]; - } else if (tensor->GetDataType() == ge::DT_UINT32) { - value = tensor->GetData()[offset]; - } else { - SetError("Only int32, uint32 and int64 are supported for input value tensors"); - return false; - } - return true; - } - - bool SetOutputDimNum(int64_t output, int64_t dim_num) override { - const auto shape = ctx_->GetOutputShape(output); - if (shape == nullptr) { - return false; - } - shape->SetDimNum(dim_num); - return true; - } - - bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) override { - const auto shape = ctx_->GetOutputShape(output); - if (shape == nullptr) { - return false; - } - shape->SetDim(dim_index, dim); - return true; - } - - void SetError(const char *msg) override { - if (msg != nullptr) { - error_message_ << msg << std::endl; - } - } - - std::string Error() const { - return error_message_.str(); - } - - private: - gert::InferShapeContext *ctx_ = nullptr; - std::stringstream error_message_; -}; - -template -class Cache { - public: - std::shared_ptr Get(const std::string &key) { - std::lock_guard lock(mtx_); - auto it = cache_.find(key); - if (it != cache_.end()) { - return it->second; - } - return nullptr; - } - - std::shared_ptr GetWithDefault(const std::string &key, const std::shared_ptr &value) { - std::lock_guard lock(mtx_); - return cache_.emplace(key, value).first->second; - } - - private: - std::mutex mtx_; - std::map> cache_; -}; - -Cache g_shape_rule_cache; -Cache g_dtype_rule_cache; -} // namespace - -ShapeInferenceRule::~ShapeInferenceRule() { - if (handle_) { - dlclose(handle_); - handle_ = nullptr; - infer_shape_ = nullptr; - infer_shape_on_compile_ = nullptr; - } -} - -ge::graphStatus ShapeInferenceRule::InferOnRuntime(Ctx *ctx) const { - if (!infer_shape_) { - ctx->SetError("infer_shape function is not set"); - return ge::GRAPH_FAILED; - } - if (!infer_shape_(ctx)) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ShapeInferenceRule::InferOnCompile(Ctx *ctx) const { - if (!infer_shape_on_compile_) { - ctx->SetError("infer_shape_on_compile function is not set"); - return ge::GRAPH_FAILED; - } - if (!infer_shape_on_compile_(ctx)) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ShapeInferenceRule::InferOnRuntime(gert::InferShapeContext *infer_shape_ctx) const { - GE_ASSERT_NOTNULL(infer_shape_ctx); - auto ctx = GertContextWrapper(infer_shape_ctx); - const ge::graphStatus result = InferOnRuntime(&ctx); - if (result != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Failed infer shape by rule for op %s(%s): %s", infer_shape_ctx->GetNodeName(), - infer_shape_ctx->GetNodeType(), ctx.Error().c_str()); - } - return result; -} - -ge::graphStatus ShapeInferenceRule::InferOnCompile(gert::InferShapeContext *infer_shape_ctx) const { - GE_ASSERT_NOTNULL(infer_shape_ctx); - auto ctx = GertContextWrapper(infer_shape_ctx); - const ge::graphStatus result = InferOnCompile(&ctx); - if (result != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Failed infer shape on compile by rule for op %s(%s): %s", infer_shape_ctx->GetNodeName(), - infer_shape_ctx->GetNodeType(), ctx.Error().c_str()); - } - return result; -} - -std::string InferenceRule::GetInferenceRule(const ge::OpDescPtr &op) { - if (op == nullptr) { - return ""; - } - std::string rule_json; - (void) ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json); - return rule_json; -} - -std::shared_ptr ShapeInferenceRule::FromOpDesc(const ge::OpDescPtr &op) { - std::string rule_json; - if (!ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json)) { - // Skip log error if op does not with rule - return nullptr; - } - return FromJsonString(rule_json); -} - -std::shared_ptr ShapeInferenceRule::FromJsonString(const std::string &json_str) { - auto cached = g_shape_rule_cache.Get(json_str); - if (cached != nullptr) { - return cached; - } - - const auto rule = std::make_shared(); - RuleJsonParser parser; - const std::string error_msg = parser.ParseJson(json_str); - if (!error_msg.empty()) { - *rule << error_msg; - return g_shape_rule_cache.GetWithDefault(json_str, rule); - } - - std::stringstream gen_code_ss; - parser.CodegenInferShape(gen_code_ss); - - std::stringstream code_ss; - code_ss << kHeader << std::endl; - code_ss << gen_code_ss.str() << std::endl; - - CppJitCompiler compiler; - const auto binary = compiler.Compile(code_ss.str()); - if (binary.empty()) { - *rule << "Failed to compile C++ code to shared object:\n" << gen_code_ss.str() << "\nError: " << compiler.Error(); - return g_shape_rule_cache.GetWithDefault(json_str, rule); - } - return g_shape_rule_cache.GetWithDefault(json_str, std::make_shared(FromCompiledBinary(binary))); -} - -ShapeInferenceRule ShapeInferenceRule::FromCompiledBinary(const std::vector &binary) { - ShapeInferenceRule infer_handle; - CppJitCompiler compiler; - void *handle = compiler.Load(binary); - if (!handle) { - infer_handle << "Failed to load compiled shared object from memory: " << compiler.Error(); - return infer_handle; - } - - infer_handle.handle_ = handle; - infer_handle.infer_shape_ = (InferShapeFunc) dlsym(handle, "infer_shape"); - if (!infer_handle.infer_shape_) { - infer_handle << "dlsym infer_shape failed: " << dlerror(); - return infer_handle; - } - infer_handle.infer_shape_on_compile_ = (InferShapeFunc) dlsym(handle, "infer_shape_on_compile"); - if (!infer_handle.infer_shape_on_compile_) { - infer_handle << "dlsym infer_shape_on_compile failed: " << dlerror(); - return infer_handle; - } - return infer_handle; -} - -ge::graphStatus ShapeInferenceRule::CompileJsonString(const std::string &json_str, std::vector &binary) { - RuleJsonParser parser; - const std::string error_msg = parser.ParseJson(json_str); - if (!error_msg.empty()) { - GELOGE(ge::FAILED, "%s", error_msg.c_str()); - return ge::GRAPH_FAILED; - } - - std::stringstream code_ss; - code_ss << kHeader << std::endl; - parser.CodegenInferShape(code_ss); - - CppJitCompiler compiler; - binary = compiler.Compile(code_ss.str()); - if (binary.empty()) { - GELOGE(ge::FAILED, "Failed to compile C++ code to shared object:%s,\nError:%s", code_ss.str().c_str(), - compiler.Error().c_str()); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DtypeInferenceRule::InferDtype(gert::InferDataTypeContext *infer_dtype_ctx) const { - GE_ASSERT_NOTNULL(infer_dtype_ctx); - if (!Error().empty()) { - GELOGE(ge::FAILED, "Failed infer dtype by rule for op %s(%s): %s", infer_dtype_ctx->GetNodeName(), - infer_dtype_ctx->GetNodeType(), Error().c_str()); - return ge::GRAPH_FAILED; - } - for (size_t i = 0U; i < dtypes_.size(); i++) { - GE_ASSERT_GRAPH_SUCCESS(infer_dtype_ctx->SetOutputDataType(i, dtypes_[i])); - } - return ge::GRAPH_SUCCESS; -} - -std::shared_ptr DtypeInferenceRule::FromOpDesc(const ge::OpDescPtr &op) { - std::string rule_json; - if (!ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json)) { - // Skip log error if op does not with rule - return nullptr; - } - return FromJsonString(rule_json); -} - -std::shared_ptr DtypeInferenceRule::FromJsonString(const std::string &json_str) { - auto cached = g_dtype_rule_cache.Get(json_str); - if (cached != nullptr) { - return cached; - } - - const auto rule = std::make_shared(); - Json rule_json; - try { - rule_json = Json::parse(json_str); - } catch (const std::exception &e) { - *rule << "Error parsing json: " << e.what(); - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - if (!rule_json.contains("dtype")) { - *rule << "Missing 'dtype' field in rule json."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - const auto dtype_json = rule_json["dtype"]; - if (dtype_json.is_null()) { - *rule << "Filed 'dtype' must not be null."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - if (!dtype_json.is_array()) { - *rule << "Field 'dtype' must be an array."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - for (const auto &dtype : dtype_json) { - if (dtype.is_null()) { - *rule << "Element in 'dtype' field must not be null."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - if (!dtype.is_number_integer()) { - *rule << "Element in 'dtype' field must be an integer."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - const int32_t dtype_value = dtype.get(); - if (dtype_value >= ge::DataType::DT_MAX || dtype_value < 0 || dtype_value == ge::DataType::DT_UNDEFINED) { - *rule << "Element " << dtype_value << " in 'dtype' field is out of range [0," << ge::DataType::DT_MAX - << "(DT_MAX)) and cannot be " << ge::DataType::DT_UNDEFINED << "(DT_UNDEFINED)."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - rule->dtypes_.emplace_back(static_cast(dtype_value)); - } - - return g_dtype_rule_cache.GetWithDefault(json_str, rule); -} diff --git a/register/inference_rule.h b/register/inference_rule.h deleted file mode 100644 index 056385a0e5..0000000000 --- a/register/inference_rule.h +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd.|Hisilicon Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_INFERENCE_RULE_H -#define REGISTER_INFERENCE_RULE_H - -#include -#include - -#include "external/exe_graph/runtime/infer_shape_context.h" -#include "external/exe_graph/runtime/infer_datatype_context.h" -#include "graph/op_desc.h" - -/** - * @brief 推导规则基类 - * - * 为了引导原始错误记录在对象上,不分散的打印日志,有助于向用户展示明确报错。 - */ -class InferenceRule { - public: - template - InferenceRule &operator<<(const T &msg) { - err_ << msg; - return *this; - } - - std::string Error() const { - return err_.str(); - } - - bool IsValid() const { - return err_.str().empty(); - } - static std::string GetInferenceRule(const ge::OpDescPtr &op); - - protected: - std::stringstream err_; -}; - -/** - * @brief Shape推导实现类 - * - * 负责从不同类型的输入编译并加载得到Shape推导可执行函数,并与GE数据结构配合工作。 - */ -class ShapeInferenceRule : public InferenceRule { - public: - // Ctx接口定义,供推导函数调用,不依赖任何头文件。实现与用户环境完全隔离。 - class Ctx { - public: - virtual ~Ctx() = default; - - virtual bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) = 0; - - virtual bool GetInputValue(int64_t input, int64_t offset, int64_t &value) = 0; - - virtual bool SetOutputDimNum(int64_t output, int64_t dim_num) = 0; - - virtual bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) = 0; - - virtual void SetError(const char *) = 0; - }; - - using InferShapeFunc = bool (*)(Ctx *); - - ShapeInferenceRule() : handle_(nullptr), infer_shape_(nullptr), infer_shape_on_compile_(nullptr) {} - ~ShapeInferenceRule(); - ShapeInferenceRule(const ShapeInferenceRule &) = delete; - ShapeInferenceRule &operator=(const ShapeInferenceRule &) = delete; - ShapeInferenceRule &operator=(ShapeInferenceRule &&other) = delete; - ShapeInferenceRule(ShapeInferenceRule &&other) noexcept { - handle_ = other.handle_; - infer_shape_ = other.infer_shape_; - infer_shape_on_compile_ = other.infer_shape_on_compile_; - err_ << other.err_.str(); - other.handle_ = nullptr; - other.infer_shape_ = nullptr; - other.infer_shape_on_compile_ = nullptr; - } - - static std::shared_ptr FromOpDesc(const ge::OpDescPtr &op); - static std::shared_ptr FromJsonString(const std::string &json_str); - - // 编译后的二进制以属性的方式保存在节点上,用于RT2执行时加载 - static ge::graphStatus CompileJsonString(const std::string &json_str, std::vector &binary); - static ShapeInferenceRule FromCompiledBinary(const std::vector &binary); - - ge::graphStatus InferOnRuntime(gert::InferShapeContext *infer_shape_ctx) const; - ge::graphStatus InferOnCompile(gert::InferShapeContext *infer_shape_ctx) const; - - ge::graphStatus InferOnRuntime(Ctx *ctx) const; - ge::graphStatus InferOnCompile(Ctx *ctx) const; - - private: - void *handle_; - InferShapeFunc infer_shape_; - InferShapeFunc infer_shape_on_compile_; -}; - -/** - * @brief Dtype推导实现类 - * - * 负责从不同类型的解析得到Shape推导可执行函数,并与GE图结构配合工作。Dtype推导实现无需编译。 - */ -class DtypeInferenceRule : public InferenceRule { - public: - static std::shared_ptr FromOpDesc(const ge::OpDescPtr &op); - static std::shared_ptr FromJsonString(const std::string &json_str); - - ge::graphStatus InferDtype(gert::InferDataTypeContext *infer_dtype_ctx) const; - - private: - std::vector dtypes_; -}; -#endif diff --git a/register/shape_inference.cc b/register/shape_inference.cc index b321dcedd2..f744629783 100644 --- a/register/shape_inference.cc +++ b/register/shape_inference.cc @@ -18,7 +18,6 @@ #include "graph/utils/transformer_utils.h" #include "register/op_impl_space_registry.h" #include "common/checker.h" -#include "register/inference_rule.h" namespace gert { namespace { @@ -522,88 +521,6 @@ ge::graphStatus UpdateOpDescOutFormat(const ge::OpDescPtr &op_desc, gert::InferF } return ge::GRAPH_SUCCESS; } - - -ge::graphStatus InferShapeByRegisteredFuncOrRule(const OpImplKernelRegistry::OpImplFunctionsV2 *functions, - const ge::OpDescPtr &op_desc, - gert::InferShapeContext *infer_shape_ctx) { - if (functions && functions->infer_shape) { - if (functions->IsOutputShapeDependOnCompute()) { - GELOGD("OpDesc %s(%s) is third class operator", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - (void) ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, - static_cast(ge::DEPEND_SHAPE_RANGE)); - } - GELOGD("Infer shape for %s[%s] by registered func", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return functions->infer_shape(infer_shape_ctx); - } - const auto shape_infer_rule = ShapeInferenceRule::FromOpDesc(op_desc); - if (shape_infer_rule == nullptr) { - REPORT_CALL_ERROR("EZ9999", - "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared " - "library (.so) has been loaded " - "successfully, and that you have already developed the infer_shape func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - GELOGE(ge::GRAPH_FAILED, - "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared library (.so) " - "has been loaded " - "successfully, and that you have already developed the infer_shape func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - if (!shape_infer_rule->IsValid()) { - REPORT_CALL_ERROR( - "EZ9999", - "No infer shape func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), InferenceRule::GetInferenceRule(op_desc).c_str(), - shape_infer_rule->Error().c_str()); - GELOGE(ge::GRAPH_FAILED, - "No infer shape func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), InferenceRule::GetInferenceRule(op_desc).c_str(), - shape_infer_rule->Error().c_str()); - return ge::GRAPH_FAILED; - } - GELOGD("Infer shape for %s[%s] by inference rule", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return shape_infer_rule->InferOnCompile(infer_shape_ctx); -} - -ge::graphStatus InferDtypeByRegisteredFuncOrRule(const OpImplKernelRegistry::OpImplFunctionsV2 *functions, - const ge::OpDescPtr &op_desc, - gert::InferDataTypeContext *infer_dtype_ctx) { - if (functions && functions->infer_datatype) { - GELOGD("Infer dtype for %s[%s] by registered func", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return functions->infer_datatype(infer_dtype_ctx); - } - const auto dtype_infer_rule = DtypeInferenceRule::FromOpDesc(op_desc); - if (dtype_infer_rule == nullptr) { - REPORT_CALL_ERROR("EZ9999", - "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto " - "shared library (.so) has been " - "loaded successfully, and that you have already developed the infer_datatype func or marked " - "the T-derivation rules on the IR.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - GELOGE(ge::GRAPH_FAILED, - "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto shared library " - "(.so) has been " - "loaded successfully, and that you have already developed the infer_datatype func or marked " - "the T-derivation rules on the IR.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - if (!dtype_infer_rule->IsValid()) { - REPORT_CALL_ERROR( - "EZ9999", - "No infer dtype func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), InferenceRule::GetInferenceRule(op_desc).c_str(), - dtype_infer_rule->Error().c_str()); - GELOGE(ge::GRAPH_FAILED, - "No infer dtype func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), InferenceRule::GetInferenceRule(op_desc).c_str(), - dtype_infer_rule->Error().c_str()); - return ge::GRAPH_FAILED; - } - GELOGD("Infer dtype for %s[%s] by inference rule", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return dtype_infer_rule->InferDtype(infer_dtype_ctx); -} } ge::graphStatus InferShapeRangeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc) { @@ -632,15 +549,35 @@ ge::graphStatus InferShapeRangeOnCompile(const ge::Operator &op, const ge::OpDes } ge::graphStatus InferShapeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); + const auto *const space_registry = DefaultOpImplSpaceRegistry::GetInstance() + .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); GE_ASSERT_NOTNULL(space_registry); + const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); + if ((functions == nullptr) || (functions->infer_shape == nullptr)) { + REPORT_CALL_ERROR( + "EZ9999", + "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared library (.so) has been loaded " + "successfully, and that you have already developed the infer_shape func.", + op_desc->GetNamePtr(), op_desc->GetTypePtr()); + GELOGE(ge::GRAPH_FAILED, + "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared library (.so) has been loaded " + "successfully, and that you have already developed the infer_shape func.", op_desc->GetNamePtr(), + op_desc->GetTypePtr()); + return ge::GRAPH_FAILED; + } + + if (functions->IsOutputShapeDependOnCompute()) { + GELOGD("OpDesc %s(%s) is third class operator", op_desc->GetNamePtr(), op_desc->GetTypePtr()); + (void)ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, + static_cast(ge::DEPEND_SHAPE_RANGE)); + } + ge::NodeShapeTransUtils transformer(op_desc); - GE_CHK_BOOL_RET_STATUS(transformer.Init(), ge::GRAPH_FAILED, "Failed to init transformer for %s", - op_desc->GetNamePtr()); - GE_CHK_BOOL_RET_STATUS(transformer.CatchFormatAndShape(), ge::GRAPH_FAILED, "Failed to catch format and shape for %s", - op_desc->GetNamePtr()); + GE_CHK_BOOL_RET_STATUS(transformer.Init(), ge::GRAPH_FAILED, + "Failed to init transformer for %s", op_desc->GetNamePtr()); + GE_CHK_BOOL_RET_STATUS(transformer.CatchFormatAndShape(), ge::GRAPH_FAILED, + "Failed to catch format and shape for %s", op_desc->GetNamePtr()); std::vector> inputs_holder; std::vector> outputs_holder; std::vector> ge_tensors_holder; @@ -653,27 +590,19 @@ ge::graphStatus InferShapeOnCompile(const ge::Operator &op, const ge::OpDescPtr GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextOutputs(op_desc, outputs_holder), "[Construct][InferShapeContextOutputs] failed, op_desc[%s]", op_desc->GetName().c_str()); const auto kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(op, inputs_holder)) - .Outputs(GetOutputs(outputs_holder)) - .Build(op_desc); + .Inputs(GetInputs(op, inputs_holder)).Outputs(GetOutputs(outputs_holder)).Build(op_desc); auto infer_shape_ctx = reinterpret_cast(kernel_context_holder.context_); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - - ret = InferShapeByRegisteredFuncOrRule(functions, op_desc, infer_shape_ctx); + ret = functions->infer_shape(infer_shape_ctx); GE_CHK_STATUS_RET(ret, "[Call][InferShapeV2Func] failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - - GE_ASSERT_GRAPH_SUCCESS(UpdateOpDescOutShape(op_desc, infer_shape_ctx), - "UpdateOpDescOutShape failed, OutputShape is nullptr. op_desc[%s]", - op_desc->GetName().c_str()); + GE_ASSERT_GRAPH_SUCCESS(UpdateOpDescOutShape(op_desc, infer_shape_ctx), "UpdateOpDescOutShape failed, OutputShape is nullptr. op_desc[%s]", op_desc->GetName().c_str()); GE_CHK_BOOL_RET_STATUS(transformer.UpdateFormatAndShape(), ge::GRAPH_FAILED, "Failed to update format and shape for %s", op_desc->GetNamePtr()); return ge::GRAPH_SUCCESS; } ge::graphStatus InferDataTypeOnCompile(const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); + const auto *const space_registry = DefaultOpImplSpaceRegistry::GetInstance() + .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); if (space_registry == nullptr) { GELOGW("Default space registry has not been initialized!"); if (op_desc->IsSupportSymbolicInferDataType()) { @@ -684,22 +613,33 @@ ge::graphStatus InferDataTypeOnCompile(const ge::OpDescPtr &op_desc) { op_desc->GetNamePtr(), op_desc->GetTypePtr()); return ge::GRAPH_FAILED; } + const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); + if ((functions == nullptr) || + (functions->infer_datatype == nullptr)) { + if (op_desc->IsSupportSymbolicInferDataType()) { + return op_desc->SymbolicInferDataType(); + } + + REPORT_CALL_ERROR( + "EZ9999", + "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto shared library (.so) has been " + "loaded successfully, and that you have already developed the infer_datatype func or marked " + "the T-derivation rules on the IR.", + op_desc->GetNamePtr(), op_desc->GetTypePtr()); + GELOGE(ge::GRAPH_FAILED, + "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto shared library (.so) has been " + "loaded successfully, and that you have already developed the infer_datatype func or marked " + "the T-derivation rules on the IR.", + op_desc->GetNamePtr(), op_desc->GetTypePtr()); + return ge::GRAPH_FAILED; + } std::vector inputs; std::vector outputs; ConstructDataTypeContextInputs(op_desc, inputs); ConstructDataTypeContextOutputs(op_desc, outputs); const auto kernel_context_holder = gert::KernelRunContextBuilder().Inputs(inputs).Outputs(outputs).Build(op_desc); const auto kernel_context = reinterpret_cast(kernel_context_holder.context_); - - ge::graphStatus ret = ge::GRAPH_FAILED; - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - - if ((!functions || !functions->infer_datatype) && op_desc->IsSupportSymbolicInferDataType()) { - GELOGD("Infer dtype for %s[%s] by ir symbol", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return op_desc->SymbolicInferDataType(); - } - - ret = InferDtypeByRegisteredFuncOrRule(functions, op_desc, kernel_context); + const auto ret = functions->infer_datatype(kernel_context); GE_CHK_STATUS_RET(ret, "[Check][InferDataType] result failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); for (size_t i = 0UL; i < op_desc->GetOutputsSize(); i++) { const auto &out_desc = op_desc->MutableOutputDesc(static_cast(i)); diff --git a/tests/ut/register/CMakeLists.txt b/tests/ut/register/CMakeLists.txt index a8f532b624..537b915f2c 100644 --- a/tests/ut/register/CMakeLists.txt +++ b/tests/ut/register/CMakeLists.txt @@ -16,11 +16,10 @@ include_directories(${METADEF_DIR}) include_directories(${METADEF_DIR}/register) file(GLOB_RECURSE REGISTER_UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/register/*.cc" ) -file(GLOB_RECURSE FAKER_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/faker/*.cc") file(GLOB_RECURSE UTILS_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/graph/common/*.cc" ) add_executable(ut_register - ${REGISTER_UT_FILES} ${UTILS_FILES} ${FAKER_SRCS} + ${REGISTER_UT_FILES} ${UTILS_FILES} ) add_compile_definitions(CMAKE_BINARY_DIR=\"${CMAKE_BINARY_DIR}\") target_compile_options(ut_register PRIVATE @@ -61,6 +60,5 @@ target_link_libraries(ut_register target_include_directories(ut_register PRIVATE ${METADEF_DIR}/tests/ut/graph/common - ${METADEF_DIR}/tests/depends ) diff --git a/tests/ut/register/testcase/inference_rule_unittest.cc b/tests/ut/register/testcase/inference_rule_unittest.cc deleted file mode 100644 index e92db0149b..0000000000 --- a/tests/ut/register/testcase/inference_rule_unittest.cc +++ /dev/null @@ -1,1004 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/operator_reg.h" -#include "graph/debug/ge_log.h" -#include "register/inference_rule.h" -#include "register/shape_inference.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_registry_base.h" -#include "tests/depends/faker/kernel_run_context_faker.h" - -using Json = nlohmann::json; -using namespace gert; - -namespace ge { -REG_OP(RuleInferOp) - .DYNAMIC_INPUT(x, TensorType::ALL()) - .DYNAMIC_OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(RuleInferOp); -} // namespace ge - -namespace { -class CtxMaker { - public: - CtxMaker() : compile_holder(), runtime_holder(), dtypes_holder() { - json["shape"]["inputs"] = Json::array(); - json["shape"]["outputs"] = Json::array(); - json["dtype"] = Json::array(); - } - - CtxMaker &Input(const Json::array_t &input, const std::initializer_list runtime_input) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewShape()); - runtime_inputs.emplace_back(NewShape(runtime_input)); - auto &compile_input = compile_inputs.back()->MutableOriginShape(); - compile_input.SetDimNum(runtime_input.size()); - for (size_t i = 0; i < runtime_input.size(); ++i) { - const auto &dim = input[i]; - if (dim.is_string()) { - compile_input.SetDim(i, -1); - } else if (dim.is_number_integer()) { - const int64_t dim_value = dim.get(); - compile_input.SetDim(i, dim_value); - } else { - compile_input.SetDim(i, -3); - } - } - return *this; - } - - CtxMaker &ValueInput(const Json::array_t &input, const std::initializer_list runtime_input, - ge::DataType dtype) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewTensor(runtime_input, dtype)); - runtime_inputs.emplace_back(NewTensor(runtime_input, dtype)); - return *this; - } - - CtxMaker &NullInput() { - json["shape"]["inputs"].push_back(nullptr); - compile_inputs.emplace_back(nullptr); - runtime_inputs.emplace_back(nullptr); - return *this; - } - - CtxMaker &Output(const Json::array_t &output) { - json["shape"]["outputs"].push_back(output); - compile_outputs.emplace_back(NewShape()); - runtime_outputs.emplace_back(NewShape()); - return *this; - } - - CtxMaker &Dtypes(const Json::array_t &dtypes) { - json["dtype"] = dtypes; - output_dtypes.resize(dtypes.size(), ge::DataType::DT_UNDEFINED); - for (auto &output_dtype : output_dtypes) { - ctx_dtypes.emplace_back(&output_dtype); - } - return *this; - } - - std::string Str() const { - return json.dump(); - } - - void Build(bool with_rule = true) { - const auto rule_op = std::make_shared("op"); - rule_op->create_dynamic_input_x(compile_inputs.size()); - rule_op->create_dynamic_output_y(compile_outputs.size()); - for (size_t i = 0; i < compile_inputs.size(); ++i) { - if (compile_inputs[i] == nullptr) { - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc()); - continue; - } - auto &storage_shape = compile_inputs[i]->MutableOriginShape(); - std::vector dims; - dims.reserve(storage_shape.GetDimNum()); - for (size_t j = 0; j < storage_shape.GetDimNum(); ++j) { - dims.push_back(storage_shape.GetDim(j)); - } - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc(ge::Shape(dims), ge::FORMAT_ND, ge::DT_FLOAT16)); - } - desc = ge::OpDescUtils::GetOpDescFromOperator(*rule_op); - if (with_rule) { - ge::AttrUtils::SetStr(desc, "_inference_rule", Str()); - } - op = rule_op; - - std::vector inputs; - std::vector outputs; - inputs.reserve(compile_inputs.size()); - for (auto &input : compile_inputs) { - inputs.emplace_back(input); - } - outputs.reserve(compile_outputs.size()); - for (auto &output : compile_outputs) { - outputs.emplace_back(output); - } - - compile_holder = InferShapeContextFaker() - .IrInputNum(inputs.size()) - .NodeIoNum(inputs.size(), outputs.size()) - .InputShapes(inputs) - .OutputShapes(outputs) - .Build(); - - std::vector rt_inputs; - std::vector rt_outputs; - rt_inputs.reserve(runtime_inputs.size()); - for (auto &input : runtime_inputs) { - rt_inputs.emplace_back(input); - } - rt_outputs.reserve(runtime_outputs.size()); - for (auto &output : runtime_outputs) { - rt_outputs.emplace_back(output); - } - - runtime_holder = InferShapeContextFaker() - .IrInputNum(rt_inputs.size()) - .NodeIoNum(rt_inputs.size(), rt_outputs.size()) - .InputShapes(rt_inputs) - .OutputShapes(rt_outputs) - .Build(); - - dtypes_holder = InferDataTypeContextFaker() - .IrInputNum(rt_inputs.size()) - .NodeIoNum(rt_inputs.size(), rt_outputs.size()) - .OutputDataTypes(ctx_dtypes) - .Build(); - } - - InferShapeContext *CompileCtx() { - return compile_holder.GetContext(); - } - - InferShapeContext *RuntimeCtx() { - return runtime_holder.GetContext(); - } - - InferDataTypeContext *DtypeCtx() { - return dtypes_holder.GetContext(); - } - - ge::OpDescPtr OpDesc() const { - return desc; - } - - ge::Operator &Operator() const { - return *op; - } - - StorageShape *NewShape() { - holders.emplace_back(std::make_shared()); - return holders.back().get(); - } - - StorageShape *NewTensor(const std::initializer_list &runtime_input, ge::DataType dtype) { - values.emplace_back(std::shared_ptr(malloc(sizeof(int64_t) * runtime_input.size()), std::free)); - auto shape = StorageShape({static_cast(runtime_input.size())}, {static_cast(runtime_input.size())}); - tensor_holders.emplace_back(std::make_shared(shape, StorageFormat(), kOnHost, dtype, values.back().get())); - if (dtype == ge::DT_INT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } else if (dtype == ge::DT_INT64) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = dim; - } - } else if (dtype == ge::DT_UINT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } - return reinterpret_cast(tensor_holders.back().get()); - } - - StorageShape *NewShape(const std::initializer_list &runtime_input) { - holders.emplace_back(std::make_shared(runtime_input, runtime_input)); - return holders.back().get(); - } - - Json json; - std::vector compile_inputs; - std::vector runtime_inputs; - std::vector compile_outputs; - std::vector runtime_outputs; - - std::vector> holders; - FakeKernelContextHolder compile_holder; - FakeKernelContextHolder runtime_holder; - FakeKernelContextHolder dtypes_holder; - - std::vector> values; - std::vector> tensor_holders; - - std::vector ctx_dtypes; - std::vector output_dtypes; - - std::shared_ptr op = nullptr; - ge::OpDescPtr desc = nullptr; -}; -} // namespace - -class InferenceRuleUtest : public testing::Test { - protected: - void SetUp() override { - // construct op impl registry - const auto space_registry = std::make_shared(); - const auto registry_holder = std::make_shared(); - const auto funcs = gert::OpImplRegistry::GetInstance().CreateOrGetOpImpl("RuleInferOp"); - registry_holder->AddTypesToImpl("RuleInferOp", funcs); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - } - - void TearDown() override {} - - static std::string ShapeEqual(Shape *shape, std::initializer_list dims) { - std::stringstream ss; - if (shape == nullptr) { - return "shape == nullptr"; - } - if (shape->GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape->GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape->GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape->GetDim(i); - return ss.str(); - } - } - return ""; - } - - static std::string ShapeEqual(const ge::GeShape &shape, std::initializer_list dims) { - std::stringstream ss; - if (shape.GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape.GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape.GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape.GetDim(i); - return ss.str(); - } - } - return ""; - } -}; - -TEST_F(InferenceRuleUtest, BasicDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, MultiDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s1"}, {32, 64}).Output({"s1", "s0"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {64, 32}), ""); -} - -TEST_F(InferenceRuleUtest, DimSymbolWithFunctionVertical) { - CtxMaker ctx_maker; - int64_t s0 = 32; - int64_t s1 = 64; - // "+", "-", "*", "Div", "Floor", "Ceil", "Pow", "Mod" - ctx_maker.Input({"s0", "s1"}, {s0, s1}) - .Output({"s1+s0"}) - .Output({"s1-s0"}) - .Output({"s1*s0"}) - .Output({"Div(s1,s0)"}) - .Output({"Floor(Div(s1,3))"}) - .Output({"Ceil(Div(s1,3))"}) - .Output({"Pow(s0,2)"}) - .Output({"Mod(s1,7)"}) - .Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(1), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(2), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(3), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(4), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(5), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(6), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(7), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {s1 + s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(1), {s1 - s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(2), {s1 * s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(3), {s1 / s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(4), {s1 / 3}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(5), {(s1 + 2) / 3}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(6), {s0 * s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(7), {s1 % 7}), ""); -} - -TEST_F(InferenceRuleUtest, DimSymbolWithFunctionHorizontal) { - CtxMaker ctx_maker; - int64_t s0 = 32; - int64_t s1 = 64; - // "+", "-", "*", "Div", "Floor", "Ceil", "Pow", "Mod" - ctx_maker.Input({"s0", "s1"}, {s0, s1}) - .Output( - {"s1+s0", "s1-s0", "s1*s0", "Div(s1,s0)", "Floor(Div(s1,3))", "Ceil(Div(s1,3))", "Pow(s0,2)", "Mod(s1,7)"}) - .Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1, -1, -1, -1, -1, -1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), - {s1 + s0, s1 - s0, s1 * s0, s1 / s0, s1 / 3, (s1 + 2) / 3, s0 * s0, s1 % 7}), - ""); -} - -TEST_F(InferenceRuleUtest, StaticDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Output({"128", "32+24"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {128, 56}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {128, 56}), ""); -} - -TEST_F(InferenceRuleUtest, NullDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", nullptr, "s1"}, {32, 20, 24}).Output({"s0+s1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {56}), ""); -} - -TEST_F(InferenceRuleUtest, RepeatDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s0"}, {32, 32}).Input({"s1"}, {24}).Input({"s1"}, {24}).Output({"s0+s1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {56}), ""); -} - -TEST_F(InferenceRuleUtest, SymbolMixStrAndIntAndNull) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", 128, "s1", nullptr, "s3", "24"}, {4, 128, 8, 0, 16, 24}) - .Output({"s1", "128", 32, "128+32"}) - .Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, 128, 32, 160}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {8, 128, 32, 160}), ""); -} - -TEST_F(InferenceRuleUtest, SymbolWithNullInput) { - CtxMaker ctx_maker; - ctx_maker.NullInput().Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolBasic) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0"}, {32}, ge::DT_INT32).Output({"v0+3"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {35}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMultiDtype) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0"}, {32}, ge::DT_INT32) - .ValueInput({"v1"}, {24}, ge::DT_UINT32) - .ValueInput({"v2"}, {8}, ge::DT_INT64) - .Output({"v0+v1+v2"}) - .Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 24 + 8}), ""); -} - -TEST_F(InferenceRuleUtest, MultiValueSymbol) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", "v2", "v1"}, {32, 2, 6}, ge::DT_INT32).Output({"v0+v1+v2"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 2 + 6}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMixNull) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", nullptr, "v1"}, {32, 2, 6}, ge::DT_INT32).Output({"v0+v1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 6}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMixDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s1"}, {3, 4}) - .ValueInput({"v0", nullptr, "v1"}, {32, 2, 6}, ge::DT_INT32) - .Output({"v0+s0", "v1+s1"}) - .Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 3, 6 + 4}), ""); -} - -TEST_F(InferenceRuleUtest, CompileAndLoadSucceed) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - std::vector binary; - ASSERT_EQ(ShapeInferenceRule::CompileJsonString(ctx_maker.Str(), binary), ge::GRAPH_SUCCESS); - const auto handle = ShapeInferenceRule::FromCompiledBinary(binary); - ASSERT_EQ(handle.Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle.InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle.InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, OutputWithUndefinedSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Symbol 's1' used in output but not defined in inputs"); -} - -TEST_F(InferenceRuleUtest, InputIsNotRawSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"t0"}, {32}).Output({"t1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing input symbols: Invalid input[0].size(0): t0, symbol dimension must start with 's' or 'v' " - "and follow with a number"); -} - -TEST_F(InferenceRuleUtest, InputIsNotSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0+2"}, {32}).Output({"s0+2"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing input symbols: Invalid input[0].size(0): s0+2, symbol dimension must start with 's' or 'v' " - "and follow with a number"); -} - -TEST_F(InferenceRuleUtest, NoShapeFiled) { - const auto handle = ShapeInferenceRule::FromJsonString("{}"); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Missing 'shape' field in rule json."); -} - -TEST_F(InferenceRuleUtest, InputsFormatError) { - { - Json json; - json["shape"]["inputs"] = 3; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.inputs' field: 3 field must be an array or null."); - } - - { - Json json; - json["shape"]["inputs"] = {3}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.inputs' field: [3] element must be an array of dimension expressions."); - } - - { - Json json; - json["shape"]["inputs"] = {{2.5}}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Invalid 'shape.inputs' field: [[2.5]] dimension expression must be a string or integer."); - } -} - -TEST_F(InferenceRuleUtest, OutputsFormatError) { - { - Json json; - json["shape"]["outputs"] = 3; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.outputs' field: 3 field must be an array or null."); - } - - { - Json json; - json["shape"]["outputs"] = {3}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.outputs' field: [3] element must be an array of dimension expressions."); - } - - { - Json json; - json["shape"]["outputs"] = {{2.5}}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Invalid 'shape.outputs' field: [[2.5]] dimension expression must be a string or integer."); - } - - { - Json json; - json["shape"]["outputs"] = {{nullptr}}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid output[0].size(0): empty dimension"); - } - - { - Json json; - json["shape"]["outputs"] = {{""}}; - const auto handle = ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid output[0].size(0): empty dimension"); - } -} - -TEST_F(InferenceRuleUtest, UnsupportedFunction) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"Abc(s0)"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr 'Abc(s0)': Invalid function: Abc, supported [Div, Floor, " - "Ceil, Pow, Mod]"); -} - -TEST_F(InferenceRuleUtest, UnsupportedOperator) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 / 3"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr 's0 / 3': Expression contains invalid characters"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_UnmatchedRightParenthesis) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0)"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid dim expr 's0)': Unmatched ')'"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_UnmatchedLeftParenthesis) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"(s0"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid dim expr '(s0': Unmatched '('"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_InvalidSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"2s0)"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr '2s0)': Invalid identifier: '2s0', expected start with 's' " - "or 'v' and follow with a number"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_SyntaxError) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 ++ 2"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Failed to compile C++ code to shared object:\nextern \"C\" {bool infer_shape(Ctx *ctx) {\n " - "GET_SYMBOL_DIM(s0, 0, 0);\n\n SET_OUTPUT_RANK(0, 1);\n SET_OUTPUT_DIM(0, 0, static_cast(s0 " - "++ 2));\n\n return true;\n}\nbool infer_shape_on_compile(Ctx *ctx) {\n SET_OUTPUT_RANK(0, 1);\n " - "SET_OUTPUT_DIM(0, 0, -1);\n\n return true;\n}}\nError: syntax error"); -} - -TEST_F(InferenceRuleUtest, BasicDtypeInfer) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_BF16}).Build(); - - const auto handle = DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto dtype_ctx = ctx_maker.DtypeCtx(); - - ASSERT_EQ(handle->InferDtype(dtype_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(dtype_ctx->GetOutputDataType(0), ge::DataType::DT_BF16); -} - -TEST_F(InferenceRuleUtest, InvalidDtype1) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_UNDEFINED}).Build(); - - const auto handle = DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element 28 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, InvalidDtype2) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_MAX}).Build(); - - const auto handle = DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element 42 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, InvalidDtype3) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({-1}).Build(); - - const auto handle = DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element -1 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, DtypesFormatError) { - { - Json json; - const auto handle = DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Missing 'dtype' field in rule json."); - } - - { - Json json; - json["dtype"] = 3; - const auto handle = DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Field 'dtype' must be an array."); - } - - { - Json json; - json["dtype"] = {nullptr}; - const auto handle = DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Element in 'dtype' field must not be null."); - } - - { - Json json; - json["dtype"] = {2.5}; - const auto handle = DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Element in 'dtype' field must be an integer."); - } - - { - Json json; - json["dtype"] = nullptr; - const auto handle = DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Filed 'dtype' must not be null."); - } -} - -TEST_F(InferenceRuleUtest, JsonFormatError) { - Json json; - const auto handle = DtypeInferenceRule::FromJsonString("{"); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing json: [json.exception.parse_error.101] parse error at line 1, column 2: syntax error while " - "parsing object key - unexpected end of input; expected string literal"); -} - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompile) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_EQ(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(desc->GetOutputDesc(0).GetShape(), {-1}), ""); -} - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompileNoRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(false); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompileInvalidRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0+s4"}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompile) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_FLOAT16}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_EQ(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); - ASSERT_EQ(desc->GetOutputDesc(0).GetDataType(), ge::DT_FLOAT16); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompileNoRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_FLOAT16}).Build(false); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompileInvalidRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_UNDEFINED}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInvalidDimCtx) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - { - CtxMaker ctx_bug; - ctx_bug.Build(); - - const auto compile_ctx = ctx_bug.CompileCtx(); - ASSERT_NE(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.Input({"s0"}, {32}).Build(); - - const auto compile_ctx = ctx_bug.CompileCtx(); - ASSERT_NE(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, CalledByInvalidValueCtx) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", "v1"}, {32, 24}, ge::DT_INT32).Output({"v1"}).Build(); - - const auto handle = ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - { - CtxMaker ctx_bug; - ctx_bug.Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.ValueInput({"v0"}, {32}, ge::DT_INT32).Output({"v0"}).Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.ValueInput({"v0, v1"}, {32, 24}, ge::DT_INT16).Output({"v1"}).Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, CompileInvalidJsonStrOrCode) { - std::vector binary; - ASSERT_NE(ShapeInferenceRule::CompileJsonString("{", binary), ge::GRAPH_SUCCESS); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 ++ 2"}).Build(); - ASSERT_NE(ShapeInferenceRule::CompileJsonString(ctx_maker.Str(), binary), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CallInvalidRule) { - { - const auto rule = ShapeInferenceRule::FromJsonString("{"); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - ASSERT_NE(rule->InferOnCompile(ctx_maker.CompileCtx()), ge::GRAPH_SUCCESS); - ASSERT_NE(rule->InferOnRuntime(ctx_maker.RuntimeCtx()), ge::GRAPH_SUCCESS); - } - - { - const auto rule = DtypeInferenceRule::FromJsonString("{"); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - ASSERT_NE(rule->InferDtype(ctx_maker.DtypeCtx()), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, JustForCoverage) { - auto handle = ShapeInferenceRule::FromCompiledBinary({}); - ASSERT_NE(handle.Error(), ""); - - ASSERT_TRUE(ShapeInferenceRule::GetInferenceRule(nullptr).empty()); -} \ No newline at end of file diff --git a/tests/ut/register/testcase/shape_inference_unittest.cc b/tests/ut/register/testcase/shape_inference_unittest.cc index 8e2d27c2d3..550cd0b694 100644 --- a/tests/ut/register/testcase/shape_inference_unittest.cc +++ b/tests/ut/register/testcase/shape_inference_unittest.cc @@ -1112,7 +1112,7 @@ TEST_F(ShapeInferenceUT, CallInferV2Func_NoInferShape_failed) { gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); auto status = InferShapeOnCompile(op, op_desc); - ASSERT_NE(status, ge::GRAPH_SUCCESS); + ASSERT_EQ(status, ge::GRAPH_FAILED); } TEST_F(ShapeInferenceUT, CallInferFormatFunc_OptionalInput) { -- Gitee