50 std::memset(&out[0], 0, out.num_elements() *
sizeof(T));
52 const unsigned int N = in.
shape()[0];
53 const unsigned int K = in.
shape()[1];
54 const unsigned int B = in.
shape()[2];
56 const unsigned int num_tiles_x = std::ceil(N / static_cast<float>(rhs_info.
n0));
57 const unsigned int num_tiles_y = std::ceil(K / static_cast<float>(rhs_info.
k0));
71 const unsigned int offset_output_x = rhs_info.
interleave ? tile_to_use->
shape()[0] : tile_to_use->
shape()[0] * tile_to_use->
shape()[1];
72 const unsigned int step_output_x = rhs_info.
interleave ? tile_to_use->
shape()[0] * rhs_info.
h0 : tile_to_use->
shape()[0];
73 #ifdef ARM_COMPUTE_OPENMP 74 #pragma omp parallel for schedule(dynamic, 1) collapse(3) 76 for(
unsigned int z = 0; z < B; ++z)
78 for(
unsigned int y = 0; y < num_tiles_y; ++y)
80 for(
unsigned int x = 0; x < num_tiles_x; ++x)
83 get_tile<T>(in, src_tile,
Coordinates(x * rhs_info.
n0, y * rhs_info.
k0, z, 0));
88 transpose_matrix<T>(src_tile, src_tile_transposed);
92 const unsigned int offset_output = (y * rhs_info.
k0 * rhs_info.
n0 * rhs_info.
h0) + ((x % rhs_info.
h0) * offset_output_x) + ((x / rhs_info.
h0) * out.shape()[0]) + (z * out.shape()[0] * out.shape()[1]);
94 for(
unsigned int i = 0; i < tile_to_use->
shape()[1]; ++i)
96 const unsigned int offset_tile = i * tile_to_use->
shape()[0];
99 std::copy(&(*tile_to_use)[offset_tile], &(*tile_to_use)[offset_tile + tile_to_use->
shape()[0]], &out[offset_output + i * step_output_x]);
DataType data_type() const override
Data type of the tensor.
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
TensorShape shape() const override
Shape of the tensor.
SimpleTensor< T > copy(const SimpleTensor< T > &src, const TensorShape &output_shape)
bool transpose
True if the (k0xn0) block has to be transposed before been stored.
Copyright (c) 2017-2021 Arm Limited.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
GEMM RHS (Right Hand Side) matrix information.
unsigned int n0
Number of columns processed by the matrix multiplication.
Simple tensor object that stores elements in a consecutive chunk of memory.
bool interleave
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
SimpleTensor< T > gemm_reshape_rhs_matrix(const SimpleTensor< T > &in, const TensorShape &output_shape, const GEMMRHSMatrixInfo &rhs_info)