44 const int M = src.
shape()[0];
45 const int N = src.
shape()[1];
46 const int num_channels = src.
shape()[2];
48 const int MxN = M *
N;
49 const int channels_in_group = num_channels /
num_groups;
54 #pragma omp parallel for collapse(2) 56 for(
int n = 0; n <
batches; ++n)
62 const T *src_ptr = src_ref + g * channels_in_group * MxN + n * num_channels * MxN;
63 T *dst_ptr = dst_ref + g * MxN + n * num_channels * MxN;
64 for(
int i = 0; i < channels_in_group; ++i)
67 src_ptr + (i + 1) * MxN,
68 dst_ptr + i * num_groups * MxN);
DataType data_type() const override
Data type of the tensor.
SimpleTensor< T > channel_shuffle(const SimpleTensor< T > &src, int num_groups)
TensorShape shape() const override
Shape of the tensor.
SimpleTensor< T > copy(const SimpleTensor< T > &src, const TensorShape &output_shape)
SimpleTensor< float > src
Copyright (c) 2017-2021 Arm Limited.
const unsigned int num_groups
SimpleTensor< float > src_ref
Simple tensor object that stores elements in a consecutive chunk of memory.
int num_channels() const override
Number of channels of the tensor.
QuantizationInfo quantization_info() const override
Quantization info in case of asymmetric quantized type.
const T * data() const
Constant pointer to the underlying buffer.