ArmNN
 25.11
Loading...
Searching...
No Matches
ClConvolution2dWorkload.hpp
Go to the documentation of this file.
1//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Tensor.hpp>
10
11#include "ClBaseWorkload.hpp"
12
13#include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h>
14#include <arm_compute/runtime/MemoryManagerOnDemand.h>
15
16#include <cl/ICLTensorProxy.hpp>
17
18#include <memory>
19
20namespace armnn
21{
22
23arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
24 const TensorInfo& output,
25 const Convolution2dDescriptor& descriptor,
26 const TensorInfo& weights,
27 const Optional<TensorInfo>& biases,
28 bool isFastMathEnabled = false,
29 const ActivationDescriptor* activationDescriptor = nullptr);
30
31class ClConvolution2dWorkload : public ClBaseWorkload<Convolution2dQueueDescriptor>
32{
33public:
35 const WorkloadInfo& info,
36 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager,
37 const arm_compute::CLCompileContext& clCompileContext,
38 const bool isFastMathEnabled = false);
39 void Execute() const override;
40
41 arm_compute::ConvolutionMethod GetConvolutionMethod() const;
42
43 bool SupportsTensorHandleReplacement() const override
44 {
45 // NCHW DataLayout on ACL still uses paddding for alignment on the Conv2d workload so importing is unreliable.
46 if (m_Data.m_Parameters.m_DataLayout == DataLayout::NCHW)
47 {
48 return false;
49 }
50 else
51 {
52 return true;
53 }
54 }
55
56
57protected:
58 void Reconfigure() override;
59
60private:
61 mutable arm_compute::CLConvolutionLayer m_ConvolutionLayer;
62
63 arm_compute::ConvolutionMethod m_ConvolutionMethod;
64
65 std::unique_ptr<ICLTensorProxy> m_InputProxy;
66 std::unique_ptr<ICLTensorProxy> m_WeightsProxy;
67 std::unique_ptr<ICLTensorProxy> m_BiasProxy;
68 std::unique_ptr<ICLTensorProxy> m_OutputProxy;
69};
70
71} //namespace armnn
72
ClBaseWorkload(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info)
ClConvolution2dWorkload(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info, std::shared_ptr< arm_compute::MemoryManagerOnDemand > &memoryManager, const arm_compute::CLCompileContext &clCompileContext, const bool isFastMathEnabled=false)
arm_compute::ConvolutionMethod GetConvolutionMethod() const
bool SupportsTensorHandleReplacement() const override
Copyright (c) 2021 ARM Limited and Contributors.
arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo &input, const TensorInfo &output, const Convolution2dDescriptor &descriptor, const TensorInfo &weights, const Optional< TensorInfo > &biases, bool isFastMathEnabled, const ActivationDescriptor *activationDescriptor)
An ActivationDescriptor for the ActivationLayer.
A Convolution2dDescriptor for the Convolution2dLayer.
Contains information about TensorInfos of a layer.