ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DeleteBroadcastTo.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
14 {
15 public:
16  /// Run for every BroadcastToLayer. Remove it if it is before an ElementWiseLayer.
17  /// Since ElementWiseBinary uses a brodcastLoop, using a broadcastTo layer is
18  /// not necessary so it will be deleted.
19  void Run(Graph&, BroadcastToLayer& layer) const
20  {
21  if(layer.GetType() == LayerType::BroadcastTo)
22  {
24  Layer& next = layer.GetOutputSlot(0).GetConnection(0)->GetOwningLayer();
26  {
27  Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
28  auto tensorInfo = connectedLayer.GetOutputSlot().GetTensorInfo();
29  layer.GetOutputSlot().MoveAllConnections(connectedLayer.GetOutputSlot());
30  connectedLayer.GetOutputSlot().GetOutputHandler().SetTensorInfo(tensorInfo);
31  }
32  }
33  }
34 protected:
35  DeleteBroadcastToImpl() = default;
37 };
39 }
40 }
Layer & GetOwningLayer() const
Definition: Layer.hpp:53
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:339
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:286
const std::vector< InputSlot > & GetInputSlots() const
Definition: Layer.hpp:258
void SetTensorInfo(const TensorInfo &tensorInfo)
Sets the TensorInfo used by this output handler.
const InputSlot * GetConnection(unsigned int index) const override
Definition: Layer.cpp:83
void MoveAllConnections(OutputSlot &destination)
Moves all connections to another OutputSlot.
Definition: Layer.cpp:156
const OutputHandler & GetOutputHandler() const
Definition: Layer.hpp:139
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:100
void Run(Graph &, BroadcastToLayer &layer) const
Run for every BroadcastToLayer.
Copyright (c) 2021 ARM Limited and Contributors.