40 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
configure_lhs_rhs_info(
unsigned int m,
unsigned int n,
unsigned int m0,
unsigned int n0,
unsigned int k0,
unsigned int v0,
unsigned int h0,
41 bool lhs_interleave,
bool rhs_interleave,
bool lhs_transpose,
bool rhs_transpose,
bool export_to_cl_image)
44 v0 = std::max(std::min(static_cast<int>(m / m0), static_cast<int>(v0)), static_cast<int>(1));
45 h0 = std::max(std::min(static_cast<int>(n / n0), static_cast<int>(h0)), static_cast<int>(1));
48 const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image);
50 return std::make_pair(lhs_info, rhs_info);
53 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
select_lhs_rhs_info(std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_img,
54 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> info_buf,
73 constexpr
unsigned int num_floats_per_pixel = 4;
79 if(pixel_alignment == 0)
84 const unsigned int row_pitch_alignment = pixel_alignment * num_floats_per_pixel;
85 const unsigned int round_up_width = ((stride_y_in_elements + row_pitch_alignment - 1) / row_pitch_alignment) * row_pitch_alignment;
86 const unsigned int padding = round_up_width - stride_y_in_elements;
bool image2d_from_buffer_supported(const cl::Device &device)
Helper function to check whether the cl_khr_image2d_from_buffer extension is supported.
Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info)
Utility function to validate the image2d OpenCL object support on the RHS reshaped matrix.
1 channel, 1 F32 per channel
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
static CLKernelLibrary & get()
Access the KernelLibrary singleton.
GEMM LHS (Left Hand Side) matrix information.
Store the tensor's metadata.
bool export_to_cl_image
True if the reshaped rhs has to be exported to cl_image.
Copyright (c) 2017-2021 Arm Limited.
1 channel, 1 F16 per channel
unsigned int k0
Number of partial accumulations performed by the matrix multiplication.
void update_padding_for_cl_image(ITensorInfo *tensor)
Update padding required to export the OpenCL buffer to OpenCL image2d.
GEMM RHS (Right Hand Side) matrix information.
virtual const TensorShape & tensor_shape() const =0
Size for each dimension of the tensor.
unsigned int n0
Number of columns processed by the matrix multiplication.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
TensorShape compute_rhs_reshaped_shape(const ITensorInfo &a, const GEMMRHSMatrixInfo &rhs_info)
Calculate the Right Hand Side matrix reshaped shape.
virtual size_t element_size() const =0
Element size in bytes calculated as data_size() * num_channels()
BorderSize PaddingSize
Container for 2D padding size.
size_t get_cl_image_pitch_alignment(const cl::Device &device)
Helper function to get the cl_image pitch alignment in pixels.
Manages all the OpenCL kernels compilation and caching, provides accessors for the OpenCL Context.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > configure_lhs_rhs_info(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image)
Configure GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
Wrapper to configure the Khronos OpenCL C++ header.
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Store the tensor's metadata.
virtual const Strides & strides_in_bytes() const =0
The strides in bytes for accessing each dimension of the tensor.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(t,...)
DataType
Available data types.
virtual bool extend_padding(const PaddingSize &padding)=0
Update the offset to the first element, the strides and the total size.
std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > select_lhs_rhs_info(std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_img, std::pair< GEMMLHSMatrixInfo, GEMMRHSMatrixInfo > info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
Select GEMMLHSMatrixInfo and GEMMRHSMatrixInfo.
const cl::Device & get_device()
Gets the CL device for which the programs are created.