ArmNN
 25.11
Loading...
Searching...
No Matches
LstmParams.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "TensorFwd.hpp"
8#include "Exceptions.hpp"
9
10namespace armnn
11{
12
14{
16 : m_InputToInputWeights(nullptr)
17 , m_InputToForgetWeights(nullptr)
18 , m_InputToCellWeights(nullptr)
19 , m_InputToOutputWeights(nullptr)
24 , m_CellToInputWeights(nullptr)
25 , m_CellToForgetWeights(nullptr)
26 , m_CellToOutputWeights(nullptr)
27 , m_InputGateBias(nullptr)
28 , m_ForgetGateBias(nullptr)
29 , m_CellBias(nullptr)
30 , m_OutputGateBias(nullptr)
31 , m_ProjectionWeights(nullptr)
32 , m_ProjectionBias(nullptr)
35 , m_CellLayerNormWeights(nullptr)
37 {
38 }
39
61};
62
64{
66 : m_InputToInputWeights(nullptr)
67 , m_InputToForgetWeights(nullptr)
68 , m_InputToCellWeights(nullptr)
69 , m_InputToOutputWeights(nullptr)
74 , m_CellToInputWeights(nullptr)
75 , m_CellToForgetWeights(nullptr)
76 , m_CellToOutputWeights(nullptr)
77 , m_InputGateBias(nullptr)
78 , m_ForgetGateBias(nullptr)
79 , m_CellBias(nullptr)
80 , m_OutputGateBias(nullptr)
81 , m_ProjectionWeights(nullptr)
82 , m_ProjectionBias(nullptr)
85 , m_CellLayerNormWeights(nullptr)
87 {
88 }
110
111 const TensorInfo& Deref(const TensorInfo* tensorInfo) const
112 {
113 if (tensorInfo != nullptr)
114 {
115 const TensorInfo &temp = *tensorInfo;
116 return temp;
117 }
118 throw InvalidArgumentException("Can't dereference a null pointer");
119 }
120
122 {
124 }
126 {
128 }
130 {
132 }
134 {
136 }
154 {
156 }
158 {
160 }
162 {
164 }
166 {
167 return Deref(m_InputGateBias);
168 }
170 {
171 return Deref(m_ForgetGateBias);
172 }
173 const TensorInfo& GetCellBias() const
174 {
175 return Deref(m_CellBias);
176 }
178 {
179 return Deref(m_OutputGateBias);
180 }
182 {
184 }
186 {
187 return Deref(m_ProjectionBias);
188 }
190 {
192 }
198 {
200 }
205};
206
207} // namespace armnn
208
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Copyright (c) 2021 ARM Limited and Contributors.
const ConstTensor * m_InputLayerNormWeights
const ConstTensor * m_RecurrentToCellWeights
const ConstTensor * m_InputToForgetWeights
const ConstTensor * m_CellToForgetWeights
const ConstTensor * m_RecurrentToInputWeights
const ConstTensor * m_ProjectionBias
const ConstTensor * m_CellToInputWeights
const ConstTensor * m_InputToCellWeights
const ConstTensor * m_CellBias
const ConstTensor * m_RecurrentToOutputWeights
const ConstTensor * m_InputToOutputWeights
const ConstTensor * m_OutputGateBias
const ConstTensor * m_OutputLayerNormWeights
const ConstTensor * m_InputGateBias
const ConstTensor * m_ProjectionWeights
const ConstTensor * m_ForgetGateBias
const ConstTensor * m_CellLayerNormWeights
const ConstTensor * m_RecurrentToForgetWeights
const ConstTensor * m_ForgetLayerNormWeights
const ConstTensor * m_CellToOutputWeights
const ConstTensor * m_InputToInputWeights
const TensorInfo * m_InputGateBias
const TensorInfo * m_RecurrentToForgetWeights
const TensorInfo * m_InputToInputWeights
const TensorInfo * m_CellToOutputWeights
const TensorInfo * m_CellToForgetWeights
const TensorInfo * m_ForgetGateBias
const TensorInfo & GetRecurrentToCellWeights() const
const TensorInfo & GetInputGateBias() const
const TensorInfo * m_InputToOutputWeights
const TensorInfo & GetCellToOutputWeights() const
const TensorInfo * m_OutputGateBias
const TensorInfo & GetCellToInputWeights() const
const TensorInfo * m_InputToForgetWeights
const TensorInfo & GetInputToOutputWeights() const
const TensorInfo * m_ForgetLayerNormWeights
const TensorInfo & GetCellLayerNormWeights() const
const TensorInfo * m_CellLayerNormWeights
const TensorInfo * m_RecurrentToCellWeights
const TensorInfo & GetInputLayerNormWeights() const
const TensorInfo & GetRecurrentToForgetWeights() const
const TensorInfo & GetInputToForgetWeights() const
const TensorInfo & GetOutputLayerNormWeights() const
const TensorInfo & GetOutputGateBias() const
const TensorInfo & GetInputToInputWeights() const
const TensorInfo * m_OutputLayerNormWeights
const TensorInfo * m_RecurrentToInputWeights
const TensorInfo & Deref(const TensorInfo *tensorInfo) const
const TensorInfo & GetForgetGateBias() const
const TensorInfo & GetRecurrentToInputWeights() const
const TensorInfo & GetCellBias() const
const TensorInfo * m_CellBias
const TensorInfo * m_InputLayerNormWeights
const TensorInfo * m_CellToInputWeights
const TensorInfo * m_ProjectionWeights
const TensorInfo & GetRecurrentToOutputWeights() const
const TensorInfo & GetProjectionWeights() const
const TensorInfo & GetInputToCellWeights() const
const TensorInfo * m_ProjectionBias
const TensorInfo & GetForgetLayerNormWeights() const
const TensorInfo * m_InputToCellWeights
const TensorInfo & GetProjectionBias() const
const TensorInfo & GetCellToForgetWeights() const
const TensorInfo * m_RecurrentToOutputWeights