ArmNN
 24.08
BatchMatMulImpl.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022, 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulImpl.hpp"
7 
9 #include <armnn/Logging.hpp>
10 #include <armnnUtils/Permute.hpp>
11 
12 namespace armnn
13 {
14 
16  const TensorInfo& inputXInfo,
17  const TensorInfo& inputYInfo,
18  const TensorInfo& outputInfo,
19  Decoder<float>& inputXDecoder,
20  Decoder<float>& inputYDecoder,
21  Encoder<float>& outputEncoder)
22  : params(params),
23  inputXInfo(inputXInfo),
24  inputYInfo(inputYInfo),
25  outputInfo(outputInfo),
26  inputXDecoder(inputXDecoder),
27  inputYDecoder(inputYDecoder),
28  outputEncoder(outputEncoder)
29 {
30  inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31  inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
32  // At this point, we don't touch the input decoders - just the resultant vectors
33 
34  ApplyParams();
35 
36  ApplyBatchMatMul();
37 }
38 
39 void BatchMatMul::ApplyBatchMatMul()
40 {
41  auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
42  inputXInfo.GetShape());
43  auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
44  inputYInfo.GetShape());
45 
46  // the inputYRowSize (or inputXColSize) needs to be obtained using the original (unadjusted) axis value,
47  // because it's obtained from the original tensor shape
48  unsigned int inputYRowSize = inputYInfo.GetShape()[axesYToMul.first];
49 
50  AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
51 
52  unsigned int inputXColDim = axesXToMul.second;
53  unsigned int inputYRowDim = axesYToMul.first;
54 
55  auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
56  {
57  float sum = 0.0f;
58 
59  // InputYRowSize is synonymous with inputXColSize
60  for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
61  auto xIdx = curIdx;
62  xIdx[inputXColDim] = inputYRowIdx;
63 
64  auto yIdx = curIdx;
65  yIdx[inputYRowDim] = inputYRowIdx;
66 
67  sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
68  }
69 
70  SetValueAt(sum, DataSlot::Output, curIdx);
71  };
72 
73  auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
74  RecurseTensor(outputInfo,
75  batchMatMulOperation,
76  startIdx,
77  0);
78 }
79 
80 void BatchMatMul::ApplyParams()
81 {
82  if(params.m_TransposeX)
83  {
84  Transpose(DataSlot::InputX);
85  }
86  else if(params.m_AdjointX)
87  {
88  Adjoint(DataSlot::InputX);
89  }
90  if(params.m_TransposeY)
91  {
92  Transpose(DataSlot::InputY);
93  }
94  else if(params.m_AdjointY)
95  {
96  Adjoint(DataSlot::InputY);
97  }
98 }
99 
100 void BatchMatMul::Transpose(DataSlot type)
101 {
102  // AKA the permute of the tensor
103  // This modifies the tensor's info.
104 
105  switch(type)
106  {
107  case DataSlot::InputX:
108  {
109  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
110  inputXInfo.GetShape());
111  inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
112  std::vector<float> temp(inputXData.size());
113  armnnUtils::Permute(inputXInfo.GetShape(),
114  permuteVec,
115  inputXData.data(),
116  temp.data(),
117  sizeof(float));
118  inputXData = temp;
119  break;
120  }
121  case DataSlot::InputY:
122  {
123  auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
124  inputYInfo.GetShape());
125  inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
126  std::vector<float> temp(inputYData.size());
127  armnnUtils::Permute(inputYInfo.GetShape(),
128  permuteVec,
129  inputYData.data(),
130  temp.data(),
131  sizeof(float));
132  inputYData = temp;
133  break;
134  }
135  case DataSlot::Output: // We needn't transpose the output tensor
136  default:
137  break;
138  }
139 }
140 
141 void BatchMatMul::Adjoint(DataSlot type)
142 {
143  // Finding the adjoint of a square matrix:
144  // Calculate the cofactor of each element (using Gauss elimination here)
145  // Apply a transpose to it (this also modifies the tensor's info)
146 
147  TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
148  const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
149  const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
150 
151  // We grab a copy of the tensor data to prevent overwriting
152  std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
153 
154  // The sub-matrix is the resultant matrix when the row and column of the current index is removed
155  unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
156  std::vector<std::vector<float>> subMat(subMatAxisSize,
157  std::vector<float>(subMatAxisSize));
158 
159  // Lambdas for each sub-step of the cofactor operation
160  auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
161  {
162  float diff = std::fabs(a-b);
163  float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
164  return (diff <= bound) || (diff < std::numeric_limits<float>::min());
165  };
166 
167  float swapMultiplier = std::numeric_limits<float>::max();
168  auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
169  {
170  // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
171  for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
172  {
173  float tmp = subMat[rowIdxA][colIdx];
174  subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
175  subMat[rowIdxB][colIdx] = tmp;
176  }
177  swapMultiplier *= -1.0f;
178  };
179 
180  auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
181  {
182  unsigned int result = std::numeric_limits<unsigned int>::max();
183 
184  // The original diagonal has been checked and is invalid
185  for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
186  {
187  if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
188  {
189  result = rowIdx;
190  break;
191  }
192  }
193  return result;
194  };
195 
196  auto eliminate = [&](const float& pivot, unsigned int pivotPos)
197  {
198  for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
199  {
200  float multiplierNumerator = subMat[rowIdx][pivotPos];
201  if(almostEquals(multiplierNumerator, 0.0f))
202  {
203  continue;
204  }
205  float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
206  // Hence the almostEquals usage to counteract this
207  for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
208  {
209  // We start at col=pivotPos as we have assumed that all elements
210  // to our left have been eliminated to zero already
211 
212  // We subtract based on the element directly above us in our pivot row
213  subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
214  }
215  }
216  };
217 
218  auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
219  {
220  auto row = curIdx[axesToAdjoint.first];
221  auto col = curIdx[axesToAdjoint.second];
222 
223  float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
224 
225  for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
226  {
227  for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
228  {
229  unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
230  unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
231  auto cloneIdx = curIdx;
232  cloneIdx[axesToAdjoint.first] = outerRow;
233  cloneIdx[axesToAdjoint.second] = outerCol;
234  subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
235  }
236  }
237 
238  float determinant = 1.0f;
239 
240  // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
241  switch(subMatAxisSize)
242  {
243  case 0:
244  {
245  determinant = GetValueAt(type, curIdx, inputDataClone);
246  break;
247  }
248  case 1:
249  {
250  // If the resultant sub-matrix is just one element - that's the determinant
251  determinant = subMat[0][0];
252  break;
253  }
254  case 2:
255  {
256  // For a 2x2 sub-matrix, the determinant is just a*d-b*c
257  determinant = subMat[0][0] * subMat[1][1] -
258  subMat[0][1] * subMat[1][0];
259  break;
260  }
261  default:
262  {
263  // Gaussian elimination to find the determinant of this sub-matrix
264  swapMultiplier = 1.0f;
265  // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
266  // nearest non-zero down within the column
267  for(unsigned int pivotRow = 0, pivotCol = 0;
268  pivotRow < subMatAxisSize;
269  pivotRow++, pivotCol++)
270  {
271  float& pivot = subMat[pivotRow][pivotCol];
272 
273  if(almostEquals(pivot, 0.0f))
274  {
275  unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
276  if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
277  {
278  // No valid pivot down this column, which means that this pivot remains a zero.
279  // This results in the determinant for this entire sub-matrix to just be zero.
280  determinant = 0.0f;
281  break;
282  }
283  swapRows(pivotRow, nextValidPivotRowIdx);
284  }
285  determinant *= pivot;
286  // The actual elimination bit (which will update/propagate to the pivots down the line)
287  eliminate(pivot, pivotRow); // Synonymous with pivotCol
288  }
289 
290  determinant *= swapMultiplier;
291  break;
292  }
293  }
294  float cofactor = minorMultiplier * determinant;
295  SetValueAt(cofactor, type, curIdx);
296  };
297 
298  auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
299  RecurseTensor(inputInfo,
300  cofactorOperation,
301  startIdx,
302  0);
303 
304  Transpose(type);
305 }
306 
307 void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
308  const std::function<void(const std::vector<unsigned int>&)>& operation,
309  std::vector<unsigned int>& curIdx,
310  unsigned int curDim)
311 {
312  if(!(curDim < tensorInfo.GetNumDimensions()))
313  {
314  // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
315  operation(curIdx);
316  return;
317  }
318 
319  for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
320  {
321  curIdx[curDim] = i;
322  RecurseTensor(tensorInfo,
323  operation,
324  curIdx,
325  curDim + 1);
326  }
327 }
328 
329 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
330  std::pair<unsigned int, unsigned int>& axesYToMul)
331 {
332  int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
333  static_cast<int>(inputYInfo.GetNumDimensions());
334  if(rankDiff == 0)
335  {
336  return;
337  }
338  else if(rankDiff < 0)
339  {
340  // Y is the larger one
341  axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
342  axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
343  }
344  else if(rankDiff > 0)
345  {
346  // X is the larger one
347  axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
348  axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
349  }
350 }
351 
352 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
353 {
354  // This gets the data from the input vector that we have, Not the decoder
355  // But for the output, it is operating on the encoder itself
356 
357  AdjustToSafeIdx(type, idx);
358  unsigned int flatIdx = CalcFlatIdx(type, idx);
359  float value = 0.0f;
360  switch(type)
361  {
362  case DataSlot::InputX:
363  value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
364  break;
365  case DataSlot::InputY:
366  value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
367  break;
368  case DataSlot::Output:
369  outputEncoder[flatIdx];
370  value = outputEncoder.Get();
371  break;
372  default:
373  break;
374  }
375 
376  return value;
377 }
378 
379 void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
380 {
381  AdjustToSafeIdx(type, idx);
382  unsigned int flatIdx = CalcFlatIdx(type, idx);
383  switch(type)
384  {
385  case DataSlot::InputX:
386  inputXData[flatIdx] = value;
387  break;
388  case DataSlot::InputY:
389  inputYData[flatIdx] = value;
390  break;
391  case DataSlot::Output:
392  outputEncoder[flatIdx];
393  outputEncoder.Set(value);
394  break;
395  default:
396  break;
397  }
398 }
399 
400 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
401 {
402  for(unsigned int dim = 0; dim < idx.size(); dim++)
403  {
404  switch(type)
405  {
406  case DataSlot::InputX:
407  {
408  auto xRank = inputXInfo.GetNumDimensions();
409  auto xDiff = outputInfo.GetNumDimensions() - xRank;
410  if (dim < xDiff ||
411  idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
412  {
413  idx[dim] = 0; // Broadcasting
414  }
415  break;
416  }
417  case DataSlot::InputY:
418  {
419  auto yRank = inputYInfo.GetNumDimensions();
420  auto yDiff = outputInfo.GetNumDimensions() - yRank;
421  if (dim < yDiff ||
422  idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
423  {
424  idx[dim] = 0;
425  }
426  break;
427  }
428  case DataSlot::Output:
429  {
430  // Our indices are based off the output
431  break;
432  }
433  default:
434  break;
435  }
436  }
437 }
438 
439 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
440 {
441  unsigned int result = idx[idx.size()-1];
442  unsigned int dimMultiplier = 1;
443  unsigned int offset = 0;
444 
445  // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
446  // Check offset in relation to i, to stop calculating flat index once all input shape fields considered
447  for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0 && (i + 1) > offset; i--)
448  {
449  switch(type)
450  {
451  case DataSlot::InputX:
452  offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
453  dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
454  break;
455  case DataSlot::InputY:
456  offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
457  dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
458  break;
459  case DataSlot::Output:
460  dimMultiplier *= outputInfo.GetShape()[i+1];
461  break;
462  default:
463  break;
464  }
465  result += (idx[i] * dimMultiplier);
466  }
467  return result;
468 }
469 
470 } // namespace armnn
armnn::Decoder< float >
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1612
armnn::Encoder::Set
virtual void Set(IType right)=0
WorkloadData.hpp
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1617
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
Definition: Descriptors.cpp:485
BatchMatMulImpl.hpp
armnn::Encoder::Get
virtual IType Get() const =0
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1621
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::BatchMatMulDescriptor::GetPermuteVec
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
Definition: Descriptors.cpp:523
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1618
armnnUtils::Permute
void Permute(const armnn::TensorShape &dstShape, const armnn::PermutationVector &mappings, const void *src, void *dst, size_t dataTypeSize)
Definition: Permute.cpp:164
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:125
armnn::Encoder< float >
Logging.hpp
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1613
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1622
armnn::BatchMatMulDescriptor
A BatchMatMulDescriptor for the BatchMatMul operator.
Definition: Descriptors.hpp:1584
armnn::Decoder::DecodeTensor
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
Permute.hpp
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::BatchMatMul::BatchMatMul
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
Definition: BatchMatMulImpl.cpp:15