ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
IDeserializer::DeserializerImpl Class Reference

#include <Deserializer.hpp>

Public Member Functions

armnn::INetworkPtr CreateNetworkFromBinary (const std::vector< uint8_t > &binaryContent)
 Create an input network from binary file contents. More...
 
armnn::INetworkPtr CreateNetworkFromBinary (std::istream &binaryContent)
 Create an input network from a binary input stream. More...
 
BindingPointInfo GetNetworkInputBindingInfo (unsigned int layerId, const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name. More...
 
BindingPointInfo GetNetworkOutputBindingInfo (unsigned int layerId, const std::string &name) const
 Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name. More...
 
 DeserializerImpl ()
 
 ~DeserializerImpl ()=default
 
 DeserializerImpl (const DeserializerImpl &)=delete
 
DeserializerImploperator= (const DeserializerImpl &)=delete
 

Static Public Member Functions

static GraphPtr LoadGraphFromBinary (const uint8_t *binaryContent, size_t len)
 
static TensorRawPtrVector GetInputs (const GraphPtr &graph, unsigned int layerIndex)
 
static TensorRawPtrVector GetOutputs (const GraphPtr &graph, unsigned int layerIndex)
 
static LayerBaseRawPtr GetBaseLayer (const GraphPtr &graphPtr, unsigned int layerIndex)
 
static int32_t GetBindingLayerInfo (const GraphPtr &graphPtr, unsigned int layerIndex)
 
static std::string GetLayerName (const GraphPtr &graph, unsigned int index)
 
static armnn::Pooling2dDescriptor GetPooling2dDescriptor (Pooling2dDescriptor pooling2dDescriptor, unsigned int layerIndex)
 
static armnn::Pooling3dDescriptor GetPooling3dDescriptor (Pooling3dDescriptor pooling3dDescriptor, unsigned int layerIndex)
 
static armnn::NormalizationDescriptor GetNormalizationDescriptor (NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex)
 
static armnn::LstmDescriptor GetLstmDescriptor (LstmDescriptorPtr lstmDescriptor)
 
static armnn::LstmInputParams GetLstmInputParams (LstmDescriptorPtr lstmDescriptor, LstmInputParamsPtr lstmInputParams)
 
static armnn::QLstmDescriptor GetQLstmDescriptor (QLstmDescriptorPtr qLstmDescriptorPtr)
 
static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor (UnidirectionalSequenceLstmDescriptorPtr descriptor)
 
static armnn::TensorInfo OutputShapeOfReshape (const armnn::TensorInfo &inputTensorInfo, const std::vector< uint32_t > &targetDimsIn)
 

Detailed Description

Definition at line 34 of file Deserializer.hpp.

Constructor & Destructor Documentation

◆ DeserializerImpl() [1/2]

Definition at line 207 of file Deserializer.cpp.

208 : m_Network(nullptr, nullptr),
209 //May require LayerType_Max to be included
210 m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupportedLayer)
211 {
212  // register supported layers
213  m_ParserFunctions[Layer_AbsLayer] = &DeserializerImpl::ParseAbs;
214  m_ParserFunctions[Layer_ActivationLayer] = &DeserializerImpl::ParseActivation;
215  m_ParserFunctions[Layer_AdditionLayer] = &DeserializerImpl::ParseAdd;
216  m_ParserFunctions[Layer_ArgMinMaxLayer] = &DeserializerImpl::ParseArgMinMax;
217  m_ParserFunctions[Layer_BatchMatMulLayer] = &DeserializerImpl::ParseBatchMatMul;
218  m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &DeserializerImpl::ParseBatchToSpaceNd;
219  m_ParserFunctions[Layer_BatchNormalizationLayer] = &DeserializerImpl::ParseBatchNormalization;
220  m_ParserFunctions[Layer_CastLayer] = &DeserializerImpl::ParseCast;
221  m_ParserFunctions[Layer_ChannelShuffleLayer] = &DeserializerImpl::ParseChannelShuffle;
222  m_ParserFunctions[Layer_ComparisonLayer] = &DeserializerImpl::ParseComparison;
223  m_ParserFunctions[Layer_ConcatLayer] = &DeserializerImpl::ParseConcat;
224  m_ParserFunctions[Layer_ConstantLayer] = &DeserializerImpl::ParseConstant;
225  m_ParserFunctions[Layer_Convolution2dLayer] = &DeserializerImpl::ParseConvolution2d;
226  m_ParserFunctions[Layer_Convolution3dLayer] = &DeserializerImpl::ParseConvolution3d;
227  m_ParserFunctions[Layer_DepthToSpaceLayer] = &DeserializerImpl::ParseDepthToSpace;
228  m_ParserFunctions[Layer_DepthwiseConvolution2dLayer] = &DeserializerImpl::ParseDepthwiseConvolution2d;
229  m_ParserFunctions[Layer_DequantizeLayer] = &DeserializerImpl::ParseDequantize;
230  m_ParserFunctions[Layer_DetectionPostProcessLayer] = &DeserializerImpl::ParseDetectionPostProcess;
231  m_ParserFunctions[Layer_DivisionLayer] = &DeserializerImpl::ParseDivision;
232  m_ParserFunctions[Layer_ElementwiseBinaryLayer] = &DeserializerImpl::ParseElementwiseBinary;
233  m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &DeserializerImpl::ParseElementwiseUnary;
234  m_ParserFunctions[Layer_EqualLayer] = &DeserializerImpl::ParseEqual;
235  m_ParserFunctions[Layer_FullyConnectedLayer] = &DeserializerImpl::ParseFullyConnected;
236  m_ParserFunctions[Layer_FillLayer] = &DeserializerImpl::ParseFill;
237  m_ParserFunctions[Layer_FloorLayer] = &DeserializerImpl::ParseFloor;
238  m_ParserFunctions[Layer_GatherLayer] = &DeserializerImpl::ParseGather;
239  m_ParserFunctions[Layer_GatherNdLayer] = &DeserializerImpl::ParseGatherNd;
240  m_ParserFunctions[Layer_GreaterLayer] = &DeserializerImpl::ParseGreater;
241  m_ParserFunctions[Layer_InstanceNormalizationLayer] = &DeserializerImpl::ParseInstanceNormalization;
242  m_ParserFunctions[Layer_L2NormalizationLayer] = &DeserializerImpl::ParseL2Normalization;
243  m_ParserFunctions[Layer_LogicalBinaryLayer] = &DeserializerImpl::ParseLogicalBinary;
244  m_ParserFunctions[Layer_LogSoftmaxLayer] = &DeserializerImpl::ParseLogSoftmax;
245  m_ParserFunctions[Layer_LstmLayer] = &DeserializerImpl::ParseLstm;
246  m_ParserFunctions[Layer_MaximumLayer] = &DeserializerImpl::ParseMaximum;
247  m_ParserFunctions[Layer_MeanLayer] = &DeserializerImpl::ParseMean;
248  m_ParserFunctions[Layer_MinimumLayer] = &DeserializerImpl::ParseMinimum;
249  m_ParserFunctions[Layer_MergeLayer] = &DeserializerImpl::ParseMerge;
250  m_ParserFunctions[Layer_MergerLayer] = &DeserializerImpl::ParseConcat;
251  m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializerImpl::ParseMultiplication;
252  m_ParserFunctions[Layer_NormalizationLayer] = &DeserializerImpl::ParseNormalization;
253  m_ParserFunctions[Layer_PadLayer] = &DeserializerImpl::ParsePad;
254  m_ParserFunctions[Layer_PermuteLayer] = &DeserializerImpl::ParsePermute;
255  m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializerImpl::ParsePooling2d;
256  m_ParserFunctions[Layer_Pooling3dLayer] = &DeserializerImpl::ParsePooling3d;
257  m_ParserFunctions[Layer_PreluLayer] = &DeserializerImpl::ParsePrelu;
258  m_ParserFunctions[Layer_QLstmLayer] = &DeserializerImpl::ParseQLstm;
259  m_ParserFunctions[Layer_QuantizeLayer] = &DeserializerImpl::ParseQuantize;
260  m_ParserFunctions[Layer_QuantizedLstmLayer] = &DeserializerImpl::ParseQuantizedLstm;
261  m_ParserFunctions[Layer_RankLayer] = &DeserializerImpl::ParseRank;
262  m_ParserFunctions[Layer_ReduceLayer] = &DeserializerImpl::ParseReduce;
263  m_ParserFunctions[Layer_ReshapeLayer] = &DeserializerImpl::ParseReshape;
264  m_ParserFunctions[Layer_ResizeBilinearLayer] = &DeserializerImpl::ParseResizeBilinear;
265  m_ParserFunctions[Layer_ResizeLayer] = &DeserializerImpl::ParseResize;
266  m_ParserFunctions[Layer_ReverseV2Layer] = &DeserializerImpl::ParseReverseV2;
267  m_ParserFunctions[Layer_RsqrtLayer] = &DeserializerImpl::ParseRsqrt;
268  m_ParserFunctions[Layer_ScatterNdLayer] = &DeserializerImpl::ParseScatterNd;
269  m_ParserFunctions[Layer_ShapeLayer] = &DeserializerImpl::ParseShape;
270  m_ParserFunctions[Layer_SliceLayer] = &DeserializerImpl::ParseSlice;
271  m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializerImpl::ParseSoftmax;
272  m_ParserFunctions[Layer_SpaceToBatchNdLayer] = &DeserializerImpl::ParseSpaceToBatchNd;
273  m_ParserFunctions[Layer_SpaceToDepthLayer] = &DeserializerImpl::ParseSpaceToDepth;
274  m_ParserFunctions[Layer_SplitterLayer] = &DeserializerImpl::ParseSplitter;
275  m_ParserFunctions[Layer_StackLayer] = &DeserializerImpl::ParseStack;
276  m_ParserFunctions[Layer_StandInLayer] = &DeserializerImpl::ParseStandIn;
277  m_ParserFunctions[Layer_StridedSliceLayer] = &DeserializerImpl::ParseStridedSlice;
278  m_ParserFunctions[Layer_SubtractionLayer] = &DeserializerImpl::ParseSubtraction;
279  m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch;
280  m_ParserFunctions[Layer_TileLayer] = &DeserializerImpl::ParseTile;
281  m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d;
282  m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose;
283  m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm;
284 }

