Compute Library
 21.02
IWeightsManager.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
25 
26 namespace arm_compute
27 {
29  : _managed_weights(), _managed_weights_parents()
30 {
31 }
32 
33 void IWeightsManager::manage(const ITensor *weights, ITransformWeights *parent)
34 {
35  if(!are_weights_managed(weights))
36  {
37  _managed_weights[weights];
38  }
39 
40  // In case the weights are an output of a previous reshape function
41  // store the parent's link
42  if(parent != nullptr)
43  {
44  if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
45  {
46  _managed_weights_parents[weights] = parent;
47  }
48  }
49 }
50 
51 ITensor *IWeightsManager::run(const ITensor *weights, ITransformWeights *weights_transform)
52 {
53  ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot run function. Weights are not managed");
54 
55  // Find if I have the same weights with weights transform. If I do, don't run the reshape
56  auto item = _managed_weights.find(weights);
57  bool perform_run{ true };
58  ITensor *weights_tensor{ nullptr };
59 
60  // Check if I already have the requested transform and I have run the reshape function
61  for(auto it : item->second)
62  {
63  if(it->is_reshape_run() && (it->uid() == weights_transform->uid()))
64  {
65  weights_tensor = it->get_weights();
66  perform_run = false;
67  break;
68  }
69  }
70 
71  if(perform_run)
72  {
73  weights_transform->run();
74  weights_tensor = weights_transform->get_weights();
75  }
76 
77  // Check if we can release memory from parent
78  auto parent_item = _managed_weights_parents.find(weights);
79  if(parent_item != _managed_weights_parents.end())
80  {
81  int32_t refcount = parent_item->second->decrease_refcount();
82  if(refcount == 0)
83  {
84  parent_item->second->release();
85  }
86  }
87 
88  // Check top level weights. If all the transformations are done
89  // mark the weights as unused
90  if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
91  {
92  auto item = _managed_weights.find(weights);
93  bool mark_as_unused = true;
94  for(auto it : item->second)
95  {
96  if(!it->is_reshape_run())
97  {
98  mark_as_unused = false;
99  break;
100  }
101  }
102 
103  if(mark_as_unused)
104  {
105  weights->mark_as_unused();
106  }
107  }
108 
109  return weights_tensor;
110 }
111 
113 {
114  return (_managed_weights.find(weights) != _managed_weights.end());
115 }
116 
117 ITensor *IWeightsManager::acquire(const ITensor *weights, ITransformWeights *weights_transform)
118 {
119  ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot acquire weights. Weights are not managed");
120 
121  ITensor *transformed_weights{ nullptr };
122  auto item = _managed_weights.find(weights);
123 
124  // Check if I already have the requested transform. If I do,
125  // increase the refcount of the transformed weights object and
126  // reuse the tensor
127  for(auto it : item->second)
128  {
129  if(it->uid() == weights_transform->uid())
130  {
131  transformed_weights = it->get_weights();
132  it->increase_refcount();
133  break;
134  }
135  }
136 
137  if(transformed_weights == nullptr)
138  {
139  transformed_weights = weights_transform->get_weights();
140  weights_transform->increase_refcount();
141  item->second.emplace_back(weights_transform);
142  }
143 
144  // Manage the weights and store link to the parent node
145  manage(transformed_weights, weights_transform);
146 
147  return transformed_weights;
148 }
149 } // namespace arm_compute
virtual ITensor * get_weights()=0
Get a pointer to the transformed weights.
void manage(const ITensor *weights, ITransformWeights *parent=nullptr)
Start managing a weights tensor.
Interface for Neon tensor.
Definition: ITensor.h:36
virtual void run()=0
Run the transformation function.
Copyright (c) 2017-2021 Arm Limited.
void mark_as_unused() const
Marks a tensor as unused.
Definition: ITensor.cpp:168
bool are_weights_managed(const ITensor *weights)
Check if the weights are managed.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
virtual uint32_t uid()=0
Function that returns a unique id of the reshape function.
void increase_refcount()
Increase the object's refcount.
Weights tensor transform interface In order to identify the different reshape functions, each reshape function has to generate a unique id.
ITensor * run(const ITensor *weights, ITransformWeights *weights_transform)
Run the reshape function.
ITensor * acquire(const ITensor *weights, ITransformWeights *weights_transform)
Acquire the requested reshape tensor of the selected weights.