5 #if !defined(ARMNN_DISABLE_THREADS)
13 namespace experimental
18 std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
19 : m_RuntimePtr(runtimePtr)
21 for (
auto i = 0u; i < numThreads; ++i)
23 m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities,
this, i));
31 if (memHandles.size() == 0)
33 throw armnn::RuntimeException(
"Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
36 if (memHandles.size() != m_Threads.size())
39 "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
42 NetworkId networkId = memHandles[0]->GetNetworkId();
43 for (uint32_t i = 1; i < memHandles.size(); ++i)
45 if (networkId != memHandles[i]->GetNetworkId())
48 "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
52 std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
54 m_WorkingMemHandleMap.insert(pair);
59 if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
61 m_WorkingMemHandleMap.erase(networkId);
73 std::shared_ptr<IAsyncExecutionCallback> cb)
75 if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
81 ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
83 std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
86 std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
90 m_HighPriorityQueue.push(operation);
93 m_LowPriorityQueue.push(operation);
97 m_MediumPriorityQueue.push(operation);
99 m_ThreadPoolEvent.notify_one();
105 std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
106 m_TerminatePool =
true;
109 m_ThreadPoolEvent.notify_all();
111 for (
auto &thread : m_Threads)
117 void Threadpool::ProcessExecPriorities(uint32_t index)
120 int highPriorityCount = 0;
121 int mediumPriorityCount = 0;
125 std::shared_ptr<ExecutionTuple> currentExecInProgress(
nullptr);
129 std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
131 m_ThreadPoolEvent.wait(lock,
134 return m_TerminatePool || !m_HighPriorityQueue.empty() ||
135 !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
138 if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
139 m_LowPriorityQueue.empty())
146 if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
148 currentExecInProgress = m_HighPriorityQueue.front();
149 m_HighPriorityQueue.pop();
150 highPriorityCount += 1;
153 else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
155 currentExecInProgress = m_MediumPriorityQueue.front();
156 m_MediumPriorityQueue.pop();
157 mediumPriorityCount += 1;
159 highPriorityCount = 0;
162 else if (!m_LowPriorityQueue.empty())
164 currentExecInProgress = m_LowPriorityQueue.front();
165 m_LowPriorityQueue.pop();
167 highPriorityCount = 0;
168 mediumPriorityCount = 0;
173 highPriorityCount = 0;
174 mediumPriorityCount = 0;
180 auto networkId = std::get<0>(*currentExecInProgress);
181 auto inputTensors = std::get<1>(*currentExecInProgress);
182 auto outputTensors = std::get<2>(*currentExecInProgress);
183 auto cb = std::get<3>(*currentExecInProgress);
190 IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
198 catch (
const RuntimeException&)