24.02.1
|
Go to the documentation of this file.
42 return "SplitLayerSubTensorMutator";
59 std::vector<NodeID> topological_sorted_node_ids =
dfs(g);
73 return (g.tensor(tid) != nullptr) &&
74 (g.tensor(tid)->desc().target == input_tensor->desc().target);
81 << node->
id() <<
" and name : " << node->
name() << std::endl);
83 auto *split_node = arm_compute::utils::cast::polymorphic_downcast<SplitLayerNode *>(node);
85 const int axis = split_node->axis();
86 const unsigned int num_splits = split_node->num_splits();
87 const bool extend_parent = (axis < 2);
90 for (
unsigned int i = 0; i < node->
outputs().size(); ++i)
95 std::tie(std::ignore, coords) =
96 split_node->compute_output_descriptor(input_tensor->
desc(), num_splits, axis, i);
100 std::unique_ptr<ITensorHandle> handle =
virtual void mutate(Graph &g) override
Walk the graph and perform a specific mutation.
void set_handle(std::unique_ptr< ITensorHandle > backend_tensor)
Sets the backend tensor.
MutationType
Mutation type.
const char * name() override
Returns mutator name.
reverse_iterable< T > reverse_iterate(T &val)
Creates a reverse iterable for a given type.
Device backend interface.
TensorDescriptor & desc()
TensorInfo metadata accessor.
Tensor * output(size_t idx) const
Returns the tensor of a given output of the node.
const std::vector< TensorID > & outputs() const
Returns outputs of the node.
#define ARM_COMPUTE_LOG_GRAPH_VERBOSE(x)
ITensorHandle * handle()
Backend tensor handle accessor.
std::string name() const
Returns node's name.
std::vector< NodeID > dfs(Graph &g)
Depth first search traversal.
bool is_target_supported(Target target)
Checks if a specific target is supported.
@ Backend
IR specific mutation.
TensorShape shape
Tensor shape.
MutationType type() const override
Returns mutation type.
virtual NodeType type() const =0
Returns node's type.
const INode * node(NodeID id) const
Get node object given its id.
virtual std::unique_ptr< ITensorHandle > create_subtensor(ITensorHandle *parent, TensorShape shape, Coordinates coords, bool extend_parent)=0
Create a backend Sub-Tensor.
Copyright (c) 2017-2024 Arm Limited.
Tensor * input(size_t idx) const
Returns the tensor of a given input of the node.
IDeviceBackend & get_backend(Target target)
Get a backend from the registry.
NodeID id() const
Returns node's ID.
const std::vector< NodeID > & nodes(NodeType type)
Returns graph input nodes.
static BackendRegistry & get()
Gets backend registry instance.