38 #include "compute_kernel_writer/include/ckw/KernelWriter.h"
44 namespace experimental
46 namespace dynamic_fusion
72 const auto dst_h =
static_cast<int32_t
>(_dst->
dimension(1));
76 auto const_dst_h_i32 = writer->declare_constant_tile(ckw::ConstantData({{dst_h}}, ckw::DataType::Int32));
77 auto const_pos_1_i32 = writer->declare_constant_tile(ckw::ConstantData({{1}}, ckw::DataType::Int32));
78 auto const_0_i32 = writer->declare_constant_tile(ckw::ConstantData({{0}}, ckw::DataType::Int32));
79 auto const_neg_1_fp = writer->declare_constant_tile(ckw::ConstantData({{-1.0f}}, dst_dt));
80 auto const_pos_1_fp = writer->declare_constant_tile(ckw::ConstantData({{1.0f}}, dst_dt));
81 auto const_0_fp = writer->declare_constant_tile(ckw::ConstantData({{0.0f}}, dst_dt));
82 auto const_A_fp = writer->declare_constant_tile(ckw::ConstantData({{_attributes.
a()}}, dst_dt));
83 auto const_B_fp = writer->declare_constant_tile(ckw::ConstantData({{_attributes.
b()}}, dst_dt));
97 int32_t dst_n0_partial = -1;
98 int32_t dst_m0_partial = -1;
101 int32_t dst_shift_back = -1;
103 if (!
dst->has_tile())
109 dst_n0 = root_window.
x().
step();
110 dst_m0 = root_window.y().step();
111 dst_n0_partial = _dst->
dimension(0) % dst_n0;
113 dst_shift_back = (dst_n0 - dst_n0_partial) % dst_n0;
115 ckw::TensorSampler sampler_dst;
116 sampler_dst.format(ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1);
118 if (dst_n0_partial == 0)
124 sampler_dst.address_mode_x(ckw::TensorSamplerAddressModeX::OverlappingMin);
127 if (dst_m0_partial == 0)
133 sampler_dst.address_mode_y(ckw::TensorSamplerAddressModeY::ClampToBorderMaxOnly);
137 sampler_dst.storage(ckw::TensorStorageType::BufferUint8Ptr);
140 auto tile_dst = writer->declare_tile(
"dst", ckw::TileInfo(dst_dt, dst_m0, dst_n0));
143 dst->init_virtual_tensor(tile_dst, sampler_dst);
148 dst_n0 =
dst->tile().tile_info().width();
149 dst_m0 =
dst->tile().tile_info().height();
150 dst_n0_partial = _dst->
dimension(0) % dst_n0;
152 ckw::TensorSampler sampler_dst =
dst->tensor_sampler();
154 if (sampler_dst.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
158 else if (sampler_dst.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
160 dst_m0_partial = _dst->
dimension(1) % dst_m0;
164 dst_shift_back = (dst_n0 - dst_n0_partial) % dst_n0;
167 const auto &tile_dst =
dst->tile();
173 auto const_dst_n0 = writer->declare_constant_tile(ckw::ConstantData({{dst_n0}}, ckw::DataType::Int32));
174 auto const_dst_m0 = writer->declare_constant_tile(ckw::ConstantData({{dst_m0}}, ckw::DataType::Int32));
175 auto const_dst_shift_back_n0 =
176 writer->declare_constant_tile(ckw::ConstantData({{dst_shift_back}}, ckw::DataType::Int32));
181 if (!
src->has_tile())
184 ckw::TensorSampler sampler_src =
dst->tensor_sampler();
186 auto tile_gid_0 = writer->declare_tile(
"gid_0_src", ckw::TileInfo(ckw::DataType::Int32));
187 auto tile_gid_1 = writer->declare_tile(
"gid_1_src", ckw::TileInfo(ckw::DataType::Int32));
188 auto tile_gid_2 = writer->declare_tile(
"gid_2_src", ckw::TileInfo(ckw::DataType::Int32));
190 writer->op_get_global_id(tile_gid_0, 0);
191 writer->op_get_global_id(tile_gid_1, 1);
192 writer->op_get_global_id(tile_gid_2, 2);
194 auto tile_nout0 = writer->declare_tile(
"nout0_src", ckw::TileInfo(ckw::DataType::Int32));
196 writer->declare_tile(
"mout0_src", ckw::TileInfo(ckw::DataType::Int32));
197 auto tile_mout1 = writer->declare_tile(
"mout1_src", ckw::TileInfo(ckw::DataType::Int32));
198 auto tile_bout0 = writer->declare_tile(
"bout0_src", ckw::TileInfo(ckw::DataType::Int32));
205 if (sampler_src.format() == ckw::TensorSamplerFormat::Dim0_Dim1xDim2_1)
207 writer->op_assign(tile_mout1, const_0_i32);
210 else if (sampler_src.format() == ckw::TensorSamplerFormat::Dim0_Dim1_Dim2)
212 writer->op_binary(tile_mout1, ckw::BinaryOp::Mod, tile_gid_2, const_dst_h_i32);
213 writer->op_binary(tile_bout0, ckw::BinaryOp::Div, tile_gid_2, const_dst_h_i32);
216 auto tile_src = writer->declare_tile(
"src", ckw::TileInfo(dst_dt, dst_m0, dst_n0));
218 writer->op_load(tile_src,
src->tensor(), sampler_src, tile_nout0, tile_mout0, tile_mout1, tile_bout0);
221 src->init_virtual_tensor(tile_src, sampler_src);
224 const auto &tile_src =
src->tile();
231 case ActivationLayerInfo::ActivationFunction::LOGISTIC:
234 writer->op_binary(tile_dst, ckw::BinaryOp::Mul, tile_src, const_neg_1_fp);
236 writer->op_unary(tile_dst, ckw::UnaryOp::Exp, tile_dst);
238 writer->op_binary(tile_dst, ckw::BinaryOp::Add, tile_dst, const_pos_1_fp);
240 writer->op_binary(tile_dst, ckw::BinaryOp::Div, const_pos_1_fp, tile_dst);
243 case ActivationLayerInfo::ActivationFunction::TANH:
245 writer->op_unary(tile_dst, ckw::UnaryOp::Tanh, tile_src);
248 case ActivationLayerInfo::ActivationFunction::RELU:
251 writer->op_binary(tile_dst, ckw::BinaryOp::Max, tile_src, const_0_fp);
254 case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
257 writer->op_binary(tile_dst, ckw::BinaryOp::Max, tile_src, const_0_fp);
259 writer->op_binary(tile_dst, ckw::BinaryOp::Min, tile_dst, const_A_fp);
262 case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
265 writer->op_binary(tile_dst, ckw::BinaryOp::Max, tile_src, const_B_fp);
267 writer->op_binary(tile_dst, ckw::BinaryOp::Min, tile_dst, const_A_fp);