ArmNN
 25.11
Loading...
Searching...
No Matches
BatchMatMulImpl.hpp
Go to the documentation of this file.
1//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10
12
13namespace armnn
14{
15
17public:
19 const TensorInfo& inputXInfo,
20 const TensorInfo& inputYInfo,
21 const TensorInfo& outputInfo,
22 Decoder<float>& inputXDecoder,
23 Decoder<float>& inputYDecoder,
24 Encoder<float>& outputEncoder);
25
26private:
27 enum DataSlot
28 {
29 InputX = 0,
30 InputY = 1,
31 Output = 2
32 };
33
34 const BatchMatMulDescriptor& params;
35 TensorInfo inputXInfo;
36 TensorInfo inputYInfo;
37 TensorInfo outputInfo;
38 Decoder<float>& inputXDecoder;
39 Decoder<float>& inputYDecoder;
40 Encoder<float>& outputEncoder;
41
42 std::vector<float> inputXData;
43 std::vector<float> inputYData;
44
45 void ApplyBatchMatMul();
46
47 void ApplyParams();
48
49 void Transpose(DataSlot type);
50
51 void Adjoint(DataSlot type);
52
53 void RecurseTensor(const TensorInfo& tensorInfo,
54 std::function<void(const std::vector<unsigned int>&)> const& operation,
55 std::vector<unsigned int>& curIdx,
56 unsigned int curDim);
57
58 // Adjusts it for when input tensors are of unequal rank
59 void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
60 std::pair<unsigned int, unsigned int>& axesYToMul);
61
62 float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
63
64 void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
65
66 // Takes into account broadcasting
67 void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
68
69 unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
70};
71
72} // namespace armnn
BatchMatMul(const BatchMatMulDescriptor &params, const TensorInfo &inputXInfo, const TensorInfo &inputYInfo, const TensorInfo &outputInfo, Decoder< float > &inputXDecoder, Decoder< float > &inputYDecoder, Encoder< float > &outputEncoder)
Copyright (c) 2021 ARM Limited and Contributors.
A BatchMatMulDescriptor for the BatchMatMul operator.