24.08
LstmParams.hpp
Go to the documentation of this file.
1
//
2
// Copyright © 2017 Arm Ltd. All rights reserved.
3
// SPDX-License-Identifier: MIT
4
//
5
#pragma once
6
7
#include "
TensorFwd.hpp
"
8
#include "
Exceptions.hpp
"
9
10
namespace
armnn
11
{
12
13
struct
LstmInputParams
14
{
15
LstmInputParams
()
16
:
m_InputToInputWeights
(nullptr)
17
,
m_InputToForgetWeights
(nullptr)
18
,
m_InputToCellWeights
(nullptr)
19
,
m_InputToOutputWeights
(nullptr)
20
,
m_RecurrentToInputWeights
(nullptr)
21
,
m_RecurrentToForgetWeights
(nullptr)
22
,
m_RecurrentToCellWeights
(nullptr)
23
,
m_RecurrentToOutputWeights
(nullptr)
24
,
m_CellToInputWeights
(nullptr)
25
,
m_CellToForgetWeights
(nullptr)
26
,
m_CellToOutputWeights
(nullptr)
27
,
m_InputGateBias
(nullptr)
28
,
m_ForgetGateBias
(nullptr)
29
,
m_CellBias
(nullptr)
30
,
m_OutputGateBias
(nullptr)
31
,
m_ProjectionWeights
(nullptr)
32
,
m_ProjectionBias
(nullptr)
33
,
m_InputLayerNormWeights
(nullptr)
34
,
m_ForgetLayerNormWeights
(nullptr)
35
,
m_CellLayerNormWeights
(nullptr)
36
,
m_OutputLayerNormWeights
(nullptr)
37
{
38
}
39
40
const
ConstTensor
*
m_InputToInputWeights
;
41
const
ConstTensor
*
m_InputToForgetWeights
;
42
const
ConstTensor
*
m_InputToCellWeights
;
43
const
ConstTensor
*
m_InputToOutputWeights
;
44
const
ConstTensor
*
m_RecurrentToInputWeights
;
45
const
ConstTensor
*
m_RecurrentToForgetWeights
;
46
const
ConstTensor
*
m_RecurrentToCellWeights
;
47
const
ConstTensor
*
m_RecurrentToOutputWeights
;
48
const
ConstTensor
*
m_CellToInputWeights
;
49
const
ConstTensor
*
m_CellToForgetWeights
;
50
const
ConstTensor
*
m_CellToOutputWeights
;
51
const
ConstTensor
*
m_InputGateBias
;
52
const
ConstTensor
*
m_ForgetGateBias
;
53
const
ConstTensor
*
m_CellBias
;
54
const
ConstTensor
*
m_OutputGateBias
;
55
const
ConstTensor
*
m_ProjectionWeights
;
56
const
ConstTensor
*
m_ProjectionBias
;
57
const
ConstTensor
*
m_InputLayerNormWeights
;
58
const
ConstTensor
*
m_ForgetLayerNormWeights
;
59
const
ConstTensor
*
m_CellLayerNormWeights
;
60
const
ConstTensor
*
m_OutputLayerNormWeights
;
61
};
62
63
struct
LstmInputParamsInfo
64
{
65
LstmInputParamsInfo
()
66
:
m_InputToInputWeights
(nullptr)
67
,
m_InputToForgetWeights
(nullptr)
68
,
m_InputToCellWeights
(nullptr)
69
,
m_InputToOutputWeights
(nullptr)
70
,
m_RecurrentToInputWeights
(nullptr)
71
,
m_RecurrentToForgetWeights
(nullptr)
72
,
m_RecurrentToCellWeights
(nullptr)
73
,
m_RecurrentToOutputWeights
(nullptr)
74
,
m_CellToInputWeights
(nullptr)
75
,
m_CellToForgetWeights
(nullptr)
76
,
m_CellToOutputWeights
(nullptr)
77
,
m_InputGateBias
(nullptr)
78
,
m_ForgetGateBias
(nullptr)
79
,
m_CellBias
(nullptr)
80
,
m_OutputGateBias
(nullptr)
81
,
m_ProjectionWeights
(nullptr)
82
,
m_ProjectionBias
(nullptr)
83
,
m_InputLayerNormWeights
(nullptr)
84
,
m_ForgetLayerNormWeights
(nullptr)
85
,
m_CellLayerNormWeights
(nullptr)
86
,
m_OutputLayerNormWeights
(nullptr)
87
{
88
}
89
const
TensorInfo
*
m_InputToInputWeights
;
90
const
TensorInfo
*
m_InputToForgetWeights
;
91
const
TensorInfo
*
m_InputToCellWeights
;
92
const
TensorInfo
*
m_InputToOutputWeights
;
93
const
TensorInfo
*
m_RecurrentToInputWeights
;
94
const
TensorInfo
*
m_RecurrentToForgetWeights
;
95
const
TensorInfo
*
m_RecurrentToCellWeights
;
96
const
TensorInfo
*
m_RecurrentToOutputWeights
;
97
const
TensorInfo
*
m_CellToInputWeights
;
98
const
TensorInfo
*
m_CellToForgetWeights
;
99
const
TensorInfo
*
m_CellToOutputWeights
;
100
const
TensorInfo
*
m_InputGateBias
;
101
const
TensorInfo
*
m_ForgetGateBias
;
102
const
TensorInfo
*
m_CellBias
;
103
const
TensorInfo
*
m_OutputGateBias
;
104
const
TensorInfo
*
m_ProjectionWeights
;
105
const
TensorInfo
*
m_ProjectionBias
;
106
const
TensorInfo
*
m_InputLayerNormWeights
;
107
const
TensorInfo
*
m_ForgetLayerNormWeights
;
108
const
TensorInfo
*
m_CellLayerNormWeights
;
109
const
TensorInfo
*
m_OutputLayerNormWeights
;
110
111
const
TensorInfo
&
Deref
(
const
TensorInfo
* tensorInfo)
const
112
{
113
if
(tensorInfo !=
nullptr
)
114
{
115
const
TensorInfo
&temp = *tensorInfo;
116
return
temp;
117
}
118
throw
InvalidArgumentException
(
"Can't dereference a null pointer"
);
119
}
120
121
const
TensorInfo
&
GetInputToInputWeights
()
const
122
{
123
return
Deref
(
m_InputToInputWeights
);
124
}
125
const
TensorInfo
&
GetInputToForgetWeights
()
const
126
{
127
return
Deref
(
m_InputToForgetWeights
);
128
}
129
const
TensorInfo
&
GetInputToCellWeights
()
const
130
{
131
return
Deref
(
m_InputToCellWeights
);
132
}
133
const
TensorInfo
&
GetInputToOutputWeights
()
const
134
{
135
return
Deref
(
m_InputToOutputWeights
);
136
}
137
const
TensorInfo
&
GetRecurrentToInputWeights
()
const
138
{
139
return
Deref
(
m_RecurrentToInputWeights
);
140
}
141
const
TensorInfo
&
GetRecurrentToForgetWeights
()
const
142
{
143
return
Deref
(
m_RecurrentToForgetWeights
);
144
}
145
const
TensorInfo
&
GetRecurrentToCellWeights
()
const
146
{
147
return
Deref
(
m_RecurrentToCellWeights
);
148
}
149
const
TensorInfo
&
GetRecurrentToOutputWeights
()
const
150
{
151
return
Deref
(
m_RecurrentToOutputWeights
);
152
}
153
const
TensorInfo
&
GetCellToInputWeights
()
const
154
{
155
return
Deref
(
m_CellToInputWeights
);
156
}
157
const
TensorInfo
&
GetCellToForgetWeights
()
const
158
{
159
return
Deref
(
m_CellToForgetWeights
);
160
}
161
const
TensorInfo
&
GetCellToOutputWeights
()
const
162
{
163
return
Deref
(
m_CellToOutputWeights
);
164
}
165
const
TensorInfo
&
GetInputGateBias
()
const
166
{
167
return
Deref
(
m_InputGateBias
);
168
}
169
const
TensorInfo
&
GetForgetGateBias
()
const
170
{
171
return
Deref
(
m_ForgetGateBias
);
172
}
173
const
TensorInfo
&
GetCellBias
()
const
174
{
175
return
Deref
(
m_CellBias
);
176
}
177
const
TensorInfo
&
GetOutputGateBias
()
const
178
{
179
return
Deref
(
m_OutputGateBias
);
180
}
181
const
TensorInfo
&
GetProjectionWeights
()
const
182
{
183
return
Deref
(
m_ProjectionWeights
);
184
}
185
const
TensorInfo
&
GetProjectionBias
()
const
186
{
187
return
Deref
(
m_ProjectionBias
);
188
}
189
const
TensorInfo
&
GetInputLayerNormWeights
()
const
190
{
191
return
Deref
(
m_InputLayerNormWeights
);
192
}
193
const
TensorInfo
&
GetForgetLayerNormWeights
()
const
194
{
195
return
Deref
(
m_ForgetLayerNormWeights
);
196
}
197
const
TensorInfo
&
GetCellLayerNormWeights
()
const
198
{
199
return
Deref
(
m_CellLayerNormWeights
);
200
}
201
const
TensorInfo
&
GetOutputLayerNormWeights
()
const
202
{
203
return
Deref
(
m_OutputLayerNormWeights
);
204
}
205
};
206
207
}
// namespace armnn
208
armnn::LstmInputParams::m_RecurrentToForgetWeights
const ConstTensor * m_RecurrentToForgetWeights
Definition:
LstmParams.hpp:45
armnn::LstmInputParamsInfo::m_InputToInputWeights
const TensorInfo * m_InputToInputWeights
Definition:
LstmParams.hpp:89
armnn::LstmInputParams::m_OutputLayerNormWeights
const ConstTensor * m_OutputLayerNormWeights
Definition:
LstmParams.hpp:60
armnn::LstmInputParamsInfo::GetCellBias
const TensorInfo & GetCellBias() const
Definition:
LstmParams.hpp:173
armnn::LstmInputParamsInfo::m_InputLayerNormWeights
const TensorInfo * m_InputLayerNormWeights
Definition:
LstmParams.hpp:106
armnn::LstmInputParamsInfo::Deref
const TensorInfo & Deref(const TensorInfo *tensorInfo) const
Definition:
LstmParams.hpp:111
armnn::LstmInputParams::m_ProjectionBias
const ConstTensor * m_ProjectionBias
Definition:
LstmParams.hpp:56
armnn::LstmInputParamsInfo::GetInputToCellWeights
const TensorInfo & GetInputToCellWeights() const
Definition:
LstmParams.hpp:129
armnn::LstmInputParams::m_RecurrentToCellWeights
const ConstTensor * m_RecurrentToCellWeights
Definition:
LstmParams.hpp:46
armnn::LstmInputParams::m_CellBias
const ConstTensor * m_CellBias
Definition:
LstmParams.hpp:53
armnn::LstmInputParams::LstmInputParams
LstmInputParams()
Definition:
LstmParams.hpp:15
armnn::TensorInfo
Definition:
Tensor.hpp:152
armnn::LstmInputParamsInfo::GetProjectionBias
const TensorInfo & GetProjectionBias() const
Definition:
LstmParams.hpp:185
armnn::LstmInputParamsInfo::m_OutputGateBias
const TensorInfo * m_OutputGateBias
Definition:
LstmParams.hpp:103
armnn::LstmInputParamsInfo::m_CellToInputWeights
const TensorInfo * m_CellToInputWeights
Definition:
LstmParams.hpp:97
armnn::LstmInputParamsInfo::m_CellLayerNormWeights
const TensorInfo * m_CellLayerNormWeights
Definition:
LstmParams.hpp:108
armnn::LstmInputParamsInfo::GetInputGateBias
const TensorInfo & GetInputGateBias() const
Definition:
LstmParams.hpp:165
armnn::LstmInputParamsInfo::m_RecurrentToCellWeights
const TensorInfo * m_RecurrentToCellWeights
Definition:
LstmParams.hpp:95
armnn::LstmInputParams::m_CellToOutputWeights
const ConstTensor * m_CellToOutputWeights
Definition:
LstmParams.hpp:50
armnn::LstmInputParams::m_InputToCellWeights
const ConstTensor * m_InputToCellWeights
Definition:
LstmParams.hpp:42
armnn::LstmInputParamsInfo::GetRecurrentToInputWeights
const TensorInfo & GetRecurrentToInputWeights() const
Definition:
LstmParams.hpp:137
armnn::LstmInputParamsInfo::GetRecurrentToForgetWeights
const TensorInfo & GetRecurrentToForgetWeights() const
Definition:
LstmParams.hpp:141
armnn::LstmInputParamsInfo::m_OutputLayerNormWeights
const TensorInfo * m_OutputLayerNormWeights
Definition:
LstmParams.hpp:109
armnn::LstmInputParamsInfo::GetRecurrentToCellWeights
const TensorInfo & GetRecurrentToCellWeights() const
Definition:
LstmParams.hpp:145
armnn::LstmInputParams::m_ForgetGateBias
const ConstTensor * m_ForgetGateBias
Definition:
LstmParams.hpp:52
armnn::LstmInputParams::m_CellToInputWeights
const ConstTensor * m_CellToInputWeights
Definition:
LstmParams.hpp:48
armnn::LstmInputParamsInfo::m_RecurrentToOutputWeights
const TensorInfo * m_RecurrentToOutputWeights
Definition:
LstmParams.hpp:96
armnn::LstmInputParamsInfo::GetInputLayerNormWeights
const TensorInfo & GetInputLayerNormWeights() const
Definition:
LstmParams.hpp:189
armnn::LstmInputParams::m_InputToOutputWeights
const ConstTensor * m_InputToOutputWeights
Definition:
LstmParams.hpp:43
TensorFwd.hpp
armnn::LstmInputParamsInfo::m_ForgetLayerNormWeights
const TensorInfo * m_ForgetLayerNormWeights
Definition:
LstmParams.hpp:107
armnn::LstmInputParams::m_CellToForgetWeights
const ConstTensor * m_CellToForgetWeights
Definition:
LstmParams.hpp:49
armnn::LstmInputParams::m_RecurrentToInputWeights
const ConstTensor * m_RecurrentToInputWeights
Definition:
LstmParams.hpp:44
armnn::LstmInputParamsInfo::m_RecurrentToForgetWeights
const TensorInfo * m_RecurrentToForgetWeights
Definition:
LstmParams.hpp:94
armnn::LstmInputParamsInfo::m_CellToForgetWeights
const TensorInfo * m_CellToForgetWeights
Definition:
LstmParams.hpp:98
armnn::LstmInputParams::m_InputToInputWeights
const ConstTensor * m_InputToInputWeights
Definition:
LstmParams.hpp:40
armnn::LstmInputParamsInfo::GetCellToInputWeights
const TensorInfo & GetCellToInputWeights() const
Definition:
LstmParams.hpp:153
armnn::LstmInputParamsInfo::GetRecurrentToOutputWeights
const TensorInfo & GetRecurrentToOutputWeights() const
Definition:
LstmParams.hpp:149
armnn::LstmInputParamsInfo::GetInputToInputWeights
const TensorInfo & GetInputToInputWeights() const
Definition:
LstmParams.hpp:121
armnn::LstmInputParams::m_RecurrentToOutputWeights
const ConstTensor * m_RecurrentToOutputWeights
Definition:
LstmParams.hpp:47
armnn::LstmInputParams::m_InputGateBias
const ConstTensor * m_InputGateBias
Definition:
LstmParams.hpp:51
armnn::LstmInputParamsInfo::GetForgetGateBias
const TensorInfo & GetForgetGateBias() const
Definition:
LstmParams.hpp:169
armnn::InvalidArgumentException
Definition:
Exceptions.hpp:80
armnn::LstmInputParamsInfo::GetCellToForgetWeights
const TensorInfo & GetCellToForgetWeights() const
Definition:
LstmParams.hpp:157
armnn::LstmInputParamsInfo::LstmInputParamsInfo
LstmInputParamsInfo()
Definition:
LstmParams.hpp:65
armnn::LstmInputParamsInfo::m_InputToCellWeights
const TensorInfo * m_InputToCellWeights
Definition:
LstmParams.hpp:91
armnn::LstmInputParamsInfo::m_CellBias
const TensorInfo * m_CellBias
Definition:
LstmParams.hpp:102
armnn::LstmInputParams::m_InputLayerNormWeights
const ConstTensor * m_InputLayerNormWeights
Definition:
LstmParams.hpp:57
armnn::LstmInputParamsInfo::m_RecurrentToInputWeights
const TensorInfo * m_RecurrentToInputWeights
Definition:
LstmParams.hpp:93
armnn::LstmInputParamsInfo::m_ForgetGateBias
const TensorInfo * m_ForgetGateBias
Definition:
LstmParams.hpp:101
armnn::LstmInputParams::m_ForgetLayerNormWeights
const ConstTensor * m_ForgetLayerNormWeights
Definition:
LstmParams.hpp:58
armnn::LstmInputParamsInfo::GetInputToOutputWeights
const TensorInfo & GetInputToOutputWeights() const
Definition:
LstmParams.hpp:133
armnn::LstmInputParamsInfo::GetOutputGateBias
const TensorInfo & GetOutputGateBias() const
Definition:
LstmParams.hpp:177
armnn::LstmInputParamsInfo::GetCellToOutputWeights
const TensorInfo & GetCellToOutputWeights() const
Definition:
LstmParams.hpp:161
armnn::LstmInputParams::m_OutputGateBias
const ConstTensor * m_OutputGateBias
Definition:
LstmParams.hpp:54
armnn::LstmInputParams::m_ProjectionWeights
const ConstTensor * m_ProjectionWeights
Definition:
LstmParams.hpp:55
armnn::LstmInputParamsInfo::GetOutputLayerNormWeights
const TensorInfo & GetOutputLayerNormWeights() const
Definition:
LstmParams.hpp:201
armnn::LstmInputParams::m_InputToForgetWeights
const ConstTensor * m_InputToForgetWeights
Definition:
LstmParams.hpp:41
Exceptions.hpp
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition:
01_00_quick_start.dox:6
armnn::LstmInputParamsInfo::m_CellToOutputWeights
const TensorInfo * m_CellToOutputWeights
Definition:
LstmParams.hpp:99
armnn::LstmInputParamsInfo
Definition:
LstmParams.hpp:63
armnn::ConstTensor
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition:
Tensor.hpp:329
armnn::LstmInputParamsInfo::m_InputGateBias
const TensorInfo * m_InputGateBias
Definition:
LstmParams.hpp:100
armnn::LstmInputParamsInfo::GetProjectionWeights
const TensorInfo & GetProjectionWeights() const
Definition:
LstmParams.hpp:181
armnn::LstmInputParamsInfo::m_InputToForgetWeights
const TensorInfo * m_InputToForgetWeights
Definition:
LstmParams.hpp:90
armnn::LstmInputParamsInfo::m_InputToOutputWeights
const TensorInfo * m_InputToOutputWeights
Definition:
LstmParams.hpp:92
armnn::LstmInputParamsInfo::m_ProjectionBias
const TensorInfo * m_ProjectionBias
Definition:
LstmParams.hpp:105
armnn::LstmInputParams
Definition:
LstmParams.hpp:13
armnn::LstmInputParamsInfo::m_ProjectionWeights
const TensorInfo * m_ProjectionWeights
Definition:
LstmParams.hpp:104
armnn::LstmInputParams::m_CellLayerNormWeights
const ConstTensor * m_CellLayerNormWeights
Definition:
LstmParams.hpp:59
armnn::LstmInputParamsInfo::GetForgetLayerNormWeights
const TensorInfo & GetForgetLayerNormWeights() const
Definition:
LstmParams.hpp:193
armnn::LstmInputParamsInfo::GetCellLayerNormWeights
const TensorInfo & GetCellLayerNormWeights() const
Definition:
LstmParams.hpp:197
armnn::LstmInputParamsInfo::GetInputToForgetWeights
const TensorInfo & GetInputToForgetWeights() const
Definition:
LstmParams.hpp:125
include
armnn
LstmParams.hpp
Generated on Wed Aug 28 2024 14:31:47 for Arm NN by
1.8.17