23.11
|
Go to the documentation of this file.
54 if (
dst->tensor_shape().total_size() != 0)
65 void reshape_tensor_per_element(
const Window &window,
const ITensor *
src, ITensor *
dst)
67 const TensorShape &src_shape =
src->info()->tensor_shape();
68 const TensorShape &
dst_shape =
dst->info()->tensor_shape();
70 Iterator dst_it(
dst, window);
74 [&](
const Coordinates &dst_coord)
77 const auto output_ptr =
dst->ptr_to_element(dst_coord);
78 const auto input_ptr =
src->ptr_to_element(src_coord);
80 *
reinterpret_cast<T *
>(output_ptr) = *
reinterpret_cast<T *
>(input_ptr);
85 void reshape_tensor_per_element_selector(
const Window &window,
const ITensor *
src, ITensor *
dst)
87 switch (
src->info()->data_type())
95 reshape_tensor_per_element<uint8_t>(window,
src,
dst);
100 reshape_tensor_per_element<uint16_t>(window,
src,
dst);
105 reshape_tensor_per_element<uint32_t>(window,
src,
dst);
110 reshape_tensor_per_element<uint64_t>(window,
src,
dst);
117 void reshape_tensor_per_row(
const Window &window,
const ITensor *
src, ITensor *
dst)
119 const TensorShape &src_shape =
src->info()->tensor_shape();
120 const TensorShape &
dst_shape =
dst->info()->tensor_shape();
121 Coordinates src_coord{};
122 Coordinates dst_coord{};
124 const auto element_size =
dst->info()->element_size();
125 const auto window_start_x =
static_cast<int>(window.x().start());
126 const auto window_end_x =
static_cast<int>(window.x().end());
127 const auto src_row_size =
static_cast<int>(src_shape[0]);
128 const auto row_size_in_bytes = src_row_size * element_size;
130 auto output_ptr =
dst->ptr_to_element(dst_coord);
131 auto input_ptr =
src->ptr_to_element(src_coord);
136 Iterator dst_it(
dst, win);
143 for (
int x = window_start_x; x < window_end_x; x += src_row_size)
146 output_ptr =
dst->ptr_to_element(dst_coord);
147 input_ptr =
src->ptr_to_element(src_coord);
149 std::memcpy(output_ptr, input_ptr, row_size_in_bytes);
157 void reshape_tensor_per_window(
const Window &window,
const ITensor *
src, ITensor *
dst)
159 Iterator src_it(
src, window);
160 Iterator dst_it(
dst, window);
162 const size_t element_size =
dst->info()->element_size();
163 const auto window_size = window.x().end() - window.x().start();
164 const auto window_size_in_bytes = window_size * element_size;
166 const auto input_ptr = src_it.ptr();
167 const auto output_ptr = dst_it.ptr();
169 std::memcpy(output_ptr, input_ptr, window_size_in_bytes);
179 _reshape_tensor_fn = reshape_tensor_per_element_selector;
183 ICpuKernel::configure(win);
205 return "CpuReshapeKernel";
232 const auto dst_row_size =
static_cast<int>(dst_info->
tensor_shape()[0]);
234 if (!src_has_holes && !dst_has_holes)
244 _reshape_tensor_fn = reshape_tensor_per_window;
253 if (!src_has_holes_in_x && !dst_has_holes_in_x && (src_row_size == dst_row_size))
255 _reshape_tensor_fn = reshape_tensor_per_row;
262 _reshape_tensor_fn = reshape_tensor_per_element_selector;
266 ICPPKernel::configure(win);
@ QSYMM8_PER_CHANNEL
quantized, symmetric per channel fixed-point 8-bit number
@ U64
unsigned 64-bit number
SimpleTensor< float > src
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
@ F64
64-bit floating-point number
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
size_t get_mws(const CPUInfo &platform, size_t thread_count) const override
Return minimum workload size of the relevant kernel.
@ QASYMM8
quantized, asymmetric fixed-point 8-bit number unsigned
@ U16
unsigned 16-bit number
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
@ QSYMM8
quantized, symmetric fixed-point 8-bit number
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst)
Static function to check if given info will lead to a valid configuration.
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Coordinates index2coords(const TensorShape &shape, int index)
Convert a linear index into n-dimensional coordinates.
ITensor * get_tensor(int id)
Get tensor of a given id from the pac.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
size_t num_dimensions() const override
The number of dimensions of the tensor (rank)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
const char * name() const override
Name of the kernel.
const ITensor * get_const_tensor(int id) const
Get constant tensor of a given id.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
@ U32
unsigned 32-bit number
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&...iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
void prepare(ITensorPack &tensors)
Prepare the reshape kernel for execution (Only executed once) by calculating max or squashed window a...
@ U8
unsigned 8-bit number
@ S16
signed 16-bit number
@ QASYMM8_SIGNED
quantized, asymmetric fixed-point 8-bit number signed
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
const Window & window() const
The maximum window the kernel can be executed on.
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...)
Information about executing thread and CPU.
int coords2index(const TensorShape &shape, const Coordinates &coord)
Convert n-dimensional coordinates into a linear index.
void configure(const ITensorInfo *src, ITensorInfo *dst)
Configure kernel for a given list of arguments.
Describe a multidimensional execution window.
@ S64
signed 64-bit number
std::pair< Window, size_t > calculate_squashed_or_max_window(const ITensorInfo &src0, const ITensorInfo &src1)
Copyright (c) 2017-2023 Arm Limited.
@ F16
16-bit floating-point number
@ S32
signed 32-bit number
TensorInfo src_info(src_shape, 1, data_type)
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Store the tensor's metadata.
@ F32
32-bit floating-point number
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
@ UNKNOWN
Unknown data type.
bool has_holes(const ITensorInfo &info)
Check if the tensor has any holes.
static constexpr size_t default_mws
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)