Compute Library
 22.08
arm_gemm.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25 
26 #include <cstring>
27 #include <memory>
28 #include <vector>
29 
30 #include "arm_gemm_local.hpp"
31 #include "gemm_common.hpp"
32 
33 namespace arm_gemm
34 {
35 enum class GemmMethod
36 {
37  DEFAULT,
48 };
49 
50 enum class WeightFormat
51 {
52  UNSPECIFIED = 0x1,
53  ANY = 0x2,
54  OHWI = 0x100100,
55  OHWIo2 = 0x100200,
56  OHWIo4 = 0x100400,
57  OHWIo8 = 0x100800,
58  OHWIo16 = 0x101000,
59  OHWIo32 = 0x102000,
60  OHWIo64 = 0x104000,
61  OHWIo128 = 0x108000,
62  OHWIo4i2 = 0x200400,
63  OHWIo4i2_bf16 = 0x200410,
64  OHWIo8i2 = 0x200800,
65  OHWIo8i2_bf16 = 0x200810,
66  OHWIo16i2 = 0x201000,
67  OHWIo16i2_bf16 = 0x201010,
68  OHWIo32i2 = 0x202000,
69  OHWIo32i2_bf16 = 0x202010,
70  OHWIo64i2 = 0x204000,
71  OHWIo64i2_bf16 = 0x204010,
72  OHWIo4i4 = 0x400400,
73  OHWIo4i4_bf16 = 0x400410,
74  OHWIo8i4 = 0x400800,
75  OHWIo8i4_bf16 = 0x400810,
76  OHWIo16i4 = 0x401000,
77  OHWIo16i4_bf16 = 0x401010,
78  OHWIo32i4 = 0x402000,
79  OHWIo32i4_bf16 = 0x402010,
80  OHWIo64i4 = 0x404000,
81  OHWIo64i4_bf16 = 0x404010,
82  OHWIo2i8 = 0x800200,
83  OHWIo4i8 = 0x800400,
84  OHWIo8i8 = 0x800800,
85  OHWIo16i8 = 0x801000,
86  OHWIo32i8 = 0x802000,
87  OHWIo64i8 = 0x804000
88 };
89 
91 {
93  std::string name = "";
94  bool is_default = false;
95  uint64_t cycle_estimate = 0;
96 
97  KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0)
98  : method(m), name(n), is_default(d), cycle_estimate(c)
99  {
100  }
101  KernelDescription() noexcept
102  {
103  }
104 };
105 
107 {
109  std::string filter = "";
110  unsigned int inner_block_size = 0;
111  unsigned int outer_block_size = 0;
113 
115  : method(method)
116  {
117  }
119  {
120  }
121 };
122 
124 {
125  enum class Type
126  {
127  None,
128  ReLU,
129  BoundedReLU
130  };
131 
133  float param1;
134  float param2;
135 
136  Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f)
137  : type(type), param1(p1), param2(p2)
138  {
139  }
140 };
141 
142 struct GemmArgs
143 {
144 public:
145  const CPUInfo *_ci;
146  unsigned int _Msize; // num of tiles
147  unsigned int _Nsize; // output channels
148  unsigned int _Ksize; // input channels
149  unsigned int _Ksections;
150  unsigned int _nbatches;
151  unsigned int _nmulti; // n_gemms to be performed
157  const GemmConfig *_cfg;
158 
159  GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N,
160  unsigned int K, unsigned int Ksections, unsigned int nbatches,
161  unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads,
162  bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr)
163  : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads),
164  _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg)
165  {
166  }
167 };
168 
170 {
171 public:
172  const int32_t *bias = nullptr;
173  size_t bias_multi_stride = 0;
174  int32_t a_offset = 0;
175  int32_t b_offset = 0;
176  int32_t c_offset = 0;
177  bool per_channel_requant = false;
178  int32_t per_layer_left_shift = 0;
179  int32_t per_layer_right_shift = 0;
180  int32_t per_layer_mul = 0;
181  const int32_t *per_channel_left_shifts = nullptr;
182  const int32_t *per_channel_right_shifts = nullptr;
183  const int32_t *per_channel_muls = nullptr;
184  int32_t minval = 0;
185  int32_t maxval = 0;
186 
187  Requantize32() = default;
188 
189  // Constructor for per-tensor quantization
190  Requantize32(const int32_t *bias, size_t bias_multi_stride,
191  int32_t a_offset, int32_t b_offset, int32_t c_offset,
192  int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
193  : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max<int32_t>(requant_shift, 0)),
194  per_layer_right_shift(std::min<int32_t>(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv)
195  {
196  }
197 
198  // Constructor for per-channel quantization
199  Requantize32(const int32_t *bias, size_t bias_multi_stride,
200  int32_t a_offset, int32_t b_offset, int32_t c_offset,
201  const int32_t *requant_left_shifts,
202  const int32_t *requant_right_shifts,
203  const int32_t *requant_muls,
204  int32_t minv, int32_t maxv)
205  : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts),
206  per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv)
207  {
208  }
209 };
210 
211 struct Nothing
212 {
213 };
214 
215 template <typename Top, typename Tret>
216 using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
217 
218 /* Low level API calls.
219  * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
220 
221 /* get_gemm_method(): Given the templated types and provided parameters,
222  * which is the preferred method to implement this GEMM? */
223 template <typename Top, typename Tret, class OutputStage = Nothing>
224 KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
225 
226 template <typename Top, typename Tret, class OutputStage = Nothing>
227 UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
228 
229 template <typename Top, typename Tret, class OutputStage = Nothing>
230 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
231 
232 template <typename Top, typename Tret, class OutputStage = Nothing>
233 bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
234 
235 } // namespace arm_gemm
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os)
const CPUInfo * _ci
Definition: arm_gemm.hpp:145
std::vector< KernelDescription > get_compatible_kernels(const GemmArgs &args, const OutputStage &os)
unsigned int _nmulti
Definition: arm_gemm.hpp:151
unsigned int _Nsize
Definition: arm_gemm.hpp:147
Activation _act
Definition: arm_gemm.hpp:153
GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, unsigned int K, unsigned int Ksections, unsigned int nbatches, unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, bool fixed_format=false, bool fast_mode=false, const GemmConfig *cfg=nullptr)
Definition: arm_gemm.hpp:159
unsigned int M
bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os)
const GemmConfig * _cfg
Definition: arm_gemm.hpp:157
const CPUInfo & ci
GemmConfig(GemmMethod method)
Definition: arm_gemm.hpp:114
Activation(Type type=Type::None, float p1=0.0f, float p2=0.0f)
Definition: arm_gemm.hpp:136
unsigned int N
const char * name
Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv)
Definition: arm_gemm.hpp:190
UniqueGemmCommon< Top, Tret > gemm(const GemmArgs &args, const OutputStage &os)
unsigned int _Msize
Definition: arm_gemm.hpp:146
std::unique_ptr< GemmCommon< Top, Tret > > UniqueGemmCommon
Definition: arm_gemm.hpp:216
KernelDescription(GemmMethod m, std::string n, bool d=false, uint64_t c=0)
Definition: arm_gemm.hpp:97
const int32_t * requant_muls
unsigned int _Ksections
Definition: arm_gemm.hpp:149
unsigned int _Ksize
Definition: arm_gemm.hpp:148
unsigned int _nbatches
Definition: arm_gemm.hpp:150
Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, const int32_t *requant_left_shifts, const int32_t *requant_right_shifts, const int32_t *requant_muls, int32_t minv, int32_t maxv)
Definition: arm_gemm.hpp:199
const int32_t * bias
unsigned int K