37 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
44 const int M = a.
shape().y();
45 const int N = b.
shape().x();
46 const int K = a.
shape().x();
47 const int D = a.
shape().z();
48 const int W = a.
shape()[3];
50 const int a_stride_z = K *
M;
51 const int a_stride_w = K * M * D;
53 const int b_stride_z = b.
shape().num_dimensions() > 2 ? N *
K : 0;
54 const int b_stride_w = b.
shape().num_dimensions() > 3 ? K * N * D : 0;
56 const int c_stride_z = N *
M;
57 const int c_stride_w = N * M * D;
59 #if defined(_OPENMP) && !( defined(__arm__) && defined(__ANDROID__)) 60 #pragma omp parallel for collapse(2) 62 for(
int w = 0;
w < W; ++
w)
64 for(
int depth = 0; depth < D; ++depth)
66 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
67 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
68 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
70 for(
int row = 0; row <
M; ++row)
72 for(
int col = 0; col <
N; ++col)
76 for(
int k = 0; k <
K; ++k)
78 acc += a[base_addr_a + k + row *
K] * b[base_addr_b + col + k *
N];
82 dst[base_addr_c + col + row *
N] = alpha * acc + beta * c[base_addr_c + col + row *
N];
91 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
99 const int M = a.
shape().y();
100 const int N = b.
shape().x();
101 const int K = a.
shape().x();
102 const int D = a.
shape().z();
103 const int W = a.
shape()[3];
105 const int a_stride_z = K *
M;
106 const int a_stride_w = K * M * D;
108 const int b_stride_z = b.
shape().num_dimensions() > 2 ? N *
K : 0;
109 const int b_stride_w = b.
shape().num_dimensions() > 3 ? K * N * D : 0;
111 const int c_stride_z = N *
M;
112 const int c_stride_w = N * M * D;
114 #if defined(_OPENMP) && !( defined(__arm__) && defined(__ANDROID__)) 115 #pragma omp parallel for collapse(2) 117 for(
int w = 0;
w < W; ++
w)
119 for(
int depth = 0; depth < D; ++depth)
121 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
122 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
123 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
125 for(
int row = 0; row <
M; ++row)
127 for(
int col = 0; col <
N; ++col)
131 for(
int k = 0; k <
K; ++k)
133 acc +=
static_cast<float>(a[base_addr_a + k + row *
K] * b[base_addr_b + col + k *
N]);
137 dst[base_addr_c + col + row *
N] =
static_cast<T
>(alpha * acc + beta * c[base_addr_c + col + row *
N]);
DataType data_type() const override
Data type of the tensor.
TensorShape shape() const override
Shape of the tensor.
decltype(strategy::transforms) typedef type
Copyright (c) 2017-2021 Arm Limited.
Simple tensor object that stores elements in a consecutive chunk of memory.
SimpleTensor< T > gemm_mixed_precision(const SimpleTensor< T > &a, const SimpleTensor< T > &b, const SimpleTensor< T > &c, float alpha, float beta)
SimpleTensor< T > gemm(const SimpleTensor< T > &a, const SimpleTensor< T > &b, const SimpleTensor< T > &c, float alpha, float beta)