ArmNN
 25.11
Loading...
Searching...
No Matches
TensorHandleFactoryRegistry.cpp
Go to the documentation of this file.
1//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
8
9namespace armnn
10{
11
12void TensorHandleFactoryRegistry::RegisterFactory(std::unique_ptr <ITensorHandleFactory> newFactory)
13{
14 if (!newFactory)
15 {
16 return;
17 }
18
19 ITensorHandleFactory::FactoryId id = newFactory->GetId();
20
21 // Don't register duplicates
22 for (auto& registeredFactory : m_Factories)
23 {
24 if (id == registeredFactory->GetId())
25 {
26 return;
27 }
28 }
29
30 // Take ownership of the new allocator
31 m_Factories.push_back(std::move(newFactory));
32}
33
34void TensorHandleFactoryRegistry::RegisterMemoryManager(std::shared_ptr<armnn::IMemoryManager> memoryManger)
35{
36 m_MemoryManagers.push_back(memoryManger);
37}
38
40{
41 for (auto& factory : m_Factories)
42 {
43 if (factory->GetId() == id)
44 {
45 return factory.get();
46 }
47 }
48
49 return nullptr;
50}
51
53 MemorySource memSource) const
54{
55 for (auto& factory : m_Factories)
56 {
57 if (factory->GetId() == id && factory->GetImportFlags() == static_cast<MemorySourceFlags>(memSource))
58 {
59 return factory.get();
60 }
61 }
62
63 return nullptr;
64}
65
67 ITensorHandleFactory::FactoryId importFactoryId)
68{
69 m_FactoryMappings[copyFactoryId] = importFactoryId;
70}
71
77
79{
80 for (auto& mgr : m_MemoryManagers)
81 {
82 mgr->Acquire();
83 }
84}
85
87{
88 for (auto& mgr : m_MemoryManagers)
89 {
90 mgr->Release();
91 }
92}
93
94} // namespace armnn
void RegisterFactory(std::unique_ptr< ITensorHandleFactory > allocator)
Register a TensorHandleFactory and transfer ownership.
void AquireMemory()
Aquire memory required for inference.
void ReleaseMemory()
Release memory required for inference.
void RegisterMemoryManager(std::shared_ptr< IMemoryManager > memoryManger)
Register a memory manager with shared ownership.
ITensorHandleFactory * GetFactory(ITensorHandleFactory::FactoryId id) const
Find a TensorHandleFactory by Id Returns nullptr if not found.
ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId)
Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy.
void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId, ITensorHandleFactory::FactoryId importFactoryId)
Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Impor...
Copyright (c) 2021 ARM Limited and Contributors.
MemorySource
Define the Memory Source to reduce copies.
Definition Types.hpp:246
unsigned int MemorySourceFlags