ArmNN
 25.11
Loading...
Searching...
No Matches
LstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for LstmQueueDescriptor:
[legend]
Collaboration diagram for LstmQueueDescriptor:
[legend]

Public Member Functions

 LstmQueueDescriptor ()
void Validate (const WorkloadInfo &workloadInfo) const
Public Member Functions inherited from QueueDescriptorWithParameters< LstmDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
template<typename T>
const T * GetAdditionalInformation () const

Public Attributes

const ConstTensorHandlem_InputToInputWeights
const ConstTensorHandlem_InputToForgetWeights
const ConstTensorHandlem_InputToCellWeights
const ConstTensorHandlem_InputToOutputWeights
const ConstTensorHandlem_RecurrentToInputWeights
const ConstTensorHandlem_RecurrentToForgetWeights
const ConstTensorHandlem_RecurrentToCellWeights
const ConstTensorHandlem_RecurrentToOutputWeights
const ConstTensorHandlem_CellToInputWeights
const ConstTensorHandlem_CellToForgetWeights
const ConstTensorHandlem_CellToOutputWeights
const ConstTensorHandlem_InputGateBias
const ConstTensorHandlem_ForgetGateBias
const ConstTensorHandlem_CellBias
const ConstTensorHandlem_OutputGateBias
const ConstTensorHandlem_ProjectionWeights
const ConstTensorHandlem_ProjectionBias
const ConstTensorHandlem_InputLayerNormWeights
const ConstTensorHandlem_ForgetLayerNormWeights
const ConstTensorHandlem_CellLayerNormWeights
const ConstTensorHandlem_OutputLayerNormWeights
Public Attributes inherited from QueueDescriptorWithParameters< LstmDescriptor >
LstmDescriptor m_Parameters
Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
void * m_AdditionalInfoObject
bool m_AllowExpandedDims = false

Additional Inherited Members

Protected Member Functions inherited from QueueDescriptorWithParameters< LstmDescriptor >
 QueueDescriptorWithParameters ()=default
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 QueueDescriptor (QueueDescriptor const &)=default
QueueDescriptoroperator= (QueueDescriptor const &)=default

Detailed Description

Definition at line 400 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ LstmQueueDescriptor()

LstmQueueDescriptor ( )
inline

Definition at line 402 of file WorkloadData.hpp.

403 : m_InputToInputWeights(nullptr)
404 , m_InputToForgetWeights(nullptr)
405 , m_InputToCellWeights(nullptr)
406 , m_InputToOutputWeights(nullptr)
407 , m_RecurrentToInputWeights(nullptr)
408 , m_RecurrentToForgetWeights(nullptr)
409 , m_RecurrentToCellWeights(nullptr)
410 , m_RecurrentToOutputWeights(nullptr)
411 , m_CellToInputWeights(nullptr)
412 , m_CellToForgetWeights(nullptr)
413 , m_CellToOutputWeights(nullptr)
414 , m_InputGateBias(nullptr)
415 , m_ForgetGateBias(nullptr)
416 , m_CellBias(nullptr)
417 , m_OutputGateBias(nullptr)
418 , m_ProjectionWeights(nullptr)
419 , m_ProjectionBias(nullptr)
420 , m_InputLayerNormWeights(nullptr)
421 , m_ForgetLayerNormWeights(nullptr)
422 , m_CellLayerNormWeights(nullptr)
423 , m_OutputLayerNormWeights(nullptr)
424 {
425 }

References m_CellBias, m_CellLayerNormWeights, m_CellToForgetWeights, m_CellToInputWeights, m_CellToOutputWeights, m_ForgetGateBias, m_ForgetLayerNormWeights, m_InputGateBias, m_InputLayerNormWeights, m_InputToCellWeights, m_InputToForgetWeights, m_InputToInputWeights, m_InputToOutputWeights, m_OutputGateBias, m_OutputLayerNormWeights, m_ProjectionBias, m_ProjectionWeights, m_RecurrentToCellWeights, m_RecurrentToForgetWeights, m_RecurrentToInputWeights, and m_RecurrentToOutputWeights.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo & workloadInfo) const

Definition at line 2021 of file WorkloadData.cpp.