◆ ~DeserializerImpl()

~DeserializerImpl ( )
default

◆ DeserializerImpl() [2/2]

DeserializerImpl ( const DeserializerImpl )
delete

Member Function Documentation

◆ CreateNetworkFromBinary() [1/2]

INetworkPtr CreateNetworkFromBinary ( const std::vector< uint8_t > &  binaryContent)

Create an input network from binary file contents.

Definition at line 878 of file Deserializer.cpp.

879 {
880  ResetParser();
881  GraphPtr graph = LoadGraphFromBinary(binaryContent.data(), binaryContent.size());
882  return CreateNetworkFromGraph(graph);
883 }
static GraphPtr LoadGraphFromBinary(const uint8_t *binaryContent, size_t len)
const armnnSerializer::SerializedGraph * GraphPtr

◆ CreateNetworkFromBinary() [2/2]

armnn::INetworkPtr CreateNetworkFromBinary ( std::istream &  binaryContent)

Create an input network from a binary input stream.

Definition at line 885 of file Deserializer.cpp.

886 {
887  ResetParser();
888  if (binaryContent.fail()) {
889  ARMNN_LOG(error) << (std::string("Cannot read input"));
890  throw ParseException("Unable to read Input stream data");
891  }
892  binaryContent.seekg(0, std::ios::end);
893  const std::streamoff size = binaryContent.tellg();
894  std::vector<char> content(static_cast<size_t>(size));
895  binaryContent.seekg(0);
896  binaryContent.read(content.data(), static_cast<std::streamsize>(size));
897  GraphPtr graph = LoadGraphFromBinary(reinterpret_cast<uint8_t*>(content.data()), static_cast<size_t>(size));
898  return CreateNetworkFromGraph(graph);
899 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:212

References ARMNN_LOG, and armnn::error.

◆ GetBaseLayer()

LayerBaseRawPtr GetBaseLayer ( const GraphPtr graphPtr,
unsigned int  layerIndex 
)
static

Definition at line 286 of file Deserializer.cpp.

287 {
288  auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
289 
290  switch(layerType)
291  {
292  case Layer::Layer_AbsLayer:
293  return graphPtr->layers()->Get(layerIndex)->layer_as_AbsLayer()->base();
294  case Layer::Layer_ActivationLayer:
295  return graphPtr->layers()->Get(layerIndex)->layer_as_ActivationLayer()->base();
296  case Layer::Layer_AdditionLayer:
297  return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base();
298  case Layer::Layer_ArgMinMaxLayer:
299  return graphPtr->layers()->Get(layerIndex)->layer_as_ArgMinMaxLayer()->base();
300  case Layer::Layer_BatchMatMulLayer:
301  return graphPtr->layers()->Get(layerIndex)->layer_as_BatchMatMulLayer()->base();
302  case Layer::Layer_BatchToSpaceNdLayer:
303  return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base();
304  case Layer::Layer_BatchNormalizationLayer:
305  return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
306  case Layer::Layer_CastLayer:
307  return graphPtr->layers()->Get(layerIndex)->layer_as_CastLayer()->base();
308  case Layer::Layer_ChannelShuffleLayer:
309  return graphPtr->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->base();
310  case Layer::Layer_ComparisonLayer:
311  return graphPtr->layers()->Get(layerIndex)->layer_as_ComparisonLayer()->base();
312  case Layer::Layer_ConcatLayer:
313  return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base();
314  case Layer::Layer_ConstantLayer:
315  return graphPtr->layers()->Get(layerIndex)->layer_as_ConstantLayer()->base();
316  case Layer::Layer_Convolution2dLayer:
317  return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution2dLayer()->base();
318  case Layer::Layer_Convolution3dLayer:
319  return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution3dLayer()->base();
320  case Layer::Layer_DepthToSpaceLayer:
321  return graphPtr->layers()->Get(layerIndex)->layer_as_DepthToSpaceLayer()->base();
322  case Layer::Layer_DepthwiseConvolution2dLayer:
323  return graphPtr->layers()->Get(layerIndex)->layer_as_DepthwiseConvolution2dLayer()->base();
324  case Layer::Layer_DequantizeLayer:
325  return graphPtr->layers()->Get(layerIndex)->layer_as_DequantizeLayer()->base();
326  case Layer::Layer_DetectionPostProcessLayer:
327  return graphPtr->layers()->Get(layerIndex)->layer_as_DetectionPostProcessLayer()->base();
328  case Layer::Layer_DivisionLayer:
329  return graphPtr->layers()->Get(layerIndex)->layer_as_DivisionLayer()->base();
330  case Layer::Layer_EqualLayer:
331  return graphPtr->layers()->Get(layerIndex)->layer_as_EqualLayer()->base();
332  case Layer::Layer_ElementwiseBinaryLayer:
333  return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer()->base();
334  case Layer::Layer_ElementwiseUnaryLayer:
335  return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer()->base();
336  case Layer::Layer_FullyConnectedLayer:
337  return graphPtr->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer()->base();
338  case Layer::Layer_FillLayer:
339  return graphPtr->layers()->Get(layerIndex)->layer_as_FillLayer()->base();
340  case Layer::Layer_FloorLayer:
341  return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base();
342  case Layer::Layer_GatherLayer:
343  return graphPtr->layers()->Get(layerIndex)->layer_as_GatherLayer()->base();
344  case Layer::Layer_GatherNdLayer:
345  return graphPtr->layers()->Get(layerIndex)->layer_as_GatherNdLayer()->base();
346  case Layer::Layer_GreaterLayer:
347  return graphPtr->layers()->Get(layerIndex)->layer_as_GreaterLayer()->base();
348  case Layer::Layer_InputLayer:
349  return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
350  case Layer::Layer_InstanceNormalizationLayer:
351  return graphPtr->layers()->Get(layerIndex)->layer_as_InstanceNormalizationLayer()->base();
352  case Layer::Layer_L2NormalizationLayer:
353  return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base();
354  case Layer::Layer_LogicalBinaryLayer:
355  return graphPtr->layers()->Get(layerIndex)->layer_as_LogicalBinaryLayer()->base();
356  case Layer::Layer_LogSoftmaxLayer:
357  return graphPtr->layers()->Get(layerIndex)->layer_as_LogSoftmaxLayer()->base();
358  case Layer::Layer_LstmLayer:
359  return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base();
360  case Layer::Layer_MeanLayer:
361  return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base();
362  case Layer::Layer_MinimumLayer:
363  return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base();
364  case Layer::Layer_MaximumLayer:
365  return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base();
366  case Layer::Layer_MergeLayer:
367  return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base();
368  case Layer::Layer_MergerLayer:
369  return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base();
370  case Layer::Layer_MultiplicationLayer:
371  return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
372  case Layer::Layer_NormalizationLayer:
373  return graphPtr->layers()->Get(layerIndex)->layer_as_NormalizationLayer()->base();
374  case Layer::Layer_OutputLayer:
375  return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
376  case Layer::Layer_PadLayer:
377  return graphPtr->layers()->Get(layerIndex)->layer_as_PadLayer()->base();
378  case Layer::Layer_PermuteLayer:
379  return graphPtr->layers()->Get(layerIndex)->layer_as_PermuteLayer()->base();
380  case Layer::Layer_Pooling2dLayer:
381  return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
382  case Layer::Layer_Pooling3dLayer:
383  return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling3dLayer()->base();
384  case Layer::Layer_PreluLayer:
385  return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
386  case Layer::Layer_QLstmLayer:
387  return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
388  case Layer::Layer_QuantizeLayer:
389  return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
390  case Layer::Layer_QuantizedLstmLayer:
391  return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer()->base();
392  case Layer::Layer_RankLayer:
393  return graphPtr->layers()->Get(layerIndex)->layer_as_RankLayer()->base();
394  case Layer::Layer_ReduceLayer:
395  return graphPtr->layers()->Get(layerIndex)->layer_as_ReduceLayer()->base();
396  case Layer::Layer_ReshapeLayer:
397  return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
398  case Layer::Layer_ResizeBilinearLayer:
399  return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeBilinearLayer()->base();
400  case Layer::Layer_ResizeLayer:
401  return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeLayer()->base();
402  case Layer::Layer_ReverseV2Layer:
403  return graphPtr->layers()->Get(layerIndex)->layer_as_ReverseV2Layer()->base();
404  case Layer::Layer_RsqrtLayer:
405  return graphPtr->layers()->Get(layerIndex)->layer_as_RsqrtLayer()->base();
406  case Layer::Layer_ScatterNdLayer:
407  return graphPtr->layers()->Get(layerIndex)->layer_as_ScatterNdLayer()->base();
408  case Layer::Layer_ShapeLayer:
409  return graphPtr->layers()->Get(layerIndex)->layer_as_ShapeLayer()->base();
410  case Layer::Layer_SliceLayer:
411  return graphPtr->layers()->Get(layerIndex)->layer_as_SliceLayer()->base();
412  case Layer::Layer_SoftmaxLayer:
413  return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
414  case Layer::Layer_SpaceToBatchNdLayer:
415  return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToBatchNdLayer()->base();
416  case Layer::Layer_SpaceToDepthLayer:
417  return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToDepthLayer()->base();
418  case Layer::Layer_SplitterLayer:
419  return graphPtr->layers()->Get(layerIndex)->layer_as_SplitterLayer()->base();
420  case Layer::Layer_StackLayer:
421  return graphPtr->layers()->Get(layerIndex)->layer_as_StackLayer()->base();
422  case Layer::Layer_StandInLayer:
423  return graphPtr->layers()->Get(layerIndex)->layer_as_StandInLayer()->base();
424  case Layer::Layer_StridedSliceLayer:
425  return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
426  case Layer::Layer_SubtractionLayer:
427  return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base();
428  case Layer::Layer_SwitchLayer:
429  return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base();
430  case Layer::Layer_TileLayer:
431  return graphPtr->layers()->Get(layerIndex)->layer_as_TileLayer()->base();
432  case Layer::Layer_TransposeConvolution2dLayer:
433  return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base();
434  case Layer::Layer_TransposeLayer:
435  return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base();
436  case Layer::Layer_UnidirectionalSequenceLstmLayer:
437  return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base();
438  case Layer::Layer_NONE:
439  default:
440  throw ParseException(fmt::format("Layer type {} not recognized", layerType));
441  }
442 }

◆ GetBindingLayerInfo()

int32_t GetBindingLayerInfo ( const GraphPtr graphPtr,
unsigned int  layerIndex 
)
static

Definition at line 451 of file Deserializer.cpp.

452 {
453  auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
454 
455  if (layerType == Layer::Layer_InputLayer)
456  {
457  return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->layerBindingId();
458  }
459  else if ( layerType == Layer::Layer_OutputLayer )
460  {
461  return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->layerBindingId();
462  }
463  return 0;
464 }

◆ GetInputs()

TensorRawPtrVector GetInputs ( const GraphPtr graph,
unsigned int  layerIndex 
)
static

Definition at line 827 of file Deserializer.cpp.

828 {
829  CHECK_LAYERS(graphPtr, 0, layerIndex);
830  auto layer = GetBaseLayer(graphPtr, layerIndex);
831  const auto& numInputs = layer->inputSlots()->size();
832 
833  TensorRawPtrVector result(numInputs);
834 
835  for (unsigned int i=0; i<numInputs; ++i)
836  {
837  auto inputId = CHECKED_NON_NEGATIVE(static_cast<int32_t>
838  (layer->inputSlots()->Get(i)->connection()->sourceLayerIndex()));
839  result[i] = GetBaseLayer(graphPtr, inputId)->outputSlots()->Get(0)->tensorInfo();
840  }
841  return result;
842 }
#define CHECK_LAYERS(GRAPH, LAYERS_INDEX, LAYER_INDEX)
#define CHECKED_NON_NEGATIVE(VALUE)
static LayerBaseRawPtr GetBaseLayer(const GraphPtr &graphPtr, unsigned int layerIndex)
std::vector< TensorRawPtr > TensorRawPtrVector

References CHECK_LAYERS, and CHECKED_NON_NEGATIVE.

◆ GetLayerName()

std::string GetLayerName ( const GraphPtr graph,
unsigned int  index 
)
static

Definition at line 444 of file Deserializer.cpp.

445 {
446  auto layer = GetBaseLayer(graph, index);
447  assert(layer);
448  return layer->layerName()->str();
449 }

◆ GetLstmDescriptor()

armnn::LstmDescriptor GetLstmDescriptor ( LstmDescriptorPtr  lstmDescriptor)
static

Definition at line 3309 of file Deserializer.cpp.

3310 {
3311  armnn::LstmDescriptor desc;
3312 
3313  desc.m_ActivationFunc = lstmDescriptor->activationFunc();
3314  desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell();
3315  desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj();
3316  desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
3317  desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
3318  desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
3319  desc.m_LayerNormEnabled = lstmDescriptor->layerNormEnabled();
3320 
3321  return desc;
3322 }
An LstmDescriptor for the LstmLayer.
bool m_PeepholeEnabled
Enable/disable peephole.
bool m_LayerNormEnabled
Enable/disable layer normalization.
float m_ClippingThresCell
Clipping threshold value for the cell state.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_ClippingThresProj
Clipping threshold value for the projection.
bool m_CifgEnabled
Enable/disable cifg (coupled input & forget gate).
uint32_t m_ActivationFunc
The activation function to use.

References LstmDescriptor::m_ActivationFunc, LstmDescriptor::m_CifgEnabled, LstmDescriptor::m_ClippingThresCell, LstmDescriptor::m_ClippingThresProj, LstmDescriptor::m_LayerNormEnabled, LstmDescriptor::m_PeepholeEnabled, and LstmDescriptor::m_ProjectionEnabled.

◆ GetLstmInputParams()

static armnn::LstmInputParams GetLstmInputParams ( LstmDescriptorPtr  lstmDescriptor,
LstmInputParamsPtr  lstmInputParams 
)
static

◆ GetNetworkInputBindingInfo()

BindingPointInfo GetNetworkInputBindingInfo ( unsigned int  layerId,
const std::string &  name 
) const

Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.

Definition at line 963 of file Deserializer.cpp.

965 {
966  IgnoreUnused(layerIndex);
967  for (auto inputBinding : m_InputBindings)
968  {
969  if (inputBinding.first == name)
970  {
971  return inputBinding.second;
972  }
973  }
974  throw ParseException(fmt::format("No input binding found for layer:{0} / {1}",
975  name,
976  CHECK_LOCATION().AsString()));
977 }
#define CHECK_LOCATION()
Definition: Exceptions.hpp:203
void IgnoreUnused(Ts &&...)

References CHECK_LOCATION, and armnn::IgnoreUnused().

◆ GetNetworkOutputBindingInfo()

BindingPointInfo GetNetworkOutputBindingInfo ( unsigned int  layerId,
const std::string &  name 
) const

Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.

Definition at line 979 of file Deserializer.cpp.

981 {
982  IgnoreUnused(layerIndex);
983  for (auto outputBinding : m_OutputBindings)
984  {
985  if (outputBinding.first == name)
986  {
987  return outputBinding.second;
988  }
989  }
990  throw ParseException(fmt::format("No output binding found for layer:{0} / {1}",
991  name,
992  CHECK_LOCATION().AsString()));
993 }

References CHECK_LOCATION, and armnn::IgnoreUnused().

◆ GetNormalizationDescriptor()

armnn::NormalizationDescriptor GetNormalizationDescriptor ( NormalizationDescriptorPtr  normalizationDescriptor,
unsigned int  layerIndex 
)
static

Definition at line 2974 of file Deserializer.cpp.

2977 {
2978  IgnoreUnused(layerIndex);
2980 
2981  switch (normalizationDescriptor->normChannelType())
2982  {
2983  case NormalizationAlgorithmChannel_Across:
2984  {
2986  break;
2987  }
2988  case NormalizationAlgorithmChannel_Within:
2989  {
2991  break;
2992  }
2993  default:
2994  {
2995  throw ParseException("Unsupported normalization channel type");
2996  }
2997  }
2998 
2999  switch (normalizationDescriptor->normMethodType())
3000  {
3001  case NormalizationAlgorithmMethod_LocalBrightness:
3002  {
3004  break;
3005  }
3006  case NormalizationAlgorithmMethod_LocalContrast:
3007  {
3009  break;
3010  }
3011  default:
3012  {
3013  throw ParseException("Unsupported normalization method type");
3014  }
3015  }
3016 
3017  switch (normalizationDescriptor->dataLayout())
3018  {
3019  case DataLayout_NCHW:
3020  {
3022  break;
3023  }
3024  case DataLayout_NHWC:
3025  {
3027  break;
3028  }
3029  default:
3030  {
3031  throw ParseException("Unsupported data layout");
3032  }
3033  }
3034 
3035  desc.m_Alpha = normalizationDescriptor->alpha();
3036  desc.m_Beta = normalizationDescriptor->beta();
3037  desc.m_K = normalizationDescriptor->k();
3038  desc.m_NormSize = normalizationDescriptor->normSize();
3039 
3040  return desc;
3041 }
@ LocalContrast
Jarret 2009: Local Contrast Normalization.
@ LocalBrightness
Krichevsky 2012: Local Brightness Normalization.
A NormalizationDescriptor for the NormalizationLayer.
NormalizationAlgorithmMethod m_NormMethodType
Normalization method algorithm to use (LocalBrightness, LocalContrast).
float m_Alpha
Alpha value for the normalization equation.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
float m_Beta
Beta value for the normalization equation.
float m_K
Kappa value used for the across channel normalization equation.
uint32_t m_NormSize
Depth radius value.
NormalizationAlgorithmChannel m_NormChannelType
Normalization channel algorithm to use (Across, Within).

References armnn::Across, armnn::IgnoreUnused(), armnn::LocalBrightness, armnn::LocalContrast, NormalizationDescriptor::m_Alpha, NormalizationDescriptor::m_Beta, NormalizationDescriptor::m_DataLayout, NormalizationDescriptor::m_K, NormalizationDescriptor::m_NormChannelType, NormalizationDescriptor::m_NormMethodType, NormalizationDescriptor::m_NormSize, armnn::NCHW, armnn::NHWC, and armnn::Within.

◆ GetOutputs()

TensorRawPtrVector GetOutputs ( const GraphPtr graph,
unsigned int  layerIndex 
)
static

Definition at line 844 of file Deserializer.cpp.

845 {
846  CHECK_LAYERS(graphPtr, 0, layerIndex);
847  auto layer = GetBaseLayer(graphPtr, layerIndex);
848  const auto& numOutputs = layer->outputSlots()->size();
849 
850  TensorRawPtrVector result(numOutputs);
851 
852  for (unsigned int i=0; i<numOutputs; ++i)
853  {
854  result[i] = layer->outputSlots()->Get(i)->tensorInfo();
855  }
856  return result;
857 }

References CHECK_LAYERS.

◆ GetPooling2dDescriptor()

armnn::Pooling2dDescriptor GetPooling2dDescriptor ( Pooling2dDescriptor  pooling2dDescriptor,
unsigned int  layerIndex 
)
static

Definition at line 2385 of file Deserializer.cpp.

2387 {
2388  IgnoreUnused(layerIndex);
2390 
2391  switch (pooling2dDesc->poolType())
2392  {
2393  case PoolingAlgorithm_Average:
2394  {
2396  break;
2397  }
2398  case PoolingAlgorithm_Max:
2399  {
2401  break;
2402  }
2403  case PoolingAlgorithm_L2:
2404  {
2406  break;
2407  }
2408  default:
2409  {
2410  throw ParseException("Unsupported pooling algorithm");
2411  }
2412  }
2413 
2414  switch (pooling2dDesc->outputShapeRounding())
2415  {
2416  case OutputShapeRounding_Floor:
2417  {
2419  break;
2420  }
2421  case OutputShapeRounding_Ceiling:
2422  {
2424  break;
2425  }
2426  default:
2427  {
2428  throw ParseException("Unsupported output shape rounding");
2429  }
2430  }
2431 
2432  switch (pooling2dDesc->paddingMethod())
2433  {
2434  case PaddingMethod_Exclude:
2435  {
2437  break;
2438  }
2439  case PaddingMethod_IgnoreValue:
2440  {
2442  break;
2443  }
2444  default:
2445  {
2446  throw ParseException("Unsupported padding method");
2447  }
2448  }
2449 
2450  switch (pooling2dDesc->dataLayout())
2451  {
2452  case DataLayout_NCHW:
2453  {
2455  break;
2456  }
2457  case DataLayout_NHWC:
2458  {
2460  break;
2461  }
2462  default:
2463  {
2464  throw ParseException("Unsupported data layout");
2465  }
2466  }
2467 
2468  desc.m_PadRight = pooling2dDesc->padRight();
2469  desc.m_PadLeft = pooling2dDesc->padLeft();
2470  desc.m_PadBottom = pooling2dDesc->padBottom();
2471  desc.m_PadTop = pooling2dDesc->padTop();
2472  desc.m_StrideX = pooling2dDesc->strideX();
2473  desc.m_StrideY = pooling2dDesc->strideY();
2474  desc.m_PoolWidth = pooling2dDesc->poolWidth();
2475  desc.m_PoolHeight = pooling2dDesc->poolHeight();
2476 
2477  return desc;
2478 }
@ Exclude
The padding fields don't count and are ignored.
@ IgnoreValue
The padding fields count, but are ignored.
A Pooling2dDescriptor for the Pooling2dLayer.
uint32_t m_PadRight
Padding right value in the width dimension.
PoolingAlgorithm m_PoolType
The pooling algorithm to use (Max. Average, L2).
uint32_t m_PoolHeight
Pooling height value.
uint32_t m_PadTop
Padding top value in the height dimension.
DataLayout m_DataLayout
The data layout to be used (NCHW, NHWC).
uint32_t m_PoolWidth
Pooling width value.
PaddingMethod m_PaddingMethod
The padding method to be used. (Exclude, IgnoreValue).
uint32_t m_PadBottom
Padding bottom value in the height dimension.
uint32_t m_PadLeft
Padding left value in the width dimension.
uint32_t m_StrideY
Stride value when proceeding through input for the height dimension.
uint32_t m_StrideX
Stride value when proceeding through input for the width dimension.
OutputShapeRounding m_OutputShapeRounding
The rounding method for the output shape. (Floor, Ceiling).

References armnn::Average, armnn::Ceiling, armnn::Exclude, armnn::Floor, armnn::IgnoreUnused(), armnn::IgnoreValue, armnn::L2, Pooling2dDescriptor::m_DataLayout, Pooling2dDescriptor::m_OutputShapeRounding, Pooling2dDescriptor::m_PadBottom, Pooling2dDescriptor::m_PaddingMethod, Pooling2dDescriptor::m_PadLeft, Pooling2dDescriptor::m_PadRight, Pooling2dDescriptor::m_PadTop, Pooling2dDescriptor::m_PoolHeight, Pooling2dDescriptor::m_PoolType, Pooling2dDescriptor::m_PoolWidth, Pooling2dDescriptor::m_StrideX, Pooling2dDescriptor::m_StrideY, armnn::Max, armnn::NCHW, and armnn::NHWC.

◆ GetPooling3dDescriptor()

armnn::Pooling3dDescriptor GetPooling3dDescriptor ( Pooling3dDescriptor  pooling3dDescriptor,
unsigned int  layerIndex 
)
static

Definition at line 2480 of file Deserializer.cpp.

2482 {
2483  IgnoreUnused(layerIndex);
2485 
2486  switch (pooling3dDesc->poolType())
2487  {
2488  case PoolingAlgorithm_Average:
2489  {
2491  break;
2492  }
2493  case PoolingAlgorithm_Max:
2494  {
2496  break;
2497  }
2498  case PoolingAlgorithm_L2:
2499  {
2501  break;
2502  }
2503  default:
2504  {
2505  throw ParseException("Unsupported pooling algorithm");
2506  }
2507  }
2508 
2509  switch (pooling3dDesc->outputShapeRounding())
2510  {
2511  case OutputShapeRounding_Floor:
2512  {
2514  break;
2515  }
2516  case OutputShapeRounding_Ceiling:
2517  {
2519  break;
2520  }
2521  default:
2522  {
2523  throw ParseException("Unsupported output shape rounding");
2524  }
2525  }
2526 
2527  switch (pooling3dDesc->paddingMethod())
2528  {
2529  case PaddingMethod_Exclude:
2530  {
2532  break;
2533  }
2534  case PaddingMethod_IgnoreValue:
2535  {
2537  break;
2538  }
2539  default:
2540  {
2541  throw ParseException("Unsupported padding method");
2542  }
2543  }
2544 
2545  switch (pooling3dDesc->dataLayout())
2546  {
2547  case DataLayout_NCDHW:
2548  {
2550  break;
2551  }
2552  case DataLayout_NDHWC:
2553  {
2555  break;
2556  }
2557  default:
2558  {
2559  throw ParseException("Unsupported data layout");
2560  }
2561  }
2562 
2563  desc.m_PadRight = pooling3dDesc->padRight();
2564  desc.m_PadLeft = pooling3dDesc->padLeft();
2565  desc.m_PadBottom = pooling3dDesc->padBottom();
2566  desc.m_PadTop = pooling3dDesc->padTop();
2567  desc.m_PadFront = pooling3dDesc->padFront();
2568  desc.m_PadBack = pooling3dDesc->padBack();
2569  desc.m_StrideX = pooling3dDesc->strideX();
2570  desc.m_StrideY = pooling3dDesc->strideY();
2571  desc.m_StrideZ = pooling3dDesc->strideZ();
2572  desc.m_PoolWidth = pooling3dDesc->poolWidth();
2573  desc.m_PoolHeight = pooling3dDesc->poolHeight();
2574  desc.m_PoolDepth = pooling3dDesc->poolDepth();
2575 
2576  return desc;
2577 }
A Pooling3dDescriptor for the Pooling3dLayer.
uint32_t m_PadRight
Padding right value in the width dimension.
PoolingAlgorithm m_PoolType
The pooling algorithm to use (Max. Average, L2).
uint32_t m_PadBack
Padding back value in the depth dimension.
uint32_t m_StrideZ
Stride value when proceeding through input for the depth dimension.
uint32_t m_PoolHeight
Pooling height value.
uint32_t m_PadTop
Padding top value in the height dimension.
DataLayout m_DataLayout
The data layout to be used (NCDHW, NDHWC).
uint32_t m_PoolWidth
Pooling width value.
uint32_t m_PadFront
Padding front value in the depth dimension.
PaddingMethod m_PaddingMethod
The padding method to be used. (Exclude, IgnoreValue).
uint32_t m_PadBottom
Padding bottom value in the height dimension.
uint32_t m_PadLeft
Padding left value in the width dimension.
uint32_t m_StrideY
Stride value when proceeding through input for the height dimension.
uint32_t m_PoolDepth
Pooling depth value.
uint32_t m_StrideX
Stride value when proceeding through input for the width dimension.
OutputShapeRounding m_OutputShapeRounding
The rounding method for the output shape. (Floor, Ceiling).

References armnn::Average, armnn::Ceiling, armnn::Exclude, armnn::Floor, armnn::IgnoreUnused(), armnn::IgnoreValue, armnn::L2, Pooling3dDescriptor::m_DataLayout, Pooling3dDescriptor::m_OutputShapeRounding, Pooling3dDescriptor::m_PadBack, Pooling3dDescriptor::m_PadBottom, Pooling3dDescriptor::m_PaddingMethod, Pooling3dDescriptor::m_PadFront, Pooling3dDescriptor::m_PadLeft, Pooling3dDescriptor::m_PadRight, Pooling3dDescriptor::m_PadTop, Pooling3dDescriptor::m_PoolDepth, Pooling3dDescriptor::m_PoolHeight, Pooling3dDescriptor::m_PoolType, Pooling3dDescriptor::m_PoolWidth, Pooling3dDescriptor::m_StrideX, Pooling3dDescriptor::m_StrideY, Pooling3dDescriptor::m_StrideZ, armnn::Max, armnn::NCDHW, and armnn::NDHWC.

◆ GetQLstmDescriptor()

armnn::QLstmDescriptor GetQLstmDescriptor ( QLstmDescriptorPtr  qLstmDescriptorPtr)
static

Definition at line 3440 of file Deserializer.cpp.

3441 {
3443 
3444  desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled();
3445  desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled();
3446  desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
3447  desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled();
3448 
3449  desc.m_CellClip = qLstmDescriptor->cellClip();
3450  desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
3451 
3452  desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale();
3453  desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
3454  desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale();
3455  desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
3456 
3457  desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale();
3458  desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
3459 
3460  return desc;
3461 }
A QLstmDescriptor for the QLstmLayer.
float m_CellIntermediateScale
Cell intermediate quantization scale.
float m_InputIntermediateScale
Input intermediate quantization scale.
bool m_PeepholeEnabled
Enable/disable peephole.
int32_t m_HiddenStateZeroPoint
Hidden State zero point.
bool m_LayerNormEnabled
Enable/disable layer normalization.
bool m_ProjectionEnabled
Enable/disable the projection layer.
float m_OutputIntermediateScale
Output intermediate quantization scale.
float m_ProjectionClip
Clipping threshold value for the projection.
float m_CellClip
Clipping threshold value for the cell state.
bool m_CifgEnabled
Enable/disable CIFG (coupled input & forget gate).
float m_HiddenStateScale
Hidden State quantization scale.
float m_ForgetIntermediateScale
Forget intermediate quantization scale.

References QLstmDescriptor::m_CellClip, QLstmDescriptor::m_CellIntermediateScale, QLstmDescriptor::m_CifgEnabled, QLstmDescriptor::m_ForgetIntermediateScale, QLstmDescriptor::m_HiddenStateScale, QLstmDescriptor::m_HiddenStateZeroPoint, QLstmDescriptor::m_InputIntermediateScale, QLstmDescriptor::m_LayerNormEnabled, QLstmDescriptor::m_OutputIntermediateScale, QLstmDescriptor::m_PeepholeEnabled, QLstmDescriptor::m_ProjectionClip, and QLstmDescriptor::m_ProjectionEnabled.

◆ GetUnidirectionalSequenceLstmDescriptor()

armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor ( UnidirectionalSequenceLstmDescriptorPtr  descriptor)
static

Definition at line 3899 of file Deserializer.cpp.

3901 {
3903 
3904  desc.m_ActivationFunc = descriptor->activationFunc();
3905  desc.m_ClippingThresCell = descriptor->clippingThresCell();
3906  desc.m_ClippingThresProj = descriptor->clippingThresProj();
3907  desc.m_CifgEnabled = descriptor->cifgEnabled();
3908  desc.m_PeepholeEnabled = descriptor->peepholeEnabled();
3909  desc.m_ProjectionEnabled = descriptor->projectionEnabled();
3910  desc.m_LayerNormEnabled = descriptor->layerNormEnabled();
3911  desc.m_TimeMajor = descriptor->timeMajor();
3912 
3913  return desc;
3914 }
bool m_TimeMajor
Enable/disable time major.

References LstmDescriptor::m_ActivationFunc, LstmDescriptor::m_CifgEnabled, LstmDescriptor::m_ClippingThresCell, LstmDescriptor::m_ClippingThresProj, LstmDescriptor::m_LayerNormEnabled, LstmDescriptor::m_PeepholeEnabled, LstmDescriptor::m_ProjectionEnabled, and LstmDescriptor::m_TimeMajor.

◆ LoadGraphFromBinary()

GraphPtr LoadGraphFromBinary ( const uint8_t *  binaryContent,
size_t  len 
)
static

Definition at line 901 of file Deserializer.cpp.

902 {
903  if (binaryContent == nullptr)
904  {
905  throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
906  CHECK_LOCATION().AsString()));
907  }
908  flatbuffers::Verifier verifier(binaryContent, len);
909  if (verifier.VerifyBuffer<SerializedGraph>() == false)
910  {
911  throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
912  "flatbuffers format. size:{0} {1}",
913  len,
914  CHECK_LOCATION().AsString()));
915  }
916  return GetSerializedGraph(binaryContent);
917 }

