ArmNN
 24.08
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
[legend]
Collaboration diagram for BatchMatMulDescriptor:
[legend]

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
 
bool operator== (const BatchMatMulDescriptor &rhs) const
 
- Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
 
virtual ~BaseDescriptor ()=default
 

Static Public Member Functions

static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication. More...
 
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together. More...
 
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed. More...
 

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_TransposeY
 
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_AdjointY
 
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More...
 
DataLayout m_DataLayoutY
 

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1584 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool  transposeX = false,
bool  transposeY = false,
bool  adjointX = false,
bool  adjointY = false,
DataLayout  dataLayoutX = DataLayout::NCHW,
DataLayout  dataLayoutY = DataLayout::NCHW 
)
inline

Definition at line 1586 of file Descriptors.hpp.

1592  : m_TransposeX(transposeX)
1593  , m_TransposeY(transposeY)
1594  , m_AdjointX(adjointX)
1595  , m_AdjointY(adjointY)
1596  , m_DataLayoutX(dataLayoutX)
1597  , m_DataLayoutY(dataLayoutY)
1598  {}

Member Function Documentation

◆ GetAxesNotMul()

std::vector< unsigned int > GetAxesNotMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 506 of file Descriptors.cpp.

509 {
510  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
511  std::vector<unsigned int> axesNotMul;
512  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
513  {
514  if(i == axesToMul.first || i == axesToMul.second)
515  {
516  continue;
517  }
518  axesNotMul.push_back(i);
519  }
520  return axesNotMul;
521 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulQueueDescriptor::Validate().

◆ GetAxesToMul()

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 485 of file Descriptors.cpp.

488 {
489  auto numDims = tensorShape.GetNumDimensions();
490  std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
491  switch(dataLayout)
492  {
493  case DataLayout::NDHWC:
494  case DataLayout::NHWC:
495  axes.first -= 1;
496  axes.second -= 1;
497  break;
498  case DataLayout::NCDHW:
499  case DataLayout::NCHW:
500  default:
501  break;
502  }
503  return axes;
504 }

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

Referenced by BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes which will be transposed.

Definition at line 523 of file Descriptors.cpp.

526 {
527  std::vector<unsigned int> vec;
528  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
529  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
530  {
531  if(i == axesToMul.first)
532  {
533  vec.push_back(i+1);
534  }
535  else if(i == axesToMul.second)
536  {
537  vec.push_back(i-1);
538  }
539  else
540  {
541  vec.push_back(i);
542  }
543  }
544  return PermutationVector(vec.data(),
545  static_cast<unsigned int>(vec.size()));
546 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ operator==()

bool operator== ( const BatchMatMulDescriptor rhs) const
inline

Definition at line 1600 of file Descriptors.hpp.

1601  {
1602  return m_TransposeX == rhs.m_TransposeX &&
1603  m_TransposeY == rhs.m_TransposeY &&
1604  m_AdjointX == rhs.m_AdjointX &&
1605  m_AdjointY == rhs.m_AdjointY &&
1606  m_DataLayoutX == rhs.m_DataLayoutX &&
1607  m_DataLayoutY == rhs.m_DataLayoutY;
1608  }

References BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, and BatchMatMulDescriptor::m_TransposeY.

Member Data Documentation

◆ m_AdjointX

◆ m_AdjointY

◆ m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

◆ m_TransposeY


The documentation for this struct was generated from the following files:
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1612
armnn::DataLayout::NCDHW
@ NCDHW
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1617
armnn::DataLayout::NHWC
@ NHWC
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
Definition: Descriptors.cpp:485
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1621
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1618
armnn::DataLayout::NDHWC
@ NDHWC
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1613
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1622
armnn::DataLayout::NCHW
@ NCHW