ArmNN
 25.02
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
TosaLayerSupportRules.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022, 2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <tosa_serialization_handler.h>
9 
10 using namespace armnn;
11 using namespace tosa;
12 
13 // List of Layer Support Rules common to TOSA backends only, for use with CheckSupportRule()
14 
16 {
17  template<typename Container>
18  explicit TosaOperatorAttributeOfAny(TosaSerializationOperator* op, const Container& c)
19  {
20  m_Res = std::any_of(c.begin(), c.end(), [&op](Attribute attribute)
21  {
22  return attribute == op->GetAttributeType();
23  });
24  }
25 };
26 
27 struct TosaTypeAnyOf : public Rule
28 {
29  template<typename Container>
30  TosaTypeAnyOf(TosaSerializationTensor* tensor, const Container& c)
31  {
32  m_Res = std::any_of(c.begin(), c.end(), [&tensor](DType dt)
33  {
34  return dt == tensor->GetDtype();
35  });
36  }
37 };
38 
40 {
41  explicit TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor* tensor)
42  {
43  m_Res = (tensor->GetShape().size() <= MaxNumOfTensorDimensions) || (!tensor->GetShape().empty());
44  }
45 };
46 
47 struct TosaAssertSize : public Rule
48 {
49  template<typename Container>
50  explicit TosaAssertSize(const Container& c1, const Container& c2)
51  {
52  m_Res = (c1.size() == c2.size());
53  }
54 };
55 
57 {
58  explicit TosaContainerContainsTwoTypes(std::tuple<DType, DType>& check,
59  const std::vector<std::tuple<DType, DType>>& c)
60  {
61  for (auto item: c)
62  {
63  if (std::get<0>(check) == std::get<0>(item) &&
64  std::get<1>(check) == std::get<1>(item))
65  {
66  m_Res = true;
67  return;
68  }
69  }
70  m_Res = false;
71  }
72 };
73 
75 {
76  explicit TosaContainerContainsThreeTypes(std::tuple<DType, DType, DType>& check,
77  const std::vector<std::tuple<DType, DType, DType>>& c)
78  {
79  for (auto item: c)
80  {
81  if (std::get<0>(check) == std::get<0>(item) &&
82  std::get<1>(check) == std::get<1>(item) &&
83  std::get<2>(check) == std::get<2>(item))
84  {
85  m_Res = true;
86  return;
87  }
88  }
89  m_Res = false;
90  }
91 };
Copyright (c) 2021 ARM Limited and Contributors.
constexpr unsigned int MaxNumOfTensorDimensions
Definition: Types.hpp:31
TosaAssertSize(const Container &c1, const Container &c2)
TosaContainerContainsThreeTypes(std::tuple< DType, DType, DType > &check, const std::vector< std::tuple< DType, DType, DType >> &c)
TosaContainerContainsTwoTypes(std::tuple< DType, DType > &check, const std::vector< std::tuple< DType, DType >> &c)
TosaOperatorAttributeOfAny(TosaSerializationOperator *op, const Container &c)
TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor *tensor)
TosaTypeAnyOf(TosaSerializationTensor *tensor, const Container &c)