38 inline Coordinates expand_coordinates(Coordinates in_coord,
size_t axis,
size_t slice,
size_t num_dimensions)
50 Coordinates expanded_coord;
51 expanded_coord.set_num_dimensions(num_dimensions);
52 expanded_coord.set(axis,
slice);
53 for(
size_t k = 0; k < axis; ++k)
55 expanded_coord.set(k, in_coord[k]);
57 for(
size_t k = axis + 1; k < num_dimensions; ++k)
59 expanded_coord.set(k, in_coord[k - 1]);
61 return expanded_coord;
65 SimpleTensor<T> get_slice(
const SimpleTensor<T> &input_tensor,
size_t axis,
size_t slice)
67 TensorShape out_shape = input_tensor.shape();
68 out_shape.remove_dimension(axis);
70 const size_t unpacked_num_dimensions(input_tensor.shape().num_dimensions());
72 SimpleTensor<T> output{ out_shape, input_tensor.data_type() };
75 win.use_tensor_dimensions(out_shape);
78 const Coordinates input_coords = expand_coordinates(
id, axis,
slice, unpacked_num_dimensions);
79 *
reinterpret_cast<T *
>(output(
id)) = *
reinterpret_cast<const T *
>(input_tensor(input_coords));
90 const unsigned int axis_u =
wrap_around(axis,
static_cast<int>(input_tensor.
shape().num_dimensions()));
92 for(
size_t k = 0; k < output_tensors.size(); ++k)
96 output = copy_tensor<T>(kth_slice);
98 return output_tensors;