23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
39 void BatchMatMul::ApplyBatchMatMul()
48 unsigned int inputYRowSize = inputYInfo.
GetShape()[axesYToMul.first];
50 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
52 unsigned int inputXColDim = axesXToMul.second;
53 unsigned int inputYRowDim = axesYToMul.first;
55 auto batchMatMulOperation = [&](
const std::vector<unsigned int>& curIdx)
60 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
62 xIdx[inputXColDim] = inputYRowIdx;
65 yIdx[inputYRowDim] = inputYRowIdx;
67 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
70 SetValueAt(sum, DataSlot::Output, curIdx);
74 RecurseTensor(outputInfo,
80 void BatchMatMul::ApplyParams()
84 Transpose(DataSlot::InputX);
88 Adjoint(DataSlot::InputX);
92 Transpose(DataSlot::InputY);
96 Adjoint(DataSlot::InputY);
100 void BatchMatMul::Transpose(DataSlot type)
107 case DataSlot::InputX:
112 std::vector<float> temp(inputXData.size());
121 case DataSlot::InputY:
126 std::vector<float> temp(inputYData.size());
135 case DataSlot::Output:
141 void BatchMatMul::Adjoint(DataSlot type)
147 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
152 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
155 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
156 std::vector<std::vector<float>> subMat(subMatAxisSize,
157 std::vector<float>(subMatAxisSize));
160 auto almostEquals = [&](
const float& a,
const float& b,
float unitsInLastPlace = 2.0f)
162 float diff = std::fabs(a-b);
163 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
164 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
167 float swapMultiplier = std::numeric_limits<float>::max();
168 auto swapRows = [&](
unsigned int rowIdxA,
unsigned int rowIdxB)
171 for(
unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
173 float tmp = subMat[rowIdxA][colIdx];
174 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
175 subMat[rowIdxB][colIdx] = tmp;
177 swapMultiplier *= -1.0f;
180 auto findNextValidPivotRowIdx = [&](
unsigned int colIdx)
182 unsigned int result = std::numeric_limits<unsigned int>::max();
185 for(
unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
187 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
196 auto eliminate = [&](
const float& pivot,
unsigned int pivotPos)
198 for(
unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
200 float multiplierNumerator = subMat[rowIdx][pivotPos];
201 if(almostEquals(multiplierNumerator, 0.0f))
205 float multiplier = multiplierNumerator / pivot;
207 for(
unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
213 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
218 auto cofactorOperation = [&](
const std::vector<unsigned int>& curIdx)
220 auto row = curIdx[axesToAdjoint.first];
221 auto col = curIdx[axesToAdjoint.second];
223 float minorMultiplier =
static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
225 for(
unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
227 for(
unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
229 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
230 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
231 auto cloneIdx = curIdx;
232 cloneIdx[axesToAdjoint.first] = outerRow;
233 cloneIdx[axesToAdjoint.second] = outerCol;
234 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
238 float determinant = 1.0f;
241 switch(subMatAxisSize)
245 determinant = GetValueAt(type, curIdx, inputDataClone);
251 determinant = subMat[0][0];
257 determinant = subMat[0][0] * subMat[1][1] -
258 subMat[0][1] * subMat[1][0];
264 swapMultiplier = 1.0f;
267 for(
unsigned int pivotRow = 0, pivotCol = 0;
268 pivotRow < subMatAxisSize;
269 pivotRow++, pivotCol++)
271 float& pivot = subMat[pivotRow][pivotCol];
273 if(almostEquals(pivot, 0.0f))
275 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
276 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
283 swapRows(pivotRow, nextValidPivotRowIdx);
285 determinant *= pivot;
287 eliminate(pivot, pivotRow);
290 determinant *= swapMultiplier;
294 float cofactor = minorMultiplier * determinant;
295 SetValueAt(cofactor, type, curIdx);
298 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
299 RecurseTensor(inputInfo,
307 void BatchMatMul::RecurseTensor(
const TensorInfo& tensorInfo,
308 const std::function<
void(
const std::vector<unsigned int>&)>& operation,
309 std::vector<unsigned int>& curIdx,
312 if(!(curDim < tensorInfo.GetNumDimensions()))
319 for(
unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
322 RecurseTensor(tensorInfo,
329 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
330 std::pair<unsigned int, unsigned int>& axesYToMul)
338 else if(rankDiff < 0)
341 axesXToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
342 axesXToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
344 else if(rankDiff > 0)
347 axesYToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
348 axesYToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
352 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx,
const std::vector<float>& customData)
357 AdjustToSafeIdx(type, idx);
358 unsigned int flatIdx = CalcFlatIdx(type, idx);
362 case DataSlot::InputX:
363 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
365 case DataSlot::InputY:
366 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
368 case DataSlot::Output:
369 outputEncoder[flatIdx];
370 value = outputEncoder.
Get();
379 void BatchMatMul::SetValueAt(
float value, DataSlot type, std::vector<unsigned int> idx)
381 AdjustToSafeIdx(type, idx);
382 unsigned int flatIdx = CalcFlatIdx(type, idx);
385 case DataSlot::InputX:
386 inputXData[flatIdx] = value;
388 case DataSlot::InputY:
389 inputYData[flatIdx] = value;
391 case DataSlot::Output:
392 outputEncoder[flatIdx];
393 outputEncoder.
Set(value);
400 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
402 for(
unsigned int dim = 0; dim < idx.size(); dim++)
406 case DataSlot::InputX:
411 idx[dim] > inputXInfo.
GetShape()[dim-xDiff]-1)
417 case DataSlot::InputY:
422 idx[dim] > inputYInfo.
GetShape()[dim-yDiff]-1)
428 case DataSlot::Output:
439 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type,
const std::vector<unsigned int>& idx)
441 unsigned int result = idx[idx.size()-1];
442 unsigned int dimMultiplier = 1;
443 unsigned int offset = 0;
447 for(
unsigned int i =
static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0 && (i + 1) > offset; i--)
451 case DataSlot::InputX:
453 dimMultiplier *= inputXInfo.
GetShape()[i + 1 - offset];
455 case DataSlot::InputY:
457 dimMultiplier *= inputYInfo.
GetShape()[i + 1 - offset];
459 case DataSlot::Output:
460 dimMultiplier *= outputInfo.
GetShape()[i+1];
465 result += (idx[i] * dimMultiplier);