ArmNN
 25.11
Loading...
Searching...
No Matches
BatchMatMulImpl.cpp
Go to the documentation of this file.
1//
2// Copyright © 2022, 2024 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulImpl.hpp"
7
9#include <armnn/Logging.hpp>
11
12namespace armnn
13{
14
16 const TensorInfo& inputXInfo,
17 const TensorInfo& inputYInfo,
18 const TensorInfo& outputInfo,
19 Decoder<float>& inputXDecoder,
20 Decoder<float>& inputYDecoder,
21 Encoder<float>& outputEncoder)
22 : params(params),
23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
29{
30 inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
31 inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
32 // At this point, we don't touch the input decoders - just the resultant vectors
33
34 ApplyParams();
35
36 ApplyBatchMatMul();
37}
38
39void BatchMatMul::ApplyBatchMatMul()
40{
42 inputXInfo.GetShape());
44 inputYInfo.GetShape());
45
46 // the inputYRowSize (or inputXColSize) needs to be obtained using the original (unadjusted) axis value,
47 // because it's obtained from the original tensor shape
48 unsigned int inputYRowSize = inputYInfo.GetShape()[axesYToMul.first];
49
50 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
51
52 unsigned int inputXColDim = axesXToMul.second;
53 unsigned int inputYRowDim = axesYToMul.first;
54
55 auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
56 {
57 float sum = 0.0f;
58
59 // InputYRowSize is synonymous with inputXColSize
60 for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
61 auto xIdx = curIdx;
62 xIdx[inputXColDim] = inputYRowIdx;
63
64 auto yIdx = curIdx;
65 yIdx[inputYRowDim] = inputYRowIdx;
66
67 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
68 }
69
70 SetValueAt(sum, DataSlot::Output, curIdx);
71 };
72
73 auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
74 RecurseTensor(outputInfo,
75 batchMatMulOperation,
76 startIdx,
77 0);
78}
79
80void BatchMatMul::ApplyParams()
81{
82 if(params.m_TransposeX)
83 {
84 Transpose(DataSlot::InputX);
85 }
86 else if(params.m_AdjointX)
87 {
88 Adjoint(DataSlot::InputX);
89 }
90 if(params.m_TransposeY)
91 {
92 Transpose(DataSlot::InputY);
93 }
94 else if(params.m_AdjointY)
95 {
96 Adjoint(DataSlot::InputY);
97 }
98}
99
100void BatchMatMul::Transpose(DataSlot type)
101{
102 // AKA the permute of the tensor
103 // This modifies the tensor's info.
104
105 switch(type)
106 {
107 case DataSlot::InputX:
108 {
109 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
110 inputXInfo.GetShape());
111 inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
112 std::vector<float> temp(inputXData.size());
113 armnnUtils::Permute(inputXInfo.GetShape(),
114 permuteVec,
115 inputXData.data(),
116 temp.data(),
117 sizeof(float));
118 inputXData = temp;
119 break;
120 }
121 case DataSlot::InputY:
122 {
123 auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
124 inputYInfo.GetShape());
125 inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
126 std::vector<float> temp(inputYData.size());
127 armnnUtils::Permute(inputYInfo.GetShape(),
128 permuteVec,
129 inputYData.data(),
130 temp.data(),
131 sizeof(float));
132 inputYData = temp;
133 break;
134 }
135 case DataSlot::Output: // We needn't transpose the output tensor
136 default:
137 break;
138 }
139}
140
141void BatchMatMul::Adjoint(DataSlot type)
142{
143 // Finding the adjoint of a square matrix:
144 // Calculate the cofactor of each element (using Gauss elimination here)
145 // Apply a transpose to it (this also modifies the tensor's info)
146
147 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
148 const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
149 const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
150
151 // We grab a copy of the tensor data to prevent overwriting
152 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
153
154 // The sub-matrix is the resultant matrix when the row and column of the current index is removed
155 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
156 std::vector<std::vector<float>> subMat(subMatAxisSize,
157 std::vector<float>(subMatAxisSize));
158
159 // Lambdas for each sub-step of the cofactor operation
160 auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
161 {
162 float diff = std::fabs(a-b);
163 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
164 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
165 };
166
167 float swapMultiplier = std::numeric_limits<float>::max();
168 auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
169 {
170 // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
171 for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
172 {
173 float tmp = subMat[rowIdxA][colIdx];
174 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
175 subMat[rowIdxB][colIdx] = tmp;
176 }
177 swapMultiplier *= -1.0f;
178 };
179
180 auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
181 {
182 unsigned int result = std::numeric_limits<unsigned int>::max();
183
184 // The original diagonal has been checked and is invalid
185 for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
186 {
187 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
188 {
189 result = rowIdx;
190 break;
191 }
192 }
193 return result;
194 };
195
196 auto eliminate = [&](const float& pivot, unsigned int pivotPos)
197 {
198 for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
199 {
200 float multiplierNumerator = subMat[rowIdx][pivotPos];
201 if(almostEquals(multiplierNumerator, 0.0f))
202 {
203 continue;
204 }
205 float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
206 // Hence the almostEquals usage to counteract this
207 for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
208 {
209 // We start at col=pivotPos as we have assumed that all elements
210 // to our left have been eliminated to zero already
211
212 // We subtract based on the element directly above us in our pivot row
213 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
214 }
215 }
216 };
217
218 auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
219 {
220 auto row = curIdx[axesToAdjoint.first];
221 auto col = curIdx[axesToAdjoint.second];
222
223 float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
224
225 for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
226 {
227 for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
228 {
229 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
230 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
231 auto cloneIdx = curIdx;
232 cloneIdx[axesToAdjoint.first] = outerRow;
233 cloneIdx[axesToAdjoint.second] = outerCol;
234 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
235 }
236 }
237
238 float determinant = 1.0f;
239
240 // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
241 switch(subMatAxisSize)
242 {
243 case 0:
244 {
245 determinant = GetValueAt(type, curIdx, inputDataClone);
246 break;
247 }
248 case 1:
249 {
250 // If the resultant sub-matrix is just one element - that's the determinant
251 determinant = subMat[0][0];
252 break;
253 }
254 case 2:
255 {
256 // For a 2x2 sub-matrix, the determinant is just a*d-b*c
257 determinant = subMat[0][0] * subMat[1][1] -
258 subMat[0][1] * subMat[1][0];
259 break;
260 }
261 default:
262 {
263 // Gaussian elimination to find the determinant of this sub-matrix
264 swapMultiplier = 1.0f;
265 // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
266 // nearest non-zero down within the column
267 for(unsigned int pivotRow = 0, pivotCol = 0;
268 pivotRow < subMatAxisSize;
269 pivotRow++, pivotCol++)
270 {
271 float& pivot = subMat[pivotRow][pivotCol];
272
273 if(almostEquals(pivot, 0.0f))
274 {
275 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
276 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
277 {
278 // No valid pivot down this column, which means that this pivot remains a zero.
279 // This results in the determinant for this entire sub-matrix to just be zero.
280 determinant = 0.0f;
281 break;
282 }
283 swapRows(pivotRow, nextValidPivotRowIdx);
284 }
285 determinant *= pivot;
286 // The actual elimination bit (which will update/propagate to the pivots down the line)
287 eliminate(pivot, pivotRow); // Synonymous with pivotCol
288 }
289
290 determinant *= swapMultiplier;
291 break;
292 }
293 }
294 float cofactor = minorMultiplier * determinant;
295 SetValueAt(cofactor, type, curIdx);
296 };
297
298 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
299 RecurseTensor(inputInfo,
300 cofactorOperation,
301 startIdx,
302 0);
303
304 Transpose(type);
305}
306
307void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
308 const std::function<void(const std::vector<unsigned int>&)>& operation,
309 std::vector<unsigned int>& curIdx,
310 unsigned int curDim)
311{
312 if(!(curDim < tensorInfo.GetNumDimensions()))
313 {
314 // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
315 operation(curIdx);
316 return;
317 }
318
319 for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
320 {
321 curIdx[curDim] = i;
322 RecurseTensor(tensorInfo,
323 operation,
324 curIdx,
325 curDim + 1);
326 }
327}
328
329void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
330 std::pair<unsigned int, unsigned int>& axesYToMul)
331{
332 int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
333 static_cast<int>(inputYInfo.GetNumDimensions());
334 if(rankDiff == 0)
335 {
336 return;
337 }
338 else if(rankDiff < 0)
339 {
340 // Y is the larger one
341 axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
342 axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
343 }
344 else if(rankDiff > 0)
345 {
346 // X is the larger one
347 axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
348 axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
349 }
350}
351
352float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
353{
354 // This gets the data from the input vector that we have, Not the decoder
355 // But for the output, it is operating on the encoder itself
356
357 AdjustToSafeIdx(type, idx);
358 unsigned int flatIdx = CalcFlatIdx(type, idx);
359 float value = 0.0f;
360 switch(type)
361 {
362 case DataSlot::InputX:
363 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
364 break;
365 case DataSlot::InputY:
366 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
367 break;
368 case DataSlot::Output:
369 outputEncoder[flatIdx];
370 value = outputEncoder.Get();
371 break;
372 default:
373 break;
374 }
375
376 return value;
377}
378
379void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
380{
381 AdjustToSafeIdx(type, idx);
382 unsigned int flatIdx = CalcFlatIdx(type, idx);
383 switch(type)
384 {
385 case DataSlot::InputX:
386 inputXData[flatIdx] = value;
387 break;
388 case DataSlot::InputY:
389 inputYData[flatIdx] = value;
390 break;
391 case DataSlot::Output:
392 outputEncoder[flatIdx];
393 outputEncoder.Set(value);
394 break;
395 default:
396 break;
397 }
398}
399
400void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
401{
402 for(unsigned int dim = 0; dim < idx.size(); dim++)
403 {
404 switch(type)
405 {
406 case DataSlot::InputX:
407 {
408 auto xRank = inputXInfo.GetNumDimensions();
409 auto xDiff = outputInfo.GetNumDimensions() - xRank;
410 if (dim < xDiff ||
411 idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
412 {
413 idx[dim] = 0; // Broadcasting
414 }
415 break;
416 }
417 case DataSlot::InputY:
418 {
419 auto yRank = inputYInfo.GetNumDimensions();
420 auto yDiff = outputInfo.GetNumDimensions() - yRank;
421 if (dim < yDiff ||
422 idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
423 {
424 idx[dim] = 0;
425 }
426 break;
427 }
428 case DataSlot::Output:
429 {
430 // Our indices are based off the output
431 break;
432 }
433 default:
434 break;
435 }
436 }
437}
438
439unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
440{
441 unsigned int result = idx[idx.size()-1];
442 unsigned int dimMultiplier = 1;
443 unsigned int offset = 0;
444
445 // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
446 // Check offset in relation to i, to stop calculating flat index once all input shape fields considered
447 for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0 && (i + 1) > offset; i--)
448 {
449 switch(type)
450 {
451 case DataSlot::InputX:
452 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
453 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
454 break;
455 case DataSlot::InputY:
456 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
457 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
458 break;
459 case DataSlot::Output:
460 dimMultiplier *= outputInfo.GetShape()[i+1];
461 break;
462 default:
463 break;
464 }
465 result += (idx[i] * dimMultiplier);
466 }
467 return result;
468}
469
470} // namespace armnn
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
const TensorShape & GetShape() const
Definition Tensor.hpp:193
Copyright (c) 2021 ARM Limited and Contributors.
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition Permute.cpp:125
A BatchMatMulDescriptor for the BatchMatMul operator.
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)