ArmNN
 25.11
Loading...
Searching...
No Matches
BatchMatMulQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchMatMulQueueDescriptor:
[legend]
Collaboration diagram for BatchMatMulQueueDescriptor:
[legend]

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
Public Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
template<typename T>
const T * GetAdditionalInformation () const

Additional Inherited Members

Public Attributes inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
BatchMatMulDescriptor m_Parameters
Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
std::vector< ITensorHandle * > m_Outputs
void * m_AdditionalInfoObject
bool m_AllowExpandedDims = false
Protected Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
 QueueDescriptorWithParameters ()=default
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 QueueDescriptor (QueueDescriptor const &)=default
QueueDescriptoroperator= (QueueDescriptor const &)=default

Detailed Description

Definition at line 753 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo & workloadInfo) const

Definition at line 4188 of file WorkloadData.cpp.

4189{
4190 const std::string descriptorName{"BatchMatMulDescriptor"};
4191
4192 ValidateNumInputs(workloadInfo, descriptorName, 2);
4193 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4194
4195 // Inputs must be: both 2D+
4196 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4197 // axes N and I must be the same size
4198
4199 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4200 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4201 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4202 // Output info has already been inferred
4203
4204 std::vector<DataType> supportedTypes =
4205 {
4212 };
4213
4214 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4215 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4216 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4217
4218 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4219 (inputYInfoBeforeParams.GetNumDimensions() < 2))
4220 {
4221 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4222 }
4223
4224 TensorInfo inputXInfoAfterParams;
4225 TensorInfo inputYInfoAfterParams;
4226
4229 {
4230 throw InvalidArgumentException(descriptorName +
4231 ": Invalid descriptor parameters - Transpose and Adjoint "
4232 "cannot both be true for a given input tensor.");
4233 }
4235 {
4236 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4239 inputXInfoBeforeParams.GetShape()));
4240 }
4241 else if(m_Parameters.m_AdjointX)
4242 {
4244 inputXInfoBeforeParams.GetShape());
4245 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4246 inputXInfoBeforeParams.GetShape()[axesToMul.second])
4247 {
4248 throw InvalidArgumentException(descriptorName +
4249 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4250 }
4251 // Shape remains the same as it's square
4252 inputXInfoAfterParams = inputXInfoBeforeParams;
4253 }
4254 else
4255 {
4256 inputXInfoAfterParams = inputXInfoBeforeParams;
4257 }
4258
4260 {
4261 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4264 inputYInfoBeforeParams.GetShape()));
4265 }
4266 else if(m_Parameters.m_AdjointY)
4267 {
4269 inputYInfoBeforeParams.GetShape());
4270 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4271 inputYInfoBeforeParams.GetShape()[axesToMul.second])
4272 {
4273 throw InvalidArgumentException(descriptorName +
4274 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4275 }
4276 // Shape remains the same as it's square
4277 inputYInfoAfterParams = inputYInfoBeforeParams;
4278 }
4279 else
4280 {
4281 inputYInfoAfterParams = inputYInfoBeforeParams;
4282 }
4283
4285 {
4286 case DataLayout::NCDHW:
4287 case DataLayout::NDHWC:
4288 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4289 {
4290 throw InvalidArgumentException(descriptorName +
4291 ": Input tensor X does not have the correct "
4292 "number of dimensions for the Data Layout that it has been assigned.");
4293 }
4294 break;
4295 case DataLayout::NCHW:
4296 case DataLayout::NHWC:
4297 default:
4298 break;
4299 }
4300
4302 {
4303 case DataLayout::NCDHW:
4304 case DataLayout::NDHWC:
4305 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4306 {
4307 throw InvalidArgumentException(descriptorName +
4308 ": Input tensor Y does not have the correct "
4309 "number of dimensions for the Data Layout that it has been assigned.");
4310 }
4311 break;
4312 case DataLayout::NCHW:
4313 case DataLayout::NHWC:
4314 default:
4315 break;
4316 }
4317
4319 inputXInfoAfterParams.GetShape());
4321 inputYInfoBeforeParams.GetShape());
4322
4323 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4324 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4325 {
4326 throw InvalidArgumentException(descriptorName +
4327 ": The final axis of input tensor X must be the same size as "
4328 "the second last axis of input tensor Y.");
4329 }
4330
4331 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4332 // e.g. NHWC isnt compatible with NCHW as of now
4335
4336 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4337 {
4338 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4339 {
4340 throw InvalidArgumentException(descriptorName +
4341 ": Invalid input tensor data layout combination.");
4342 }
4343 }
4344 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4345 {
4346 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4347 {
4348 throw InvalidArgumentException(descriptorName +
4349 ": Invalid input tensor data layout combination.");
4350 }
4351 }
4352 }
4353
4354 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4355 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4356 inputYInfoAfterParams.GetNumDimensions());
4357 if(outputTensorDimSize-2 > 0)
4358 {
4359 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4361 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4363 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4365
4366 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4367 {
4368 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4369
4370 for(unsigned int i = 0; i < sizeDiff; i++)
4371 {
4372 axisIndices.insert(axisIndices.begin(), 1);
4373 }
4374
4375 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4376 {
4377 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4378 }
4379 };
4380
4382 inputXInfoAfterParams.GetShape());
4384 inputYInfoAfterParams.GetShape());
4385
4386 doAxisExtension(axesXNotMul, tiXNotMul);
4387 doAxisExtension(axesYNotMul, tiYNotMul);
4388
4389 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4390 {
4391 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4392 tiYNotMul.GetShape()[i]);
4393 }
4394
4395 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4396 tiYNotMul,
4397 tiOutNotMul,
4398 descriptorName,
4399 "input_X",
4400 "input_Y");
4401 }
4402}
const TensorShape & GetShape() const
Definition Tensor.hpp:193
unsigned int GetNumDimensions() const
Definition Tensor.hpp:197
DataLayout
Definition Types.hpp:63
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition Permute.cpp:125
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
static std::vector< unsigned int > GetAxesNotMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes (for each input) that will not be multiplied together.
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos

References armnn::BFloat16, armnn::Float16, armnn::Float32, BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), BatchMatMulDescriptor::GetPermuteVec(), TensorInfo::GetShape(), WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< BatchMatMulDescriptor >::m_Parameters, armnn::NCDHW, armnn::NCHW, armnn::NDHWC, armnn::NHWC, armnnUtils::Permuted(), armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.


The documentation for this struct was generated from the following files: