ArmNN
 24.02
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
[legend]
Collaboration diagram for BatchMatMulDescriptor:
[legend]

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
 
bool operator== (const BatchMatMulDescriptor &rhs) const
 
- Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
 
virtual ~BaseDescriptor ()=default
 

Static Public Member Functions

static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication. More...
 
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together. More...
 
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed. More...
 

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_TransposeY
 
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_AdjointY
 
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More...
 
DataLayout m_DataLayoutY
 

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1584 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool  transposeX = false,
bool  transposeY = false,
bool  adjointX = false,
bool  adjointY = false,
DataLayout  dataLayoutX = DataLayout::NCHW,
DataLayout  dataLayoutY = DataLayout::NCHW 
)
inline

Definition at line 1586 of file Descriptors.hpp.

1592  : m_TransposeX(transposeX)
1593  , m_TransposeY(transposeY)
1594  , m_AdjointX(adjointX)
1595  , m_AdjointY(adjointY)
1596  , m_DataLayoutX(dataLayoutX)
1597  , m_DataLayoutY(dataLayoutY)
1598  {}

Member Function Documentation

◆ GetAxesNotMul()

std::vector< unsigned int > GetAxesNotMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 505 of file Descriptors.cpp.

508 {
509  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
510  std::vector<unsigned int> axesNotMul;
511  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
512  {
513  if(i == axesToMul.first || i == axesToMul.second)
514  {
515  continue;
516  }
517  axesNotMul.push_back(i);
518  }
519  return axesNotMul;
520 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulQueueDescriptor::Validate().

◆ GetAxesToMul()

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 484 of file Descriptors.cpp.

487 {
488  auto numDims = tensorShape.GetNumDimensions();
489  std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
490  switch(dataLayout)
491  {
492  case DataLayout::NDHWC:
493  case DataLayout::NHWC:
494  axes.first -= 1;
495  axes.second -= 1;
496  break;
497  case DataLayout::NCDHW:
498  case DataLayout::NCHW:
499  default:
500  break;
501  }
502  return axes;
503 }

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

Referenced by BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes which will be transposed.

Definition at line 522 of file Descriptors.cpp.

525 {
526  std::vector<unsigned int> vec;
527  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
528  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
529  {
530  if(i == axesToMul.first)
531  {
532  vec.push_back(i+1);
533  }
534  else if(i == axesToMul.second)
535  {
536  vec.push_back(i-1);
537  }
538  else
539  {
540  vec.push_back(i);
541  }
542  }
543  return PermutationVector(vec.data(),
544  static_cast<unsigned int>(vec.size()));
545 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ operator==()

bool operator== ( const BatchMatMulDescriptor rhs) const
inline

Definition at line 1600 of file Descriptors.hpp.

1601  {
1602  return m_TransposeX == rhs.m_TransposeX &&
1603  m_TransposeY == rhs.m_TransposeY &&
1604  m_AdjointX == rhs.m_AdjointX &&
1605  m_AdjointY == rhs.m_AdjointY &&
1606  m_DataLayoutX == rhs.m_DataLayoutX &&
1607  m_DataLayoutY == rhs.m_DataLayoutY;
1608  }

References BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, and BatchMatMulDescriptor::m_TransposeY.

Member Data Documentation

◆ m_AdjointX

◆ m_AdjointY

◆ m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

◆ m_TransposeY


The documentation for this struct was generated from the following files:
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1612
armnn::DataLayout::NCDHW
@ NCDHW
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1617
armnn::DataLayout::NHWC
@ NHWC
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
Definition: Descriptors.cpp:484
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1621
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1618
armnn::DataLayout::NDHWC
@ NDHWC
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1613
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1622
armnn::DataLayout::NCHW
@ NCHW