Compute Library
 22.08
ClDirectConv2dKernel Struct Reference

#include <ClKernelGraph.h>

Collaboration diagram for ClDirectConv2dKernel:
[legend]

Public Member Functions

Complexity complexity () const override
 
 ClDirectConv2dKernel ()=default
 
 ~ClDirectConv2dKernel () override=default
 
 ClDirectConv2dKernel (const ClKernelGraph *graph, Id id, const ClKernelConfig config, const ClDirectConv2dKernelDescriptor &desc, const ITensorDescPack< ClKernelTensor > tensors)
 
bool operator== (const ClKernel &other) const override
 
Status generate (ClKernelBlueprint &bp) const override
 
- Public Member Functions inherited from ClKernel
 ClKernel ()=default
 
virtual ~ClKernel ()=default
 
 ClKernel (const ClKernel &kernel)=default
 
ClKerneloperator= (const ClKernel &kernel)=default
 
 ClKernel (ClKernel &&kernel)=default
 
ClKerneloperator= (ClKernel &&kernel)=default
 
 ClKernel (const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ITensorDescPack< ClKernelTensor > &tensors)
 
Id id () const
 
ITensorDescPack< ClKernelTensortensors () const
 
ClKernelConfig config () const
 

Static Public Member Functions

static Status validate (const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ClDirectConv2dKernelDescriptor &conv2d_desc)
 

Data Fields

ClDirectConv2dKernelDescriptor desc {}
 

Additional Inherited Members

- Public Types inherited from ClKernel
using Id = DependencyGraph::Id
 

Detailed Description

Definition at line 122 of file ClKernelGraph.h.

Constructor & Destructor Documentation

◆ ClDirectConv2dKernel() [1/2]

ClDirectConv2dKernel ( )
default

◆ ~ClDirectConv2dKernel()

~ClDirectConv2dKernel ( )
overridedefault

◆ ClDirectConv2dKernel() [2/2]

Member Function Documentation

◆ complexity()

Complexity complexity ( ) const
inlineoverridevirtual

◆ generate()

Status generate ( ClKernelBlueprint bp) const
overridevirtual

Implements ClKernel.

Definition at line 39 of file ClKernelGraph.cpp.

References arm_compute::ACL_DST_0, arm_compute::ACL_SRC_0, arm_compute::ACL_SRC_1, arm_compute::ACL_SRC_2, arm_compute::experimental::dynamic_fusion::add_kcomp_direct_conv2d(), arm_compute::experimental::dynamic_fusion::add_tensor(), ARM_COMPUTE_ERROR_ON_NULLPTR, bias, ClDirectConv2dKernel::desc, arm_compute::test::validation::dst, and arm_compute::test::validation::input.

40 {
41  const auto input = _tensors.get_const_tensor(TensorType::ACL_SRC_0);
42  const auto weight = _tensors.get_const_tensor(TensorType::ACL_SRC_1);
43  const auto bias = _tensors.get_const_tensor(TensorType::ACL_SRC_2);
44  const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0);
46  ArgumentID input_id;
47  add_tensor(bp, input->desc, input_id, input->id);
48  ArgumentID weight_id;
49  add_tensor(bp, weight->desc, weight_id, weight->id);
50  ArgumentID bias_id = g_arg_placeholder;
51  if(bias != nullptr)
52  {
53  add_tensor(bp, bias->desc, bias_id, bias->id);
54  }
55  ArgumentID dst_id;
56  add_tensor(bp, dst->desc, dst_id, dst->id);
57 
58  add_kcomp_direct_conv2d(bp, desc, input_id, weight_id, bias_id, dst_id);
59  return Status{};
60 }
Status add_kcomp_direct_conv2d(ClKernelBlueprint &kernel_blueprint, const ClDirectConv2dKernelDescriptor &direct_conv2d_desc, ArgumentID src_id, ArgumentID weight_id, ArgumentID bias_id, ArgumentID &dst_id)
Component: Direct Convolution.
OpTensor add_tensor(OperatorGraph &graph, ITensorInfo &info)
Associate a TensorInfo with a newly created OpTensor in the graph.
#define ARM_COMPUTE_ERROR_ON_NULLPTR(...)
Definition: Validate.h:157
const int32_t * bias

◆ operator==()

bool operator== ( const ClKernel other) const
overridevirtual

Implements ClKernel.

Definition at line 121 of file ClKernelGraph.cpp.

References ClKernel::config(), ClDirectConv2dKernel::desc, and ClKernel::tensors().

