27 #define MIN_VALUE_float -FLT_MAX
28 #define MIN_VALUE_half -HALF_MAX
29 #define MIN_VALUE_char CHAR_MIN
30 #define MIN_VALUE_uchar 0
32 #define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type
33 #define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type)
34 #define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE)
70 __kernel
void softmax_x(
71 __global uchar *src_ptr,
75 uint src_offset_first_element,
77 __global uchar *dst_ptr,
81 uint dst_offset_first_element
85 __global uchar *tmp_ptr,
89 uint tmp_offset_first_element
93 const int dim_0 = get_global_id(0);
94 const int dim_1 = get_global_id(1);
95 const int dim_2 = get_global_id(2);
97 src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
98 dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
101 tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
102 #else // IS_QUANTIZED
103 __global uchar *tmp_ptr = dst_ptr;
104 #endif // IS_QUANTIZED
117 for (; i < LENGTH; ++i)
119 DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i *
sizeof(DATA_TYPE));
121 max_value = max(max_value, data);
125 TMP_DATA_TYPE sum_value = 0;
128 TMP_DATA_TYPE max_value_f = (
CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE;
129 TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
130 # define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
131 #else // IS_QUANTIZED
132 # define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
133 #endif // IS_QUANTIZED
139 data = REGULARIZE(data);
148 VSTORE(
VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i *
sizeof(TMP_DATA_TYPE)));
151 for (; i < LENGTH; ++i)
153 TMP_DATA_TYPE data =
CONVERT(*(__global DATA_TYPE *)(src_ptr + i *
sizeof(DATA_TYPE)), TMP_DATA_TYPE);
155 data = REGULARIZE(data);
158 sum_value += exp(data);
164 *(__global TMP_DATA_TYPE *)(tmp_ptr + i *
sizeof(TMP_DATA_TYPE)) = data;
172 TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET;
173 # define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte)
175 TMP_DATA_TYPE norm_div = sum_value * DST_SCALE;
176 # define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE))
178 #else // IS_QUANTIZED
180 # define NORMALIZE(SIZE, x) ((x) - log(sum_value))
182 # define NORMALIZE(SIZE, x) ((x) / sum_value)
184 #endif // IS_QUANTIZED
192 VSTORE(
VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i *
sizeof(DATA_TYPE)));
195 for (; i < LENGTH; ++i)
197 TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i *
sizeof(TMP_DATA_TYPE));
199 DATA_TYPE result = NORMALIZE(1, data);
201 *(__global DATA_TYPE *)(dst_ptr + i *
sizeof(DATA_TYPE)) = result;
244 __kernel
void softmax_non_x(
245 __global uchar *src_ptr,
249 uint src_offset_first_element,
251 __global uchar *dst_ptr,
255 uint dst_offset_first_element,
257 __global uchar *tmp_ptr,
261 uint tmp_offset_first_element,
263 uint src_stride_axis,
268 const int dim_1 = get_global_id(1);
269 const int dim_2 = get_global_id(2);
271 src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
272 dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
273 tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
289 uint tmp_extra_offset = LENGTH *
VEC_SIZE * (
sizeof(TMP_DATA_TYPE) -
sizeof(DATA_TYPE));
295 for (i = 0; i < LENGTH; ++i)
299 max_value = max(max_value, data);
301 VSTORE(
VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i *
VEC_SIZE *
sizeof(DATA_TYPE)));
309 VEC_DATA_TYPE(TMP_DATA_TYPE,
VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
310 # define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
311 #else // IS_QUANTIZED
312 # define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
313 #endif // IS_QUANTIZED
315 for (i = 0; i < LENGTH; ++i)
319 data = REGULARIZE(data);
322 sum_value += exp(data);
337 # define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte)
340 # define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))
342 #else // IS_QUANTIZED
344 # define NORMALIZE(x) ((x) - log(sum_value))
346 # define NORMALIZE(x) ((x) / sum_value)
348 #endif // IS_QUANTIZED
350 for (i = 0; i < LENGTH; ++i)
356 STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis,
VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
362 #endif // SOFTMAX_NON_X
365 #undef MIN_VALUE_TYPE
366 #undef MIN_VALUE_TYPE_STR
368 #undef MIN_VALUE_float
369 #undef MIN_VALUE_half
370 #undef MIN_VALUE_char
371 #undef MIN_VALUE_uchar