diff --git a/third_party/transformer/src/expand_dimension.h b/third_party/transformer/inc/expand_dimension.h similarity index 94% rename from third_party/transformer/src/expand_dimension.h rename to third_party/transformer/inc/expand_dimension.h index 351d2922eba68eef7785b08485f94ffa68e9564a..aeefa08291fe257e1a2f5c5c144e04270587a0a6 100644 --- a/third_party/transformer/src/expand_dimension.h +++ b/third_party/transformer/inc/expand_dimension.h @@ -9,30 +9,28 @@ #ifndef COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ #define COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ - + #include #include #include "graph/types.h" #include "graph/ge_tensor.h" #include "exe_graph/runtime/shape.h" - +#include "transfer_def.h" + namespace transformer { -/* Pad dimension according to reshape type */ + /* Pad dimension according to reshape type */ bool ExpandDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, const uint32_t &tensor_index, const std::string &reshape_type, ge::GeShape &shape); - + bool ExpandRangeDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, const uint32_t &tensor_index, const std::string &reshape_type, std::vector> &ranges); - -const std::set kFormatNZSet = {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_16, - ge::FORMAT_FRACTAL_NZ_C0_32}; - + class ExpandDimension { public: ExpandDimension(); ~ExpandDimension(); - + static int64_t GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, const size_t &origin_dim_size, const std::string &reshape_type); static bool GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, @@ -61,3 +59,4 @@ private: }; } // namespace transformer #endif // COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ + \ No newline at end of file diff --git a/third_party/transformer/inc/transfer_def.h b/third_party/transformer/inc/transfer_def.h index aa20b8dd6078b43a51b765b7d05db9f7b8e9c147..508ecbb25ca94724a8525c8e5f455413f787901e 100644 --- a/third_party/transformer/inc/transfer_def.h +++ b/third_party/transformer/inc/transfer_def.h @@ -16,10 +16,11 @@ #include "exe_graph/runtime/shape.h" namespace transformer { -namespace { + const size_t ORIGIN_FORMAT_DIM_SIZE = 5; const size_t EXT_AXIS_SIZE = 4; -} +const size_t EXT_AXIS_OP_SIZE = 3; + struct AlignShapeInfo { ge::Format src_format; @@ -46,6 +47,7 @@ using TransferDimsFunc = std::function; using ExtAxisValue = std::array; +using ExtAxisOpValue = std::array; } // namespace transformer #endif // TRANSFORMER_INC_TRANSFER_DEF_H_ diff --git a/third_party/transformer/inc/transfer_range_according_to_format.h b/third_party/transformer/inc/transfer_range_according_to_format.h index 10ae453ab99eba381990e3c998631fd2c50db21a..4341927c65192669d3bb65c8d777624dd5958732 100644 --- a/third_party/transformer/inc/transfer_range_according_to_format.h +++ b/third_party/transformer/inc/transfer_range_according_to_format.h @@ -48,6 +48,9 @@ class RangeTransferAccordingToFormat { static bool GetRangeAccordingToFormat(RangeAndFormat &range_and_format_info); + static bool GetRangeAccordingToFormat(const ExtAxisOpValue &op_value, RangeAndFormat &range_and_format_info); + + // deprecated static bool GetRangeAccordingToFormat(const ge::OpDescPtr &op_desc, RangeAndFormat &range_and_format_info); }; } // namespace fe diff --git a/third_party/transformer/inc/transfer_shape_according_to_format.h b/third_party/transformer/inc/transfer_shape_according_to_format.h index 1608814f80eff8b4f84e0c6dfb1c24ba1767bcd6..f6c716df8c40b835fad79e5242d14eaf307024ba 100644 --- a/third_party/transformer/inc/transfer_shape_according_to_format.h +++ b/third_party/transformer/inc/transfer_shape_according_to_format.h @@ -16,8 +16,6 @@ #include "graph/op_desc.h" #include "platform/platform_info.h" #include "transfer_def.h" -// 先带着上线,改完其他代码仓的依赖再去掉 -#include "transfer_shape_utils.h" namespace transformer { struct CalcShapeExtraAttr { @@ -52,7 +50,7 @@ class ShapeTransferAccordingToFormat { static bool GetShapeAccordingToFormat(ShapeAndFormat &shapeAndFormatInfo); - static bool GetShapeAccordingToFormat(const ge::OpDescPtr &op_desc, ShapeAndFormat &shapeAndFormatInfo); + static bool GetShapeAccordingToFormat(const ExtAxisOpValue &op_value, ShapeAndFormat &shapeAndFormatInfo); static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, const ExtAxisValue &ext_axis, ge::GeShape &shape); @@ -61,12 +59,13 @@ class ShapeTransferAccordingToFormat { const ExtAxisValue &ext_axis, const ge::GeShape &origin_shape, ge::GeShape &shape); static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr, const fe::PlatFormInfos *platform_infos_ptr = nullptr); + gert::Shape &shape, const ExtAxisOpValue &op_value, + const fe::PlatFormInfos *platform_infos_ptr = nullptr); static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const gert::Shape &origin_shape, gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr); + const gert::Shape &origin_shape, gert::Shape &shape, const ExtAxisOpValue &op_value); - static void InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis); + static void InitExtAxisValue(const ExtAxisOpValue &op_value, ExtAxisValue &ext_axis); static bool InitPlatformInfo(); static int64_t GetC0ByDtype(const ge::DataType &data_type); @@ -74,6 +73,19 @@ class ShapeTransferAccordingToFormat { static int64_t GetN0ByDtype(const ge::DataType &data_type); static bool GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); static bool TransferDims(const TransferDimsInfo &transfer_dims_info, AxisIndexMapping &axis_index_mapping); + + // deprecated + static bool GetShapeAccordingToFormat(const ge::OpDescPtr &op_desc, ShapeAndFormat &shapeAndFormatInfo); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr, + const fe::PlatFormInfos *platform_infos_ptr = nullptr); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const gert::Shape &origin_shape, gert::Shape &shape, + const ge::OpDescPtr op_desc = nullptr); + + static void InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis); }; } // namespace transformer #endif // COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ diff --git a/third_party/transformer/src/axis_constants.h b/third_party/transformer/src/axis_constants.h index b7aaa4f8a84704c785256a9c48de11e04f10ae2e..08be86e0815f7ae76cb4c8fe6fe23c2ee0159f87 100644 --- a/third_party/transformer/src/axis_constants.h +++ b/third_party/transformer/src/axis_constants.h @@ -12,6 +12,8 @@ #include #include +#include +#include "graph/types.h" namespace transformer { extern const size_t DIM_SIZE_TWO; @@ -20,7 +22,7 @@ extern const size_t DIM_SIZE_FIVE; extern const size_t DIM_SIZE_SIX; extern const size_t EXT_INDEX_INPUT_SIZE; -extern const size_t EXT_INDEX_HIDEEN_SIZE; +extern const size_t EXT_INDEX_HIDDEN_SIZE; extern const size_t EXT_INDEX_STATE_SIZE; extern const size_t EXT_INDEX_M0_VAL; @@ -79,6 +81,10 @@ extern const int32_t AXIS_C1HWNCoC0_DIM_H; extern const int32_t AXIS_C1HWNCoC0_DIM_W; extern const int32_t AXIS_C1HWNCoC0_DIM_N; extern const int32_t AXIS_C1HWNCoC0_DIM_Co; + +const std::set kFormatNZSet = {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_16, + ge::FORMAT_FRACTAL_NZ_C0_32}; + } // namespace transformer #endif // COMMON_UTILS_TRANSFORMER_INC_AXIS_CONSTANTS_H_ diff --git a/third_party/transformer/src/axis_util.cc b/third_party/transformer/src/axis_util.cc index 1e936c074633b2284b5ea1babcaac0b6a8cb7d88..b520d732d7efcb81ad1104541738000b0b40334d 100644 --- a/third_party/transformer/src/axis_util.cc +++ b/third_party/transformer/src/axis_util.cc @@ -20,7 +20,7 @@ const size_t DIM_SIZE_FIVE = 5; const size_t DIM_SIZE_SIX = 6; const size_t EXT_INDEX_INPUT_SIZE = 0; -const size_t EXT_INDEX_HIDEEN_SIZE = 1; +const size_t EXT_INDEX_HIDDEN_SIZE = 1; const size_t EXT_INDEX_STATE_SIZE = 2; const size_t EXT_INDEX_M0_VAL = 3; diff --git a/third_party/transformer/src/axis_util.h b/third_party/transformer/src/axis_util.h index 433b1a4c0eb5f873f0d38a12a69d2800448f98e6..319dcfa6a8a3cb7e4992076d8829e4b3211da816 100644 --- a/third_party/transformer/src/axis_util.h +++ b/third_party/transformer/src/axis_util.h @@ -48,7 +48,7 @@ enum AxisValueType { AXIS_G = 8, AXIS_M0 = 9, AXIS_INPUT_SIZE = 10, - AXIS_HIDEEN_SIZE = 11, + AXIS_HIDDEN_SIZE = 11, AXIS_STATE_SIZE = 12, AXIS_BOTTOM = 13 }; diff --git a/third_party/transformer/src/expand_dimension.cc b/third_party/transformer/src/expand_dimension.cc index 3d78b095e6b9fe22392b5a0bd579105a710cf622..b86fe42ba375f84bbee379ba2d4d102591cc6949 100644 --- a/third_party/transformer/src/expand_dimension.cc +++ b/third_party/transformer/src/expand_dimension.cc @@ -7,7 +7,7 @@ * See LICENSE in the root of the software repository for the full text of the License. * ===================================================================================================================*/ -#include "expand_dimension.h" +#include "../inc/expand_dimension.h" #include #include #include diff --git a/third_party/transformer/src/transfer_range_according_to_format.cc b/third_party/transformer/src/transfer_range_according_to_format.cc index 163d96898146ed2a12b410e36e9a546ffb640574..a29f2bf2db8308f4e4d1314190557d2ce9fca591 100644 --- a/third_party/transformer/src/transfer_range_according_to_format.cc +++ b/third_party/transformer/src/transfer_range_according_to_format.cc @@ -38,8 +38,37 @@ bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(const ge::OpDescP range_and_format_info.new_range.emplace_back(shape_low.GetDim(i), shape_upper.GetDim(i)); } return res; +} +bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(const ExtAxisOpValue &op_value, + RangeAndFormat &range_and_format_info) { + /* The default new range is old range */ + std::vector range_upper_old; + std::vector range_low_old; + for (auto &i : range_and_format_info.old_range) { + range_low_old.emplace_back(i.first); + range_upper_old.emplace_back(i.second); + } + + ge::GeShape shape_low(range_low_old); + ge::GeShape shape_upper(range_upper_old); + transformer::ShapeAndFormat shape_and_format_info_low {shape_low, range_and_format_info.old_format, + range_and_format_info.new_format, range_and_format_info.current_data_type}; + transformer::ShapeAndFormat shape_and_format_info_upper {shape_upper, range_and_format_info.old_format, + range_and_format_info.new_format, range_and_format_info.current_data_type}; + ShapeTransferAccordingToFormat shape_transfer; + bool res = (shape_transfer.GetShapeAccordingToFormat(op_value, shape_and_format_info_low) && + shape_transfer.GetShapeAccordingToFormat(op_value, shape_and_format_info_upper)); + if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) { + return false; + } + range_and_format_info.new_range.clear(); + for (size_t i = 0; i < shape_low.GetDimNum(); ++i) { + range_and_format_info.new_range.emplace_back(shape_low.GetDim(i), shape_upper.GetDim(i)); + } + return res; } + bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(RangeAndFormat &range_and_format_info) { /* The default new range is old range */ std::vector range_upper_old; diff --git a/third_party/transformer/src/transfer_shape_according_to_format.cc b/third_party/transformer/src/transfer_shape_according_to_format.cc index 3c41169dbbeec21e715aa720c670506eb7e46025..c124a393984b6c0d899c569a34d43218bc12b387 100644 --- a/third_party/transformer/src/transfer_shape_according_to_format.cc +++ b/third_party/transformer/src/transfer_shape_according_to_format.cc @@ -52,6 +52,21 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(const ge::OpDescP return ret; } +bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(const ExtAxisOpValue &op_value, + ShapeAndFormat &shapeAndFormatInfo) { + if (shapeAndFormatInfo.oldShape.IsUnknownDimNum()) { + return true; + } + gert::Shape shape; + GeShapeToRtShape(shapeAndFormatInfo.oldShape, shape); + ExtAxisValue ext_axis; + InitExtAxisValue(op_value, ext_axis); + bool ret = TransferShapeUtils::TransferShape(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat, + shapeAndFormatInfo.currentDataType, ext_axis, shape); + RtShapeToGeShape(shape, shapeAndFormatInfo.oldShape); + return ret; +} + bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat &shapeAndFormatInfo) { if (shapeAndFormatInfo.oldShape.IsUnknownDimNum()) { return true; @@ -96,6 +111,15 @@ bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_form return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, shape, platform_infos_ptr); } +bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, + const ge::DataType &data_type, gert::Shape &shape, + const ExtAxisOpValue &op_value, + const fe::PlatFormInfos *platform_infos_ptr) { + ExtAxisValue ext_axis; + InitExtAxisValue(op_value, ext_axis); + return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, shape, platform_infos_ptr); +} + bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, const gert::Shape &origin_shape, gert::Shape &shape, const ge::OpDescPtr op_desc) { @@ -104,6 +128,14 @@ bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_form return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, origin_shape, shape); } +bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, + const ge::DataType &data_type, const gert::Shape &origin_shape, + gert::Shape &shape, const ExtAxisOpValue &op_value) { + ExtAxisValue ext_axis; + InitExtAxisValue(op_value, ext_axis); + return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, origin_shape, shape); +} + void ShapeTransferAccordingToFormat::InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis) { int64_t input_size = 1; int64_t hidden_size = 1; @@ -115,11 +147,18 @@ void ShapeTransferAccordingToFormat::InitExtAxisValue(const ge::OpDescPtr &op_de } ext_axis[EXT_INDEX_INPUT_SIZE] = input_size; - ext_axis[EXT_INDEX_HIDEEN_SIZE] = hidden_size; + ext_axis[EXT_INDEX_HIDDEN_SIZE] = hidden_size; ext_axis[EXT_INDEX_STATE_SIZE] = state_size; ext_axis[EXT_INDEX_M0_VAL] = kM0DefaultVal; } +void ShapeTransferAccordingToFormat::InitExtAxisValue(const ExtAxisOpValue &op_value, ExtAxisValue &ext_axis) { + ext_axis[EXT_INDEX_INPUT_SIZE] = op_value[EXT_INDEX_INPUT_SIZE]; + ext_axis[EXT_INDEX_HIDDEN_SIZE] = op_value[EXT_INDEX_HIDDEN_SIZE]; + ext_axis[EXT_INDEX_STATE_SIZE] = op_value[EXT_INDEX_STATE_SIZE]; + ext_axis[EXT_INDEX_M0_VAL] = kM0DefaultVal; +} + bool ShapeTransferAccordingToFormat::InitPlatformInfo() { return TransferShapeUtils::InitPlatformInfo(); } diff --git a/third_party/transformer/src/transfer_shape_utils.cc b/third_party/transformer/src/transfer_shape_utils.cc index 238591d135e3823cf207083f79361943f8c5a75c..c1f04d82e58d3eebfbd577aec2bea9479fad5829 100644 --- a/third_party/transformer/src/transfer_shape_utils.cc +++ b/third_party/transformer/src/transfer_shape_utils.cc @@ -198,7 +198,7 @@ bool TransferShapeUtils::TransferShape(const ge::Format &origin_format, const ge axis_value[AXIS_M0] = GetM0ByDtype(data_type); if (primary_format == ge::FORMAT_FRACTAL_ZN_RNN || primary_format == ge::FORMAT_ND_RNN_BIAS) { axis_value[AXIS_INPUT_SIZE] = ext_axis[EXT_INDEX_INPUT_SIZE]; - axis_value[AXIS_HIDEEN_SIZE] = ext_axis[EXT_INDEX_HIDEEN_SIZE]; + axis_value[AXIS_HIDDEN_SIZE] = ext_axis[EXT_INDEX_HIDDEN_SIZE]; axis_value[AXIS_STATE_SIZE] = ext_axis[EXT_INDEX_STATE_SIZE]; } @@ -722,7 +722,7 @@ bool TransferShapeUtils::GetFznRNNShapeByAxisValue(const AxisValue &axis_value, CHECK(origin_shape_size < DIM_SIZE_TWO, GELOGW("ndValue's dim num is less than 2!"), return true); /* check nd shape value */ int64_t k_value = shape.GetDim(origin_shape_size - MINUS_VALUE_TWO); - int64_t hidden_or_state_size = axis_value[AXIS_HIDEEN_SIZE]; + int64_t hidden_or_state_size = axis_value[AXIS_HIDDEN_SIZE]; if (axis_value[AXIS_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) { hidden_or_state_size = axis_value[AXIS_STATE_SIZE]; } @@ -740,9 +740,9 @@ bool TransferShapeUtils::GetFznRNNShapeByAxisValue(const AxisValue &axis_value, } int64_t n_value = shape.GetDim(origin_shape_size - MINUS_VALUE_ONE); - INT64_ZEROCHECK(axis_value[AXIS_HIDEEN_SIZE]); - int64_t n_num = n_value / axis_value[AXIS_HIDEEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDEEN_SIZE], axis_value[AXIS_C0]), n_num); + INT64_ZEROCHECK(axis_value[AXIS_HIDDEN_SIZE]); + int64_t n_num = n_value / axis_value[AXIS_HIDDEN_SIZE]; + MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num); shape.SetDim(origin_shape_size - MINUS_VALUE_ONE, n_num); shape.AppendDim(SHAPE_NUMBER_16); shape.AppendDim(axis_value[AXIS_C0]); @@ -750,11 +750,11 @@ bool TransferShapeUtils::GetFznRNNShapeByAxisValue(const AxisValue &axis_value, } bool TransferShapeUtils::GetNDRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - CHECK(axis_value[AXIS_HIDEEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); + CHECK(axis_value[AXIS_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); size_t size_of_original_vec = shape.GetDimNum(); /* check nd shape value */ - int64_t n_num = shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / axis_value[AXIS_HIDEEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDEEN_SIZE], axis_value[AXIS_C0]), n_num); + int64_t n_num = shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / axis_value[AXIS_HIDDEN_SIZE]; + MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num); MUL_OVERFLOW(n_num, axis_value[AXIS_C0], n_num); shape.SetDim(size_of_original_vec - MINUS_VALUE_ONE, n_num); return true; @@ -1114,7 +1114,7 @@ bool TransferShapeUtils::GetFractalZnRnnShape(const ExtAxisValue &ext_axis, cons } /* check nd shape value */ int64_t k_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_TWO); - int64_t hidden_or_state_size = ext_axis[EXT_INDEX_HIDEEN_SIZE]; + int64_t hidden_or_state_size = ext_axis[EXT_INDEX_HIDDEN_SIZE]; if (ext_axis[EXT_INDEX_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) { hidden_or_state_size = ext_axis[EXT_INDEX_STATE_SIZE]; } @@ -1130,9 +1130,9 @@ bool TransferShapeUtils::GetFractalZnRnnShape(const ExtAxisValue &ext_axis, cons } int64_t n_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_ONE); - INT64_ZEROCHECK(ext_axis[EXT_INDEX_HIDEEN_SIZE]); - int64_t n_num = n_value / ext_axis[EXT_INDEX_HIDEEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDEEN_SIZE], c0), n_num); + INT64_ZEROCHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE]); + int64_t n_num = n_value / ext_axis[EXT_INDEX_HIDDEN_SIZE]; + MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num); shape.AppendDim(n_num); shape.AppendDim(SHAPE_NUMBER_16); shape.AppendDim(c0); @@ -1141,15 +1141,15 @@ bool TransferShapeUtils::GetFractalZnRnnShape(const ExtAxisValue &ext_axis, cons bool TransferShapeUtils::GetNdRnnBiasShape(const ExtAxisValue &ext_axis, const int64_t &c0, const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(ext_axis[EXT_INDEX_HIDEEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); + CHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); size_t size_of_original_vec = origin_shape.GetDimNum(); shape.SetDimNum(0); for (size_t i = 0; i < size_of_original_vec - MINUS_VALUE_ONE; i++) { shape.AppendDim(origin_shape.GetDim(i)); } /* check nd shape value */ - int64_t n_num = origin_shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / ext_axis[EXT_INDEX_HIDEEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDEEN_SIZE], c0), n_num); + int64_t n_num = origin_shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / ext_axis[EXT_INDEX_HIDDEN_SIZE]; + MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num); MUL_OVERFLOW(n_num, c0, n_num); shape.AppendDim(n_num); return true;