23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
39 void BatchMatMul::ApplyBatchMatMul()
45 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
47 unsigned int inputXColDim = axesXToMul.second;
48 unsigned int inputYRowDim = axesYToMul.first;
50 unsigned int inputYRowSize = inputYInfo.
GetShape()[inputYRowDim];
52 auto batchMatMulOperation = [&](
const std::vector<unsigned int>& curIdx)
57 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
59 xIdx[inputXColDim] = inputYRowIdx;
62 yIdx[inputYRowDim] = inputYRowIdx;
64 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
67 SetValueAt(sum, DataSlot::Output, curIdx);
71 RecurseTensor(outputInfo,
77 void BatchMatMul::ApplyParams()
81 Transpose(DataSlot::InputX);
85 Adjoint(DataSlot::InputX);
89 Transpose(DataSlot::InputY);
93 Adjoint(DataSlot::InputY);
97 void BatchMatMul::Transpose(DataSlot type)
104 case DataSlot::InputX:
109 std::vector<float> temp(inputXData.size());
118 case DataSlot::InputY:
123 std::vector<float> temp(inputYData.size());
132 case DataSlot::Output:
138 void BatchMatMul::Adjoint(DataSlot type)
144 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
148 ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
150 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
153 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
154 std::vector<std::vector<float>> subMat(subMatAxisSize,
155 std::vector<float>(subMatAxisSize));
158 auto almostEquals = [&](
const float& a,
const float& b,
float unitsInLastPlace = 2.0f)
160 float diff = std::fabs(a-b);
161 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
162 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
165 float swapMultiplier = std::numeric_limits<float>::max();
166 auto swapRows = [&](
unsigned int rowIdxA,
unsigned int rowIdxB)
169 for(
unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
171 float tmp = subMat[rowIdxA][colIdx];
172 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
173 subMat[rowIdxB][colIdx] = tmp;
175 swapMultiplier *= -1.0f;
178 auto findNextValidPivotRowIdx = [&](
unsigned int colIdx)
180 unsigned int result = std::numeric_limits<unsigned int>::max();
183 for(
unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
185 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
194 auto eliminate = [&](
const float& pivot,
unsigned int pivotPos)
196 for(
unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
198 float multiplierNumerator = subMat[rowIdx][pivotPos];
199 if(almostEquals(multiplierNumerator, 0.0f))
203 float multiplier = multiplierNumerator / pivot;
205 for(
unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
211 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
216 auto cofactorOperation = [&](
const std::vector<unsigned int>& curIdx)
218 auto row = curIdx[axesToAdjoint.first];
219 auto col = curIdx[axesToAdjoint.second];
221 float minorMultiplier =
static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
223 for(
unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
225 for(
unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
227 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
228 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
229 auto cloneIdx = curIdx;
230 cloneIdx[axesToAdjoint.first] = outerRow;
231 cloneIdx[axesToAdjoint.second] = outerCol;
232 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
236 float determinant = 1.0f;
239 switch(subMatAxisSize)
243 determinant = GetValueAt(type, curIdx, inputDataClone);
249 determinant = subMat[0][0];
255 determinant = subMat[0][0] * subMat[1][1] -
256 subMat[0][1] * subMat[1][0];
262 swapMultiplier = 1.0f;
265 for(
unsigned int pivotRow = 0, pivotCol = 0;
266 pivotRow < subMatAxisSize;
267 pivotRow++, pivotCol++)
269 float& pivot = subMat[pivotRow][pivotCol];
271 if(almostEquals(pivot, 0.0f))
273 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
274 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
281 swapRows(pivotRow, nextValidPivotRowIdx);
283 determinant *= pivot;
285 eliminate(pivot, pivotRow);
288 determinant *= swapMultiplier;
292 float cofactor = minorMultiplier * determinant;
293 SetValueAt(cofactor, type, curIdx);
296 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
297 RecurseTensor(inputInfo,
305 void BatchMatMul::RecurseTensor(
const TensorInfo& tensorInfo,
306 const std::function<
void(
const std::vector<unsigned int>&)>& operation,
307 std::vector<unsigned int>& curIdx,
310 if(!(curDim < tensorInfo.GetNumDimensions()))
317 for(
unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
320 RecurseTensor(tensorInfo,
327 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
328 std::pair<unsigned int, unsigned int>& axesYToMul)
336 else if(rankDiff < 0)
339 axesXToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
340 axesXToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
342 else if(rankDiff > 0)
345 axesYToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
346 axesYToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
350 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx,
const std::vector<float>& customData)
355 AdjustToSafeIdx(type, idx);
356 unsigned int flatIdx = CalcFlatIdx(type, idx);
360 case DataSlot::InputX:
361 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
363 case DataSlot::InputY:
364 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
366 case DataSlot::Output:
367 outputEncoder[flatIdx];
368 value = outputEncoder.
Get();
377 void BatchMatMul::SetValueAt(
float value, DataSlot type, std::vector<unsigned int> idx)
379 AdjustToSafeIdx(type, idx);
380 unsigned int flatIdx = CalcFlatIdx(type, idx);
383 case DataSlot::InputX:
384 inputXData[flatIdx] = value;
386 case DataSlot::InputY:
387 inputYData[flatIdx] = value;
389 case DataSlot::Output:
390 outputEncoder[flatIdx];
391 outputEncoder.
Set(value);
398 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
400 for(
unsigned int dim = 0; dim < idx.size(); dim++)
404 case DataSlot::InputX:
409 idx[dim] > inputXInfo.
GetShape()[dim-xDiff]-1)
415 case DataSlot::InputY:
420 idx[dim] > inputYInfo.
GetShape()[dim-yDiff]-1)
426 case DataSlot::Output:
437 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type,
const std::vector<unsigned int>& idx)
439 unsigned int result = idx[idx.size()-1];
440 unsigned int dimMultiplier = 1;
444 for(
unsigned int i =
static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0; i--)
448 case DataSlot::InputX:
450 dimMultiplier *= inputXInfo.
GetShape()[i + 1 - offset];
452 case DataSlot::InputY:
454 dimMultiplier *= inputYInfo.
GetShape()[i + 1 - offset];
456 case DataSlot::Output:
457 dimMultiplier *= outputInfo.
GetShape()[i+1];
462 result += (idx[i] * dimMultiplier);