122 {
123  const auto converted = *utils::cast::polymorphic_downcast<const ClDirectConv2dKernel *>(&other);
124  return config() == other.config() && tensors() == other.tensors() && desc == converted.desc;
125 }
ITensorDescPack< ClKernelTensor > tensors() const

◆ validate()

Status validate ( const ITensorInfo src,
const ITensorInfo weights,
const ITensorInfo biases,
const ITensorInfo dst,
const ClDirectConv2dKernelDescriptor conv2d_desc 
)
static

Definition at line 61 of file ClKernelGraph.cpp.

References ARM_COMPUTE_RETURN_ERROR_ON, ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN, ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS, ARM_COMPUTE_RETURN_ERROR_ON_MSG, ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR, Padding2D::bottom, arm_compute::CHANNEL, arm_compute::misc::shape_calculator::compute_deep_convolution_shape(), ClDirectConv2dKernelDescriptor::conv2d, ITensorInfo::data_layout(), arm_compute::test::validation::data_layout, ITensorInfo::dimension(), arm_compute::F16, arm_compute::F32, arm_compute::get_data_layout_dimension_index(), Padding2D::left, arm_compute::NHWC, ITensorInfo::num_dimensions(), Conv2dDescriptor::pad, Padding2D::right, Conv2dDescriptor::stride, ITensorInfo::tensor_shape(), Padding2D::top, TensorShape::total_size(), Size2D::x(), and Size2D::y().

Referenced by Conv2dContent::select_conv_method(), and Conv2dContent::translate().

62 {
63  // 1. Check validity
65  // Matching data type
68  if(biases != nullptr)
69  {
71  }
72 
73  // Matching data layout
76  if(biases != nullptr)
77  {
79  }
80 
81  // All tensor infos are initialized
82  ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0);
83  ARM_COMPUTE_RETURN_ERROR_ON(weights->tensor_shape().total_size() == 0);
84  ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0);
85  if(biases != nullptr)
86  {
87  ARM_COMPUTE_RETURN_ERROR_ON(biases->tensor_shape().total_size() == 0);
88  }
89  // Device requirements are met
91  // weights shape is correct
92  const DataLayout data_layout = src->data_layout();
93  const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
94  ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(channel_idx) != src->dimension(channel_idx), "Weights feature map dimension should match the respective src's one");
95  ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional");
96 
97  // dst shape is correct
98  PadStrideInfo legacy_pad_stride(conv2d_desc.conv2d.stride.x(), conv2d_desc.conv2d.stride.y(), conv2d_desc.conv2d.pad.left, conv2d_desc.conv2d.pad.right, conv2d_desc.conv2d.pad.top,
99  conv2d_desc.conv2d.pad.bottom, DimensionRoundingType{});
101  misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, legacy_pad_stride));
102 
103  // biases shape is correct
104  if(biases != nullptr)
105  {
106  ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->dimension(0) != weights->dimension(3),
107  "Biases size and number of dst feature maps should match");
108  ARM_COMPUTE_RETURN_ERROR_ON_MSG(biases->num_dimensions() > 1,
109  "Biases should be one dimensional");
110  }
111 
112  // 2. Check support level
113  // Data type
115  // Data layout
117 
118  return Status{};
119 }
#define ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(tensor)
Definition: CLValidate.h:35
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(t,...)
Definition: Validate.h:742
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(...)
Definition: Validate.h:490
1 channel, 1 F32 per channel
DimensionRoundingType
Dimension rounding type when down-scaling on CNNs.
Definition: Types.h:550
#define ARM_COMPUTE_RETURN_ERROR_ON(cond)
If the condition is true, an error is returned.
Definition: Error.h:296
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(...)
Definition: Validate.h:284
SimpleTensor< float > src
Definition: DFT.cpp:155
1 channel, 1 F16 per channel
#define ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(...)
Definition: Validate.h:159
size_t get_data_layout_dimension_index(const DataLayout &data_layout, const DataLayoutDimension &data_layout_dimension)
Get the index of the given dimension.
Definition: Helpers.inl:193
#define ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(...)
Definition: Validate.h:541
Num samples, height, width, channels.
#define ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(t, c,...)
Definition: Validate.h:788
#define ARM_COMPUTE_RETURN_ERROR_ON_MSG(cond, msg)
If the condition is true, an error is returned.
Definition: Error.h:244
DataLayout
[DataLayout enum definition]
Definition: Types.h:113
TensorShape compute_deep_convolution_shape(const TensorShape &input_shape, DataLayout input_data_layout, const TensorShape &weights_shape, const PadStrideInfo &conv_info)
Calculate the deep convolution shape output shape of a tensor.

Field Documentation

◆ desc


The documentation for this struct was generated from the following files: