46 template <
typename InputType,
typename AccType = InputType>
47 void vector_float_sum(AccType &result, AccType &result_square,
const InputType &inputs)
53 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 55 inline void vector_float_sum(float32x4_t &result, float32x4_t &result_square,
const float16x8_t &inputs)
57 vector_float_sum(result, result_square, wrapper::vcvt<float>(
wrapper::vgetlow(inputs)));
58 vector_float_sum(result, result_square, wrapper::vcvt<float>(
wrapper::vgethigh(inputs)));
60 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 62 template <
typename InputType,
typename AccType = InputType>
63 InputType vector_float_norm(
const InputType &inputs,
const AccType &vec_mean,
const AccType &vec_multip,
const AccType &vec_beta)
68 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 70 inline float16x8_t vector_float_norm(
const float16x8_t &inputs,
const float32x4_t &vec_mean,
const float32x4_t &vec_multip,
const float32x4_t &vec_beta)
74 const auto result_low = wrapper::vcvt<float16_t>(vector_float_norm(input_low, vec_mean, vec_multip, vec_beta));
75 const auto result_high = wrapper::vcvt<float16_t>(vector_float_norm(input_high, vec_mean, vec_multip, vec_beta));
80 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 82 template <
typename T,
typename AccType = T>
83 void instance_normalization_nchw(ITensor *
input, ITensor *output,
float gamma,
float beta,
float epsilon,
const Window &window)
86 using ExactTagType =
typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
93 constexpr
int window_step_x = 16 /
sizeof(T);
94 const unsigned int elements_plane =
input->info()->dimension(0) * output->info()->dimension(1);
96 Iterator input_it(
input, win);
99 Window win_plane = window;
100 win_plane.set(
Window::DimX, Window::Dimension(0, 1, 1));
101 win_plane.set(
Window::DimZ, Window::Dimension(
id[2],
id[2] + 1, 1));
102 win_plane.set(3, Window::Dimension(
id[3],
id[3] + 1, 1));
104 Iterator input_plane_it(
input, win_plane);
105 Iterator output_plane_it(output, win_plane);
107 auto sum_h_w = static_cast<AccType>(0.f);
108 auto sum_squares_h_w = static_cast<AccType>(0.f);
112 const auto input_ptr = reinterpret_cast<const T *>(input_plane_it.ptr());
114 auto vec_sum_h_w =
wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
115 auto vec_sum_squares_h_w =
wrapper::vdup_n(static_cast<AccType>(0.f), ExactTagType{});
118 int x = window.x().start();
119 for(; x <= (window.x().end() - window_step_x); x += window_step_x)
122 vector_float_sum(vec_sum_h_w, vec_sum_squares_h_w, vec_input_val);
129 vec2_sum_squares_h_w =
wrapper::vpadd(vec2_sum_squares_h_w, vec2_sum_squares_h_w);
135 for(; x < window.x().end(); ++x)
137 const auto value = static_cast<AccType>(*(input_ptr + x));
139 sum_squares_h_w += value * value;
142 input_plane_it, output_plane_it);
144 const auto mean_h_w = sum_h_w / elements_plane;
145 const auto var_h_w = sum_squares_h_w / elements_plane - mean_h_w * mean_h_w;
147 const auto multip_h_w = gamma / std::sqrt(var_h_w +
epsilon);
148 const auto vec_mean_h_w =
wrapper::vdup_n(static_cast<AccType>(mean_h_w), ExactTagType{});
149 const auto vec_multip_h_w =
wrapper::vdup_n(static_cast<AccType>(multip_h_w), ExactTagType{});
150 const auto vec_beta =
wrapper::vdup_n(static_cast<AccType>(beta), ExactTagType{});
154 auto input_ptr = reinterpret_cast<T *>(input_plane_it.ptr());
155 auto output_ptr = reinterpret_cast<T *>(output_plane_it.ptr());
158 int x = window.x().start();
160 for(; x <= (window.x().end() - window_step_x); x += window_step_x)
163 const auto normalized_vec = vector_float_norm(vec_val, vec_mean_h_w, vec_multip_h_w, vec_beta);
168 for(; x < window.x().end(); ++x)
170 const auto val = static_cast<AccType>(*(input_ptr + x));
171 *(output_ptr + x) = static_cast<T>((val - mean_h_w) * multip_h_w + beta);
174 input_plane_it, output_plane_it);
189 if(output !=
nullptr && output->total_size() != 0)
199 std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *
input, ITensorInfo *output)
208 return std::make_pair(Status{}, win);
213 : _func(nullptr), _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12)
222 _output = output ==
nullptr ?
input : output;
225 _epsilon =
info.epsilon;
226 _use_mixed_precision =
info.use_mixed_precision;
232 _func = &instance_normalization_nchw<float>;
234 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 237 if(_use_mixed_precision)
239 _func = &instance_normalization_nchw<float16_t, float>;
243 _func = &instance_normalization_nchw<float16_t>;
246 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 253 auto win_config = validate_and_configure_window(_input->
info(), _output->
info());
256 INEKernel::configure(std::get<1>(win_config));
271 (*_func)(_input, _output, _gamma, _beta, _epsilon,
window);
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
NEInstanceNormalizationLayerKernel()
Default constructor.
const Window & window() const
The maximum window the kernel can be executed on.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...)
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
uint8x16_t vloadq(const uint8_t *ptr)
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
virtual DataType data_type() const =0
Data type used for each element of the tensor.
uint8x8_t vadd(const uint8x8_t &a, const uint8x8_t &b)
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
Static function to check if given info will lead to a valid configuration of NEInstanceNormalizationL...
1 channel, 1 F32 per channel
Store the tensor's metadata.
void configure(ITensor *input, ITensor *output, const InstanceNormalizationLayerKernelInfo &info)
Set the input and output tensors.
#define ARM_COMPUTE_ERROR_THROW_ON(status)
uint8x8_t vsub(const uint8x8_t &a, const uint8x8_t &b)
Interface for CPU tensor.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
uint8x8_t vpadd(const uint8x8_t &a, const uint8x8_t &b)
uint8_t vgetlane(const uint8x8_t vector, const unsigned int lane)
static constexpr size_t DimX
Alias for dimension 0 also known as X dimension.
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
uint8x8_t vgetlow(const uint8x16_t val)
uint8x16_t vcombine(const uint8x8_t &a, const uint8x8_t &b)
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
uint8x8_t vgethigh(const uint8x16_t val)
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
uint8x8_t vmul(const uint8x8_t &a, const uint8x8_t &b)
Information about executing thread and CPU.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
static constexpr size_t DimZ
Alias for dimension 2 also known as Z dimension.
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Num samples, height, width, channels.
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
void vstore(uint8_t *ptr, uint8x8_t val)
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
uint8x8_t vdup_n(uint8_t value, traits::vector_64_tag)
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Includes all wrapper headers at once.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(t,...)
Describe a multidimensional execution window.
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)