ArmNN
 25.11
Loading...
Searching...
No Matches
IDeserializer::DeserializerImpl Class Reference

#include <Deserializer.hpp>

Public Member Functions

armnn::INetworkPtr CreateNetworkFromBinary (const std::vector< uint8_t > &binaryContent)
 Create an input network from binary file contents.
armnn::INetworkPtr CreateNetworkFromBinary (std::istream &binaryContent)
 Create an input network from a binary input stream.
BindingPointInfo GetNetworkInputBindingInfo (unsigned int layerId, const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
BindingPointInfo GetNetworkOutputBindingInfo (unsigned int layerId, const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
 DeserializerImpl ()
 ~DeserializerImpl ()=default
 DeserializerImpl (const DeserializerImpl &)=delete
DeserializerImploperator= (const DeserializerImpl &)=delete

Static Public Member Functions

static GraphPtr LoadGraphFromBinary (const uint8_t *binaryContent, size_t len)
static TensorRawPtrVector GetInputs (const GraphPtr &graph, unsigned int layerIndex)
static TensorRawPtrVector GetOutputs (const GraphPtr &graph, unsigned int layerIndex)
static LayerBaseRawPtr GetBaseLayer (const GraphPtr &graphPtr, unsigned int layerIndex)
static int32_t GetBindingLayerInfo (const GraphPtr &graphPtr, unsigned int layerIndex)
static std::string GetLayerName (const GraphPtr &graph, unsigned int index)
static armnn::Pooling2dDescriptor GetPooling2dDescriptor (Pooling2dDescriptor pooling2dDescriptor, unsigned int layerIndex)
static armnn::Pooling3dDescriptor GetPooling3dDescriptor (Pooling3dDescriptor pooling3dDescriptor, unsigned int layerIndex)
static armnn::NormalizationDescriptor GetNormalizationDescriptor (NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex)
static armnn::LstmDescriptor GetLstmDescriptor (LstmDescriptorPtr lstmDescriptor)
static armnn::LstmInputParams GetLstmInputParams (LstmDescriptorPtr lstmDescriptor, LstmInputParamsPtr lstmInputParams)
static armnn::QLstmDescriptor GetQLstmDescriptor (QLstmDescriptorPtr qLstmDescriptorPtr)
static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor (UnidirectionalSequenceLstmDescriptorPtr descriptor)
static armnn::TensorInfo OutputShapeOfReshape (const armnn::TensorInfo &inputTensorInfo, const std::vector< uint32_t > &targetDimsIn)

Detailed Description

Definition at line 34 of file Deserializer.hpp.

Constructor & Destructor Documentation

◆ DeserializerImpl() [1/2]

Definition at line 207 of file Deserializer.cpp.

208: m_Network(nullptr, nullptr),
209//May require LayerType_Max to be included
210m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupportedLayer)
211{
212 // register supported layers
213 m_ParserFunctions[Layer_AbsLayer] = &DeserializerImpl::ParseAbs;
214 m_ParserFunctions[Layer_ActivationLayer] = &DeserializerImpl::ParseActivation;
215 m_ParserFunctions[Layer_AdditionLayer] = &DeserializerImpl::ParseAdd;
216 m_ParserFunctions[Layer_ArgMinMaxLayer] = &DeserializerImpl::ParseArgMinMax;
217 m_ParserFunctions[Layer_BatchMatMulLayer] = &DeserializerImpl::ParseBatchMatMul;
218 m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &DeserializerImpl::ParseBatchToSpaceNd;
219 m_ParserFunctions[Layer_BatchNormalizationLayer] = &DeserializerImpl::ParseBatchNormalization;
220 m_ParserFunctions[Layer_CastLayer] = &DeserializerImpl::ParseCast;
221 m_ParserFunctions[Layer_ChannelShuffleLayer] = &DeserializerImpl::ParseChannelShuffle;
222 m_ParserFunctions[Layer_ComparisonLayer] = &DeserializerImpl::ParseComparison;
223 m_ParserFunctions[Layer_ConcatLayer] = &DeserializerImpl::ParseConcat;
224 m_ParserFunctions[Layer_ConstantLayer] = &DeserializerImpl::ParseConstant;
225 m_ParserFunctions[Layer_Convolution2dLayer] = &DeserializerImpl::ParseConvolution2d;
226 m_ParserFunctions[Layer_Convolution3dLayer] = &DeserializerImpl::ParseConvolution3d;
227 m_ParserFunctions[Layer_DepthToSpaceLayer] = &DeserializerImpl::ParseDepthToSpace;
228 m_ParserFunctions[Layer_DepthwiseConvolution2dLayer] = &DeserializerImpl::ParseDepthwiseConvolution2d;
229 m_ParserFunctions[Layer_DequantizeLayer] = &DeserializerImpl::ParseDequantize;
230 m_ParserFunctions[Layer_DetectionPostProcessLayer] = &DeserializerImpl::ParseDetectionPostProcess;
231 m_ParserFunctions[Layer_DivisionLayer] = &DeserializerImpl::ParseDivision;
232 m_ParserFunctions[Layer_ElementwiseBinaryLayer] = &DeserializerImpl::ParseElementwiseBinary;
233 m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &DeserializerImpl::ParseElementwiseUnary;
234 m_ParserFunctions[Layer_EqualLayer] = &DeserializerImpl::ParseEqual;
235 m_ParserFunctions[Layer_FullyConnectedLayer] = &DeserializerImpl::ParseFullyConnected;
236 m_ParserFunctions[Layer_FillLayer] = &DeserializerImpl::ParseFill;
237 m_ParserFunctions[Layer_FloorLayer] = &DeserializerImpl::ParseFloor;
238 m_ParserFunctions[Layer_GatherLayer] = &DeserializerImpl::ParseGather;
239 m_ParserFunctions[Layer_GatherNdLayer] = &DeserializerImpl::ParseGatherNd;
240 m_ParserFunctions[Layer_GreaterLayer] = &DeserializerImpl::ParseGreater;
241 m_ParserFunctions[Layer_InstanceNormalizationLayer] = &DeserializerImpl::ParseInstanceNormalization;
242 m_ParserFunctions[Layer_L2NormalizationLayer] = &DeserializerImpl::ParseL2Normalization;
243 m_ParserFunctions[Layer_LogicalBinaryLayer] = &DeserializerImpl::ParseLogicalBinary;
244 m_ParserFunctions[Layer_LogSoftmaxLayer] = &DeserializerImpl::ParseLogSoftmax;
245 m_ParserFunctions[Layer_LstmLayer] = &DeserializerImpl::ParseLstm;
246 m_ParserFunctions[Layer_MaximumLayer] = &DeserializerImpl::ParseMaximum;
247 m_ParserFunctions[Layer_MeanLayer] = &DeserializerImpl::ParseMean;
248 m_ParserFunctions[Layer_MinimumLayer] = &DeserializerImpl::ParseMinimum;
249 m_ParserFunctions[Layer_MergeLayer] = &DeserializerImpl::ParseMerge;
250 m_ParserFunctions[Layer_MergerLayer] = &DeserializerImpl::ParseConcat;
251 m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializerImpl::ParseMultiplication;
252 m_ParserFunctions[Layer_NormalizationLayer] = &DeserializerImpl::ParseNormalization;
253 m_ParserFunctions[Layer_PadLayer] = &DeserializerImpl::ParsePad;
254 m_ParserFunctions[Layer_PermuteLayer] = &DeserializerImpl::ParsePermute;
255 m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializerImpl::ParsePooling2d;
256 m_ParserFunctions[Layer_Pooling3dLayer] = &DeserializerImpl::ParsePooling3d;
257 m_ParserFunctions[Layer_PreluLayer] = &DeserializerImpl::ParsePrelu;
258 m_ParserFunctions[Layer_QLstmLayer] = &DeserializerImpl::ParseQLstm;
259 m_ParserFunctions[Layer_QuantizeLayer] = &DeserializerImpl::ParseQuantize;
260 m_ParserFunctions[Layer_QuantizedLstmLayer] = &DeserializerImpl::ParseQuantizedLstm;
261 m_ParserFunctions[Layer_RankLayer] = &DeserializerImpl::ParseRank;
262 m_ParserFunctions[Layer_ReduceLayer] = &DeserializerImpl::ParseReduce;
263 m_ParserFunctions[Layer_ReshapeLayer] = &DeserializerImpl::ParseReshape;
264 m_ParserFunctions[Layer_ResizeBilinearLayer] = &DeserializerImpl::ParseResizeBilinear;
265 m_ParserFunctions[Layer_ResizeLayer] = &DeserializerImpl::ParseResize;
266 m_ParserFunctions[Layer_ReverseV2Layer] = &DeserializerImpl::ParseReverseV2;
267 m_ParserFunctions[Layer_RsqrtLayer] = &DeserializerImpl::ParseRsqrt;
268 m_ParserFunctions[Layer_ScatterNdLayer] = &DeserializerImpl::ParseScatterNd;
269 m_ParserFunctions[Layer_ShapeLayer] = &DeserializerImpl::ParseShape;
270 m_ParserFunctions[Layer_SliceLayer] = &DeserializerImpl::ParseSlice;
271 m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializerImpl::ParseSoftmax;
272 m_ParserFunctions[Layer_SpaceToBatchNdLayer] = &DeserializerImpl::ParseSpaceToBatchNd;
273 m_ParserFunctions[Layer_SpaceToDepthLayer] = &DeserializerImpl::ParseSpaceToDepth;
274 m_ParserFunctions[Layer_SplitterLayer] = &DeserializerImpl::ParseSplitter;
275 m_ParserFunctions[Layer_StackLayer] = &DeserializerImpl::ParseStack;
276 m_ParserFunctions[Layer_StandInLayer] = &DeserializerImpl::ParseStandIn;
277 m_ParserFunctions[Layer_StridedSliceLayer] = &DeserializerImpl::ParseStridedSlice;
278 m_ParserFunctions[Layer_SubtractionLayer] = &DeserializerImpl::ParseSubtraction;
279 m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch;
280 m_ParserFunctions[Layer_TileLayer] = &DeserializerImpl::ParseTile;
281 m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d;
282 m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose;
283 m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm;
284}

References DeserializerImpl().

Referenced by DeserializerImpl(), DeserializerImpl(), GetLstmInputParams(), and operator=().

◆ ~DeserializerImpl()

~DeserializerImpl ( )
default

◆ DeserializerImpl() [2/2]

DeserializerImpl ( const DeserializerImpl & )
delete

References DeserializerImpl().

Member Function Documentation

◆ CreateNetworkFromBinary() [1/2]

INetworkPtr CreateNetworkFromBinary ( const std::vector< uint8_t > & binaryContent)

Create an input network from binary file contents.

Definition at line 878 of file Deserializer.cpp.

879{
880 ResetParser();
881 GraphPtr graph = LoadGraphFromBinary(binaryContent.data(), binaryContent.size());
882 return CreateNetworkFromGraph(graph);
883}
const armnnSerializer::SerializedGraph * GraphPtr

References LoadGraphFromBinary().

◆ CreateNetworkFromBinary() [2/2]

armnn::INetworkPtr CreateNetworkFromBinary ( std::istream & binaryContent)

Create an input network from a binary input stream.

Definition at line 885 of file Deserializer.cpp.

886{
887 ResetParser();
888 if (binaryContent.fail()) {
889 ARMNN_LOG(error) << (std::string("Cannot read input"));
890 throw ParseException("Unable to read Input stream data");
891 }
892 binaryContent.seekg(0, std::ios::end);
893 const std::streamoff size = binaryContent.tellg();
894 std::vector<char> content(static_cast<size_t>(size));
895 binaryContent.seekg(0);
896 binaryContent.read(content.data(), static_cast<std::streamsize>(size));
897 GraphPtr graph = LoadGraphFromBinary(reinterpret_cast<uint8_t*>(content.data()), static_cast<size_t>(size));
898 return CreateNetworkFromGraph(graph);
899}
#define ARMNN_LOG(severity)
Definition Logging.hpp:212

References ARMNN_LOG, armnn::error, and LoadGraphFromBinary().

◆ GetBaseLayer()

LayerBaseRawPtr GetBaseLayer ( const GraphPtr & graphPtr,
unsigned int layerIndex )
static

Definition at line 286 of file Deserializer.cpp.

287{
288 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
289
290 switch(layerType)
291 {
292 case Layer::Layer_AbsLayer:
293 return graphPtr->layers()->Get(layerIndex)->layer_as_AbsLayer()->base();
294 case Layer::Layer_ActivationLayer:
295 return graphPtr->layers()->Get(layerIndex)->layer_as_ActivationLayer()->base();
296 case Layer::Layer_AdditionLayer:
297 return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base();
298 case Layer::Layer_ArgMinMaxLayer:
299 return graphPtr->layers()->Get(layerIndex)->layer_as_ArgMinMaxLayer()->base();
300 case Layer::Layer_BatchMatMulLayer:
301 return graphPtr->layers()->Get(layerIndex)->layer_as_BatchMatMulLayer()->base();
302 case Layer::Layer_BatchToSpaceNdLayer:
303 return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base();
304 case Layer::Layer_BatchNormalizationLayer:
305 return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
306 case Layer::Layer_CastLayer:
307 return graphPtr->layers()->Get(layerIndex)->layer_as_CastLayer()->base();
308 case Layer::Layer_ChannelShuffleLayer:
309 return graphPtr->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->base();
310 case Layer::Layer_ComparisonLayer:
311 return graphPtr->layers()->Get(layerIndex)->layer_as_ComparisonLayer()->base();
312 case Layer::Layer_ConcatLayer:
313 return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base();
314 case Layer::Layer_ConstantLayer:
315 return graphPtr->layers()->Get(layerIndex)->layer_as_ConstantLayer()->base();
316 case Layer::Layer_Convolution2dLayer:
317 return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution2dLayer()->base();
318 case Layer::Layer_Convolution3dLayer:
319 return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution3dLayer()->base();
320 case Layer::Layer_DepthToSpaceLayer:
321 return graphPtr->layers()->Get(layerIndex)->layer_as_DepthToSpaceLayer()->base();
322 case Layer::Layer_DepthwiseConvolution2dLayer:
323 return graphPtr->layers()->Get(layerIndex)->layer_as_DepthwiseConvolution2dLayer()->base();
324 case Layer::Layer_DequantizeLayer:
325 return graphPtr->layers()->Get(layerIndex)->layer_as_DequantizeLayer()->base();
326 case Layer::Layer_DetectionPostProcessLayer:
327 return graphPtr->layers()->Get(layerIndex)->layer_as_DetectionPostProcessLayer()->base();
328 case Layer::Layer_DivisionLayer:
329 return graphPtr->layers()->Get(layerIndex)->layer_as_DivisionLayer()->base();
330 case Layer::Layer_EqualLayer:
331 return graphPtr->layers()->Get(layerIndex)->layer_as_EqualLayer()->base();
332 case Layer::Layer_ElementwiseBinaryLayer:
333 return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer()->base();
334 case Layer::Layer_ElementwiseUnaryLayer:
335 return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer()->base();
336 case Layer::Layer_FullyConnectedLayer:
337 return graphPtr->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer()->base();
338 case Layer::Layer_FillLayer:
339 return graphPtr->layers()->Get(layerIndex)->layer_as_FillLayer()->base();
340 case Layer::Layer_FloorLayer:
341 return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base();
342 case Layer::Layer_GatherLayer:
343 return graphPtr->layers()->Get(layerIndex)->layer_as_GatherLayer()->base();
344 case Layer::Layer_GatherNdLayer:
345 return graphPtr->layers()->Get(layerIndex)->layer_as_GatherNdLayer()->base();
346 case Layer::Layer_GreaterLayer:
347 return graphPtr->layers()->Get(layerIndex)->layer_as_GreaterLayer()->base();
348 case Layer::Layer_InputLayer:
349 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
350 case Layer::Layer_InstanceNormalizationLayer:
351 return graphPtr->layers()->Get(layerIndex)->layer_as_InstanceNormalizationLayer()->base();
352 case Layer::Layer_L2NormalizationLayer:
353 return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base();
354 case Layer::Layer_LogicalBinaryLayer:
355 return graphPtr->layers()->Get(layerIndex)->layer_as_LogicalBinaryLayer()->base();
356 case Layer::Layer_LogSoftmaxLayer:
357 return graphPtr->layers()->Get(layerIndex)->layer_as_LogSoftmaxLayer()->base();
358 case Layer::Layer_LstmLayer:
359 return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base();
360 case Layer::Layer_MeanLayer:
361 return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base();
362 case Layer::Layer_MinimumLayer:
363 return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base();
364 case Layer::Layer_MaximumLayer:
365 return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base();
366 case Layer::Layer_MergeLayer:
367 return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base();
368 case Layer::Layer_MergerLayer:
369 return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base();
370 case Layer::Layer_MultiplicationLayer:
371 return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
372 case Layer::Layer_NormalizationLayer:
373 return graphPtr->layers()->Get(layerIndex)->layer_as_NormalizationLayer()->base();
374 case Layer::Layer_OutputLayer:
375 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
376 case Layer::Layer_PadLayer:
377 return graphPtr->layers()->Get(layerIndex)->layer_as_PadLayer()->base();
378 case Layer::Layer_PermuteLayer:
379 return graphPtr->layers()->Get(layerIndex)->layer_as_PermuteLayer()->base();
380 case Layer::Layer_Pooling2dLayer:
381 return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
382 case Layer::Layer_Pooling3dLayer:
383 return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling3dLayer()->base();
384 case Layer::Layer_PreluLayer:
385 return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
386 case Layer::Layer_QLstmLayer:
387 return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
388 case Layer::Layer_QuantizeLayer:
389 return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
390 case Layer::Layer_QuantizedLstmLayer:
391 return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer()->base();
392 case Layer::Layer_RankLayer:
393 return graphPtr->layers()->Get(layerIndex)->layer_as_RankLayer()->base();
394 case Layer::Layer_ReduceLayer:
395 return graphPtr->layers()->Get(layerIndex)->layer_as_ReduceLayer()->base();
396 case Layer::Layer_ReshapeLayer:
397 return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
398 case Layer::Layer_ResizeBilinearLayer:
399 return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeBilinearLayer()->base();
400 case Layer::Layer_ResizeLayer:
401 return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeLayer()->base();
402 case Layer::Layer_ReverseV2Layer:
403 return graphPtr->layers()->Get(layerIndex)->layer_as_ReverseV2Layer()->base();
404 case Layer::Layer_RsqrtLayer:
405 return graphPtr->layers()->Get(layerIndex)->layer_as_RsqrtLayer()->base();
406 case Layer::Layer_ScatterNdLayer:
407 return graphPtr->layers()->Get(layerIndex)->layer_as_ScatterNdLayer()->base();
408 case Layer::Layer_ShapeLayer:
409 return graphPtr->layers()->Get(layerIndex)->layer_as_ShapeLayer()->base();
410 case Layer::Layer_SliceLayer:
411 return graphPtr->layers()->Get(layerIndex)->layer_as_SliceLayer()->base();
412 case Layer::Layer_SoftmaxLayer:
413 return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
414 case Layer::Layer_SpaceToBatchNdLayer:
415 return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToBatchNdLayer()->base();
416 case Layer::Layer_SpaceToDepthLayer:
417 return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToDepthLayer()->base();
418 case Layer::Layer_SplitterLayer:
419 return graphPtr->layers()->Get(layerIndex)->layer_as_SplitterLayer()->base();
420 case Layer::Layer_StackLayer:
421 return graphPtr->layers()->Get(layerIndex)->layer_as_StackLayer()->base();
422 case Layer::Layer_StandInLayer:
423 return graphPtr->layers()->Get(layerIndex)->layer_as_StandInLayer()->base();
424 case Layer::Layer_StridedSliceLayer:
425 return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
426 case Layer::Layer_SubtractionLayer:
427 return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base();
428 case Layer::Layer_SwitchLayer:
429 return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base();
430 case Layer::Layer_TileLayer:
431 return graphPtr->layers()->Get(layerIndex)->layer_as_TileLayer()->base();
432 case Layer::Layer_TransposeConvolution2dLayer:
433 return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base();
434 case Layer::Layer_TransposeLayer:
435 return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base();
436 case Layer::Layer_UnidirectionalSequenceLstmLayer:
437 return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base();
438 case Layer::Layer_NONE:
439 default:
440 throw ParseException(fmt::format("Layer type {} not recognized", layerType));
441 }
442}

Referenced by GetInputs(), GetLayerName(), and GetOutputs().

◆ GetBindingLayerInfo()

int32_t GetBindingLayerInfo ( const GraphPtr & graphPtr,
unsigned int layerIndex )
static

Definition at line 451 of file Deserializer.cpp.

452{
453 auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
454
455 if (layerType == Layer::Layer_InputLayer)
456 {
457 return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->layerBindingId();
458 }
459 else if ( layerType == Layer::Layer_OutputLayer )
460 {
461 return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->layerBindingId();
462 }
463 return 0;
464}

◆ GetInputs()

TensorRawPtrVector GetInputs ( const GraphPtr & graph,
unsigned int layerIndex )
static

Definition at line 827 of file Deserializer.cpp.

828{
829 CHECK_LAYERS(graphPtr, 0, layerIndex);
830 auto layer = GetBaseLayer(graphPtr, layerIndex);
831 const auto& numInputs = layer->inputSlots()->size();
832
833 TensorRawPtrVector result(numInputs);
834
835 for (unsigned int i=0; i<numInputs; ++i)
836 {
837 auto inputId = CHECKED_NON_NEGATIVE(static_cast<int32_t>
838 (layer->inputSlots()->Get(i)->connection()->sourceLayerIndex()));
839 result[i] = GetBaseLayer(graphPtr, inputId)->outputSlots()->Get(0)->tensorInfo();
840 }
841 return result;
842}
#define CHECK_LAYERS(GRAPH, LAYERS_INDEX, LAYER_INDEX)
#define CHECKED_NON_NEGATIVE(VALUE)
std::vector< TensorRawPtr > TensorRawPtrVector

References CHECK_LAYERS, CHECKED_NON_NEGATIVE, and GetBaseLayer().

◆ GetLayerName()

std::string GetLayerName ( const GraphPtr & graph,
unsigned int index )
static

Definition at line 444 of file Deserializer.cpp.

445{
446 auto layer = GetBaseLayer(graph, index);
447 assert(layer);
448 return layer->layerName()->str();
449}

References GetBaseLayer().

◆ GetLstmDescriptor()

armnn::LstmDescriptor GetLstmDescriptor ( LstmDescriptorPtr lstmDescriptor)
static

Definition at line 3309 of file Deserializer.cpp.

3310{
3311 armnn::LstmDescriptor desc;
3312
3313 desc.m_ActivationFunc = lstmDescriptor->activationFunc();
3314 desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell();
3315 desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj();
3316 desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
3317 desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
3318 desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
3319 desc.m_LayerNormEnabled = lstmDescriptor->layerNormEnabled();
3320
3321 return desc;
3322}
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
float m_ClippingThresCell
Clipping threshold value for the cell state.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_ClippingThresProj
Clipping threshold value for the projection.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
uint32_t m_ActivationFunc
The activation function to use.

References LstmDescriptor::m_ActivationFunc, LstmDescriptor::m_CifgEnabled, LstmDescriptor::m_ClippingThresCell, LstmDescriptor::m_ClippingThresProj, LstmDescriptor::m_LayerNormEnabled, LstmDescriptor::m_PeepholeEnabled, and LstmDescriptor::m_ProjectionEnabled.

◆ GetLstmInputParams()

armnn::LstmInputParams GetLstmInputParams ( LstmDescriptorPtr lstmDescriptor,
LstmInputParamsPtr lstmInputParams )
static

References DeserializerImpl().

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( unsigned int layerId,
const std::string & name ) const

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Definition at line 963 of file Deserializer.cpp.

965{
966 IgnoreUnused(layerIndex);
967 for (auto inputBinding : m_InputBindings)
968 {
969 if (inputBinding.first == name)
970 {
971 return inputBinding.second;
972 }
973 }
974 throw ParseException(fmt::format("No input binding found for layer:{0} / {1}",
975 name,
976 CHECK_LOCATION().AsString()));
977}
#define CHECK_LOCATION()
void IgnoreUnused(Ts &&...)

References CHECK_LOCATION, and armnn::IgnoreUnused().

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( unsigned int layerId,
const std::string & name ) const

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Definition at line 979 of file Deserializer.cpp.

981{
982 IgnoreUnused(layerIndex);
983 for (auto outputBinding : m_OutputBindings)
984 {
985 if (outputBinding.first == name)
986 {
987 return outputBinding.second;
988 }
989 }
990 throw ParseException(fmt::format("No output binding found for layer:{0} / {1}",
991 name,
992 CHECK_LOCATION().AsString()));
993}

References CHECK_LOCATION, and armnn::IgnoreUnused().

◆ GetNormalizationDescriptor()

armnn::NormalizationDescriptor GetNormalizationDescriptor ( NormalizationDescriptorPtr normalizationDescriptor,
unsigned int layerIndex )
static

Definition at line 2974 of file Deserializer.cpp.

2977{
2978 IgnoreUnused(layerIndex);
2979 armnn::NormalizationDescriptor desc;
2980
2981 switch (normalizationDescriptor->normChannelType())
2982 {
2983 case NormalizationAlgorithmChannel_Across:
2984 {
2986 break;
2987 }
2988 case NormalizationAlgorithmChannel_Within:
2989 {
2991 break;
2992 }
2993 default:
2994 {
2995 throw ParseException("Unsupported normalization channel type");
2996 }
2997 }
2998
2999 switch (normalizationDescriptor->normMethodType())
3000 {
3001 case NormalizationAlgorithmMethod_LocalBrightness:
3002 {
3004 break;
3005 }
3006 case NormalizationAlgorithmMethod_LocalContrast:
3007 {
3009 break;
3010 }
3011 default:
3012 {
3013 throw ParseException("Unsupported normalization method type");
3014 }
3015 }
3016
3017 switch (normalizationDescriptor->dataLayout())
3018 {
3019 case DataLayout_NCHW:
3020 {
3022 break;
3023 }
3024 case DataLayout_NHWC:
3025 {
3027 break;
3028 }
3029 default:
3030 {
3031 throw ParseException("Unsupported data layout");
3032 }
3033 }
3034
3035 desc.m_Alpha = normalizationDescriptor->alpha();
3036 desc.m_Beta = normalizationDescriptor->beta();
3037 desc.m_K = normalizationDescriptor->k();
3038 desc.m_NormSize = normalizationDescriptor->normSize();
3039
3040 return desc;
3041}
@ LocalContrast
Jarret 2009: Local Contrast Normalization.
Definition Types.hpp:219
@ LocalBrightness
Krichevsky 2012: Local Brightness Normalization.
Definition Types.hpp:217
NormalizationAlgorithmMethod m_NormMethodType
Normalization method algorithm to use (LocalBrightness, LocalContrast).
float m_Alpha
Alpha value for the normalization equation.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
float m_Beta
Beta value for the normalization equation.
float m_K
Kappa value used for the across channel normalization equation.
uint32_t m_NormSize
Depth radius value.
NormalizationAlgorithmChannel m_NormChannelType
Normalization channel algorithm to use (Across, Within).

References armnn::Across, armnn::IgnoreUnused(), armnn::LocalBrightness, armnn::LocalContrast, NormalizationDescriptor::m_Alpha, NormalizationDescriptor::m_Beta, NormalizationDescriptor::m_DataLayout, NormalizationDescriptor::m_K, NormalizationDescriptor::m_NormChannelType, NormalizationDescriptor::m_NormMethodType, NormalizationDescriptor::m_NormSize, armnn::NCHW, armnn::NHWC, and armnn::Within.

◆ GetOutputs()

TensorRawPtrVector GetOutputs ( const GraphPtr & graph,
unsigned int layerIndex )
static

Definition at line 844 of file Deserializer.cpp.

845{
846 CHECK_LAYERS(graphPtr, 0, layerIndex);
847 auto layer = GetBaseLayer(graphPtr, layerIndex);
848 const auto& numOutputs = layer->outputSlots()->size();
849
850 TensorRawPtrVector result(numOutputs);
851
852 for (unsigned int i=0; i<numOutputs; ++i)
853 {
854 result[i] = layer->outputSlots()->Get(i)->tensorInfo();
855 }
856 return result;
857}

References CHECK_LAYERS, and GetBaseLayer().

◆ GetPooling2dDescriptor()

armnn::Pooling2dDescriptor GetPooling2dDescriptor ( Pooling2dDescriptor pooling2dDescriptor,
unsigned int layerIndex )
static

Definition at line 2385 of file Deserializer.cpp.

2387{
2388 IgnoreUnused(layerIndex);
2389 armnn::Pooling2dDescriptor desc;
2390
2391 switch (pooling2dDesc->poolType())
2392 {
2393 case PoolingAlgorithm_Average:
2394 {
2396 break;
2397 }
2398 case PoolingAlgorithm_Max:
2399 {
2401 break;
2402 }
2403 case PoolingAlgorithm_L2:
2404 {
2406 break;
2407 }
2408 default:
2409 {
2410 throw ParseException("Unsupported pooling algorithm");
2411 }
2412 }
2413
2414 switch (pooling2dDesc->outputShapeRounding())
2415 {
2416 case OutputShapeRounding_Floor:
2417 {
2419 break;
2420 }
2421 case OutputShapeRounding_Ceiling:
2422 {
2424 break;
2425 }
2426 default:
2427 {
2428 throw ParseException("Unsupported output shape rounding");
2429 }
2430 }
2431
2432 switch (pooling2dDesc->paddingMethod())
2433 {
2434 case PaddingMethod_Exclude:
2435 {
2437 break;
2438 }
2439 case PaddingMethod_IgnoreValue:
2440 {
2442 break;
2443 }
2444 default:
2445 {
2446 throw ParseException("Unsupported padding method");
2447 }
2448 }
2449
2450 switch (pooling2dDesc->dataLayout())
2451 {
2452 case DataLayout_NCHW:
2453 {
2455 break;
2456 }
2457 case DataLayout_NHWC:
2458 {
2460 break;
2461 }
2462 default:
2463 {
2464 throw ParseException("Unsupported data layout");
2465 }
2466 }
2467
2468 desc.m_PadRight = pooling2dDesc->padRight();
2469 desc.m_PadLeft = pooling2dDesc->padLeft();
2470 desc.m_PadBottom = pooling2dDesc->padBottom();
2471 desc.m_PadTop = pooling2dDesc->padTop();
2472 desc.m_StrideX = pooling2dDesc->strideX();
2473 desc.m_StrideY = pooling2dDesc->strideY();
2474 desc.m_PoolWidth = pooling2dDesc->poolWidth();
2475 desc.m_PoolHeight = pooling2dDesc->poolHeight();
2476
2477 return desc;
2478}
@ Exclude
The padding fields don't count and are ignored.
Definition Types.hpp:194
@ IgnoreValue
The padding fields count, but are ignored.
Definition Types.hpp:192
uint32_t m_PadRight
Padding right value in the width dimension.
PoolingAlgorithm m_PoolType
The pooling algorithm to use (Max. Average, L2).
uint32_t m_PoolHeight
Pooling height value.
uint32_t m_PadTop
Padding top value in the height dimension.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
uint32_t m_PoolWidth
Pooling width value.
PaddingMethod m_PaddingMethod
The padding method to be used. (Exclude, IgnoreValue).
uint32_t m_PadBottom
Padding bottom value in the height dimension.
uint32_t m_PadLeft
Padding left value in the width dimension.
uint32_t m_StrideY
Stride value when proceeding through input for the height dimension.
uint32_t m_StrideX
Stride value when proceeding through input for the width dimension.
OutputShapeRounding m_OutputShapeRounding
The rounding method for the output shape. (Floor, Ceiling).

References armnn::Average, armnn::Ceiling, armnn::Exclude, armnn::Floor, armnn::IgnoreUnused(), armnn::IgnoreValue, armnn::L2, Pooling2dDescriptor::m_DataLayout, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, armnn::Max, armnn::NCHW, and armnn::NHWC.

◆ GetPooling3dDescriptor()

armnn::Pooling3dDescriptor GetPooling3dDescriptor ( Pooling3dDescriptor pooling3dDescriptor,
unsigned int layerIndex )
static

Definition at line 2480 of file Deserializer.cpp.

2482{
2483 IgnoreUnused(layerIndex);
2484 armnn::Pooling3dDescriptor desc;
2485
2486 switch (pooling3dDesc->poolType())
2487 {
2488 case PoolingAlgorithm_Average:
2489 {
2491 break;
2492 }
2493 case PoolingAlgorithm_Max:
2494 {
2496 break;
2497 }
2498 case PoolingAlgorithm_L2:
2499 {
2501 break;
2502 }
2503 default:
2504 {
2505 throw ParseException("Unsupported pooling algorithm");
2506 }
2507 }
2508
2509 switch (pooling3dDesc->outputShapeRounding())
2510 {
2511 case OutputShapeRounding_Floor:
2512 {
2514 break;
2515 }
2516 case OutputShapeRounding_Ceiling:
2517 {
2519 break;
2520 }
2521 default:
2522 {
2523 throw ParseException("Unsupported output shape rounding");
2524 }
2525 }
2526
2527 switch (pooling3dDesc->paddingMethod())
2528 {
2529 case PaddingMethod_Exclude:
2530 {
2532 break;
2533 }
2534 case PaddingMethod_IgnoreValue:
2535 {
2537 break;
2538 }
2539 default:
2540 {
2541 throw ParseException("Unsupported padding method");
2542 }
2543 }
2544
2545 switch (pooling3dDesc->dataLayout())
2546 {
2547 case DataLayout_NCDHW:
2548 {
2550 break;
2551 }
2552 case DataLayout_NDHWC:
2553 {
2555 break;
2556 }
2557 default:
2558 {
2559 throw ParseException("Unsupported data layout");
2560 }
2561 }
2562
2563 desc.m_PadRight = pooling3dDesc->padRight();
2564 desc.m_PadLeft = pooling3dDesc->padLeft();
2565 desc.m_PadBottom = pooling3dDesc->padBottom();
2566 desc.m_PadTop = pooling3dDesc->padTop();
2567 desc.m_PadFront = pooling3dDesc->padFront();
2568 desc.m_PadBack = pooling3dDesc->padBack();
2569 desc.m_StrideX = pooling3dDesc->strideX();
2570 desc.m_StrideY = pooling3dDesc->strideY();
2571 desc.m_StrideZ = pooling3dDesc->strideZ();
2572 desc.m_PoolWidth = pooling3dDesc->poolWidth();
2573 desc.m_PoolHeight = pooling3dDesc->poolHeight();
2574 desc.m_PoolDepth = pooling3dDesc->poolDepth();
2575
2576 return desc;
2577}
uint32_t m_PadRight
Padding right value in the width dimension.
PoolingAlgorithm m_PoolType
The pooling algorithm to use (Max. Average, L2).
uint32_t m_PadBack
Padding back value in the depth dimension.
uint32_t m_StrideZ
Stride value when proceeding through input for the depth dimension.
uint32_t m_PoolHeight
Pooling height value.
uint32_t m_PadTop
Padding top value in the height dimension.
DataLayout m_DataLayout
The data layout to be used (NCDHW, NDHWC).
uint32_t m_PoolWidth
Pooling width value.
uint32_t m_PadFront
Padding front value in the depth dimension.
PaddingMethod m_PaddingMethod
The padding method to be used. (Exclude, IgnoreValue).
uint32_t m_PadBottom
Padding bottom value in the height dimension.
uint32_t m_PadLeft
Padding left value in the width dimension.
uint32_t m_StrideY
Stride value when proceeding through input for the height dimension.
uint32_t m_PoolDepth
Pooling depth value.
uint32_t m_StrideX
Stride value when proceeding through input for the width dimension.
OutputShapeRounding m_OutputShapeRounding
The rounding method for the output shape. (Floor, Ceiling).

References armnn::Average, armnn::Ceiling, armnn::Exclude, armnn::Floor, armnn::IgnoreUnused(), armnn::IgnoreValue, armnn::L2, Pooling3dDescriptor::m_DataLayout, Pooling3dDescriptor::m_OutputShapeRounding, Pooling3dDescriptor::m_PadBack, Pooling3dDescriptor::m_PadBottom, Pooling3dDescriptor::m_PaddingMethod, Pooling3dDescriptor::m_PadFront, Pooling3dDescriptor::m_PadLeft, Pooling3dDescriptor::m_PadRight, Pooling3dDescriptor::m_PadTop, Pooling3dDescriptor::m_PoolDepth, Pooling3dDescriptor::m_PoolHeight, Pooling3dDescriptor::m_PoolType, Pooling3dDescriptor::m_PoolWidth, Pooling3dDescriptor::m_StrideX, Pooling3dDescriptor::m_StrideY, Pooling3dDescriptor::m_StrideZ, armnn::Max, armnn::NCDHW, and armnn::NDHWC.

◆ GetQLstmDescriptor()

armnn::QLstmDescriptor GetQLstmDescriptor ( QLstmDescriptorPtr qLstmDescriptorPtr)
static

Definition at line 3440 of file Deserializer.cpp.

3441{
3442 armnn::QLstmDescriptor desc;
3443
3444 desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled();
3445 desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled();
3446 desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
3447 desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled();
3448
3449 desc.m_CellClip = qLstmDescriptor->cellClip();
3450 desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
3451
3452 desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale();
3453 desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
3454 desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale();
3455 desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
3456
3457 desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale();
3458 desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
3459
3460 return desc;
3461}
float m_CellIntermediateScale
Cell intermediate quantization scale.
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_OutputIntermediateScale
Output intermediate quantization scale.
float m_ProjectionClip
Clipping threshold value for the projection.
float m_CellClip
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
float m_HiddenStateScale
Hidden State quantization scale.
float m_ForgetIntermediateScale
Forget intermediate quantization scale.

References QLstmDescriptor::m_CellClip, QLstmDescriptor::m_CellIntermediateScale, QLstmDescriptor::m_CifgEnabled, QLstmDescriptor::m_ForgetIntermediateScale, QLstmDescriptor::m_HiddenStateScale, QLstmDescriptor::m_HiddenStateZeroPoint, QLstmDescriptor::m_InputIntermediateScale, QLstmDescriptor::m_LayerNormEnabled, QLstmDescriptor::m_OutputIntermediateScale, QLstmDescriptor::m_PeepholeEnabled, QLstmDescriptor::m_ProjectionClip, and QLstmDescriptor::m_ProjectionEnabled.

◆ GetUnidirectionalSequenceLstmDescriptor()

armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor ( UnidirectionalSequenceLstmDescriptorPtr descriptor)
static

Definition at line 3899 of file Deserializer.cpp.

3901{
3903
3904 desc.m_ActivationFunc = descriptor->activationFunc();
3905 desc.m_ClippingThresCell = descriptor->clippingThresCell();
3906 desc.m_ClippingThresProj = descriptor->clippingThresProj();
3907 desc.m_CifgEnabled = descriptor->cifgEnabled();
3908 desc.m_PeepholeEnabled = descriptor->peepholeEnabled();
3909 desc.m_ProjectionEnabled = descriptor->projectionEnabled();
3910 desc.m_LayerNormEnabled = descriptor->layerNormEnabled();
3911 desc.m_TimeMajor = descriptor->timeMajor();
3912
3913 return desc;
3914}
LstmDescriptor UnidirectionalSequenceLstmDescriptor
bool m_TimeMajor
Enable/disable time major.

References LstmDescriptor::m_ActivationFunc, LstmDescriptor::m_CifgEnabled, LstmDescriptor::m_ClippingThresCell, LstmDescriptor::m_ClippingThresProj, LstmDescriptor::m_LayerNormEnabled, LstmDescriptor::m_PeepholeEnabled, LstmDescriptor::m_ProjectionEnabled, and LstmDescriptor::m_TimeMajor.

◆ LoadGraphFromBinary()

GraphPtr LoadGraphFromBinary ( const uint8_t * binaryContent,
size_t len )
static

Definition at line 901 of file Deserializer.cpp.

902{
903 if (binaryContent == nullptr)
904 {
905 throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
906 CHECK_LOCATION().AsString()));
907 }
908 flatbuffers::Verifier verifier(binaryContent, len);
909 if (verifier.VerifyBuffer<SerializedGraph>() == false)
910 {
911 throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
912 "flatbuffers format. size:{0} {1}",
913 len,
914 CHECK_LOCATION().AsString()));
915 }
916 return GetSerializedGraph(binaryContent);
917}

