ArmNN
 25.11
Loading...
Searching...
No Matches
ModelAccuracyChecker Class Reference

#include <ModelAccuracyChecker.hpp>

Public Member Functions

 ModelAccuracyChecker (const std::map< std::string, std::string > &validationLabelSet, const std::vector< LabelCategoryNames > &modelOutputLabels)
 Constructor for a model top k accuracy checker.
float GetAccuracy (unsigned int k)
 Get Top K accuracy.
template<typename TContainer>
void AddImageResult (const std::string &imageName, std::vector< TContainer > outputTensor)
 Record the prediction result of an image.

Detailed Description

Definition at line 45 of file ModelAccuracyChecker.hpp.

Constructor & Destructor Documentation

◆ ModelAccuracyChecker()

ModelAccuracyChecker ( const std::map< std::string, std::string > & validationLabelSet,
const std::vector< LabelCategoryNames > & modelOutputLabels )

Constructor for a model top k accuracy checker.

Parameters
[in]validationLabelSetMapping from names of images to be validated, to category names of their corresponding ground-truth labels.
[in]modelOutputLabelsMapping from output nodes to the category names of their corresponding labels Note that an output node can have multiple category names.

Definition at line 17 of file ModelAccuracyChecker.cpp.

19 : m_GroundTruthLabelSet(validationLabels)
20 , m_ModelOutputLabels(modelOutputLabels)
21{}

Member Function Documentation

◆ AddImageResult()

template<typename TContainer>
void AddImageResult ( const std::string & imageName,
std::vector< TContainer > outputTensor )
inline

Record the prediction result of an image.

Parameters
[in]imageNameName of the image.
[in]outputTensorOutput tensor of the network running imageName.

Definition at line 73 of file ModelAccuracyChecker.hpp.

74 {
75 // Increment the total number of images processed
76 ++m_ImagesProcessed;
77
78 std::map<int, float> confidenceMap;
79 auto& output = outputTensor[0];
80
81 // Create a map of all predictions
82 mapbox::util::apply_visitor([&confidenceMap](auto && value)
83 {
84 int index = 0;
85 for (const auto & o : value)
86 {
87 if (o > 0)
88 {
89 confidenceMap.insert(std::pair<int, float>(index, static_cast<float>(o)));
90 }
91 ++index;
92 }
93 },
94 output);
95
96 // Create a comparator for sorting the map in order of highest probability
97 typedef std::function<bool(std::pair<int, float>, std::pair<int, float>)> Comparator;
98
99 Comparator compFunctor =
100 [](std::pair<int, float> element1, std::pair<int, float> element2)
101 {
102 return element1.second > element2.second;
103 };
104
105 // Do the sorting and store in an ordered set
106 std::set<std::pair<int, float>, Comparator> setOfPredictions(
107 confidenceMap.begin(), confidenceMap.end(), compFunctor);
108
109 const std::string correctLabel = m_GroundTruthLabelSet.at(imageName);
110
111 unsigned int index = 1;
112 for (std::pair<int, float> element : setOfPredictions)
113 {
114 if (index >= m_TopK.size())
115 {
116 break;
117 }
118 // Check if the ground truth label value is included in the topi prediction.
119 // Note that a prediction can have multiple prediction labels.
120 const LabelCategoryNames predictionLabels = m_ModelOutputLabels[static_cast<size_t>(element.first)];
121 if (std::find(predictionLabels.begin(), predictionLabels.end(), correctLabel) != predictionLabels.end())
122 {
123 ++m_TopK[index];
124 break;
125 }
126 ++index;
127 }
128 }
std::vector< std::string > LabelCategoryNames

◆ GetAccuracy()

float GetAccuracy ( unsigned int k)

Get Top K accuracy.

Parameters
[in]kThe number of top predictions to use for validating the ground-truth label. For example, if k is 3, then a prediction is considered correct as long as the ground-truth appears in the top 3 predictions.
Returns
The accuracy, according to the top k th predictions.

Definition at line 23 of file ModelAccuracyChecker.cpp.

24{
25 if (k > 10)
26 {
27 ARMNN_LOG(warning) << "Accuracy Tool only supports a maximum of Top 10 Accuracy. "
28 "Printing Top 10 Accuracy result!";
29 k = 10;
30 }
31 unsigned int total = 0;
32 for (unsigned int i = k; i > 0; --i)
33 {
34 total += m_TopK[i];
35 }
36 return static_cast<float>(total * 100) / static_cast<float>(m_ImagesProcessed);
37}
#define ARMNN_LOG(severity)
Definition Logging.hpp:212

References ARMNN_LOG, and armnn::warning.


The documentation for this class was generated from the following files: