47 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->
producer());
48 auto *bn_node = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->
consumer());
51 if(conv_node->num_groups() > 1)
57 <<
" with BatchNormalization Layer node with ID : " << output_edge->
consumer_id() << std::endl);
60 if(conv_node->output(0)->accessor() ==
nullptr)
62 const Target assigned_target = conv_node->assigned_target();
65 const auto conv_input_id = conv_node->input_edge(0)->producer_id();
66 const auto conv_weights_id = conv_node->input_edge(1)->producer_id();
67 const auto conv_info = conv_node->convolution_info();
68 const auto conv_method = conv_node->convolution_method();
69 const auto num_groups = conv_node->num_groups();
70 const auto act_info = bn_node->fused_activation();
71 FastMathHint fast_math_hint = conv_node->fast_math_hint();
74 const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
75 const auto bn_var_id = bn_node->input_edge(2)->producer_id();
77 const auto epsilon = bn_node->epsilon();
82 if(conv_node->input_edge(2) !=
nullptr)
94 if(bn_node->input_edge(3) !=
nullptr)
96 const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
100 if(bn_node->input_edge(4) !=
nullptr)
102 const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
106 auto fused_node = g.
node(fused_id);
110 auto bn_node_accessor = bn_node->output(0)->extract_accessor();
111 auto bn_node_name = bn_node->name();
117 for(
auto &driving_node : bn_driving_nodes)
119 g.
add_connection(fused_id, 0, driving_node.node_id, driving_node.index);
123 fused_node->output(0)->set_accessor(std::move(bn_node_accessor));
124 fused_node->set_assigned_target(assigned_target);
125 fused_node->set_common_node_parameters(
NodeParams{ conv_node->
name() +
"+" + bn_node_name, assigned_target });
140 auto *depth_conv_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(output_edge->
producer());
141 auto *bn_node = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->
consumer());
144 <<
" with BatchNormalization Layer node with ID : " << output_edge->
consumer_id() << std::endl);
147 if(depth_conv_node->output(0)->accessor() ==
nullptr)
149 const Target assigned_target = depth_conv_node->assigned_target();
152 const auto depth_conv_input_id = depth_conv_node->input_edge(0)->producer_id();
153 const auto conv_weights_id = depth_conv_node->input_edge(1)->producer_id();
154 const auto conv_info = depth_conv_node->convolution_info();
155 const auto depth_conv_method = depth_conv_node->depthwise_convolution_method();
156 const auto depth_multiplier = depth_conv_node->depth_multiplier();
157 const auto act_info = bn_node->fused_activation();
160 const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
161 const auto bn_var_id = bn_node->input_edge(2)->producer_id();
162 const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
163 const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
164 const auto epsilon = bn_node->epsilon();
169 if(depth_conv_node->input_edge(2) !=
nullptr)
183 auto fused_node = g.
node(fused_id);
187 auto bn_node_accessor = bn_node->output(0)->extract_accessor();
188 auto bn_node_name = bn_node->name();
194 for(
auto &driving_node : bn_driving_nodes)
196 g.
add_connection(fused_id, 0, driving_node.node_id, driving_node.index);
200 fused_node->output(0)->set_accessor(std::move(bn_node_accessor));
201 fused_node->set_assigned_target(assigned_target);
202 fused_node->set_common_node_parameters(
NodeParams{ depth_conv_node->
name() +
"+" + bn_node_name, assigned_target });
209 ARM_COMPUTE_LOG_GRAPH_VERBOSE(
"Prevented fusion of depthwise convolution with batch normalization due to the presence of an output accessor\n");
213 template <
typename N>
218 auto *n_node = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->
producer());
219 auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(output_edge->
consumer());
224 if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
236 <<
" with Activation Layer node with ID : " << output_edge->
consumer_id() << std::endl);
239 if(n_node->output(0)->accessor() ==
nullptr)
245 n_node->set_fused_activation(act_node->activation_info());
248 auto act_node_accessor = act_node->output(0)->extract_accessor();
254 for(
auto &driving_node : act_driving_nodes)
256 g.
add_connection(n_node->id(), 0, driving_node.node_id, driving_node.index);
260 n_node->output(0)->set_accessor(std::move(act_node_accessor));
268 template <
typename N1,
typename N2,
typename F,
typename... Args>
269 void fuse_layer(
Graph &g, std::function<
bool(
INode &)>
const &prec,
const F fuse_fcn, Args &&... optional_arguments)
274 for(
unsigned int i = 0; i < g.
nodes().size(); ++i)
276 auto node = g.
node(i);
278 if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
280 const auto output_edge_id = *node->
output_edges().begin();
281 const auto output_edge = g.
edge(output_edge_id);
284 if((output_edge !=
nullptr) && (output_edge->consumer() !=
nullptr) && (output_edge->consumer()->type() == N2::node_type) && prec(*output_edge->producer()))
286 fuse_fcn(g, output_edge, optional_arguments...);
295 return "NodeFusionMutator";
314 auto empty_prec = [](
INode &)
318 auto cl_target_prec = [](
INode & n)
322 auto qs8_prec = [&g](
INode & n)
326 const auto output_edge_id = *n.output_edges().begin();
327 const auto output_edge = g.
edge(output_edge_id);
332 return (output_qasymm8 && same_qinfo) || !output_qasymm8;
336 detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations);
337 detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations);
338 detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations);
339 detail::fuse_layer<FullyConnectedLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<FullyConnectedLayerNode>, supported_fused_activations);
340 detail::fuse_layer<EltwiseLayerNode, ActivationLayerNode>(g, cl_target_prec, detail::fuse_node_with_activation<EltwiseLayerNode>, supported_fused_activations);
Edge * input_edge(size_t idx) const
Returns the edge of a given input of the node.
void configure_tensor(Tensor *tensor)
Configures tensor.
INode * consumer() const
Returns consumer node.
const std::set< EdgeID > & output_edges() const
Returns output edge set.
Fused Depthwise Convolution Batch Normalization node.
std::vector< NodeIdxPair > get_driving_nodes(const INode &node)
Get the list of driving nodes of a given node.
NodeID add_node(Ts &&... args)
Adds a node to the graph.
void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
void fuse_depthwise_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
Copyright (c) 2017-2021 Arm Limited.
Batch Normalization node.
TensorDescriptor & desc()
TensorInfo metadata accessor.
QuantizationInfo quant_info
Quantization info.
Exponential Linear Unit ( )
Tensor * output(size_t idx) const
Returns the tensor of a given output of the node.
quantized, asymmetric fixed-point 8-bit number unsigned
const unsigned int num_groups
bool remove_node(NodeID nid)
Remove the node with the given ID.
NodeID producer_id() const
Returns producer node id.
void fuse_layer(Graph &g, std::function< bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx)
Adds a connection between two nodes.
FastMathHint
Enable or disable fast math for Convolution layer.
Lower and Upper Bounded Rectifier ( )
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
const char * name() override
Returns mutator name.
Upper Bounded Rectifier ( )
MutationType
Mutation type.
MutationType type() const override
Returns mutation type.
const INode * node(NodeID id) const
Get node object given its id.
std::string name
Node name.
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
const Edge * edge(EdgeID id) const
Get edge object given its id.
NodeID consumer_id() const
Returns sink node id.
void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set< Activation > &supported_fused_activations)
OpenCL capable target device.
INode * producer() const
Returns producer node.
virtual void mutate(Graph &g) override
Walk the graph and perform a specific mutation.
bool is_data_type_float(DataType dt)
Check if a given data type is of floating point type.