29 using LayerList = std::list<Layer*>;
30 using Iterator = LayerList::const_iterator;
32 const TensorInfo OverrideDataType(
const TensorInfo& info, Optional<DataType> type)
39 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
44 bool IWorkloadFactory::IsLayerConfigurationSupported(
const BackendId& backendId,
45 const IConnectableLayer& connectableLayer,
46 Optional<DataType> dataType,
47 std::string& outReasonIfUnsupported,
50 Optional<std::string&> reason = outReasonIfUnsupported;
52 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
55 if (!backendRegistry.IsBackendRegistered(backendId))
58 ss << connectableLayer.GetName() <<
" is not supported on " << backendId
59 <<
" because this backend is not registered.";
61 outReasonIfUnsupported = ss.str();
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
69 switch(layer.GetType())
73 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject.IsActivationSupported(
77 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
88 result = layerSupportObject.IsAdditionSupported(
89 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
97 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
98 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject.IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
118 result = layerSupportObject.IsBatchNormalizationSupported(
119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
135 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
146 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
147 OverrideDataType(output, dataType),
153 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
155 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
156 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
157 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
159 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
160 OverrideDataType(input1, dataType),
162 cLayer->GetParameters(),
168 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
169 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
174 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
175 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
176 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
181 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
182 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
183 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
188 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
189 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
190 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
195 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
196 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
197 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
202 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
204 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
206 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
209 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
212 Optional<TensorInfo> biases;
213 if (descriptor.m_BiasEnabled)
219 result = layerSupportObject.IsConvolution2dSupported(
223 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
230 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
231 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
233 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
234 OverrideDataType(output, dataType),
240 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
242 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
243 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
245 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
246 OverrideDataType(output, dataType),
247 cLayer->GetParameters(),
253 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
254 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
256 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
259 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
262 Optional<TensorInfo> biases;
263 if (descriptor.m_BiasEnabled)
269 result = layerSupportObject.IsDepthwiseConvolutionSupported(
273 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
283 result = layerSupportObject.IsDequantizeSupported(input,
284 OverrideDataType(output, dataType),
290 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
291 const TensorInfo&
boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
292 const TensorInfo&
scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
293 const TensorInfo&
anchors = cLayer->m_Anchors->GetTensorInfo();
295 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
296 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
297 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
298 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
300 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
301 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
314 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
316 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
317 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
319 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
320 OverrideDataType(output, dataType),
321 cLayer->GetParameters(),
327 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
328 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
329 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
330 const FillDescriptor& descriptor = cLayer->GetParameters();
332 result = layerSupportObject.IsFillSupported(
333 OverrideDataType(input, dataType),
334 OverrideDataType(output, dataType),
341 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
342 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
343 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
344 cLayer->GetParameters(),
350 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
351 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
352 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
353 OverrideDataType(output, dataType),
359 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
360 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
361 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
363 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
364 TensorInfo weightsInfo;
365 const TensorInfo* weightsInfoPtr =
nullptr;
367 if (descriptor.m_ConstantWeights)
370 weightsInfo = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
374 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
377 weightsInfoPtr = &weightsInfo;
380 const TensorInfo* biasInfoPtr =
nullptr;
381 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}),
DataType::BFloat16);
382 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}),
DataType::Float16);
383 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}),
DataType::Float32);
386 if (descriptor.m_BiasEnabled)
388 if(descriptor.m_ConstantWeights)
392 biasInfoPtr = &biasInfo;
396 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
397 biasInfoPtr = &biasInfo;
403 switch(input.GetDataType())
407 biasInfoPtr = &dummyBFloat16Bias;
412 biasInfoPtr = &dummyFloat16Bias;
417 biasInfoPtr = &dummyFloat32Bias;
425 biasInfoPtr = &dummyQA8Bias;
434 result = layerSupportObject.IsFullyConnectedSupported(
435 OverrideDataType(input, dataType),
436 OverrideDataType(output, dataType),
445 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
446 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
447 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
448 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
449 const GatherDescriptor& descriptor = cLayer->GetParameters();
450 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
452 OverrideDataType(output, dataType),
459 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
460 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
465 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
466 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
468 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
469 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
471 result = layerSupportObject.IsInstanceNormalizationSupported(
472 OverrideDataType(input, dataType),
473 OverrideDataType(output, dataType),
480 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
481 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
483 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
486 result = layerSupportObject.IsL2NormalizationSupported(
487 OverrideDataType(input, dataType),
488 OverrideDataType(output, dataType),
495 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
497 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
498 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
499 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
501 result = layerSupportObject.IsLogicalBinarySupported(input0,
504 cLayer->GetParameters(),
510 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
512 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
513 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
516 OverrideDataType(output, dataType),
517 cLayer->GetParameters(),
523 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
524 const LstmDescriptor& descriptor = cLayer->GetParameters();
527 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
529 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
531 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
534 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
535 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
536 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
537 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
540 const TensorInfo& inputToForgetWeights
541 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
542 const TensorInfo& inputToCellWeights
543 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
544 const TensorInfo& inputToOutputWeights
545 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
546 const TensorInfo& recurrentToForgetWeights
547 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
548 const TensorInfo& recurrentToCellWeights
549 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
550 const TensorInfo& recurrentToOutputWeights
551 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
552 const TensorInfo& forgetGateBias
553 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
554 const TensorInfo& cellBias
555 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
556 const TensorInfo& outputGateBias
557 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
559 LstmInputParamsInfo paramsInfo;
561 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
562 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
563 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
564 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
565 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
566 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
567 paramsInfo.m_ForgetGateBias = &forgetGateBias;
568 paramsInfo.m_CellBias = &cellBias;
569 paramsInfo.m_OutputGateBias = &outputGateBias;
573 TensorInfo optInputToInputWeights;
574 TensorInfo optRecurrentToInputWeights;
575 TensorInfo optCellToInputWeights;
576 TensorInfo optInputGateBias;
577 TensorInfo optProjectionWeights;
578 TensorInfo optProjectionBias;
579 TensorInfo optCellToForgetWeights;
580 TensorInfo optCellToOutputWeights;
581 TensorInfo optInputLayerNormWeights;
582 TensorInfo optForgetLayerNormWeights;
583 TensorInfo optCellLayerNormWeights;
584 TensorInfo optOutputLayerNormWeights;
586 if(!descriptor.m_CifgEnabled)
588 optInputToInputWeights =
589 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
590 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
592 optRecurrentToInputWeights =
593 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
594 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
596 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
597 paramsInfo.m_InputGateBias = &optInputGateBias;
600 if(descriptor.m_ProjectionEnabled)
602 optProjectionWeights =
603 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
604 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
605 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
608 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
609 paramsInfo.m_ProjectionBias = &optProjectionBias;
613 if(descriptor.m_PeepholeEnabled)
615 if(!descriptor.m_CifgEnabled)
617 optCellToInputWeights =
618 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
620 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
622 optCellToForgetWeights =
623 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
624 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
625 optCellToOutputWeights =
626 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
627 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
630 if(descriptor.m_LayerNormEnabled)
632 if (!descriptor.m_CifgEnabled)
634 optInputLayerNormWeights = OverrideDataType(
635 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
636 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
639 optForgetLayerNormWeights = OverrideDataType(
640 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
641 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
643 optCellLayerNormWeights = OverrideDataType(
644 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
645 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
647 optOutputLayerNormWeights = OverrideDataType(
648 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
649 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
652 result = layerSupportObject.IsLstmSupported(
667 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
668 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
669 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
671 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
672 OverrideDataType(input1, dataType),
673 OverrideDataType(output, dataType),
679 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
680 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
682 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
683 OverrideDataType(output, dataType),
689 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
692 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
693 OverrideDataType(output, dataType),
699 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
700 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
701 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
703 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
704 OverrideDataType(input1, dataType),
705 OverrideDataType(output, dataType),
711 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
714 auto getTensorInfo = [&dataType](
const InputSlot& slot)
716 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
721 std::vector<TensorInfo> inputs(beginI, endI);
723 auto getTensorInfoPtr = [](
const TensorInfo&
info)
730 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
732 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
734 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
741 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
742 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
743 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
744 result = layerSupportObject.IsMultiplicationSupported(
745 OverrideDataType(input0, dataType),
746 OverrideDataType(input1, dataType),
747 OverrideDataType(output, dataType),
753 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
754 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
755 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
756 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
757 OverrideDataType(output, dataType),
758 cLayer->GetParameters(),
764 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
765 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
770 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
771 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
772 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
773 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
774 OverrideDataType(output, dataType),
775 cLayer->GetParameters(),
781 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
782 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
783 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
784 result = layerSupportObject.IsPadSupported(
785 OverrideDataType(input, dataType),
786 OverrideDataType(output, dataType),
787 cLayer->GetParameters(),
793 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
794 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
795 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
796 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
797 OverrideDataType(output, dataType),
798 cLayer->GetParameters(),
804 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
805 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
806 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
807 cLayer->GetParameters(),
813 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
814 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
815 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
820 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
821 const QLstmDescriptor& descriptor = cLayer->GetParameters();
824 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
825 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
826 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
829 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
830 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
831 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
834 LstmInputParamsInfo paramsInfo;
837 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
838 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
839 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
841 paramsInfo.m_RecurrentToForgetWeights =
842 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
843 paramsInfo.m_RecurrentToCellWeights =
844 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
845 paramsInfo.m_RecurrentToOutputWeights =
846 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
848 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
849 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
850 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
852 if(!descriptor.m_CifgEnabled)
854 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
855 paramsInfo.m_RecurrentToInputWeights =
856 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
857 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
860 if(descriptor.m_ProjectionEnabled)
862 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
865 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
867 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
871 if(descriptor.m_PeepholeEnabled)
873 if (!descriptor.m_CifgEnabled)
875 paramsInfo.m_CellToInputWeights =
876 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
879 paramsInfo.m_CellToForgetWeights =
880 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
881 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
884 if(descriptor.m_LayerNormEnabled)
886 if (!descriptor.m_CifgEnabled)
888 paramsInfo.m_InputLayerNormWeights =
889 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
892 paramsInfo.m_ForgetLayerNormWeights =
893 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
894 paramsInfo.m_CellLayerNormWeights =
895 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
896 paramsInfo.m_OutputLayerNormWeights =
897 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
900 result = layerSupportObject.IsQLstmSupported(input,
913 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
916 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
917 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
918 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
921 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
922 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
925 QuantizedLstmInputParamsInfo paramsInfo;
927 paramsInfo.m_InputToInputWeights =
928 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
929 paramsInfo.m_InputToForgetWeights =
930 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
931 paramsInfo.m_InputToCellWeights =
932 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
933 paramsInfo.m_InputToOutputWeights =
934 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
936 paramsInfo.m_RecurrentToInputWeights =
937 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
938 paramsInfo.m_RecurrentToForgetWeights =
939 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
940 paramsInfo.m_RecurrentToCellWeights =
941 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
942 paramsInfo.m_RecurrentToOutputWeights =
943 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
945 paramsInfo.m_InputGateBias =
946 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
947 paramsInfo.m_ForgetGateBias =
948 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
949 paramsInfo.m_CellBias =
950 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
951 paramsInfo.m_OutputGateBias =
952 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
954 result = layerSupportObject.IsQuantizedLstmSupported(input,
965 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
966 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
967 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
968 result = layerSupportObject.IsDivisionSupported(
969 OverrideDataType(input0, dataType),
970 OverrideDataType(input1, dataType),
971 OverrideDataType(output, dataType),
977 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
978 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
979 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
980 OverrideDataType(output, dataType),
986 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
987 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
988 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
989 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
990 OverrideDataType(output, dataType),
991 cLayer->GetParameters(),
997 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
998 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
999 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1000 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
1001 OverrideDataType(output, dataType),
1002 cLayer->GetParameters(),
1008 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
1010 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1011 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1013 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1014 OverrideDataType(output, dataType),
1015 cLayer->GetParameters(),
1021 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
1022 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1023 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1024 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1025 OverrideDataType(output, dataType),
1026 cLayer->GetParameters(),
1032 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
1033 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1034 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1035 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1036 OverrideDataType(output, dataType),
1037 cLayer->GetParameters(),
1043 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
1045 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1046 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1048 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1049 OverrideDataType(output, dataType),
1050 cLayer->GetParameters(),
1056 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
1057 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1060 auto getTensorInfo = [&dataType](
const OutputSlot& slot)
1062 return OverrideDataType(slot.GetTensorInfo(), dataType);
1066 std::vector<TensorInfo> outputs(beginI, endI);
1068 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1070 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1072 cLayer->GetParameters(),
1078 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
1081 auto getTensorInfo = [&dataType](
const InputSlot& slot)
1083 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1087 std::vector<TensorInfo> inputs(beginI, endI);
1089 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1095 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1097 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1099 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1105 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
1108 auto getTensorInfoIn = [&dataType](
const InputSlot& slot)
1110 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1112 auto getTensorInfoOut = [&dataType](
const OutputSlot& slot)
1114 return OverrideDataType(slot.GetTensorInfo(), dataType);
1118 std::vector<TensorInfo> inputs(beginI, endI);
1122 std::vector<TensorInfo> outputs(beginO, endO);
1125 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1131 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1135 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1138 result = layerSupportObject.IsStandInSupported(inputPtrs,
1140 cLayer->GetParameters(),
1146 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
1147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1149 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1150 OverrideDataType(output, dataType),
1151 cLayer->GetParameters(),
1157 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1158 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1159 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1160 result = layerSupportObject.IsSubtractionSupported(
1161 OverrideDataType(input0, dataType),
1162 OverrideDataType(input1, dataType),
1163 OverrideDataType(output, dataType),
1169 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1170 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1171 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1172 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1173 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1174 OverrideDataType(input1, dataType),
1175 OverrideDataType(output0, dataType),
1176 OverrideDataType(output1, dataType),
1182 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
1183 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1184 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1185 result = layerSupportObject.IsMeanSupported(
1186 OverrideDataType(input, dataType),
1187 OverrideDataType(output, dataType),
1188 cLayer->GetParameters(),
1194 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1195 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1196 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1197 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1198 OverrideDataType(input1, dataType),
1199 OverrideDataType(output, dataType),
1205 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1206 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1207 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1208 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1209 OverrideDataType(alpha, dataType),
1210 OverrideDataType(output, dataType),
1216 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
1217 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1218 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1219 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1220 OverrideDataType(output, dataType),
1221 cLayer->GetParameters(),
1227 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
1229 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1231 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1233 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1235 Optional<TensorInfo> biases;
1236 if (descriptor.m_BiasEnabled)
1239 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1244 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1246 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1257 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1258 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1259 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1261 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1262 OverrideDataType(output, dataType),
1263 cLayer->GetParameters(),
1269 ARMNN_ASSERT_MSG(
false,
"WorkloadFactory did not recognise type of layer.");
1270 reason.value() =
"Unrecognised layer type";
1281 std::string& outReasonIfUnsupported)
1283 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1288 std::string& outReasonIfUnsupported)
1290 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1291 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1297 std::string& outReasonIfUnsupported,
1300 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1301 return IsLayerConfigurationSupported(layer->GetBackendId(),
1304 outReasonIfUnsupported,
1311 std::string& outReasonIfUnsupported,
1314 return IsLayerConfigurationSupported(backendId,
1317 outReasonIfUnsupported,
1325 return std::unique_ptr<IWorkload>();
1331 return std::unique_ptr<IWorkload>();
1337 return std::unique_ptr<IWorkload>();
1343 return std::unique_ptr<IWorkload>();
1349 return std::unique_ptr<IWorkload>();
1355 return std::unique_ptr<IWorkload>();
1361 return std::unique_ptr<IWorkload>();
1367 return std::unique_ptr<IWorkload>();
1373 return std::unique_ptr<IWorkload>();
1379 return std::unique_ptr<IWorkload>();
1385 return std::unique_ptr<IWorkload>();
1391 return std::unique_ptr<IWorkload>();
1397 return std::unique_ptr<IWorkload>();
1403 return std::unique_ptr<IWorkload>();
1409 return std::unique_ptr<IWorkload>();
1415 return std::unique_ptr<IWorkload>();
1421 return std::unique_ptr<IWorkload>();
1427 return std::unique_ptr<IWorkload>();
1433 return std::unique_ptr<IWorkload>();
1439 return std::unique_ptr<IWorkload>();
1445 return std::unique_ptr<IWorkload>();
1451 return std::unique_ptr<IWorkload>();
1457 return std::unique_ptr<IWorkload>();
1463 return std::unique_ptr<IWorkload>();
1469 return std::unique_ptr<IWorkload>();
1475 return std::unique_ptr<IWorkload>();
1481 return std::unique_ptr<IWorkload>();
1487 return std::unique_ptr<IWorkload>();
1493 return std::unique_ptr<IWorkload>();
1500 return std::unique_ptr<IWorkload>();
1506 return std::unique_ptr<IWorkload>();
1512 return std::unique_ptr<IWorkload>();
1518 return std::unique_ptr<IWorkload>();
1524 return std::unique_ptr<IWorkload>();
1530 return std::unique_ptr<IWorkload>();
1536 return std::unique_ptr<IWorkload>();
1542 return std::unique_ptr<IWorkload>();
1548 return std::unique_ptr<IWorkload>();
1554 return std::unique_ptr<IWorkload>();
1560 return std::unique_ptr<IWorkload>();
1566 return std::unique_ptr<IWorkload>();
1572 return std::unique_ptr<IWorkload>();
1578 return std::unique_ptr<IWorkload>();
1584 return std::unique_ptr<IWorkload>();
1590 return std::unique_ptr<IWorkload>();
1596 return std::unique_ptr<IWorkload>();
1602 return std::unique_ptr<IWorkload>();
1608 return std::unique_ptr<IWorkload>();
1614 return std::unique_ptr<IWorkload>();
1620 return std::unique_ptr<IWorkload>();
1626 return std::unique_ptr<IWorkload>();
1632 return std::unique_ptr<IWorkload>();
1638 return std::unique_ptr<IWorkload>();
1643 return std::unique_ptr<IWorkload>();
1649 return std::unique_ptr<IWorkload>();
1655 return std::unique_ptr<IWorkload>();
1661 return std::unique_ptr<IWorkload>();
1667 return std::unique_ptr<IWorkload>();
1673 return std::unique_ptr<IWorkload>();
1679 return std::unique_ptr<IWorkload>();
1685 return std::unique_ptr<IWorkload>();
1691 return std::unique_ptr<IWorkload>();
1697 return std::unique_ptr<IWorkload>();
1703 return std::unique_ptr<IWorkload>();
1709 return std::unique_ptr<IWorkload>();
1715 return std::unique_ptr<IWorkload>();
1721 return std::unique_ptr<IWorkload>();
1727 return std::unique_ptr<IWorkload>();
1733 return std::unique_ptr<IWorkload>();
1740 return std::unique_ptr<IWorkload>();
virtual std::unique_ptr< IWorkload > CreateSplitter(const SplitterQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateBatchNormalization(const BatchNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDebug(const DebugQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateL2Normalization(const L2NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
virtual std::unique_ptr< IWorkload > CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateMultiplication(const MultiplicationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateGreater(const GreaterQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateArgMinMax(const ArgMinMaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMerger(const MergerQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateLogSoftmax(const LogSoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateResizeBilinear(const ResizeBilinearQueueDescriptor &descriptor, const WorkloadInfo &info) const
std::vector< BackendOptions > ModelOptions
virtual std::unique_ptr< IWorkload > CreateStridedSlice(const StridedSliceQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateStack(const StackQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
constexpr TransformIterator< Function, Iterator > MakeTransformIterator(Iterator i, Function f)
virtual std::unique_ptr< IWorkload > CreateFakeQuantization(const FakeQuantizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQuantizedLstm(const QuantizedLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQLstm(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConstant(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info) const
BackendRegistry & BackendRegistryInstance()
virtual std::unique_ptr< IWorkload > CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
std::vector< float > boxEncodings({ 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, -1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f })
virtual std::unique_ptr< IWorkload > CreateAbs(const AbsQueueDescriptor &descriptor, const WorkloadInfo &info) const
Copyright (c) 2021 ARM Limited and Contributors.
virtual std::unique_ptr< IWorkload > CreateActivation(const ActivationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateRsqrt(const RsqrtQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTranspose(const TransposeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDivision(const DivisionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMaximum(const MaximumQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConcat(const ConcatQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMerge(const MergeQueueDescriptor &descriptor, const WorkloadInfo &info) const
armnn::Optional< armnn::DataType > GetBiasTypeFromWeightsType(armnn::Optional< armnn::DataType > weightsType)
virtual std::unique_ptr< IWorkload > CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateEqual(const EqualQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateRank(const RankQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateResize(const ResizeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateCast(const CastQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual std::unique_ptr< IWorkload > CreateQuantize(const QuantizeQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReduce(const ReduceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSwitch(const SwitchQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreatePad(const PadQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT(COND)
static bool IsLayerSupported(const BackendId &backendId, const IConnectableLayer &layer, Optional< DataType > dataType, std::string &outReasonIfUnsupported)
virtual std::unique_ptr< IWorkload > CreateNormalization(const NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalBinary(const LogicalBinaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReshape(const ReshapeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePermute(const PermuteQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFill(const FillQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateComparison(const ComparisonQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreatePooling2d(const Pooling2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToDepth(const SpaceToDepthQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMinimum(const MinimumQueueDescriptor &descriptor, const WorkloadInfo &info) const
std::vector< float > scores({ 0.0f, 0.9f, 0.8f, 0.0f, 0.75f, 0.72f, 0.0f, 0.6f, 0.5f, 0.0f, 0.93f, 0.95f, 0.0f, 0.5f, 0.4f, 0.0f, 0.3f, 0.2f })
virtual std::unique_ptr< IWorkload > CreateDepthToSpace(const DepthToSpaceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSlice(const SliceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateAddition(const AdditionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMean(const MeanQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateOutput(const OutputQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSoftmax(const SoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
Contains information about inputs and outputs to a layer.
virtual std::unique_ptr< IWorkload > CreateFullyConnected(const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFloor(const FloorQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemImport(const MemImportQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSubtraction(const SubtractionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePreCompiled(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvolution2d(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePrelu(const PreluQueueDescriptor &descriptor, const WorkloadInfo &info) const
std::vector< float > anchors({ 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 10.5f, 1.0f, 1.0f, 0.5f, 10.5f, 1.0f, 1.0f, 0.5f, 100.5f, 1.0f, 1.0f })
virtual std::unique_ptr< IWorkload > CreateDequantize(const DequantizeQueueDescriptor &descriptor, const WorkloadInfo &info) const