26 #if defined(VEC_SIZE) && defined(DATA_TYPE) && defined(INTERNAL_DATA_TYPE) && defined(GAMMA) && defined(BETA) && defined(EPSILON) && defined(DIM_X) && defined(DIM_Y) && defined(DIM_Z) 66 INTERNAL_DATA_TYPE
sum = 0.f;
67 INTERNAL_DATA_TYPE sum_sq = 0.f;
71 const int ch = get_global_id(0);
72 const int batch = get_global_id(2);
73 const int elements_plane = DIM_Y * DIM_Z;
75 for(
int i_w = 0; i_w < DIM_Y; ++i_w)
77 for(
int i_h = 0; i_h < DIM_Z; ++i_h)
81 sum_sq += data * data;
85 #else // !defined(NHWC) 86 const int ch = get_global_id(2) % DIM_Z;
87 const int batch = get_global_id(2) / DIM_Z;
88 const int elements_plane = DIM_X * DIM_Y;
95 for(
int y = 0; y < DIM_Y; ++y)
104 part_sum_sq += data * data;
107 for(; x < DIM_X; ++x)
111 part_sum_sq.s0 += data * data;
116 part_sum.s01234567 += part_sum.s89abcdef;
117 part_sum_sq.s01234567 += part_sum_sq.s89abcdef;
118 #endif // VEC_SIZE > 8 120 part_sum.s0123 += part_sum.s4567;
121 part_sum_sq.s0123 += part_sum_sq.s4567;
122 #endif // VEC_SIZE > 4 124 part_sum.s01 += part_sum.s23;
125 part_sum_sq.s01 += part_sum_sq.s23;
126 #endif // VEC_SIZE > 2 127 part_sum.s0 += part_sum.s1;
128 part_sum_sq.s0 += part_sum_sq.s1;
130 sum = (INTERNAL_DATA_TYPE)part_sum.s0;
131 sum_sq = (INTERNAL_DATA_TYPE)part_sum_sq.s0;
133 #endif // defined(NHWC) 135 const INTERNAL_DATA_TYPE mean = (sum / elements_plane);
136 const INTERNAL_DATA_TYPE var = (sum_sq / elements_plane) - (mean * mean);
137 const INTERNAL_DATA_TYPE multip = GAMMA / sqrt(var + EPSILON);
141 for(
int i_w = 0; i_w < DIM_Y; ++i_w)
143 for(
int i_h = 0; i_h < DIM_Z; ++i_h)
147 __global
DATA_TYPE *output_address = input_address;
151 *(output_address) = (*(input_address) - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
155 #else // !defined(NHWC) 156 for(
int y = 0; y < DIM_Y; ++y)
163 __global
DATA_TYPE *output_address = input_address;
172 res = (data - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
177 for(; x < DIM_X; ++x)
181 __global
DATA_TYPE *output_address = input_address;
185 *(output_address) = (*(input_address) - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
188 #endif // defined(NHWC)
DATA_TYPE sum(__global const DATA_TYPE *input)
Calculate sum of a vector.
Structure to hold 4D tensor information.
#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size)
SimpleTensor< T > instance_normalization(const SimpleTensor< T > &src, float gamma, float beta, float epsilon)
__global const uchar * tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
Get the pointer position of a Tensor4D.
#define TENSOR4D_DECLARATION(name)
#define VEC_DATA_TYPE(type, size)