ArmNN
 25.11
Loading...
Searching...
No Matches
QuantizedLstmQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for QuantizedLstmQueueDescriptor:
[legend]
Collaboration diagram for QuantizedLstmQueueDescriptor:
[legend]

Public Member Functions

 QuantizedLstmQueueDescriptor ()
void Validate (const WorkloadInfo &workloadInfo) const
Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
template<typename T>
const T * GetAdditionalInformation () const

Public Attributes

const ConstTensorHandlem_InputToInputWeights
const ConstTensorHandlem_InputToForgetWeights
const ConstTensorHandlem_InputToCellWeights
const ConstTensorHandlem_InputToOutputWeights
const ConstTensorHandlem_RecurrentToInputWeights
const ConstTensorHandlem_RecurrentToForgetWeights
const ConstTensorHandlem_RecurrentToCellWeights
const ConstTensorHandlem_RecurrentToOutputWeights
const ConstTensorHandlem_InputGateBias
const ConstTensorHandlem_ForgetGateBias
const ConstTensorHandlem_CellBias
const ConstTensorHandlem_OutputGateBias
Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
void * m_AdditionalInfoObject
bool m_AllowExpandedDims = false

Additional Inherited Members

Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 QueueDescriptor (QueueDescriptor const &)=default
QueueDescriptoroperator= (QueueDescriptor const &)=default

Detailed Description

Definition at line 614 of file WorkloadData.hpp.

Constructor & Destructor Documentation

◆ QuantizedLstmQueueDescriptor()

Definition at line 616 of file WorkloadData.hpp.

617 : m_InputToInputWeights(nullptr)
618 , m_InputToForgetWeights(nullptr)
619 , m_InputToCellWeights(nullptr)
620 , m_InputToOutputWeights(nullptr)
621
622 , m_RecurrentToInputWeights(nullptr)
623 , m_RecurrentToForgetWeights(nullptr)
624 , m_RecurrentToCellWeights(nullptr)
625 , m_RecurrentToOutputWeights(nullptr)
626
627 , m_InputGateBias(nullptr)
628 , m_ForgetGateBias(nullptr)
629 , m_CellBias(nullptr)
630 , m_OutputGateBias(nullptr)
631 {}

References m_CellBias, m_ForgetGateBias, m_InputGateBias, m_InputToCellWeights, m_InputToForgetWeights, m_InputToInputWeights, m_InputToOutputWeights, m_OutputGateBias, m_RecurrentToCellWeights, m_RecurrentToForgetWeights, m_RecurrentToInputWeights, and m_RecurrentToOutputWeights.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo & workloadInfo) const

Definition at line 3446 of file WorkloadData.cpp.

