44 NodeID create_grouped_convolution(Graph &g,
45 const NodeParams ¶ms,
51 ActivationLayerInfo fused_act,
75 std::vector<NodeIdxPair> convolution_outputs;
78 NodeParams group_params = params;
79 NodeID conv_nid = g.add_node<ConvolutionLayerNode>(
conv_info, 1, method, fast_math_hint);
80 g.add_connection(input_split, i, conv_nid, 0);
81 g.add_connection(weights_split, i, conv_nid, 1);
84 g.add_connection(bias_split, i, conv_nid, 2);
88 if (!group_params.name.empty())
94 INode *node = g.node(conv_nid);
96 node->set_common_node_parameters(group_params);
99 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
100 conv_node->set_fused_activation(fused_act);
102 convolution_outputs.push_back({conv_nid, 0});
112 return "GroupedConvolutionMutator";
129 size_t total_nodes = g.
nodes().size();
132 for (
unsigned int i = 0; i < total_nodes; ++i)
136 arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node)->num_groups() != 1)
146 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
153 const FastMathHint fast_math_hint = conv_node->fast_math_hint();
154 const unsigned int num_groups = conv_node->num_groups();
155 const NodeParams params = conv_node->common_node_params();
156 const Target assigned_target = conv_node->assigned_target();
159 ARM_COMPUTE_ERROR_ON(conv_node->input_edge(0) ==
nullptr || conv_node->input_edge(1) ==
nullptr);
160 const NodeID input_id = conv_node->input_edge(0)->producer()->id();
161 const NodeID weights_id = conv_node->input_edge(1)->producer()->id();
163 (conv_node->input_edge(2) !=
nullptr) ? conv_node->input_edge(2)->producer()->id() :
EmptyNodeID;
169 auto node_accessor = conv_node->output(0)->extract_accessor();
177 create_grouped_convolution(g, params, {input_id, 0}, weights_id, bias_id,
conv_info, conv_method,
184 for (
auto &driving_node : driving_nodes)
186 g.
add_connection(grouped_conv_id, 0, driving_node.node_id, driving_node.index);
194 [](std::unique_ptr<Tensor> &
t) { configure_tensor(t.get()); });
196 [&assigned_target](std::unique_ptr<INode> &n)
200 n->set_assigned_target(assigned_target);