Compute Library
 22.11
NEFuseBatchNormalizationKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
26 
30 #include "arm_compute/core/Utils.h"
34 #include "src/core/CPP/Validate.h"
39 
40 #include <map>
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 struct FuseBatchNormalizeSelectorData
47 {
51  cpuinfo::CpuIsaInfo isa;
52 };
53 
55 using FBNUKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, ITensor *,
56  const ITensor *, const ITensor *, const ITensor *, const ITensor *, float, const Window &)>::type;
57 
58 struct FBNUKernel
59 {
60  const char *name;
61  const FBNSelectorPtr is_selected;
62  FBNUKernelPtr ukernel;
63 };
64 
65 static const FBNUKernel available_kernels[] =
66 {
67  {
68  "fused_batch_normalization_conv_NHWC_F16",
69  [](const FuseBatchNormalizeSelectorData & data)
70  {
71  return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
72  },
74  },
75  {
76  "fused_batch_normalization_conv_NCHW_F16",
77  [](const FuseBatchNormalizeSelectorData & data)
78  {
79  return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
80  },
82  },
83  {
84  "fused_batch_normalization_dwc_NHWC_F16",
85  [](const FuseBatchNormalizeSelectorData & data)
86  {
87  return data.dt == DataType::F16 && data.dl == DataLayout::NHWC && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
88  },
90  },
91  {
92  "fused_batch_normalization_dwc_NCHW_F16",
93  [](const FuseBatchNormalizeSelectorData & data)
94  {
95  return data.dt == DataType::F16 && data.dl == DataLayout::NCHW && data.isa.fp16 && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
96  },
98  },
99  {
100  "fused_batch_normalization_conv_NHWC_F32",
101  [](const FuseBatchNormalizeSelectorData & data)
102  {
103  return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
104  },
106  },
107  {
108  "fused_batch_normalization_conv_NCHW_F32",
109  [](const FuseBatchNormalizeSelectorData & data)
110  {
111  return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::CONVOLUTION;
112  },
114  },
115  {
116  "fused_batch_normalization_dwc_NHWC_F32",
117  [](const FuseBatchNormalizeSelectorData & data)
118  {
119  return data.dt == DataType::F32 && data.dl == DataLayout::NHWC && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
120  },
122  },
123  {
124  "fused_batch_normalization_dwc_NCHW_F32",
125  [](const FuseBatchNormalizeSelectorData & data)
126  {
127  return data.dt == DataType::F32 && data.dl == DataLayout::NCHW && data.fbn_type == FuseBatchNormalizationType::DEPTHWISECONVOLUTION;
128  },
130  }
131 };
132 
133 /** Micro-kernel selector
134  *
135  * @param[in] data Selection data passed to help pick the appropriate micro-kernel
136  *
137  * @param[in]
138  *
139  * @return A matching micro-kernel else nullptr
140  */
141 const FBNUKernel *get_implementation(const FuseBatchNormalizeSelectorData &data)
142 {
143  for(const auto &uk : available_kernels)
144  {
145  if(uk.is_selected(data))
146  {
147  return &uk;
148  }
149  }
150  return nullptr;
151 }
152 
153 Status validate_arguments(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
154  const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
155  const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
157 {
158  ARM_COMPUTE_UNUSED(epsilon);
159  ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
163  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, bn_mean, bn_var);
164  ARM_COMPUTE_RETURN_ERROR_ON(input_bias == nullptr && fused_bias == nullptr);
165  ARM_COMPUTE_RETURN_ERROR_ON(bn_mean->num_dimensions() > 1);
166 
168  {
169  ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(3) != bn_mean->dimension(0));
170  }
171  else
172  {
173  const size_t channel_idx = get_data_layout_dimension_index(input_weights->data_layout(), DataLayoutDimension::CHANNEL);
174  ARM_COMPUTE_RETURN_ERROR_ON(input_weights->dimension(channel_idx) != bn_mean->dimension(0));
175  }
176  // Validate bias
177  if(input_bias != nullptr)
178  {
180  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, input_bias);
181  }
182  // Validate beta
183  if(bn_beta != nullptr)
184  {
187  }
188  // Validate gamma
189  if(bn_gamma != nullptr)
190  {
193  }
194 
195  // Validate output weights
196  if(fused_weights != nullptr && fused_weights->total_size() != 0)
197  {
198  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_weights, fused_weights);
199  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input_weights, fused_weights);
200  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_weights);
201  }
202  // Validate output bias
203  if(fused_bias != nullptr && fused_bias->total_size() != 0)
204  {
206  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_weights, fused_bias);
207  }
208 
209  return Status{};
210 }
211 
212 } // namespace
213 
215  : _input_weights(nullptr), _input_bias(nullptr), _bn_mean(nullptr), _bn_var(nullptr), _bn_gamma(nullptr), _bn_beta(nullptr), _fused_weights(nullptr), _fused_bias(nullptr), _epsilon(),
216  _run_in_place_weights(false), _run_in_place_bias(false), _func(nullptr)
217 {
218 }
219 
220 void NEFuseBatchNormalizationKernel::configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var,
221  ITensor *fused_weights, ITensor *fused_bias,
222  const ITensor *input_bias, const ITensor *bn_beta, const ITensor *bn_gamma,
224 {
225  ARM_COMPUTE_ERROR_ON_NULLPTR(input_weights, bn_mean, bn_var);
226 
227  _input_weights = input_weights;
228  _input_bias = input_bias;
229  _bn_mean = bn_mean;
230  _bn_var = bn_var;
231  _bn_beta = bn_beta;
232  _bn_gamma = bn_gamma;
233  _fused_weights = fused_weights;
234  _fused_bias = fused_bias;
235  _epsilon = epsilon;
236 
237  _run_in_place_weights = (fused_weights == nullptr) || (fused_weights == input_weights);
238  _run_in_place_bias = (fused_bias == nullptr) || (input_bias != nullptr && fused_bias == input_bias);
239 
240  // Auto initialize outputs
241  if(_fused_weights != nullptr)
242  {
243  // Output tensor auto initialization if not yet initialized
244  auto_init_if_empty(*_fused_weights->info(), *_input_weights->info()->clone());
245  }
246  if(_fused_bias != nullptr)
247  {
248  // Output tensor auto initialization if not yet initialized
249  auto_init_if_empty(*_fused_bias->info(), *_bn_mean->info()->clone());
250  }
251 
252  // Validate arguments
253  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input_weights->info(), bn_mean->info(), bn_var->info(),
254  (fused_weights != nullptr) ? fused_weights->info() : nullptr,
255  (fused_bias != nullptr) ? fused_bias->info() : nullptr,
256  (input_bias != nullptr) ? input_bias->info() : nullptr,
257  (bn_beta != nullptr) ? bn_beta->info() : nullptr,
258  (bn_gamma != nullptr) ? bn_gamma->info() : nullptr,
259  epsilon, fbn_type));
260 
261  const auto *uk = get_implementation(FuseBatchNormalizeSelectorData{ input_weights->info()->data_type(), input_weights->info()->data_layout(), fbn_type, CPUInfo::get().get_isa() });
263  ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
264  _func = uk->ukernel;
265 
266  // Configure kernel window
267  Window win = calculate_max_window(*input_weights->info());
268  INEKernel::configure(win);
269 }
270 
271 Status NEFuseBatchNormalizationKernel::validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var,
272  const ITensorInfo *fused_weights, const ITensorInfo *fused_bias,
273  const ITensorInfo *input_bias, const ITensorInfo *bn_beta, const ITensorInfo *bn_gamma,
275 {
276  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input_weights, bn_mean, bn_var, fused_weights, fused_bias, input_bias, bn_beta, bn_gamma, epsilon, fbn_type));
277  return Status{};
278 }
279 
281 {
282  ARM_COMPUTE_UNUSED(info);
285 
286  ARM_COMPUTE_ERROR_ON(_func == nullptr);
287  (*_func)(_input_weights, _input_bias, _fused_weights, _fused_bias, _bn_mean, _bn_var, _bn_beta, _bn_gamma, _epsilon, window);
288 }
289 } // namespace arm_compute
static Status validate(const ITensorInfo *input_weights, const ITensorInfo *bn_mean, const ITensorInfo *bn_var, const ITensorInfo *fused_weights, const ITensorInfo *fused_bias, const ITensorInfo *input_bias=nullptr, const ITensorInfo *bn_beta=nullptr, const ITensorInfo *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Static function to check if given info will lead to a valid configuration of NEFuseBatchNormalization...
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
FuseBatchNormalizationType fbn_type
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...)
Definition: Validate.h:490
#define ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(tensor)
Definition: Validate.h:115
#define REGISTER_FP16_NEON(func_name)
Definition: Registrars.h:48
const FBNSelectorPtr is_selected
void fused_batch_normalization_conv_f32(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
Definition: fp32.cpp:31
#define REGISTER_FP32_NEON(func_name)
Definition: Registrars.h:74
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
1 channel, 1 F32 per channel
void fused_batch_normalization_dwc_nhwc_f16(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Status class.
Definition: Error.h:52
cpuinfo::CpuIsaInfo isa
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
decltype(strategy::transforms) typedef type
Interface for CPU tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2022 Arm Limited.
1 channel, 1 F16 per channel
void run(const Window &window, const ThreadInfo &info) override
Execute the kernel on the passed window.
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:159
FuseBatchNormalizationType
Available FuseBatchNormalizationType.
Definition: Types.h:158
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
void configure(const ITensor *input_weights, const ITensor *bn_mean, const ITensor *bn_var, ITensor *fused_weights, ITensor *fused_bias, const ITensor *input_bias=nullptr, const ITensor *bn_beta=nullptr, const ITensor *bn_gamma=nullptr, float epsilon=0.001f, FuseBatchNormalizationType fbn_type=FuseBatchNormalizationType::CONVOLUTION)
Set the source, destination of the kernel.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
void fused_batch_normalization_dwc_nchw_f32(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
Definition: all.cpp:128
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:915
Num samples, channels, height, width.
FBNUKernelPtr ukernel
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
Information about executing thread and CPU.
Definition: CPPTypes.h:179
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:439
size_t get_data_layout_dimension_index(const DataLayout &data_layout, const DataLayoutDimension &data_layout_dimension)
Get the index of the given dimension.
Definition: Helpers.inl:193
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
Num samples, height, width, channels.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
void fused_batch_normalization_conv_f16(const ITensor *conv_weights, const ITensor *conv_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
void fused_batch_normalization_dwc_nhwc_f32(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
Definition: fp32.cpp:32
Includes all wrapper headers at once.
static CPUInfo & get()
Access the KernelLibrary singleton.
Definition: CPPTypes.cpp:40
void fused_batch_normalization_dwc_nchw_f16(const ITensor *dwc_weights, const ITensor *dwc_bias, ITensor *fused_weights, ITensor *fused_bias, const ITensor *bn_mean, const ITensor *bn_var, const ITensor *bn_beta, const ITensor *bn_gamma, float epsilon, const Window &window)
DataType
Available data types.
Definition: Types.h:79
const char * name
DataLayout
[DataLayout enum definition]
Definition: Types.h:113
Describe a multidimensional execution window.
Definition: Window.h:39
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:201
cpuinfo::CpuIsaInfo get_isa() const
Gets the current cpu&#39;s ISA information.
Definition: CPPTypes.cpp:124
virtual DataLayout data_layout() const =0
Get the data layout of the tensor.