2022{
2023 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
2024
2025 const std::string descriptorName{"LstmQueueDescriptor"};
2026
2027 // check dimensions of all inputs and outputs
2028 if (workloadInfo.m_InputTensorInfos.size() != 3)
2029 {
2030 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
2031 }
2032 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2033 {
2034 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
2035 }
2036
2037 std::vector<DataType> supportedTypes =
2038 {
2043 };
2044
2045 // check for supported type of one input and match them with all the other input and output
2046 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
2047
2048 // type matches all other inputs
2049 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
2050 {
2051 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2052 workloadInfo.m_InputTensorInfos[i],
2053 descriptorName,
2054 "input_0",
2055 "input_" + std::to_string(i));
2056 }
2057 // type matches all other outputs
2058 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
2059 {
2060 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
2061 workloadInfo.m_OutputTensorInfos[i],
2062 "LstmQueueDescriptor",
2063 "input_0",
2064 "output_" + std::to_string(i));
2065 }
2066
2067 // Making sure clipping parameters have valid values.
2068 // == 0 means no clipping
2069 // > 0 means clipping
2071 {
2072 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2073 }
2075 {
2076 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2077 }
2078
2079 // Inferring batch size, number of outputs and number of cells from the inputs.
2080 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2081 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2082 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2083 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2084 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2085 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2086
2087 // input tensor
2088 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2089 descriptorName + " input_0");
2090 // outputStateInTensor
2091 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2092 descriptorName + " input_1");
2093 // outputStateInTensor
2094 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2095 descriptorName + " input_2");
2096 // scratchBufferTensor
2097 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
2098 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2099 descriptorName + " output_0");
2100 // outputStateOutTensor
2101 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2102 descriptorName + " output_1");
2103 // cellStateOutTensor
2104 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2105 descriptorName + " output_2");
2106 // outputTensor
2107 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2108 descriptorName + " output_3");
2109
2110 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2112 {
2114 (n_cell * n_input), "InputLayerNormWeights");
2115 }
2116
2117 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2119 (n_cell * n_input), "InputToForgetWeights");
2120
2121 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2123 (n_cell * n_input), "InputToCellWeights");
2124
2126 {
2128 (n_cell * n_output), "RecurrentToInputWeights");
2129 }
2130
2131 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2133 (n_cell * n_output), "RecurrentToForgetWeights");
2134
2135 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2137 (n_cell * n_output), "RecurrentToCellWeights");
2138
2139 // Make sure the input-gate's parameters are either both present (regular
2140 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2141 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2145 if (!cifg_weights_all_or_none)
2146 {
2147 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2148 "RecurrentToInputWeights must either both be present (regular LSTM) "
2149 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2150 "accordingly.");
2151 }
2152
2154 {
2156 n_cell, "CellToInputWeights");
2157 }
2159 {
2161 n_cell, "CellToForgetWeights");
2162 }
2164 {
2166 n_cell, "CellToOutputWeights");
2167 }
2168
2169 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2170 bool peephole_weights_all_or_none =
2175 if (!peephole_weights_all_or_none)
2176 {
2177 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
2178 }
2179
2180 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2182 {
2183 if (m_InputGateBias)
2184 {
2185 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
2186 }
2187 }
2188 else
2189 {
2190 if (!m_InputGateBias)
2191 {
2192 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2193 "must be present.");
2194 }
2195 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2196 n_cell, "InputGateBias");
2197 }
2198
2199 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2200 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2201
2202 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2203 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2204
2205 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2206 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2207
2209 {
2211 (n_cell * n_output), "ProjectionWeights");
2212 }
2213 if (m_ProjectionBias)
2214 {
2215 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2216 }
2217
2218 // Making sure the projection tensors are consistent:
2219 // 1) If projection weight is not present, then projection bias should not be
2220 // present.
2221 // 2) If projection weight is present, then projection bias is optional.
2222 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2228 if (!projecton_tensors_consistent)
2229 {
2230 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
2231 }
2232
2233 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2234 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2235 // either all have values or none of them have values. Layer normalization is used when the values of all the
2236 // layer normalization weights are present
2238 {
2239 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2240 }
2242 {
2243 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2244 }
2246 {
2247 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2248 }
2250 {
2251 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2252 }
2253
2255 {
2257 {
2259 {
2260 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2261 "disabled but InputLayerNormWeights are not present");
2262 }
2264 1, n_cell, "InputLayerNormWeights");
2265 }
2266 else if (m_InputLayerNormWeights)
2267 {
2268 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2269 "enabled");
2270 }
2271
2272 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2273 "ForgetLayerNormWeights");
2274 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2275
2276 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2277 "OutputLayerNormWeights");
2278 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2279
2280 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2281 "CellLayerNormWeights");
2282 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2283 }
2285 {
2286 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2287 "normalisation weights are present.");
2288 }
2289}
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
float m_ClippingThresCell
Clipping threshold value for the cell state.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_ClippingThresProj
Clipping threshold value for the projection.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
const ConstTensorHandle * m_OutputLayerNormWeights
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_InputLayerNormWeights
const ConstTensorHandle * m_CellToForgetWeights
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_ProjectionWeights
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_CellToInputWeights
const ConstTensorHandle * m_CellToOutputWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_ProjectionBias
const ConstTensorHandle * m_ForgetLayerNormWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
const ConstTensorHandle * m_CellLayerNormWeights
void ValidateTensorNumDimNumElem(const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos

References armnn::BFloat16, armnn::Float16, armnn::Float32, m_CellBias, m_CellLayerNormWeights, m_CellToForgetWeights, m_CellToInputWeights, m_CellToOutputWeights, m_ForgetGateBias, m_ForgetLayerNormWeights, m_InputGateBias, m_InputLayerNormWeights, WorkloadInfo::m_InputTensorInfos, m_InputToCellWeights, m_InputToForgetWeights, m_InputToInputWeights, m_InputToOutputWeights, m_OutputGateBias, m_OutputLayerNormWeights, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< LstmDescriptor >::m_Parameters, m_ProjectionBias, m_ProjectionWeights, m_RecurrentToCellWeights, m_RecurrentToForgetWeights, m_RecurrentToInputWeights, m_RecurrentToOutputWeights, armnn::QSymmS16, and QueueDescriptor::ValidateTensorNumDimNumElem().

Member Data Documentation

◆ m_CellBias

const ConstTensorHandle* m_CellBias

Definition at line 440 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_CellLayerNormWeights

const ConstTensorHandle* m_CellLayerNormWeights

Definition at line 446 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_CellToForgetWeights

const ConstTensorHandle* m_CellToForgetWeights

Definition at line 436 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_CellToInputWeights

const ConstTensorHandle* m_CellToInputWeights

Definition at line 435 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_CellToOutputWeights

const ConstTensorHandle* m_CellToOutputWeights

Definition at line 437 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_ForgetGateBias

const ConstTensorHandle* m_ForgetGateBias

Definition at line 439 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_ForgetLayerNormWeights

const ConstTensorHandle* m_ForgetLayerNormWeights

Definition at line 445 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputGateBias

const ConstTensorHandle* m_InputGateBias

Definition at line 438 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputLayerNormWeights

const ConstTensorHandle* m_InputLayerNormWeights

Definition at line 444 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputToCellWeights

const ConstTensorHandle* m_InputToCellWeights

Definition at line 429 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputToForgetWeights

const ConstTensorHandle* m_InputToForgetWeights

Definition at line 428 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputToInputWeights

const ConstTensorHandle* m_InputToInputWeights

Definition at line 427 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_InputToOutputWeights

const ConstTensorHandle* m_InputToOutputWeights

Definition at line 430 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_OutputGateBias

const ConstTensorHandle* m_OutputGateBias

Definition at line 441 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_OutputLayerNormWeights

const ConstTensorHandle* m_OutputLayerNormWeights

Definition at line 447 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_ProjectionBias

const ConstTensorHandle* m_ProjectionBias

Definition at line 443 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_ProjectionWeights

const ConstTensorHandle* m_ProjectionWeights

Definition at line 442 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_RecurrentToCellWeights

const ConstTensorHandle* m_RecurrentToCellWeights

Definition at line 433 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_RecurrentToForgetWeights

const ConstTensorHandle* m_RecurrentToForgetWeights

Definition at line 432 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_RecurrentToInputWeights

const ConstTensorHandle* m_RecurrentToInputWeights

Definition at line 431 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().

◆ m_RecurrentToOutputWeights

const ConstTensorHandle* m_RecurrentToOutputWeights

Definition at line 434 of file WorkloadData.hpp.

Referenced by LstmLayer::CreateWorkload(), LstmQueueDescriptor(), and Validate().


The documentation for this struct was generated from the following files: