diff --git a/graph/ascendc_ir/core/ascendc_ir.cc b/graph/ascendc_ir/core/ascendc_ir.cc index f78eceea7546309cc6d041eac5fa0846316e027c..d08dfbc01ca34ebf70b2f60d750f9e103f9a710e 100644 --- a/graph/ascendc_ir/core/ascendc_ir.cc +++ b/graph/ascendc_ir/core/ascendc_ir.cc @@ -377,8 +377,16 @@ std::pair AscGraphImpl::DoSplit(const int64_t axis_id, const s if (actual_outer_axis_name.empty()) { actual_outer_axis_name = single_axis.name + outer_suffix; } - const auto inner_size = CreateSizeVar(actual_inner_axis_name + "_size"); - const auto outer_size = ge::sym::Ceiling(single_axis.size / inner_size); + ge::Expression inner_size; + ge::Expression outer_size; + if (single_axis.size == sym::kSymbolOne) { + inner_size = sym::kSymbolOne; + outer_size = sym::kSymbolOne; + } else { + inner_size = CreateSizeVar(actual_inner_axis_name + "_size"); + outer_size = ge::sym::Ceiling(single_axis.size / inner_size); + } + Axis::Type inner_type = is_tile_split ? Axis::kAxisTypeTileInner : Axis::kAxisTypeBlockInner; Axis::Type outer_type = is_tile_split ? Axis::kAxisTypeTileOuter : Axis::kAxisTypeBlockOuter; int64_t outter_id = static_cast(graph_attr_group_ptr->axis.size()); diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc b/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc index ea33a607188694d7abe2114edc21fc7b240a67bf..76cc9866e8a0900c78b3d58e796316fbc86d6207 100644 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc +++ b/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc @@ -139,6 +139,20 @@ TEST_F(UtestAscendCIR, TileSplit) { EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); } +TEST_F(UtestAscendCIR, TileSplitSizeOneAxis) { + AscGraph graph("test_graph"); + Axis &s0_axis = graph.CreateAxis("S0", ge::sym::kSymbolOne); + auto split_axis = graph.TileSplit(s0_axis.id); + EXPECT_NE(split_axis.first, nullptr); + EXPECT_NE(split_axis.second, nullptr); + auto &outer_axis = *split_axis.first; + auto &inner_axis = *split_axis.second; + EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); + EXPECT_EQ(inner_axis.size, ge::sym::kSymbolOne); + EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); + EXPECT_EQ(outer_axis.size, ge::sym::kSymbolOne); +} + TEST_F(UtestAscendCIR, MergeAxis) { AscGraph graph("test_graph"); const Expression s0 = graph.CreateSizeVar("s0");