Compute Library
 19.08
NESplit Class Reference

Basic function to split a tensor along a given axis. More...

#include <NESplit.h>

Collaboration diagram for NESplit:
[legend]

Public Member Functions

 NESplit ()
 Default constructor. More...
 
void configure (const ITensor *input, const std::vector< ITensor * > &outputs, unsigned int axis)
 Initialise the kernel's input and outputs. More...
 
void run () override
 Run the kernels contained in the function. More...
 
- Public Member Functions inherited from IFunction
virtual ~IFunction ()=default
 Destructor. More...
 
virtual void prepare ()
 Prepare the function for executing. More...
 

Static Public Member Functions

static Status validate (const ITensorInfo *input, const std::vector< ITensorInfo * > &outputs, unsigned int axis)
 Static function to check if given info will lead to a valid configuration of NESplit. More...
 

Detailed Description

Basic function to split a tensor along a given axis.

Definition at line 41 of file NESplit.h.

Constructor & Destructor Documentation

◆ NESplit()

NESplit ( )

Default constructor.

Definition at line 37 of file NESplit.cpp.

38  : _outputs_vector(), _slice_functions(), _num_outputs(0)
39 {
40 }

Member Function Documentation

◆ configure()

void configure ( const ITensor input,
const std::vector< ITensor * > &  outputs,
unsigned int  axis 
)

Initialise the kernel's input and outputs.

Parameters
[in]inputThe input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32.
[out]outputsA vector containing the output tensors. Data types supported: Same as input. The output tensors should match the input tensor dimensions for all shape dimensions apart from the split dimension.
[in]axisAxis on which to split the input.

Definition at line 42 of file NESplit.cpp.

43 {
44  // Create Slice functions
45  _num_outputs = outputs.size();
46  _slice_functions.resize(_num_outputs);
47 
48  // Get output shape
49  const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs);
50 
51  // Extract output tensor info
52  std::vector<ITensorInfo *> outputs_info;
53  for(auto &output : outputs)
54  {
56  outputs_info.emplace_back(output->info());
57  }
58 
59  // Validate
60  ARM_COMPUTE_ERROR_THROW_ON(NESplit::validate(input->info(), outputs_info, axis));
61 
62  const size_t axis_split_step = output_shape[axis];
63  unsigned int axis_offset = 0;
64 
65  // Start/End coordinates
66  Coordinates start_coords;
67  Coordinates end_coords;
68  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
69  {
70  end_coords.set(d, -1);
71  }
72 
73  for(unsigned int i = 0; i < _num_outputs; i++)
74  {
75  // Update coordinate on axis
76  start_coords.set(axis, axis_offset);
77  end_coords.set(axis, axis_offset + axis_split_step);
78 
79  // Configure slice function
80  _slice_functions[i].configure(input, outputs[i], start_coords, end_coords);
81 
82  // Set valid region from shape
83  outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape));
84 
85  // Update axis offset
86  axis_offset += axis_split_step;
87  }
88 }
static Status validate(const ITensorInfo *input, const std::vector< ITensorInfo * > &outputs, unsigned int axis)
Static function to check if given info will lead to a valid configuration of NESplit.
Definition: NESplit.cpp:90
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:327
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
TensorShape compute_split_shape(const ITensorInfo *input, unsigned int axis, unsigned int num_splits)
Calculate the split output shape of a tensor.

References ARM_COMPUTE_ERROR_ON_NULLPTR, ARM_COMPUTE_ERROR_THROW_ON, arm_compute::test::validation::axis, arm_compute::misc::shape_calculator::compute_split_shape(), ITensor::info(), arm_compute::test::validation::output_shape, Dimensions< T >::set(), and NESplit::validate().

◆ run()

void run ( )
overridevirtual

Run the kernels contained in the function.

For NEON kernels:

  • Multi-threading is used for the kernels which are parallelisable.
  • By default std::thread::hardware_concurrency() threads are used.
Note
CPPScheduler::set_num_threads() can be used to manually set the number of threads

For OpenCL kernels:

  • All the kernels are enqueued on the queue associated with CLScheduler.
  • The queue is then flushed.
Note
The function will not block until the kernels are executed. It is the user's responsibility to wait.
Will call prepare() on first run if hasn't been done

Implements IFunction.

Definition at line 131 of file NESplit.cpp.

132 {
133  for(unsigned i = 0; i < _num_outputs; ++i)
134  {
135  _slice_functions[i].run();
136  }
137 }

◆ validate()

Status validate ( const ITensorInfo input,
const std::vector< ITensorInfo * > &  outputs,
unsigned int  axis 
)
static

Static function to check if given info will lead to a valid configuration of NESplit.

Parameters
[in]inputThe input tensor info. Data types supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32.
[in]outputsA vector containing the output tensors' info. Data types supported: Same as input. The output tensors should match the input tensor dimensions for all shape dimensions apart from the split dimension
[in]axisAxis on which to split the input.
Returns
a status

Definition at line 90 of file NESplit.cpp.

91 {
93  ARM_COMPUTE_RETURN_ERROR_ON(axis >= input->num_dimensions());
94  ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2);
95 
96  // Get output shape
97  const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_split_shape(input, axis, outputs.size());
98  ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0);
99 
100  const size_t axis_split_step = output_shape[axis];
101  unsigned int axis_offset = 0;
102 
103  // Start/End coordinates
104  Coordinates start_coords;
105  Coordinates end_coords;
106  for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d)
107  {
108  end_coords.set(d, -1);
109  }
110 
111  // Validate output tensors
112  for(const auto &output : outputs)
113  {
115 
116  // Output auto inizialitation if not yet initialized
117  TensorInfo tmp_output_info = *output->clone();
118  auto_init_if_empty(tmp_output_info, input->clone()->set_is_resizable(true).set_tensor_shape(output_shape));
119 
120  // Update coordinate on axis
121  start_coords.set(axis, axis_offset);
122  end_coords.set(axis, axis_offset + axis_split_step);
123 
124  ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(input, output, start_coords, end_coords));
125  axis_offset += axis_split_step;
126  }
127 
128  return Status{};
129 }
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:193
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:244
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
Definition: Helpers.inl:201
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:163
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const Coordinates &starts, const Coordinates &ends)
Static function to check if given info will lead to a valid configuration of NESlice.
Definition: NESlice.cpp:48
TensorShape compute_split_shape(const ITensorInfo *input, unsigned int axis, unsigned int num_splits)
Calculate the split output shape of a tensor.

References ARM_COMPUTE_RETURN_ERROR_ON, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR, ARM_COMPUTE_RETURN_ON_ERROR, arm_compute::auto_init_if_empty(), arm_compute::test::validation::axis, ICloneable< T >::clone(), TensorInfo::clone(), arm_compute::misc::shape_calculator::compute_split_shape(), ITensorInfo::num_dimensions(), arm_compute::test::validation::output_shape, Dimensions< T >::set(), and NESlice::validate().

Referenced by NESplit::configure().


The documentation for this class was generated from the following files: