39 template<
typename FactoryType>
41 const FactoryType& factory,
49 if (factory.SupportsSubTensors())
55 && ((concatAxis == numberOfDimensions - 1) || (concatAxis == numberOfDimensions - 2));
59 std::queue<ConcatLayer*> m_ConcatLayers;
61 m_ConcatLayers.push(
this);
62 while (!m_ConcatLayers.empty())
72 bool canUseSubTensorOnXorY =
true;
73 bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
74 if (isTensorHandleFactory)
76 for (
unsigned int i = 0; i < numInputSlots; ++i)
80 std::vector<Capability> capabilities =
86 canUseSubTensorOnXorY =
false;
87 if (capabilities.empty())
89 canUseSubTensorOnXorY =
true;
96 && (PolymorphicDowncast<const Layer*>(currentLayer))->GetType() ==
LayerType::Concat)
98 canUseSubTensorOnXorY =
false;
101 if (!canUseSubTensorOnXorY)
108 std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
109 subTensors.reserve(numInputSlots);
110 for (
unsigned int i = 0; i < numInputSlots; ++i)
115 auto CreateSubTensor = [&]()
126 factoryId == slot->GetTensorHandleFactoryId() &&
129 slot->GetNumConnections() == 1 &&
130 canUseSubTensorOnXorY &&
135 return factory.CreateSubTensorHandle(*parentTensor,
140 return std::unique_ptr<ITensorHandle>();
143 auto subTensor = CreateSubTensor();
150 subTensors.push_back(std::move(subTensor));
155 if (subTensors.size() < numInputSlots)
162 for (
auto& subTensor : subTensors)
169 throw armnn::Exception(
"ConcatLayer: Expected a valid sub-tensor for substitution.");
172 outputHandler.SetData(std::move(subTensor));
174 Layer& inputLayer = slot->GetOwningLayer();
178 m_ConcatLayers.push(PolymorphicDowncast<ConcatLayer*>(&inputLayer));
188 const bool isMemoryManaged)
195 CreateTensors(registry, workloadFactory, isMemoryManaged);
204 CreateTensors(registry, *handleFactory, isMemoryManaged);
217 throw armnn::Exception(
"inputShapes' and m_NumViews' sizes do not match (\""
218 + std::to_string(inputShapes.size()) +
224 for (
unsigned int i=0; i< inputShapes.size(); i++)
226 auto& inputShape = inputShapes[i];
228 ConditionalThrowIfNotEqual<LayerValidationException>(
229 "ConcatLayer: Num Dimensions must match all inputs.",
231 inputShape.GetNumDimensions());
235 std::vector<unsigned int> extentMin(numDims);
236 std::vector<unsigned int> extentMax(numDims);
237 for (
unsigned int i = 0; i < inputShapes.size(); i++)
241 for (
unsigned int d = 0; d < numDims; d++)
243 extentMin[d] = std::min(extentMin[d], origin[d]);
244 extentMax[d] = std::max(extentMax[d], origin[d] + shape[d]);
249 if (!std::all_of(extentMin.begin(), extentMin.end(), [](
unsigned int s) { return s == 0; }))
257 for (
unsigned int a = 0; a < inputShapes.size(); a++)
261 for (
unsigned int b = 0; b < a; b++)
266 bool allAxesOverlap =
true;
267 for (
unsigned int d = 0; d < numDims && allAxesOverlap; d++)
269 unsigned int a1 = aOrigin[d];
270 unsigned int a2 = aOrigin[d] + aShape[d];
272 unsigned int b1 = bOrigin[d];
273 unsigned int b2 = bOrigin[d] + bShape[d];
275 if (a2 <= b1 || b2 <= a1)
277 allAxesOverlap =
false;
290 unsigned int totalViewsVolume = 0;
291 for (
unsigned int i = 0; i < inputShapes.size(); i++)
293 totalViewsVolume += inputShapes[i].GetNumElements();
295 unsigned int outputVolume = 1;
296 for (
unsigned int d = 0; d < numDims; d++)
298 outputVolume *= (extentMax[d] - extentMin[d]);
301 ConditionalThrowIfNotEqual<LayerValidationException>(
302 "ConcatLayer: there are some gaps between views",
306 return std::vector<TensorShape>({
TensorShape({numDims, extentMax.data()}) });
312 ConditionalThrowIfNotEqual<LayerValidationException>(
313 "ConcatLayer: Num Inputs must match num views.",
323 std::vector<TensorShape> inputShapes;
331 if (inferredShapes.size() != 1)
334 + std::to_string(inferredShapes.size()) +
335 " elements - should only have 1.");
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
#define ARMNN_NO_DEPRECATE_WARN_END
This layer represents a merge operation.
void ExecuteStrategy(IStrategy &strategy) const override
Apply a visitor to this layer.
std::vector< TensorShape > InferOutputShapes(const std::vector< TensorShape > &inputShapes) const override
By default returns inputShapes if the number of inputs are equal to number of outputs,...
virtual void CreateTensorHandles(const TensorHandleFactoryRegistry ®istry, const IWorkloadFactory &factory, const bool IsMemoryManaged=true) override
Set the outputs to be appropriate sub tensors of the input if sub tensors are supported otherwise cre...
void ValidateTensorShapesFromInputs() override
Check if the input tensor shape(s) will lead to a valid configuration of ConcatLayer.
ConcatLayer * Clone(Graph &graph) const override
Creates a dynamically-allocated copy of this layer.
ConcatLayer(const OriginsDescriptor ¶m, const char *name)
Constructor to create a ConcatLayer.
virtual std::unique_ptr< IWorkload > CreateWorkload(const IWorkloadFactory &factory) const override
Makes a workload for the Concat type.
Base class for all ArmNN exceptions so that users can filter to just those.
virtual void ExecuteStrategy(const IConnectableLayer *layer, const armnn::BaseDescriptor &descriptor, const std::vector< armnn::ConstTensor > &constants, const char *name, const armnn::LayerBindingId id=0)=0
virtual std::vector< Capability > GetCapabilities(const IConnectableLayer *layer, const IConnectableLayer *connectedLayer, CapabilityClass capabilityClass)
static const FactoryId LegacyFactoryId
virtual std::unique_ptr< IWorkload > CreateWorkload(LayerType type, const QueueDescriptor &descriptor, const WorkloadInfo &info) const =0
Backends should implement their own CreateWorkload function with a switch statement.
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation &location) const
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
void VerifyShapeInferenceType(const TensorShape &outputShape, ShapeInferenceMethod shapeInferenceMethod)
Layer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char *name)
const char * GetName() const override
Returns the name of the layer.
std::vector< OutputHandler > m_OutputHandlers
unsigned int GetNumInputSlots() const override
Returns the number of connectable input slots.
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
void ValidateAndCopyShape(const TensorShape &outputShape, const TensorShape &inferredShape, const ShapeInferenceMethod shapeInferenceMethod, const std::string &layerName, const unsigned int outputSlotIndex=0)
const OutputHandler & GetOutputHandler(unsigned int i=0) const
void SetAdditionalInfo(QueueDescriptor &descriptor) const
ShapeInferenceMethod m_ShapeInferenceMethod
WorkloadInfo PrepInfoAndDesc(QueueDescriptor &descriptor) const
Helper function to reduce duplication in *Layer::CreateWorkload.
OriginsDescriptor m_Param
The parameters for the layer (not including tensor-valued weights etc.).
const OriginsDescriptor & GetParameters() const override
const TensorInfo & GetTensorInfo() const
Gets the matching TensorInfo for the output.
ITensorHandle * GetData() const
Gets the allocated tensor memory.
const InputSlot * GetConnection(unsigned int index) const override
Layer & GetOwningLayer() const
const OutputHandler & GetOutputHandler() const
const TensorInfo & GetTensorInfo() const override
ITensorHandleFactory::FactoryId GetTensorHandleFactoryId() const
ITensorHandleFactory * GetFactory(ITensorHandleFactory::FactoryId id) const
Find a TensorHandleFactory by Id Returns nullptr if not found.
bool IsTypeSpaceMatch(const TensorInfo &other) const
Check that the types are the same and, if quantize, that the quantization parameters are the same.
const TensorShape & GetShape() const
Copyright (c) 2021 ARM Limited and Contributors.
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
std::vector< ViewOrigin > m_ViewOrigins
An OriginsDescriptor for the ConcatLayer.
uint32_t GetNumViews() const
Get the number of views.
unsigned int GetConcatAxis() const
Get the concatenation axis value.
uint32_t GetNumDimensions() const
Get the number of dimensions.
const uint32_t * GetViewOrigin(uint32_t idx) const
Return the view origin at the int value idx.