7 #include "armnnTestUtils/TensorHelpers.hpp"
12 #include <reference/test/RefWorkloadFactoryHelper.hpp>
14 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
16 #include <armnnTestUtils/LayerTestResult.hpp>
17 #include <armnnTestUtils/TensorCopyUtils.hpp>
18 #include <armnnTestUtils/WorkloadTestUtils.hpp>
20 #include <doctest/doctest.h>
42 template <
typename T, std::
size_t n>
45 bool testNameIndicatesUnsupported = testName.find(
"UNSUPPORTED") != std::string::npos;
46 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
47 "The test name does not match the supportedness it is reporting");
48 if (testResult.m_Supported)
50 auto result = CompareTensors(testResult.m_ActualData,
51 testResult.m_ExpectedData,
52 testResult.m_ActualShape,
53 testResult.m_ExpectedShape,
54 testResult.m_CompareBoolean);
55 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
59 template <
typename T, std::
size_t n>
62 bool testNameIndicatesUnsupported = testName.find(
"UNSUPPORTED") != std::string::npos;
63 for (
unsigned int i = 0; i < testResult.size(); ++i)
65 CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
66 "The test name does not match the supportedness it is reporting");
67 if (testResult[i].m_Supported)
69 auto result = CompareTensors(testResult[i].m_ActualData,
70 testResult[i].m_ExpectedData,
71 testResult[i].m_ActualShape,
72 testResult[i].m_ExpectedShape);
73 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
78 template<
typename FactoryType,
typename TFuncPtr,
typename... Args>
81 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
84 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
85 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
87 auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
94 template<
typename FactoryType,
typename TFuncPtr,
typename... Args>
97 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
100 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
101 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
103 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
105 auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...);
111 #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
112 TEST_CASE(#TestName) \
117 #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
118 TEST_CASE(#TestName) \
120 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
123 #define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
124 TEST_CASE_FIXTURE(Fixture, #TestName) \
126 RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
129 #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
130 TEST_CASE(#TestName) \
132 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
135 #define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
136 TEST_CASE_FIXTURE(Fixture, #TestName) \
138 RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
141 template<
typename FactoryType,
typename TFuncPtr,
typename... Args>
144 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
145 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
149 auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
153 template<
typename FactoryType,
typename TFuncPtr,
typename... Args>
156 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
161 auto refMemoryManager = WorkloadFactoryHelper<armnn::RefWorkloadFactory>::GetMemoryManager();
162 auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager);
164 auto testResult = (*testFunction)(
165 workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...);
169 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
170 TEST_CASE(#TestName) \
172 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
175 #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
176 TEST_CASE(#TestName) \
178 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
181 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
182 TEST_CASE_FIXTURE(Fixture, #TestName) \
184 CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
187 #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
188 TEST_CASE_FIXTURE(Fixture, #TestName) \
190 CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \