ArmNN
 24.08
BatchMatMulOperator.hpp File Reference
Include dependency graph for BatchMatMulOperator.hpp:
This graph shows which files directly or indirectly include this file:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertBatchMatMulToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const BatchMatMulDescriptor *descriptor=nullptr)
 

Function Documentation

◆ ConvertBatchMatMulToTosaOperator()

TosaSerializationBasicBlock* ConvertBatchMatMulToTosaOperator ( const Layer layer,
const std::vector< const TensorInfo * > &  inputs,
const std::vector< const TensorInfo * > &  outputs,
const BatchMatMulDescriptor descriptor = nullptr 
)

Definition at line 14 of file BatchMatMulOperator.cpp.

18 {
19  if (descriptor->m_AdjointX || descriptor->m_AdjointY )
20  {
21  throw Exception("Support for adjoint not implemented.");
22  }
23  if (descriptor->m_DataLayoutX != armnn::DataLayout::NCHW || descriptor->m_DataLayoutY != armnn::DataLayout::NCHW )
24  {
25  throw Exception("MatMul only supported in the last 2 dimensions");
26  }
27 
28  std::string input0Name = std::string("input_0");
29  std::string input1Name = std::string("input_1");
30  std::string outputName = std::string("output_0");
31  std::string outputReshape0Name = std::string("intermediate0_") + GetUniqueTosaMappingID();
32  std::string outputReshape1Name = std::string("intermediate0_") + GetUniqueTosaMappingID();
33  std::string outputTranspose0Name = std::string("intermediate1_") + GetUniqueTosaMappingID();
34  std::string outputTranspose1Name = std::string("intermediate1_") + GetUniqueTosaMappingID();
35 
36  std::string blockName = std::string("Op_BATCHMATMUL_block_") + GetUniqueTosaMappingID();
37 
38  // If a layer is present then the block will be used for execution, so input and output names need to be determined
39  // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
40  if(layer != nullptr)
41  {
42  // Get the layer connected to the input slot and determine unique tensor names.
43  input0Name = GenerateUniqueInputName(layer->GetInputSlot(0));
44  input1Name = GenerateUniqueInputName(layer->GetInputSlot(1));
45  outputName = GenerateUniqueOutputName(*layer);
46  }
47 
48  // Assumes both input types are same data type
49  DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
50  bool isInputInt8 = (inputDType == DType_INT8);
51  bool isInputInt16 = (inputDType == DType_INT16);
52 
53  std::vector<TosaSerializationTensor*> tensors;
54  std::vector<TosaSerializationOperator*> operators;
55 
56  // Only add input tensors if connected layer is an input layer.
57  // As intermediate or constant tensors will be created separately.
58  // There also can't be duplicate tensor.
59  if(input0Name.find("input_") != std::string::npos)
60  {
61  std::vector<int32_t> inputShape0 = GetTosaTensorShape(inputs[0]->GetShape());
62  tensors.push_back(new TosaSerializationTensor(input0Name, inputShape0, inputDType, {}));
63  }
64  if(input1Name.find("input_") != std::string::npos)
65  {
66  std::vector<int32_t> inputShape1 = GetTosaTensorShape(inputs[1]->GetShape());
67  tensors.push_back(new TosaSerializationTensor(input1Name, inputShape1, inputDType, {}));
68  }
69 
70  std::string input0TransposeName = input0Name;
71  std::string input1TransposeName = input1Name;
72  std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
73 
74  std::string input0MatMulName = input0Name;
75  std::string input1MatMulName = input1Name;
76 
77  // *** ADD OP STEPS ***
78 
79  // ADD a RESHAPE OPs if BATCH DIMS > 1
80  // RESHAPE input 1
81  std::vector<int32_t> targetShape0 = GetTosaTensorShape(outputs[0]->GetShape());
82  std::vector<int32_t> transpose0Shape = GetTosaTensorShape(inputs[0]->GetShape());
83  uint32_t input0Dimensions = inputs[0]->GetNumDimensions();
84  if (input0Dimensions > 3)
85  {
86  uint32_t x = 1;
87  for (uint32_t i = 0; i < (input0Dimensions - 2); ++i)
88  {
89  x *=(inputs[0]->GetShape()[i]);
90  }
91 
92  targetShape0 = {static_cast<int32_t>(x),
93  static_cast<int32_t>(inputs[0]->GetShape()[input0Dimensions - 2]),
94  static_cast<int32_t>(inputs[0]->GetShape()[input0Dimensions - 1])};
95 
96  TosaReshapeAttribute attribute(targetShape0);
97 
98  auto* input0ReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
99  Attribute_ReshapeAttribute,
100  &attribute,
101  {input0Name},
102  {outputReshape0Name});
103 
104  operators.push_back(input0ReshapeOp);
105  transpose0Shape = targetShape0;
106  tensors.push_back(new TosaSerializationTensor(outputReshape0Name, targetShape0, inputDType, {}));
107  input0TransposeName = outputReshape0Name;
108  input0MatMulName = outputReshape0Name;
109  }
110 
111  // RESHAPE input 2
112  std::vector<int32_t> targetShape1 = GetTosaTensorShape(outputs[0]->GetShape());
113  std::vector<int32_t> transpose1Shape = GetTosaTensorShape(inputs[1]->GetShape());
114  uint32_t input1Dimensions = inputs[1]->GetNumDimensions();
115  if (input1Dimensions > 3)
116  {
117  uint32_t x = 1;
118  for (uint32_t i = 0; i < (input1Dimensions - 2); i++)
119  {
120  x *= (inputs[1]->GetShape()[i]);
121  }
122 
123  targetShape1 = {static_cast<int32_t>(x),
124  static_cast<int32_t>(inputs[1]->GetShape()[input1Dimensions - 2]),
125  static_cast<int32_t>(inputs[1]->GetShape()[input1Dimensions - 1])};
126 
127  TosaReshapeAttribute attribute(targetShape1);
128 
129  auto* input1ReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
130  Attribute_ReshapeAttribute,
131  &attribute,
132  {input1Name},
133  {outputReshape1Name});
134 
135  operators.push_back(input1ReshapeOp);
136  transpose1Shape = targetShape1;
137  tensors.push_back(new TosaSerializationTensor(outputReshape1Name, targetShape1, inputDType, {}));
138  input1TransposeName = outputReshape1Name;
139  input1MatMulName = outputReshape1Name;
140  }
141  bool needsReshape = input0Dimensions > 3 || input1Dimensions > 3;
142 
143  // ADD a TRANSPOSE OP for one/both inputs if transpose set to true
144  if (descriptor->m_TransposeX)
145  {
146  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->m_DataLayoutX,
147  inputs[0]->GetShape());
148 
149  std::vector<int32_t> mappings(permuteVec.begin(),
150  permuteVec.end());
151  TosaTransposeAttribute transposeAttribute(mappings);
152 
153  TosaSerializationOperator *transposeOp = new TosaSerializationOperator(Op_TRANSPOSE,
154  Attribute_TransposeAttribute,
155  &transposeAttribute,
156  {input0TransposeName},
157  {outputTranspose0Name});
158  operators.push_back(transposeOp);
159  tensors.push_back(new TosaSerializationTensor(outputTranspose0Name, transpose0Shape, inputDType, {}));
160  input0MatMulName = outputTranspose0Name;
161  }
162 
163  if (descriptor->m_TransposeY)
164  {
165  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(descriptor->m_DataLayoutY,
166  inputs[1]->GetShape());
167 
168 
169  std::vector<int32_t> mappings(permuteVec.begin(),
170  permuteVec.end());
171  TosaTransposeAttribute transposeAttribute(mappings);
172 
173  TosaSerializationOperator *transposeOp = new TosaSerializationOperator(Op_TRANSPOSE,
174  Attribute_TransposeAttribute,
175  &transposeAttribute,
176  {input1TransposeName},
177  {outputTranspose1Name});
178  operators.push_back(transposeOp);
179  tensors.push_back(new TosaSerializationTensor(outputTranspose1Name, transpose1Shape, inputDType, {}));
180  input1MatMulName = outputTranspose1Name;
181  }
182 
183  // ADD MAT MUL layer
184  std::string matMulOutputStr = needsReshape || isInputInt8 || isInputInt16 ?
185  std::string("intermediate2_") + GetUniqueTosaMappingID() : outputName;
186 
187  TosaMatMulAttribute matMulAttribute(0,0); // input0_zp, input1_zp
188  DType matMulOutDType = ArmNNToDType(inputs[1]->GetDataType());
189  if (isInputInt8)
190  {
191  matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
192  matMulOutDType = DType_INT32;
193  }
194  if (isInputInt16)
195  {
196  matMulAttribute = TosaMatMulAttribute(inputs[0]->GetQuantizationOffset(), inputs[1]->GetQuantizationOffset());
197  matMulOutDType = DType_INT48;
198  }
199  TosaSerializationOperator* matMulOp = new TosaSerializationOperator(Op_MATMUL,
200  Attribute_MatMulAttribute,
201  &matMulAttribute,
202  {input0MatMulName, input1MatMulName},
203  {matMulOutputStr});
204 
205  operators.push_back(matMulOp);
206  tensors.push_back(new TosaSerializationTensor(matMulOutputStr, targetShape0, matMulOutDType, {}));
207 
208  std::string outputRescale = needsReshape ?
209  std::string("intermediate3_") + GetUniqueTosaMappingID() : outputName;
210  std::string inputReshape2Name = isInputInt8 || isInputInt16 ? outputRescale : matMulOutputStr;
211 
212  // ADD Rescale layer if it is int8
213  if (isInputInt8 || isInputInt16)
214  {
215  bool scale32 = isInputInt16 ? false : true;
216  bool doubleRound = isInputInt16 ? false : true;
217 
218  double scale_alpha = inputs[0]->GetQuantizationScale() / outputs[0]->GetQuantizationScale();
219  int32_t input_zp = inputs[0]->GetQuantizationOffset();
220  int32_t output_zp = outputs[0]->GetQuantizationOffset();
221 
222  TosaSerializationOperator* rescaleOp = nullptr;
223  CreateRescaleTosaOperator(matMulOutputStr,
224  outputRescale,
225  scale_alpha,
226  input_zp,
227  output_zp,
228  doubleRound,
229  scale32,
230  &rescaleOp);
231 
232  tensors.push_back(new TosaSerializationTensor(outputRescale,
233  targetShape0,
234  inputDType, {}));
235 
236  operators.push_back(rescaleOp);
237  }
238 
239  // ADD a RESHAPE back to expected rank
240  if (needsReshape)
241  {
242  const std::vector<int32_t>& targetShape = GetTosaTensorShape(TensorShape(outputs[0]->GetShape()));
243  TosaReshapeAttribute attribute(targetShape);
244 
245  auto* outputReshapeOp = new TosaSerializationOperator(Op_RESHAPE,
246  Attribute_ReshapeAttribute,
247  &attribute,
248  {inputReshape2Name},
249  {outputName});
250 
251  operators.push_back(outputReshapeOp);
252  tensors.push_back(new TosaSerializationTensor(outputName, targetShape, inputDType, {}));
253  }
254 
255  return new TosaSerializationBasicBlock(blockName, // name
256  mainName, // region name
257  {operators}, // operators
258  tensors, // tensors
259  {input0Name, input1Name}, // inputs
260  {outputName}); // outputs
261 }

