24 #if defined(__aarch64__)
42 switch (_input->info()->data_type())
46 const int ksize_rows_elements = _xmax * _ksize;
47 const int jump_rows = ksize_rows_elements * window.x().start();
48 const int k_start = window.x().start() * _ksize;
49 const int k_end = std::min(window.x().end() * _ksize, _kmax);
50 const int stride = _kmax;
57 switch (_output->info()->data_type())
60 arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(
61 reinterpret_cast<float *
>(_output->buffer()) + jump_rows,
62 reinterpret_cast<float *
>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
65 arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>(
66 reinterpret_cast<bfloat16 *
>(_output->buffer()) + jump_rows,
67 reinterpret_cast<float *
>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
74 #if defined(ARM_COMPUTE_ENABLE_SVE)
77 switch (_output->info()->data_type())
80 arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(
81 reinterpret_cast<float *
>(_output->buffer()) + jump_rows,
82 reinterpret_cast<float *
>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
85 arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>(
86 reinterpret_cast<bfloat16 *
>(_output->buffer()) + jump_rows,
87 reinterpret_cast<float *
>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
109 NEReorderKernel::NEReorderKernel()
132 _input_wf = input_wf;
133 _output_wf = output_wf;
136 auto dims =
input->info()->num_dimensions();
141 _xmax =
input->info()->dimension(0);
142 _kmax =
input->info()->dimension(1);
147 _xmax =
input->info()->dimension(2);
148 _kmax =
input->info()->dimension(3);
163 #if defined(ARM_COMPUTE_ENABLE_SVE)
167 window_size = _kmax / _ksize;
174 window_size = _kmax / _ksize;
183 if (_kmax % _ksize != 0)
188 win.set(
Window::DimX, Window::Dimension(0, window_size, 1));
190 INEKernel::configure(win);
194 const ITensorInfo *output,
200 if (output->tensor_shape().total_size() != 0)
211 auto dims = output->num_dimensions();
216 input_x_dim =
input->dimension(0);
217 input_k_dim =
input->dimension(1);
218 output_x_dim = output->dimension(0);
219 output_k_dim = output->dimension(1);
224 input_x_dim =
input->dimension(2);
225 input_k_dim =
input->dimension(3);
226 output_x_dim = output->dimension(2);
227 output_k_dim = output->dimension(3);
239 #if defined(ARM_COMPUTE_ENABLE_SVE)
259 int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize);
269 #endif // defined(__aarch64__)