Compute Library
 21.02
IWeightsManager Class Reference

Weights manager interface to handle weights transformations. More...

#include <IWeightsManager.h>

Public Member Functions

 IWeightsManager ()
 Constructor. More...
 
virtual ~IWeightsManager ()=default
 Default Destructor. More...
 
 IWeightsManager (const IWeightsManager &)=delete
 Prevent instances of this class to be copy constructed. More...
 
IWeightsManageroperator= (const IWeightsManager &)=delete
 Prevent instances of this class to be copied. More...
 
 IWeightsManager (IWeightsManager &&)=default
 Allow instances of this class to be move constructed. More...
 
IWeightsManageroperator= (IWeightsManager &&)=default
 Allow instances of this class to be moved. More...
 
void manage (const ITensor *weights, ITransformWeights *parent=nullptr)
 Start managing a weights tensor. More...
 
ITensorrun (const ITensor *weights, ITransformWeights *weights_transform)
 Run the reshape function. More...
 
ITensoracquire (const ITensor *weights, ITransformWeights *weights_transform)
 Acquire the requested reshape tensor of the selected weights. More...
 
bool are_weights_managed (const ITensor *weights)
 Check if the weights are managed. More...
 

Detailed Description

Weights manager interface to handle weights transformations.

Definition at line 36 of file IWeightsManager.h.

Constructor & Destructor Documentation

◆ IWeightsManager() [1/3]

Constructor.

Definition at line 28 of file IWeightsManager.cpp.

29  : _managed_weights(), _managed_weights_parents()
30 {
31 }

◆ ~IWeightsManager()

virtual ~IWeightsManager ( )
virtualdefault

Default Destructor.

◆ IWeightsManager() [2/3]

IWeightsManager ( const IWeightsManager )
delete

Prevent instances of this class to be copy constructed.

◆ IWeightsManager() [3/3]

IWeightsManager ( IWeightsManager &&  )
default

Allow instances of this class to be move constructed.

Member Function Documentation

◆ acquire()

ITensor * acquire ( const ITensor weights,
ITransformWeights weights_transform 
)

Acquire the requested reshape tensor of the selected weights.

Parameters
[in]weightsPointer to the weights tensor to be managed
[in]weights_transformWeights transformation object

Definition at line 117 of file IWeightsManager.cpp.

References IWeightsManager::are_weights_managed(), ARM_COMPUTE_ERROR_ON_MSG, ITransformWeights::get_weights(), ITransformWeights::increase_refcount(), IWeightsManager::manage(), and ITransformWeights::uid().

Referenced by NEFullyConnectedLayer::configure(), CLFullyConnectedLayer::configure(), NEGEMMConvolutionLayer::configure(), and CLGEMMConvolutionLayer::configure().

118 {
119  ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot acquire weights. Weights are not managed");
120 
121  ITensor *transformed_weights{ nullptr };
122  auto item = _managed_weights.find(weights);
123 
124  // Check if I already have the requested transform. If I do,
125  // increase the refcount of the transformed weights object and
126  // reuse the tensor
127  for(auto it : item->second)
128  {
129  if(it->uid() == weights_transform->uid())
130  {
131  transformed_weights = it->get_weights();
132  it->increase_refcount();
133  break;
134  }
135  }
136 
137  if(transformed_weights == nullptr)
138  {
139  transformed_weights = weights_transform->get_weights();
140  weights_transform->increase_refcount();
141  item->second.emplace_back(weights_transform);
142  }
143 
144  // Manage the weights and store link to the parent node
145  manage(transformed_weights, weights_transform);
146 
147  return transformed_weights;
148 }
void manage(const ITensor *weights, ITransformWeights *parent=nullptr)
Start managing a weights tensor.
bool are_weights_managed(const ITensor *weights)
Check if the weights are managed.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456

◆ are_weights_managed()

bool are_weights_managed ( const ITensor weights)

Check if the weights are managed.

Parameters
[in]weightsPointer to the weights tensor we want to check if managed
Returns
True if the weights tensor is managed else false

Definition at line 112 of file IWeightsManager.cpp.

Referenced by IWeightsManager::acquire(), NEFullyConnectedLayer::configure(), CLFullyConnectedLayer::configure(), NEGEMMConvolutionLayer::configure(), CLGEMMConvolutionLayer::configure(), IWeightsManager::manage(), NEGEMM::prepare(), NEGEMMLowpMatrixMultiplyCore::prepare(), NEFullyConnectedLayer::prepare(), CLGEMM::prepare(), CLFullyConnectedLayer::prepare(), NEGEMMConvolutionLayer::prepare(), CLGEMMConvolutionLayer::prepare(), IWeightsManager::run(), and CLGEMM::run().

113 {
114  return (_managed_weights.find(weights) != _managed_weights.end());
115 }

◆ manage()

void manage ( const ITensor weights,
ITransformWeights parent = nullptr 
)

Start managing a weights tensor.

Parameters
[in]weightsPointer to the weights tensor to be managed
[in]parentParent node in case where the weights are coming from a previous reshape function

Definition at line 33 of file IWeightsManager.cpp.

References IWeightsManager::are_weights_managed().

Referenced by IWeightsManager::acquire(), NEFullyConnectedLayer::configure(), and CLFullyConnectedLayer::configure().

34 {
35  if(!are_weights_managed(weights))
36  {
37  _managed_weights[weights];
38  }
39 
40  // In case the weights are an output of a previous reshape function
41  // store the parent's link
42  if(parent != nullptr)
43  {
44  if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
45  {
46  _managed_weights_parents[weights] = parent;
47  }
48  }
49 }
bool are_weights_managed(const ITensor *weights)
Check if the weights are managed.

◆ operator=() [1/2]

IWeightsManager& operator= ( const IWeightsManager )
delete

Prevent instances of this class to be copied.

◆ operator=() [2/2]

IWeightsManager& operator= ( IWeightsManager &&  )
default

Allow instances of this class to be moved.

◆ run()

ITensor * run ( const ITensor weights,
ITransformWeights weights_transform 
)

Run the reshape function.

Parameters
[in]weightsPointer to the weights tensor we want to reshape
[in]weights_transformWeights transformation object
Returns
The reshaped tensor

Definition at line 51 of file IWeightsManager.cpp.

References IWeightsManager::are_weights_managed(), ARM_COMPUTE_ERROR_ON_MSG, ITransformWeights::get_weights(), ITensor::mark_as_unused(), ITransformWeights::run(), and ITransformWeights::uid().

Referenced by NEFullyConnectedLayer::prepare(), CLGEMM::prepare(), CLFullyConnectedLayer::prepare(), NEGEMMConvolutionLayer::prepare(), CLGEMMConvolutionLayer::prepare(), and CLGEMM::run().

52 {
53  ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot run function. Weights are not managed");
54 
55  // Find if I have the same weights with weights transform. If I do, don't run the reshape
56  auto item = _managed_weights.find(weights);
57  bool perform_run{ true };
58  ITensor *weights_tensor{ nullptr };
59 
60  // Check if I already have the requested transform and I have run the reshape function
61  for(auto it : item->second)
62  {
63  if(it->is_reshape_run() && (it->uid() == weights_transform->uid()))
64  {
65  weights_tensor = it->get_weights();
66  perform_run = false;
67  break;
68  }
69  }
70 
71  if(perform_run)
72  {
73  weights_transform->run();
74  weights_tensor = weights_transform->get_weights();
75  }
76 
77  // Check if we can release memory from parent
78  auto parent_item = _managed_weights_parents.find(weights);
79  if(parent_item != _managed_weights_parents.end())
80  {
81  int32_t refcount = parent_item->second->decrease_refcount();
82  if(refcount == 0)
83  {
84  parent_item->second->release();
85  }
86  }
87 
88  // Check top level weights. If all the transformations are done
89  // mark the weights as unused
90  if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
91  {
92  auto item = _managed_weights.find(weights);
93  bool mark_as_unused = true;
94  for(auto it : item->second)
95  {
96  if(!it->is_reshape_run())
97  {
98  mark_as_unused = false;
99  break;
100  }
101  }
102 
103  if(mark_as_unused)
104  {
105  weights->mark_as_unused();
106  }
107  }
108 
109  return weights_tensor;
110 }
bool are_weights_managed(const ITensor *weights)
Check if the weights are managed.
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456

The documentation for this class was generated from the following files: