Compute Library
 23.05
LSTMParams.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_LSTMPARAMS_H
25 #define ARM_COMPUTE_LSTMPARAMS_H
26 
27 #include "arm_compute/core/Types.h"
29 
30 #include <cstddef>
31 #include <memory>
32 
33 namespace arm_compute
34 {
35 template <typename T>
37 {
38 public:
39  /** Constructor */
41  : _input_to_input_weights(nullptr),
42  _recurrent_to_input_weights(nullptr),
43  _cell_to_input_weights(nullptr),
44  _input_gate_bias(nullptr),
45  _cell_to_forget_weights(nullptr),
46  _cell_to_output_weights(nullptr),
47  _projection_weights(nullptr),
48  _projection_bias(nullptr),
49  _input_layer_norm_weights(nullptr),
50  _forget_layer_norm_weights(nullptr),
51  _cell_layer_norm_weights(nullptr),
52  _output_layer_norm_weights(nullptr),
53  _cell_clip(0.f),
54  _projection_clip(0.0f),
55  _input_intermediate_scale(0.0f),
56  _forget_intermediate_scale(0.0f),
57  _cell_intermediate_scale(0.0f),
58  _output_intermediate_scale(0.0f),
59  _hidden_state_zero(0),
60  _hidden_state_scale(0.0f),
61  _has_peephole_opt(false),
62  _has_projection(false),
63  _has_cifg_opt(true),
64  _use_layer_norm(false)
65  {
66  }
67  /** Prevent instances of this class from being copied (As this class contains pointers) */
68  LSTMParams(const LSTMParams &) = delete;
69  /** Prevent instances of this class from being copied (As this class contains pointers) */
70  LSTMParams &operator=(const LSTMParams &) = delete;
71  /** Default destructor */
72  ~LSTMParams() = default;
73  /** Set CIFG tensor parameters.
74  *
75  * @param[in] input_to_input_weights 2D weights tensor with dimensions [input_size, num_units]. Data types supported: QSYMM8/F16/F32.
76  * @param[in] recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input_to_input_weights.
77  * @param[in] cell_to_input_weights 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input_to_input_weights.
78  * @param[in] input_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_to_input_weights, S32 when @p input_to_input_weights is QSYMM8
79  *
80  * @return Reference to this LSTMParams object
81  */
83  {
84  _input_to_input_weights = input_to_input_weights;
85  _recurrent_to_input_weights = recurrent_to_input_weights;
86  _cell_to_input_weights = cell_to_input_weights;
87  _input_gate_bias = input_gate_bias;
88  _has_cifg_opt = false;
89  return *this;
90  }
91  /** Set projection tensor parameters.
92  *
93  * @param[in] projection_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Data types supported: QSYMM8/F16/F32.
94  * @param[in] projection_bias 1D weights tensor with dimensions [output_size]. Data type supported: Same as @p projection_weights, S32 when @p input_to_input_weights is QSYMM8.
95  *
96  * @return Reference to this LSTMParams object
97  */
99  {
100  _projection_weights = projection_weights;
101  _projection_bias = projection_bias;
102  _has_projection = true;
103  return *this;
104  }
105  /** Set peephole tensor parameters.
106  *
107  * @param[in] cell_to_forget_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
108  * @param[in] cell_to_output_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p cell_to_forget_weights.
109  *
110  * @return Reference to this LSTMParams object
111  */
113  {
114  _cell_to_forget_weights = cell_to_forget_weights;
115  _cell_to_output_weights = cell_to_output_weights;
116  _has_peephole_opt = true;
117  return *this;
118  }
119  /** Set layer normalization tensor parameters.
120  *
121  * @param[in] input_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Data types supported: QSYMM16/F16/F32.
122  * @param[in] forget_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
123  * @param[in] cell_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
124  * @param[in] output_layer_norm_weights 1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input_layer_norm_weights.
125  *
126  * @return Reference to this LSTMParams object
127  */
130  {
131  _input_layer_norm_weights = input_layer_norm_weights;
132  _forget_layer_norm_weights = forget_layer_norm_weights;
133  _cell_layer_norm_weights = cell_layer_norm_weights;
134  _output_layer_norm_weights = output_layer_norm_weights;
135  _use_layer_norm = true;
136  return *this;
137  }
138 
139  /** Set cell clip value.
140  *
141  * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
142  *
143  * @return Reference to this LSTMParams object
144  */
146  {
147  _cell_clip = cell_clip;
148  return *this;
149  }
150 
151  /** Set projection clip value.
152  *
153  * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
154  *
155  * @return Reference to this LSTMParams object
156  */
158  {
159  _projection_clip = projection_clip;
160  return *this;
161  }
162 
163  /** Set scale of the intermediate results of matmul of each layer parameters.
164  *
165  * @param[in] input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
166  * @param[in] forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
167  * @param[in] cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
168  * @param[in] output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
169  *
170  * @return Reference to this LSTMParams object
171  */
173  {
174  _input_intermediate_scale = input_intermediate_scale;
175  _forget_intermediate_scale = forget_intermediate_scale;
176  _cell_intermediate_scale = cell_intermediate_scale;
177  _output_intermediate_scale = output_intermediate_scale;
178  return *this;
179  }
180 
181  /** Set hidden state zero and scale parameters.
182  *
183  * @param[in] hidden_state_zero The zero point of the hidden state.
184  * @param[in] hidden_state_scale The scale of the hidden state.
185  *
186  * @return Reference to this LSTMParams object
187  */
189  {
190  _hidden_state_zero = hidden_state_zero;
191  _hidden_state_scale = hidden_state_scale;
192  return *this;
193  }
194 
195  const T *input_to_input_weights() const
196  {
197  return _input_to_input_weights;
198  }
199 
200  const T *recurrent_to_input_weights() const
201  {
202  return _recurrent_to_input_weights;
203  }
204 
206  {
207  return _cell_to_input_weights;
208  }
209 
210  const T *input_gate_bias() const
211  {
212  return _input_gate_bias;
213  }
214 
216  {
217  return _cell_to_forget_weights;
218  }
219 
221  {
222  return _cell_to_output_weights;
223  }
224 
225  const T *projection_weights() const
226  {
227  return _projection_weights;
228  }
229 
230  const T *projection_bias() const
231  {
232  return _projection_bias;
233  }
234 
236  {
237  return _input_layer_norm_weights;
238  }
239 
241  {
242  return _forget_layer_norm_weights;
243  }
244 
246  {
247  return _cell_layer_norm_weights;
248  }
249 
251  {
252  return _output_layer_norm_weights;
253  }
254 
255  float cell_clip() const
256  {
257  return _cell_clip;
258  }
259 
260  float projection_clip() const
261  {
262  return _projection_clip;
263  }
264 
266  {
267  return _input_intermediate_scale;
268  }
269 
271  {
272  return _forget_intermediate_scale;
273  }
274 
276  {
277  return _cell_intermediate_scale;
278  }
279 
281  {
282  return _output_intermediate_scale;
283  }
284 
285  int32_t hidden_state_zero() const
286  {
287  return _hidden_state_zero;
288  }
289 
290  float hidden_state_scale() const
291  {
292  return _hidden_state_scale;
293  }
294 
295  bool has_peephole_opt() const
296  {
297  return _has_peephole_opt;
298  }
299 
300  bool has_projection() const
301  {
302  return _has_projection;
303  }
304 
305  bool has_cifg_opt() const
306  {
307  return _has_cifg_opt;
308  }
309 
310  bool use_layer_norm() const
311  {
312  return _use_layer_norm;
313  }
314 
315 private:
316  const T *_input_to_input_weights;
317  const T *_recurrent_to_input_weights;
318  T *_cell_to_input_weights;
319  const T *_input_gate_bias;
320  T *_cell_to_forget_weights;
321  T *_cell_to_output_weights;
322  const T *_projection_weights;
323  const T *_projection_bias;
324  T *_input_layer_norm_weights;
325  T *_forget_layer_norm_weights;
326  T *_cell_layer_norm_weights;
327  T *_output_layer_norm_weights;
328  float _cell_clip;
329  float _projection_clip;
330  float _input_intermediate_scale;
331  float _forget_intermediate_scale;
332  float _cell_intermediate_scale;
333  float _output_intermediate_scale;
334  int32_t _hidden_state_zero;
335  float _hidden_state_scale;
336  bool _has_peephole_opt;
337  bool _has_projection;
338  bool _has_cifg_opt;
339  bool _use_layer_norm;
340 };
341 }
342 #endif /*ARM_COMPUTE_LSTMPARAMS_H */
~LSTMParams()=default
Default destructor.
const T * projection_weights() const
Definition: LSTMParams.h:225
const T * input_to_input_weights() const
Definition: LSTMParams.h:195
bool use_layer_norm() const
Definition: LSTMParams.h:310
bool has_peephole_opt() const
Definition: LSTMParams.h:295
LSTMParams & set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
Set CIFG tensor parameters.
Definition: LSTMParams.h:82
LSTMParams & set_cell_clip_params(float cell_clip)
Set cell clip value.
Definition: LSTMParams.h:145
T * forget_layer_norm_weights() const
Definition: LSTMParams.h:240
float output_intermediate_scale() const
Definition: LSTMParams.h:280
bool has_cifg_opt() const
Definition: LSTMParams.h:305
float cell_intermediate_scale() const
Definition: LSTMParams.h:275
float forget_intermediate_scale() const
Definition: LSTMParams.h:270
LSTMParams & set_hidden_state_params(int32_t hidden_state_zero, float hidden_state_scale)
Set hidden state zero and scale parameters.
Definition: LSTMParams.h:188
T * cell_to_input_weights() const
Definition: LSTMParams.h:205
Copyright (c) 2017-2023 Arm Limited.
const T * recurrent_to_input_weights() const
Definition: LSTMParams.h:200
int32_t hidden_state_zero() const
Definition: LSTMParams.h:285
const T * projection_bias() const
Definition: LSTMParams.h:230
T * output_layer_norm_weights() const
Definition: LSTMParams.h:250
float input_intermediate_scale() const
Definition: LSTMParams.h:265
float hidden_state_scale() const
Definition: LSTMParams.h:290
LSTMParams & set_matmul_scale_params(float input_intermediate_scale, float forget_intermediate_scale, float cell_intermediate_scale, float output_intermediate_scale)
Set scale of the intermediate results of matmul of each layer parameters.
Definition: LSTMParams.h:172
LSTMParams()
Constructor.
Definition: LSTMParams.h:40
LSTMParams & set_projection_params(const T *projection_weights, const T *projection_bias)
Set projection tensor parameters.
Definition: LSTMParams.h:98
float cell_clip() const
Definition: LSTMParams.h:255
LSTMParams & set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights, T *cell_layer_norm_weights, T *output_layer_norm_weights)
Set layer normalization tensor parameters.
Definition: LSTMParams.h:128
T * cell_to_forget_weights() const
Definition: LSTMParams.h:215
LSTMParams & set_projection_clip_params(float projection_clip)
Set projection clip value.
Definition: LSTMParams.h:157
LSTMParams & set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
Set peephole tensor parameters.
Definition: LSTMParams.h:112
bool has_projection() const
Definition: LSTMParams.h:300
float projection_clip() const
Definition: LSTMParams.h:260
T * cell_to_output_weights() const
Definition: LSTMParams.h:220
T * input_layer_norm_weights() const
Definition: LSTMParams.h:235
const T * input_gate_bias() const
Definition: LSTMParams.h:210
T * cell_layer_norm_weights() const
Definition: LSTMParams.h:245
LSTMParams & operator=(const LSTMParams &)=delete
Prevent instances of this class from being copied (As this class contains pointers) ...