Compute Library
 21.08
SplitLayerSubTensorMutator.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
32 
33 #include "support/Cast.h"
34 #include "support/Iterable.h"
35 
36 namespace arm_compute
37 {
38 namespace graph
39 {
41 {
42  return "SplitLayerSubTensorMutator";
43 }
44 
46 {
48 }
49 
51 {
52  // Early exit if no Split layers exist in graph
53  if(g.nodes(NodeType::SplitLayer).empty())
54  {
55  return;
56  }
57 
58  // Perform topological sort
59  std::vector<NodeID> topological_sorted_node_ids = dfs(g);
60 
61  // Should be in reverse order of execution
62  for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
63  {
64  INode *node = g.node(node_id);
65  if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
66  {
67  // Get output tensor
68  Tensor *input_tensor = node->input(0);
69 
70  // Check that all tensor have the same target and are valid
71  bool is_valid = std::all_of(node->outputs().cbegin(), node->outputs().cend(),
72  [&](const TensorID & tid)
73  {
74  return (g.tensor(tid) != nullptr) && (g.tensor(tid)->desc().target == input_tensor->desc().target);
75  });
76 
77  // Create subtensors
78  if(is_valid && is_target_supported(input_tensor->desc().target))
79  {
80  ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
81  << node->id() << " and name : " << node->name() << std::endl);
82 
83  auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
84 
85  const int axis = split_node->axis();
86  const unsigned int num_splits = split_node->num_splits();
87  const bool extend_parent = (axis < 2);
88 
89  // Create sub-tensor handles
90  for(unsigned int i = 0; i < node->outputs().size(); ++i)
91  {
92  Tensor *output_tensor = node->output(i);
93  const TensorShape output_shape = output_tensor->desc().shape;
94  Coordinates coords;
95  std::tie(std::ignore, coords) = split_node->compute_output_descriptor(input_tensor->desc(), num_splits, axis, i);
96 
98  std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
99  output_tensor->set_handle(std::move(handle));
100  }
101  }
102  }
103  }
104 }
105 } // namespace graph
106 } // namespace arm_compute
const char * name() override
Returns mutator name.
std::string name() const
Returns node&#39;s name.
Definition: INode.cpp:107
Shape of a tensor.
Definition: TensorShape.h:39
void set_handle(std::unique_ptr< ITensorHandle > backend_tensor)
Sets the backend tensor.
Definition: Tensor.cpp:50
ITensorHandle * handle()
Backend tensor handle accessor.
Definition: Tensor.cpp:55
IDeviceBackend & get_backend(Target target)
Get a backend from the registry.
Copyright (c) 2017-2021 Arm Limited.
TensorDescriptor & desc()
TensorInfo metadata accessor.
Definition: Tensor.cpp:40
Node interface.
Definition: INode.h:45
Tensor * output(size_t idx) const
Returns the tensor of a given output of the node.
Definition: INode.cpp:158
virtual void mutate(Graph &g) override
Walk the graph and perform a specific mutation.
Coordinates of an item.
Definition: Coordinates.h:37
bool is_target_supported(Target target)
Checks if a specific target is supported.
Definition: Utils.cpp:34
NodeID id() const
Returns node&#39;s ID.
Definition: INode.cpp:102
const std::vector< TensorID > & outputs() const
Returns outputs of the node.
Definition: INode.cpp:122
static BackendRegistry & get()
Gets backend registry instance.
MutationType type() const override
Returns mutation type.
std::vector< NodeID > dfs(Graph &g)
Depth first search traversal.
virtual std::unique_ptr< ITensorHandle > create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)=0
Create a backend Sub-Tensor.
Graph class.
Definition: Graph.h:53
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
Definition: Graph.cpp:174
const INode * node(NodeID id) const
Get node object given its id.
Definition: Graph.cpp:204
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
Definition: Logger.h:50
virtual NodeType type() const =0
Returns node&#39;s type.
reverse_iterable< T > reverse_iterate(T &val)
Creates a reverse iterable for a given type.
Definition: Iterable.h:101
Tensor * input(size_t idx) const
Returns the tensor of a given input of the node.
Definition: INode.cpp:150
unsigned int TensorID
Definition: Types.h:67
const Tensor * tensor(TensorID id) const
Get tensor object given its id.
Definition: Graph.cpp:224
Tensor object.
Definition: Tensor.h:41