Compute Library
 21.02
CLGEMMReshapeRHSMatrixKernel.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
32 #include "arm_compute/core/Utils.h"
35 #include "src/core/CL/CLValidate.h"
39 #include "support/StringSupport.h"
40 
41 namespace arm_compute
42 {
44 
45 namespace
46 {
47 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const GEMMRHSMatrixInfo &rhs_info)
48 {
49  ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.n0 == 0);
50  ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 == 0);
51  ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.h0 == 0);
52  ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.n0 & (rhs_info.n0 - 1)) && rhs_info.n0 != 3), "Only 2,3,4,8,16 are supported for n0");
53  ARM_COMPUTE_RETURN_ERROR_ON_MSG(((rhs_info.k0 & (rhs_info.k0 - 1)) && (rhs_info.k0 != 1) && (rhs_info.k0 != 3)), "Only 1,2,3,4,8,16 are supported for k0");
54  ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.n0 > 16);
55  ARM_COMPUTE_RETURN_ERROR_ON(rhs_info.k0 > 16);
56  ARM_COMPUTE_RETURN_ERROR_ON((rhs_info.k0 == 1) && (rhs_info.transpose));
57 
59  ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
60 
61  if(rhs_info.export_to_cl_image)
62  {
63  const TensorInfo tensor_reshaped_info(compute_rhs_reshaped_shape(*input, rhs_info), 1, input->data_type());
65  }
66 
67  if(output->total_size() != 0)
68  {
69  ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_rhs_reshaped_shape(*input, rhs_info));
72  }
73 
74  return Status{};
75 }
76 
77 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const GEMMRHSMatrixInfo &rhs_info)
78 {
79  const unsigned int num_elems_processed_per_iteration_x = rhs_info.n0;
80  const unsigned int num_elems_processed_per_iteration_y = rhs_info.k0;
81  bool window_changed = false;
82 
83  // Output auto initialization if not yet initialized
84  auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*input, rhs_info)));
85 
86  // Configure window
87  Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
88 
89  AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
90  AccessWindowStatic output_access(output, 0, 0, output->dimension(0), output->dimension(1));
91 
92  window_changed = update_window_and_padding(win, input_access);
93  output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
94 
95  if(rhs_info.export_to_cl_image)
96  {
98  }
99 
100  // Collapse along the Z direction
101  // This collapse needs to be here in order to tune the Z dimension of LWS
102  Window collapsed = win.collapse(win, Window::DimZ);
103 
104  Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
105  return std::make_pair(err, collapsed);
106 }
107 } // namespace
108 
110  : _input(nullptr), _output(nullptr)
111 {
112 }
113 
115 {
116  configure(CLKernelLibrary::get().get_compile_context(), input, output, rhs_info);
117 }
118 
119 void CLGEMMReshapeRHSMatrixKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const GEMMRHSMatrixInfo &rhs_info)
120 {
121  ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
122 
123  // Perform validate step
124  ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), rhs_info));
125 
126  _input = input;
127  _output = output;
128 
129  // Create build options
130  CLBuildOptions build_opts;
131  build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
132  build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
133  build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
134  build_opts.add_option_if(rhs_info.transpose, "-DTRANSPOSE");
135  build_opts.add_option_if(rhs_info.interleave, "-DINTERLEAVE");
136  build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1)));
137  build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(input->info()->element_size()));
138 
139  std::string kernel_name("gemm_reshape_rhs_matrix_");
140  kernel_name += rhs_info.transpose ? "t" : "nt";
141 
142  // Create kernel
143  _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
144 
145  // Configure kernel window
146  auto win_config = validate_and_configure_window(input->info(), output->info(), rhs_info);
147  ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
148  ICLKernel::configure_internal(win_config.second);
149 }
150 
152 {
153  ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, rhs_info));
154  ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), rhs_info).first);
155 
156  return Status{};
157 }
158 
159 void CLGEMMReshapeRHSMatrixKernel::run(const Window &window, cl::CommandQueue &queue)
160 {
163 
165 
166  do
167  {
168  unsigned int idx = 0;
169  add_3D_tensor_argument(idx, _input, slice);
170  add_3D_tensor_argument(idx, _output, slice);
171  enqueue(queue, *this, slice, lws_hint());
172  }
173  while(window.slide_window_slice_3D(slice));
174 }
175 } // namespace arm_compute
Window calculate_max_window(const ValidRegion &valid_region, const Steps &steps, bool skip_border, BorderSize border_size)
#define ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(tensor)
Definition: CLValidate.h:35
const Window & window() const
The maximum window the kernel can be executed on.
Definition: IKernel.cpp:28
void enqueue(IGCKernel &kernel, const Window &window, const gles::NDRange &lws=gles::NDRange(1U, 1U, 1U))
Add the kernel to the command queue with the given window.
Definition: IGCKernel.cpp:41
virtual size_t dimension(size_t index) const =0
Return the size of the requested dimension.
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(...)
Definition: Validate.h:610
const StringSet & options() const
Gets the current options list set.
cl::NDRange lws_hint() const
Return the Local-Workgroup-Size hint.
Definition: ICLKernel.h:276
#define ARM_COMPUTE_RETURN_ON_ERROR(status)
Checks if a status contains an error and returns it.
Definition: Error.h:204
std::string to_string(T &&value)
Convert integer and float values to string.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix...
unsigned int h0
Number of horizontal blocks of size (k0xn0) stored on the same output row.
Definition: Types.h:1992
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
Store the tensor&#39;s metadata.
Definition: ITensorInfo.h:40
#define ARM_COMPUTE_ERROR_THROW_ON(status)
Definition: Error.h:455
Status class.
Definition: Error.h:52
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(...)
Definition: Validate.h:288
void run(const Window &window, cl::CommandQueue &queue) override
Enqueue the OpenCL kernel to process the given window on the passed OpenCL command queue...
void add_3D_tensor_argument(unsigned int &idx, const ICLTensor *tensor, const Window &window)
Add the passed 3D tensor&#39;s parameters to the object&#39;s kernel&#39;s arguments starting from the index idx...
Definition: ICLKernel.h:172
bool transpose
True if the (k0xn0) block has to be transposed before been stored.
Definition: Types.h:1993
Copyright (c) 2017-2021 Arm Limited.
void add_option(std::string option)
Adds option to the existing build option list.
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
Definition: Types.h:1991
void update_padding_for_cl_image(ITensorInfo *tensor)
Update padding required to export the OpenCL buffer to OpenCL image2d.
cl::Kernel create_kernel(const CLCompileContext &ctx, const std::string &kernel_name, const std::set< std::string > &build_opts=std::set< std::string >())
Creates an opencl kernel using a compile context.
Definition: CLHelpers.cpp:403
bool update_window_and_padding(Window &win, Ts &&... patterns)
Update window and padding size for each of the access patterns.
Definition: WindowHelpers.h:46
GEMM RHS (Right Hand Side) matrix information.
Definition: Types.h:1983
unsigned int n0
Number of columns processed by the matrix multiplication.
Definition: Types.h:1990
std::string kernel_name
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
bool auto_init_if_empty(ITensorInfo &info, const TensorShape &shape, int num_channels, DataType data_type, QuantizationInfo quantization_info=QuantizationInfo())
Auto initialize the tensor info (shape, number of channels and data type) if the current assignment i...
virtual std::unique_ptr< T > clone() const =0
Provide a clone of the current object of class T.
virtual ITensorInfo * info() const =0
Interface to be implemented by the child class to return the tensor&#39;s metadata.
void add_option_if(bool cond, std::string option)
Adds option if a given condition is true;.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
bool slide_window_slice_3D(Window &slice) const
Slide the passed 3D window slice.
Definition: Window.h:335
#define ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(k)
Definition: Validate.h:941
CLCompileContext class.
Interface for OpenCL tensor.
Definition: ICLTensor.h:42
#define ARM_COMPUTE_CREATE_ERROR(error_code, msg)
Creates an error with a given message.
Definition: Error.h:159
static constexpr size_t DimZ
Alias for dimension 2 also known as Z dimension.
Definition: Window.h:47
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context...
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:545
std::string get_cl_unsigned_type_from_element_size(size_t element_size)
Translates the element size to an unsigned integer data type.
Definition: CLHelpers.cpp:103
void configure(const ICLTensor *input, ICLTensor *output, const GEMMRHSMatrixInfo &rhs_info)
Initialise the kernel&#39;s input and output.
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const GEMMLowpOutputStageInfo *output_stage)
Wrapper to configure the Khronos OpenCL C++ header.
bool interleave
True if the h0 (k0xn0) blocks have to be interleaved in the output row.
Definition: Types.h:1994
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:161
Window first_slice_window_3D() const
First 3D slice of the window.
Definition: Window.h:291
Describe a multidimensional execution window.
Definition: Window.h:39
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const GEMMRHSMatrixInfo &rhs_info)
Static function to check if given info will lead to a valid configuration of CLGEMMReshapeRHSMatrixKe...
#define ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(f, s)
Definition: Validate.h:205
SimpleTensor< T > slice(const SimpleTensor< T > &src, Coordinates starts, Coordinates ends)