37 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
44 const int M = a.
shape().y();
45 const int N =
b.shape().x();
46 const int K = a.
shape().x();
47 const int D = a.
shape().z();
48 const int W = a.
shape()[3];
50 const int a_stride_z =
K *
M;
51 const int a_stride_w =
K *
M * D;
53 const int b_stride_z =
b.shape().num_dimensions() > 2 ?
N *
K : 0;
54 int b_stride_w =
b.shape().num_dimensions() > 3 ?
K *
N * D : 0;
58 const bool is_batched_gemm =
b.shape().num_dimensions() == 3 && a.
shape().num_dimensions() == 4 && c.
shape().num_dimensions() == 4 && a.
shape()[2] == 1 && c.
shape()[2] == 1;
63 b_stride_w = b_stride_z;
66 const int c_stride_z =
N *
M;
67 const int c_stride_w =
N *
M * D;
69 #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
70 #pragma omp parallel for collapse(2)
72 for(
int w = 0;
w < W; ++
w)
74 for(
int depth = 0; depth < D; ++depth)
76 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
77 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
78 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
80 for(
int row = 0; row <
M; ++row)
82 for(
int col = 0; col <
N; ++col)
86 for(
int k = 0; k <
K; ++k)
88 acc += a[base_addr_a + k + row *
K] *
b[base_addr_b + col + k *
N];
92 dst[base_addr_c + col + row *
N] = alpha * acc + beta * c[base_addr_c + col + row *
N];
101 template <typename T, typename std::enable_if<is_floating_point<T>::value,
int>
::type>
109 const int M = a.
shape().y();
110 const int N =
b.shape().x();
111 const int K = a.
shape().x();
112 const int D = a.
shape().z();
113 const int W = a.
shape()[3];
115 const int a_stride_z =
K *
M;
116 const int a_stride_w =
K *
M * D;
118 const int b_stride_z =
b.shape().num_dimensions() > 2 ?
N *
K : 0;
119 int b_stride_w =
b.shape().num_dimensions() > 3 ?
K *
N * D : 0;
123 const bool is_batched_gemm =
b.shape().num_dimensions() == 3 && a.
shape().num_dimensions() == 4 && c.
shape().num_dimensions() == 4 && a.
shape()[2] == 1 && c.
shape()[2] == 1;
128 b_stride_w = b_stride_z;
131 const int c_stride_z =
N *
M;
132 const int c_stride_w =
N *
M * D;
134 #if defined(_OPENMP) && !(defined(__arm__) && defined(__ANDROID__))
135 #pragma omp parallel for collapse(2)
137 for(
int w = 0;
w < W; ++
w)
139 for(
int depth = 0; depth < D; ++depth)
141 const int base_addr_a = depth * a_stride_z +
w * a_stride_w;
142 const int base_addr_b = depth * b_stride_z +
w * b_stride_w;
143 const int base_addr_c = depth * c_stride_z +
w * c_stride_w;
145 for(
int row = 0; row <
M; ++row)
147 for(
int col = 0; col <
N; ++col)
151 for(
int k = 0; k <
K; ++k)
153 acc +=
static_cast<float>(a[base_addr_a + k + row *
K] *
b[base_addr_b + col + k *
N]);
157 dst[base_addr_c + col + row *
N] =
static_cast<T
>(alpha * acc + beta * c[base_addr_c + col + row *
N]);