39 template <
typename T,
typename TDeltas>
45 const size_t num_classes = deltas.
shape()[0] / 4;
46 const size_t num_boxes = deltas.
shape()[1];
47 const TDeltas *deltas_ptr = deltas.
data();
48 T *pred_boxes_ptr = pred_boxes.data();
54 const auto scale_before = T(info.
scale());
58 const size_t box_fields = 4;
59 const size_t class_fields = 4;
61 #pragma omp parallel for 63 for(
size_t i = 0; i < num_boxes; ++i)
66 const size_t start_box = box_fields * i;
67 const T width = (boxes[start_box + 2] / scale_before) - (boxes[start_box] / scale_before) + T(1.f);
68 const T height = (boxes[start_box + 3] / scale_before) - (boxes[start_box + 1] / scale_before) + T(1.f);
69 const T ctr_x = (boxes[start_box] / scale_before) + T(0.5f) * width;
70 const T ctr_y = (boxes[start_box + 1] / scale_before) + T(0.5f) * height;
72 for(
size_t j = 0; j < num_classes; ++j)
75 const size_t start_delta = i * num_classes * class_fields + class_fields * j;
76 const TDeltas dx = deltas_ptr[start_delta] / TDeltas(info.
weights()[0]);
77 const TDeltas dy = deltas_ptr[start_delta + 1] / TDeltas(info.
weights()[1]);
78 TDeltas dw = deltas_ptr[start_delta + 2] / TDeltas(info.
weights()[2]);
79 TDeltas dh = deltas_ptr[start_delta + 3] / TDeltas(info.
weights()[3]);
86 const T pred_ctr_x = dx * width + ctr_x;
87 const T pred_ctr_y = dy * height + ctr_y;
88 const T pred_w = T(std::exp(dw)) * width;
89 const T pred_h = T(std::exp(dh)) * height;
92 pred_boxes_ptr[start_delta] = scale_after * utility::clamp<T>(pred_ctr_x - T(0.5f) * pred_w, T(0), T(img_w - 1));
93 pred_boxes_ptr[start_delta + 1] = scale_after * utility::clamp<T>(pred_ctr_y - T(0.5f) * pred_h, T(0), T(img_h - 1));
94 pred_boxes_ptr[start_delta + 2] = scale_after * utility::clamp<T>(pred_ctr_x + T(0.5f) * pred_w -
offset, T(0), T(img_w - 1));
95 pred_boxes_ptr[start_delta + 3] = scale_after * utility::clamp<T>(pred_ctr_y + T(0.5f) * pred_h -
offset, T(0), T(img_h - 1));
__global uchar * offset(const Image *img, int x, int y)
Get the pointer position of a Image.
DataType data_type() const override
Data type of the tensor.
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
SimpleTensor< float > convert_from_asymmetric(const SimpleTensor< uint8_t > &src)
TensorShape shape() const override
Shape of the tensor.
Copyright (c) 2017-2021 Arm Limited.
SimpleTensor< T > bounding_box_transform(const SimpleTensor< T > &boxes, const SimpleTensor< TDeltas > &deltas, const BoundingBoxTransformInfo &info)
Simple tensor object that stores elements in a consecutive chunk of memory.
ScaleKernelInfo info(interpolation_policy, default_border_mode, PixelValue(), sampling_policy, false)
QuantizationInfo quantization_info() const override
Quantization info in case of asymmetric quantized type.
DataType
Available data types.
const T * data() const
Constant pointer to the underlying buffer.