29 if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
31 layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
32 layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
34 const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
35 const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
37 if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
42 unsigned int reshapeSlot = 1;
43 TensorInfo reshapeInfo = inputInfo1;
44 TensorInfo inputInfo = inputInfo0;
46 if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
49 reshapeInfo = inputInfo0;
50 inputInfo = inputInfo1;
53 uint32_t numDimensions = inputInfo.GetNumDimensions();
55 std::vector<unsigned> reshapedDim;
56 for (
unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
58 reshapedDim.push_back(reshapeInfo.GetShape()[i]);
61 std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
62 std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
68 Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
70 (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
72 ConstantLayer& constantLayer =
static_cast<ConstantLayer&
>(parentLayer);
74 constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
75 ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<
void>()));
76 constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
80 const std::string layerName =
"Reshape_for:" + layer.GetNameStr() +
"-" + std::to_string(reshapeSlot);
81 const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
82 ReshapeLayer* reshapeLayer =
83 graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
84 reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);