Compute Library
 20.08
NEQLSTMLayer.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
28 #include "arm_compute/core/Utils.h"
33 
34 namespace arm_compute
35 {
36 using namespace arm_compute::utils::info_helpers;
37 namespace
38 {
39 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
40  float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
41 {
42  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
43  ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
44  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
45  return Status{};
46 }
47 } // namespace
48 
49 Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
50 {
51  ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
52  ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
54  ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
55  return Status{};
56 }
57 
58 void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
59 {
61  _src = &src;
62  _dst = &dst;
63  _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
64  _window = calculate_max_window(*_src->info(), Steps());
65 }
66 
68 {
69  Iterator input_iter{ _src, _window };
70  Iterator output_iter{ _dst, _window };
71 
72  execute_window_loop(_window, [&](const Coordinates &)
73  {
74  memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
75  },
76  input_iter, output_iter);
77 }
78 
79 NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
80 {
81  _memory_group = MemoryGroup(std::move(memory_manager));
82 }
83 
84 void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
85  const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
86  Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
87  const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
88 {
89  _memory_group.manage(mm_res);
90  _memory_group.manage(outstage_res);
91 
92  mm_res->allocator()->init(mm_res_info);
93  outstage_res->allocator()->init(outstage_tensor_info);
94 
95  // Configure matrix-multiplication
96  mm.configure(mm_input, mm_weights, nullptr, mm_res);
97 
98  // Configure output stage
99  quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
100  outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
101  mm_res->allocator()->allocate();
102 }
103 
107  const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
108  const ITensor *cell_state_in, const ITensor *output_state_in,
109  ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
110  const LSTMParams<ITensor> &lstm_params)
111 {
114  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
115 
116  // Set lstm parameters
117  LSTMParams<ITensorInfo> lstm_params_info{};
118  build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
119 
120  // Validate
123  forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
124  cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
125  lstm_params_info));
126 
127  const int batch_size = input->info()->dimension(1);
128  const int num_units = input_to_output_weights->info()->dimension(1);
129  const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
130 
131  const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
132  const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
133  const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
134 
135  _projection_bias = lstm_params.projection_bias();
136  _input_to_forget_weights = input_to_forget_weights;
137  _input_to_cell_weights = input_to_cell_weights;
138  _input_to_output_weights = input_to_output_weights;
139  _recurrent_to_forget_weights = recurrent_to_forget_weights;
140  _recurrent_to_cell_weights = recurrent_to_cell_weights;
141  _recurrent_to_output_weights = recurrent_to_output_weights;
142  _projection_weights = lstm_params.projection_weights();
143 
144  // Layer normalization
145  _has_layer_norm = lstm_params.use_layer_norm();
146  if(_has_layer_norm)
147  {
148  set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
149  set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
150  set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
151  set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
152 
153  set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
154  set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
155  set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
156  set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
157  }
158 
159  _has_cifg = lstm_params.has_cifg_opt();
160  _has_projection = lstm_params.has_projection();
161  _has_peephole = lstm_params.has_peephole_opt();
162 
163  // Calculate and decompose effective scales for optimizing matmul calculation
164  const int32_t cell_shift = log2(qcell_state_in.scale);
165 
166  // Calculate quantized parameters for clipping.
167  int16_t quantized_cell_clip = 0;
168  if(lstm_params.cell_clip() > 0.0f)
169  {
170  quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
171  }
172  _has_cell_clipping = quantized_cell_clip > 0;
173 
174  // Precompute effective bias for optimizing the matmul computations.
175  if(!_has_cifg)
176  {
177  _input_to_input_weights = lstm_params.input_to_input_weights();
178  _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
179 
180  _input_to_input_reduction.configure(_input_to_input_weights, &_input_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
181  _recurrent_to_input_reduction.configure(_recurrent_to_input_weights, &_recurrent_to_input_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
182  }
183  _input_to_forget_reduction.configure(input_to_forget_weights, &_input_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
184  _recurrent_to_forget_reduction.configure(recurrent_to_forget_weights, &_recurrent_to_forget_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
185  _input_to_cell_reduction.configure(input_to_cell_weights, &_input_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
186  _recurrent_to_cell_reduction.configure(recurrent_to_cell_weights, &_recurrent_to_cell_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
187  _input_to_output_reduction.configure(input_to_output_weights, &_input_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
188  _recurrent_to_output_reduction.configure(recurrent_to_output_weights, &_recurrent_to_output_eff_bias, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
189  if(_has_projection)
190  {
191  _projection_reduction.configure(_projection_weights, &_projection_eff_bias, GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
192  if(_projection_bias != nullptr)
193  {
194  _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
195  }
196  }
197 
198  // Pre-transpose weights to be used in GEMM.
199  _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
200  _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
201  _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
202  _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
203  _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
204  _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
205  if(!_has_cifg)
206  {
207  _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
208  _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
209  }
210  if(_has_projection)
211  {
212  _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
213  }
214 
215  GEMMLowpOutputStageInfo gemmlowp_info;
218  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
219  gemmlowp_info.output_data_type = DataType::QSYMM16;
220 
221  const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
222  // Forget gate.
223  const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
224  const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
225  configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
226  input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
227  &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
228  mm_out_info, forget_gate_outstage_info);
229 
230  const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
231  configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
232  output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
233  &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
234  mm_out_info, forget_gate_outstage_info);
235 
236  _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
237  _input_to_forget_outstage_res.allocator()->allocate();
238 
239  if(_has_peephole)
240  {
241  _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
242  _memory_group.manage(&_mul_cell_to_forget_res);
243  _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
244  _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
245  _memory_group.manage(&_cell_to_forget_outstage_res);
246  const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
247  quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
248  _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
249  _mul_cell_to_forget_res.allocator()->allocate();
250  _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
251  _cell_to_forget_outstage_res.allocator()->allocate();
252  }
253 
254  Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
255 
256  if(_has_layer_norm)
257  {
258  configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
259  forget_activation_input->allocator()->allocate();
260  forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
261  }
262 
263  // Output quantization info of Sigmoid and Tanh activations
264  const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
265  const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
266 
267  _memory_group.manage(&_forget_gate);
268  _forget_gate.allocator()->init(forget_gate_info);
269  _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
270  forget_activation_input->allocator()->allocate();
271 
272  // Modulation gate.
273  const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
274  const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
275  configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
276  input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
277  &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
278  mm_out_info, cell_outstage_info);
279 
280  const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
281  configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
282  output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
283  &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
284  mm_out_info, cell_outstage_info);
285 
286  _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
287  _input_to_cell_outstage_res.allocator()->allocate();
288 
289  Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
290 
291  if(_has_layer_norm)
292  {
293  configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
294  cell_activation_input->allocator()->allocate();
295  cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
296  }
297 
298  const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
299 
300  _memory_group.manage(&_cell_gate);
301  _cell_gate.allocator()->init(cell_gate_info);
302  _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
303  cell_activation_input->allocator()->allocate();
304 
305  // Input gate.
306  const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
307  _input_gate.allocator()->init(input_gate_info);
308  _memory_group.manage(&_input_gate);
309  if(_has_cifg)
310  {
311  _ones.allocator()->init(*_forget_gate.info());
312  _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
313  _ones.allocator()->allocate();
314  }
315  else
316  {
317  const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
318  const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
319  configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
320  input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
321  &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
322  mm_out_info, input_outstage_info);
323 
324  const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
325  configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
326  output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
327  &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
328  mm_out_info, input_outstage_info);
329  _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
330  _input_to_input_outstage_res.allocator()->allocate();
331 
332  if(_has_peephole)
333  {
334  _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
335  _memory_group.manage(&_mul_cell_to_input_res);
336  _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
337  const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
338  quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
339  _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
340  _memory_group.manage(&_cell_to_input_outstage_res);
341  _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
342  _mul_cell_to_input_res.allocator()->allocate();
343  _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
344  _cell_to_input_outstage_res.allocator()->allocate();
345  }
346 
347  Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
348 
349  if(_has_layer_norm)
350  {
351  configure_layer_norm(LayerNormGate::Input, input_activation_input);
352  input_activation_input->allocator()->allocate();
353  input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
354  }
355 
356  _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
357  input_activation_input->allocator()->allocate();
358  }
359  // Cell.
360  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
361  _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
362  const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
363  const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
364  const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
365  _memory_group.manage(&_mul_input_cell_res);
366  _mul_input_cell_res.allocator()->init(mul_input_cell_info);
367  _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
368  _cell_gate.allocator()->allocate();
369  _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
370  _mul_input_cell_res.allocator()->allocate();
371  _forget_gate.allocator()->allocate();
372  if(_has_cell_clipping)
373  {
374  _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
375  }
376  // Output gate.
377  const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
378  const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
379  configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
380  input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
381  &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
382  mm_out_info, output_outstage_info);
383 
384  const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
385  configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
386  output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
387  &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
388  mm_out_info, output_outstage_info);
389 
390  _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
391  _input_to_output_outstage_res.allocator()->allocate();
392 
393  if(_has_peephole)
394  {
395  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
396  // Here we are not using the output stage because all operations are done in float
397  _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
398  _memory_group.manage(&_mul_cell_to_output_res);
399  _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
400 
401  const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
402  quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
403  _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
404  _memory_group.manage(&_cell_to_output_outstage_res);
405  _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
406  _mul_cell_to_output_res.allocator()->allocate();
407 
408  _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
409  _cell_to_output_outstage_res.allocator()->allocate();
410  }
411 
412  Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
413 
414  if(_has_layer_norm)
415  {
416  configure_layer_norm(LayerNormGate::Output, output_activation_input);
417  output_activation_input->allocator()->allocate();
418  output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
419  }
420  const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
421 
422  _memory_group.manage(&_output_gate);
423  _output_gate.allocator()->init(output_gate_info);
424  _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
425  output_activation_input->allocator()->allocate();
426 
427  // Hidden.
428  _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
429  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
430  _memory_group.manage(&_hidden_mul_res);
431  const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
432  _hidden_mul_res.allocator()->init(hidden_mul_res);
433  _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
434  _output_gate.allocator()->allocate();
435  _input_gate.allocator()->allocate();
436  const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
437  quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
438  gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
439  gemmlowp_info.output_data_type = output_state_in->info()->data_type();
440 
441  _projection_tensor_copy_required = (num_units != output_size);
442  ITensor *hidden_gate_result = output_state_out;
443 
444  _memory_group.manage(&_hidden_gate);
445 
446  if(_projection_tensor_copy_required)
447  {
448  _hidden_gate.allocator()->init(*output_state_out->info());
449  _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
450  hidden_gate_result = &_hidden_gate;
451  }
452 
453  _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
454  _hidden_mul_res.allocator()->allocate();
455 
456  // Projection.
457  if(_has_projection)
458  {
459  const TensorInfo projection_outstage_info(*output_state_out->info());
460  const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
461  const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
462  gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
464  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
466 
467  TensorInfo projection_mm_out_info{ mm_out_info };
468  projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
469 
470  configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
471  hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
472  &_mm_projection_res, &_projection_outstage_res, projection_scale,
473  projection_mm_out_info, projection_outstage_info);
474 
475  ITensor *accumulate_destination = output_state_out;
476 
477  if(_projection_tensor_copy_required)
478  {
479  _hidden_gate.allocator()->allocate();
480  _projection_accumulate_res.allocator()->init(*output_state_out->info());
481  _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
482  _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res);
483  accumulate_destination = &_projection_accumulate_res;
484  }
485 
486  _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
487  _projection_outstage_res.allocator()->allocate();
488 
489  if(_projection_tensor_copy_required)
490  {
491  _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
492  _projection_accumulate_res.allocator()->allocate();
493  }
494 
495  int8_t quantized_projection_clip{ 0 };
496  if(lstm_params.projection_clip() > 0.0f)
497  {
498  quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
499  }
500 
501  if(quantized_projection_clip > 0)
502  {
503  _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
504  _has_projection_clipping = true;
505  }
506  }
507  else
508  {
509  if(_projection_tensor_copy_required)
510  {
511  _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
512  _hidden_gate.allocator()->allocate();
513  }
514  }
515 
516  // Copy output_state_out to output
517  _copy_output.configure(output_state_out, output);
518 }
519 
523  const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
524  const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
525  const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
526  const LSTMParams<ITensorInfo> &lstm_params)
527 {
529  recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
530  cell_state_out, output_state_out, output);
531 
533  ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
534 
535  const unsigned int input_size = input->dimension(0);
536  const unsigned int batch_size = input->dimension(1);
537  const unsigned int num_units = input_to_output_weights->dimension(1);
538  const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
539 
544  ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
549 
550  ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
551  ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
555 
556  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
557  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
558  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
560 
561  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
562  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
563  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
565 
566  // Check whether peephole weights are all there or none
567  if(lstm_params.has_peephole_opt())
568  {
571  ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
572  ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
575 
576  if(!lstm_params.has_cifg_opt())
577  {
581  }
582  }
583 
584  const UniformQuantizationInfo qinput = input->quantization_info().uniform();
585  const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
586  const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
587 
588  // Calculate and decompose effective scales for optimizing matmul calculation
589  const int32_t cell_shift = log2(qcell_state_in.scale);
590  ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
591 
592  // Calculate quantized parameters for clipping.
593  int16_t quantized_cell_clip = 0;
594  if(lstm_params.cell_clip() > 0.0f)
595  {
596  quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
597  }
598 
599  // Precompute effective bias for optimizing the matmul computations.
600  const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
601  const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
602  if(!lstm_params.has_cifg_opt())
603  {
606  true)));
607  }
614  if(lstm_params.has_projection())
615  {
617  lstm_params.hidden_state_zero(),
618  true)));
619  if(lstm_params.projection_bias() != nullptr)
620  {
622  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info, &projection_eff_bias_info, ConvertPolicy::SATURATE));
623  }
624  }
625 
626  const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
627  const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
628 
629  // Validate weights transpose
636  if(!lstm_params.has_cifg_opt())
637  {
638  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
639  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
640  }
641  if(lstm_params.has_projection())
642  {
643  const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
644  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
645  }
646 
647  GEMMLowpOutputStageInfo gemmlowp_info;
650  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
651  gemmlowp_info.output_data_type = DataType::QSYMM16;
652 
653  const bool has_layer_norm = lstm_params.use_layer_norm();
654 
655  // Forget gate.
657  const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
658  const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
659  const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
660  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
661 
662  const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
663  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
664 
665  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
666 
667  if(lstm_params.has_peephole_opt())
668  {
672  const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
674  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
675  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
676  }
677 
678  if(has_layer_norm)
679  {
680  const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
681  const ITensorInfo *b_info = forget_gate_bias;
682  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
683  }
684 
685  // Output quantization info of Sigmoid and Tanh activations
686  const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
687  const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
688 
690 
691  // Modulation gate.
693  const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
694  const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
695  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
696 
697  const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
698  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
699 
700  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
701 
702  if(has_layer_norm)
703  {
704  const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
705  const ITensorInfo *b_info = cell_bias;
706  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
707  }
708  const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
709 
711 
712  // Input gate.
713  const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
714  if(lstm_params.has_cifg_opt())
715  {
716  ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
717  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
718  }
719  else
720  {
727 
729  const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
730  const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
731  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
732 
733  const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
734  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
735 
736  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
737 
738  if(lstm_params.has_peephole_opt())
739  {
742  const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
744  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
745  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
746  }
747 
748  if(has_layer_norm)
749  {
750  const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
751  const ITensorInfo *b_info = lstm_params.input_gate_bias();
752  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
753  }
754 
756  }
757  // Cell.
758  ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
760  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
761  if(quantized_cell_clip > 0)
762  {
764  quantized_cell_clip)));
765  }
766  // Output gate.
768  const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
769  const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
770  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
771 
772  const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
773  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
774 
775  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
776  if(lstm_params.has_peephole_opt())
777  {
779  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
780  // Here we are not using the output stage because all operations are done in float
781  // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
782  // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
785  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
786  }
787 
788  if(has_layer_norm)
789  {
790  const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
791  const ITensorInfo *b_info = output_gate_bias;
792  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
793  }
794 
795  const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
797 
798  // Hidden.
800  const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
801  const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
802  ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
803 
805  const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
806  ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
807  gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
808  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
809 
810  const bool projection_tensor_copy_required = num_units != output_size;
811 
812  // Projection.
813  if(lstm_params.has_projection())
814  {
816  ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
817 
818  const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
819  const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
821  gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
823  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
825 
826  const TensorInfo projection_outstage_info(*output_state_out);
827  const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
828 
829  TensorInfo projection_mm_out_info{ mm_out_info };
830  projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
831 
832  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
833  &projection_outstage_info));
834 
835  if(projection_tensor_copy_required)
836  {
837  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info));
838  }
839 
840  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
841 
842  if(projection_tensor_copy_required)
843  {
844  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
845  }
846 
847  int8_t quantized_projection_clip{ 0 };
848  if(lstm_params.projection_clip() > 0.0f)
849  {
850  quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
851  }
852 
853  if(quantized_projection_clip > 0)
854  {
856  quantized_projection_clip)));
857  }
858  }
859  else
860  {
861  if(projection_tensor_copy_required)
862  {
863  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
864  }
865  }
866 
867  if(cell_state_out->total_size() > 0)
868  {
869  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
870  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
871  }
872 
873  if(output_state_out->total_size() > 0)
874  {
876  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
877  }
878 
879  ARM_COMPUTE_RETURN_ON_ERROR(NECopyKernel::validate(output_state_out, output));
880  return Status{};
881 }
882 
884 {
885  prepare();
886 
887  // Acquire all the temporaries
888  MemoryGroupResourceScope scope_mg(_memory_group);
889 
890  // Forget gate.
891  _mm_input_to_forget.run();
892  _input_to_forget_outstage.run();
893 
894  _mm_recurrent_to_forget.run();
895  _recurrent_to_forget_outstage.run();
896  _accumulate_input_recurrent_forget.run();
897 
898  if(_has_peephole)
899  {
900  _pixelwise_mul_cell_to_forget.run();
901  _cell_to_forget_outstage.run();
902  _accumulate_cell_forget.run();
903  }
904 
905  if(_has_layer_norm)
906  {
907  NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Forget), Window::DimY);
908  }
909 
910  _forget_gate_sigmoid.run();
911 
912  // Modulation gate.
913  _mm_input_to_cell.run();
914  _input_to_cell_outstage.run();
915 
916  _mm_recurrent_to_cell.run();
917  _recurrent_to_cell_outstage.run();
918  _accumulate_input_recurrent_modulation.run();
919 
920  if(_has_layer_norm)
921  {
922  NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Cell), Window::DimY);
923  }
924 
925  _cell_gate_tanh.run();
926 
927  // Input gate
928  if(_has_cifg)
929  {
930  _input_gate_sub.run();
931  }
932  else
933  {
934  _mm_input_to_input.run();
935  _input_to_input_outstage.run();
936  _mm_recurrent_to_input.run();
937  _recurrent_to_input_outstage.run();
938  _accumulate_input_recurrent_input.run();
939 
940  if(_has_peephole)
941  {
942  _pixelwise_mul_cell_to_input.run();
943  _cell_to_input_outstage.run();
944  _accumulate_cell_input.run();
945  }
946 
947  if(_has_layer_norm)
948  {
949  NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY);
950  }
951 
952  _input_gate_sigmoid.run();
953  }
954 
955  // Cell.
956  _pixelwise_mul_forget_cell.run();
957  _pixelwise_mul_input_cell.run();
958  _add_forget_cell.run();
959 
960  if(_has_cell_clipping)
961  {
962  _cell_clip.run();
963  }
964 
965  // Output gate.
966  _mm_input_to_output.run();
967  _input_to_output_outstage.run();
968  _mm_recurrent_to_output.run();
969  _recurrent_to_output_outstage.run();
970  _accumulate_input_recurrent_output.run();
971  if(_has_peephole)
972  {
973  _pixelwise_mul_cell_to_output.run();
974  _cell_to_output_outstage.run();
975  _accumulate_cell_to_output.run();
976  }
977 
978  if(_has_layer_norm)
979  {
980  NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Output), Window::DimY);
981  }
982 
983  _output_gate_sigmoid.run();
984 
985  // Hidden.
986  _hidden_tanh.run();
987  _pixelwise_mul_hidden.run();
988  _hidden_outstage.run();
989 
990  // Projection.
991  if(_has_projection)
992  {
993  _mm_projection.run();
994  _projection_outstage.run();
995 
996  if(_projection_tensor_copy_required)
997  {
998  _projection_output_to_accumulate_copy.run();
999  }
1000 
1001  _accumulate_projection.run();
1002 
1003  if(_projection_tensor_copy_required)
1004  {
1005  _projection_accumulate_to_output_copy.run();
1006  }
1007 
1008  if(_has_projection_clipping)
1009  {
1010  _projection_clip.run();
1011  }
1012  }
1013  else
1014  {
1015  if(_projection_tensor_copy_required)
1016  {
1017  _hidden_to_output_copy.run();
1018  }
1019  }
1020 
1021  // Copy output_state_out to output
1022  NEScheduler::get().schedule(&_copy_output, Window::DimY);
1023 }
1024 
1026 {
1027  if(!_is_prepared)
1028  {
1029  // Pre-transpose weights to be used in GEMM.
1030  _input_to_forget_weights_transposed.allocator()->allocate();
1031  _input_to_cell_weights_transposed.allocator()->allocate();
1032  _input_to_output_weights_transposed.allocator()->allocate();
1033  _recurrent_to_forget_weights_transposed.allocator()->allocate();
1034  _recurrent_to_cell_weights_transposed.allocator()->allocate();
1035  _recurrent_to_output_weights_transposed.allocator()->allocate();
1036  _transpose_input_to_forget_weights.run();
1037  _transpose_input_to_cell_weights.run();
1038  _transpose_input_to_output_weights.run();
1039  _transpose_recurrent_to_forget_weights.run();
1040  _transpose_recurrent_to_cell_weights.run();
1041  _transpose_recurrent_to_output_weights.run();
1042 
1043  // Precompute effective biases
1044  if(_has_cifg)
1045  {
1046  std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1047  }
1048  else
1049  {
1050  _input_to_input_eff_bias.allocator()->allocate();
1051  _recurrent_to_input_eff_bias.allocator()->allocate();
1052  NEScheduler::get().schedule(&_input_to_input_reduction, Window::DimY);
1053  NEScheduler::get().schedule(&_recurrent_to_input_reduction, Window::DimY);
1054 
1055  _input_to_input_weights_transposed.allocator()->allocate();
1056  _recurrent_to_input_weights_transposed.allocator()->allocate();
1057  _transpose_input_to_input_weights.run();
1058  _transpose_recurrent_to_input_weights.run();
1059  _input_to_input_weights->mark_as_unused();
1060  _recurrent_to_input_weights->mark_as_unused();
1061  }
1062  _input_to_forget_eff_bias.allocator()->allocate();
1063  _recurrent_to_forget_eff_bias.allocator()->allocate();
1064  _input_to_cell_eff_bias.allocator()->allocate();
1065  _recurrent_to_cell_eff_bias.allocator()->allocate();
1066  _input_to_output_eff_bias.allocator()->allocate();
1067  _recurrent_to_output_eff_bias.allocator()->allocate();
1068  NEScheduler::get().schedule(&_input_to_forget_reduction, Window::DimY);
1069  NEScheduler::get().schedule(&_recurrent_to_forget_reduction, Window::DimY);
1070  NEScheduler::get().schedule(&_input_to_cell_reduction, Window::DimY);
1071  NEScheduler::get().schedule(&_recurrent_to_cell_reduction, Window::DimY);
1072  NEScheduler::get().schedule(&_input_to_output_reduction, Window::DimY);
1073  NEScheduler::get().schedule(&_recurrent_to_output_reduction, Window::DimY);
1074 
1075  if(_has_projection)
1076  {
1077  _projection_eff_bias.allocator()->allocate();
1078  NEScheduler::get().schedule(&_projection_reduction, Window::DimY);
1079  if(_projection_bias != nullptr)
1080  {
1081  _projection_bias_add.run();
1082  _projection_bias->mark_as_unused();
1083  }
1084 
1085  _projection_weights_transposed.allocator()->allocate();
1086  _transpose_projection_weights.run();
1087  _projection_weights->mark_as_unused();
1088 
1089  if(!_projection_tensor_copy_required)
1090  {
1091  _hidden_gate.mark_as_unused();
1092  _projection_accumulate_res.mark_as_unused();
1093  }
1094  }
1095 
1096  // Mark weights as unused
1097  _input_to_forget_weights->mark_as_unused();
1098  _input_to_cell_weights->mark_as_unused();
1099  _input_to_output_weights->mark_as_unused();
1100  _recurrent_to_forget_weights->mark_as_unused();
1101  _recurrent_to_cell_weights->mark_as_unused();
1102  _recurrent_to_output_weights->mark_as_unused();
1103 
1104  _is_prepared = true;
1105  }
1106 }
1107 
1108 } // namespace arm_compute
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
void configure(const ITensor *input, const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights, const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights, const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias, const ITensor *cell_state_in, const ITensor *output_state_in, ITensor *cell_state_out, ITensor *output_state_out, ITensor *output, const LSTMParams< ITensor > &lstm_params)
Initialize function's tensors.
const T * projection_weights() const
Definition: LSTMParams.h:227
int32_t gemmlowp_multiplier
GEMMLowp output stage multiplier used for quantizing to QASYMM8.
Definition: Types.h:1885
const T * input_to_input_weights() const
Definition: LSTMParams.h:197
int16_t quantize_qsymm16(float value, const UniformQuantizationInfo &qinfo, RoundingPolicy rounding_policy=RoundingPolicy::TO_NEAREST_UP)
Quantize a value given a 16-bit symmetric quantization scheme.
Shape of a tensor.
Definition: TensorShape.h:39
Quantize using a fixed point multiplication.
NEQLSTMLayer(std::shared_ptr< IMemoryManager > memory_manager=nullptr)
Default constructor.
quantized, symmetric fixed-point 16-bit number
bool use_layer_norm() const
Definition: LSTMParams.h:312
TensorInfo * info() const override
Interface to be implemented by the child class to return the tensor's metadata.
Definition: CLTensor.cpp:41
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEArithmeticAddition.
void init(const TensorAllocator &allocator, const Coordinates &coords, TensorInfo &sub_info)
Shares the same backing memory with another tensor allocator, while the tensor info might be differen...
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
bool has_peephole_opt() const
Definition: LSTMParams.h:297
#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:543
T * forget_layer_norm_weights() const
Definition: LSTMParams.h:242
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:545
void build_lstm_params_tensor_info(const LSTMParams< T > &lstm_params, LSTMParams< ITensorInfo > *lstm_params_info)
Build LSTMParams<ITensorInfo> object by extracting the metadata from each tensor.
Definition: InfoHelpers.h:71
virtual ITensorInfo & set_tensor_shape(const TensorShape &shape)=0
Set the shape of an already initialized tensor.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info)
[NEActivationLayer snippet]
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:792
float output_intermediate_scale() const
Definition: LSTMParams.h:282
bool has_cifg_opt() const
Definition: LSTMParams.h:307
float cell_intermediate_scale() const
Definition: LSTMParams.h:277
float forget_intermediate_scale() const
Definition: LSTMParams.h:272
Store the tensor's metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Quantization info when assuming per layer quantization.
int32_t gemmlowp_offset
GEMMLowp output stage offset used for quantizing to QASYMM8.
Definition: Types.h:1884
T * cell_to_input_weights() const
Definition: LSTMParams.h:207
Status calculate_quantized_multiplier(float multiplier, int32_t *quant_multiplier, int32_t *shift, bool ignore_epsilon=false)
Calculate quantized representation of multiplier.
Status class.
Definition: Error.h:52
int32_t gemmlowp_max_bound
GEMMLowp max value used to saturate down the output result before converting back to QASYMM8.
Definition: Types.h:1888
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
Activation Layer Information class.
Definition: Types.h:1517
GEMMLowpOutputStageType type
GEMMLowp output stage type.
Definition: Types.h:1883
Interface for NEON tensor.
Definition: ITensor.h:36
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps=Steps(), bool skip_border=false, BorderSize border_size=BorderSize())
Calculate the maximum window for a given tensor shape and border setting.
Definition: Helpers.cpp:28
Copyright (c) 2017-2020 Arm Limited.
static Status validate(const ITensorInfo *input, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias, const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in, const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output, const LSTMParams< ITensorInfo > &lstm_params)
Static function to check if given info will lead to a valid configuration of NEQLSTMLayer.
TensorAllocator * allocator()
Return a pointer to the tensor's allocator.
Definition: Tensor.cpp:48
1 channel, 1 S32 per channel
const T * recurrent_to_input_weights() const
Definition: LSTMParams.h:202
int32_t hidden_state_zero() const
Definition: LSTMParams.h:287
const T * projection_bias() const
Definition: LSTMParams.h:232
Quantization information.
T * output_layer_norm_weights() const
Definition: LSTMParams.h:252
float input_intermediate_scale() const
Definition: LSTMParams.h:267
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output, const GEMMInfo &gemm_info=GEMMInfo())
Initialise the kernel's inputs, output.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:443
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
int8_t quantize_qasymm8_signed(float value, const INFO_TYPE &qinfo, RoundingPolicy rounding_policy=RoundingPolicy::TO_NEAREST_UP)
Quantize a value given a signed 8-bit asymmetric quantization scheme.
float hidden_state_scale() const
Definition: LSTMParams.h:292
void allocate() override
Allocate size specified by TensorInfo of CPU memory.
UniformQuantizationInfo uniform() const
Return per layer quantization info.
GEMMLowp output stage info.
Definition: Types.h:1881
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor's metadata.
Basic implementation of the tensor interface.
Definition: Tensor.h:37
virtual QuantizationInfo quantization_info() const =0
Get the quantization settings (scale and offset) of the tensor.
quantized, symmetric fixed-point 8-bit number
float cell_clip() const
Definition: LSTMParams.h:257
T * cell_to_forget_weights() const
Definition: LSTMParams.h:217
void init(Format format)
Initialize the tensor info with just a format.
Definition: TensorInfo.cpp:107
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:163
bool has_projection() const
Definition: LSTMParams.h:302
float projection_clip() const
Definition: LSTMParams.h:262
int32_t gemmlowp_shift
GEMMLowp output stage shift used for quantizing to uint8.
Definition: Types.h:1886
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const PaddingList &padding=PaddingList())
Static function to check if given info will lead to a valid configuration of NECopyKernel.
T * cell_to_output_weights() const
Definition: LSTMParams.h:222
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
void configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info)
Initialise the kernel's inputs, output.
Memory group resources scope handling class.
Definition: IMemoryGroup.h:82
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEPixelWiseMultiplicatio...
virtual size_t total_size() const =0
Returns the total size of the tensor in bytes.
T * input_layer_norm_weights() const
Definition: LSTMParams.h:237
virtual void schedule(ICPPKernel *kernel, const Hints &hints)=0
Runs the kernel in the same thread as the caller synchronously.
static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpOutputStage.
void run() override
Run the kernels contained in the function.
Basic function to execute GEMMLowpQuantizeDown kernels on NEON.
const T * input_gate_bias() const
Definition: LSTMParams.h:212
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
Store the tensor's metadata.
Definition: TensorInfo.h:45
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:128
ITensorInfo & set_tensor_shape(const TensorShape &shape) override
Set the shape of an already initialized tensor.
Definition: TensorInfo.cpp:350
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info=GEMMInfo())
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
T * cell_layer_norm_weights() const
Definition: LSTMParams.h:247
quantized, asymmetric fixed-point 8-bit number signed
static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixAReducti...
int32_t gemmlowp_min_bound
GEMMLowp min value used to saturate down the output result before converting back to QASYMM8.
Definition: Types.h:1887
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
Definition: TensorInfo.h:261
DataType output_data_type
Output tensor data type to use if the output is not initialized.
Definition: Types.h:1893
Truncates the least significant values that are lost in operations.
void prepare() override
Prepare the function for executing.
Basic function to execute GEMMLowpMatrixMultiplyCore on NEON.
static Status validate(const ITensorInfo *input, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NETranspose.
Definition: NETranspose.cpp:40
Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes, const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
cast configure & src
Definition: Cast.cpp:169
static IScheduler & get()
Access the scheduler singleton.
Definition: Scheduler.cpp:95
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEArithmeticSubtraction.