23 inputXInfo(inputXInfo),
24 inputYInfo(inputYInfo),
25 outputInfo(outputInfo),
26 inputXDecoder(inputXDecoder),
27 inputYDecoder(inputYDecoder),
28 outputEncoder(outputEncoder)
39 void BatchMatMul::ApplyBatchMatMul()
45 AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
47 unsigned int inputXColDim = axesXToMul.second;
48 unsigned int inputYRowDim = axesYToMul.first;
50 unsigned int inputYRowSize = inputYInfo.
GetShape()[inputYRowDim];
52 auto batchMatMulOperation = [&](
const std::vector<unsigned int>& curIdx)
57 for (
unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
59 xIdx[inputXColDim] = inputYRowIdx;
62 yIdx[inputYRowDim] = inputYRowIdx;
64 sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
67 SetValueAt(sum, DataSlot::Output, curIdx);
71 RecurseTensor(outputInfo,
77 void BatchMatMul::ApplyParams()
81 Transpose(DataSlot::InputX);
85 Adjoint(DataSlot::InputX);
89 Transpose(DataSlot::InputY);
93 Adjoint(DataSlot::InputY);
97 void BatchMatMul::Transpose(DataSlot type)
104 case DataSlot::InputX:
109 std::vector<float> temp(inputXData.size());
118 case DataSlot::InputY:
123 std::vector<float> temp(inputYData.size());
132 case DataSlot::Output:
138 void BatchMatMul::Adjoint(DataSlot type)
144 TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
149 std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
152 unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
153 std::vector<std::vector<float>> subMat(subMatAxisSize,
154 std::vector<float>(subMatAxisSize));
157 auto almostEquals = [&](
const float& a,
const float& b,
float unitsInLastPlace = 2.0f)
159 float diff = std::fabs(a-b);
160 float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
161 return (diff <= bound) || (diff < std::numeric_limits<float>::min());
164 float swapMultiplier = std::numeric_limits<float>::max();
165 auto swapRows = [&](
unsigned int rowIdxA,
unsigned int rowIdxB)
168 for(
unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
170 float tmp = subMat[rowIdxA][colIdx];
171 subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
172 subMat[rowIdxB][colIdx] = tmp;
174 swapMultiplier *= -1.0f;
177 auto findNextValidPivotRowIdx = [&](
unsigned int colIdx)
179 unsigned int result = std::numeric_limits<unsigned int>::max();
182 for(
unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
184 if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
193 auto eliminate = [&](
const float& pivot,
unsigned int pivotPos)
195 for(
unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
197 float multiplierNumerator = subMat[rowIdx][pivotPos];
198 if(almostEquals(multiplierNumerator, 0.0f))
202 float multiplier = multiplierNumerator / pivot;
204 for(
unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
210 subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
215 auto cofactorOperation = [&](
const std::vector<unsigned int>& curIdx)
217 auto row = curIdx[axesToAdjoint.first];
218 auto col = curIdx[axesToAdjoint.second];
220 float minorMultiplier =
static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
222 for(
unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
224 for(
unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
226 unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
227 unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
228 auto cloneIdx = curIdx;
229 cloneIdx[axesToAdjoint.first] = outerRow;
230 cloneIdx[axesToAdjoint.second] = outerCol;
231 subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
235 float determinant = 1.0f;
238 switch(subMatAxisSize)
242 determinant = GetValueAt(type, curIdx, inputDataClone);
248 determinant = subMat[0][0];
254 determinant = subMat[0][0] * subMat[1][1] -
255 subMat[0][1] * subMat[1][0];
261 swapMultiplier = 1.0f;
264 for(
unsigned int pivotRow = 0, pivotCol = 0;
265 pivotRow < subMatAxisSize;
266 pivotRow++, pivotCol++)
268 float& pivot = subMat[pivotRow][pivotCol];
270 if(almostEquals(pivot, 0.0f))
272 unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
273 if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
280 swapRows(pivotRow, nextValidPivotRowIdx);
282 determinant *= pivot;
284 eliminate(pivot, pivotRow);
287 determinant *= swapMultiplier;
291 float cofactor = minorMultiplier * determinant;
292 SetValueAt(cofactor, type, curIdx);
295 auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
296 RecurseTensor(inputInfo,
304 void BatchMatMul::RecurseTensor(
const TensorInfo& tensorInfo,
305 const std::function<
void(
const std::vector<unsigned int>&)>& operation,
306 std::vector<unsigned int>& curIdx,
309 if(!(curDim < tensorInfo.GetNumDimensions()))
316 for(
unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
319 RecurseTensor(tensorInfo,
326 void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
327 std::pair<unsigned int, unsigned int>& axesYToMul)
335 else if(rankDiff < 0)
338 axesXToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
339 axesXToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
341 else if(rankDiff > 0)
344 axesYToMul.first +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
345 axesYToMul.second +=
static_cast<std::make_unsigned<unsigned int>::type
>(std::abs(rankDiff));
349 float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx,
const std::vector<float>& customData)
354 AdjustToSafeIdx(type, idx);
355 unsigned int flatIdx = CalcFlatIdx(type, idx);
359 case DataSlot::InputX:
360 value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
362 case DataSlot::InputY:
363 value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
365 case DataSlot::Output:
366 outputEncoder[flatIdx];
367 value = outputEncoder.
Get();
376 void BatchMatMul::SetValueAt(
float value, DataSlot type, std::vector<unsigned int> idx)
378 AdjustToSafeIdx(type, idx);
379 unsigned int flatIdx = CalcFlatIdx(type, idx);
382 case DataSlot::InputX:
383 inputXData[flatIdx] = value;
385 case DataSlot::InputY:
386 inputYData[flatIdx] = value;
388 case DataSlot::Output:
389 outputEncoder[flatIdx];
390 outputEncoder.
Set(value);
397 void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
399 for(
unsigned int dim = 0; dim < idx.size(); dim++)
403 case DataSlot::InputX:
408 idx[dim] > inputXInfo.
GetShape()[dim-xDiff]-1)
414 case DataSlot::InputY:
419 idx[dim] > inputYInfo.
GetShape()[dim-yDiff]-1)
425 case DataSlot::Output:
436 unsigned int BatchMatMul::CalcFlatIdx(DataSlot type,
const std::vector<unsigned int>& idx)
438 unsigned int result = idx[idx.size()-1];
439 unsigned int dimMultiplier = 1;
443 for(
unsigned int i =
static_cast<unsigned int>(idx.size()-2);
static_cast<int>(i) >= 0; i--)
447 case DataSlot::InputX:
449 dimMultiplier *= inputXInfo.
GetShape()[i + 1 - offset];
451 case DataSlot::InputY:
453 dimMultiplier *= inputYInfo.
GetShape()[i + 1 - offset];
455 case DataSlot::Output:
456 dimMultiplier *= outputInfo.
GetShape()[i+1];
461 result += (idx[i] * dimMultiplier);