Compute Library
 21.02
LSTMParams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_LSTMPARAMS_H
25 #define ARM_COMPUTE_LSTMPARAMS_H
26 
29 #include "arm_compute/core/Types.h"
31 
32 #include <cstddef>
33 #include <memory>
34 
35 namespace arm_compute
36 {
37 template <typename T>
39 {
40 public:
41  /** Constructor */
43  : _input_to_input_weights(nullptr),
44  _recurrent_to_input_weights(nullptr),
45  _cell_to_input_weights(nullptr),
46  _input_gate_bias(nullptr),
47  _cell_to_forget_weights(nullptr),
48  _cell_to_output_weights(nullptr),
49  _projection_weights(nullptr),
50  _projection_bias(nullptr),
51  _input_layer_norm_weights(nullptr),
52  _forget_layer_norm_weights(nullptr),
53  _cell_layer_norm_weights(nullptr),
54  _output_layer_norm_weights(nullptr),
55  _cell_clip(0.f),
56  _projection_clip(0.0f),
57  _input_intermediate_scale(0.0f),
58  _forget_intermediate_scale(0.0f),
59  _cell_intermediate_scale(0.0f),
60  _output_intermediate_scale(0.0f),
61  _hidden_state_zero(0),
62  _hidden_state_scale(0.0f),
63  _has_peephole_opt(false),
64  _has_projection(false),
65  _has_cifg_opt(true),
66  _use_layer_norm(false)
67  {
68  }
69  /** Prevent instances of this class from being copied (As this class contains pointers) */
70  LSTMParams(const LSTMParams &) = delete;
71  /** Prevent instances of this class from being copied (As this class contains pointers) */
72  LSTMParams &operator=(const LSTMParams &) = delete;
73  /** Default destructor */
74  ~LSTMParams() = default;
75  /** Set CIFG tensor parameters.
76  *
77  * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
78  * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
79  * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
80  * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
81  *
82  * @return Reference to this LSTMParams object
83  */
85  {
86  _input_to_input_weights = input_to_input_weights;
87  _recurrent_to_input_weights = recurrent_to_input_weights;
88  _cell_to_input_weights = cell_to_input_weights;
89  _input_gate_bias = input_gate_bias;
90  _has_cifg_opt = false;
91  return *this;
92  }
93  /** Set projection tensor parameters.
94  *
95  * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
96  * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
97  *
98  * @return Reference to this LSTMParams object
99  */
101  {
102  _projection_weights = projection_weights;
103  _projection_bias = projection_bias;
104  _has_projection = true;
105  return *this;
106  }
107  /** Set peephole tensor parameters.
108  *
109  * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
110  * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
111  *
112  * @return Reference to this LSTMParams object
113  */
115  {
116  _cell_to_forget_weights = cell_to_forget_weights;
117  _cell_to_output_weights = cell_to_output_weights;
118  _has_peephole_opt = true;
119  return *this;
120  }
121  /** Set layer normalization tensor parameters.
122  *
123  * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
124  * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125  * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
126  * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
127  *
128  * @return Reference to this LSTMParams object
129  */
132  {
133  _input_layer_norm_weights = input_layer_norm_weights;
134  _forget_layer_norm_weights = forget_layer_norm_weights;
135  _cell_layer_norm_weights = cell_layer_norm_weights;
136  _output_layer_norm_weights = output_layer_norm_weights;
137  _use_layer_norm = true;
138  return *this;
139  }
140 
141  /** Set cell clip value.
142  *
143  * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
144  *
145  * @return Reference to this LSTMParams object
146  */
148  {
149  _cell_clip = cell_clip;
150  return *this;
151  }
152 
153  /** Set projection clip value.
154  *
155  * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
156  *
157  * @return Reference to this LSTMParams object
158  */
160  {
161  _projection_clip = projection_clip;
162  return *this;
163  }
164 
165  /** Set scale of the intermediate results of matmul of each layer parameters.
166  *
167  * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
168  * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
169  * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
170  * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
171  *
172  * @return Reference to this LSTMParams object
173  */
175  {
176  _input_intermediate_scale = input_intermediate_scale;
177  _forget_intermediate_scale = forget_intermediate_scale;
178  _cell_intermediate_scale = cell_intermediate_scale;
179  _output_intermediate_scale = output_intermediate_scale;
180  return *this;
181  }
182 
183  /** Set hidden state zero and scale parameters.
184  *
185  * @param[in] hidden_state_zero The zero point of the hidden state.
186  * @param[in] hidden_state_scale The scale of the hidden state.
187  *
188  * @return Reference to this LSTMParams object
189  */
191  {
192  _hidden_state_zero = hidden_state_zero;
193  _hidden_state_scale = hidden_state_scale;
194  return *this;
195  }
196 
197  const T *input_to_input_weights() const
198  {
199  return _input_to_input_weights;
200  }
201 
202  const T *recurrent_to_input_weights() const
203  {
204  return _recurrent_to_input_weights;
205  }
206 
208  {
209  return _cell_to_input_weights;
210  }
211 
212  const T *input_gate_bias() const
213  {
214  return _input_gate_bias;
215  }
216 
218  {
219  return _cell_to_forget_weights;
220  }
221 
223  {
224  return _cell_to_output_weights;
225  }
226 
227  const T *projection_weights() const
228  {
229  return _projection_weights;
230  }
231 
232  const T *projection_bias() const
233  {
234  return _projection_bias;
235  }
236 
238  {
239  return _input_layer_norm_weights;
240  }
241 
243  {
244  return _forget_layer_norm_weights;
245  }
246 
248  {
249  return _cell_layer_norm_weights;
250  }
251 
253  {
254  return _output_layer_norm_weights;
255  }
256 
257  float cell_clip() const
258  {
259  return _cell_clip;
260  }
261 
262  float projection_clip() const
263  {
264  return _projection_clip;
265  }
266 
268  {
269  return _input_intermediate_scale;
270  }
271 
273  {
274  return _forget_intermediate_scale;
275  }
276 
278  {
279  return _cell_intermediate_scale;
280  }
281 
283  {
284  return _output_intermediate_scale;
285  }
286 
287  int32_t hidden_state_zero() const
288  {
289  return _hidden_state_zero;
290  }
291 
292  float hidden_state_scale() const
293  {
294  return _hidden_state_scale;
295  }
296 
297  bool has_peephole_opt() const
298  {
299  return _has_peephole_opt;
300  }
301 
302  bool has_projection() const
303  {
304  return _has_projection;
305  }
306 
307  bool has_cifg_opt() const
308  {
309  return _has_cifg_opt;
310  }
311 
312  bool use_layer_norm() const
313  {
314  return _use_layer_norm;
315  }
316 
317 private:
318  const T *_input_to_input_weights;
319  const T *_recurrent_to_input_weights;
320  T *_cell_to_input_weights;
321  const T *_input_gate_bias;
322  T *_cell_to_forget_weights;
323  T *_cell_to_output_weights;
324  const T *_projection_weights;
325  const T *_projection_bias;
326  T *_input_layer_norm_weights;
327  T *_forget_layer_norm_weights;
328  T *_cell_layer_norm_weights;
329  T *_output_layer_norm_weights;
330  float _cell_clip;
331  float _projection_clip;
332  float _input_intermediate_scale;
333  float _forget_intermediate_scale;
334  float _cell_intermediate_scale;
335  float _output_intermediate_scale;
336  int32_t _hidden_state_zero;
337  float _hidden_state_scale;
338  bool _has_peephole_opt;
339  bool _has_projection;
340  bool _has_cifg_opt;
341  bool _use_layer_norm;
342 };
343 }
344 #endif /*ARM_COMPUTE_LSTMPARAMS_H */
~LSTMParams()=default
Default destructor.
const T * projection_weights() const
Definition: LSTMParams.h:227
const T * input_to_input_weights() const
Definition: LSTMParams.h:197
bool use_layer_norm() const
Definition: LSTMParams.h:312
bool has_peephole_opt() const
Definition: LSTMParams.h:297
LSTMParams & set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
Set CIFG tensor parameters.
Definition: LSTMParams.h:84
LSTMParams & set_cell_clip_params(float cell_clip)
Set cell clip value.
Definition: LSTMParams.h:147
T * forget_layer_norm_weights() const
Definition: LSTMParams.h:242
float output_intermediate_scale() const
Definition: LSTMParams.h:282
bool has_cifg_opt() const
Definition: LSTMParams.h:307
float cell_intermediate_scale() const
Definition: LSTMParams.h:277
float forget_intermediate_scale() const
Definition: LSTMParams.h:272
LSTMParams & set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
Set hidden state zero and scale parameters.
Definition: LSTMParams.h:190
T * cell_to_input_weights() const
Definition: LSTMParams.h:207
Copyright (c) 2017-2021 Arm Limited.
const T * recurrent_to_input_weights() const
Definition: LSTMParams.h:202
int32_t hidden_state_zero() const
Definition: LSTMParams.h:287
const T * projection_bias() const
Definition: LSTMParams.h:232
T * output_layer_norm_weights() const
Definition: LSTMParams.h:252
float input_intermediate_scale() const
Definition: LSTMParams.h:267
float hidden_state_scale() const
Definition: LSTMParams.h:292
LSTMParams & set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
Set scale of the intermediate results of matmul of each layer parameters.
Definition: LSTMParams.h:174
LSTMParams()
Constructor.
Definition: LSTMParams.h:42
LSTMParams & set_projection_params(const T *projection_weights, const T *projection_bias)
Set projection tensor parameters.
Definition: LSTMParams.h:100
float cell_clip() const
Definition: LSTMParams.h:257
LSTMParams & set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights, T *cell_layer_norm_weights, T *output_layer_norm_weights)
Set layer normalization tensor parameters.
Definition: LSTMParams.h:130
T * cell_to_forget_weights() const
Definition: LSTMParams.h:217
LSTMParams & set_projection_clip_params(float projection_clip)
Set projection clip value.
Definition: LSTMParams.h:159
LSTMParams & set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
Set peephole tensor parameters.
Definition: LSTMParams.h:114
bool has_projection() const
Definition: LSTMParams.h:302
float projection_clip() const
Definition: LSTMParams.h:262
T * cell_to_output_weights() const
Definition: LSTMParams.h:222
T * input_layer_norm_weights() const
Definition: LSTMParams.h:237
const T * input_gate_bias() const
Definition: LSTMParams.h:212
T * cell_layer_norm_weights() const
Definition: LSTMParams.h:247
LSTMParams & operator=(const LSTMParams &)=delete
Prevent instances of this class from being copied (As this class contains pointers) ...