37 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
44 const int M = a.
shape().y();
45 const int N = b.
shape().x();
46 const int K = a.
shape().x();
47 const int D = a.
shape().z();
48 const int W = a.
shape()[3];
50 const int a_stride_z = K *
M;
51 const int a_stride_w = K * M * D;
53 const int b_stride_z = b.
shape().num_dimensions() > 2 ? N *
K : 0;
54 int b_stride_w = b.
shape().num_dimensions() > 3 ? K * N * D : 0;
58 const bool is_batched_gemm = b.
shape().num_dimensions() == 3 && a.
shape().num_dimensions() == 4 && c.
shape().num_dimensions() == 4 && a.
shape()[2] == 1 && c.
shape()[2] == 1;
63 b_stride_w = b_stride_z;
66 const int c_stride_z = N *
M;
67 const int c_stride_w = N * M * D;
69 #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__)) 70 #pragma omp parallel for collapse(2) 72 for(
int w = 0;
w < W; ++
w)
74 for(
int depth = 0; depth < D; ++depth)
76 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
77 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
78 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
80 for(
int row = 0; row <
M; ++row)
82 for(
int col = 0; col <
N; ++col)
86 for(
int k = 0;
k <
K; ++
k)
88 acc += a[base_addr_a +
k + row *
K] * b[base_addr_b + col +
k *
N];
92 dst[base_addr_c + col + row *
N] = alpha * acc + beta * c[base_addr_c + col + row *
N];
101 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
109 const int M = a.
shape().y();
110 const int N = b.
shape().x();
111 const int K = a.
shape().x();
112 const int D = a.
shape().z();
113 const int W = a.
shape()[3];
115 const int a_stride_z = K *
M;
116 const int a_stride_w = K * M * D;
118 const int b_stride_z = b.
shape().num_dimensions() > 2 ? N *
K : 0;
119 int b_stride_w = b.
shape().num_dimensions() > 3 ? K * N * D : 0;
123 const bool is_batched_gemm = b.
shape().num_dimensions() == 3 && a.
shape().num_dimensions() == 4 && c.
shape().num_dimensions() == 4 && a.
shape()[2] == 1 && c.
shape()[2] == 1;
128 b_stride_w = b_stride_z;
131 const int c_stride_z = N *
M;
132 const int c_stride_w = N * M * D;
134 #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__)) 135 #pragma omp parallel for collapse(2) 137 for(
int w = 0;
w < W; ++
w)
139 for(
int depth = 0; depth < D; ++depth)
141 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
142 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
143 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
145 for(
int row = 0; row <
M; ++row)
147 for(
int col = 0; col <
N; ++col)
151 for(
int k = 0;
k <
K; ++
k)
153 acc +=
static_cast<float>(a[base_addr_a +
k + row *
K] * b[base_addr_b + col +
k *
N]);
157 dst[base_addr_c + col + row *
N] =
static_cast<T
>(alpha * acc + beta * c[base_addr_c + col + row *
N]);
DataType data_type() const override
Data type of the tensor.
TensorShape shape() const override
Shape of the tensor.
decltype(strategy::transforms) typedef type
Copyright (c) 2017-2023 Arm Limited.
Simple tensor object that stores elements in a consecutive chunk of memory.
SimpleTensor< T > gemm_mixed_precision(const SimpleTensor< T > &a, const SimpleTensor< T > &b, const SimpleTensor< T > &c, float alpha, float beta)
SimpleTensor< T > gemm(const SimpleTensor< T > &a, const SimpleTensor< T > &b, const SimpleTensor< T > &c, float alpha, float beta)