ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
BatchMatMulQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchMatMulQueueDescriptor:
[legend]
Collaboration diagram for BatchMatMulQueueDescriptor:
[legend]

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Additional Inherited Members

- Public Attributes inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
BatchMatMulDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 
- Protected Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 753 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 4188 of file WorkloadData.cpp.

4189 {
4190  const std::string descriptorName{"BatchMatMulDescriptor"};
4191 
4192  ValidateNumInputs(workloadInfo, descriptorName, 2);
4193  ValidateNumOutputs(workloadInfo, descriptorName, 1);
4194 
4195  // Inputs must be: both 2D+
4196  // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4197  // axes N and I must be the same size
4198 
4199  const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4200  const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4201  const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4202  // Output info has already been inferred
4203 
4204  std::vector<DataType> supportedTypes =
4205  {
4212  };
4213 
4214  ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4215  ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4216  ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4217 
4218  if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4219  (inputYInfoBeforeParams.GetNumDimensions() < 2))
4220  {
4221  throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4222  }
4223 
4224  TensorInfo inputXInfoAfterParams;
4225  TensorInfo inputYInfoAfterParams;
4226 
4229  {
4230  throw InvalidArgumentException(descriptorName +
4231  ": Invalid descriptor parameters - Transpose and Adjoint "
4232  "cannot both be true for a given input tensor.");
4233  }
4235  {
4236  inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4239  inputXInfoBeforeParams.GetShape()));
4240  }
4241  else if(m_Parameters.m_AdjointX)
4242  {
4244  inputXInfoBeforeParams.GetShape());
4245  if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4246  inputXInfoBeforeParams.GetShape()[axesToMul.second])
4247  {
4248  throw InvalidArgumentException(descriptorName +
4249  ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4250  }
4251  // Shape remains the same as it's square
4252  inputXInfoAfterParams = inputXInfoBeforeParams;
4253  }
4254  else
4255  {
4256  inputXInfoAfterParams = inputXInfoBeforeParams;
4257  }
4258 
4260  {
4261  inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4264  inputYInfoBeforeParams.GetShape()));
4265  }
4266  else if(m_Parameters.m_AdjointY)
4267  {
4269  inputYInfoBeforeParams.GetShape());
4270  if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4271  inputYInfoBeforeParams.GetShape()[axesToMul.second])
4272  {
4273  throw InvalidArgumentException(descriptorName +
4274  ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4275  }
4276  // Shape remains the same as it's square
4277  inputYInfoAfterParams = inputYInfoBeforeParams;
4278  }
4279  else
4280  {
4281  inputYInfoAfterParams = inputYInfoBeforeParams;
4282  }
4283 
4284  switch(m_Parameters.m_DataLayoutX)
4285  {
4286  case DataLayout::NCDHW:
4287  case DataLayout::NDHWC:
4288  if(inputXInfoAfterParams.GetNumDimensions() < 3)
4289  {
4290  throw InvalidArgumentException(descriptorName +
4291  ": Input tensor X does not have the correct "
4292  "number of dimensions for the Data Layout that it has been assigned.");
4293  }
4294  break;
4295  case DataLayout::NCHW:
4296  case DataLayout::NHWC:
4297  default:
4298  break;
4299  }
4300 
4301  switch(m_Parameters.m_DataLayoutY)
4302  {
4303  case DataLayout::NCDHW:
4304  case DataLayout::NDHWC:
4305  if(inputYInfoAfterParams.GetNumDimensions() < 3)
4306  {
4307  throw InvalidArgumentException(descriptorName +
4308  ": Input tensor Y does not have the correct "
4309  "number of dimensions for the Data Layout that it has been assigned.");
4310  }
4311  break;
4312  case DataLayout::NCHW:
4313  case DataLayout::NHWC:
4314  default:
4315  break;
4316  }
4317 
4319  inputXInfoAfterParams.GetShape());
4321  inputYInfoBeforeParams.GetShape());
4322 
4323  if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4324  != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4325  {
4326  throw InvalidArgumentException(descriptorName +
4327  ": The final axis of input tensor X must be the same size as "
4328  "the second last axis of input tensor Y.");
4329  }
4330 
4331  { // Separate scope so we don't pollute the rest of the scope with our temp variables
4332  // e.g. NHWC isnt compatible with NCHW as of now
4335 
4336  if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4337  {
4338  if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4339  {
4340  throw InvalidArgumentException(descriptorName +
4341  ": Invalid input tensor data layout combination.");
4342  }
4343  }
4344  if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4345  {
4346  if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4347  {
4348  throw InvalidArgumentException(descriptorName +
4349  ": Invalid input tensor data layout combination.");
4350  }
4351  }
4352  }
4353 
4354  // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4355  unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4356  inputYInfoAfterParams.GetNumDimensions());
4357  if(outputTensorDimSize-2 > 0)
4358  {
4359  TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4361  TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4363  TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4365 
4366  auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4367  {
4368  auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4369 
4370  for(unsigned int i = 0; i < sizeDiff; i++)
4371  {
4372  axisIndices.insert(axisIndices.begin(), 1);
4373  }
4374 
4375  for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4376  {
4377  ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4378  }
4379  };
4380 
4382  inputXInfoAfterParams.GetShape());
4384  inputYInfoAfterParams.GetShape());
4385 
4386  doAxisExtension(axesXNotMul, tiXNotMul);
4387  doAxisExtension(axesYNotMul, tiYNotMul);
4388 
4389  for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4390  {
4391  tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4392  tiYNotMul.GetShape()[i]);
4393  }
4394 
4395  ValidateBroadcastTensorShapesMatch(tiXNotMul,
4396  tiYNotMul,
4397  tiOutNotMul,
4398  descriptorName,
4399  "input_X",
4400  "input_Y");
4401  }
4402 }
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
DataLayout
Definition: Types.hpp:63
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:125
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
static std::vector< unsigned int > GetAxesNotMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes (for each input) that will not be multiplied together.
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
std::vector< TensorInfo > m_OutputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos

References armnn::BFloat16, armnn::Float16, armnn::Float32, BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), BatchMatMulDescriptor::GetPermuteVec(), TensorInfo::GetShape(), BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< BatchMatMulDescriptor >::m_Parameters, BatchMatMulDescriptor::m_TransposeX, BatchMatMulDescriptor::m_TransposeY, armnn::NCDHW, armnn::NCHW, armnn::NDHWC, armnn::NHWC, armnnUtils::Permuted(), armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.


The documentation for this struct was generated from the following files: