23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
30 inputXData = this->inputXDecoder.
DecodeTensor(inputXInfo.GetShape());
31 inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
39void BatchMatMul::ApplyBatchMatMul()
48 unsigned int inputYRowSize = inputYInfo.
GetShape()[axesYToMul.first];
50 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
52 unsigned int inputXColDim = axesXToMul.second;
53 unsigned int inputYRowDim = axesYToMul.first;
55 auto batchMatMulOperation = [&](
const std::vector<unsigned int>& curIdx)
60 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
62 xIdx[inputXColDim] = inputYRowIdx;
65 yIdx[inputYRowDim] = inputYRowIdx;
67 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
70 SetValueAt(sum, DataSlot::Output, curIdx);
73 auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
74 RecurseTensor(outputInfo,
80void BatchMatMul::ApplyParams()
82 if(params.m_TransposeX)
86 else if(params.m_AdjointX)
88 Adjoint(DataSlot::InputX);
90 if(params.m_TransposeY)
94 else if(params.m_AdjointY)
96 Adjoint(DataSlot::InputY);
100void BatchMatMul::Transpose(DataSlot type)
107 case DataSlot::InputX:
110 inputXInfo.GetShape());
112 std::vector<float> temp(inputXData.size());
121 case DataSlot::InputY:
124 inputYInfo.GetShape());
126 std::vector<float> temp(inputYData.size());
135 case DataSlot::Output:
141void BatchMatMul::Adjoint(DataSlot type)
147 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
148 const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
152 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
155 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
156 std::vector<std::vector<float>> subMat(subMatAxisSize,
157 std::vector<float>(subMatAxisSize));
160 auto almostEquals = [&](
const float& a,
const float& b,
float unitsInLastPlace = 2.0f)
162 float diff = std::fabs(a-b);
163 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
164 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
167 float swapMultiplier = std::numeric_limits<float>::max();
168 auto swapRows = [&](
unsigned int rowIdxA,
unsigned int rowIdxB)
171 for(
unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
173 float tmp = subMat[rowIdxA][colIdx];
174 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
175 subMat[rowIdxB][colIdx] = tmp;
177 swapMultiplier *= -1.0f;
180 auto findNextValidPivotRowIdx = [&](
unsigned int colIdx)
182 unsigned int result = std::numeric_limits<unsigned int>::max();
185 for(
unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
187 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
196 auto eliminate = [&](
const float& pivot,
unsigned int pivotPos)
198 for(
unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
200 float multiplierNumerator = subMat[rowIdx][pivotPos];
201 if(almostEquals(multiplierNumerator, 0.0f))
205 float multiplier = multiplierNumerator / pivot;
207 for(
unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
213 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
218 auto cofactorOperation = [&](
const std::vector<unsigned int>& curIdx)
220 auto row = curIdx[axesToAdjoint.first];
221 auto col = curIdx[axesToAdjoint.second];
223 float minorMultiplier =
static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
225 for(
unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
227 for(
unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
229 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
230 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
231 auto cloneIdx = curIdx;
232 cloneIdx[axesToAdjoint.first] = outerRow;
233 cloneIdx[axesToAdjoint.second] = outerCol;
234 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
238 float determinant = 1.0f;
241 switch(subMatAxisSize)
245 determinant = GetValueAt(type, curIdx, inputDataClone);
251 determinant = subMat[0][0];
257 determinant = subMat[0][0] * subMat[1][1] -
258 subMat[0][1] * subMat[1][0];
264 swapMultiplier = 1.0f;
267 for(
unsigned int pivotRow = 0, pivotCol = 0;
268 pivotRow < subMatAxisSize;
269 pivotRow++, pivotCol++)
271 float& pivot = subMat[pivotRow][pivotCol];
273 if(almostEquals(pivot, 0.0f))
275 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
276 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
283 swapRows(pivotRow, nextValidPivotRowIdx);
285 determinant *= pivot;
287 eliminate(pivot, pivotRow);
290 determinant *= swapMultiplier;
294 float cofactor = minorMultiplier * determinant;
295 SetValueAt(cofactor, type, curIdx);
298 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
299 RecurseTensor(inputInfo,
307void BatchMatMul::RecurseTensor(
const TensorInfo& tensorInfo,
308 const std::function<
void(
const std::vector<unsigned int>&)>& operation,
309 std::vector<unsigned int>& curIdx,
312 if(!(curDim < tensorInfo.GetNumDimensions()))
319 for(
unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
322 RecurseTensor(tensorInfo,
329void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
330 std::pair<unsigned int, unsigned int>& axesYToMul)
332 int rankDiff =
static_cast<int>(inputXInfo.GetNumDimensions()) -
333 static_cast<int>(inputYInfo.GetNumDimensions());
338 else if(rankDiff < 0)
341 axesXToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
342 axesXToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
344 else if(rankDiff > 0)
347 axesYToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
348 axesYToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
352float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx,
const std::vector<float>& customData)
357 AdjustToSafeIdx(type, idx);
358 unsigned int flatIdx = CalcFlatIdx(type, idx);
362 case DataSlot::InputX:
363 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
365 case DataSlot::InputY:
366 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
368 case DataSlot::Output:
369 outputEncoder[flatIdx];
370 value = outputEncoder.Get();
379void BatchMatMul::SetValueAt(
float value, DataSlot type, std::vector<unsigned int> idx)
381 AdjustToSafeIdx(type, idx);
382 unsigned int flatIdx = CalcFlatIdx(type, idx);
385 case DataSlot::InputX:
386 inputXData[flatIdx] = value;
388 case DataSlot::InputY:
389 inputYData[flatIdx] = value;
391 case DataSlot::Output:
392 outputEncoder[flatIdx];
393 outputEncoder.Set(value);
400void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
402 for(
unsigned int dim = 0; dim < idx.size(); dim++)
406 case DataSlot::InputX:
408 auto xRank = inputXInfo.GetNumDimensions();
409 auto xDiff = outputInfo.GetNumDimensions() - xRank;
411 idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
417 case DataSlot::InputY:
419 auto yRank = inputYInfo.GetNumDimensions();
420 auto yDiff = outputInfo.GetNumDimensions() - yRank;
422 idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
428 case DataSlot::Output:
439unsigned int BatchMatMul::CalcFlatIdx(DataSlot type,
const std::vector<unsigned int>& idx)
441 unsigned int result = idx[idx.size()-1];
442 unsigned int dimMultiplier = 1;
443 unsigned int offset = 0;
447 for(
unsigned int i =
static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0 && (i + 1) > offset; i--)
451 case DataSlot::InputX:
452 offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
453 dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
455 case DataSlot::InputY:
456 offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
457 dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
459 case DataSlot::Output:
460 dimMultiplier *= outputInfo.GetShape()[i+1];
465 result += (idx[i] * dimMultiplier);
BatchMatMul(const BatchMatMulDescriptor ¶ms, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
virtual std::vector< float > DecodeTensor(const TensorShape &tensorShape, bool isDepthwise=false)=0
const TensorShape & GetShape() const
Copyright (c) 2021 ARM Limited and Contributors.
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
A BatchMatMulDescriptor for the BatchMatMul operator.
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)