24.02
|
Go to the documentation of this file.
37 template<
typename FactoryType>
39 const FactoryType& factory,
44 bool useSubTensors = factory.SupportsSubTensors();
55 std::vector<std::unique_ptr<ITensorHandle>> subTensors;
65 std::set<unsigned int> splitAxis;
67 for (
unsigned int i = 0; i < numSplit; ++i)
69 for (
unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
73 splitAxis.insert(dimIdx);
81 std::set<unsigned int>::iterator axisIt = axis.begin();
84 ((*axisIt == numberOfDimensions - 1) ||
85 (*axisIt == numberOfDimensions - 2));
98 bool canUseSubTensorOnXorY =
true;
99 bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
100 if (isTensorHandleFactory)
102 for (
unsigned int it = 0; it < numOutputSlots; ++it)
105 ITensorHandleFactory* handleFactory = registry.
GetFactory(factoryId);
106 std::vector<Capability> capabilities =
112 canUseSubTensorOnXorY =
false;
113 if (capabilities.empty())
115 canUseSubTensorOnXorY =
true;
119 if (!canUseSubTensorOnXorY)
126 auto CreateSubTensor = [&]()
137 canUseSubTensorOnXorY)
140 return factory.CreateSubTensorHandle(*inputData,
145 return std::unique_ptr<ITensorHandle>();
148 auto subTensor = CreateSubTensor();
151 useSubTensors =
false;
154 subTensors.push_back(std::move(subTensor));
160 for (
auto& subTensor : subTensors)
179 const bool isMemoryManaged)
186 CreateTensors(registry, workloadFactory, isMemoryManaged);
192 CreateTensors(registry, *handleFactory, isMemoryManaged);
205 std::vector<TensorShape> outShapes;
222 std::vector<TensorShape> views;
236 inferredShapes[viewIdx],
#define ARMNN_ASSERT(COND)
A ViewsDescriptor for the SplitterLayer.
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
This layer represents a split operation.
const TensorInfo & GetTensorInfo() const override
void Splitter(const SplitterQueueDescriptor &data, std::vector< ITensorHandle * > inputs, std::vector< ITensorHandle * > outputs)
ITensorHandleFactory * GetFactory(ITensorHandleFactory::FactoryId id) const
Find a TensorHandleFactory by Id Returns nullptr if not found.
SplitterLayer(const ViewsDescriptor ¶m, const char *name)
Constructor to create a SplitterLayer.
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of SplitterLayer.
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
std::vector< OutputHandler > m_OutputHandlers
const std::vector< InputSlot > & GetInputSlots() const
std::vector< ViewOrigin > m_ViewOrigins
std::set< unsigned int > ComputeSplitAxis(const armnn::SplitterDescriptor &desc, const TensorShape &input)
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
const ViewsDescriptor & GetParameters() const override
ITensorHandle * GetData() const
Gets the allocated tensor memory.
virtual void CreateTensorHandles(const TensorHandleFactoryRegistry ®istry, const IWorkloadFactory &factory, const bool IsMemoryManaged=true) override
Set the outputs to be appropriate sub tensors of the input if sub tensors are supported otherwise cre...
const char * GetName() const override
Returns the name of the layer.
static const FactoryId LegacyFactoryId
virtual std::vector< Capability > GetCapabilities(const IConnectableLayer *layer, const IConnectableLayer *connectedLayer, CapabilityClass capabilityClass)
const uint32_t * GetViewSizes(uint32_t idx) const
Get the view sizes at the int value idx.
SplitterLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
ViewsDescriptor m_Param
The parameters for the layer (not including tensor-valued weights etc.).
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *Layer::CreateWorkload.
bool IsTypeSpaceMatch(const TensorInfo &other) const
Check that the types are the same and, if quantize, that the quantization parameters are the same.
unsigned int GetNumOutputSlots() const override
Returns the number of connectable output slots.
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
void SetAdditionalInfo(QueueDescriptor &descriptor) const
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs,...
std::vector< OutputSlot >::iterator BeginOutputSlots()
const TensorShape & GetShape() const
#define ARMNN_NO_DEPRECATE_WARN_END
void IgnoreUnused(Ts &&...)
uint32_t GetNumDimensions() const
Get the number of dimensions.
ITensorHandleFactory::FactoryId GetTensorHandleFactoryId() const
Copyright (c) 2021 ARM Limited and Contributors.
uint32_t GetNumViews() const
Get the number of views.
ShapeInferenceMethod m_ShapeInferenceMethod
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
const InputSlot * GetConnection(unsigned int index) const override
std::vector< OutputSlot >::iterator EndOutputSlots()
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const =0
Backends should implement their own CreateWorkload function with a switch statement.
virtual void ExecuteStrategy(const IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
const uint32_t * GetViewOrigin(uint32_t idx) const
Get the view origin at the int value idx.
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the Splitter type.