3447{
3448 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3449
3450 // Validate number of inputs/outputs
3451 ValidateNumInputs(workloadInfo, descriptorName, 3);
3452 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3453
3454 // Input/output tensor infos
3455 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3456 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3457 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3458
3459 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3460 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3461
3462 std::vector<DataType> inputOutputSupportedTypes =
3463 {
3465 };
3466
3467 std::vector<DataType> cellStateSupportedTypes =
3468 {
3470 };
3471
3472 std::vector<DataType> weightsSupportedTypes =
3473 {
3475 };
3476
3477 std::vector<DataType> biasSupportedTypes =
3478 {
3480 };
3481
3482 // Validate types of input/output tensors
3483 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3484 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3485 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3486
3487 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3488 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3489
3490 // Validate matching types of input/output tensors
3491 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3492 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3493 "outputStateIn", "outputStateOut");
3494 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3495
3496 // Validate matching quantization info for input/output tensors
3497 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3498 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3499 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3500
3501 // Infer number of batches, input size and output size from tensor dimensions
3502 const uint32_t numBatches = inputInfo.GetShape()[0];
3503 const uint32_t inputSize = inputInfo.GetShape()[1];
3504 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3505
3506 // Validate number of dimensions and number of elements for input/output tensors
3507 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3508 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3509 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3510 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3511 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3512
3513 // Validate number of dimensions and number of elements for weights tensors
3514 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3515 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3516 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3517
3518 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3519 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3520 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3521
3522 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3523 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3524 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3525
3526 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3527 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3528 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3529
3530 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3531 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3532 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3533
3534 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3535 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3536 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3537 " RecurrentToForgetWeights");
3538
3539 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3540 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3541 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3542
3543 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3544 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3545 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3546
3547 // Validate data types for weights tensors (all should match each other)
3548 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3549
3550 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3551 "inputToInputWeights", "inputToForgetWeights");
3552 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3553 "inputToInputWeights", "inputToCellWeights");
3554 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3555 "inputToInputWeights", "inputToOutputWeights");
3556
3557 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3558 "inputToInputWeights", "recurrentToInputWeights");
3559 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3560 "inputToInputWeights", "recurrentToForgeteights");
3561 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3562 "inputToInputWeights", "recurrentToCellWeights");
3563 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3564 "inputToInputWeights", "recurrentToOutputWeights");
3565
3566 // Validate matching quantization info for weight tensors (all should match each other)
3567 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3568 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3569 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3570 descriptorName, "inputToInputWeights", "inputToCellWeights");
3571 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3572 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3573
3574 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3575 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3576 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3577 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3578 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3579 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3580 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3581 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3582
3583 // Validate number of dimensions and number of elements in bias tensors
3584 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3585 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3586 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3587
3588 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3589 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3590 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3591
3592 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3593 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3594 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3595
3596 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3597 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3598 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3599
3600 // Validate data types for bias tensors (all should match each other)
3601 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3602
3603 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3604 "inputGateBias", "forgetGateBias");
3605 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3606 "inputGateBias", "cellBias");
3607 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3608 "inputGateBias", "outputGateBias");
3609
3610 // Validate bias tensor quantization info
3611 ValidateBiasTensorQuantization(inputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3612 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3613 ValidateBiasTensorQuantization(cellBiasInfo, inputToInputWeightsInfo, descriptorName);
3614 ValidateBiasTensorQuantization(outputGateBiasInfo, inputToInputWeightsInfo, descriptorName);
3615}
const ConstTensorHandle * m_InputToOutputWeights
const ConstTensorHandle * m_RecurrentToInputWeights
const ConstTensorHandle * m_ForgetGateBias
const ConstTensorHandle * m_InputGateBias
const ConstTensorHandle * m_RecurrentToOutputWeights
const ConstTensorHandle * m_OutputGateBias
const ConstTensorHandle * m_CellBias
const ConstTensorHandle * m_InputToCellWeights
const ConstTensorHandle * m_InputToForgetWeights
const ConstTensorHandle * m_InputToInputWeights
const ConstTensorHandle * m_RecurrentToCellWeights
const ConstTensorHandle * m_RecurrentToForgetWeights
void ValidateTensorNumDimNumElem(const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos

References TensorInfo::GetShape(), m_CellBias, m_ForgetGateBias, m_InputGateBias, WorkloadInfo::m_InputTensorInfos, m_InputToCellWeights, m_InputToForgetWeights, m_InputToInputWeights, m_InputToOutputWeights, m_OutputGateBias, WorkloadInfo::m_OutputTensorInfos, m_RecurrentToCellWeights, m_RecurrentToForgetWeights, m_RecurrentToInputWeights, m_RecurrentToOutputWeights, armnn::QAsymmU8, armnn::QSymmS16, armnn::Signed32, and QueueDescriptor::ValidateTensorNumDimNumElem().

Member Data Documentation

◆ m_CellBias

◆ m_ForgetGateBias

const ConstTensorHandle* m_ForgetGateBias

◆ m_InputGateBias

const ConstTensorHandle* m_InputGateBias

◆ m_InputToCellWeights

const ConstTensorHandle* m_InputToCellWeights

◆ m_InputToForgetWeights

const ConstTensorHandle* m_InputToForgetWeights

◆ m_InputToInputWeights

const ConstTensorHandle* m_InputToInputWeights

◆ m_InputToOutputWeights

const ConstTensorHandle* m_InputToOutputWeights

◆ m_OutputGateBias

const ConstTensorHandle* m_OutputGateBias

◆ m_RecurrentToCellWeights

const ConstTensorHandle* m_RecurrentToCellWeights

◆ m_RecurrentToForgetWeights

const ConstTensorHandle* m_RecurrentToForgetWeights

◆ m_RecurrentToInputWeights

const ConstTensorHandle* m_RecurrentToInputWeights

◆ m_RecurrentToOutputWeights

const ConstTensorHandle* m_RecurrentToOutputWeights

The documentation for this struct was generated from the following files: