Compute Library
 22.11
PostOpUtils.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS
25 #define ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS
26 
29 
31 #include "support/Cast.h"
32 
33 #include <vector>
34 
35 /** (EXPERIMENTAL_POST_OPS) */
36 namespace arm_compute
37 {
38 namespace experimental
39 {
40 /** Transform a PostOpList of type FromTensorT to one of type ToTensorT */
41 template <typename FromTensorT, typename ToTensorT>
42 PostOpList<ToTensorT> transform_post_op_list_arguments(const PostOpList<FromTensorT> &post_ops, std::function<ToTensorT(FromTensorT)> transform_arg)
43 {
44  PostOpList<ToTensorT> transformed_post_ops;
45  int op_idx = 0;
46  for(const auto &post_op : post_ops.get_list())
47  {
48  switch(post_op->type())
49  {
51  {
52  const auto _post_op = utils::cast::polymorphic_downcast<const PostOpAct<FromTensorT> *>(post_op.get());
53  transformed_post_ops.template push_back_op<PostOpAct<ToTensorT>>(_post_op->_act_info);
54  break;
55  }
57  {
58  const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwiseAdd<FromTensorT> *>(post_op.get());
59  transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
60  break;
61  }
63  {
64  const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get());
65  transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy);
66  break;
67  }
68  default:
69  {
70  ARM_COMPUTE_ERROR("Unsupported PostOpType");
71  }
72  }
73  ++op_idx;
74  }
75  return transformed_post_ops;
76 }
77 
78 /** Get post op argument TensorType from post op argument index in a flattened, ordered post op argument list */
79 inline TensorType get_post_op_arg_type(size_t index)
80 {
81  ARM_COMPUTE_ERROR_ON_MSG(static_cast<int>(index) > EXPERIMENTAL_ACL_POST_OP_ARG_LAST - EXPERIMENTAL_ACL_POST_OP_ARG_FIRST, "Post Op argument index is out of range");
82  return static_cast<TensorType>(EXPERIMENTAL_ACL_POST_OP_ARG_FIRST + static_cast<int>(index));
83 }
84 
85 /** Get a sequence of PostOp Types from PostOpList */
86 template <typename T>
88 {
89  PostOpTypeSequence post_op_sequence;
90  for(const auto &op : post_ops.get_list())
91  {
92  post_op_sequence.push_back(op->type());
93  }
94  return post_op_sequence;
95 }
96 
97 } // namespace experimental
98 } // namespace arm_compute
99 #endif //ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS
std::vector< std::unique_ptr< IPostOp< TensorRelatedT > > > & get_list()
Get the underlying post op list.
Definition: IPostOp.h:165
experimental::PostOpList< ITensorInfo * > post_ops
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
Definition: Error.h:352
TensorType
Memory type.
Definition: Types.h:38
std::vector< PostOpType > PostOpTypeSequence
An ordered sequence of type of Post Ops.
Definition: IPostOp.h:43
Copyright (c) 2017-2022 Arm Limited.
TensorType get_post_op_arg_type(size_t index)
Get post op argument TensorType from post op argument index in a flattened, ordered post op argument ...
Definition: PostOpUtils.h:79
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
Definition: Error.h:456
PostOpList< ToTensorT > transform_post_op_list_arguments(const PostOpList< FromTensorT > &post_ops, std::function< ToTensorT(FromTensorT)> transform_arg)
Transform a PostOpList of type FromTensorT to one of type ToTensorT.
Definition: PostOpUtils.h:42
PostOpTypeSequence get_post_op_sequence(const PostOpList< T > &post_ops)
Get a sequence of PostOp Types from PostOpList.
Definition: PostOpUtils.h:87
A sequence of PostOps that can be appended to the end of other operators.
Definition: IPostOp.h:119