ArmNN
 25.11
Loading...
Searching...
No Matches
ConcatOperator.cpp File Reference
#include "ConcatOperator.hpp"
Include dependency graph for ConcatOperator.cpp:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertConcatToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const OriginsDescriptor *concatDescriptor)

Function Documentation

◆ ConvertConcatToTosaOperator()

TosaSerializationBasicBlock * ConvertConcatToTosaOperator ( const Layer * layer,
const std::vector< const TensorInfo * > & inputs,
const std::vector< const TensorInfo * > & outputs,
const OriginsDescriptor * concatDescriptor )

Definition at line 8 of file ConcatOperator.cpp.

12{
13 auto numInputs = inputs.size();
14 std::vector<std::string> inputNames;
15 inputNames.reserve(numInputs);
16 std::string outputName = std::string("output0_");
17 std::string blockName = std::string("Op_CONCAT_block_") + GetUniqueTosaMappingID();
18
19 // Set input names for validation purposes only.
20 if (layer == nullptr)
21 {
22 for (uint32_t i = 0; i < numInputs; ++i)
23 {
24 inputNames.push_back("input_"+ std::to_string(i));
25 }
26 }
27 // If a layer is present then the block will be used for execution, so input and output names need to be determined
28 // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
29 else
30 {
31 // Get the layers connected to the input slots and determine unique tensor names.
32 for (uint32_t i = 0; i < numInputs; ++i)
33 {
34 std::string inputName = GenerateUniqueInputName(layer->GetInputSlot(i));
35 inputNames.push_back(inputName);
36 }
37
38 // Determine unique output tensor name.
39 outputName = GenerateUniqueOutputName(*layer);
40 }
41
42 auto axis = static_cast<int32_t>(concatDescriptor->GetConcatAxis());
43 TosaAxisAttribute attribute(axis);
44
45 TosaSerializationOperator* op = new TosaSerializationOperator(Op_CONCAT,
46 Attribute_AxisAttribute,
47 &attribute,
48 inputNames,
49 {outputName});
50
51 std::vector<TosaSerializationTensor*> tensors;
52 tensors.reserve(numInputs + 1);
53 for (uint32_t i = 0; i < numInputs; ++i)
54 {
55 // Only add input tensors for validation or when the connected layer is an input layer.
56 // As there can't be duplicate tensors and intermediate or constant tensors are created separately.
57 if(inputNames[i].find("input") != std::string::npos)
58 {
59 std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[i]->GetShape());
60 DType inputDType = ArmNNToDType(inputs[i]->GetDataType());
61 tensors.push_back(new TosaSerializationTensor(inputNames[i], inputShape, inputDType, {}));
62 }
63 }
64
65 std::vector<int32_t> outputShape0 = GetTosaTensorShape(outputs[0]->GetShape());
66 DType outputDType0 = ArmNNToDType(outputs[0]->GetDataType());
67
68 TosaSerializationTensor* outputTensor0 = new TosaSerializationTensor(outputName, outputShape0, outputDType0, {});
69 tensors.push_back(outputTensor0);
70
71 // operatorInputNames/operatorOutputNames ends up being the same as
72 // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
73 return new TosaSerializationBasicBlock(blockName, // name
74 mainName, // region name
75 {op}, // operators
76 tensors, // tensors
77 inputNames, // inputs
78 {outputName}); // outputs
79}
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot=0)
const std::string mainName
DType ArmNNToDType(const DataType &type)
std::string GenerateUniqueInputName(const armnn::InputSlot &slot)
std::string GetUniqueTosaMappingID()
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition Layer.hpp:337
unsigned int GetConcatAxis() const
Get the concatenation axis value.

References ArmNNToDType(), GenerateUniqueInputName(), GenerateUniqueOutputName(), OriginsDescriptor::GetConcatAxis(), Layer::GetInputSlot(), GetTosaTensorShape(), GetUniqueTosaMappingID(), and mainName.

Referenced by GetTosaMapping().