ArmNN
 24.08
TensorIOUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
10 #include <fmt/format.h>
11 #include <mapbox/variant.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 template<typename TContainer>
17 inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18  const std::vector<TContainer>& inputDataContainers)
19 {
20  armnn::InputTensors inputTensors;
21 
22  const size_t numInputs = inputBindings.size();
23  if (numInputs != inputDataContainers.size())
24  {
25  throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26  "tensor data containers: {0} != {1}",
27  numInputs,
28  inputDataContainers.size()));
29  }
30 
31  for (size_t i = 0; i < numInputs; i++)
32  {
33  const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34  const TContainer& inputData = inputDataContainers[i];
35 
36  mapbox::util::apply_visitor([&](auto&& value)
37  {
38  if (value.size() != inputBinding.second.GetNumElements())
39  {
40  throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41  inputBinding.second.GetNumElements(),
42  value.size()));
43  }
44  armnn::TensorInfo inputTensorInfo = inputBinding.second;
45  inputTensorInfo.SetConstant(true);
46  armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47  inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48  },
49  inputData);
50  }
51 
52  return inputTensors;
53 }
54 
55 template<typename TContainer>
56 inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57  std::vector<TContainer>& outputDataContainers)
58 {
59  armnn::OutputTensors outputTensors;
60 
61  const size_t numOutputs = outputBindings.size();
62  if (numOutputs != outputDataContainers.size())
63  {
64  throw armnn::Exception(fmt::format("Number of outputs does not match number"
65  "of tensor data containers: {0} != {1}",
66  numOutputs,
67  outputDataContainers.size()));
68  }
69 
70  for (size_t i = 0; i < numOutputs; i++)
71  {
72  const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73  TContainer& outputData = outputDataContainers[i];
74 
75  mapbox::util::apply_visitor([&](auto&& value)
76  {
77  if (value.size() != outputBinding.second.GetNumElements())
78  {
79  throw armnn::Exception("Output tensor has incorrect size");
80  }
81 
82  armnn::Tensor outputTensor(outputBinding.second, value.data());
83  outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84  },
85  outputData);
86  }
87 
88  return outputTensors;
89 }
90 
91 } // namespace armnnUtils
armnn::BindingPointInfo
std::pair< armnn::LayerBindingId, armnn::TensorInfo > BindingPointInfo
Definition: Tensor.hpp:276
armnn::Tensor
A tensor defined by a TensorInfo (shape and data type) and a mutable backing store.
Definition: Tensor.hpp:321
armnn::InputTensors
std::vector< std::pair< LayerBindingId, class ConstTensor > > InputTensors
Definition: Tensor.hpp:394
armnnUtils::MakeInputTensors
armnn::InputTensors MakeInputTensors(const std::vector< armnn::BindingPointInfo > &inputBindings, const std::vector< TContainer > &inputDataContainers)
Definition: TensorIOUtils.hpp:17
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::OutputTensors
std::vector< std::pair< LayerBindingId, class Tensor > > OutputTensors
Definition: Tensor.hpp:395
armnnUtils::MakeOutputTensors
armnn::OutputTensors MakeOutputTensors(const std::vector< armnn::BindingPointInfo > &outputBindings, std::vector< TContainer > &outputDataContainers)
Definition: TensorIOUtils.hpp:56
armnnUtils
Definition: CompatibleTypes.hpp:10
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
Tensor.hpp
armnn::ConstTensor
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:329
armnn::TensorInfo::SetConstant
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition: Tensor.cpp:518
armnnUtils::TContainer
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char >, std::vector< int8_t > > TContainer
Definition: TContainer.hpp:18