24 #ifndef ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS 25 #define ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS 38 namespace experimental
41 template <
typename FromTensorT,
typename ToTensorT>
46 for(
const auto &post_op : post_ops.
get_list())
48 switch(post_op->type())
52 const auto _post_op = utils::cast::polymorphic_downcast<const PostOpAct<FromTensorT> *>(post_op.get());
53 transformed_post_ops.template push_back_op<PostOpAct<ToTensorT>>(_post_op->_act_info);
58 const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwiseAdd<FromTensorT> *>(post_op.get());
59 transformed_post_ops.template push_back_op<PostOpEltwiseAdd<ToTensorT>>(transform_arg(_post_op->_addend), _post_op->_prev_dst_pos, _post_op->_policy);
64 const auto _post_op = utils::cast::polymorphic_downcast<const PostOpEltwisePRelu<FromTensorT> *>(post_op.get());
65 transformed_post_ops.template push_back_op<PostOpEltwisePRelu<ToTensorT>>(transform_arg(_post_op->_alpha_param), _post_op->_prev_dst_pos, _post_op->_policy);
75 return transformed_post_ops;
90 for(
const auto &op : post_ops.
get_list())
92 post_op_sequence.push_back(op->type());
94 return post_op_sequence;
99 #endif //ARM_COMPUTE_EXPERIMENTAL_POSTOPUTILS std::vector< std::unique_ptr< IPostOp< TensorRelatedT > > > & get_list()
Get the underlying post op list.
experimental::PostOpList< ITensorInfo * > post_ops
#define ARM_COMPUTE_ERROR(msg)
Print the given message then throw an std::runtime_error.
std::vector< PostOpType > PostOpTypeSequence
An ordered sequence of type of Post Ops.
Copyright (c) 2017-2022 Arm Limited.
TensorType get_post_op_arg_type(size_t index)
Get post op argument TensorType from post op argument index in a flattened, ordered post op argument ...
#define ARM_COMPUTE_ERROR_ON_MSG(cond, msg)
PostOpList< ToTensorT > transform_post_op_list_arguments(const PostOpList< FromTensorT > &post_ops, std::function< ToTensorT(FromTensorT)> transform_arg)
Transform a PostOpList of type FromTensorT to one of type ToTensorT.
PostOpTypeSequence get_post_op_sequence(const PostOpList< T > &post_ops)
Get a sequence of PostOp Types from PostOpList.
A sequence of PostOps that can be appended to the end of other operators.