ArmNN
 25.11
Loading...
Searching...
No Matches
Lstm.hpp
Go to the documentation of this file.
1//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
10
11#include "Encoders.hpp"
12#include "Decoders.hpp"
13
14namespace armnn
15{
16
17void LstmImpl(const LstmDescriptor& descriptor,
18 const TensorInfo& inputInfo,
19 const TensorInfo& outputInfo,
20 const TensorShape& inputToOutputWeightsShape,
21 const TensorShape& recurrentToOutputWeightsShape,
22 std::unique_ptr<Decoder<float>>& inputData,
23 std::unique_ptr<Decoder<float>>& outputStateIn,
24 std::unique_ptr<Decoder<float>>& cellStateIn,
25 std::unique_ptr<Encoder<float>>& outputStateOut,
26 std::unique_ptr<Encoder<float>>& cellStateOut,
27 std::unique_ptr<Encoder<float>>& output,
28 std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
29 std::unique_ptr<Decoder<float>>& outputDecoder,
30 std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
31 std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
32 std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
33 std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
34 std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
35 std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
36 std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
37 std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
38 std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
39 std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
40 std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
41 std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
42 std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
43 std::unique_ptr<Decoder<float>>& cellBiasTensor,
44 std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
45 std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
46 std::unique_ptr<Decoder<float>>& projectionBiasTensor,
47 std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
48 std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
49 std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
50 std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
51 std::unique_ptr<Encoder<float>>& inputGateScratch,
52 std::unique_ptr<Encoder<float>>& cellScratch,
53 std::unique_ptr<Encoder<float>>& forgetGateScratch,
54 std::unique_ptr<Encoder<float>>& outputGateScratch,
55 std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
56 std::unique_ptr<Decoder<float>>& cellScratchDecoder,
57 std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
58 std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
59 float layerNormEpsilon);
60
61} //namespace armnn
Copyright (c) 2021 ARM Limited and Contributors.
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float > > &inputData, std::unique_ptr< Decoder< float > > &outputStateIn, std::unique_ptr< Decoder< float > > &cellStateIn, std::unique_ptr< Encoder< float > > &outputStateOut, std::unique_ptr< Encoder< float > > &cellStateOut, std::unique_ptr< Encoder< float > > &output, std::unique_ptr< Decoder< float > > &cellStateOutDecoder, std::unique_ptr< Decoder< float > > &outputDecoder, std::unique_ptr< Decoder< float > > &inputToInputWeightsTensor, std::unique_ptr< Decoder< float > > &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &inputToCellWeightsTensor, std::unique_ptr< Decoder< float > > &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float > > &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToInputWeightsTensor, std::unique_ptr< Decoder< float > > &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float > > &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float > > &inputGateBiasTensor, std::unique_ptr< Decoder< float > > &forgetGateBiasTensor, std::unique_ptr< Decoder< float > > &cellBiasTensor, std::unique_ptr< Decoder< float > > &outputGateBiasTensor, std::unique_ptr< Decoder< float > > &projectionWeightsTensor, std::unique_ptr< Decoder< float > > &projectionBiasTensor, std::unique_ptr< Decoder< float > > &inputLayerNormWeights, std::unique_ptr< Decoder< float > > &forgetLayerNormWeights, std::unique_ptr< Decoder< float > > &cellLayerNormWeights, std::unique_ptr< Decoder< float > > &outputLayerNormWeights, std::unique_ptr< Encoder< float > > &inputGateScratch, std::unique_ptr< Encoder< float > > &cellScratch, std::unique_ptr< Encoder< float > > &forgetGateScratch, std::unique_ptr< Encoder< float > > &outputGateScratch, std::unique_ptr< Decoder< float > > &inputGateScratchDecoder, std::unique_ptr< Decoder< float > > &cellScratchDecoder, std::unique_ptr< Decoder< float > > &forgetGateScratchDecoder, std::unique_ptr< Decoder< float > > &outputGateScratchDecoder, float layerNormEpsilon)
Definition Lstm.cpp:13
An LstmDescriptor for the LstmLayer.