References ArmNNToDType(), GenerateUniqueInputName(), GenerateUniqueOutputName(), Layer::GetInputSlot(), GetTosaTensorShape(), GetUniqueTosaMappingID(), BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, and armnn::NCHW.

Referenced by GetTosaMapping().

armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1612
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1617
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1621
GenerateUniqueOutputName
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
Definition: TosaOperatorUtils.hpp:120
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1618
armnn::Layer::GetInputSlot
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:337
mainName
const std::string mainName
Definition: TosaOperatorUtils.hpp:19
armnn::TensorShape
Definition: Tensor.hpp:20
ArmNNToDType
DType ArmNNToDType(const DataType &type)
Definition: TosaOperatorUtils.hpp:22
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1613
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1622
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
GetTosaTensorShape
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
Definition: TosaOperatorUtils.hpp:79
CreateRescaleTosaOperator
void CreateRescaleTosaOperator(const std::string &inputName, const std::string &outputName, const std::vector< int32_t > &multipliers, const std::vector< int32_t > &shifts, int32_t input_zp, int32_t output_zp, bool double_round, bool scale32, bool per_channel, TosaSerializationOperator **op)
Definition: TosaRescaleOperatorUtils.hpp:10
GenerateUniqueInputName
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
Definition: TosaOperatorUtils.hpp:109
armnn::DataLayout::NCHW
@ NCHW
GetUniqueTosaMappingID
std::string GetUniqueTosaMappingID()
Definition: TosaOperatorUtils.hpp:138