38 #include "compute_kernel_writer/include/ckw/KernelWriter.h"
43 namespace experimental
45 namespace dynamic_fusion
80 auto const_k_i32 = writer->declare_constant_tile(ckw::ConstantData({{k}}, ckw::DataType::Int32));
81 auto const_k0_i32 = writer->declare_constant_tile(ckw::ConstantData({{k0}}, ckw::DataType::Int32));
82 auto const_0_i32 = writer->declare_constant_tile(ckw::ConstantData({{0}}, ckw::DataType::Int32));
83 auto const_pos_1_i32 = writer->declare_constant_tile(ckw::ConstantData({{1}}, ckw::DataType::Int32));
84 auto const_0_fp = writer->declare_constant_tile(ckw::ConstantData({{0.0f}}, dst_dt));
85 auto const_k_minus_k0_i32 = writer->declare_constant_tile(ckw::ConstantData({{k - k0}}, ckw::DataType::Int32));
96 const int32_t dst_n0 = root_window.
x().
step();
97 const int32_t dst_m0 = root_window.y().step();
100 const int32_t dst_n0_partial = _dst->
dimension(0) % dst_n0;
101 const int32_t dst_m0_partial = _dst->
dimension(1) % dst_m0;
104 const int32_t dst_shift_back = (dst_n0 - dst_n0_partial) % dst_n0;
106 ckw::TensorSampler sampler_dst;
107 sampler_dst.format(ckw::TensorSamplerFormat::Dim0_Dim1_Dim2);
108 if (dst_n0_partial == 0)
114 sampler_dst.address_mode_x(ckw::TensorSamplerAddressModeX::OverlappingMin);
117 if (dst_m0_partial == 0)
123 sampler_dst.address_mode_y(ckw::TensorSamplerAddressModeY::ClampToBorderMaxOnly);
127 sampler_dst.storage(ckw::TensorStorageType::BufferUint8Ptr);
130 auto tile_dst = writer->declare_tile(
"dst", ckw::TileInfo(dst_dt, dst_m0, dst_n0));
133 writer->op_assign(tile_dst, const_0_fp);
136 dst->init_virtual_tensor(tile_dst, sampler_dst);
142 auto const_dst_n0_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_n0}}, ckw::DataType::Int32));
143 auto const_dst_m0_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_m0}}, ckw::DataType::Int32));
144 auto const_shift_back_dst_n0_i32 =
145 writer->declare_constant_tile(ckw::ConstantData({{dst_shift_back}}, ckw::DataType::Int32));
153 ckw::TensorSampler sampler_lhs;
154 sampler_lhs.format(ckw::TensorSamplerFormat::Dim0_Dim1_Dim2);
158 sampler_lhs.storage(ckw::TensorStorageType::BufferUint8Ptr);
161 ckw::TensorSampler sampler_rhs;
162 sampler_rhs.format(ckw::TensorSamplerFormat::Dim0_Dim1_Dim2);
166 sampler_rhs.storage(ckw::TensorStorageType::BufferUint8Ptr);
177 auto tile_gid_0 = writer->declare_tile(
"gid_0", ckw::TileInfo(ckw::DataType::Int32));
178 auto tile_gid_1 = writer->declare_tile(
"gid_1", ckw::TileInfo(ckw::DataType::Int32));
179 auto tile_gid_2 = writer->declare_tile(
"gid_2", ckw::TileInfo(ckw::DataType::Int32));
181 writer->op_get_global_id(tile_gid_0, 0);
182 writer->op_get_global_id(tile_gid_1, 1);
183 writer->op_get_global_id(tile_gid_2, 2);
185 auto tile_idx_n = writer->declare_tile(
"idx_n", ckw::TileInfo(ckw::DataType::Int32));
186 auto tile_idx_m = writer->declare_tile(
"idx_m", ckw::TileInfo(ckw::DataType::Int32));
187 auto tile_idx_b = writer->declare_tile(
"idx_b", ckw::TileInfo(ckw::DataType::Int32));
191 const_shift_back_dst_n0_i32, const_0_i32);
198 auto tile_idx_k = writer->declare_tile(
"idx_k", ckw::TileInfo(ckw::DataType::Int32));
200 writer->op_assign(tile_idx_k, const_0_i32);
203 writer->op_for_loop(tile_idx_k, ckw::BinaryOp::LessEqual, const_k_minus_k0_i32, tile_idx_k, ckw::AssignmentOp::Increment, const_k0_i32,
206 auto tile_lhs = writer->declare_tile(
"lhs", ckw::TileInfo(
to_ckw(_lhs->
data_type()), dst_m0, k0));
207 auto tile_rhs = writer->declare_tile(
"rhs", ckw::TileInfo(
to_ckw(_rhs->
data_type()), dst_n0, k0));
208 writer->op_assign(tile_lhs, const_0_fp);
209 writer->op_assign(tile_rhs, const_0_fp);
211 writer->op_load(tile_lhs, lhs->
tensor(), sampler_lhs, tile_idx_k, tile_idx_m, tile_idx_b, const_0_i32);
212 writer->op_load(tile_rhs, rhs->
tensor(), sampler_rhs, tile_idx_k, tile_idx_n, tile_idx_b, const_0_i32);
214 writer->op_binary(tile_dst, ckw::BinaryOp::MatMul_Nt_T, tile_lhs, tile_rhs);
221 writer->op_for_loop(tile_idx_k, ckw::BinaryOp::Less, const_k_i32, tile_idx_k, ckw::AssignmentOp::Increment, const_pos_1_i32, [&]()
223 auto tile_lhs = writer->declare_tile(
"lhs", ckw::TileInfo(
to_ckw(_lhs->
data_type()), dst_m0, 1));
224 auto tile_rhs = writer->declare_tile(
"rhs", ckw::TileInfo(
to_ckw(_rhs->
data_type()), dst_n0, 1));
225 writer->op_assign(tile_lhs, const_0_fp);
226 writer->op_assign(tile_rhs, const_0_fp);
228 writer->op_load(tile_lhs, lhs->
tensor(), sampler_lhs, tile_idx_k, tile_idx_m, tile_idx_b, const_0_i32);
229 writer->op_load(tile_rhs, rhs->
tensor(), sampler_rhs, tile_idx_k, tile_idx_n, tile_idx_b, const_0_i32);
231 writer->op_binary(tile_dst, ckw::BinaryOp::MatMul_Nt_T, tile_lhs, tile_rhs);
243 const bool adj_lhs = _attributes.
adj_lhs();
245 const int32_t m0 = adj_lhs ?
adjust_vec_size(_settings.
m0(), m) : std::min(_settings.
m0(), m);