28 {
29 if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
30 {
31 layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
32 layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
33
34 const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
35 const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
36
37 if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
38 {
39 return;
40 }
41
42 unsigned int reshapeSlot = 1;
43 TensorInfo reshapeInfo = inputInfo1;
44 TensorInfo inputInfo = inputInfo0;
45
46 if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
47 {
48 reshapeSlot = 0;
49 reshapeInfo = inputInfo0;
50 inputInfo = inputInfo1;
51 }
52
53 uint32_t numDimensions = inputInfo.GetNumDimensions();
54
55 std::vector<unsigned> reshapedDim;
56 for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
57 {
58 reshapedDim.push_back(reshapeInfo.GetShape()[i]);
59 }
60
61 std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
62 std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
63
64 reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
65
66
67
68
69 bool elementWiseMul = false;
71 {
72 auto& binaryLayerOp = *PolymorphicDowncast<ElementwiseBinaryLayer*>(&layer);
73 elementWiseMul = binaryLayerOp.GetParameters().m_Operation==BinaryOperation::Mul;
74 }
75 auto parentName = std::string(layer.GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer().GetName());
76 Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
78 ((parentLayer.GetOutputSlot(0).GetNumConnections() == 1) ||
79 (parentName.find("Quantize")!=std::string::npos && elementWiseMul)))
80 {
81 ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
82
83 constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
84 ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
85 constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
86 }
87 else
88 {
89 const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
90 const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
91 ReshapeLayer* reshapeLayer =
92 graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
93 reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
94 }
95 }
96 }