Compute Library
 21.02
CLGEMMDefaultConfigReshapedRHSOnlyValhall.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
33 
34 #include <map>
35 #include <utility>
36 
37 namespace arm_compute
38 {
39 namespace cl_gemm
40 {
42 
45 {
46 }
47 
48 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
49 {
50  using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMDefaultConfigReshapedRHSOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k,
51  unsigned int b);
52 
53  // Configurations for Mali-G77
54  static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G77 =
55  {
56  { DataType::F32, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_f32 },
57  { DataType::F16, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_f16 },
58  { DataType::QASYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_u8 },
59  { DataType::QSYMM8, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_u8 },
60  { DataType::QASYMM8_SIGNED, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_u8 },
61  { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_u8 }
62  };
63 
64  switch(_target)
65  {
66  case GPUTarget::G77:
67  default:
68  if(gemm_configs_G77.find(data_type) != gemm_configs_G77.end())
69  {
70  return (this->*gemm_configs_G77[data_type])(m, n, k, b);
71  }
72  else
73  {
74  ARM_COMPUTE_ERROR("Not supported data type");
75  }
76  }
77 }
78 
79 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
80 {
81  if(m == 1)
82  {
83  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
84  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
85 
86  if(r_mk <= 0.0064484127797186375)
87  {
88  if(r_mn <= 0.0028273810748942196)
89  {
90  GEMMLHSMatrixInfo lhs_info_buf;
91  GEMMRHSMatrixInfo rhs_info_buf;
92  GEMMLHSMatrixInfo lhs_info_img;
93  GEMMRHSMatrixInfo rhs_info_img;
94 
95  const unsigned int h0 = std::max(n / 4, 1U);
96  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, false, true);
97  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true, false);
98 
99  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
100  std::make_pair(lhs_info_buf, rhs_info_buf),
101  n, k, b, DataType::F32);
102  }
103  else
104  {
105  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, false, true, false, false, false);
106  }
107  }
108  else
109  {
110  if(r_mk <= 0.020312500186264515)
111  {
112  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, false, false);
113  }
114  else
115  {
116  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, true, false);
117  }
118  }
119  }
120  else
121  {
122  const float r_mn = static_cast<float>(m) / static_cast<float>(n);
123  const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
124  const float r_mk = static_cast<float>(m) / static_cast<float>(k);
125 
126  if(workload <= 1999.2000122070312)
127  {
128  if(workload <= 747.1999816894531)
129  {
130  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
131  }
132  else
133  {
134  GEMMLHSMatrixInfo lhs_info_buf;
135  GEMMRHSMatrixInfo rhs_info_buf;
136  GEMMLHSMatrixInfo lhs_info_img;
137  GEMMRHSMatrixInfo rhs_info_img;
138  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
139  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
140 
141  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
142  std::make_pair(lhs_info_buf, rhs_info_buf),
143  n, k, b, DataType::F32);
144  }
145  }
146  else
147  {
148  if(r_mn <= 0.03348214365541935)
149  {
150  if(r_mk <= 0.028125000186264515)
151  {
152  return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
153  }
154  else
155  {
156  GEMMLHSMatrixInfo lhs_info_buf;
157  GEMMRHSMatrixInfo rhs_info_buf;
158  GEMMLHSMatrixInfo lhs_info_img;
159  GEMMRHSMatrixInfo rhs_info_img;
160  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true);
161  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false);
162 
163  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
164  std::make_pair(lhs_info_buf, rhs_info_buf),
165  n, k, b, DataType::F32);
166  }
167  }
168  else
169  {
170  GEMMLHSMatrixInfo lhs_info_buf;
171  GEMMRHSMatrixInfo rhs_info_buf;
172  GEMMLHSMatrixInfo lhs_info_img;
173  GEMMRHSMatrixInfo rhs_info_img;
174  std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, false, true);
175  std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, false, true, false, true, false);
176 
177  return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
178  std::make_pair(lhs_info_buf, rhs_info_buf),
179  n, k, b, DataType::F32);
180  }
181  }
182  }
183 }
184 
185 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
186 {
189 
190  if(m == 1)
191  {
192  const unsigned int h0 = std::max(n / 2, 1U);
193  if(n <= 836.0)
194  {
195  return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false);
196  }
197  else
198  {
199  return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true, false);
200  }
201  }
202  else if(m < 128)
203  {
204  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
205  if(k >= 512)
206  {
207  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
208  }
209  else
210  {
211  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
212  }
213  }
214  else
215  {
216  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217  if(n >= 64)
218  {
219  return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, false, true, false, false);
220  }
221  else
222  {
223  if(k >= 512)
224  {
225  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, false);
226  }
227  else
228  {
229  return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, false);
230  }
231  }
232  }
233 }
234 
235 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMDefaultConfigReshapedRHSOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
236 {
239 
240  if(m == 1)
241  {
242  const unsigned int h0 = std::max(n / 2, 1U);
243  return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
244  }
245  else
246  {
247  const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
248  if(m >= 28)
249  {
250  return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
251  }
252  else
253  {
254  return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, false, true, false, true);
255  }
256  }
257 }
258 } // namespace cl_gemm
259 } // namespace arm_compute
Basic interface for the GEMM kernel configuration.
SimpleTensor< float > b
Definition: DFT.cpp:157
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
1 channel, 1 F32 per channel
GEMM LHS (Left Hand Side) matrix information.
Definition: Types.h:1968
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
const DataType data_type
Definition: Im2Col.cpp:150
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override
Given M, N, K and B, this method returns the GEMMLHSMatrixInfo and GEMMRHSMatrixInfo to be used...
#define ARM_COMPUTE_UNUSED(...)
To avoid unused variables warnings.
Definition: Error.h:152
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:1983
quantized, asymmetric fixed-point 8-bit number unsigned
quantized, symmetric fixed-point 8-bit number
quantized, symmetric per channel fixed-point 8-bit number
GPUTarget
Available GPU Targets.
Definition: GPUTarget.h:34
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
quantized, asymmetric fixed-point 8-bit number signed
DataType
Available data types.
Definition: Types.h:77
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.