ArmNN
 25.11
Loading...
Searching...
No Matches
QuantizedLstmParams.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "TensorFwd.hpp"
8#include "Exceptions.hpp"
9
10namespace armnn
11{
12
14{
16 : m_InputToInputWeights(nullptr)
17 , m_InputToForgetWeights(nullptr)
18 , m_InputToCellWeights(nullptr)
19 , m_InputToOutputWeights(nullptr)
20
25
26 , m_InputGateBias(nullptr)
27 , m_ForgetGateBias(nullptr)
28 , m_CellBias(nullptr)
29 , m_OutputGateBias(nullptr)
30 {
31 }
32
37
42
47
48 const ConstTensor& Deref(const ConstTensor* tensorPtr) const
49 {
50 if (tensorPtr != nullptr)
51 {
52 const ConstTensor &temp = *tensorPtr;
53 return temp;
54 }
55 throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
56 }
57
59 {
61 }
62
67
69 {
71 }
72
77
82
87
92
97
99 {
100 return Deref(m_InputGateBias);
101 }
102
104 {
105 return Deref(m_ForgetGateBias);
106 }
107
109 {
110 return Deref(m_CellBias);
111 }
112
114 {
115 return Deref(m_OutputGateBias);
116 }
117};
118
120{
122 : m_InputToInputWeights(nullptr)
123 , m_InputToForgetWeights(nullptr)
124 , m_InputToCellWeights(nullptr)
125 , m_InputToOutputWeights(nullptr)
126
129 , m_RecurrentToCellWeights(nullptr)
131
132 , m_InputGateBias(nullptr)
133 , m_ForgetGateBias(nullptr)
134 , m_CellBias(nullptr)
135 , m_OutputGateBias(nullptr)
136 {
137 }
138
143
148
153
154
155 const TensorInfo& Deref(const TensorInfo* tensorInfo) const
156 {
157 if (tensorInfo != nullptr)
158 {
159 const TensorInfo &temp = *tensorInfo;
160 return temp;
161 }
162 throw InvalidArgumentException("Can't dereference a null pointer");
163 }
164
166 {
168 }
170 {
172 }
174 {
176 }
178 {
180 }
181
198
200 {
201 return Deref(m_InputGateBias);
202 }
204 {
205 return Deref(m_ForgetGateBias);
206 }
207 const TensorInfo& GetCellBias() const
208 {
209 return Deref(m_CellBias);
210 }
212 {
213 return Deref(m_OutputGateBias);
214 }
215};
216
217} // namespace armnn
218
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Copyright (c) 2021 ARM Limited and Contributors.
const ConstTensor & GetCellBias() const
const ConstTensor & GetRecurrentToCellWeights() const
const ConstTensor & GetOutputGateBias() const
const ConstTensor * m_RecurrentToInputWeights
const ConstTensor * m_RecurrentToOutputWeights
const ConstTensor & GetInputToCellWeights() const
const ConstTensor & GetInputToOutputWeights() const
const ConstTensor & GetRecurrentToOutputWeights() const
const ConstTensor & GetForgetGateBias() const
const ConstTensor & GetInputGateBias() const
const ConstTensor & GetInputToInputWeights() const
const ConstTensor & Deref(const ConstTensor *tensorPtr) const
const ConstTensor & GetInputToForgetWeights() const
const ConstTensor * m_RecurrentToForgetWeights
const ConstTensor & GetRecurrentToForgetWeights() const
const ConstTensor & GetRecurrentToInputWeights() const
const TensorInfo & GetRecurrentToCellWeights() const
const TensorInfo & GetInputGateBias() const
const TensorInfo & GetInputToOutputWeights() const
const TensorInfo & GetRecurrentToForgetWeights() const
const TensorInfo & GetInputToForgetWeights() const
const TensorInfo & GetOutputGateBias() const
const TensorInfo & GetInputToInputWeights() const
const TensorInfo & Deref(const TensorInfo *tensorInfo) const
const TensorInfo & GetForgetGateBias() const
const TensorInfo & GetRecurrentToInputWeights() const
const TensorInfo & GetRecurrentToOutputWeights() const
const TensorInfo & GetInputToCellWeights() const