42 const int stride_x =
conv_info.stride().first;
43 const int stride_y =
conv_info.stride().second;
44 const int kernel_width = kernel_dims.
width;
45 const int kernel_height = kernel_dims.
height;
48 const int src_width =
src.shape().x();
49 const int src_height =
src.shape().y();
50 const int src_channels =
src.shape().z();
51 const int batches =
src.shape().total_size_upper(3);
52 const int dst_height =
dst.shape().y();
61 for(
int g = 0; g < static_cast<int>(
num_groups); ++g)
63 const int first_group_ch = g * (src_channels /
num_groups);
64 const int last_group_ch = (g + 1) * (src_channels /
num_groups);
66 for(
int yo = 0; yo < dst_height; ++yo)
69 const int xi = (yo % convolved_dims.first) * stride_x;
70 const int yi = (yo / convolved_dims.first) * stride_y;
72 for(
int ci = first_group_ch; ci < last_group_ch; ++ci)
74 for(
int yk = 0; yk < kernel_height; ++yk)
76 for(
int xk = 0; xk < kernel_width; ++xk)
85 dst[dst_idx++] = static_cast<T>(1);
96 const int stride_x =
conv_info.stride().first;
97 const int stride_y =
conv_info.stride().second;
98 const int kernel_width = kernel_dims.
width;
99 const int kernel_height = kernel_dims.
height;
101 const int pad_y =
conv_info.pad().second;
102 const int src_width =
src.shape().y();
103 const int src_height =
src.shape().z();
104 const int src_channels =
src.shape().x();
105 const int batches =
src.shape().total_size_upper(3);
106 const int dst_width =
has_bias ?
dst.shape().x() - 1 :
dst.shape().x();
107 const int dst_height =
dst.shape().y();
113 #pragma omp parallel for schedule(dynamic, 1) collapse(2) 117 for(
int yo = 0; yo < dst_height; ++yo)
120 const int xi = (yo % convolved_dims.first) * stride_x;
121 const int yi = (yo / convolved_dims.first) * stride_y;
123 for(
int ci = 0; ci < src_channels; ++ci)
125 for(
int yk = 0; yk < kernel_height; ++yk)
127 for(
int xk = 0; xk < kernel_width; ++xk)
137 dst[dst_width + yo *
dst.shape().x() +
b *
dst.shape().x() *
dst.shape().y()] = static_cast<T>(1);
143 template <
typename T>
146 switch(
src.data_layout())
T tensor_elem_at(const SimpleTensor< T > &src, Coordinates coord, BorderMode border_mode, T constant_border_value)
void im2col_nchw(const SimpleTensor< T > &src, SimpleTensor< T > &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups)
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
size_t height
Height of the image region or rectangle.
void im2col(const SimpleTensor< T > &src, SimpleTensor< T > &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int num_groups)
std::pair< unsigned int, unsigned int > scaled_dimensions(int width, int height, int kernel_width, int kernel_height, const PadStrideInfo &pad_stride_info, const Size2D &dilation=Size2D(1U, 1U))
Returns expected width and height of output scaled tensor depending on dimensions rounding mode.
const unsigned int num_groups
Padding and stride information class.
void im2col_nhwc(const SimpleTensor< T > &src, SimpleTensor< T > &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
Num samples, channels, height, width.
bool is_data_type_quantized_asymmetric(DataType dt)
Check if a given data type is of asymmetric quantized type.
Simple tensor object that stores elements in a consecutive chunk of memory.
size_t width
Width of the image region or rectangle.
Class for specifying the size of an image or rectangle.
Num samples, height, width, channels.