ArmNN
 24.08
AddBroadcastReshapeLayerImpl Class Reference

#include <AddBroadcastReshapeLayer.hpp>

Public Member Functions

void Run (Graph &graph, Layer &layer) const
 Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different. More...
 

Protected Member Functions

 AddBroadcastReshapeLayerImpl ()=default
 
 ~AddBroadcastReshapeLayerImpl ()=default
 

Detailed Description

Definition at line 23 of file AddBroadcastReshapeLayer.hpp.

Constructor & Destructor Documentation

◆ AddBroadcastReshapeLayerImpl()

AddBroadcastReshapeLayerImpl ( )
protecteddefault

◆ ~AddBroadcastReshapeLayerImpl()

~AddBroadcastReshapeLayerImpl ( )
protecteddefault

Member Function Documentation

◆ Run()

void Run ( Graph graph,
Layer layer 
) const
inline

Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.

Definition at line 27 of file AddBroadcastReshapeLayer.hpp.

28  {
29  if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
30  {
31  layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
32  layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
33 
34  const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
35  const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
36 
37  if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
38  {
39  return;
40  }
41 
42  unsigned int reshapeSlot = 1;
43  TensorInfo reshapeInfo = inputInfo1;
44  TensorInfo inputInfo = inputInfo0;
45 
46  if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
47  {
48  reshapeSlot = 0;
49  reshapeInfo = inputInfo0;
50  inputInfo = inputInfo1;
51  }
52 
53  uint32_t numDimensions = inputInfo.GetNumDimensions();
54 
55  std::vector<unsigned> reshapedDim;
56  for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
57  {
58  reshapedDim.push_back(reshapeInfo.GetShape()[i]);
59  }
60 
61  std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
62  std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
63 
64  reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
65 
66  // If the parent layer is a Constant layer and it is only used once we can short circuit by just
67  // changing the tensor info rather than adding a reshape layer.
68  Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
69  if ((parentLayer.GetType() == armnn::LayerType::Constant) &&
70  (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
71  {
72  ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
73 
74  constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
75  ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
76  constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
77  }
78  else
79  {
80  const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
81  const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
82  ReshapeLayer* reshapeLayer =
83  graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
84  reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
85  }
86  }
87  }

References armnn::Constant, InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), Layer::GetNameStr(), OutputSlot::GetNumConnections(), TensorInfo::GetNumDimensions(), Layer::GetOutputSlot(), OutputSlot::GetOwningLayer(), TensorInfo::GetShape(), OutputSlot::GetTensorInfo(), Layer::GetType(), Graph::InsertNewLayer(), OutputSlot::IsTensorInfoSet(), ConstantLayer::m_LayerOutput, TensorInfo::SetShape(), and OutputSlot::SetTensorInfo().


The documentation for this class was generated from the following file:
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::LayerType::Constant
@ Constant