42 return "SplitLayerSubTensorMutator";
59 std::vector<NodeID> topological_sorted_node_ids =
dfs(g);
71 bool is_valid = std::all_of(node->
outputs().cbegin(), node->
outputs().cend(),
81 << node->
id() <<
" and name : " << node->
name() << std::endl);
83 auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
85 const int axis = split_node->axis();
86 const unsigned int num_splits = split_node->num_splits();
87 const bool extend_parent = (axis < 2);
90 for(
unsigned int i = 0; i < node->
outputs().size(); ++i)
95 std::tie(std::ignore, coords) = split_node->compute_output_descriptor(input_tensor->
desc(), num_splits, axis, i);
const char * name() override
Returns mutator name.
TensorShape shape
Tensor shape.
std::string name() const
Returns node's name.
void set_handle(std::unique_ptr< ITensorHandle > backend_tensor)
Sets the backend tensor.
ITensorHandle * handle()
Backend tensor handle accessor.
IDeviceBackend & get_backend(Target target)
Get a backend from the registry.
Copyright (c) 2017-2021 Arm Limited.
TensorDescriptor & desc()
TensorInfo metadata accessor.
Tensor * output(size_t idx) const
Returns the tensor of a given output of the node.
virtual void mutate(Graph &g) override
Walk the graph and perform a specific mutation.
bool is_target_supported(Target target)
Checks if a specific target is supported.
NodeID id() const
Returns node's ID.
const std::vector< TensorID > & outputs() const
Returns outputs of the node.
static BackendRegistry & get()
Gets backend registry instance.
MutationType type() const override
Returns mutation type.
std::vector< NodeID > dfs(Graph &g)
Depth first search traversal.
virtual std::unique_ptr< ITensorHandle > create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)=0
Create a backend Sub-Tensor.
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
MutationType
Mutation type.
const INode * node(NodeID id) const
Get node object given its id.
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
Device backend interface.
virtual NodeType type() const =0
Returns node's type.
reverse_iterable< T > reverse_iterate(T &val)
Creates a reverse iterable for a given type.
Tensor * input(size_t idx) const
Returns the tensor of a given input of the node.
const Tensor * tensor(TensorID id) const
Get tensor object given its id.