diff --git a/inc/external/exe_graph/runtime/stride.h b/inc/external/exe_graph/runtime/stride.h new file mode 100644 index 0000000000000000000000000000000000000000..ecc532bbb36e929afa7f7704841f4f83b14047ec --- /dev/null +++ b/inc/external/exe_graph/runtime/stride.h @@ -0,0 +1,186 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_STRIDE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_STRIDE_H_ + +#include +#include +#include +#include +#include +#include +#include "utils/extern_math_util.h" + +namespace gert { +struct Stride { + public: + static constexpr size_t kMaxDimNum = 25; + static constexpr int64_t kInvalidDimValue = std::numeric_limits::min(); + + public: + /** + * 默认构造一个Stride,默认构造的Stride实例中,dim_num长度为0 + */ + Stride() : dim_num_(0), strides_{0} { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 通过dims值构造Stride,例如:Stride({2,4,8,16})创建一个Stride实例,有4个维度,每个维度的值分别是2,4,8,16 + * @param dims Stride的所有dim值 + */ + Stride(const std::initializer_list &args) : Stride() { + if (args.size() > kMaxDimNum) { + return; + } + dim_num_ = args.size(); + size_t i = 0; + for (const int64_t arg : args) { + strides_[i++] = arg; + } + } + + /** + * 拷贝构造 + * @param other 源对象 + * 为了提升性能,dims_超过dim_num_的空间没有拷贝,可能有脏数据 + */ + Stride(const Stride &other) { + dim_num_ = other.dim_num_; + for (size_t i = 0U; i < dim_num_; ++i) { + strides_[i] = other.strides_[i]; + } + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 拷贝赋值 + * @param other + * @return + * 为了提升性能,dims_超过dim_num_的空间没有拷贝,可能有脏数据 + */ + Stride &operator=(const Stride &other) { + if (&other != this) { + dim_num_ = other.dim_num_; + for (size_t i = 0U; i < dim_num_; ++i) { + strides_[i] = other.strides_[i]; + } + } + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + return *this; + } + + /** + * 判断与另外一个shape对象是否相等,如果两个Stride的dim num并且dim num内每个dim的值都相等,那么认为两个shape相等 + * @param rht 另一个Shape对象 + * @return true/false + */ + bool operator==(const Stride &rht) const { + if (this->dim_num_ != rht.dim_num_) { + return false; + } + for (size_t i = 0; i < this->dim_num_; i++) { + if (this->strides_[i] != rht.strides_[i]) { + return false; + } + } + return true; + } + + /** + * 判断与另一个Stride对象是否不等 + * @param rht 另一个Stride对象 + * @return true/false + */ + bool operator!=(const Stride &rht) const { + return !(*this == rht); + } + + /** + * 获取dim num + * @return + */ + size_t GetDimNum() const { + return dim_num_; + } + + /** + * 设置dim num + * @param dim_num + */ + void SetDimNum(const size_t dim_num) { + this->dim_num_ = dim_num; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,在idx超出MaxDimNum时,返回`kInvalidDimValue` + */ + int64_t GetStride(const size_t idx) const { + if (idx >= kMaxDimNum) { + return kInvalidDimValue; + } + return strides_[idx]; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,行为未定义 + */ + const int64_t &operator[](const size_t idx) const { + return strides_[idx]; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,在idx超出MaxDimNum时,行为未定义 + */ + int64_t &operator[](const size_t idx) { + return strides_[idx]; + } + + /** + * 设置dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return + */ + void SetDim(size_t idx, const int64_t dim_value) { + if (idx >= kMaxDimNum) { + return; + } + strides_[idx] = dim_value; + this->dim_num_ = (this->dim_num_ < idx) ? idx : this->dim_num_; + } + + /** + * 向后扩展一个dim值,如果扩展的dim数量超出Stride的最大限制,那么本函数不做任何事情 + * @param 扩展的dim值 + * @return this引用 + */ + Stride& AppendDim(const int64_t value) { + if (this->dim_num_ >= kMaxDimNum) { + return *this; + } + strides_[this->dim_num_++] = value; + return *this; + } + + private: + size_t dim_num_; + int64_t strides_[kMaxDimNum]; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class Stride must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_SHAPE_H_ diff --git a/inc/external/exe_graph/runtime/tensor.h b/inc/external/exe_graph/runtime/tensor.h index 82504595f5994c41879ea62093e3a29f0dd6c5fa..b8dff583d85e5f6d3f40a330d338b7a973d5646d 100644 --- a/inc/external/exe_graph/runtime/tensor.h +++ b/inc/external/exe_graph/runtime/tensor.h @@ -16,6 +16,7 @@ #include "storage_shape.h" #include "storage_format.h" #include "tensor_data.h" +#include "stride.h" namespace gert { using TensorAddress = void *; ///< Tensor地址 @@ -32,14 +33,16 @@ class Tensor { const ge::DataType data_type, TensorAddress addr) : Tensor(storage_shape, storage_format, placement, data_type, addr, nullptr) {} Tensor(const StorageShape &storage_shape, const StorageFormat &storage_format, ge::DataType data_type) - : storage_shape_(storage_shape), storage_format_(storage_format), data_type_(data_type) { + : storage_shape_(storage_shape), storage_format_(storage_format), data_type_(data_type), + is_view_(false), view_offset_(0) { (void) memset(reserved_, 0, sizeof(reserved_)); (void) memset(reserved_field_, 0, sizeof(reserved_field_)); } Tensor(const StorageShape &storage_shape, const StorageFormat &storage_format, const TensorPlacement placement, const ge::DataType data_type, TensorAddress addr, TensorAddrManager manager) : storage_shape_(storage_shape), storage_format_(storage_format), data_type_(data_type), - tensor_data_(addr, manager, static_cast(ge::GetSizeInBytes(GetShapeSize(), data_type_)), placement) { + tensor_data_(addr, manager, static_cast(ge::GetSizeInBytes(GetShapeSize(), data_type_)), placement), + is_view_(false), view_offset_(0) { (void) memset(reserved_, 0, sizeof(reserved_)); (void) memset(reserved_field_, 0, sizeof(reserved_field_)); } @@ -288,6 +291,63 @@ class Tensor { TensorData &MutableTensorData() { return tensor_data_; } + + /** + * 获取是否为非连续tensor + * @return 是否为非连续tensor + */ + const bool IsView() const { + return is_view_; + } + + /** + * 获取view shape + * @return 只读的view shape + */ + const Shape &GetViewShape() const { + return storage_shape_.GetStorageShape(); // view没有使能的时候,仍然可以查到shape + } + + /** + * 获取view stride + * @return 只读的view stride + */ + const Stride &GetViewStride() const { + return view_stride_; + } + + /** + * 获取view offset + * @return view offset + */ + int64_t GetViewOffset() const { + return view_offset_; + } + + /** + * 设置是否为非连续tensor + * @is_view 是否为非连续tensor + */ + void SetIsView(const bool is_view) { + is_view_ = is_view; + } + + /** + * 设置view stride + * @view_stride + */ + void SetViewStride(Stride &&view_stride) { + view_stride_ = std::move(view_stride); + } + + /** + * 设置view offset + * @view_offset + */ + void SetViewOffset(const uint64_t view_offset) { + view_offset_ = view_offset; + } + private: static std::unique_ptr NewFollowingTensor(const ge::DataType dt, size_t &total_size) { if (ge::AddOverflow(total_size, sizeof(Tensor), total_size)) { @@ -309,7 +369,10 @@ class Tensor { uint8_t reserved_[4]; // Reserved field, 4-byte aligned ge::DataType data_type_; TensorData tensor_data_; - uint8_t reserved_field_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left + bool is_view_; + uint8_t reserved_field_[39]; // Reserved field, 32+8, do not directly use when only 8-byte left + Stride view_stride_; + int64_t view_offset_; }; static_assert(std::is_standard_layout::value, "The class Tensor must be a POD"); } // namespace gert diff --git a/inc/external/exe_graph/runtime/tiling_context.h b/inc/external/exe_graph/runtime/tiling_context.h index a8a30ecf1e0b0c6c77bee49e2edc0d9ce7b2adff..2b021f2434c323f214eb59a6a30ca11900b9b74a 100644 --- a/inc/external/exe_graph/runtime/tiling_context.h +++ b/inc/external/exe_graph/runtime/tiling_context.h @@ -440,6 +440,118 @@ class TilingContext : public ExtendedKernelContext { } return *p; } + + /** + * 判断输入是否为非连续tensor + * @index 输入index + * @return 是否为非连续tensor + */ + bool InputIsView(size_t index) { + const auto p = GetInputPointer(index); + if (p == nullptr) { + return false; + } + + return p->IsView(); + } + + /** + * 获取输入的view shape + * @index 输入index + * @return 输入view shape + */ + const Shape *GetInputViewShape(size_t index) { + const auto p = GetInputPointer(index); + if (p == nullptr || !(p->IsView())) { + return nullptr; + } + + return &(p->GetViewShape()); + } + + /** + * 获取输入的stride + * @index 输入index + * @return 输入view stride + */ + const Stride *GetInputStride(size_t index) { + const auto p = GetInputPointer(index); + if (p == nullptr || !(p->IsView())) { + return nullptr; + } + + return &(p->GetViewStride()); + } + + /** + * 获取输入的stride + * @index 输入index + * @return 输入view offset + */ + int64_t GetInputViewOffset(size_t index) const { + const auto p = GetInputPointer(index); + if (p == nullptr || !(p->IsView())) { + return 0; + } + + return p->GetViewOffset(); + } + + /** + * 判断输出是否为非连续tensor + * @index 输出index + * @return 是否为非连续tensor + */ + bool OutputIsView(size_t index) { + const auto p = GetOutputPointer(index); + if (p == nullptr) { + return false; + } + + return p->IsView(); + } + + /** + * 获取输出的view shape + * @index 输出index + * @return 输出view shape + */ + const Shape *GetOutputViewShape(size_t index) { + const auto p = GetOutputPointer(index); + if (p == nullptr || !(p->IsView())) { + return nullptr; + } + + return &(p->GetViewShape()); + } + + /** + * 获取输出的stride + * @index 输出index + * @return 输入出view stride + */ + const Stride *GetOutputStride(size_t index) { + const auto p = GetOutputPointer(index); + if (p == nullptr || !(p->IsView())) { + return nullptr; + } + + return &(p->GetViewStride()); + } + + /** + * 获取输出的stride + * @index 输出index + * @return 输出view offset + */ + int64_t GetOutputViewOffset(size_t index) const { + const auto p = GetOutputPointer(index); + if (p == nullptr || !(p->IsView())) { + return 0; + } + + return p->GetViewOffset(); + } }; static_assert(std::is_standard_layout::value, "The class TilingContext must be a POD"); } // namespace gert diff --git a/tests/ut/base/testcase/context_builder_unittest.cc b/tests/ut/base/testcase/context_builder_unittest.cc index 8ac2afa9a0faaf37d41513727b9e4da02cb9de5b..a9a061827fa232b8511ececee2aea8fd1293eec5 100644 --- a/tests/ut/base/testcase/context_builder_unittest.cc +++ b/tests/ut/base/testcase/context_builder_unittest.cc @@ -448,6 +448,240 @@ TEST_F(UtestContextBuilder, CreateTilingContextOK) { EXPECT_EQ(((int64_t *) (int_vec_vec->Get(1)->GetData()))[2], 6); } +TEST_F(UtestContextBuilder, CreateStrideOK) { + Stride stride_default; + EXPECT_EQ(stride_default.GetDimNum(), 0); + EXPECT_EQ(stride_default.GetStride(0), 0); + + Stride stride_1({1, 2, 3, 4}); + EXPECT_EQ(stride_1.GetDimNum(), 4); + EXPECT_EQ(stride_1.GetStride(0), 1); + EXPECT_EQ(stride_1.GetStride(1), 2); + EXPECT_EQ(stride_1.GetStride(2), 3); + EXPECT_EQ(stride_1.GetStride(3), 4); + + Stride stride_2(stride_1); + EXPECT_EQ(stride_1, stride_2); + + Stride stride_3 = stride_1; + EXPECT_EQ(stride_3.GetDimNum(), 4); + EXPECT_EQ(stride_3.GetStride(0), 1); + EXPECT_EQ(stride_3.GetStride(1), 2); + EXPECT_EQ(stride_3.GetStride(2), 3); + EXPECT_EQ(stride_3.GetStride(3), 4); + EXPECT_EQ(stride_1, stride_3); + EXPECT_NE(stride_default, stride_3); + + Stride stride_4; + stride_4.SetDimNum(4); + stride_4.SetDim(0, 1); + stride_4.SetDim(1, 2); + stride_4.SetDim(2, 3); + stride_4.SetDim(3, 4); + EXPECT_EQ(stride_1, stride_4); + const int64_t dim_0 = stride_4[0]; + const int64_t dim_1 = stride_4[1]; + const int64_t dim_2 = stride_4[2]; + const int64_t dim_3 = stride_4[3]; + EXPECT_EQ(dim_0, 1); + EXPECT_EQ(dim_1, 2); + EXPECT_EQ(dim_2, 3); + EXPECT_EQ(dim_3, 4); + EXPECT_EQ(stride_4[0], 1); + EXPECT_EQ(stride_4[1], 2); + EXPECT_EQ(stride_4[2], 3); + EXPECT_EQ(stride_4[3], 4); + + Stride stride_5({1, 2, 3, 4}); + stride_5.SetDim(Stride::kMaxDimNum + 1, 100); + EXPECT_EQ(stride_5.GetDimNum(), 4); + EXPECT_EQ(stride_5.GetStride(Stride::kMaxDimNum + 1), stride_5.kInvalidDimValue); + stride_5.AppendDim(5); + EXPECT_EQ(stride_5.GetDimNum(), 5); + EXPECT_EQ(stride_5.GetStride(4), 5); + + Stride stride_6({1, 2, 3, 4, 5 ,6 , 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 28, 19, 20, 21, 22, 23, 24, 25, 26}); + EXPECT_EQ(stride_6.GetDimNum(), 0); + EXPECT_EQ(stride_6.GetStride(0), 0); + + Stride stride_7({1, 2, 3, 4, 5 ,6 , 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 28, 19, 20, 21, 22, 23, 24, 25}); + EXPECT_EQ(stride_7.GetDimNum(), 25); + stride_7.AppendDim(26); + EXPECT_EQ(stride_7.GetDimNum(), 25); + EXPECT_EQ(stride_7.GetStride(24), 25); +} + +TEST_F(UtestContextBuilder, CreateTilingContextViewOK) { + auto workspace_size_holer = gert::ContinuousVector::Create(4096); + auto ws_ptr = reinterpret_cast(workspace_size_holer.get()); + + gert::Shape shape_0{1, 1, 1, 1, 1}; + gert::Shape shape_1{10, 10, 10, 10, 20}; + gert::Shape shape_2{1, 1, 1, 1, 1}; + gert::Shape shape_3{10, 10, 10, 10, 20}; + gert::Shape resultShape{10, 10, 10, 10, 20}; + + gert::StorageShape x({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); + gert::StorageShape resultIn({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); + gert::StorageShape gamma({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); + gert::StorageShape beta({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); + gert::StorageShape result({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); + + uint8_t data_x[1] = {9}; + // 构造1 + gert::Tensor x_tensor(x, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, + ge::DT_FLOAT, (void *) data_x); + // 构造2 + gert::Tensor resultIn_tensor(resultIn, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, ge::DT_FLOAT); + // 构造3 + gert::Tensor gammax_tensor(gamma, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, + ge::DT_FLOAT, nullptr); + // 设置view相关信息 view使能 + gert::Tensor beta1_tensor(beta, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, + ge::DT_FLOAT, nullptr); + beta1_tensor.SetIsView(true); + Stride stride1({1, 1, 1, 1, 1}); + beta1_tensor.SetViewStride(std::move(stride1)); + beta1_tensor.SetViewOffset(8); + + // 设置view相关信息 view不使能 + gert::Tensor beta2_tensor(beta, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, + ge::DT_FLOAT, nullptr); + beta2_tensor.SetIsView(false); + Stride stride2({2, 2, 2, 2, 2}); + beta2_tensor.SetViewStride(std::move(stride2)); + beta2_tensor.SetViewOffset(8); + + // 设置view相关信息 view使能 + gert::Tensor result1_tensor(result, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, + TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + result1_tensor.SetIsView(true); + Stride stride3({3, 3, 3, 3, 3}); + result1_tensor.SetViewStride(std::move(stride3)); + result1_tensor.SetViewOffset(8); + + // 设置view相关信息 view不使能 + gert::Tensor result2_tensor(result, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, + TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + result2_tensor.SetIsView(false); + Stride stride4({4, 4, 4, 4, 4}); + result2_tensor.SetViewStride(std::move(stride4)); + result2_tensor.SetViewOffset(8); + + auto tmp_tiling_data = gert::TilingData::CreateCap(100); + uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; + uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; + int32_t deterministic = 1; + + OpTilingContextBuilder ctx_builder; + auto holder = ctx_builder.OpName("tmp") + .OpType("DIY") + .IONum(5, 2) + .AppendAttr(int64_t(1)) + .AppendAttr(bool(true)) + .AppendAttr(float(0.3)) + .AppendAttr(AscendString("my_info")) + .AppendAttr(std::vector({true, false, true})) + .AppendAttr(std::vector({1, 2, 3})) + .AppendAttr(std::vector({0.1, 0.2, 0.3})) + .AppendAttr(std::vector({"123", "234"})) + .AppendAttr(std::vector>({{1, 2, 3}, {4, 5, 6}})) + .TilingData(reinterpret_cast(tmp_tiling_data.get())) + .Workspace(ws_ptr) + .CompileInfo(tmp_compile_info) + .Deterministic(deterministic) + .PlatformInfo(tmp_platform_info) + .InputTensors({&x_tensor, &resultIn_tensor, &gammax_tensor, &beta1_tensor, &beta2_tensor}) + .OutputTensors({&result1_tensor, &result2_tensor}) + .Build(); + + auto ctx = holder.GetContext(); + EXPECT_NE(ctx, nullptr); + auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); + EXPECT_NE(ctx_compute_node_info, nullptr); + EXPECT_EQ(ctx->GetCompileInfo(), tmp_compile_info); + + EXPECT_EQ(ctx->GetInputTensor(0)->IsView(), false); + EXPECT_EQ(ctx->GetInputTensor(0)->GetViewStride().GetDimNum(), 0); + EXPECT_EQ(ctx->GetInputTensor(0)->GetViewStride().GetStride(0), 0); + EXPECT_EQ(ctx->GetInputTensor(0)->GetViewOffset(), 0); + EXPECT_EQ(ctx->GetInputTensor(0)->GetViewShape(), shape_0); + EXPECT_EQ(ctx->InputIsView(0), false); + EXPECT_EQ(ctx->GetInputStride(0), nullptr); + EXPECT_EQ(ctx->GetInputViewOffset(0), 0); + EXPECT_EQ(ctx->GetInputViewShape(0), nullptr); + + EXPECT_EQ(ctx->GetInputTensor(1)->IsView(), false); + EXPECT_EQ(ctx->GetInputTensor(1)->GetViewStride().GetDimNum(), 0); + EXPECT_EQ(ctx->GetInputTensor(1)->GetViewStride().GetStride(0), 0); + EXPECT_EQ(ctx->GetInputTensor(1)->GetViewOffset(), 0); + EXPECT_EQ(ctx->GetInputTensor(1)->GetViewShape(), shape_1); + EXPECT_EQ(ctx->InputIsView(1), false); + EXPECT_EQ(ctx->GetInputStride(1), nullptr); + EXPECT_EQ(ctx->GetInputViewOffset(1), 0); + EXPECT_EQ(ctx->GetInputViewShape(1), nullptr); + + + EXPECT_EQ(ctx->GetInputTensor(2)->IsView(), false); + EXPECT_EQ(ctx->GetInputTensor(2)->GetViewStride().GetDimNum(), 0); + EXPECT_EQ(ctx->GetInputTensor(2)->GetViewStride().GetStride(0), 0); + EXPECT_EQ(ctx->GetInputTensor(2)->GetViewOffset(), 0); + EXPECT_EQ(ctx->GetInputTensor(2)->GetViewShape(), shape_2); + EXPECT_EQ(ctx->InputIsView(2), false); + EXPECT_EQ(ctx->GetInputStride(2), nullptr); + EXPECT_EQ(ctx->GetInputViewOffset(2), 0); + EXPECT_EQ(ctx->GetInputViewShape(2), nullptr); + + + EXPECT_EQ(ctx->GetInputTensor(3)->IsView(), true); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetDimNum(), 5); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetStride(0), 1); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetStride(1), 1); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetStride(2), 1); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetStride(3), 1); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewStride().GetStride(4), 1); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewOffset(), 8); + EXPECT_EQ(ctx->GetInputTensor(3)->GetViewShape(), shape_3); + EXPECT_EQ(ctx->InputIsView(3), true); + EXPECT_EQ(ctx->GetInputStride(3)->GetDimNum(), 5); + EXPECT_EQ(ctx->GetInputStride(3)->GetStride(0), 1); + EXPECT_EQ(ctx->GetInputStride(3)->GetStride(1), 1); + EXPECT_EQ(ctx->GetInputStride(3)->GetStride(2), 1); + EXPECT_EQ(ctx->GetInputStride(3)->GetStride(3), 1); + EXPECT_EQ(ctx->GetInputStride(3)->GetStride(4), 1); + EXPECT_EQ(ctx->GetInputViewOffset(3), 8); + EXPECT_EQ(*(ctx->GetInputViewShape(3)), shape_3); + + EXPECT_EQ(ctx->GetInputTensor(4)->IsView(), false); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetDimNum(), 5); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetStride(0), 2); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetStride(1), 2); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetStride(2), 2); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetStride(3), 2); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewStride().GetStride(4), 2); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewOffset(), 8); + EXPECT_EQ(ctx->GetInputTensor(4)->GetViewShape(), shape_3); + EXPECT_EQ(ctx->InputIsView(4), false); + EXPECT_EQ(ctx->GetInputStride(4), nullptr); + EXPECT_EQ(ctx->GetInputViewOffset(4), 0); + EXPECT_EQ(ctx->GetInputViewShape(4), nullptr); + + EXPECT_EQ(ctx->OutputIsView(0), true); + EXPECT_EQ(ctx->GetOutputStride(0)->GetDimNum(), 5); + EXPECT_EQ(ctx->GetOutputStride(0)->GetStride(0), 3); + EXPECT_EQ(ctx->GetOutputStride(0)->GetStride(1), 3); + EXPECT_EQ(ctx->GetOutputStride(0)->GetStride(2), 3); + EXPECT_EQ(ctx->GetOutputStride(0)->GetStride(3), 3); + EXPECT_EQ(ctx->GetOutputStride(0)->GetStride(4), 3); + EXPECT_EQ(ctx->GetOutputViewOffset(0), 8); + EXPECT_EQ(*(ctx->GetOutputViewShape(0)), resultShape); + + EXPECT_EQ(ctx->OutputIsView(1), false); + EXPECT_EQ(ctx->GetOutputStride(1), nullptr); + EXPECT_EQ(ctx->GetOutputViewOffset(1), 0); + EXPECT_EQ(ctx->GetOutputViewShape(1), nullptr); +} + TEST_F(UtestContextBuilder, CreateTilingContextTilingDataSizeOK) { auto workspace_size_holer = gert::ContinuousVector::Create(4096); auto ws_ptr = reinterpret_cast(workspace_size_holer.get());