References CHECK_LOCATION.

◆ operator=()

DeserializerImpl& operator= ( const DeserializerImpl )
delete

◆ OutputShapeOfReshape()

armnn::TensorInfo OutputShapeOfReshape ( const armnn::TensorInfo inputTensorInfo,
const std::vector< uint32_t > &  targetDimsIn 
)
static

Definition at line 2640 of file Deserializer.cpp.

2642 {
2643  std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
2644  const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
2645 
2646  if (stretchDim != targetDimsIn.end())
2647  {
2648  if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
2649  {
2650  throw ParseException(fmt::format("At most one component of shape can be -1 {}",
2651  CHECK_LOCATION().AsString()));
2652  }
2653 
2654  auto targetNumElements =
2655  armnn::numeric_cast<unsigned int>(
2656  std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
2657 
2658  auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
2659  if (targetNumElements == 0)
2660  {
2661  if (inputTensorInfo.GetNumElements() == 0)
2662  {
2663  outputDims[stretchIndex] = 0;
2664  }
2665  else
2666  {
2667  throw ParseException(
2668  fmt::format("Input to reshape is a tensor with elements, but the requested shape has 0. {}",
2669  CHECK_LOCATION().AsString()));
2670  }
2671  }
2672  else
2673  {
2674  outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
2675  }
2676  }
2677 
2678  TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
2679 
2680  armnn::TensorInfo reshapeInfo = inputTensorInfo;
2681  reshapeInfo.SetShape(outputShape);
2682 
2683  return reshapeInfo;
2684 }
unsigned int GetNumElements() const
Definition: Tensor.hpp:198
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:195

References CHECK_LOCATION, TensorInfo::GetNumElements(), and TensorInfo::SetShape().


The documentation for this class was generated from the following files: