38 #include "compute_kernel_writer/include/ckw/KernelWriter.h"
44 namespace experimental
46 namespace dynamic_fusion
56 "The source data type must be a floating-point data type");
72 const auto dst_h =
static_cast<int32_t
>(_dst->
dimension(1));
75 auto const_dst_h_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_h}}, ckw::DataType::Int32));
76 auto const_pos_1_i32 = writer->declare_constant_tile(ckw::ConstantData({{1}}, ckw::DataType::Int32));
77 auto const_0_i32 = writer->declare_constant_tile(ckw::ConstantData({{0}}, ckw::DataType::Int32));
91 int32_t dst_n0_partial = -1;
92 int32_t dst_m0_partial = -1;
95 int32_t dst_shift_back = -1;
103 dst_n0 = root_window.
x().
step();
104 dst_m0 = root_window.y().step();
105 dst_n0_partial = _dst->
dimension(0) % dst_n0;
107 dst_shift_back = (dst_n0 - dst_n0_partial) % dst_n0;
109 ckw::TensorSampler sampler_dst;
110 sampler_dst.format(ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1);
111 if (dst_n0_partial == 0)
117 sampler_dst.address_mode_x(ckw::TensorSamplerAddressModeX::OverlappingMin);
120 if (dst_m0_partial == 0)
126 sampler_dst.address_mode_y(ckw::TensorSamplerAddressModeY::ClampToBorderMaxOnly);
130 sampler_dst.storage(ckw::TensorStorageType::BufferUint8Ptr);
134 auto tile_dst = writer->declare_tile(
"dst", ckw::TileInfo(dst_dt, dst_m0, dst_n0));
137 dst->init_virtual_tensor(tile_dst, sampler_dst);
144 dst_n0 =
dst->tile().tile_info().width();
145 dst_m0 =
dst->tile().tile_info().height();
146 dst_n0_partial = _dst->
dimension(0) % dst_n0;
148 ckw::TensorSampler sampler_dst =
dst->tensor_sampler();
150 if (sampler_dst.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
154 else if (sampler_dst.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
156 dst_m0_partial = _dst->
dimension(1) % dst_m0;
160 dst_shift_back = (dst_n0 - dst_n0_partial) % dst_n0;
163 const auto &tile_dst =
dst->tile();
169 auto const_dst_n0_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_n0}}, ckw::DataType::Int32));
170 auto const_dst_m0_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_m0}}, ckw::DataType::Int32));
171 auto const_dst_shift_back_n0_i32 =
172 writer->declare_constant_tile(ckw::ConstantData({{dst_shift_back}}, ckw::DataType::Int32));
177 if (!
src->has_tile())
180 ckw::TensorSampler sampler_src =
dst->tensor_sampler();
182 auto tile_gid_0 = writer->declare_tile(
"gid_0", ckw::TileInfo(ckw::DataType::Int32));
183 auto tile_gid_1 = writer->declare_tile(
"gid_1", ckw::TileInfo(ckw::DataType::Int32));
184 auto tile_gid_2 = writer->declare_tile(
"gid_2", ckw::TileInfo(ckw::DataType::Int32));
186 writer->op_get_global_id(tile_gid_0, 0);
187 writer->op_get_global_id(tile_gid_1, 1);
188 writer->op_get_global_id(tile_gid_2, 2);
190 auto tile_cout0 = writer->declare_tile(
"cout0", ckw::TileInfo(ckw::DataType::Int32));
191 auto tile_mout0 = writer->declare_tile(
"mout0", ckw::TileInfo(ckw::DataType::Int32));
192 auto tile_mout1 = writer->declare_tile(
"mout1", ckw::TileInfo(ckw::DataType::Int32));
193 auto tile_bout0 = writer->declare_tile(
"bout0", ckw::TileInfo(ckw::DataType::Int32));
197 const_dst_shift_back_n0_i32, const_0_i32);
201 if (sampler_src.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
203 writer->op_assign(tile_mout1, const_0_i32);
206 else if (sampler_src.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
208 writer->op_binary(tile_mout1, ckw::BinaryOp::Mod, tile_gid_2, const_dst_h_i32);
209 writer->op_binary(tile_bout0, ckw::BinaryOp::Div, tile_gid_2, const_dst_h_i32);
212 auto tile_src = writer->declare_tile(
"src", ckw::TileInfo(src_dt, dst_m0, dst_n0));
214 writer->op_load(tile_src,
src->tensor(), sampler_src, tile_cout0, tile_mout0, tile_mout1, tile_bout0);
217 src->init_virtual_tensor(tile_src, sampler_src);
220 auto tile_src =
src->tile();
234 writer->op_cast(tile_dst, tile_src, convert_policy);