References CHECK_LOCATION.

Referenced by CreateNetworkFromBinary(), and CreateNetworkFromBinary().

◆ operator=()

DeserializerImpl & operator= ( const DeserializerImpl & )
delete

References DeserializerImpl().

◆ OutputShapeOfReshape()

armnn::TensorInfo OutputShapeOfReshape ( const armnn::TensorInfo & inputTensorInfo,
const std::vector< uint32_t > & targetDimsIn )
static

Definition at line 2640 of file Deserializer.cpp.

2642{
2643 std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
2644 const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
2645
2646 if (stretchDim != targetDimsIn.end())
2647 {
2648 if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
2649 {
2650 throw ParseException(fmt::format("At most one component of shape can be -1 {}",
2651 CHECK_LOCATION().AsString()));
2652 }
2653
2654 auto targetNumElements =
2656 std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
2657
2658 auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
2659 if (targetNumElements == 0)
2660 {
2661 if (inputTensorInfo.GetNumElements() == 0)
2662 {
2663 outputDims[stretchIndex] = 0;
2664 }
2665 else
2666 {
2667 throw ParseException(
2668 fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
2669 CHECK_LOCATION().AsString()));
2670 }
2671 }
2672 else
2673 {
2674 outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
2675 }
2676 }
2677
2678 TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
2679
2680 armnn::TensorInfo reshapeInfo = inputTensorInfo;
2681 reshapeInfo.SetShape(outputShape);
2682
2683 return reshapeInfo;
2684}
unsigned int GetNumElements() const
Definition Tensor.hpp:198
void SetShape(const TensorShape &newShape)
Definition Tensor.hpp:195
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)

References CHECK_LOCATION, TensorInfo::GetNumElements(), armnn::numeric_cast(), and TensorInfo::SetShape().


The documentation for this class was generated from the following files: