42 #include "compute_kernel_writer/include/ckw/KernelWriter.h"
43 #include "compute_kernel_writer/include/ckw/types/ConstantData.h"
44 #include "compute_kernel_writer/include/ckw/types/TensorSamplerTypes.h"
50 namespace experimental
52 namespace dynamic_fusion
79 const auto dst_h =
static_cast<int32_t
>(_dst->
dimension(1));
82 auto const_dst_h_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_h}}, ckw::DataType::Int32));
83 auto const_pos_1_i32 = writer->declare_constant_tile(ckw::ConstantData({{1}}, ckw::DataType::Int32));
84 auto const_0_i32 = writer->declare_constant_tile(ckw::ConstantData({{0}}, ckw::DataType::Int32));
98 int32_t dst_n0_partial = -1;
99 int32_t dst_m0_partial = -1;
101 if (!
dst->has_tile())
107 dst_n0 = root_window.
x().
step();
108 dst_m0 = root_window.y().step();
109 dst_n0_partial = _dst->
dimension(0) % dst_n0;
112 ckw::TensorSampler sampler_dst;
113 sampler_dst.format(ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1);
114 if (dst_n0_partial == 0)
120 sampler_dst.address_mode_x(ckw::TensorSamplerAddressModeX::OverlappingMin);
123 if (dst_m0_partial == 0)
129 sampler_dst.address_mode_y(ckw::TensorSamplerAddressModeY::ClampToBorderMaxOnly);
132 sampler_dst.storage(ckw::TensorStorageType::BufferUint8Ptr);
136 auto tile_dst = writer->declare_tile(
"dst", ckw::TileInfo(dst_dt, dst_m0, dst_n0));
139 dst->init_virtual_tensor(tile_dst, sampler_dst);
144 dst_n0 =
dst->tile().tile_info().width();
145 dst_m0 =
dst->tile().tile_info().height();
152 const auto &tile_dst =
dst->tile();
167 ckw::TensorSampler sampler_lhs =
dst->tensor_sampler();
169 bool broadcast_x =
false;
170 bool broadcast_y =
false;
172 int32_t lhs_n0 = dst_n0;
173 int32_t lhs_m0 = dst_m0;
184 if (sampler_lhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
192 else if (sampler_lhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
201 const int32_t lhs_partial_n0 = _lhs->
dimension(0) % lhs_n0;
202 const int32_t lhs_shift_back = (lhs_n0 - lhs_partial_n0) % lhs_n0;
205 auto const_lhs_n0_i32 = writer->declare_constant_tile(ckw::ConstantData({{lhs_n0}}, ckw::DataType::Int32));
206 auto const_lhs_m0_i32 = writer->declare_constant_tile(ckw::ConstantData({{lhs_m0}}, ckw::DataType::Int32));
207 auto const_lhs_shift_back_n0_i32 =
208 writer->declare_constant_tile(ckw::ConstantData({{lhs_shift_back}}, ckw::DataType::Int32));
210 auto tile_gid_0 = writer->declare_tile(
"gid_0_lhs", ckw::TileInfo(ckw::DataType::Int32));
211 auto tile_gid_1 = writer->declare_tile(
"gid_1_lhs", ckw::TileInfo(ckw::DataType::Int32));
212 auto tile_gid_2 = writer->declare_tile(
"gid_2_lhs", ckw::TileInfo(ckw::DataType::Int32));
214 writer->op_get_global_id(tile_gid_0, 0);
215 writer->op_get_global_id(tile_gid_1, 1);
216 writer->op_get_global_id(tile_gid_2, 2);
218 auto tile_cout0 = writer->declare_tile(
"cout0_lhs", ckw::TileInfo(ckw::DataType::Int32));
220 writer->declare_tile(
"mout0_lhs", ckw::TileInfo(ckw::DataType::Int32));
221 auto tile_mout1 = writer->declare_tile(
"mout1_lhs", ckw::TileInfo(ckw::DataType::Int32));
222 auto tile_bout0 = writer->declare_tile(
"bout0_lhs", ckw::TileInfo(ckw::DataType::Int32));
228 const_lhs_shift_back_n0_i32, const_0_i32);
232 writer->op_assign(tile_cout0, const_0_i32);
241 writer->op_assign(tile_mout0, const_0_i32);
245 if (sampler_lhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
247 writer->op_assign(tile_mout1, const_0_i32);
250 else if (sampler_lhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
255 writer->op_binary(tile_mout1, ckw::BinaryOp::Mod, tile_gid_2, const_dst_h_i32);
261 writer->op_assign(tile_mout1, const_0_i32);
264 writer->op_binary(tile_bout0, ckw::BinaryOp::Div, tile_gid_2, const_dst_h_i32);
268 auto tile_lhs = writer->declare_tile(
"lhs", ckw::TileInfo(lhs_dt, lhs_m0, lhs_n0));
270 writer->op_load(tile_lhs, lhs->
tensor(), sampler_lhs, tile_cout0, tile_mout0, tile_mout1, tile_bout0);
281 ckw::TensorSampler sampler_rhs =
dst->tensor_sampler();
283 bool broadcast_x =
false;
284 bool broadcast_y =
false;
286 int32_t rhs_n0 = dst_n0;
287 int32_t rhs_m0 = dst_m0;
298 if (sampler_rhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
306 else if (sampler_rhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
315 const int32_t rhs_partial_n0 = _rhs->
dimension(0) % rhs_n0;
316 const int32_t rhs_shift_back = (rhs_n0 - rhs_partial_n0) % rhs_n0;
319 auto const_rhs_n0_i32 = writer->declare_constant_tile(ckw::ConstantData({{rhs_n0}}, ckw::DataType::Int32));
320 auto const_rhs_m0_i32 = writer->declare_constant_tile(ckw::ConstantData({{rhs_m0}}, ckw::DataType::Int32));
321 auto const_rhs_shift_back_n0_i32 =
322 writer->declare_constant_tile(ckw::ConstantData({{rhs_shift_back}}, ckw::DataType::Int32));
324 auto tile_gid_0 = writer->declare_tile(
"gid_0_rhs", ckw::TileInfo(ckw::DataType::Int32));
325 auto tile_gid_1 = writer->declare_tile(
"gid_1_rhs", ckw::TileInfo(ckw::DataType::Int32));
326 auto tile_gid_2 = writer->declare_tile(
"gid_2_rhs", ckw::TileInfo(ckw::DataType::Int32));
328 writer->op_get_global_id(tile_gid_0, 0);
329 writer->op_get_global_id(tile_gid_1, 1);
330 writer->op_get_global_id(tile_gid_2, 2);
332 auto tile_cout0 = writer->declare_tile(
"cout0_rhs", ckw::TileInfo(ckw::DataType::Int32));
334 writer->declare_tile(
"mout0_rhs", ckw::TileInfo(ckw::DataType::Int32));
335 auto tile_mout1 = writer->declare_tile(
"mout1_rhs", ckw::TileInfo(ckw::DataType::Int32));
336 auto tile_bout0 = writer->declare_tile(
"bout0_rhs", ckw::TileInfo(ckw::DataType::Int32));
342 const_rhs_shift_back_n0_i32, const_0_i32);
346 writer->op_assign(tile_cout0, const_0_i32);
355 writer->op_assign(tile_mout0, const_0_i32);
359 if (sampler_rhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
361 writer->op_assign(tile_mout1, const_0_i32);
364 else if (sampler_rhs.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
367 const auto src_w =
static_cast<int32_t
>(_rhs->
dimension(1));
368 auto const_src_w = writer->declare_constant_tile(ckw::ConstantData({{src_w}}, ckw::DataType::Int32));
371 writer->op_binary(tile_mout1, ckw::BinaryOp::Mod, tile_mout1, const_src_w);
377 writer->op_assign(tile_mout1, const_0_i32);
380 writer->op_binary(tile_bout0, ckw::BinaryOp::Div, tile_mout1, const_src_w);
384 auto tile_rhs = writer->declare_tile(
"rhs", ckw::TileInfo(rhs_dt, rhs_m0, rhs_n0));
386 writer->op_load(tile_rhs, rhs->
tensor(), sampler_rhs, tile_cout0, tile_mout0, tile_mout1, tile_bout0);
392 const auto &tile_lhs = lhs->
tile();
393 const auto &tile_rhs = rhs->
tile();
399 writer->op_binary(tile_dst,
to_ckw(_attributes), tile_lhs, tile_rhs);
412 constexpr uint32_t vector_size_byte_opencl = 16;
423 const std::vector<std::string> build_params = {
424 "elementwise_binary",
430 return join(build_params,
"_");