ArmNN
 25.11
Loading...
Searching...
No Matches
TensorIOUtils.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
9
10#include <fmt/format.h>
11#include <mapbox/variant.hpp>
12
13namespace armnnUtils
14{
15
16template<typename TContainer>
17inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18 const std::vector<TContainer>& inputDataContainers)
19{
20 armnn::InputTensors inputTensors;
21
22 const size_t numInputs = inputBindings.size();
23 if (numInputs != inputDataContainers.size())
24 {
25 throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26 "tensor data containers: {0} != {1}",
27 numInputs,
28 inputDataContainers.size()));
29 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
36 mapbox::util::apply_visitor([&](auto&& value)
37 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
40 throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41 inputBinding.second.GetNumElements(),
42 value.size()));
43 }
44 armnn::TensorInfo inputTensorInfo = inputBinding.second;
45 inputTensorInfo.SetConstant(true);
46 armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
47 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
48 },
49 inputData);
50 }
51
52 return inputTensors;
53}
54
55template<typename TContainer>
56inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
57 std::vector<TContainer>& outputDataContainers)
58{
59 armnn::OutputTensors outputTensors;
60
61 const size_t numOutputs = outputBindings.size();
62 if (numOutputs != outputDataContainers.size())
63 {
64 throw armnn::Exception(fmt::format("Number of outputs does not match number"
65 "of tensor data containers: {0} != {1}",
66 numOutputs,
67 outputDataContainers.size()));
68 }
69
70 for (size_t i = 0; i < numOutputs; i++)
71 {
72 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
73 TContainer& outputData = outputDataContainers[i];
74
75 mapbox::util::apply_visitor([&](auto&& value)
76 {
77 if (value.size() != outputBinding.second.GetNumElements())
78 {
79 throw armnn::Exception("Output tensor has incorrect size");
80 }
81
82 armnn::Tensor outputTensor(outputBinding.second, value.data());
83 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
84 },
85 outputData);
86 }
87
88 return outputTensors;
89}
90
91} // namespace armnnUtils
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
Base class for all ArmNN exceptions so that users can filter to just those.
A tensor defined by a TensorInfo (shape and data type) and a mutable backing store.
Definition Tensor.hpp:322
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition Tensor.cpp:518
std::pair< armnn::LayerBindingId, armnn::TensorInfo > BindingPointInfo
Definition Tensor.hpp:276
std::vector< std::pair< LayerBindingId, class ConstTensor > > InputTensors
Definition Tensor.hpp:394
std::vector< std::pair< LayerBindingId, class Tensor > > OutputTensors
Definition Tensor.hpp:395
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char >, std::vector< int8_t > > TContainer
armnn::InputTensors MakeInputTensors(const std::vector< armnn::BindingPointInfo > &inputBindings, const std::vector< TContainer > &inputDataContainers)
armnn::OutputTensors MakeOutputTensors(const std::vector< armnn::BindingPointInfo > &outputBindings, std::vector< TContainer > &outputDataContainers)