18 const std::vector<const TensorInfo*>& inputs,
19 const std::vector<const TensorInfo*>& outputs,
24 "ConvertSoftmaxToTosaOperator: Softmax currently only supports Int8 Quantized inputs");
27 "ConvertSoftmaxToTosaOperator: Softmax must have only one input");
30 "ConvertSoftmaxToTosaOperator: Softmax must have at least one output");
32 std::string inputName = std::string(
"input_");
33 std::string outputName = std::string(
"output0_");
106 std::vector<TosaSerializationTensor *> tensors;
107 std::vector<TosaSerializationOperator *> operators;
118 if (inputName.find(
"input_") != std::string::npos)
120 tensors.push_back(
new TosaSerializationTensor(inputName, inputShape0, inputDType0, {}));
128 std::vector<uint8_t> uint8Data;
131 const std::vector<int32_t> singleValueShape(rank,1);
132 auto axis =
static_cast<int32_t
>(rank - 1);
133 TosaAxisAttribute tosaAxisAttribute(axis);
136 std::vector<int32_t> reduceShape = inputShape0;
137 reduceShape[
static_cast<unsigned long>(axis)] = 1;
139 TosaSerializationOperator *rescaleOp1 =
nullptr;
141 false,
false,
false,
true, &rescaleOp1);
143 tensors.push_back(
new TosaSerializationTensor(outputNameRescale1, inputShape0, DType_INT32, {}));
144 operators.push_back(rescaleOp1);
146 auto *reduceMaxOp1 =
new TosaSerializationOperator(Op_REDUCE_MAX,
147 Attribute_AxisAttribute,
149 {outputNameRescale1},
150 {outputNameReduceMax1});
151 tensors.push_back(
new TosaSerializationTensor(outputNameReduceMax1, reduceShape, DType_INT32, {}));
152 operators.push_back(reduceMaxOp1);
154 auto *subOp1 =
new TosaSerializationOperator(Op_SUB,
157 {outputNameRescale1, outputNameReduceMax1},
159 tensors.push_back(
new TosaSerializationTensor(outputNameSub1, inputShape0, DType_INT32, {}));
160 operators.push_back(subOp1);
162 TosaSerializationOperator *rescaleOp2 =
nullptr;
164 tensors.push_back(
new TosaSerializationTensor(outputNameRescale2, inputShape0, DType_INT16, {}));
165 operators.push_back(rescaleOp2);
167 std::array<std::vector <int16_t>, 4> lookupTables;
170 const std::vector<int16_t> first = lookupTables[0];
171 const std::vector<int16_t> table1(&first[0],&first[0]+513);
172 const std::vector<int16_t> second = lookupTables[1];
173 const std::vector<int16_t> table2(&second[0],&second[0]+513);
174 const std::vector<int16_t> third = lookupTables[2];
175 const std::vector<int16_t> table3(&third[0],&third[0]+513);
176 const std::vector<int16_t> fourth = lookupTables[3];
177 const std::vector<int16_t> table4(&fourth[0],&fourth[0]+513);
179 TosaTableAttribute tosaTableAttribute1(table1);
180 TosaTableAttribute tosaTableAttribute2(table2);
181 TosaTableAttribute tosaTableAttribute3(table3);
182 TosaTableAttribute tosaTableAttribute4(table4);
184 auto* tableOp1 =
new TosaSerializationOperator(Op_TABLE,
185 Attribute_TableAttribute,
186 &tosaTableAttribute1,
187 {outputNameRescale2},
189 tensors.push_back(
new TosaSerializationTensor(outputNameTable1, inputShape0, DType_INT32, {}));
190 operators.push_back(tableOp1);
192 auto* tableOp2 =
new TosaSerializationOperator(Op_TABLE,
193 Attribute_TableAttribute,
194 &tosaTableAttribute2,
195 {outputNameRescale2},
197 tensors.push_back(
new TosaSerializationTensor(outputNameTable2, inputShape0, DType_INT32, {}));
198 operators.push_back(tableOp2);
200 auto* tableOp3 =
new TosaSerializationOperator(Op_TABLE,
201 Attribute_TableAttribute,
202 &tosaTableAttribute3,
203 {outputNameRescale2},
205 tensors.push_back(
new TosaSerializationTensor(outputNameTable3, inputShape0, DType_INT32, {}));
206 operators.push_back(tableOp3);
208 auto* tableOp4 =
new TosaSerializationOperator(Op_TABLE,
209 Attribute_TableAttribute,
210 &tosaTableAttribute4,
211 {outputNameRescale2},
213 tensors.push_back(
new TosaSerializationTensor(outputNameTable4, inputShape0, DType_INT32, {}));
214 operators.push_back(tableOp4);
216 TosaSerializationHandler::ConvertI32toU8({17}, uint8Data);
217 tensors.push_back(
new TosaSerializationTensor(inputNameConst1,singleValueShape, DType_INT32,uint8Data));
218 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst1}));
220 auto* logicalLOp1 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
223 {outputNameTable1, inputNameConst1},
224 {outputNameLogicalL1});
225 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL1, inputShape0, DType_INT32, {}));
226 operators.push_back(logicalLOp1);
228 TosaSerializationHandler::ConvertI32toU8({9}, uint8Data);
229 tensors.push_back(
new TosaSerializationTensor(inputNameConst2, singleValueShape, DType_INT32,uint8Data));
230 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst2}));
232 auto* logicalLOp2 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
235 {outputNameTable2, inputNameConst2},
236 {outputNameLogicalL2});
237 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL2, inputShape0, DType_INT32, {}));
238 operators.push_back(logicalLOp2);
240 TosaSerializationHandler::ConvertI32toU8({1}, uint8Data);
241 tensors.push_back(
new TosaSerializationTensor(inputNameConst3, singleValueShape, DType_INT32,uint8Data));
242 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst3}));
244 auto* logicalLOp3 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
247 {outputNameTable3, inputNameConst3},
248 {outputNameLogicalL3});
249 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL3, inputShape0, DType_INT32, {}));
250 operators.push_back(logicalLOp3);
252 bool rounding =
true;
253 TosaArithmeticRightShiftAttribute shiftRAttribute(rounding);
255 TosaSerializationHandler::ConvertI32toU8({7}, uint8Data);
256 tensors.push_back(
new TosaSerializationTensor(inputNameConst4, singleValueShape, DType_INT32,uint8Data));
257 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst4}));
259 auto* arithmeticROp1 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
260 Attribute_ArithmeticRightShiftAttribute,
262 {outputNameTable4, inputNameConst4},
263 {outputNameArithmeticR1});
264 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR1, inputShape0, DType_INT32, {}));
265 operators.push_back(arithmeticROp1);
267 auto* addOp1 =
new TosaSerializationOperator(Op_ADD,
270 {outputNameLogicalL1, outputNameLogicalL2},
272 tensors.push_back(
new TosaSerializationTensor(outputNameAdd1, inputShape0, DType_INT32, {}));
273 operators.push_back(addOp1);
275 auto* addOp2 =
new TosaSerializationOperator(Op_ADD,
278 {outputNameAdd1, outputNameLogicalL3},
280 tensors.push_back(
new TosaSerializationTensor(outputNameAdd2, inputShape0, DType_INT32, {}));
281 operators.push_back(addOp2);
283 auto* addOp3 =
new TosaSerializationOperator(Op_ADD,
286 {outputNameAdd2, outputNameArithmeticR1},
288 tensors.push_back(
new TosaSerializationTensor(outputNameAdd3, inputShape0, DType_INT32, {}));
289 operators.push_back(addOp3);
291 TosaSerializationHandler::ConvertI32toU8({12}, uint8Data);
292 tensors.push_back(
new TosaSerializationTensor(inputNameConst5, singleValueShape, DType_INT32,uint8Data));
293 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst5}));
295 auto* arithmeticROp2 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
296 Attribute_ArithmeticRightShiftAttribute,
298 {outputNameAdd3, inputNameConst5},
299 {outputNameArithmeticR2});
300 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR2, inputShape0, DType_INT32, {}));
301 operators.push_back(arithmeticROp2);
303 auto* reduceSumOp1 =
new TosaSerializationOperator(Op_REDUCE_SUM,
304 Attribute_AxisAttribute,
306 {outputNameArithmeticR2},
307 {outputNameReduceSum1});
308 tensors.push_back(
new TosaSerializationTensor(outputNameReduceSum1, reduceShape, DType_INT32, {}));
309 operators.push_back(reduceSumOp1);
311 auto* countLeadingZeroOp1 =
new TosaSerializationOperator(Op_CLZ,
314 {outputNameReduceSum1},
316 tensors.push_back(
new TosaSerializationTensor(outputNameCLZ1, reduceShape, DType_INT32, {}));
317 operators.push_back(countLeadingZeroOp1);
319 TosaSerializationHandler::ConvertI32toU8({1}, uint8Data);
320 tensors.push_back(
new TosaSerializationTensor(inputNameConst3a, singleValueShape, DType_INT32, uint8Data));
321 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst3a}));
323 auto* subOp2 =
new TosaSerializationOperator(Op_SUB,
326 {outputNameCLZ1, inputNameConst3a},
328 tensors.push_back(
new TosaSerializationTensor(outputNameSub2, reduceShape, DType_INT32, {}));
329 operators.push_back(subOp2);
332 auto* logicalLOp4 =
new TosaSerializationOperator(Op_LOGICAL_LEFT_SHIFT,
335 {outputNameReduceSum1, outputNameSub2},
336 {outputNameLogicalL4});
337 tensors.push_back(
new TosaSerializationTensor(outputNameLogicalL4, reduceShape, DType_INT32, {}));
338 operators.push_back(logicalLOp4);
340 TosaMulAttribute mulAttribute1(31);
342 TosaSerializationHandler::ConvertI32toU8({-1010580540}, uint8Data);
343 tensors.push_back(
new TosaSerializationTensor(inputNameConst6, singleValueShape, DType_INT32,uint8Data));
344 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst6}));
347 auto* mulOp1 =
new TosaSerializationOperator(Op_MUL,
348 Attribute_MulAttribute,
350 {outputNameLogicalL4, inputNameConst6},
352 tensors.push_back(
new TosaSerializationTensor(outputNameMul1, reduceShape, DType_INT32, {}));
353 operators.push_back(mulOp1);
355 TosaSerializationHandler::ConvertI32toU8({1515870810}, uint8Data);
356 tensors.push_back(
new TosaSerializationTensor(inputNameConst7, singleValueShape, DType_INT32,uint8Data));
357 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst7}));
360 auto* addOp4 =
new TosaSerializationOperator(Op_ADD,
363 {outputNameMul1, inputNameConst7},
365 tensors.push_back(
new TosaSerializationTensor(outputNameAdd4, reduceShape, DType_INT32, {}));
366 operators.push_back(addOp4);
369 auto* mulOp2 =
new TosaSerializationOperator(Op_MUL,
370 Attribute_MulAttribute,
372 {outputNameAdd4, outputNameLogicalL4},
374 tensors.push_back(
new TosaSerializationTensor(outputNameMul2, reduceShape, DType_INT32, {}));
375 operators.push_back(mulOp2);
377 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
378 tensors.push_back(
new TosaSerializationTensor(inputNameConst8, singleValueShape, DType_INT32,uint8Data));
379 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8}));
381 auto* subOp3 =
new TosaSerializationOperator(Op_SUB,
384 {inputNameConst8, outputNameMul2},
386 tensors.push_back(
new TosaSerializationTensor(outputNameSub3, reduceShape, DType_INT32, {}));
387 operators.push_back(subOp3);
389 auto* mulOp3 =
new TosaSerializationOperator(Op_MUL,
390 Attribute_MulAttribute,
392 {outputNameAdd4, outputNameSub3},
394 tensors.push_back(
new TosaSerializationTensor(outputNameMul3, reduceShape, DType_INT32, {}));
395 operators.push_back(mulOp3);
397 TosaMulAttribute mulAttribute2(0);
399 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
400 tensors.push_back(
new TosaSerializationTensor(inputNameConst9, singleValueShape, DType_INT32,uint8Data));
401 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9}));
403 auto* mulOp4 =
new TosaSerializationOperator(Op_MUL,
404 Attribute_MulAttribute,
406 {outputNameMul3, inputNameConst9},
408 tensors.push_back(
new TosaSerializationTensor(outputNameMul4, reduceShape, DType_INT32, {}));
409 operators.push_back(mulOp4);
411 auto* addOp5 =
new TosaSerializationOperator(Op_ADD,
414 {outputNameAdd4, outputNameMul4},
416 tensors.push_back(
new TosaSerializationTensor(outputNameAdd5, reduceShape, DType_INT32, {}));
417 operators.push_back(addOp5);
420 auto* mulOp5 =
new TosaSerializationOperator(Op_MUL,
421 Attribute_MulAttribute,
423 {outputNameAdd5, outputNameLogicalL4},
425 tensors.push_back(
new TosaSerializationTensor(outputNameMul5, reduceShape, DType_INT32, {}));
426 operators.push_back(mulOp5);
428 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
429 tensors.push_back(
new TosaSerializationTensor(inputNameConst8a, singleValueShape, DType_INT32,uint8Data));
430 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8a}));
432 auto* subOp4 =
new TosaSerializationOperator(Op_SUB,
435 {inputNameConst8a, outputNameMul5},
437 tensors.push_back(
new TosaSerializationTensor(outputNameSub4, reduceShape, DType_INT32, {}));
438 operators.push_back(subOp4);
440 auto* mulOp6 =
new TosaSerializationOperator(Op_MUL,
441 Attribute_MulAttribute,
443 {outputNameAdd5, outputNameSub4},
445 tensors.push_back(
new TosaSerializationTensor(outputNameMul6, reduceShape, DType_INT32, {}));
446 operators.push_back(mulOp6);
448 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
449 tensors.push_back(
new TosaSerializationTensor(inputNameConst9a, singleValueShape, DType_INT32,uint8Data));
450 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9a}));
452 auto* mulOp7 =
new TosaSerializationOperator(Op_MUL,
453 Attribute_MulAttribute,
455 {outputNameMul6, inputNameConst9a},
457 tensors.push_back(
new TosaSerializationTensor(outputNameMul7, reduceShape, DType_INT32, {}));
458 operators.push_back(mulOp7);
460 auto* addOp6 =
new TosaSerializationOperator(Op_ADD,
463 {outputNameAdd5, outputNameMul7},
465 tensors.push_back(
new TosaSerializationTensor(outputNameAdd6, reduceShape, DType_INT32, {}));
466 operators.push_back(addOp6);
469 auto* mulOp8 =
new TosaSerializationOperator(Op_MUL,
470 Attribute_MulAttribute,
472 {outputNameAdd6, outputNameLogicalL4},
474 tensors.push_back(
new TosaSerializationTensor(outputNameMul8, reduceShape, DType_INT32, {}));
475 operators.push_back(mulOp8);
477 TosaSerializationHandler::ConvertI32toU8({536870912}, uint8Data);
478 tensors.push_back(
new TosaSerializationTensor(inputNameConst8b, singleValueShape, DType_INT32,uint8Data));
479 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst8b}));
481 auto* subOp5 =
new TosaSerializationOperator(Op_SUB,
484 {inputNameConst8b, outputNameMul8},
486 tensors.push_back(
new TosaSerializationTensor(outputNameSub5, reduceShape, DType_INT32, {}));
487 operators.push_back(subOp5);
489 auto* mulOp9 =
new TosaSerializationOperator(Op_MUL,
490 Attribute_MulAttribute,
492 {outputNameAdd6, outputNameSub5},
494 tensors.push_back(
new TosaSerializationTensor(outputNameMul9, reduceShape, DType_INT32, {}));
495 operators.push_back(mulOp9);
497 TosaSerializationHandler::ConvertI32toU8({4}, uint8Data);
498 tensors.push_back(
new TosaSerializationTensor(inputNameConst9b, singleValueShape, DType_INT32,uint8Data));
499 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst9b}));
501 auto* mulOp10 =
new TosaSerializationOperator(Op_MUL,
502 Attribute_MulAttribute,
504 {outputNameMul9, inputNameConst9b},
506 tensors.push_back(
new TosaSerializationTensor(outputNameMul10, reduceShape, DType_INT32, {}));
507 operators.push_back(mulOp10);
509 auto* addOp7 =
new TosaSerializationOperator(Op_ADD,
512 {outputNameAdd6, outputNameMul10},
514 tensors.push_back(
new TosaSerializationTensor(outputNameAdd7, reduceShape, DType_INT32, {}));
515 operators.push_back(addOp7);
517 TosaMulAttribute mulAttribute3(30);
519 auto* mulOp11 =
new TosaSerializationOperator(Op_MUL,
520 Attribute_MulAttribute,
522 {outputNameAdd3, outputNameAdd7},
524 tensors.push_back(
new TosaSerializationTensor(outputNameMul11, outputShape0, DType_INT32, {}));
525 operators.push_back(mulOp11);
527 TosaSerializationHandler::ConvertI32toU8({35}, uint8Data);
528 tensors.push_back(
new TosaSerializationTensor(inputNameConst10, singleValueShape, DType_INT32,uint8Data));
529 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr, {}, {inputNameConst10}));
531 auto* subOp6 =
new TosaSerializationOperator(Op_SUB,
534 {inputNameConst10, outputNameCLZ1},
536 tensors.push_back(
new TosaSerializationTensor(outputNameSub6, reduceShape, DType_INT32, {}));
537 operators.push_back(subOp6);
539 TosaSerializationHandler::ConvertI32toU8({31}, uint8Data);
540 tensors.push_back(
new TosaSerializationTensor(inputMinConst, singleValueShape, DType_INT32,uint8Data));
541 operators.push_back(
new TosaSerializationOperator(Op_CONST, Attribute_NONE,
nullptr,{}, {inputMinConst}));
543 auto* minOp =
new TosaSerializationOperator(Op_MINIMUM,
544 Attribute_NONE,
nullptr,
545 {outputNameSub6,inputMinConst},
548 tensors.push_back(
new TosaSerializationTensor(outputShiftMin, reduceShape, DType_INT32, {}));
549 operators.push_back(minOp);
551 auto* arithmeticROp3 =
new TosaSerializationOperator(Op_ARITHMETIC_RIGHT_SHIFT,
552 Attribute_ArithmeticRightShiftAttribute,
554 {outputNameMul11, outputShiftMin},
555 {outputNameArithmeticR3});
556 tensors.push_back(
new TosaSerializationTensor(outputNameArithmeticR3, outputShape0, DType_INT32, {}));
557 operators.push_back(arithmeticROp3);
559 TosaSerializationOperator* rescaleOp3 =
nullptr;
571 tensors.push_back(
new TosaSerializationTensor(outputName, outputShape0, outputDType0, {}));
572 operators.push_back(rescaleOp3);
574 return new TosaSerializationBasicBlock(blockName,