ArmNN
 25.11
Loading...
Searching...
No Matches
ConvertConstants.hpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Optimization.hpp"
9
13
14#include <Half.hpp>
15
16namespace armnn
17{
18namespace optimizations
19{
20
22{
23 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
24 {
25 const TensorInfo& info = handle->GetTensorInfo();
26
27 if (info.GetDataType() == DataType::Float16)
28 {
29 std::vector<float> newValues(info.GetNumElements());
30
32 info.GetNumElements(),
33 newValues.data());
34
35 TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
36 ConstTensor newInput(newInfo, newValues);
37 handle.reset(new ScopedTensorHandle(newInput));
38 }
39 }
40};
41
43{
44 static void Func(std::shared_ptr<ConstTensorHandle>& handle)
45 {
46 const TensorInfo& info = handle->GetTensorInfo();
47
48 if (info.GetDataType() == DataType::Float32)
49 {
50 std::vector<Half> newValues(info.GetNumElements());
51
52 armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
53 info.GetNumElements(),
54 newValues.data());
55
56 TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true);
57 ConstTensor newInput(newInfo, newValues);
58 handle.reset(new ScopedTensorHandle(newInput));
59 }
60 }
61};
62
63template<typename Converter, typename Predicate>
65{
66public:
67 ConvertConstants() = default;
69 virtual ~ConvertConstants() = default;
70
71 void Run(Graph& graph, Layer& layer) const override
72 {
73 IgnoreUnused(graph);
74 if (Predicate::Test(layer))
75 {
76 layer.OperateOnConstantTensors(Converter::Func);
77 }
78 }
79protected:
80};
81
83{
84 static bool Test(const Layer& layer)
85 {
86 return layer.GetDataType() == DataType::Float32;
87 }
88};
89
91{
92 static bool Test(const Layer& layer)
93 {
94 return layer.GetDataType() == DataType::Float16;
95 }
96};
97
100
101} //namespace optimizations
102} //namespace armnn
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition Tensor.hpp:330
void OperateOnConstantTensors(Op op)
Definition Layer.hpp:319
DataType GetDataType() const
Definition Layer.cpp:345
Optimization()=default
ConvertConstants(const ConvertConstants &)=default
void Run(Graph &graph, Layer &layer) const override
static void ConvertFloat16To32(const void *srcFloat16Buffer, size_t numElements, float *dstFloat32Buffer)
static void ConvertFloat32To16(const float *srcFloat32Buffer, size_t numElements, void *dstFloat16Buffer)
Converts a buffer of FP32 values to FP16, and stores in the given dstFloat16Buffer.
ConvertConstants< Float16ToFloat32, IsFloat32Layer > ConvertConstantsHalfToFloat
ConvertConstants< Float32ToFloat16, IsFloat16Layer > ConvertConstantsFloatToHalf
Copyright (c) 2021 ARM Limited and Contributors.
half_float::half Half
Definition Half.hpp:22
void IgnoreUnused(Ts &&...)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static void Func(std::shared_ptr< ConstTensorHandle > &handle)
static bool Test(const Layer &layer)
static bool Test(const Layer &layer)