ArmNN
 25.11
Loading...
Searching...
No Matches
ConcatQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for ConcatQueueDescriptor:
[legend]
Collaboration diagram for ConcatQueueDescriptor:
[legend]

Classes

struct  ViewOrigin

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
Public Member Functions inherited from QueueDescriptorWithParameters< OriginsDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
template<typename T>
const T * GetAdditionalInformation () const

Public Attributes

std::vector< ViewOriginm_ViewOrigins
Public Attributes inherited from QueueDescriptorWithParameters< OriginsDescriptor >
OriginsDescriptor m_Parameters
Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
void * m_AdditionalInfoObject
bool m_AllowExpandedDims = false

Additional Inherited Members

Protected Member Functions inherited from QueueDescriptorWithParameters< OriginsDescriptor >
 QueueDescriptorWithParameters ()=default
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 QueueDescriptor (QueueDescriptor const &)=default
QueueDescriptoroperator= (QueueDescriptor const &)=default

Detailed Description

Definition at line 130 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo & workloadInfo) const

Definition at line 820 of file WorkloadData.cpp.

821{
822 const std::string descriptorName{"ConcatQueueDescriptor"};
823
824 ValidateNumOutputs(workloadInfo, descriptorName, 1);
825
826 if (m_Inputs.size() <= 0)
827 {
828 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
829 }
830 if (m_Outputs.size() <= 0)
831 {
832 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
833 }
834
835 if (workloadInfo.m_InputTensorInfos.size() <= 0)
836 {
837 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
838 }
839 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
840 {
841 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
842 }
843
844 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
845 {
846 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
847 }
848
849 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
850 {
851 return;
852 }
853
854 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
855 {
856 throw InvalidArgumentException(
857 descriptorName + ": Number of split windows "
858 "has to match number of workloadInfo.m_InputTensorInfos. "
859 "Number of windows: " +
860 to_string(m_ViewOrigins.size()) +
861 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
862 }
863
864 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
865 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
866 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
867 {
868 //Checks that the dimensionality of output is same as the split windows.
869 ViewOrigin const& e = m_ViewOrigins[w];
870 if (e.m_Origin.size() != outputDims)
871 {
872 throw InvalidArgumentException(descriptorName + ": Window origin have to "
873 "have the same dimensionality as the output tensor. "
874 "Window origin (index: " +
875 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
876 " dimensions, the output "
877 "tensor has " +
878 to_string(outputDims) + " dimensions.");
879 }
880 //Checks that the merge windows are within the output tensor.
881 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
882 {
883 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
884 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
885 {
886 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
887 "be smaller or equal than the size of the output in that coord.");
888 }
889 }
890 }
891
892 // Check the supported data types
893 std::vector<DataType> supportedTypes =
894 {
903 };
904
905 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
906 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
907 {
908 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
909 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
910
911 const std::string inputName = "input_" + std::to_string(i);
912 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
913 }
914}
std::vector< ViewOrigin > m_ViewOrigins
unsigned int GetConcatAxis() const
Get the concatenation axis value.
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos

References armnn::BFloat16, armnn::Boolean, armnn::Float16, armnn::Float32, QueueDescriptor::m_Inputs, WorkloadInfo::m_InputTensorInfos, ConcatQueueDescriptor::ViewOrigin::m_Origin, QueueDescriptor::m_Outputs, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< OriginsDescriptor >::m_Parameters, m_ViewOrigins, armnn::QAsymmS8, armnn::QAsymmU8, armnn::QSymmS16, and armnn::Signed32.

Member Data Documentation

◆ m_ViewOrigins

std::vector<ViewOrigin> m_ViewOrigins

Definition at line 143 of file WorkloadData.hpp.

Referenced by armnn::Concatenate(), ConcatLayer::CreateWorkload(), and Validate().


The documentation for this struct was generated from the following files: