ArmNN
 24.02
ArgMinMax.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ArgMinMax.hpp"
7 
9 
12 
13 namespace armnn
14 {
15 
16 template <typename OUT>
17 void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
18  const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
19 {
20  IgnoreUnused(outputTensorInfo);
21 
22  unsigned int uAxis = armnnUtils::GetUnsignedAxis(inputTensorInfo.GetNumDimensions(), axis);
23 
24  const unsigned int outerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(), 0, uAxis);
25  const unsigned int axisSize = inputTensorInfo.GetShape()[uAxis];
26  const unsigned int innerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(),
27  uAxis + 1,
28  inputTensorInfo.GetNumDimensions());
29 
30  for (unsigned int outer = 0; outer < outerElements; ++outer) {
31  for (unsigned int inner = 0; inner < innerElements; ++inner) {
32  in[outer * axisSize * innerElements + inner];
33  auto tmpValue = in.Get();
34  unsigned int tmpIndex = 0;
35  for (unsigned int i = 1; i < axisSize; ++i) {
36  in[(outer * axisSize * innerElements) + (i * innerElements) + inner];
37  const auto& value = in.Get();
38  if ((function == armnn::ArgMinMaxFunction::Min && value < tmpValue) ||
39  (function == armnn::ArgMinMaxFunction::Max && value > tmpValue)) {
40  tmpValue = value;
41  tmpIndex = i;
42  }
43  }
44 
45  out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
46  }
47  }
48 }
49 
50 template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
51  const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
52 
53 template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
54  const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
55 
56 } //namespace armnn
armnn::ArgMinMaxFunction::Max
@ Max
armnn::Decoder< float >
armnnUtils::GetUnsignedAxis
unsigned int GetUnsignedAxis(const unsigned int inputDimension, const int axis)
Definition: TensorUtils.cpp:236
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:197
armnn::ArgMinMaxFunction
ArgMinMaxFunction
Definition: Types.hpp:103
IgnoreUnused.hpp
NumericCast.hpp
TensorUtils.hpp
armnn::ArgMinMax
void ArgMinMax(Decoder< float > &in, OUT *out, const TensorInfo &inputTensorInfo, const TensorInfo &outputTensorInfo, ArgMinMaxFunction function, int axis)
Definition: ArgMinMax.cpp:17
armnn::Decoder::Get
virtual IType Get() const =0
armnn::ArgMinMaxFunction::Min
@ Min
ArgMinMax.hpp
armnnUtils::GetNumElementsBetween
unsigned int GetNumElementsBetween(const armnn::TensorShape &shape, unsigned int firstAxisInclusive, unsigned int lastAxisExclusive)
Definition: TensorUtils.cpp:209
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:193
armnn::IgnoreUnused
void IgnoreUnused(Ts &&...)
Definition: IgnoreUnused.hpp:14
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6