Compute Library
 21.11
NEInstanceNormalizationLayerKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 #include "arm_compute/core/Error.h"
31 #include "arm_compute/core/Utils.h"
34 #include "src/core/CPP/Validate.h"
35 #include "src/core/NEON/NEMath.h"
39 
40 #include <arm_neon.h>
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 template <typename InputType, typename AccType = InputType>
47 void vector_float_sum(AccType &result, AccType &result_square, const InputType &inputs)
48 {
49  result = wrapper::vadd(result, inputs);
50  result_square = wrapper::vadd(result_square, wrapper::vmul(inputs, inputs));
51 }
52 
53 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
54 template <>
55 inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square, const float16x8_t &inputs)
56 {
57  vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgetlow(inputs)));
58  vector_float_sum(result, result_square, wrapper::vcvt<float>(wrapper::vgethigh(inputs)));
59 }
60 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
61 
62 template <typename InputType, typename AccType = InputType>
63 InputType vector_float_norm(const InputType &inputs, const AccType &vec_mean, const AccType &vec_multip, const AccType &vec_beta)
64 {
65  return wrapper::vadd(wrapper::vmul(wrapper::vsub(inputs, vec_mean), vec_multip), vec_beta);
66 }
67 
68 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
69 template <>
70 inline float16x8_t vector_float_norm(const float16x8_t &inputs, const float32x4_t &vec_mean, const float32x4_t &vec_multip, const float32x4_t &vec_beta)
71 {
72  const auto input_low = wrapper::vcvt<float>(wrapper::vgetlow(inputs));
73  const auto input_high = wrapper::vcvt<float>(wrapper::vgethigh(inputs));
74  const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
75  const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
76  float16x8_t result = wrapper::vcombine(result_low, result_high);
77 
78  return result;
79 }
80 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
81 
82 template <typename T, typename AccType = T>
83 void instance_normalization_nchw(ITensor *input, ITensor *output, float gamma, float beta, float epsilon, const Window &window)
84 {
85  /** SIMD vector tag type. */
86  using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
87 
88  // Clear X/Y dimensions on execution window as we handle the planes manually
89  Window win = window;
90  win.set(Window::DimX, Window::Dimension(0, 1, 1));
91  win.set(Window::DimY, Window::Dimension(0, 1, 1));
92 
93  constexpr int window_step_x = 16 / sizeof(T);
94  const unsigned int elements_plane = input->info()->dimension(0) * output->info()->dimension(1);
95 
96  Iterator input_it(input, win);
97  execute_window_loop(win, [&](const Coordinates & id)
98  {
99  Window win_plane = window;
100  win_plane.set(Window::DimX, Window::Dimension(0, 1, 1));
101  win_plane.set(Window::DimZ, Window::Dimension(id[2], id[2] + 1, 1));
102  win_plane.set(3, Window::Dimension(id[3], id[3] + 1, 1));
103 
104  Iterator input_plane_it(input, win_plane);
105  Iterator output_plane_it(output, win_plane);
106 
107  auto sum_h_w = static_cast<AccType>(0.f);
108  auto sum_squares_h_w = static_cast<AccType>(0.f);
109 
110  execute_window_loop(win_plane, [&](const Coordinates &)
111  {
112  const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
113 
114  auto vec_sum_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
115  auto vec_sum_squares_h_w = wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
116 
117  // Compute S elements per iteration
118  int x = window.x().start();
119  for(; x <= (window.x().end() - window_step_x); x += window_step_x)
120  {
121  auto vec_input_val = wrapper::vloadq(input_ptr + x);
122  vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
123  }
124 
125  auto vec2_sum_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_h_w), wrapper::vgetlow(vec_sum_h_w));
126  auto vec2_sum_squares_h_w = wrapper::vpadd(wrapper::vgethigh(vec_sum_squares_h_w), wrapper::vgetlow(vec_sum_squares_h_w));
127 
128  vec2_sum_h_w = wrapper::vpadd(vec2_sum_h_w, vec2_sum_h_w);
129  vec2_sum_squares_h_w = wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
130 
131  sum_h_w += wrapper::vgetlane(vec2_sum_h_w, 0);
132  sum_squares_h_w += wrapper::vgetlane(vec2_sum_squares_h_w, 0);
133 
134  // Compute left-over elements
135  for(; x < window.x().end(); ++x)
136  {
137  const auto value = static_cast<AccType>(*(input_ptr + x));
138  sum_h_w += value;
139  sum_squares_h_w += value * value;
140  }
141  },
142  input_plane_it, output_plane_it);
143 
144  const auto mean_h_w = sum_h_w / elements_plane;
145  const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
146 
147  const auto multip_h_w = gamma / std::sqrt(var_h_w + epsilon);
148  const auto vec_mean_h_w = wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
149  const auto vec_multip_h_w = wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
150  const auto vec_beta = wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
151 
152  execute_window_loop(win_plane, [&](const Coordinates &)
153  {
154  auto input_ptr = reinterpret_cast<T *>(input_plane_it.ptr());
155  auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
156 
157  // Compute S elements per iteration
158  int x = window.x().start();
159  //auto vec_val = wrapper::vdup_n(static_cast<T>(0.0f), ExactTagType{});
160  for(; x <= (window.x().end() - window_step_x); x += window_step_x)
161  {
162  const auto vec_val = wrapper::vloadq(input_ptr + x);
163  const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
164  wrapper::vstore(output_ptr + x, normalized_vec);
165  }
166 
167  // Compute left-over elements
168  for(; x < window.x().end(); ++x)
169  {
170  const auto val = static_cast<AccType>(*(input_ptr + x));
171  *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
172  }
173  },
174  input_plane_it, output_plane_it);
175  },
176  input_it);
177 }
178 
179 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
180 {
182  ARM_COMPUTE_UNUSED(gamma);
183  ARM_COMPUTE_UNUSED(beta);
184  ARM_COMPUTE_RETURN_ERROR_ON_MSG(epsilon == 0.f, "Epsilon must be different than 0");
185 
187  ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC, "NHWC data layout is not supported by the kernel directly");
188 
189  if(output != nullptr && output->total_size() != 0)
190  {
194  ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_channels() != output->num_channels(), "Input and output have different number of channels");
195  }
196  return Status{};
197 }
198 
199 std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
200 {
201  // We handle the planes manually
202  Window win = calculate_max_window(*input, Steps(1));
203 
204  // Output auto initialization if not yet initialized
205  auto_init_if_empty(*output, input->tensor_shape(), 1, input->data_type());
206 
207  // NEInstanceNormalizationLayerKernel doesn't need padding so update_window_and_padding() can be skipped
208  return std::make_pair(Status{}, win);
209 }
210 } // namespace
211 
213  : _func(nullptr), _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
214 {
215 }
216 
218 {
220 
221  _input = input;
222  _output = output == nullptr ? input : output;
223  _gamma = info.gamma;
224  _beta = info.beta;
225  _epsilon = info.epsilon;
226  _use_mixed_precision = info.use_mixed_precision;
227 
228  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), _gamma, _beta, _epsilon));
229 
230  if(_input->info()->data_type() == DataType::F32)
231  {
232  _func = &instance_normalization_nchw<float>;
233  }
234 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
235  else if(_input->info()->data_type() == DataType::F16)
236  {
237  if(_use_mixed_precision)
238  {
239  _func = &instance_normalization_nchw<float16_t, float>;
240  }
241  else
242  {
243  _func = &instance_normalization_nchw<float16_t>;
244  }
245  }
246 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
247  else
248  {
249  ARM_COMPUTE_ERROR("Unsupported data type");
250  }
251 
252  // Configure kernel window
253  auto win_config = validate_and_configure_window(_input->info(), _output->info());
254  ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
255 
256  INEKernel::configure(std::get<1>(win_config));
257 }
258 
260 {
261  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info.gamma, info.beta, info.epsilon));
262  ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
263  return Status{};
264 }
265 
267 {
268  ARM_COMPUTE_UNUSED(info);
271  (*_func)(_input, _output, _gamma, _beta, _epsilon, window);
272 }
273 } // namespace arm_compute
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...)
Definition: Validate.h:490
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:115
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
uint8x16_t vloadq(const uint8_t *ptr)
Definition: load.h:58
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
uint8x8_t vadd(const uint8x8_t &a, const uint8x8_t &b)
Definition: add.h:39
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
Static function to check if given info will lead to a valid configuration of NEInstanceNormalizationL...
1 channel, 1 F32 per channel
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
void configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info)
Set the input and output tensors.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
uint8x8_t vsub(const uint8x8_t &a, const uint8x8_t &b)
Definition: sub.h:39
Status class.
Definition: Error.h:52
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
bool use_mixed_precision
Use mixed precision in case of FP16 execution.
uint8x8_t vpadd(const uint8x8_t &a, const uint8x8_t &b)
Definition: add.h:187
uint8_t vgetlane(const uint8x8_t vector, const unsigned int lane)
Definition: getlane.h:91
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
Definition: Window.h:43
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
uint8x8_t vgetlow(const uint8x16_t val)
Definition: getlow.h:39
uint8x16_t vcombine(const uint8x8_t &a, const uint8x8_t &b)
Definition: combine.h:39
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:915
uint8x8_t vgethigh(const uint8x16_t val)
Definition: gethigh.h:39
float epsilon
Lower bound value for the normalization.
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
Definition: mul.h:39
Information about executing thread and CPU.
Definition: CPPTypes.h:158
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:439
static constexpr size_t DimZ
Alias for dimension 2 also known as Z dimension.
Definition: Window.h:47
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
float gamma
The scale scalar value applied to the normalized tensor.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
Num samples, height, width, channels.
void vstore(uint8_t *ptr, uint8x8_t val)
Definition: store.h:39
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)
Definition: dup_n.h:41
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
Includes all wrapper headers at once.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(t,...)
Definition: Validate.h:690
Describe a multidimensional execution window.
Definition: Window.h:39
float beta
The offset scalar value applied to the normalized tensor.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:201