38 #include "src/core/NEON/kernels/convolution/common/shims.hpp" 51 static const std::array<PermutationVector, 2> permutations2 =
58 static const std::array<PermutationVector, 6> permutations3 =
69 static const std::array<PermutationVector, 24> permutations4 =
99 return (permutations2.end() !=
std::find(permutations2.begin(), permutations2.end(), v)) || (permutations3.end() !=
std::find(permutations3.begin(), permutations3.end(), v))
100 || (permutations4.end() !=
std::find(permutations4.begin(), permutations4.end(), v));
111 if(dst->total_size() != 0)
121 template <
typename T>
122 void run_permute(
const Window &window,
const ITensor *src,
const ITensor *dst,
const PermutationVector &perm)
124 const DataLayout src_layout = src->info()->data_layout();
127 Window window_src = window;
133 window_src.set(
Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
134 window_src.set(
Window::DimY, Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start()));
135 window_src.set(
Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start()));
136 window_src.set(3, Window::Dimension(window[3].
start(), window[3].
end(), window[3].
end() - window[3].
start()));
140 Window window_dst(window);
141 const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
142 for(
size_t d = 0; d <= dst->info()->num_dimensions(); ++d)
144 window_dst.set(d, zero_window);
148 Iterator src_it(src, window_src);
149 Iterator dst_it(dst, window_dst);
151 int in_row_stride = 0;
152 int in_col_stride = 0;
153 int in_channel_stride = 0;
154 int in_batch_stride = 0;
164 in_row_stride = src->info()->strides_in_bytes().y() /
sizeof(T);
165 in_channel_stride = src->info()->strides_in_bytes().z() /
sizeof(T);
166 in_batch_stride = src->info()->strides_in_bytes()[3] /
sizeof(T);
167 n_cols = src->info()->tensor_shape().x();
168 n_rows = window_src.y().step();
169 n_channels = src->info()->tensor_shape().z();
170 n_batches = src->info()->tensor_shape()[3];
175 in_col_stride = src->info()->strides_in_bytes().y() /
sizeof(T);
176 in_row_stride = src->info()->strides_in_bytes().z() /
sizeof(T);
177 in_batch_stride = src->info()->strides_in_bytes()[3] /
sizeof(T);
178 n_channels = src->info()->tensor_shape().x();
179 n_cols = window_src.y().step();
180 n_rows = src->info()->tensor_shape().z();
181 n_batches = src->info()->tensor_shape()[3];
194 const int out_channel_stride = dst->info()->strides_in_bytes().x() /
sizeof(T);
195 const int out_col_stride = dst->info()->strides_in_bytes().y() /
sizeof(T);
196 const int out_row_stride = dst->info()->strides_in_bytes().z() /
sizeof(T);
197 const int out_batch_stride = dst->info()->strides_in_bytes()[3] /
sizeof(T);
200 const int idx =
id[0] * out_col_stride +
id[1] * out_row_stride +
id[2] * out_channel_stride;
201 reorder::nchw_to_nhwc(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx,
202 n_batches, n_channels, n_rows, n_cols,
203 in_batch_stride, in_channel_stride, in_row_stride,
204 out_batch_stride, out_row_stride, out_col_stride);
211 const int out_col_stride = dst->info()->strides_in_bytes().x() /
sizeof(T);
212 const int out_row_stride = dst->info()->strides_in_bytes().y() /
sizeof(T);
213 const int out_channel_stride = dst->info()->strides_in_bytes().z() /
sizeof(T);
214 const int out_batch_stride = dst->info()->strides_in_bytes()[3] /
sizeof(T);
217 const int idx =
id[0] * out_channel_stride +
id[1] * out_col_stride +
id[2] * out_row_stride;
218 reorder::nhwc_to_nchw(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx,
219 n_batches, n_rows, n_cols, n_channels,
220 in_batch_stride, in_row_stride, in_col_stride,
221 out_batch_stride, out_channel_stride, out_row_stride);
229 Strides strides = dst->info()->strides_in_bytes();
230 Strides perm_strides = strides;
232 const int perm_stride_3 = src->info()->num_dimensions() >= 4 ? perm_strides[3] : 0;
235 const int idx =
id[0] * perm_strides[0] +
id[1] * perm_strides[1] +
id[2] * perm_strides[2] +
id[3] * perm_stride_3;
236 *(
reinterpret_cast<T *
>(dst_it.ptr() + idx)) = *(
reinterpret_cast<const T *
>(src_it.ptr()));
263 ICpuKernel::configure(win);
281 switch(
src->info()->element_size())
300 return "CpuPermuteKernel";
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
TensorShape compute_permutation_output_shape(const ITensorInfo &input, const PermutationVector &perm)
Calculate the permuted shape of an input given a permutation vector.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...)
void permute_strides(Dimensions< T > &dimensions, const PermutationVector &perm)
Permutes the given dimensions according the permutation vector.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Strides PermutationVector
Permutation vector.
Store the tensor's metadata.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
void configure(const ITensorInfo *src, ITensorInfo *dst, const PermutationVector &perm)
Configure kernel for a given list of arguments.
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(...)
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
virtual void set_valid_region(const ValidRegion &valid_region)=0
Set the valid region of the tensor.
const char * name() const override
Name of the kernel.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
Class to describe a number of elements in each dimension.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
void end(TokenStream &in, bool &valid)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Num samples, channels, height, width.
Strides of an item in bytes.
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
Information about executing thread and CPU.
static constexpr size_t DimZ
Alias for dimension 2 also known as Z dimension.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Num samples, height, width, channels.
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
void set_num_dimensions(size_t num_dimensions)
Set number of dimensions.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
Static function to check if given info will lead to a valid configuration of CpuPermuteKernel.
Container for valid region of a window.
DataLayout
[DataLayout enum definition]
Describe a multidimensional execution window.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)