Compute Library
 21.08
SplitLayerSubTensorMutator Class Referencefinal

Mutation pass to optimize split operations by using sub-tensors. More...

#include <SplitLayerSubTensorMutator.h>

Collaboration diagram for SplitLayerSubTensorMutator:
[legend]

Public Member Functions

virtual void mutate (Graph &g) override
 Walk the graph and perform a specific mutation. More...
 
MutationType type () const override
 Returns mutation type. More...
 
const char * name () override
 Returns mutator name. More...
 
- Public Member Functions inherited from IGraphMutator
virtual ~IGraphMutator ()=default
 Virtual Destructor. More...
 

Additional Inherited Members

- Public Types inherited from IGraphMutator
enum  MutationType { IR, Backend }
 Mutation type. More...
 

Detailed Description

Mutation pass to optimize split operations by using sub-tensors.

Warning
This is compulsory to run in case Split layers are present in the model

Definition at line 37 of file SplitLayerSubTensorMutator.h.

Member Function Documentation

◆ mutate()

void mutate ( Graph g)
overridevirtual

Walk the graph and perform a specific mutation.

Parameters
[in,out]gGraph to walk and mutate

Implements IGraphMutator.

Definition at line 50 of file SplitLayerSubTensorMutator.cpp.

References ARM_COMPUTE_LOG_GRAPH_VERBOSE, IDeviceBackend::create_subtensor(), Tensor::desc(), arm_compute::graph::dfs(), BackendRegistry::get(), BackendRegistry::get_backend(), Tensor::handle(), INode::id(), INode::input(), arm_compute::graph::is_target_supported(), INode::name(), Graph::node(), Graph::nodes(), INode::output(), arm_compute::test::validation::output_shape, INode::outputs(), arm_compute::utils::iterable::reverse_iterate(), Tensor::set_handle(), TensorDescriptor::shape, arm_compute::graph::SplitLayer, TensorDescriptor::target, Graph::tensor(), and INode::type().

51 {
52  // Early exit if no Split layers exist in graph
53  if(g.nodes(NodeType::SplitLayer).empty())
54  {
55  return;
56  }
57 
58  // Perform topological sort
59  std::vector<NodeID> topological_sorted_node_ids = dfs(g);
60 
61  // Should be in reverse order of execution
62  for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
63  {
64  INode *node = g.node(node_id);
65  if(node != nullptr && node->type() == NodeType::SplitLayer && node->input(0) != nullptr)
66  {
67  // Get output tensor
68  Tensor *input_tensor = node->input(0);
69 
70  // Check that all tensor have the same target and are valid
71  bool is_valid = std::all_of(node->outputs().cbegin(), node->outputs().cend(),
72  [&](const TensorID & tid)
73  {
74  return (g.tensor(tid) != nullptr) && (g.tensor(tid)->desc().target == input_tensor->desc().target);
75  });
76 
77  // Create subtensors
78  if(is_valid && is_target_supported(input_tensor->desc().target))
79  {
80  ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
81  << node->id() << " and name : " << node->name() << std::endl);
82 
83  auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
84 
85  const int axis = split_node->axis();
86  const unsigned int num_splits = split_node->num_splits();
87  const bool extend_parent = (axis < 2);
88 
89  // Create sub-tensor handles
90  for(unsigned int i = 0; i < node->outputs().size(); ++i)
91  {
92  Tensor *output_tensor = node->output(i);
93  const TensorShape output_shape = output_tensor->desc().shape;
94  Coordinates coords;
95  std::tie(std::ignore, coords) = split_node->compute_output_descriptor(input_tensor->desc(), num_splits, axis, i);
96 
97  backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(output_tensor->desc().target);
98  std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
99  output_tensor->set_handle(std::move(handle));
100  }
101  }
102  }
103  }
104 }
IDeviceBackend & get_backend(Target target)
Get a backend from the registry.
bool is_target_supported(Target target)
Checks if a specific target is supported.
Definition: Utils.cpp:34
static BackendRegistry & get()
Gets backend registry instance.
std::vector< NodeID > dfs(Graph &g)
Depth first search traversal.
virtual std::unique_ptr< ITensorHandle > create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)=0
Create a backend Sub-Tensor.
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
Definition: Logger.h:50
reverse_iterable< T > reverse_iterate(T &val)
Creates a reverse iterable for a given type.
Definition: Iterable.h:101
unsigned int TensorID
Definition: Types.h:67

◆ name()

const char * name ( )
overridevirtual

Returns mutator name.

Returns
Mutator name

Implements IGraphMutator.

Definition at line 40 of file SplitLayerSubTensorMutator.cpp.

41 {
42  return "SplitLayerSubTensorMutator";
43 }

◆ type()

IGraphMutator::MutationType type ( ) const
overridevirtual

Returns mutation type.

Returns
Mutation type enumeration

Implements IGraphMutator.

Definition at line 45 of file SplitLayerSubTensorMutator.cpp.

References IGraphMutator::Backend.


The documentation for this class was generated from the following files: