Compute Library
 20.08
NEConvertFullyConnectedWeights.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2019 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H
25 #define ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H
26 
32 
33 namespace arm_compute
34 {
35 // Forward declarations
36 class ITensor;
37 
38 /** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */
40 {
41 public:
42  /** Default constructor */
44  /** Initialize the function.
45  *
46  * @param[in] input Source weights tensor to convert. Must be 2 dimensional. Data types supported: All.
47  * @param[out] output The converted weights tensor. Shape and Data Type: Same as @p input.
48  * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
49  * @param[in] data_layout The data layout the weights have been trained in.
50  */
51  void configure(const ITensor *input, ITensor *output, const TensorShape &original_input_shape, DataLayout data_layout);
52  /** Static function to check if given info will lead to a valid configuration of @ref NEConvertFullyConnectedWeights
53  *
54  * @param[in] input Source weights tensor info to convert. Must be 2 dimensional. Data types supported: All.
55  * @param[in] output The converted weights tensor info. Shape and Data Type: Same as @p input.
56  * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
57  * @param[in] data_layout The data layout the weights have been trained in.
58  *
59  * @return A Status
60  */
61  static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout);
62 
63  // Inherited methods overriden:
64  void run() override;
65 
66 private:
68 };
69 
70 namespace weights_transformations
71 {
72 /** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */
74 {
75 public:
76  void run() override
77  {
78  _output.allocator()->allocate();
79  _func.run();
80  _reshape_run = true;
81  }
82 
83  void release() override
84  {
85  _output.allocator()->free();
86  }
87 
88  ITensor *get_weights() override
89  {
90  return &_output;
91  }
92 
93  uint32_t uid() override
94  {
95  return _uid;
96  }
97 
98  void configure(const ITensor *input, const TensorShape &original_input_shape, DataLayout data_layout)
99  {
100  _func.configure(input, &_output, original_input_shape, data_layout);
101  }
102 
103 private:
104  static constexpr uint32_t _uid = 0x4;
105  Tensor _output{};
107 };
108 } // namespace weights_transformations
109 } // namespace arm_compute
110 #endif /* ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H */
Shape of a tensor.
Definition: TensorShape.h:39
const DataLayout data_layout
Definition: Im2Col.cpp:146
Base class for all functions.
Definition: IFunction.h:30
Store the tensor's metadata.
Definition: ITensorInfo.h:40
Status class.
Definition: Error.h:52
Interface for NEON tensor.
Definition: ITensor.h:36
Copyright (c) 2017-2020 Arm Limited.
TensorAllocator * allocator()
Return a pointer to the tensor's allocator.
Definition: Tensor.cpp:48
Basic function to run NEConvertFullyConnectedWeightsKernel.
ITensor * get_weights() override
Get a pointer to the transformed weights.
void allocate() override
Allocate size specified by TensorInfo of CPU memory.
Interface to convert the 2D Fully Connected weights from NCHW to NHWC or vice versa.
void configure(const ITensor *input, const TensorShape &original_input_shape, DataLayout data_layout)
Basic implementation of the tensor interface.
Definition: Tensor.h:37
void free() override
Free allocated CPU memory.
void configure(const ITensor *input, ITensor *output, const TensorShape &original_input_shape, DataLayout data_layout)
Initialize the function.
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout)
Static function to check if given info will lead to a valid configuration of NEConvertFullyConnectedW...
Weights tensor transform interface In order to identify the different reshape functions,...
uint32_t uid() override
Function that returns a unique id of the reshape function.
DataLayout
[DataLayout enum definition]
Definition: Types.h:120
void run() override
Run the kernels contained in the function.