ArmNN
 25.11
Loading...
Searching...
No Matches
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
[legend]
Collaboration diagram for BatchMatMulDescriptor:
[legend]

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
bool operator== (const BatchMatMulDescriptor &rhs) const
Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
virtual ~BaseDescriptor ()=default

Static Public Member Functions

static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication.
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together.
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed.

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.
bool m_TransposeY
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.
bool m_AdjointY
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
DataLayout m_DataLayoutY

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1584 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool transposeX = false,
bool transposeY = false,
bool adjointX = false,
bool adjointY = false,
DataLayout dataLayoutX = DataLayout::NCHW,
DataLayout dataLayoutY = DataLayout::NCHW )
inline

Definition at line 1586 of file Descriptors.hpp.

1592 : m_TransposeX(transposeX)
1593 , m_TransposeY(transposeY)
1594 , m_AdjointX(adjointX)
1595 , m_AdjointY(adjointY)
1596 , m_DataLayoutX(dataLayoutX)
1597 , m_DataLayoutY(dataLayoutY)
1598 {}

References m_AdjointX, m_AdjointY, m_DataLayoutX, m_DataLayoutY, m_TransposeX, m_TransposeY, and armnn::NCHW.

Referenced by operator==().

Member Function Documentation

◆ GetAxesNotMul()

std::vector< unsigned int > GetAxesNotMul ( DataLayout dataLayout,
const TensorShape & tensorShape )
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 506 of file Descriptors.cpp.

509{
510 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
511 std::vector<unsigned int> axesNotMul;
512 for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
513 {
514 if(i == axesToMul.first || i == axesToMul.second)
515 {
516 continue;
517 }
518 axesNotMul.push_back(i);
519 }
520 return axesNotMul;
521}

References GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulQueueDescriptor::Validate().

◆ GetAxesToMul()

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout dataLayout,
const TensorShape & tensorShape )
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 485 of file Descriptors.cpp.

488{
489 auto numDims = tensorShape.GetNumDimensions();
490 std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
491 switch(dataLayout)
492 {
493 case DataLayout::NDHWC:
494 case DataLayout::NHWC:
495 axes.first -= 1;
496 axes.second -= 1;
497 break;
498 case DataLayout::NCDHW:
499 case DataLayout::NCHW:
500 default:
501 break;
502 }
503 return axes;
504}

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

Referenced by GetAxesNotMul(), GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout dataLayout,
const TensorShape & tensorShape )
static

Static helper to get the axes which will be transposed.

Definition at line 523 of file Descriptors.cpp.

526{
527 std::vector<unsigned int> vec;
528 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
529 for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
530 {
531 if(i == axesToMul.first)
532 {
533 vec.push_back(i+1);
534 }
535 else if(i == axesToMul.second)
536 {
537 vec.push_back(i-1);
538 }
539 else
540 {
541 vec.push_back(i);
542 }
543 }
544 return PermutationVector(vec.data(),
545 static_cast<unsigned int>(vec.size()));
546}

References GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by ConvertBatchMatMulToTosaOperator(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ operator==()

bool operator== ( const BatchMatMulDescriptor & rhs) const
inline

Definition at line 1600 of file Descriptors.hpp.

1601 {
1602 return m_TransposeX == rhs.m_TransposeX &&
1603 m_TransposeY == rhs.m_TransposeY &&
1604 m_AdjointX == rhs.m_AdjointX &&
1605 m_AdjointY == rhs.m_AdjointY &&
1606 m_DataLayoutX == rhs.m_DataLayoutX &&
1607 m_DataLayoutY == rhs.m_DataLayoutY;
1608 }

References BatchMatMulDescriptor(), m_AdjointX, m_AdjointY, m_DataLayoutX, m_DataLayoutY, m_TransposeX, and m_TransposeY.

Member Data Documentation

◆ m_AdjointX

bool m_AdjointX

◆ m_AdjointY

◆ m_DataLayoutX

DataLayout m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

◆ m_TransposeY


The documentation for this struct was generated from the following files: