18 const std::vector<const TensorInfo*>& inputs,
19 const std::vector<const TensorInfo*>& outputs,
24 "ConvertSoftmaxToTosaOperator: Softmax currently only supports Int8 Quantized inputs");
27 "ConvertSoftmaxToTosaOperator: Softmax must have only one input");
30 "ConvertSoftmaxToTosaOperator: Softmax must have at least one output");
32 std::string inputName = std::string(
"input_");
33 std::string outputName = std::string(
"output0_");
104 std::vector<TosaSerializationTensor *> tensors;
105 std::vector<TosaSerializationOperator *> operators;
116 if (inputName.find(
"input_") != std::string::npos)
118 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape0, inputDType0, {}));
126 std::vector<uint8_t> uint8Data;
129 const std::vector<int32_t> singleValueShape(rank,1);
130 auto axis =
static_cast<int32_t
>(rank - 1);
131 TosaAxisAttribute tosaAxisAttribute(axis);
134 std::vector<int32_t> reduceShape = inputShape0;
135 reduceShape[
static_cast<unsigned long>(axis)] = 1;
137 TosaSerializationOperator *rescaleOp1 =
nullptr;
140 tensors.push_back(
new TosaSerializationTensor(outputNameRescale1, inputShape0, DType_INT32, {}));
141 operators.push_back(rescaleOp1);
143 auto *reduceMaxOp1 =
new TosaSerializationOperator(Op_REDUCE_MAX,
144 Attribute_AxisAttribute,
146 {outputNameRescale1},
147 {outputNameReduceMax1});
148 tensors.push_back(
new TosaSerializationTensor(outputNameReduceMax1, reduceShape, DType_INT32, {}));
149 operators.push_back(reduceMaxOp1);
151 auto *subOp1 =
new TosaSerializationOperator(Op_SUB,
154 {outputNameRescale1, outputNameReduceMax1},
156 tensors.push_back(
new TosaSerializationTensor(outputNameSub1, inputShape0, DType_INT32, {}));
157 operators.push_back(subOp1);
159 TosaSerializationOperator *rescaleOp2 =
nullptr;
161 tensors.push_back(
new TosaSerializationTensor(outputNameRescale2, inputShape0, DType_INT16, {}));
162 operators.push_back(rescaleOp2);
164 std::array<std::vector <int16_t>, 4> lookupTables;
167 const std::vector<int16_t> first = lookupTables[0];
168 const std::vector<int16_t> table1(&first[0],&first[0]+513);
169 const std::vector<int16_t> second = lookupTables[1];
170 const std::vector<int16_t> table2(&second[0],&second[0]+513);
171 const std::vector<int16_t> third = lookupTables[2];
172 const std::vector<int16_t> table3(&third[0],&third[0]+513);
173 const std::vector<int16_t> fourth = lookupTables[3];
174 const std::vector<int16_t> table4(&fourth[0],&fourth[0]+513);
176 TosaTableAttribute tosaTableAttribute1(table1);
177 TosaTableAttribute tosaTableAttribute2(table2);
178 TosaTableAttribute tosaTableAttribute3(table3);
179 TosaTableAttribute tosaTableAttribute4(table4);
181 auto* tableOp1 =
new TosaSerializationOperator(Op_TABLE,
182 Attribute_TableAttribute,
183 &tosaTableAttribute1,
184 {outputNameRescale2},
186 tensors.push_back(
new TosaSerializationTensor(outputNameTable1, inputShape0, DType_INT32, {}));
187 operators.push_back(tableOp1);
189 auto* tableOp2 =
new TosaSerializationOperator(Op_TABLE,
190 Attribute_TableAttribute,
191 &tosaTableAttribute2,
192 {outputNameRescale2},
194 tensors.push_back(
new TosaSerializationTensor(outputNameTable2, inputShape0, DType_INT32, {}));
195 operators.push_back(tableOp2);
197 auto* tableOp3 =
new TosaSerializationOperator(Op_TABLE,
198 Attribute_TableAttribute,
199 &tosaTableAttribute3,
200 {outputNameRescale2},
202 tensors.push_back(
new TosaSerializationTensor(outputNameTable3, inputShape0, DType_INT32, {}));
203 operators.push_back(tableOp3);
205 auto* tableOp4 =
new TosaSerializationOperator(Op_TABLE,
206 Attribute_TableAttribute,
207 &tosaTableAttribute4,
208 {outputNameRescale2},
210 tensors.push_back(
new TosaSerializationTensor(outputNameTable4, inputShape0, DType_INT32, {}));
211 operators.push_back(tableOp4);
213 TosaSerializationHandler::ConvertI32toU8({17}, uint8Data);
214 tensors.push_back(
new TosaSerializationTensor(inputNameConst1,singleValueShape, DType_INT32,uint8Data));
215 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst1}));
217 auto* logicalLOp1 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
220 {outputNameTable1, inputNameConst1},
221 {outputNameLogicalL1});
222 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL1, inputShape0, DType_INT32, {}));
223 operators.push_back(logicalLOp1);
225 TosaSerializationHandler::ConvertI32toU8({9}, uint8Data);
226 tensors.push_back(
new TosaSerializationTensor(inputNameConst2, singleValueShape, DType_INT32,uint8Data));
227 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst2}));
229 auto* logicalLOp2 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
232 {outputNameTable2, inputNameConst2},
233 {outputNameLogicalL2});
234 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL2, inputShape0, DType_INT32, {}));
235 operators.push_back(logicalLOp2);
237 TosaSerializationHandler::ConvertI32toU8({1}, uint8Data);
238 tensors.push_back(
new TosaSerializationTensor(inputNameConst3, singleValueShape, DType_INT32,uint8Data));
239 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst3}));
241 auto* logicalLOp3 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
244 {outputNameTable3, inputNameConst3},
245 {outputNameLogicalL3});
246 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL3, inputShape0, DType_INT32, {}));
247 operators.push_back(logicalLOp3);
249 bool rounding =
true;
250 TosaArithmeticRightShiftAttribute shiftRAttribute(rounding);
252 TosaSerializationHandler::ConvertI32toU8({7}, uint8Data);
253 tensors.push_back(
new TosaSerializationTensor(inputNameConst4, singleValueShape, DType_INT32,uint8Data));
254 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst4}));
256 auto* arithmeticROp1 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
257 Attribute_ArithmeticRightShiftAttribute,
259 {outputNameTable4, inputNameConst4},
260 {outputNameArithmeticR1});
261 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR1, inputShape0, DType_INT32, {}));
262 operators.push_back(arithmeticROp1);
264 auto* addOp1 =
new TosaSerializationOperator(Op_ADD,
267 {outputNameLogicalL1, outputNameLogicalL2},
269 tensors.push_back(
new TosaSerializationTensor(outputNameAdd1, inputShape0, DType_INT32, {}));
270 operators.push_back(addOp1);
272 auto* addOp2 =
new TosaSerializationOperator(Op_ADD,
275 {outputNameAdd1, outputNameLogicalL3},
277 tensors.push_back(
new TosaSerializationTensor(outputNameAdd2, inputShape0, DType_INT32, {}));
278 operators.push_back(addOp2);
280 auto* addOp3 =
new TosaSerializationOperator(Op_ADD,
283 {outputNameAdd2, outputNameArithmeticR1},
285 tensors.push_back(
new TosaSerializationTensor(outputNameAdd3, inputShape0, DType_INT32, {}));
286 operators.push_back(addOp3);
288 TosaSerializationHandler::ConvertI32toU8({12}, uint8Data);
289 tensors.push_back(
new TosaSerializationTensor(inputNameConst5, singleValueShape, DType_INT32,uint8Data));
290 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst5}));
292 auto* arithmeticROp2 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
293 Attribute_ArithmeticRightShiftAttribute,
295 {outputNameAdd3, inputNameConst5},
296 {outputNameArithmeticR2});
297 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR2, inputShape0, DType_INT32, {}));
298 operators.push_back(arithmeticROp2);
300 auto* reduceSumOp1 =
new TosaSerializationOperator(Op_REDUCE_SUM,
301 Attribute_AxisAttribute,
303 {outputNameArithmeticR2},
304 {outputNameReduceSum1});
305 tensors.push_back(
new TosaSerializationTensor(outputNameReduceSum1, reduceShape, DType_INT32, {}));
306 operators.push_back(reduceSumOp1);
308 auto* countLeadingZeroOp1 =
new TosaSerializationOperator(Op_CLZ,
311 {outputNameReduceSum1},
313 tensors.push_back(
new TosaSerializationTensor(outputNameCLZ1, reduceShape, DType_INT32, {}));
314 operators.push_back(countLeadingZeroOp1);
316 TosaSerializationHandler::ConvertI32toU8({1}, uint8Data);
317 tensors.push_back(
new TosaSerializationTensor(inputNameConst3a, singleValueShape, DType_INT32, uint8Data));
318 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst3a}));
320 auto* subOp2 =
new TosaSerializationOperator(Op_SUB,
323 {outputNameCLZ1, inputNameConst3a},
325 tensors.push_back(
new TosaSerializationTensor(outputNameSub2, reduceShape, DType_INT32, {}));
326 operators.push_back(subOp2);
329 auto* logicalLOp4 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
332 {outputNameReduceSum1, outputNameSub2},
333 {outputNameLogicalL4});
334 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL4, reduceShape, DType_INT32, {}));
335 operators.push_back(logicalLOp4);
337 TosaMulAttribute mulAttribute1(31);
339 TosaSerializationHandler::ConvertI32toU8({-1010580540}, uint8Data);
340 tensors.push_back(
new TosaSerializationTensor(inputNameConst6, singleValueShape, DType_INT32,uint8Data));
341 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst6}));
344 auto* mulOp1 =
new TosaSerializationOperator(Op_MUL,
345 Attribute_MulAttribute,
347 {outputNameLogicalL4, inputNameConst6},
349 tensors.push_back(
new TosaSerializationTensor(outputNameMul1, reduceShape, DType_INT32, {}));
350 operators.push_back(mulOp1);
352 TosaSerializationHandler::ConvertI32toU8({1515870810}, uint8Data);
353 tensors.push_back(
new TosaSerializationTensor(inputNameConst7, singleValueShape, DType_INT32,uint8Data));
354 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst7}));
357 auto* addOp4 =
new TosaSerializationOperator(Op_ADD,
360 {outputNameMul1, inputNameConst7},
362 tensors.push_back(
new TosaSerializationTensor(outputNameAdd4, reduceShape, DType_INT32, {}));
363 operators.push_back(addOp4);
366 auto* mulOp2 =
new TosaSerializationOperator(Op_MUL,
367 Attribute_MulAttribute,
369 {outputNameAdd4, outputNameLogicalL4},
371 tensors.push_back(
new TosaSerializationTensor(outputNameMul2, reduceShape, DType_INT32, {}));
372 operators.push_back(mulOp2);
374 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
375 tensors.push_back(
new TosaSerializationTensor(inputNameConst8, singleValueShape, DType_INT32,uint8Data));
376 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8}));
378 auto* subOp3 =
new TosaSerializationOperator(Op_SUB,
381 {inputNameConst8, outputNameMul2},
383 tensors.push_back(
new TosaSerializationTensor(outputNameSub3, reduceShape, DType_INT32, {}));
384 operators.push_back(subOp3);
386 auto* mulOp3 =
new TosaSerializationOperator(Op_MUL,
387 Attribute_MulAttribute,
389 {outputNameAdd4, outputNameSub3},
391 tensors.push_back(
new TosaSerializationTensor(outputNameMul3, reduceShape, DType_INT32, {}));
392 operators.push_back(mulOp3);
394 TosaMulAttribute mulAttribute2(0);
396 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
397 tensors.push_back(
new TosaSerializationTensor(inputNameConst9, singleValueShape, DType_INT32,uint8Data));
398 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9}));
400 auto* mulOp4 =
new TosaSerializationOperator(Op_MUL,
401 Attribute_MulAttribute,
403 {outputNameMul3, inputNameConst9},
405 tensors.push_back(
new TosaSerializationTensor(outputNameMul4, reduceShape, DType_INT32, {}));
406 operators.push_back(mulOp4);
408 auto* addOp5 =
new TosaSerializationOperator(Op_ADD,
411 {outputNameAdd4, outputNameMul4},
413 tensors.push_back(
new TosaSerializationTensor(outputNameAdd5, reduceShape, DType_INT32, {}));
414 operators.push_back(addOp5);
417 auto* mulOp5 =
new TosaSerializationOperator(Op_MUL,
418 Attribute_MulAttribute,
420 {outputNameAdd5, outputNameLogicalL4},
422 tensors.push_back(
new TosaSerializationTensor(outputNameMul5, reduceShape, DType_INT32, {}));
423 operators.push_back(mulOp5);
425 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
426 tensors.push_back(
new TosaSerializationTensor(inputNameConst8a, singleValueShape, DType_INT32,uint8Data));
427 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8a}));
429 auto* subOp4 =
new TosaSerializationOperator(Op_SUB,
432 {inputNameConst8a, outputNameMul5},
434 tensors.push_back(
new TosaSerializationTensor(outputNameSub4, reduceShape, DType_INT32, {}));
435 operators.push_back(subOp4);
437 auto* mulOp6 =
new TosaSerializationOperator(Op_MUL,
438 Attribute_MulAttribute,
440 {outputNameAdd5, outputNameSub4},
442 tensors.push_back(
new TosaSerializationTensor(outputNameMul6, reduceShape, DType_INT32, {}));
443 operators.push_back(mulOp6);
445 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
446 tensors.push_back(
new TosaSerializationTensor(inputNameConst9a, singleValueShape, DType_INT32,uint8Data));
447 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9a}));
449 auto* mulOp7 =
new TosaSerializationOperator(Op_MUL,
450 Attribute_MulAttribute,
452 {outputNameMul6, inputNameConst9a},
454 tensors.push_back(
new TosaSerializationTensor(outputNameMul7, reduceShape, DType_INT32, {}));
455 operators.push_back(mulOp7);
457 auto* addOp6 =
new TosaSerializationOperator(Op_ADD,
460 {outputNameAdd5, outputNameMul7},
462 tensors.push_back(
new TosaSerializationTensor(outputNameAdd6, reduceShape, DType_INT32, {}));
463 operators.push_back(addOp6);
466 auto* mulOp8 =
new TosaSerializationOperator(Op_MUL,
467 Attribute_MulAttribute,
469 {outputNameAdd6, outputNameLogicalL4},
471 tensors.push_back(
new TosaSerializationTensor(outputNameMul8, reduceShape, DType_INT32, {}));
472 operators.push_back(mulOp8);
474 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
475 tensors.push_back(
new TosaSerializationTensor(inputNameConst8b, singleValueShape, DType_INT32,uint8Data));
476 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8b}));
478 auto* subOp5 =
new TosaSerializationOperator(Op_SUB,
481 {inputNameConst8b, outputNameMul8},
483 tensors.push_back(
new TosaSerializationTensor(outputNameSub5, reduceShape, DType_INT32, {}));
484 operators.push_back(subOp5);
486 auto* mulOp9 =
new TosaSerializationOperator(Op_MUL,
487 Attribute_MulAttribute,
489 {outputNameAdd6, outputNameSub5},
491 tensors.push_back(
new TosaSerializationTensor(outputNameMul9, reduceShape, DType_INT32, {}));
492 operators.push_back(mulOp9);
494 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
495 tensors.push_back(
new TosaSerializationTensor(inputNameConst9b, singleValueShape, DType_INT32,uint8Data));
496 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9b}));
498 auto* mulOp10 =
new TosaSerializationOperator(Op_MUL,
499 Attribute_MulAttribute,
501 {outputNameMul9, inputNameConst9b},
503 tensors.push_back(
new TosaSerializationTensor(outputNameMul10, reduceShape, DType_INT32, {}));
504 operators.push_back(mulOp10);
506 auto* addOp7 =
new TosaSerializationOperator(Op_ADD,
509 {outputNameAdd6, outputNameMul10},
511 tensors.push_back(
new TosaSerializationTensor(outputNameAdd7, reduceShape, DType_INT32, {}));
512 operators.push_back(addOp7);
514 TosaMulAttribute mulAttribute3(30);
516 auto* mulOp11 =
new TosaSerializationOperator(Op_MUL,
517 Attribute_MulAttribute,
519 {outputNameAdd3, outputNameAdd7},
521 tensors.push_back(
new TosaSerializationTensor(outputNameMul11, outputShape0, DType_INT32, {}));
522 operators.push_back(mulOp11);
524 TosaSerializationHandler::ConvertI32toU8({35}, uint8Data);
525 tensors.push_back(
new TosaSerializationTensor(inputNameConst10, singleValueShape, DType_INT32,uint8Data));
526 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst10}));
528 auto* subOp6 =
new TosaSerializationOperator(Op_SUB,
531 {inputNameConst10, outputNameCLZ1},
533 tensors.push_back(
new TosaSerializationTensor(outputNameSub6, reduceShape, DType_INT32, {}));
534 operators.push_back(subOp6);
536 auto* arithmeticROp3 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
537 Attribute_ArithmeticRightShiftAttribute,
539 {outputNameMul11, outputNameSub6},
540 {outputNameArithmeticR3});
541 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR3, outputShape0, DType_INT32, {}));
542 operators.push_back(arithmeticROp3);
544 TosaSerializationOperator* rescaleOp3 =
nullptr;
547 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
548 operators.push_back(rescaleOp3);
550 return new TosaSerializationBasicBlock(blockName,