ArmNN
 24.08
Threadpool.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021, 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #if !defined(ARMNN_DISABLE_THREADS)
6 
7 #include <armnn/Threadpool.hpp>
8 
10 
11 namespace armnn
12 {
13 namespace experimental
14 {
15 
16 Threadpool::Threadpool(std::size_t numThreads,
17  IRuntime* runtimePtr,
18  std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
19  : m_RuntimePtr(runtimePtr)
20 {
21  for (auto i = 0u; i < numThreads; ++i)
22  {
23  m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i));
24  }
25 
26  LoadMemHandles(memHandles);
27 }
28 
29 void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
30 {
31  if (memHandles.size() == 0)
32  {
33  throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
34  }
35 
36  if (memHandles.size() != m_Threads.size())
37  {
39  "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
40  }
41 
42  NetworkId networkId = memHandles[0]->GetNetworkId();
43  for (uint32_t i = 1; i < memHandles.size(); ++i)
44  {
45  if (networkId != memHandles[i]->GetNetworkId())
46  {
48  "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
49  }
50  }
51 
52  std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
53 
54  m_WorkingMemHandleMap.insert(pair);
55 }
56 
58 {
59  if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
60  {
61  m_WorkingMemHandleMap.erase(networkId);
62  }
63  else
64  {
65  throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
66  }
67 }
68 
70  const InputTensors& inputTensors,
71  const OutputTensors& outputTensors,
72  const QosExecPriority priority,
73  std::shared_ptr<IAsyncExecutionCallback> cb)
74 {
75  if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
76  {
77  throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
78  }
79 
80  // Group execution parameters so that they can be easily added to the queue
81  ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
82 
83  std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
84 
85  // Add a message to the queue and notify the request thread
86  std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
87  switch (priority)
88  {
90  m_HighPriorityQueue.push(operation);
91  break;
93  m_LowPriorityQueue.push(operation);
94  break;
96  default:
97  m_MediumPriorityQueue.push(operation);
98  }
99  m_ThreadPoolEvent.notify_one();
100 }
101 
103 {
104  {
105  std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
106  m_TerminatePool = true;
107  }
108 
109  m_ThreadPoolEvent.notify_all();
110 
111  for (auto &thread : m_Threads)
112  {
113  thread->join();
114  }
115 }
116 
117 void Threadpool::ProcessExecPriorities(uint32_t index)
118 {
119  int expireRate = EXPIRE_RATE;
120  int highPriorityCount = 0;
121  int mediumPriorityCount = 0;
122 
123  while (true)
124  {
125  std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
126  {
127  // Wait for a message to be added to the queue
128  // This is in a separate scope to minimise the lifetime of the lock
129  std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
130 
131  m_ThreadPoolEvent.wait(lock,
132  [=]
133  {
134  return m_TerminatePool || !m_HighPriorityQueue.empty() ||
135  !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
136  });
137 
138  if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
139  m_LowPriorityQueue.empty())
140  {
141  break;
142  }
143 
144  // Get the message to process from the front of each queue based on priority from high to low
145  // Get high priority first if it does not exceed the expire rate
146  if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
147  {
148  currentExecInProgress = m_HighPriorityQueue.front();
149  m_HighPriorityQueue.pop();
150  highPriorityCount += 1;
151  }
152  // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
153  else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
154  {
155  currentExecInProgress = m_MediumPriorityQueue.front();
156  m_MediumPriorityQueue.pop();
157  mediumPriorityCount += 1;
158  // Reset high priority count
159  highPriorityCount = 0;
160  }
161  // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
162  else if (!m_LowPriorityQueue.empty())
163  {
164  currentExecInProgress = m_LowPriorityQueue.front();
165  m_LowPriorityQueue.pop();
166  // Reset high and medium priority count
167  highPriorityCount = 0;
168  mediumPriorityCount = 0;
169  }
170  else
171  {
172  // Reset high and medium priority count
173  highPriorityCount = 0;
174  mediumPriorityCount = 0;
175  continue;
176  }
177  }
178 
179  // invoke the asynchronous execution method
180  auto networkId = std::get<0>(*currentExecInProgress);
181  auto inputTensors = std::get<1>(*currentExecInProgress);
182  auto outputTensors = std::get<2>(*currentExecInProgress);
183  auto cb = std::get<3>(*currentExecInProgress);
184 
185  // Get time at start of inference
187 
188  try // executing the inference
189  {
190  IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
192  // Execute and populate the time at end of inference in the callback
193  m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ?
194  cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
195  cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
197  }
198  catch (const RuntimeException&)
199  {
200  cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
201  }
202  }
203 }
204 
205 } // namespace experimental
206 
207 } // namespace armnn
208 
209 #endif
armnn::InputTensors
std::vector< std::pair< LayerBindingId, class ConstTensor > > InputTensors
Definition: Tensor.hpp:394
armnn::QosExecPriority::Medium
@ Medium
armnn::experimental::Threadpool::Threadpool
Threadpool(std::size_t numThreads, IRuntime *runtimePtr, std::vector< std::shared_ptr< IWorkingMemHandle >> memHandles)
Definition: Threadpool.cpp:16
armnn::GetTimeNow
std::chrono::high_resolution_clock::time_point GetTimeNow()
Definition: Timer.hpp:14
ARMNN_NO_DEPRECATE_WARN_BEGIN
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
armnn::QosExecPriority::High
@ High
armnn::OutputTensors
std::vector< std::pair< LayerBindingId, class Tensor > > OutputTensors
Definition: Tensor.hpp:395
armnn::IRuntime
Definition: IRuntime.hpp:75
armnn::QosExecPriority::Low
@ Low
armnn::IRuntime::Execute
Status Execute(IWorkingMemHandle &workingMemHandle, const InputTensors &inputTensors, const OutputTensors &outputTensors, std::vector< ImportedInputId > preImportedInputs={}, std::vector< ImportedOutputId > preImportedOutputs={})
This is an experimental function.
Definition: Runtime.cpp:123
armnn::experimental::Threadpool::LoadMemHandles
void LoadMemHandles(std::vector< std::shared_ptr< IWorkingMemHandle >> memHandles)
Definition: Threadpool.cpp:29
armnn::NetworkId
int NetworkId
Definition: IRuntime.hpp:35
armnn::QosExecPriority
QosExecPriority
Definition: Types.hpp:79
armnn::Status::Success
@ Success
armnn::RuntimeException
Definition: Exceptions.hpp:120
ARMNN_NO_DEPRECATE_WARN_END
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
armnn::experimental::Threadpool::Schedule
void Schedule(NetworkId networkId, const InputTensors &inputTensors, const OutputTensors &outputTensors, const QosExecPriority priority, std::shared_ptr< IAsyncExecutionCallback > cb)
Schedule an asynchronous execution on the loaded network.
Definition: Threadpool.cpp:69
armnn::experimental::Threadpool::UnloadMemHandles
void UnloadMemHandles(NetworkId networkId)
Definition: Threadpool.cpp:57
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
Timer.hpp
armnn::experimental::Threadpool::TerminateThreadPool
void TerminateThreadPool() noexcept
Definition: Threadpool.cpp:102
armnn::EXPIRE_RATE
constexpr unsigned int EXPIRE_RATE
Variable to control expire rate of priority queue.
Definition: Types.hpp:37
armnn::HighResolutionClock
std::chrono::high_resolution_clock::time_point HighResolutionClock
Define a timer and associated inference ID for recording execution times.
Definition: Types.hpp:401
Threadpool.hpp
armnn::Status::Failure
@ Failure