ArmNN
 25.11
Loading...
Searching...
No Matches
TfLiteParser.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017-2025 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
8#include "armnn/INetwork.hpp"
10#include "armnn/Types.hpp"
11
12#include <schema_generated.h>
13#include <functional>
14#include <unordered_map>
15#include <vector>
16
17#include <tensorflow/lite/version.h>
18
19#if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
20#define ARMNN_POST_TFLITE_2_3
21#endif
22
23namespace armnnTfLiteParser
24{
25
27{
28public:
29 // Shorthands for TfLite types
30 using ModelPtr = std::unique_ptr<tflite::ModelT>;
31 using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
32 using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
33 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
34 using TensorPtr = std::unique_ptr<tflite::TensorT>;
35 using TensorRawPtr = const tflite::TensorT *;
36 using TensorRawPtrVector = std::vector<TensorRawPtr>;
37 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
38 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
39 using BufferPtr = std::unique_ptr<tflite::BufferT>;
40 using BufferRawPtr = const tflite::BufferT *;
41
42public:
43 /// Create the network from a flatbuffers binary file on disk
45
46 /// Create the network from a flatbuffers binary
47 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
48
49
50 /// Retrieve binding info (layer id and tensor info) for the network input identified by
51 /// the given layer name and subgraph id
53 const std::string& name) const;
54
55 /// Retrieve binding info (layer id and tensor info) for the network output identified by
56 /// the given layer name and subgraph id
58 const std::string& name) const;
59
60 /// Return the number of subgraphs in the parsed model
61 size_t GetSubgraphCount() const;
62
63 /// Return the input tensor names for a given subgraph
64 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
65
66 /// Return the output tensor names for a given subgraph
67 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
68
70 ~TfLiteParserImpl() = default;
71
72public:
73 // testable helpers
74 armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent);
75
76 armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model);
77
78 static ModelPtr LoadModelFromFile(const char* fileName);
79 static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
80 static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
81 static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
82 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
83 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
84 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
85 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
86
87 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
88 static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
89 const armnn::TensorInfo& inputTensorInfo);
90 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
91 const std::vector<int32_t>& targetDimsIn);
92
93 /// Retrieve version in X.Y.Z form
94 static const std::string GetVersion();
95
96private:
97
98 // No copying allowed until it is wanted and properly implemented
99 TfLiteParserImpl(const TfLiteParserImpl &) = delete;
100 TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;
101
102 /// Create the network from an already loaded flatbuffers model
103 armnn::INetworkPtr CreateNetworkFromModel();
104
105 // signature for the parser functions
106 using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);
107
108 void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
109 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
110
111 void ParseAbs(size_t subgraphIndex, size_t operatorIndex);
112 void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
113 void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
114 void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
115 void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
116 void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
117 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
118 void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
119 void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
120 void ParseBroadcastTo(size_t subgraphIndex, size_t operatorIndex);
121 void ParseCast(size_t subgraphIndex, size_t operatorIndex);
122 void ParseCeil(size_t subgraphIndex, size_t operatorIndex);
123 void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
124 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
125 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
126 // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
127 #if defined(ARMNN_POST_TFLITE_2_4)
128 void ParseConv3D(size_t subgraphIndex, size_t operatorIndex);
129 #endif
130 void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
131 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
132 void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
133 void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
134 void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
135 void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
136 void ParseElu(size_t subgraphIndex, size_t operatorIndex);
137 void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
138 void ParseExp(size_t subgraphIndex, size_t operatorIndex);
139 void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
140 void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex);
141 void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
142 void ParseGather(size_t subgraphIndex, size_t operatorIndex);
143 void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex);
144 void ParseGelu(size_t subgraphIndex, size_t operatorIndex);
145 void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
146 void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
147 void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
148 void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
149 void ParseLess(size_t subgraphIndex, size_t operatorIndex);
150 void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
151 void ParseLog(size_t subgraphIndex, size_t operatorIndex);
152 void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex);
153 void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
154 void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
155 void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex);
156 void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
157 void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
158 void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
159 void ParseMean(size_t subgraphIndex, size_t operatorIndex);
160 void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
161 void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex);
162 void ParseMul(size_t subgraphIndex, size_t operatorIndex);
163 void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
164 void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
165 void ParsePack(size_t subgraphIndex, size_t operatorIndex);
166 void ParsePad(size_t subgraphIndex, size_t operatorIndex);
167 void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
168 void ParsePower(size_t subgraphIndex, size_t operatorIndex);
169 void ParsePrelu(size_t subgraphIndex, size_t operatorIndex);
170 void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
171 void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation);
172 void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex);
173 void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex);
174 void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex);
175 void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
176 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
177 void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
178 void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
179 void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
180 void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
181 void ParseReverseV2(size_t subgraphIndex, size_t operatorIndex);
182 void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex);
183 void ParseScatterNd(size_t subgraphIndex, size_t operatorIndex);
184 void ParseShape(size_t subgraphIndex, size_t operatorIndex);
185 void ParseSin(size_t subgraphIndex, size_t operatorIndex);
186 void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
187 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
188 void ParseSqrt(size_t subgraphIndex, size_t operatorIndex);
189 void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
190 void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex);
191 void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
192 void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
193 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
194 void ParseSquare(size_t subgraphIndex, size_t operatorIndex);
195 void ParseSquaredDifference(size_t subgraphIndex, size_t operatorIndex);
196 void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
197 void ParseSub(size_t subgraphIndex, size_t operatorIndex);
198 void ParseSum(size_t subgraphIndex, size_t operatorIndex);
199 void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
200 void ParseTile(size_t subgraphIndex, size_t operatorIndex);
201 void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
202 void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
203 void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
204 void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
205
206 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
207 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
208 void RegisterInputSlots(size_t subgraphIndex,
209 size_t operatorIndex,
211 const std::vector<unsigned int>& tensorIndexes,
212 unsigned int startingSlotIndex = 0);
213 void RegisterOutputSlots(size_t subgraphIndex,
214 size_t operatorIndex,
216 const std::vector<unsigned int>& tensorIndexes);
217
218 void SetupInputLayerTensorInfos(size_t subgraphIndex);
219 void SetupConstantLayerTensorInfos(size_t subgraphIndex);
220
221 void SetupInputLayers(size_t subgraphIndex);
222 void SetupOutputLayers(size_t subgraphIndex);
223 void SetupConstantLayers(size_t subgraphIndex);
224
225 void ResetParser();
226
227 // Function that checks the provided buffer is valid.
228 static void ValidateBuffer(BufferRawPtr bufferPtr,
229 const armnn::TensorInfo& tensorInfo,
230 const std::string& bufferName);
231
232 void AddBroadcastReshapeLayer(size_t subgraphIndex,
233 size_t operatorIndex,
235
236 /// Attach an reshape layer to the one passed as a parameter
238 unsigned int outputSlot,
239 std::string reshapeLayerName,
240 armnn::TensorInfo outputShape);
241
242 /// Attach an activation layer to the one passed as a parameter
243 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
244 unsigned int outputSlot,
245 tflite::ActivationFunctionType activationType);
246
247 /// Attach a floor layer to the one passed as a parameter
248 armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot);
249
250 // SupportedDataStorage's purpose is to hold data till we pass over to the network.
251 // We don't care about the content, and we want a single datatype to simplify the code.
252 struct SupportedDataStorage
253 {
254 public:
255 // Convenience constructors
256 SupportedDataStorage(std::unique_ptr<float[]>&& data);
257 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
258 SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
259 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
260
261 private:
262 // Pointers to the data buffers
263 std::unique_ptr<float[]> m_FloatData;
264 std::unique_ptr<uint8_t[]> m_Uint8Data;
265 std::unique_ptr<int8_t[]> m_Int8Data;
266 std::unique_ptr<int32_t[]> m_Int32Data;
267 };
268
269 bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
270
271 bool IsConstTensor(TensorRawPtr tensorPtr);
272
273 bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
274 armnn::DataType inputDataType,
275 armnn::DataType filterDataType);
276
277 armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
278 armnn::TensorInfo& tensorInfo);
279
280 std::pair<armnn::ConstTensor, SupportedDataStorage>
281 CreateConstTensorPermuted(TensorRawPtr tensorPtr,
282 armnn::TensorInfo& tensorInfo,
284
285 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
286 CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
287 armnn::TensorInfo& tensorInfo,
288 armnn::DataType inputDataType);
289
290 template<typename T>
291 std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
292 CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
294 armnn::TensorInfo& tensorInfo,
296
297 std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
298 CreateConstTensorPtr(TensorRawPtr tensorPtr,
299 armnn::TensorInfo& inputTensorInfo);
300
301 armnn::TensorInfo InputTensorInfo(size_t subgraphIndex,
302 size_t operatorIndex,
303 int input);
304
305 armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex,
306 size_t operatorIndex,
308 int output,
309 std::vector<int> inputs);
310
311 armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex,
312 size_t operatorIndex,
314 int output = 0,
315 std::vector<armnn::TensorShape> inputShapes = {});
316
317 /// Settings for configuring the TfLiteParser
319
320 /// The network we're building. Gets cleared after it is passed to the user
321 armnn::INetworkPtr m_Network;
322 ModelPtr m_Model;
323
324 std::vector<OperatorParsingFunction> m_ParserFunctions;
325 std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
326
327 /// A mapping of an output slot to each of the input slots it should be connected to
328 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
329 /// The inputSlots are from the layers that use this tensor as one of their inputs
330 struct TensorSlots
331 {
332 armnn::IOutputSlot* outputSlot;
333 std::vector<armnn::IInputSlot*> inputSlots;
334
335 TensorSlots() : outputSlot(nullptr) { }
336 };
337 typedef std::vector<TensorSlots> TensorConnections;
338 /// Connections for tensors in each subgraph
339 /// The first index is the subgraph ID, the second index is the tensor ID
340 std::vector<TensorConnections> m_SubgraphConnections;
341
342 /// This is used in case that the model does not specify the output.
343 /// The shape can be calculated from the options.
344 std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;
345
346 std::vector<unsigned int> m_ConstantsToDequantize;
347 std::vector<unsigned int> m_ConstantsToBeCreated;
348 std::map<size_t, armnn::TensorInfo> m_TensorInfos;
349};
350
351}
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition INetwork.hpp:81
An input connection slot for a layer.
Definition INetwork.hpp:26
An output connection slot for a layer.
Definition INetwork.hpp:54
armnn::INetworkPtr LoadModel(std::unique_ptr< tflite::ModelT > model)
size_t GetSubgraphCount() const
Return the number of subgraphs in the parsed model.
static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr &model, size_t subgraphIndex)
static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr &model, size_t subgraphIndex)
static TensorRawPtrVector GetOutputs(const ModelPtr &model, size_t subgraphIndex, size_t operatorIndex)
armnn::INetworkPtr CreateNetworkFromBinary(const std::vector< uint8_t > &binaryContent)
Create the network from a flatbuffers binary.
static TensorRawPtrVector GetInputs(const ModelPtr &model, size_t subgraphIndex, size_t operatorIndex)
std::unique_ptr< tflite::TensorT > TensorPtr
BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, const std::string &name) const
Retrieve binding info (layer id and tensor info) for the network output identified by the given layer...
static BufferRawPtr GetBuffer(const ModelPtr &model, size_t bufferIndex)
std::pair< size_t, TensorRawPtr > TensorIdRawPtr
static armnn::TensorInfo OutputShapeOfSqueeze(std::vector< uint32_t > squeezeDims, const armnn::TensorInfo &inputTensorInfo)
static ModelPtr LoadModelFromBinary(const uint8_t *binaryContent, size_t len)
std::vector< TensorIdRawPtr > TensorIdRawPtrVector
static std::vector< int32_t > & GetInputTensorIds(const ModelPtr &model, size_t subgraphIndex, size_t operatorIndex)
BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, const std::string &name) const
Retrieve binding info (layer id and tensor info) for the network input identified by the given layer ...
std::vector< std::string > GetSubgraphOutputTensorNames(size_t subgraphId) const
Return the output tensor names for a given subgraph.
std::unique_ptr< tflite::SubGraphT > SubgraphPtr
static const std::string GetVersion()
Retrieve version in X.Y.Z form.
const tflite::BufferT * BufferRawPtr
std::unique_ptr< tflite::OperatorT > OperatorPtr
static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo &inputTensorInfo, const std::vector< int32_t > &targetDimsIn)
std::unique_ptr< tflite::OperatorCodeT > OperatorCodePtr
std::unique_ptr< tflite::BufferT > BufferPtr
std::vector< TensorRawPtr > TensorRawPtrVector
std::unique_ptr< tflite::ModelT > ModelPtr
const tflite::TensorT * TensorRawPtr
armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector< uint8_t > &binaryContent)
TfLiteParserImpl(const armnn::Optional< ITfLiteParser::TfLiteParserOptions > &options=armnn::EmptyOptional())
armnn::INetworkPtr CreateNetworkFromBinaryFile(const char *graphFile)
Create the network from a flatbuffers binary file on disk.
static ModelPtr LoadModelFromFile(const char *fileName)
std::vector< std::string > GetSubgraphInputTensorNames(size_t subgraphId) const
Return the input tensor names for a given subgraph.
static std::vector< int32_t > & GetOutputTensorIds(const ModelPtr &model, size_t subgraphIndex, size_t operatorIndex)
UnaryOperation
Definition Types.hpp:126
ComparisonOperation
Definition Types.hpp:110
ActivationFunction
Definition Types.hpp:87
PoolingAlgorithm
Definition Types.hpp:152
ResizeMethod
Definition Types.hpp:168
ReduceOperation
Definition Types.hpp:159
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition INetwork.hpp:339
DataType
Definition Types.hpp:49
ArgMinMaxFunction
Definition Types.hpp:104
armnn::BindingPointInfo BindingPointInfo
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition Optional.hpp:32