45 bool output_edges_are_separate_tensors(Graph &g,
const Edge *input_edge)
47 const auto parent_node = input_edge->producer();
48 const auto input_tensor = input_edge->tensor();
49 const auto input_edge_id = input_edge->id();
51 if(parent_node ==
nullptr)
56 const auto output_edges = parent_node->output_edges();
60 if(output_edges.size() == 1)
65 return std::all_of(output_edges.begin(),
67 [&](
const EdgeID & edge_id)
70 if(edge_id == input_edge_id)
75 auto edge = g.edge(edge_id);
76 return edge->tensor() != input_tensor;
81 void set_new_output_and_inherit_accessor(std::unique_ptr<INode> &node, Tensor *orig_output, Tensor *new_output)
84 << node->id() <<
" and name : " << node->name() << std::endl);
86 new_output->set_accessor(orig_output->extract_accessor());
88 node->set_output_tensor(new_output->id(), 0);
92 void try_in_place_depthwiseconv(std::unique_ptr<INode> &node)
95 Edge *input_edge = node->input_edge(0);
96 Edge *weight_edge = node->input_edge(1);
99 auto input_tensor = input_edge->tensor();
100 auto weight_tensor = weight_edge->tensor();
103 const auto input_shape = input_tensor->desc().shape;
104 const auto qinfo_input = input_tensor->desc().quant_info;
106 const auto weight_shape = weight_tensor->desc().shape;
107 const auto weight_layout = weight_tensor->desc().layout;
111 unsigned int depth_multiplier{};
112 if(node->type() == NodeType::FusedDepthwiseConvolutionBatchNormalizationLayer)
114 conv_info = polymorphic_downcast<FusedDepthwiseConvolutionBatchNormalizationNode *>(node.get())->convolution_info();
115 depth_multiplier = polymorphic_downcast<FusedDepthwiseConvolutionBatchNormalizationNode *>(node.get())->depth_multiplier();
117 else if(node->type() == NodeType::DepthwiseConvolutionLayer)
119 conv_info = polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node.get())->convolution_info();
120 depth_multiplier = polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node.get())->depth_multiplier();
124 auto current_output_tensor = node->output(0);
126 const auto out_shape = current_output_tensor->desc().shape;
127 const auto qinfo_out = current_output_tensor->desc().quant_info;
132 input_can_in_place &= weight_layout == input_tensor->desc().layout && weight_layout ==
DataLayout::NHWC;
136 const bool is_1x1 = weight_shape[weights_width_idx] == 1
U && weight_shape[weights_height_idx] == 1
U;
137 input_can_in_place &= is_1x1;
139 input_can_in_place &= depth_multiplier == 1;
140 input_can_in_place &=
conv_info.stride() == std::make_pair(1
U, 1
U);
141 input_can_in_place &= !
conv_info.has_padding();
144 if(input_can_in_place)
146 set_new_output_and_inherit_accessor(node, current_output_tensor, input_tensor);
150 ARM_COMPUTE_LOG_GRAPH_VERBOSE(
"Prevented in-place operation as there is an accessor bound to the input tensor or the quantization info are different.\n");
155 void try_in_place_elementwise(std::unique_ptr<INode> &node)
158 Edge *input0_edge = node->input_edge(0);
159 Edge *input1_edge = node->input_edge(1);
162 auto input0_tensor = input0_edge->tensor();
163 auto input1_tensor = input1_edge->tensor();
166 const auto shape0 = input0_tensor->desc().shape;
167 const auto shape1 = input1_tensor->desc().shape;
168 const auto qinfo0 = input0_tensor->desc().quant_info;
169 const auto qinfo1 = input1_tensor->desc().quant_info;
173 if(out_shape.total_size() == 0)
179 auto current_output_tensor = node->output(0);
181 const auto qinfo_out = current_output_tensor->desc().quant_info;
185 && (input0_tensor->desc().data_type == current_output_tensor->desc().data_type) && (input0_tensor->accessor() ==
nullptr);
187 && (input1_tensor->desc().data_type == current_output_tensor->desc().data_type) && (input1_tensor->accessor() ==
nullptr);
189 if(input0_can_in_place)
191 set_new_output_and_inherit_accessor(node, current_output_tensor, input0_tensor);
193 else if(input1_can_in_place)
195 set_new_output_and_inherit_accessor(node, current_output_tensor, input1_tensor);
199 ARM_COMPUTE_LOG_GRAPH_VERBOSE(
"Prevented in-place operation as there is an accessor bound to the input tensor or the quantization info are different.\n");
206 return "InPlaceOperationMutator";
211 return IGraphMutator::MutationType::Backend;
214 void InPlaceOperationMutator::mutate(
Graph &g)
216 std::set<NodeType> in_place_nodes =
218 NodeType::ActivationLayer,
219 NodeType::BatchNormalizationLayer,
220 NodeType::EltwiseLayer,
221 NodeType::UnaryEltwiseLayer,
222 NodeType::DepthwiseConvolutionLayer,
223 NodeType::FusedDepthwiseConvolutionBatchNormalizationLayer,
228 for(
auto &node : g.
nodes())
230 if(node && in_place_nodes.find(node->type()) !=
std::end(in_place_nodes))
233 Edge *input_edge = node->input_edge(0);
236 if((input_edge !=
nullptr) && output_edges_are_separate_tensors(g, input_edge))
238 if(node->type() == NodeType::EltwiseLayer)
240 try_in_place_elementwise(node);
242 else if(node->type() == NodeType::FusedDepthwiseConvolutionBatchNormalizationLayer || node->type() == NodeType::DepthwiseConvolutionLayer)
244 try_in_place_depthwiseconv(node);
249 auto current_output_tensor = node->output(0);
250 auto new_output_tensor = input_edge->
tensor();
255 if(new_output_tensor->accessor() !=
nullptr || current_output_tensor->desc().quant_info != new_output_tensor->desc().quant_info)
257 ARM_COMPUTE_LOG_GRAPH_VERBOSE(
"Prevented in-place operation as there is an accessor bound to the input tensor or the quantization info are different.\n");
261 set_new_output_and_inherit_accessor(node, current_output_tensor, new_output_tensor);
Tensor * tensor() const
Returns the tensor associated with this edge.
static TensorShape broadcast_shape(const Shapes &... shapes)
If shapes are broadcast compatible, return the broadcasted shape.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
decltype(strategy::transforms) typedef type
#define ARM_COMPUTE_LOG_GRAPH_INFO(x)
Copyright (c) 2017-2021 Arm Limited.
TensorShape input_shape
Validate test suite is to test ARM_COMPUTE_RETURN_ON_* macros we use to check the validity of given a...
bool have_different_dimensions(const Dimensions< T > &dim1, const Dimensions< T > &dim2, unsigned int upper_dim)
void end(TokenStream &in, bool &valid)
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
MutationType
Mutation type.
Num samples, height, width, channels.
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension)
Get the index of the given dimension.