44 template < typename T, typename TB, typename std::enable_if < is_floating_point<T>::value &&is_floating_point<TB>::value,
int >
::type = 0 >
45 void vector_matrix_multiply(
const SimpleTensor<T> &
src,
const SimpleTensor<T> &weights,
const SimpleTensor<TB> &bias, SimpleTensor<T> &
dst,
int offset_src,
int offset_dst,
int cols_weights,
48 const T *src_ptr = src.data() + offset_src;
49 const T *weights_ptr = weights.data();
50 const TB *bias_ptr = bias.data();
51 T *dst_ptr = dst.data() + offset_dst;
53 #pragma omp parallel for 55 for(
int y = 0; y < rows_weights; ++y)
57 dst_ptr[y] = std::inner_product(src_ptr, src_ptr + cols_weights, &weights_ptr[cols_weights * y], static_cast<T>(0)) + bias_ptr[y];
62 template < typename T, typename TB, typename std::enable_if < (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) &&std::is_same<TB, int32_t>::value,
int >::
type = 0 >
63 void vector_matrix_multiply(
const SimpleTensor<T> &src,
const SimpleTensor<T> &weights,
const SimpleTensor<TB> &bias, SimpleTensor<T> &dst,
int offset_src,
int offset_dst,
64 int cols_weights,
int rows_weights)
66 const T *src_ptr = src.data() + offset_src;
67 const T *weights_ptr = weights.data();
68 const TB *bias_ptr = bias.data();
69 T *dst_ptr = dst.data() + offset_dst;
71 const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
72 const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
73 const UniformQuantizationInfo oq_info = dst.quantization_info().uniform();
75 const int input_offset = -iq_info.offset;
76 const float input_scale = iq_info.scale;
77 const int weights_offset = -wq_info.offset;
78 const float weights_scale = wq_info.scale;
79 const int output_offset = oq_info.offset;
80 const float output_scale = oq_info.scale;
82 int output_multiplier = 0;
84 const float multiplier = input_scale * weights_scale / output_scale;
88 const int max = std::numeric_limits<T>::max();
90 #pragma omp parallel for 92 for(
int y = 0; y < rows_weights; ++y)
97 for(
int x = 0; x < cols_weights; ++x)
99 acc += (src_ptr[x] + input_offset) * (weights_ptr[x + y * cols_weights] + weights_offset);
109 dst_ptr[y] =
static_cast<T
>(acc);
114 template <
typename T,
typename TB>
127 const int num_batch_dimensions = std::max(0, static_cast<int>(dst_shape.
num_dimensions()) - 1);
128 const int num_input_dimensions = src.
shape().num_dimensions() - num_batch_dimensions;
129 const unsigned int linear_input_size = src.
shape().total_size_lower(num_input_dimensions);
139 const int cols_weights = weights.
shape().x();
140 const int rows_weights = weights.
shape().y();
143 for(
int k = 0; k < num_batches; ++k)
145 const int offset_in = k * cols_weights;
146 const int offset_out = k * rows_weights;
148 vector_matrix_multiply<T>(
src,
SimpleTensor< T > fully_connected_layer(const SimpleTensor< T > &src, const SimpleTensor< T > &weights, const SimpleTensor< TB > &bias, const TensorShape &dst_shape, QuantizationInfo out_quant_info)
DataType data_type() const override
Data type of the tensor.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
size_t total_size_upper(size_t dimension) const
Collapses given dimension and above.
TensorShape shape() const override
Shape of the tensor.
Status calculate_quantized_multiplier(float multiplier, int32_t *quant_multiplier, int32_t *shift, bool ignore_epsilon=false)
Calculate quantized representation of multiplier.
decltype(strategy::transforms) typedef type
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
Quantization information.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
int32_t quantize_down_scale_by_fixedpoint(int32_t val, int32_t result_mult_int, int32_t result_shift, int32_t result_offset_after_shift, int32_t min, int32_t max)
Quantize down the input value in range [min, max].
Simple tensor object that stores elements in a consecutive chunk of memory.
unsigned int num_dimensions() const
Returns the effective dimensionality of the tensor.
QuantizationInfo quantization_info() const override
Quantization info in case of asymmetric quantized type.