Compute Library
 23.05
NEQLSTMLayer.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
29 #include "arm_compute/core/Utils.h"
34 #include "src/common/utils/Log.h"
38 
39 namespace arm_compute
40 {
41 using namespace arm_compute::utils::info_helpers;
42 namespace
43 {
44 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
45  float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
46 {
47  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
48  ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
49  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
50  return Status{};
51 }
52 } // namespace
53 
54 Status NEQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
55 {
56  // Output quantization scale will be different, but ignored here
57  // since it will be configured at configure() stage.
58  const TensorInfo out
59  {
60  in
61  };
62  return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
63 }
64 
65 void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g, const ITensor *in)
66 {
67  ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
68 
69  Tensor &out = get_layer_norm_output(g);
70  _memory_group.manage(&out);
71  out.allocator()->init(*(in->info()));
72 
73  get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
74  get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
75 }
76 
77 NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() = default;
78 
79 Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
80 {
81  ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
82  ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
84  ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
85  return Status{};
86 }
87 
88 void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
89 {
91  ARM_COMPUTE_LOG_PARAMS(src, dst);
92 
93  _src = &src;
94  _dst = &dst;
95  _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
96  _window = calculate_max_window(*_src->info(), Steps());
97 }
98 
100 {
101  Iterator input_iter{ _src, _window };
102  Iterator output_iter{ _dst, _window };
103 
104  execute_window_loop(_window, [&](const Coordinates &)
105  {
106  memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
107  },
108  input_iter, output_iter);
109 }
110 
111 NEQLSTMLayer::~NEQLSTMLayer() = default;
112 
113 NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
114  : _memory_group(),
115  _dequantize_input_to_forget_weights(),
116  _quantize_input_to_forget_weights(),
117  _transpose_input_to_forget_weights(),
118  _transpose_input_to_cell_weights(),
119  _transpose_input_to_output_weights(),
120  _transpose_input_to_input_weights(),
121  _transpose_recurrent_to_forget_weights(),
122  _transpose_recurrent_to_cell_weights(),
123  _transpose_recurrent_to_output_weights(),
124  _transpose_recurrent_to_input_weights(),
125  _transpose_projection_weights(),
126  _input_to_input_reduction(),
127  _recurrent_to_input_reduction(),
128  _input_to_forget_reduction(),
129  _recurrent_to_forget_reduction(),
130  _input_to_cell_reduction(),
131  _recurrent_to_cell_reduction(),
132  _input_to_output_reduction(),
133  _recurrent_to_output_reduction(),
134  _projection_reduction(),
135  _projection_bias_add(),
136  _mm_input_to_forget(),
137  _mm_recurrent_to_forget(),
138  _pixelwise_mul_cell_to_forget(),
139  _input_to_forget_outstage(),
140  _recurrent_to_forget_outstage(),
141  _cell_to_forget_outstage(),
142  _accumulate_input_recurrent_forget(),
143  _accumulate_cell_forget(),
144  _forget_gate_sigmoid(),
145  _mm_input_to_cell(),
146  _input_to_cell_outstage(),
147  _mm_recurrent_to_cell(),
148  _recurrent_to_cell_outstage(),
149  _accumulate_input_recurrent_modulation(),
150  _cell_gate_tanh(),
151  _input_gate_sub(),
152  _mm_input_to_input(),
153  _input_to_input_outstage(),
154  _mm_recurrent_to_input(),
155  _recurrent_to_input_outstage(),
156  _accumulate_input_recurrent_input(),
157  _pixelwise_mul_cell_to_input(),
158  _cell_to_input_outstage(),
159  _accumulate_cell_input(),
160  _input_gate_sigmoid(),
161  _pixelwise_mul_forget_cell(),
162  _pixelwise_mul_input_cell(),
163  _add_forget_cell(),
164  _cell_clip(),
165  _mm_input_to_output(),
166  _input_to_output_outstage(),
167  _mm_recurrent_to_output(),
168  _recurrent_to_output_outstage(),
169  _accumulate_input_recurrent_output(),
170  _pixelwise_mul_cell_to_output(),
171  _cell_to_output_outstage(),
172  _accumulate_cell_to_output(),
173  _output_gate_sigmoid(),
174  _hidden_tanh(),
175  _pixelwise_mul_hidden(),
176  _hidden_outstage(),
177  _mm_projection(),
178  _projection_outstage(),
179  _accumulate_projection(),
180  _projection_clip(),
181  _projection_bias_copy(),
182  _projection_output_to_accumulate_copy(),
183  _projection_accumulate_to_output_copy(),
184  _hidden_to_output_copy(),
185  _layer_norms(),
186  _copy_output(),
187  _layer_norm_weights(),
188  _layer_norm_bias(),
189  _layer_norm_output()
190 {
191  _memory_group = MemoryGroup(std::move(memory_manager));
192 }
193 
194 void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
195  const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
196  Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
197  const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
198 {
199  _memory_group.manage(mm_res);
200  _memory_group.manage(outstage_res);
201 
202  mm_res->allocator()->init(mm_res_info);
203  outstage_res->allocator()->init(outstage_tensor_info);
204 
205  // Configure matrix-multiplication
206  mm.configure(mm_input, mm_weights, nullptr, mm_res);
207 
208  // Configure output stage
209  quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
210  outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
211  mm_res->allocator()->allocate();
212 }
213 
217  const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
218  const ITensor *cell_state_in, ITensor *output_state_in,
219  ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
220  const LSTMParams<ITensor> &lstm_params)
221 {
222  ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
223  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
224  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
225 
226  ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
227  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
228  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
229 
230  // Set lstm parameters
231  LSTMParams<ITensorInfo> lstm_params_info{};
232  build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
233 
234  _input_to_forget_weights_transposed.info()->set_quantization_info(input_to_forget_weights->info()->quantization_info());
235  _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info());
236  _input_to_output_weights_transposed.info()->set_quantization_info(input_to_output_weights->info()->quantization_info());
237  _recurrent_to_forget_weights_transposed.info()->set_quantization_info(recurrent_to_forget_weights->info()->quantization_info());
238  _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info());
239  _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info());
240 
241  if(input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
242  {
243  _convert_input_to_forget_weights_to_qsymm8 = true;
244  // Setup dequantize output tensor to go from QASYMM8_SIGNED -> F32
245 
246  _input_to_forget_weights_f32.allocator()->init(TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::F32)
247  .set_data_layout(input_to_forget_weights->info()->data_layout()));
248  // Setup the quantize output tensor to go from F32 -> QSYMM8
249  _input_to_forget_weights_symm8.allocator()->init((TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::QSYMM8)
250  .set_data_layout(input_to_forget_weights->info()->data_layout())
251  .set_quantization_info(input_to_forget_weights->info()->quantization_info())));
252 
253  _dequantize_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_f32);
254  _quantize_input_to_forget_weights.configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
255 
256  ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), _input_to_forget_weights_symm8.info(), input_to_cell_weights->info(), input_to_output_weights->info(),
257  recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
258  forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
259  cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
260  lstm_params_info));
261  }
262  else
263  {
264  ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
265  recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
266  forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
267  cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
268  lstm_params_info));
269  }
270 
271  const int batch_size = input->info()->dimension(1);
272  const int num_units = input_to_output_weights->info()->dimension(1);
273  const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
274 
275  const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
276  const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
277  const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
278 
279  _projection_bias = lstm_params.projection_bias();
280  _input_to_forget_weights = (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED) ? &_input_to_forget_weights_symm8 : input_to_forget_weights;
281  _input_to_cell_weights = input_to_cell_weights;
282  _input_to_output_weights = input_to_output_weights;
283  _recurrent_to_forget_weights = recurrent_to_forget_weights;
284  _recurrent_to_cell_weights = recurrent_to_cell_weights;
285  _recurrent_to_output_weights = recurrent_to_output_weights;
286  _projection_weights = lstm_params.projection_weights();
287 
288  // Layer normalization
289  _has_layer_norm = lstm_params.use_layer_norm();
290  if(_has_layer_norm)
291  {
292  set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
293  set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
294  set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
295  set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
296 
297  set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
298  set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
299  set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
300  set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
301  }
302 
303  _has_cifg = lstm_params.has_cifg_opt();
304  _has_projection = lstm_params.has_projection();
305  _has_peephole = lstm_params.has_peephole_opt();
306 
307  // Calculate and decompose effective scales for optimizing matmul calculation
308  const int32_t cell_shift = log2(qcell_state_in.scale);
309 
310  // Calculate quantized parameters for clipping.
311  int16_t quantized_cell_clip = 0;
312  if(lstm_params.cell_clip() > 0.0f)
313  {
314  quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
315  }
316  _has_cell_clipping = quantized_cell_clip > 0;
317 
318  // Precompute effective bias for optimizing the matmul computations.
319  if(!_has_cifg)
320  {
321  _input_to_input_weights = lstm_params.input_to_input_weights();
322  _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
323 
324  _input_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
325  _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
326  _input_to_input_reduction->configure(_input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
327  _recurrent_to_input_reduction->configure(_recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
328  }
329 
330  _input_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
331  _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
332  _input_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
333  _recurrent_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
334  _input_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
335  _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
336 
337  _input_to_forget_reduction->configure(input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
338  _recurrent_to_forget_reduction->configure(recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
339  _input_to_cell_reduction->configure(input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
340  _recurrent_to_cell_reduction->configure(recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
341  _input_to_output_reduction->configure(input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
342  _recurrent_to_output_reduction->configure(recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
343  if(_has_projection)
344  {
345  _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
346  _projection_reduction->configure(_projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
347  if(_projection_bias != nullptr)
348  {
349  _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
350  }
351  }
352 
353  // Pre-transpose weights to be used in GEMM.
354  _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
355  _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
356  _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
357  _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
358  _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
359  _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
360  if(!_has_cifg)
361  {
362  _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
363  _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
364  }
365  if(_has_projection)
366  {
367  _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
368  }
369 
370  GEMMLowpOutputStageInfo gemmlowp_info;
373  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
374  gemmlowp_info.output_data_type = DataType::QSYMM16;
375 
376  const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
377  // Forget gate.
378  const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
379  const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
380  configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
381  input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
382  &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
383  mm_out_info, forget_gate_outstage_info);
384 
385  const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
386  configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
387  output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
388  &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
389  mm_out_info, forget_gate_outstage_info);
390 
391  _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
392  _input_to_forget_outstage_res.allocator()->allocate();
393 
394  if(_has_peephole)
395  {
396  _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
397  _memory_group.manage(&_mul_cell_to_forget_res);
398  _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
399  _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
400  _memory_group.manage(&_cell_to_forget_outstage_res);
401  const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
402  quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
403  _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
404  _mul_cell_to_forget_res.allocator()->allocate();
405  _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
406  _cell_to_forget_outstage_res.allocator()->allocate();
407  }
408 
409  Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
410 
411  if(_has_layer_norm)
412  {
413  configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
414  forget_activation_input->allocator()->allocate();
415  forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
416  }
417 
418  // Output quantization info of Sigmoid and Tanh activations
419  const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
420  const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
421 
422  _memory_group.manage(&_forget_gate);
423  _forget_gate.allocator()->init(forget_gate_info);
424  _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
425  forget_activation_input->allocator()->allocate();
426 
427  // Modulation gate.
428  const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
429  const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
430  configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
431  input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
432  &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
433  mm_out_info, cell_outstage_info);
434 
435  const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
436  configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
437  output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
438  &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
439  mm_out_info, cell_outstage_info);
440 
441  _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
442  _input_to_cell_outstage_res.allocator()->allocate();
443 
444  Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
445 
446  if(_has_layer_norm)
447  {
448  configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
449  cell_activation_input->allocator()->allocate();
450  cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
451  }
452 
453  const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
454 
455  _memory_group.manage(&_cell_gate);
456  _cell_gate.allocator()->init(cell_gate_info);
457  _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
458  cell_activation_input->allocator()->allocate();
459 
460  // Input gate.
461  const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
462  _input_gate.allocator()->init(input_gate_info);
463  _memory_group.manage(&_input_gate);
464  if(_has_cifg)
465  {
466  _ones.allocator()->init(*_forget_gate.info());
467  _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
468  _ones.allocator()->allocate();
469  }
470  else
471  {
472  const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
473  const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
474  configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
475  input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
476  &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
477  mm_out_info, input_outstage_info);
478 
479  const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
480  configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
481  output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
482  &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
483  mm_out_info, input_outstage_info);
484  _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
485  _input_to_input_outstage_res.allocator()->allocate();
486 
487  if(_has_peephole)
488  {
489  _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
490  _memory_group.manage(&_mul_cell_to_input_res);
491  _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
492  const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
493  quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
494  _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
495  _memory_group.manage(&_cell_to_input_outstage_res);
496  _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
497  _mul_cell_to_input_res.allocator()->allocate();
498  _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
499  _cell_to_input_outstage_res.allocator()->allocate();
500  }
501 
502  Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
503 
504  if(_has_layer_norm)
505  {
506  configure_layer_norm(LayerNormGate::Input, input_activation_input);
507  input_activation_input->allocator()->allocate();
508  input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
509  }
510 
511  _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
512  input_activation_input->allocator()->allocate();
513  }
514  // Cell.
515  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
516  _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
517  const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
518  const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
519  const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
520  _memory_group.manage(&_mul_input_cell_res);
521  _mul_input_cell_res.allocator()->init(mul_input_cell_info);
522  _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
523  _cell_gate.allocator()->allocate();
524  _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
525  _mul_input_cell_res.allocator()->allocate();
526  _forget_gate.allocator()->allocate();
527  if(_has_cell_clipping)
528  {
529  _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
530  }
531  // Output gate.
532  const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
533  const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
534  configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
535  input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
536  &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
537  mm_out_info, output_outstage_info);
538 
539  const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
540  configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
541  output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
542  &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
543  mm_out_info, output_outstage_info);
544 
545  _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
546  _input_to_output_outstage_res.allocator()->allocate();
547 
548  if(_has_peephole)
549  {
550  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
551  // Here we are not using the output stage because all operations are done in float
552  _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
553  _memory_group.manage(&_mul_cell_to_output_res);
554  _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
555 
556  const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
557  quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
558  _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
559  _memory_group.manage(&_cell_to_output_outstage_res);
560  _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
561  _mul_cell_to_output_res.allocator()->allocate();
562 
563  _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
564  _cell_to_output_outstage_res.allocator()->allocate();
565  }
566 
567  Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
568 
569  if(_has_layer_norm)
570  {
571  configure_layer_norm(LayerNormGate::Output, output_activation_input);
572  output_activation_input->allocator()->allocate();
573  output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
574  }
575  const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
576 
577  _memory_group.manage(&_output_gate);
578  _output_gate.allocator()->init(output_gate_info);
579  _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
580  output_activation_input->allocator()->allocate();
581 
582  // Hidden.
583  _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
584  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
585  _memory_group.manage(&_hidden_mul_res);
586  const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
587  _hidden_mul_res.allocator()->init(hidden_mul_res);
588  _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
589  _output_gate.allocator()->allocate();
590  _input_gate.allocator()->allocate();
591  const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
592  quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
593  gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
594  gemmlowp_info.output_data_type = output_state_in->info()->data_type();
595 
596  _projection_tensor_copy_required = (num_units != output_size);
597  ITensor *hidden_gate_result = output_state_out;
598 
599  _memory_group.manage(&_hidden_gate);
600 
601  if(_projection_tensor_copy_required)
602  {
603  _hidden_gate.allocator()->init(*output_state_out->info());
604  _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
605  hidden_gate_result = &_hidden_gate;
606  }
607 
608  _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
609  _hidden_mul_res.allocator()->allocate();
610 
611  // Projection.
612  if(_has_projection)
613  {
614  const TensorInfo projection_outstage_info(*output_state_out->info());
615  const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
616  const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
617  gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
619  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
621 
622  TensorInfo projection_mm_out_info{ mm_out_info };
623  projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
624 
625  configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
626  hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
627  &_mm_projection_res, &_projection_outstage_res, projection_scale,
628  projection_mm_out_info, projection_outstage_info);
629 
630  ITensor *accumulate_destination = output_state_out;
631 
632  if(_projection_tensor_copy_required)
633  {
634  _hidden_gate.allocator()->allocate();
635  _projection_accumulate_res.allocator()->init(*output_state_in->info());
636  _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
637  _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
638  accumulate_destination = &_projection_accumulate_res;
639  }
640 
641  _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
642  _projection_outstage_res.allocator()->allocate();
643 
644  if(_projection_tensor_copy_required)
645  {
646  _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
647  _projection_accumulate_res.allocator()->allocate();
648  }
649 
650  int8_t quantized_projection_clip{ 0 };
651  if(lstm_params.projection_clip() > 0.0f)
652  {
653  quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
654  }
655 
656  if(quantized_projection_clip > 0)
657  {
658  _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
659  _has_projection_clipping = true;
660  }
661  }
662  else
663  {
664  if(_projection_tensor_copy_required)
665  {
666  _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
667  _hidden_gate.allocator()->allocate();
668  }
669  }
670 
671  // Copy output_state_out to output
672  _copy_output.configure(output_state_out, output);
673 }
674 
678  const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
679  const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
680  const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
681  const LSTMParams<ITensorInfo> &lstm_params)
682 {
683  ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
684  recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
685  cell_state_out, output_state_out, output);
686 
688  ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
689 
690  const unsigned int input_size = input->dimension(0);
691  const unsigned int batch_size = input->dimension(1);
692  const unsigned int num_units = input_to_output_weights->dimension(1);
693  const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
694 
695  ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
696  ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
697  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
698  ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
699  ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
700  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
702 
703  // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
704  if (input_to_forget_weights->data_type() == DataType::QSYMM8)
705  {
706  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_cell_weights, input_to_output_weights,
707  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
708  }
709  else
710  {
711  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
712  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
713  }
714  ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
715  ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
716  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
718  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
719 
720  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
721  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
722  ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
724 
725  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
726  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
727  ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
729 
730  // Check whether peephole weights are all there or none
731  if(lstm_params.has_peephole_opt())
732  {
735  ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
736  ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
739 
740  if(!lstm_params.has_cifg_opt())
741  {
745  }
746  }
747 
748  const UniformQuantizationInfo qinput = input->quantization_info().uniform();
749  const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
750  const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
751 
752  // Calculate and decompose effective scales for optimizing matmul calculation
753  const int32_t cell_shift = log2(qcell_state_in.scale);
754  ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
755 
756  // Calculate quantized parameters for clipping.
757  int16_t quantized_cell_clip = 0;
758  if(lstm_params.cell_clip() > 0.0f)
759  {
760  quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
761  }
762 
763  // Precompute effective bias for optimizing the matmul computations.
764  const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
765  const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
766  if(!lstm_params.has_cifg_opt())
767  {
769  -qinput.offset, true)));
771  -qoutput_state_in.offset,
772  true)));
773  }
774  ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
776  -qoutput_state_in.offset, true)));
777  ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
778  ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
779  true)));
780  ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
782  -qoutput_state_in.offset, true)));
783  if(lstm_params.has_projection())
784  {
786  lstm_params.hidden_state_zero(),
787  true)));
788  if(lstm_params.projection_bias() != nullptr)
789  {
791  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info, &projection_eff_bias_info, ConvertPolicy::SATURATE));
792  }
793  }
794 
795  const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_cell_weights->data_type(), input_to_cell_weights->quantization_info());
796  const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1, input_to_output_weights->data_type(), input_to_output_weights->quantization_info());
797  const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
798  const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_cell_weights->data_type(), recurrent_to_cell_weights->quantization_info());
799  const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_output_weights->data_type(), recurrent_to_output_weights->quantization_info());
800  const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
801 
802  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
803  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed));
804  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed));
805  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed));
806  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed));
807  if(!lstm_params.has_cifg_opt())
808  {
809  const TensorInfo recurrent_to_input_weights_transposed(TensorShape(num_units, output_size), 1,
810  recurrent_to_forget_weights->data_type(), lstm_params.recurrent_to_input_weights()->quantization_info());
811  const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1,
812  lstm_params.input_to_input_weights()->data_type(), lstm_params.input_to_input_weights()->quantization_info());
813  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed));
814  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed));
815  }
816  if(lstm_params.has_projection())
817  {
818  const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
819  ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
820  }
821 
822  GEMMLowpOutputStageInfo gemmlowp_info;
825  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
826  gemmlowp_info.output_data_type = DataType::QSYMM16;
827 
828  const bool has_layer_norm = lstm_params.use_layer_norm();
829 
830  // Forget gate.
832  const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
833  const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
834  const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
835  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
836 
837  const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
838  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
839 
840  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
841 
842  if(lstm_params.has_peephole_opt())
843  {
847  const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
849  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
850  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
851  }
852 
853  if(has_layer_norm)
854  {
855  const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
856  const ITensorInfo *b_info = forget_gate_bias;
857  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
858  }
859 
860  // Output quantization info of Sigmoid and Tanh activations
861  const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
862  const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
863 
865 
866  // Modulation gate.
868  const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
869  const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
870  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
871 
872  const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
873  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
874 
875  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
876 
877  if(has_layer_norm)
878  {
879  const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
880  const ITensorInfo *b_info = cell_bias;
881  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
882  }
883  const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
884 
886 
887  // Input gate.
888  const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
889  if(lstm_params.has_cifg_opt())
890  {
891  ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
892  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
893  }
894  else
895  {
897 
898  // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
899  if (input_to_forget_weights->data_type() == DataType::QSYMM8)
900  {
902  }
903  else
904  {
906  }
907  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
908  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
910  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
911 
913  const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
914  const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
915  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
916 
917  const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
918  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
919 
920  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
921 
922  if(lstm_params.has_peephole_opt())
923  {
926  const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
928  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
929  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
930  }
931 
932  if(has_layer_norm)
933  {
934  const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
935  const ITensorInfo *b_info = lstm_params.input_gate_bias();
936  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
937  }
938 
940  }
941  // Cell.
942  ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
944  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
945  if(quantized_cell_clip > 0)
946  {
948  quantized_cell_clip)));
949  }
950  // Output gate.
952  const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
953  const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
954  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
955 
956  const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
957  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
958 
959  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
960  if(lstm_params.has_peephole_opt())
961  {
963  // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
964  // Here we are not using the output stage because all operations are done in float
965  // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
966  // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
969  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
970  }
971 
972  if(has_layer_norm)
973  {
974  const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
975  const ITensorInfo *b_info = output_gate_bias;
976  ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
977  }
978 
979  const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
981 
982  // Hidden.
984  const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
985  const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
986  ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
987 
989  const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
990  ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
991  gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
992  gemmlowp_info.output_data_type = hidden_out_info.data_type();
993  ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
994 
995  const bool projection_tensor_copy_required = num_units != output_size;
996 
997  // Projection.
998  if(lstm_params.has_projection())
999  {
1000  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
1001  ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
1002 
1003  const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
1004  const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
1006  gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
1008  gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
1010 
1011  const TensorInfo projection_outstage_info(*output_state_out);
1012  const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
1013 
1014  TensorInfo projection_mm_out_info{ mm_out_info };
1015  projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
1016 
1017  ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
1018  &projection_outstage_info));
1019 
1020  if(projection_tensor_copy_required)
1021  {
1022  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
1023  }
1024 
1025  ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
1026 
1027  if(projection_tensor_copy_required)
1028  {
1029  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
1030  }
1031 
1032  int8_t quantized_projection_clip{ 0 };
1033  if(lstm_params.projection_clip() > 0.0f)
1034  {
1035  quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
1036  }
1037 
1038  if(quantized_projection_clip > 0)
1039  {
1041  quantized_projection_clip)));
1042  }
1043  }
1044  else
1045  {
1046  if(projection_tensor_copy_required)
1047  {
1048  ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
1049  }
1050  }
1051 
1052  if(cell_state_out->total_size() > 0)
1053  {
1054  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
1055  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
1056  }
1057 
1058  if(output_state_out->total_size() > 0)
1059  {
1060  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
1061  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
1062  }
1063 
1064  ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
1065  return Status{};
1066 }
1067 
1069 {
1070  prepare();
1071 
1072  // Acquire all the temporaries
1073  MemoryGroupResourceScope scope_mg(_memory_group);
1074 
1075  // Forget gate.
1076  _mm_input_to_forget.run();
1077  _input_to_forget_outstage.run();
1078 
1079  _mm_recurrent_to_forget.run();
1080  _recurrent_to_forget_outstage.run();
1081  _accumulate_input_recurrent_forget.run();
1082 
1083  if(_has_peephole)
1084  {
1085  _pixelwise_mul_cell_to_forget.run();
1086  _cell_to_forget_outstage.run();
1087  _accumulate_cell_forget.run();
1088  }
1089 
1090  if(_has_layer_norm)
1091  {
1092  NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Forget).get(), Window::DimY);
1093  }
1094 
1095  _forget_gate_sigmoid.run();
1096 
1097  // Modulation gate.
1098  _mm_input_to_cell.run();
1099  _input_to_cell_outstage.run();
1100 
1101  _mm_recurrent_to_cell.run();
1102  _recurrent_to_cell_outstage.run();
1103  _accumulate_input_recurrent_modulation.run();
1104 
1105  if(_has_layer_norm)
1106  {
1107  NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Cell).get(), Window::DimY);
1108  }
1109 
1110  _cell_gate_tanh.run();
1111 
1112  // Input gate
1113  if(_has_cifg)
1114  {
1115  _input_gate_sub.run();
1116  }
1117  else
1118  {
1119  _mm_input_to_input.run();
1120  _input_to_input_outstage.run();
1121  _mm_recurrent_to_input.run();
1122  _recurrent_to_input_outstage.run();
1123  _accumulate_input_recurrent_input.run();
1124 
1125  if(_has_peephole)
1126  {
1127  _pixelwise_mul_cell_to_input.run();
1128  _cell_to_input_outstage.run();
1129  _accumulate_cell_input.run();
1130  }
1131 
1132  if(_has_layer_norm)
1133  {
1134  NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Input).get(), Window::DimY);
1135  }
1136 
1137  _input_gate_sigmoid.run();
1138  }
1139 
1140  // Cell.
1141  _pixelwise_mul_forget_cell.run();
1142  _pixelwise_mul_input_cell.run();
1143  _add_forget_cell.run();
1144 
1145  if(_has_cell_clipping)
1146  {
1147  _cell_clip.run();
1148  }
1149 
1150  // Output gate.
1151  _mm_input_to_output.run();
1152  _input_to_output_outstage.run();
1153  _mm_recurrent_to_output.run();
1154  _recurrent_to_output_outstage.run();
1155  _accumulate_input_recurrent_output.run();
1156  if(_has_peephole)
1157  {
1158  _pixelwise_mul_cell_to_output.run();
1159  _cell_to_output_outstage.run();
1160  _accumulate_cell_to_output.run();
1161  }
1162 
1163  if(_has_layer_norm)
1164  {
1165  NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Output).get(), Window::DimY);
1166  }
1167 
1168  _output_gate_sigmoid.run();
1169 
1170  // Hidden.
1171  _hidden_tanh.run();
1172  _pixelwise_mul_hidden.run();
1173  _hidden_outstage.run();
1174 
1175  // Projection.
1176  if(_has_projection)
1177  {
1178  _mm_projection.run();
1179  _projection_outstage.run();
1180 
1181  if(_projection_tensor_copy_required)
1182  {
1183  _projection_output_to_accumulate_copy.run();
1184  }
1185 
1186  _accumulate_projection.run();
1187 
1188  if(_projection_tensor_copy_required)
1189  {
1190  _projection_accumulate_to_output_copy.run();
1191  }
1192 
1193  if(_has_projection_clipping)
1194  {
1195  _projection_clip.run();
1196  }
1197  }
1198  else
1199  {
1200  if(_projection_tensor_copy_required)
1201  {
1202  _hidden_to_output_copy.run();
1203  }
1204  }
1205 
1206  // Copy output_state_out to output
1207  _copy_output.run();
1208 }
1209 
1211 {
1212  if(!_is_prepared)
1213  {
1214  if(_convert_input_to_forget_weights_to_qsymm8)
1215  {
1216  _input_to_forget_weights_f32.allocator()->allocate();
1217  _input_to_forget_weights_symm8.allocator()->allocate();
1218  _dequantize_input_to_forget_weights.run();
1219  _quantize_input_to_forget_weights.run();
1220  }
1221 
1222  // Pre-transpose weights to be used in GEMM.
1223  _input_to_forget_weights_transposed.allocator()->allocate();
1224  _input_to_cell_weights_transposed.allocator()->allocate();
1225  _input_to_output_weights_transposed.allocator()->allocate();
1226  _recurrent_to_forget_weights_transposed.allocator()->allocate();
1227  _recurrent_to_cell_weights_transposed.allocator()->allocate();
1228  _recurrent_to_output_weights_transposed.allocator()->allocate();
1229  _transpose_input_to_forget_weights.run();
1230  _transpose_input_to_cell_weights.run();
1231  _transpose_input_to_output_weights.run();
1232  _transpose_recurrent_to_forget_weights.run();
1233  _transpose_recurrent_to_cell_weights.run();
1234  _transpose_recurrent_to_output_weights.run();
1235 
1236  // Precompute effective biases
1237  if(_has_cifg)
1238  {
1239  std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1240  }
1241  else
1242  {
1243  _input_to_input_eff_bias.allocator()->allocate();
1244  _recurrent_to_input_eff_bias.allocator()->allocate();
1245 
1246  ITensorPack packII =
1247  {
1248  { TensorType::ACL_SRC, _input_to_input_weights },
1249  { TensorType::ACL_DST, &_input_to_input_eff_bias }
1250  };
1251  NEScheduler::get().schedule_op(_input_to_input_reduction.get(), Window::DimY, _input_to_input_reduction->window(), packII);
1252 
1253  ITensorPack packRI =
1254  {
1255  { TensorType::ACL_SRC, _recurrent_to_input_weights },
1256  { TensorType::ACL_DST, &_recurrent_to_input_eff_bias }
1257  };
1258  NEScheduler::get().schedule_op(_recurrent_to_input_reduction.get(), Window::DimY, _recurrent_to_input_reduction->window(), packRI);
1259 
1260  _input_to_input_weights_transposed.allocator()->allocate();
1261  _recurrent_to_input_weights_transposed.allocator()->allocate();
1262  _transpose_input_to_input_weights.run();
1263  _transpose_recurrent_to_input_weights.run();
1264  _input_to_input_weights->mark_as_unused();
1265  _recurrent_to_input_weights->mark_as_unused();
1266  }
1267  _input_to_forget_eff_bias.allocator()->allocate();
1268  _recurrent_to_forget_eff_bias.allocator()->allocate();
1269  _input_to_cell_eff_bias.allocator()->allocate();
1270  _recurrent_to_cell_eff_bias.allocator()->allocate();
1271  _input_to_output_eff_bias.allocator()->allocate();
1272  _recurrent_to_output_eff_bias.allocator()->allocate();
1273 
1274  ITensorPack packIF =
1275  {
1276  { TensorType::ACL_SRC, _input_to_forget_weights },
1277  { TensorType::ACL_DST, &_input_to_forget_eff_bias }
1278  };
1279  NEScheduler::get().schedule_op(_input_to_forget_reduction.get(), Window::DimY, _input_to_forget_reduction->window(), packIF);
1280 
1281  ITensorPack packRF =
1282  {
1283  { TensorType::ACL_SRC, _recurrent_to_forget_weights },
1284  { TensorType::ACL_DST, &_recurrent_to_forget_eff_bias }
1285  };
1286  NEScheduler::get().schedule_op(_recurrent_to_forget_reduction.get(), Window::DimY, _recurrent_to_forget_reduction->window(), packRF);
1287 
1288  ITensorPack packIC =
1289  {
1290  { TensorType::ACL_SRC, _input_to_cell_weights },
1291  { TensorType::ACL_DST, &_input_to_cell_eff_bias }
1292  };
1293  NEScheduler::get().schedule_op(_input_to_cell_reduction.get(), Window::DimY, _input_to_cell_reduction->window(), packIC);
1294 
1295  ITensorPack packRC =
1296  {
1297  { TensorType::ACL_SRC, _recurrent_to_cell_weights },
1298  { TensorType::ACL_DST, &_recurrent_to_cell_eff_bias }
1299  };
1300  NEScheduler::get().schedule_op(_recurrent_to_cell_reduction.get(), Window::DimY, _recurrent_to_cell_reduction->window(), packRC);
1301 
1302  ITensorPack packIO =
1303  {
1304  { TensorType::ACL_SRC, _input_to_output_weights },
1305  { TensorType::ACL_DST, &_input_to_output_eff_bias }
1306  };
1307  NEScheduler::get().schedule_op(_input_to_output_reduction.get(), Window::DimY, _input_to_output_reduction->window(), packIO);
1308 
1309  ITensorPack packRO =
1310  {
1311  { TensorType::ACL_SRC, _recurrent_to_output_weights },
1312  { TensorType::ACL_DST, &_recurrent_to_output_eff_bias }
1313  };
1314  NEScheduler::get().schedule_op(_recurrent_to_output_reduction.get(), Window::DimY, _recurrent_to_output_reduction->window(), packRO);
1315 
1316  if(_has_projection)
1317  {
1318  _projection_eff_bias.allocator()->allocate();
1319  ITensorPack pack =
1320  {
1321  { TensorType::ACL_SRC, _projection_weights },
1322  { TensorType::ACL_DST, &_projection_eff_bias }
1323  };
1324  NEScheduler::get().schedule_op(_projection_reduction.get(), Window::DimY, _projection_reduction->window(), pack);
1325  if(_projection_bias != nullptr)
1326  {
1327  _projection_bias_add.run();
1328  _projection_bias->mark_as_unused();
1329  }
1330 
1331  _projection_weights_transposed.allocator()->allocate();
1332  _transpose_projection_weights.run();
1333  _projection_weights->mark_as_unused();
1334 
1335  if(!_projection_tensor_copy_required)
1336  {
1337  _hidden_gate.mark_as_unused();
1338  _projection_accumulate_res.mark_as_unused();
1339  }
1340  }
1341 
1342  // Mark weights as unused
1343  _input_to_forget_weights->mark_as_unused();
1344  _input_to_cell_weights->mark_as_unused();
1345  _input_to_output_weights->mark_as_unused();
1346  _recurrent_to_forget_weights->mark_as_unused();
1347  _recurrent_to_cell_weights->mark_as_unused();
1348  _recurrent_to_output_weights->mark_as_unused();
1349 
1350  _is_prepared = true;
1351  }
1352 }
1353 } // namespace arm_compute
virtual size_t num_dimensions() const =0
The number of dimensions of the tensor (rank)
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
const T * projection_weights() const
Definition: LSTMParams.h:225
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias)
Static function to check if given info will lead to a valid configuration of NEQLSTMLayerNormalizatio...
void run() override
Run the kernels contained in the function.
int32_t gemmlowp_multiplier
GEMMLowp output stage multiplier used for quantizing to QASYMM8.
Definition: Types.h:2311
const T * input_to_input_weights() const
Definition: LSTMParams.h:195
int16_t quantize_qsymm16(float value, const UniformQuantizationInfo &qinfo, RoundingPolicy rounding_policy=RoundingPolicy::TO_NEAREST_UP)
Quantize a value given a 16-bit symmetric quantization scheme.
Shape of a tensor.
Definition: TensorShape.h:39
Quantize using a fixed point multiplication.
NEQLSTMLayer(std::shared_ptr< IMemoryManager > memory_manager=nullptr)
Default constructor.
quantized, symmetric fixed-point 16-bit number
bool use_layer_norm() const
Definition: LSTMParams.h:310
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEArithmeticAddition.
void init(const TensorAllocator &allocator, const Coordinates &coords, TensorInfo &sub_info)
Shares the same backing memory with another tensor allocator, while the tensor info might be differen...
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
bool has_peephole_opt() const
Definition: LSTMParams.h:295
T * forget_layer_norm_weights() const
Definition: LSTMParams.h:240
void run() override
Run the kernels contained in the function.
void build_lstm_params_tensor_info(const LSTMParams< T > &lstm_params, LSTMParams< ITensorInfo > *lstm_params_info)
Build LSTMParams<ITensorInfo> object by extracting the metadata from each tensor. ...
Definition: InfoHelpers.h:71
virtual ITensorInfo & set_tensor_shape(const TensorShape &shape)=0
Set the shape of an already initialized tensor.
void run() override
Run the kernels contained in the function.
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
virtual DataType data_type() const =0
Data type used for each element of the tensor.
virtual void schedule_op(ICPPKernel *kernel, const Hints &hints, const Window &window, ITensorPack &tensors)=0
Runs the kernel in the same thread as the caller synchronously.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info)
[NEActivationLayer snippet]
void run() override
Run the kernels contained in the function.
void configure(const ITensor *input, ITensor *output)
Configure the kernel.
1 channel, 1 F32 per channel
float output_intermediate_scale() const
Definition: LSTMParams.h:280
bool has_cifg_opt() const
Definition: LSTMParams.h:305
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
float cell_intermediate_scale() const
Definition: LSTMParams.h:275
float forget_intermediate_scale() const
Definition: LSTMParams.h:270
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:43
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Quantization info when assuming per layer quantization.
int32_t gemmlowp_offset
GEMMLowp output stage offset used for quantizing to QASYMM8.
Definition: Types.h:2310
T * cell_to_input_weights() const
Definition: LSTMParams.h:205
Status calculate_quantized_multiplier(float multiplier, int32_t *quant_multiplier, int32_t *shift, bool ignore_epsilon=false)
Calculate quantized representation of multiplier.
Status class.
Definition: Error.h:52
int32_t gemmlowp_max_bound
GEMMLowp max value used to saturate down the output result before converting back to QASYMM8...
Definition: Types.h:2314
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
Activation Layer Information class.
Definition: Types.h:1659
GEMMLowpOutputStageType type
GEMMLowp output stage type.
Definition: Types.h:2309
Interface for CPU tensor.
Definition: ITensor.h:36
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Initialise the kernel&#39;s inputs, output and conversion policy.
SimpleTensor< float > src
Definition: DFT.cpp:155
Copyright (c) 2017-2023 Arm Limited.
static Status validate(const ITensorInfo *input, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias, const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in, const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output, const LSTMParams< ITensorInfo > &lstm_params)
Static function to check if given info will lead to a valid configuration of NEQLSTMLayer.
TensorAllocator * allocator()
Return a pointer to the tensor&#39;s allocator.
Definition: Tensor.cpp:48
ITensorInfo * info() const override
Interface to be implemented by the child class to return the tensor&#39;s metadata.
Definition: Tensor.cpp:33
DataType data_type() const override
Data type used for each element of the tensor.
Definition: TensorInfo.h:244
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:159
void mark_as_unused() const
Marks a tensor as unused.
Definition: ITensor.cpp:168
1 channel, 1 S32 per channel
void manage(IMemoryManageable *obj) override
Sets a object to be managed by the given memory group.
Definition: MemoryGroup.h:79
void configure(const ITensor *input, ITensor *output)
Initialise the kernel&#39;s inputs and output.
Definition: NETranspose.cpp:46
const T * recurrent_to_input_weights() const
Definition: LSTMParams.h:200
int32_t hidden_state_zero() const
Definition: LSTMParams.h:285
const T * projection_bias() const
Definition: LSTMParams.h:230
Quantization information.
T * output_layer_norm_weights() const
Definition: LSTMParams.h:250
float input_intermediate_scale() const
Definition: LSTMParams.h:265
void run() override
Run the kernels contained in the function.
Definition: NETranspose.cpp:64
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output, const GEMMInfo &gemm_info=GEMMInfo())
Initialise the kernel&#39;s inputs, output.
static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
Static function to check if given info will lead to a valid configuration.
void run() override
Run the kernels contained in the function.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
#define ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:539
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Initialise the kernel&#39;s inputs, output and convertion policy.
~NEQLSTMLayer()
Default destructor.
void run() override
Run the kernels contained in the function.
int8_t quantize_qasymm8_signed(float value, const INFO_TYPE &qinfo, RoundingPolicy rounding_policy=RoundingPolicy::TO_NEAREST_UP)
Quantize a value given a signed 8-bit asymmetric quantization scheme.
static Status validate(const ITensorInfo *input, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NECopy.
Definition: NECopy.cpp:58
float hidden_state_scale() const
Definition: LSTMParams.h:290
void allocate() override
Allocate size specified by TensorInfo of CPU memory.
UniformQuantizationInfo uniform() const
Return per layer quantization info.
GEMMLowp output stage info.
Definition: Types.h:2307
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
Basic implementation of the tensor interface.
Definition: Tensor.h:37
void configure(const ITensor *input, const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights, const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights, const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias, const ITensor *cell_state_in, ITensor *output_state_in, ITensor *cell_state_out, ITensor *output_state_out, ITensor *output, const LSTMParams< ITensor > &lstm_params)
Initialize function&#39;s tensors.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
virtual ITensorInfo & set_quantization_info(const QuantizationInfo &quantization_info)=0
Set the quantization settings (scale and offset) of the tensor.
void configure(const ITensor *input, ITensor *output)
Set the input and output tensors.
virtual QuantizationInfo quantization_info() const =0
Get the quantization settings (scale and offset) of the tensor.
quantized, symmetric fixed-point 8-bit number
float cell_clip() const
Definition: LSTMParams.h:255
src_info set_data_layout(data_layout)
T * cell_to_forget_weights() const
Definition: LSTMParams.h:215
bool has_projection() const
Definition: LSTMParams.h:300
float projection_clip() const
Definition: LSTMParams.h:260
int32_t gemmlowp_shift
GEMMLowp output stage shift used for quantizing to uint8.
Definition: Types.h:2312
static constexpr size_t DimY
Alias for dimension 1 also known as Y dimension.
Definition: Window.h:45
T * cell_to_output_weights() const
Definition: LSTMParams.h:220
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Initialise the kernel&#39;s inputs, output and conversion policy.
void configure(const ITensor *input, const ITensor *bias, ITensor *output, const GEMMLowpOutputStageInfo &info)
Initialise the kernel&#39;s inputs, output.
Memory group resources scope handling class.
Definition: IMemoryGroup.h:82
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEPixelWiseMultiplicatio...
virtual size_t total_size() const =0
Returns the total size of the tensor in bytes.
T * input_layer_norm_weights() const
Definition: LSTMParams.h:235
virtual void schedule(ICPPKernel *kernel, const Hints &hints)=0
Runs the kernel in the same thread as the caller synchronously.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(...)
Definition: Validate.h:439
static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo &info)
Static function to check if given info will lead to a valid configuration of NEGEMMLowpOutputStage.
void run() override
Run the kernels contained in the function.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
void run() override
Run the kernels contained in the function.
Definition: NECopy.cpp:66
void run() override
Run the kernels contained in the function.
Basic function to execute GEMMLowpQuantizeDown kernels.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
void configure(ITensor *input, ITensor *output, ActivationLayerInfo activation_info)
[NEActivationLayer snippet]
const T * input_gate_bias() const
Definition: LSTMParams.h:210
uint8_t * buffer() const override
Interface to be implemented by the child class to return a pointer to CPU memory. ...
Definition: Tensor.cpp:43
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
void run() override
Run the kernels contained in the function.
Tensor packing service.
Definition: ITensorPack.h:39
#define ARM_COMPUTE_LOG_PARAMS(...)
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
Store the tensor&#39;s metadata.
Definition: TensorInfo.h:43
void execute_window_loop(const Window &w, L &&lambda_function, Ts &&... iterators)
Iterate through the passed window, automatically adjusting the iterators and calling the lambda_funct...
Definition: Helpers.inl:77
ITensorInfo & set_tensor_shape(const TensorShape &shape) override
Set the shape of an already initialized tensor.
Definition: TensorInfo.cpp:352
static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info=GEMMInfo())
Static function to check if given info will lead to a valid configuration of NEGEMMLowpMatrixMultiply...
T * cell_layer_norm_weights() const
Definition: LSTMParams.h:245
quantized, asymmetric fixed-point 8-bit number signed
im2col_func configure(src_target.info(), dst_target.info(), spatial_kernel, conv_info, has_bias)
void configure(ITensor *input, ITensor *output)
Initialise the function&#39;s source and destination.
Definition: NECopy.cpp:48
int32_t gemmlowp_min_bound
GEMMLowp min value used to saturate down the output result before converting back to QASYMM8...
Definition: Types.h:2313
const TensorShape & tensor_shape() const override
Size for each dimension of the tensor.
Definition: TensorInfo.h:236
DataType output_data_type
Output tensor data type to use if the output is not initialized.
Definition: Types.h:2319
Truncates the least significant values that are lost in operations.
void prepare() override
Prepare the function for executing.
Function to run Gemm on quantized types.
static Status validate(const ITensorInfo *input, const ITensorInfo *output)
Static function to check if given info will lead to a valid configuration of NETranspose.
Definition: NETranspose.cpp:57
Status validate(const ITensorInfo *scores_in, const ITensorInfo *boxes_in, const ITensorInfo *batch_splits_in, const ITensorInfo *scores_out, const ITensorInfo *boxes_out, const ITensorInfo *classes, const ITensorInfo *batch_splits_out, const ITensorInfo *keeps, const ITensorInfo *keeps_size, const BoxNMSLimitInfo info)
virtual DataLayout data_layout() const =0
Get the data layout of the tensor.
static IScheduler & get()
Access the scheduler singleton.
Definition: Scheduler.cpp:94
static Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy, const ActivationLayerInfo &act_info=ActivationLayerInfo())
Static function to check if given info will lead to a valid configuration of NEArithmeticSubtraction...
const int32_t * bias