56 const std::vector<T> m_pad_row;
59 std::vector<int> m_kernel_y;
60 std::vector<int> m_kernel_x;
62 class column_handler {
67 const T *
const m_input_base;
68 const size_t m_input_stride;
71 const unsigned int m_start_pos;
72 const unsigned int m_start_offset;
75 const unsigned int m_length;
76 const unsigned int m_rounded_stringlen;
81 const column_handler &m_parent;
84 unsigned int m_start_output_y=0;
85 unsigned int m_start_output_x=0;
87 unsigned int m_length_remaining=0;
88 unsigned int m_current_pos=0;
90 unsigned int m_active_height=0;
93 row_handler(
const column_handler &parent,
unsigned int start_row,
unsigned int active_height) :
94 m_convolver(parent.m_parent),
96 m_start_output_y(start_row / m_convolver.m_params.output_width),
97 m_start_output_x(start_row % m_convolver.m_params.output_width),
98 m_length_remaining(m_parent.m_length),
99 m_current_pos(m_parent.m_start_pos),
100 m_active_height(active_height) { }
102 bool finished()
const {
103 return (m_length_remaining == 0);
108 std::tuple<unsigned int, unsigned int> next_block(
const T **
const row_ptr) {
110 return std::make_tuple(0, 0);
113 const T *pad_ptr = m_convolver.m_pad_row.data();
117 unsigned int offset = (m_current_pos == m_parent.m_start_pos) ? m_parent.m_start_offset : 0;
118 unsigned int in_width = std::min(m_length_remaining,
static_cast<unsigned int>(m_convolver.m_params.input_channels) -
offset);
119 unsigned int out_width = std::min(m_length_remaining, m_parent.m_rounded_stringlen -
offset);
121 unsigned int output_y = m_start_output_y;
122 unsigned int output_x = m_start_output_x;
128 while (row < m_active_height) {
129 int input_y = (output_y * m_convolver.m_params.output_stride_h) + m_convolver.m_kernel_y[m_current_pos];
130 int input_x = (output_x * m_convolver.m_params.output_stride_w) + m_convolver.m_kernel_x[m_current_pos];
133 const T *base_ptr = m_parent.m_input_base +
134 (input_y * m_convolver.m_params.input_width * m_parent.m_input_stride);
143 if (input_y >= m_convolver.m_params.input_height) {
144 while (row < m_active_height) {
145 row_ptr[row++] = pad_ptr;
153 while (output_x < m_convolver.m_params.output_width && row<m_active_height) {
154 row_ptr[row++] = pad_ptr;
165 while (row < m_active_height && input_x < 0) {
166 row_ptr[row++] = pad_ptr;
169 input_x+=m_convolver.m_params.output_stride_w;
173 if (output_x == m_convolver.m_params.output_width) {
179 while (row < m_active_height && input_x < m_convolver.m_params.input_width) {
180 row_ptr[row++] = base_ptr + (input_x * m_parent.m_input_stride);
183 input_x+=m_convolver.m_params.output_stride_w;
185 if (output_x == m_convolver.m_params.output_width) {
191 while (row < m_active_height && output_x < m_convolver.m_params.output_width) {
192 row_ptr[row++] = pad_ptr;
204 m_length_remaining-=out_width;
206 return std::make_tuple(in_width,
offset);
211 column_handler(
const convolver<T> &parent,
const T *input_base,
size_t input_stride,
212 unsigned int k_start,
unsigned int k_end,
unsigned int rounded_stringlen)
213 : m_parent(parent), m_input_base(input_base), m_input_stride(input_stride),
214 m_start_pos(k_start / rounded_stringlen),
215 m_start_offset(k_start % rounded_stringlen),
216 m_length(k_end - k_start),
217 m_rounded_stringlen(rounded_stringlen) { }
219 row_handler process_rows(
unsigned int start_row,
unsigned int active_height)
const {
220 return row_handler(*
this, start_row, active_height);
226 m_params (params), m_pad_row(params.input_channels, static_cast<T>(params.padding_value)),
227 m_kernel_y(params.kernel_width * params.kernel_height, 0),
228 m_kernel_x(params.kernel_width * params.kernel_height, 0) {
241 unsigned int k_start,
unsigned int k_end,
unsigned int rounded_stringlen)
const {
242 return column_handler(*
this, input_base, input_stride, k_start, k_end, rounded_stringlen);