ArmNN
 25.11
Loading...
Searching...
No Matches
Broadcast.hpp
Go to the documentation of this file.
1//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BaseIterator.hpp"
7#include <armnn/Tensor.hpp>
8
9#include <functional>
10
11namespace armnn
12{
13
15{
16 BroadcastLoop(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape);
17
18 BroadcastLoop(const TensorShape& inShape, const TensorShape& outShape);
19
20 unsigned int GetNumDimensions()
21 {
22 return static_cast<unsigned int>(m_DimData.size());
23 }
24
25 template <typename Func, typename DecoderOp, typename EncoderOp>
26 void Unroll(Func operationFunc,
27 unsigned int dimension,
28 DecoderOp& inData0,
29 DecoderOp& inData1,
30 EncoderOp& outData)
31 {
32 if (dimension >= GetNumDimensions())
33 {
34 outData.Set(operationFunc(inData0.Get(), inData1.Get()));
35 return;
36 }
37
38 unsigned int inData0Movement = 0;
39 unsigned int inData1Movement = 0;
40 unsigned int outDataMovement = 0;
41
42 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
43 {
44 Unroll(operationFunc, dimension + 1, inData0, inData1, outData);
45
46 inData0 += m_DimData[dimension].m_Stride1;
47 inData1 += m_DimData[dimension].m_Stride2;
48 outData += m_DimData[dimension].m_StrideOut;
49
50 inData0Movement += m_DimData[dimension].m_Stride1;
51 inData1Movement += m_DimData[dimension].m_Stride2;
52 outDataMovement += m_DimData[dimension].m_StrideOut;
53 }
54
55 // move iterator back to the start
56 inData0 -= inData0Movement;
57 inData1 -= inData1Movement;
58 outData -= outDataMovement;
59 }
60
61 template <typename Func, typename DecoderOp, typename EncoderOp>
62 void Unroll(Func operationFunc,
63 unsigned int dimension,
64 DecoderOp& inData,
65 EncoderOp& outData)
66 {
67 if (dimension >= GetNumDimensions())
68 {
69 outData.Set(operationFunc(inData.Get()));
70 return;
71 }
72
73 unsigned int inDataMovement = 0;
74 unsigned int outDataMovement = 0;
75
76 for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++)
77 {
78 Unroll(operationFunc, dimension + 1, inData, outData);
79
80 inData += m_DimData[dimension].m_Stride1;
81 outData += m_DimData[dimension].m_StrideOut;
82
83 inDataMovement += m_DimData[dimension].m_Stride1;
84 outDataMovement += m_DimData[dimension].m_StrideOut;
85 }
86
87 // move iterator back to the start
88 inData -= inDataMovement;
89 outData -= outDataMovement;
90 }
91
92private:
93 // Struct to hold the dimension data.
94 struct BroadcastDimensionData
95 {
96 unsigned int m_DimSize;
97 unsigned int m_StrideOut;
98 unsigned int m_Stride1;
99 unsigned int m_Stride2;
100 };
101
102 std::vector<BroadcastDimensionData> m_DimData;
103};
104
105} //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
unsigned int GetNumDimensions()
Definition Broadcast.hpp:20
void Unroll(Func operationFunc, unsigned int dimension, DecoderOp &inData, EncoderOp &outData)
Definition Broadcast.hpp:62
BroadcastLoop(const TensorShape &inShape0, const TensorShape &inShape1, const TensorShape &outShape)
Definition Broadcast.cpp:11
void Unroll(Func operationFunc, unsigned int dimension, DecoderOp &inData0, DecoderOp &inData1, EncoderOp &outData)
Definition Broadcast.hpp:26