29 using LayerList = std::list<Layer*>;
30 using Iterator = LayerList::const_iterator;
32 const TensorInfo OverrideDataType(
const TensorInfo& info, Optional<DataType> type)
39 return TensorInfo(info.GetShape(),
41 info.GetQuantizationScale(),
42 info.GetQuantizationOffset(),
48 bool IWorkloadFactory::IsLayerConfigurationSupported(
const BackendId& backendId,
49 const IConnectableLayer& connectableLayer,
50 Optional<DataType> dataType,
51 std::string& outReasonIfUnsupported,
54 Optional<std::string&> reason = outReasonIfUnsupported;
56 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
59 if (!backendRegistry.IsBackendRegistered(backendId))
62 ss << connectableLayer.GetName() <<
" is not supported on " << backendId
63 <<
" because this backend is not registered.";
65 outReasonIfUnsupported = ss.str();
69 auto backendFactory = backendRegistry.GetFactory(backendId);
70 auto backendObject = backendFactory();
71 auto layerSupportObject = LayerSupportHandle(backendObject->GetLayerSupport(modelOptions), backendId);
73 switch(layer.GetType())
77 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
78 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
79 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
80 result = layerSupportObject.IsActivationSupported(
81 OverrideDataType(input, dataType),
82 OverrideDataType(output, dataType),
83 cLayer->GetParameters(),
89 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
90 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
91 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
92 result = layerSupportObject.IsAdditionSupported(
93 OverrideDataType(input0, dataType),
94 OverrideDataType(input1, dataType),
95 OverrideDataType(output, dataType),
101 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
102 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
104 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
105 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
106 result = layerSupportObject.IsArgMinMaxSupported(
107 OverrideDataType(input, dataType),
115 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
116 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
119 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
120 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
121 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
122 result = layerSupportObject.IsBatchNormalizationSupported(
123 OverrideDataType(input, dataType),
124 OverrideDataType(output, dataType),
125 OverrideDataType(mean, dataType),
126 OverrideDataType(var, dataType),
127 OverrideDataType(beta, dataType),
128 OverrideDataType(gamma, dataType),
129 cLayer->GetParameters(),
135 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
136 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
137 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
139 result = layerSupportObject.IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
140 OverrideDataType(output, dataType),
141 cLayer->GetParameters(),
147 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
148 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
150 result = layerSupportObject.IsCastSupported(OverrideDataType(input, dataType),
151 OverrideDataType(output, dataType),
157 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
159 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
161 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
163 result = layerSupportObject.IsComparisonSupported(OverrideDataType(input0, dataType),
164 OverrideDataType(input1, dataType),
166 cLayer->GetParameters(),
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
173 result = layerSupportObject.IsConstantSupported(OverrideDataType(output, dataType), reason);
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180 result = layerSupportObject.IsConvertBf16ToFp32Supported(input, output, reason);
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
187 result = layerSupportObject.IsConvertFp16ToFp32Supported(input, output, reason);
192 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
193 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
194 result = layerSupportObject.IsConvertFp32ToBf16Supported(input, output, reason);
199 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
200 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
201 result = layerSupportObject.IsConvertFp32ToFp16Supported(input, output, reason);
206 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
208 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
210 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
213 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
216 Optional<TensorInfo> biases;
217 if (descriptor.m_BiasEnabled)
223 result = layerSupportObject.IsConvolution2dSupported(
227 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
234 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
235 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
237 result = layerSupportObject.IsDebugSupported(OverrideDataType(input, dataType),
238 OverrideDataType(output, dataType),
244 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
246 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
247 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
249 result = layerSupportObject.IsDepthToSpaceSupported(OverrideDataType(input, dataType),
250 OverrideDataType(output, dataType),
251 cLayer->GetParameters(),
257 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
258 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
260 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
263 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
266 Optional<TensorInfo> biases;
267 if (descriptor.m_BiasEnabled)
273 result = layerSupportObject.IsDepthwiseConvolutionSupported(
277 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
284 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
285 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
287 result = layerSupportObject.IsDequantizeSupported(input,
288 OverrideDataType(output, dataType),
294 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
295 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
296 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
297 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
299 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
300 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
301 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
302 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
304 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
305 result = layerSupportObject.IsDetectionPostProcessSupported(boxEncodings,
318 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
320 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
321 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
323 result = layerSupportObject.IsElementwiseUnarySupported(OverrideDataType(input, dataType),
324 OverrideDataType(output, dataType),
325 cLayer->GetParameters(),
331 auto cLayer = PolymorphicDowncast<const FillLayer*>(&layer);
332 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
333 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
334 const FillDescriptor& descriptor = cLayer->GetParameters();
336 result = layerSupportObject.IsFillSupported(
337 OverrideDataType(input, dataType),
338 OverrideDataType(output, dataType),
345 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
346 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
347 result = layerSupportObject.IsFakeQuantizationSupported(OverrideDataType(input, dataType),
348 cLayer->GetParameters(),
354 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
355 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
356 result = layerSupportObject.IsFloorSupported(OverrideDataType(input, dataType),
357 OverrideDataType(output, dataType),
363 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
364 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
365 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
367 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
368 TensorInfo weightsInfo;
369 const TensorInfo* weightsInfoPtr =
nullptr;
371 weightsInfo = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), dataType);
372 weightsInfoPtr = &weightsInfo;
375 const TensorInfo* biasInfoPtr =
nullptr;
376 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}),
DataType::BFloat16);
377 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}),
DataType::Float16);
378 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}),
DataType::Float32);
381 if (descriptor.m_BiasEnabled)
383 biasInfo = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), dataType);
384 biasInfoPtr = &biasInfo;
389 switch(input.GetDataType())
393 biasInfoPtr = &dummyBFloat16Bias;
398 biasInfoPtr = &dummyFloat16Bias;
403 biasInfoPtr = &dummyFloat32Bias;
411 biasInfoPtr = &dummyQA8Bias;
420 result = layerSupportObject.IsFullyConnectedSupported(
421 OverrideDataType(input, dataType),
422 OverrideDataType(output, dataType),
431 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
432 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
433 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
434 auto cLayer = PolymorphicDowncast<const GatherLayer*>(&layer);
435 const GatherDescriptor& descriptor = cLayer->GetParameters();
436 result = layerSupportObject.IsGatherSupported(OverrideDataType(input0, dataType),
438 OverrideDataType(output, dataType),
445 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
446 result = layerSupportObject.IsInputSupported(OverrideDataType(input, dataType), reason);
451 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
452 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
454 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
455 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
457 result = layerSupportObject.IsInstanceNormalizationSupported(
458 OverrideDataType(input, dataType),
459 OverrideDataType(output, dataType),
466 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
467 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
469 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
470 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
472 result = layerSupportObject.IsL2NormalizationSupported(
473 OverrideDataType(input, dataType),
474 OverrideDataType(output, dataType),
481 auto cLayer = PolymorphicDowncast<const LogicalBinaryLayer*>(&layer);
483 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
484 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
485 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
487 result = layerSupportObject.IsLogicalBinarySupported(input0,
490 cLayer->GetParameters(),
496 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
498 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
499 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
501 result = layerSupportObject.IsLogSoftmaxSupported(OverrideDataType(input, dataType),
502 OverrideDataType(output, dataType),
503 cLayer->GetParameters(),
509 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
510 const LstmDescriptor& descriptor = cLayer->GetParameters();
513 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
515 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
517 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
520 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
521 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
522 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
523 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
526 const TensorInfo& inputToForgetWeights
527 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
528 const TensorInfo& inputToCellWeights
529 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
530 const TensorInfo& inputToOutputWeights
531 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
532 const TensorInfo& recurrentToForgetWeights
533 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
534 const TensorInfo& recurrentToCellWeights
535 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
536 const TensorInfo& recurrentToOutputWeights
537 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
538 const TensorInfo& forgetGateBias
539 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
540 const TensorInfo& cellBias
541 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
542 const TensorInfo& outputGateBias
543 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
545 LstmInputParamsInfo paramsInfo;
547 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
548 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
549 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
550 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
551 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
552 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
553 paramsInfo.m_ForgetGateBias = &forgetGateBias;
554 paramsInfo.m_CellBias = &cellBias;
555 paramsInfo.m_OutputGateBias = &outputGateBias;
559 TensorInfo optInputToInputWeights;
560 TensorInfo optRecurrentToInputWeights;
561 TensorInfo optCellToInputWeights;
562 TensorInfo optInputGateBias;
563 TensorInfo optProjectionWeights;
564 TensorInfo optProjectionBias;
565 TensorInfo optCellToForgetWeights;
566 TensorInfo optCellToOutputWeights;
567 TensorInfo optInputLayerNormWeights;
568 TensorInfo optForgetLayerNormWeights;
569 TensorInfo optCellLayerNormWeights;
570 TensorInfo optOutputLayerNormWeights;
572 if(!descriptor.m_CifgEnabled)
574 optInputToInputWeights =
575 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
576 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
578 optRecurrentToInputWeights =
579 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
580 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
582 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
583 paramsInfo.m_InputGateBias = &optInputGateBias;
586 if(descriptor.m_ProjectionEnabled)
588 optProjectionWeights =
589 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
590 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
591 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
594 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
595 paramsInfo.m_ProjectionBias = &optProjectionBias;
599 if(descriptor.m_PeepholeEnabled)
601 if(!descriptor.m_CifgEnabled)
603 optCellToInputWeights =
604 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
606 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
608 optCellToForgetWeights =
609 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
610 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
611 optCellToOutputWeights =
612 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
613 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
616 if(descriptor.m_LayerNormEnabled)
618 if (!descriptor.m_CifgEnabled)
620 optInputLayerNormWeights = OverrideDataType(
621 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
622 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
625 optForgetLayerNormWeights = OverrideDataType(
626 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
627 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
629 optCellLayerNormWeights = OverrideDataType(
630 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
631 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
633 optOutputLayerNormWeights = OverrideDataType(
634 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
635 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
638 result = layerSupportObject.IsLstmSupported(
653 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
654 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
655 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
657 result = layerSupportObject.IsMaximumSupported(OverrideDataType(input0, dataType),
658 OverrideDataType(input1, dataType),
659 OverrideDataType(output, dataType),
665 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
666 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
668 result = layerSupportObject.IsMemCopySupported(OverrideDataType(input, dataType),
669 OverrideDataType(output, dataType),
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
678 result = layerSupportObject.IsMemImportSupported(OverrideDataType(input, dataType),
679 OverrideDataType(output, dataType),
685 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
686 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
687 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
689 result = layerSupportObject.IsMergeSupported(OverrideDataType(input0, dataType),
690 OverrideDataType(input1, dataType),
691 OverrideDataType(output, dataType),
697 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
700 auto getTensorInfo = [&dataType](
const InputSlot& slot)
702 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
707 std::vector<TensorInfo> inputs(beginI, endI);
709 auto getTensorInfoPtr = [](
const TensorInfo&
info)
716 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
718 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
720 result = layerSupportObject.IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
727 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
728 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
729 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
730 result = layerSupportObject.IsMultiplicationSupported(
731 OverrideDataType(input0, dataType),
732 OverrideDataType(input1, dataType),
733 OverrideDataType(output, dataType),
739 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
740 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
741 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
742 result = layerSupportObject.IsNormalizationSupported(OverrideDataType(input, dataType),
743 OverrideDataType(output, dataType),
744 cLayer->GetParameters(),
750 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
751 result = layerSupportObject.IsOutputSupported(OverrideDataType(output, dataType), reason);
756 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
757 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
758 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
759 result = layerSupportObject.IsPermuteSupported(OverrideDataType(input, dataType),
760 OverrideDataType(output, dataType),
761 cLayer->GetParameters(),
767 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
768 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
769 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
770 result = layerSupportObject.IsPadSupported(
771 OverrideDataType(input, dataType),
772 OverrideDataType(output, dataType),
773 cLayer->GetParameters(),
779 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
780 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
781 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
782 result = layerSupportObject.IsPooling2dSupported(OverrideDataType(input, dataType),
783 OverrideDataType(output, dataType),
784 cLayer->GetParameters(),
790 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
791 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
792 result = layerSupportObject.IsPreCompiledSupported(OverrideDataType(input, dataType),
793 cLayer->GetParameters(),
799 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
800 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
801 result = layerSupportObject.IsQuantizeSupported(input, output, reason);
806 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
807 const QLstmDescriptor& descriptor = cLayer->GetParameters();
810 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
811 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
812 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
815 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
816 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
817 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
820 LstmInputParamsInfo paramsInfo;
823 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToForgetWeights.get() !=
nullptr);
824 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToCellWeights.get() !=
nullptr);
825 ARMNN_ASSERT(cLayer->m_BasicParameters.m_InputToOutputWeights.get() !=
nullptr);
826 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
827 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
828 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
830 paramsInfo.m_RecurrentToForgetWeights =
831 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
832 paramsInfo.m_RecurrentToCellWeights =
833 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
834 paramsInfo.m_RecurrentToOutputWeights =
835 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
837 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
838 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
839 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
841 if(!descriptor.m_CifgEnabled)
843 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
844 paramsInfo.m_RecurrentToInputWeights =
845 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
846 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
849 if(descriptor.m_ProjectionEnabled)
851 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
854 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
856 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
860 if(descriptor.m_PeepholeEnabled)
862 if (!descriptor.m_CifgEnabled)
864 paramsInfo.m_CellToInputWeights =
865 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
868 paramsInfo.m_CellToForgetWeights =
869 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
870 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
873 if(descriptor.m_LayerNormEnabled)
875 if (!descriptor.m_CifgEnabled)
877 paramsInfo.m_InputLayerNormWeights =
878 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
881 paramsInfo.m_ForgetLayerNormWeights =
882 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
883 paramsInfo.m_CellLayerNormWeights =
884 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
885 paramsInfo.m_OutputLayerNormWeights =
886 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
889 result = layerSupportObject.IsQLstmSupported(input,
902 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
905 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
906 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
907 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
910 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
911 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
914 QuantizedLstmInputParamsInfo paramsInfo;
916 paramsInfo.m_InputToInputWeights =
917 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
918 paramsInfo.m_InputToForgetWeights =
919 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
920 paramsInfo.m_InputToCellWeights =
921 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
922 paramsInfo.m_InputToOutputWeights =
923 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
925 paramsInfo.m_RecurrentToInputWeights =
926 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
927 paramsInfo.m_RecurrentToForgetWeights =
928 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
929 paramsInfo.m_RecurrentToCellWeights =
930 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
931 paramsInfo.m_RecurrentToOutputWeights =
932 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
934 paramsInfo.m_InputGateBias =
935 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
936 paramsInfo.m_ForgetGateBias =
937 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
938 paramsInfo.m_CellBias =
939 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
940 paramsInfo.m_OutputGateBias =
941 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
943 result = layerSupportObject.IsQuantizedLstmSupported(input,
954 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
955 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
956 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
957 result = layerSupportObject.IsDivisionSupported(
958 OverrideDataType(input0, dataType),
959 OverrideDataType(input1, dataType),
960 OverrideDataType(output, dataType),
966 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
967 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
968 result = layerSupportObject.IsRankSupported(OverrideDataType(input, dataType),
969 OverrideDataType(output, dataType),
975 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
976 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
977 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
978 result = layerSupportObject.IsReshapeSupported(OverrideDataType(input, dataType),
979 OverrideDataType(output, dataType),
980 cLayer->GetParameters(),
986 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
987 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
988 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
989 result = layerSupportObject.IsResizeSupported(OverrideDataType(input, dataType),
990 OverrideDataType(output, dataType),
991 cLayer->GetParameters(),
997 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
998 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1000 result = layerSupportObject.IsShapeSupported(OverrideDataType(input, dataType),
1001 OverrideDataType(output, dataType),
1007 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
1009 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1010 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1012 result = layerSupportObject.IsSliceSupported(OverrideDataType(input, dataType),
1013 OverrideDataType(output, dataType),
1014 cLayer->GetParameters(),
1020 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
1021 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1022 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1023 result = layerSupportObject.IsSoftmaxSupported(OverrideDataType(input, dataType),
1024 OverrideDataType(output, dataType),
1025 cLayer->GetParameters(),
1031 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
1032 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1033 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1034 result = layerSupportObject.IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
1035 OverrideDataType(output, dataType),
1036 cLayer->GetParameters(),
1042 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
1044 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1045 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1047 result = layerSupportObject.IsSpaceToDepthSupported(OverrideDataType(input, dataType),
1048 OverrideDataType(output, dataType),
1049 cLayer->GetParameters(),
1055 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
1056 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1059 auto getTensorInfo = [&dataType](
const OutputSlot& slot)
1061 return OverrideDataType(slot.GetTensorInfo(), dataType);
1065 std::vector<TensorInfo> outputs(beginI, endI);
1067 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
1069 result = layerSupportObject.IsSplitterSupported(OverrideDataType(input, dataType),
1071 cLayer->GetParameters(),
1077 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
1080 auto getTensorInfo = [&dataType](
const InputSlot& slot)
1082 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1086 std::vector<TensorInfo> inputs(beginI, endI);
1088 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1094 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1096 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1098 result = layerSupportObject.IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1104 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
1107 auto getTensorInfoIn = [&dataType](
const InputSlot& slot)
1109 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1111 auto getTensorInfoOut = [&dataType](
const OutputSlot& slot)
1113 return OverrideDataType(slot.GetTensorInfo(), dataType);
1117 std::vector<TensorInfo> inputs(beginI, endI);
1121 std::vector<TensorInfo> outputs(beginO, endO);
1124 auto getTensorInfoPtr = [](
const TensorInfo&
info)
1130 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1134 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1137 result = layerSupportObject.IsStandInSupported(inputPtrs,
1139 cLayer->GetParameters(),
1145 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
1146 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1148 result = layerSupportObject.IsStridedSliceSupported(OverrideDataType(input, dataType),
1149 OverrideDataType(output, dataType),
1150 cLayer->GetParameters(),
1156 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1157 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1159 result = layerSupportObject.IsSubtractionSupported(
1160 OverrideDataType(input0, dataType),
1161 OverrideDataType(input1, dataType),
1162 OverrideDataType(output, dataType),
1168 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1169 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1170 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1171 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1172 result = layerSupportObject.IsSwitchSupported(OverrideDataType(input0, dataType),
1173 OverrideDataType(input1, dataType),
1174 OverrideDataType(output0, dataType),
1175 OverrideDataType(output1, dataType),
1181 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
1182 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1183 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1184 result = layerSupportObject.IsMeanSupported(
1185 OverrideDataType(input, dataType),
1186 OverrideDataType(output, dataType),
1187 cLayer->GetParameters(),
1193 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1194 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1195 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1196 result = layerSupportObject.IsMinimumSupported(OverrideDataType(input0, dataType),
1197 OverrideDataType(input1, dataType),
1198 OverrideDataType(output, dataType),
1204 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1205 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1206 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1207 result = layerSupportObject.IsPreluSupported(OverrideDataType(input, dataType),
1208 OverrideDataType(alpha, dataType),
1209 OverrideDataType(output, dataType),
1215 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
1216 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1217 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1218 result = layerSupportObject.IsTransposeSupported(OverrideDataType(input, dataType),
1219 OverrideDataType(output, dataType),
1220 cLayer->GetParameters(),
1226 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
1228 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1230 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1232 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1234 Optional<TensorInfo> biases;
1235 if (descriptor.m_BiasEnabled)
1238 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1243 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1245 result = layerSupportObject.IsTransposeConvolution2dSupported(input,
1256 auto cLayer = PolymorphicDowncast<const ReduceLayer*>(&layer);
1257 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1258 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1260 result = layerSupportObject.IsReduceSupported(OverrideDataType(input, dataType),
1261 OverrideDataType(output, dataType),
1262 cLayer->GetParameters(),
1268 auto cLayer = PolymorphicDowncast<const UnidirectionalSequenceLstmLayer*>(&layer);
1272 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1274 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
1276 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
1279 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1282 const TensorInfo& inputToForgetWeights
1283 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
1284 const TensorInfo& inputToCellWeights
1285 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
1286 const TensorInfo& inputToOutputWeights
1287 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
1288 const TensorInfo& recurrentToForgetWeights
1289 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
1290 const TensorInfo& recurrentToCellWeights
1291 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
1292 const TensorInfo& recurrentToOutputWeights
1293 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
1294 const TensorInfo& forgetGateBias
1295 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
1296 const TensorInfo& cellBias
1297 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
1298 const TensorInfo& outputGateBias
1299 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
1301 LstmInputParamsInfo paramsInfo;
1303 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
1304 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
1305 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
1306 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
1307 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
1308 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
1309 paramsInfo.m_ForgetGateBias = &forgetGateBias;
1310 paramsInfo.m_CellBias = &cellBias;
1311 paramsInfo.m_OutputGateBias = &outputGateBias;
1314 TensorInfo optInputToInputWeights;
1315 TensorInfo optRecurrentToInputWeights;
1316 TensorInfo optCellToInputWeights;
1317 TensorInfo optInputGateBias;
1318 TensorInfo optProjectionWeights;
1319 TensorInfo optProjectionBias;
1320 TensorInfo optCellToForgetWeights;
1321 TensorInfo optCellToOutputWeights;
1322 TensorInfo optInputLayerNormWeights;
1323 TensorInfo optForgetLayerNormWeights;
1324 TensorInfo optCellLayerNormWeights;
1325 TensorInfo optOutputLayerNormWeights;
1327 if(!descriptor.m_CifgEnabled)
1329 optInputToInputWeights =
1330 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
1331 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
1333 optRecurrentToInputWeights =
1334 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
1335 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
1337 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
1338 paramsInfo.m_InputGateBias = &optInputGateBias;
1341 if(descriptor.m_ProjectionEnabled)
1343 optProjectionWeights =
1344 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
1345 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
1346 if (cLayer->m_ProjectionParameters.m_ProjectionBias !=
nullptr)
1349 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
1350 paramsInfo.m_ProjectionBias = &optProjectionBias;
1354 if(descriptor.m_PeepholeEnabled)
1356 if(!descriptor.m_CifgEnabled)
1358 optCellToInputWeights =
1359 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
1361 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
1363 optCellToForgetWeights =
1364 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
1365 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
1366 optCellToOutputWeights =
1367 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
1368 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
1371 if(descriptor.m_LayerNormEnabled)
1373 if (!descriptor.m_CifgEnabled)
1375 optInputLayerNormWeights = OverrideDataType(
1376 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
1377 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
1380 optForgetLayerNormWeights = OverrideDataType(
1381 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
1382 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
1384 optCellLayerNormWeights = OverrideDataType(
1385 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
1386 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
1388 optOutputLayerNormWeights = OverrideDataType(
1389 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
1390 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
1393 Optional<TensorInfo> hiddenStateOut;
1394 Optional<TensorInfo> cellStateOut;
1396 result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input,
1409 ARMNN_ASSERT_MSG(
false,
"WorkloadFactory did not recognise type of layer.");
1410 reason.value() =
"Unrecognised layer type";
1421 std::string& outReasonIfUnsupported)
1423 return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
1428 std::string& outReasonIfUnsupported)
1430 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1431 return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1437 std::string& outReasonIfUnsupported,
1440 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1441 return IsLayerConfigurationSupported(layer->GetBackendId(),
1444 outReasonIfUnsupported,
1451 std::string& outReasonIfUnsupported,
1454 return IsLayerConfigurationSupported(backendId,
1457 outReasonIfUnsupported,
1465 return std::unique_ptr<IWorkload>();
1471 return std::unique_ptr<IWorkload>();
1477 return std::unique_ptr<IWorkload>();
1483 return std::unique_ptr<IWorkload>();
1489 return std::unique_ptr<IWorkload>();
1495 return std::unique_ptr<IWorkload>();
1501 return std::unique_ptr<IWorkload>();
1507 return std::unique_ptr<IWorkload>();
1513 return std::unique_ptr<IWorkload>();
1519 return std::unique_ptr<IWorkload>();
1525 return std::unique_ptr<IWorkload>();
1531 return std::unique_ptr<IWorkload>();
1537 return std::unique_ptr<IWorkload>();
1543 return std::unique_ptr<IWorkload>();
1549 return std::unique_ptr<IWorkload>();
1555 return std::unique_ptr<IWorkload>();
1561 return std::unique_ptr<IWorkload>();
1567 return std::unique_ptr<IWorkload>();
1573 return std::unique_ptr<IWorkload>();
1579 return std::unique_ptr<IWorkload>();
1585 return std::unique_ptr<IWorkload>();
1591 return std::unique_ptr<IWorkload>();
1597 return std::unique_ptr<IWorkload>();
1603 return std::unique_ptr<IWorkload>();
1609 return std::unique_ptr<IWorkload>();
1615 return std::unique_ptr<IWorkload>();
1621 return std::unique_ptr<IWorkload>();
1627 return std::unique_ptr<IWorkload>();
1633 return std::unique_ptr<IWorkload>();
1640 return std::unique_ptr<IWorkload>();
1646 return std::unique_ptr<IWorkload>();
1652 return std::unique_ptr<IWorkload>();
1658 return std::unique_ptr<IWorkload>();
1664 return std::unique_ptr<IWorkload>();
1670 return std::unique_ptr<IWorkload>();
1676 return std::unique_ptr<IWorkload>();
1682 return std::unique_ptr<IWorkload>();
1688 return std::unique_ptr<IWorkload>();
1694 return std::unique_ptr<IWorkload>();
1700 return std::unique_ptr<IWorkload>();
1706 return std::unique_ptr<IWorkload>();
1712 return std::unique_ptr<IWorkload>();
1718 return std::unique_ptr<IWorkload>();
1724 return std::unique_ptr<IWorkload>();
1730 return std::unique_ptr<IWorkload>();
1736 return std::unique_ptr<IWorkload>();
1742 return std::unique_ptr<IWorkload>();
1748 return std::unique_ptr<IWorkload>();
1754 return std::unique_ptr<IWorkload>();
1760 return std::unique_ptr<IWorkload>();
1766 return std::unique_ptr<IWorkload>();
1772 return std::unique_ptr<IWorkload>();
1778 return std::unique_ptr<IWorkload>();
1783 return std::unique_ptr<IWorkload>();
1789 return std::unique_ptr<IWorkload>();
1795 return std::unique_ptr<IWorkload>();
1801 return std::unique_ptr<IWorkload>();
1807 return std::unique_ptr<IWorkload>();
1813 return std::unique_ptr<IWorkload>();
1819 return std::unique_ptr<IWorkload>();
1825 return std::unique_ptr<IWorkload>();
1831 return std::unique_ptr<IWorkload>();
1837 return std::unique_ptr<IWorkload>();
1843 return std::unique_ptr<IWorkload>();
1849 return std::unique_ptr<IWorkload>();
1855 return std::unique_ptr<IWorkload>();
1861 return std::unique_ptr<IWorkload>();
1867 return std::unique_ptr<IWorkload>();
1873 return std::unique_ptr<IWorkload>();
1879 return std::unique_ptr<IWorkload>();
1886 return std::unique_ptr<IWorkload>();
1893 return std::unique_ptr<IWorkload>();
virtual std::unique_ptr< IWorkload > CreateSplitter(const SplitterQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateBatchNormalization(const BatchNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDebug(const DebugQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemCopy(const MemCopyQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateL2Normalization(const L2NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
virtual std::unique_ptr< IWorkload > CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateMultiplication(const MultiplicationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateGreater(const GreaterQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateArgMinMax(const ArgMinMaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMerger(const MergerQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateLogSoftmax(const LogSoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateResizeBilinear(const ResizeBilinearQueueDescriptor &descriptor, const WorkloadInfo &info) const
std::vector< BackendOptions > ModelOptions
virtual std::unique_ptr< IWorkload > CreateStridedSlice(const StridedSliceQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateStack(const StackQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLstm(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
constexpr TransformIterator< Function, Iterator > MakeTransformIterator(Iterator i, Function f)
virtual std::unique_ptr< IWorkload > CreateFakeQuantization(const FakeQuantizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQuantizedLstm(const QuantizedLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateQLstm(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConstant(const ConstantQueueDescriptor &descriptor, const WorkloadInfo &info) const
BackendRegistry & BackendRegistryInstance()
virtual std::unique_ptr< IWorkload > CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateAbs(const AbsQueueDescriptor &descriptor, const WorkloadInfo &info) const
Copyright (c) 2021 ARM Limited and Contributors.
virtual std::unique_ptr< IWorkload > CreateActivation(const ActivationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateRsqrt(const RsqrtQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTranspose(const TransposeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDivision(const DivisionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMaximum(const MaximumQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConcat(const ConcatQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateUnidirectionalSequenceLstm(const UnidirectionalSequenceLstmQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMerge(const MergeQueueDescriptor &descriptor, const WorkloadInfo &info) const
armnn::Optional< armnn::DataType > GetBiasTypeFromWeightsType(armnn::Optional< armnn::DataType > weightsType)
virtual std::unique_ptr< IWorkload > CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateEqual(const EqualQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateRank(const RankQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateResize(const ResizeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateCast(const CastQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual std::unique_ptr< IWorkload > CreateQuantize(const QuantizeQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReduce(const ReduceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSwitch(const SwitchQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreatePad(const PadQueueDescriptor &descriptor, const WorkloadInfo &Info) const
#define ARMNN_ASSERT(COND)
LstmDescriptor UnidirectionalSequenceLstmDescriptor
static bool IsLayerSupported(const BackendId &backendId, const IConnectableLayer &layer, Optional< DataType > dataType, std::string &outReasonIfUnsupported)
virtual std::unique_ptr< IWorkload > CreateNormalization(const NormalizationQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateLogicalBinary(const LogicalBinaryQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateReshape(const ReshapeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePermute(const PermuteQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFill(const FillQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateComparison(const ComparisonQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreatePooling2d(const Pooling2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSpaceToDepth(const SpaceToDepthQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateGather(const GatherQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMinimum(const MinimumQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDepthToSpace(const DepthToSpaceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSlice(const SliceQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateAddition(const AdditionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMean(const MeanQueueDescriptor &descriptor, const WorkloadInfo &Info) const
virtual std::unique_ptr< IWorkload > CreateOutput(const OutputQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSoftmax(const SoftmaxQueueDescriptor &descriptor, const WorkloadInfo &info) const
Contains information about TensorInfos of a layer.
virtual std::unique_ptr< IWorkload > CreateFullyConnected(const FullyConnectedQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateFloor(const FloorQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateMemImport(const MemImportQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateSubtraction(const SubtractionQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePreCompiled(const PreCompiledQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateShape(const ShapeQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor &descriptor, const WorkloadInfo &info) const
Depthwise Convolution 2D layer workload data.
virtual std::unique_ptr< IWorkload > CreateConvolution2d(const Convolution2dQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreatePrelu(const PreluQueueDescriptor &descriptor, const WorkloadInfo &info) const
virtual std::unique_ptr< IWorkload > CreateDequantize(const DequantizeQueueDescriptor &descriptor, const WorkloadInfo &info) const