9 const std::vector<const TensorInfo*>& inputs,
10 const std::vector<const TensorInfo*>& outputs,
13 std::string padInputName = std::string(
"input_");
15 std::string poolOutputName = std::string(
"output0_");
26 std::vector<int> paddings;
31 static_cast<int>(poolDescriptor->
m_PadTop),
33 static_cast<int>(poolDescriptor->
m_PadLeft),
45 static_cast<int>(poolDescriptor->
m_PadTop),
47 static_cast<int>(poolDescriptor->
m_PadLeft),
52 std::vector<TosaSerializationTensor*> tensors;
53 std::vector<TosaSerializationOperator*> operators;
55 TosaPadAttribute padAttribute(paddings, 0, 0.0f);
56 operators.push_back(
new TosaSerializationOperator(Op_PAD,
57 Attribute_PadAttribute,
62 std::vector<int> pad = {0, 0, 0, 0};
63 std::vector<int> kernel = {
static_cast<int>(poolDescriptor->
m_PoolHeight),
65 std::vector<int> stride = {
static_cast<int>(poolDescriptor->
m_StrideY),
66 static_cast<int>(poolDescriptor->
m_StrideX)};
67 std::vector<int> dilation = {1, 1};
70 DType inputDType =
ArmNNToDType(inputs[0]->GetDataType());
83 TosaPoolAttribute poolAttribute(pad, kernel, stride, 0, 0,
ArmNNToDType(inputs[0]->GetDataType()));
85 operators.push_back(
new TosaSerializationOperator(Op_AVG_POOL2D,
86 Attribute_PoolAttribute,
94 if(padInputName.find(
"input_") != std::string::npos)
96 tensors.push_back(
new TosaSerializationTensor(padInputName, inputShape, inputDType, {}));
100 DType outputDType =
ArmNNToDType(outputs[0]->GetDataType());
102 std::vector<int32_t> intermediateShape;
105 intermediateShape = {inputShape[0],
106 inputShape[1] + paddings[2] + paddings[3],
107 inputShape[2] + paddings[4] + paddings[5],
112 intermediateShape = {inputShape[0],
114 inputShape[2] + paddings[4] + paddings[5],
115 inputShape[3] + paddings[6] + paddings[7]};
118 tensors.push_back(
new TosaSerializationTensor(padOutputName, intermediateShape, inputDType, {}));
119 tensors.push_back(
new TosaSerializationTensor(poolOutputName, outputShape, outputDType, {}));
123 return new TosaSerializationBasicBlock(blockName,
std::string GetInputSlicedToItsUsedSize(const std::vector< int32_t > &inputShape, const std::string &inputName, const DataLayout layout, const DType datatype, const std::vector< int32_t > &kernel, const std::vector< int32_t > &pad, const std::vector< int32_t > &stride, const std::vector< int32_t > &dilations, std::vector< TosaSerializationTensor * > &tensors, std::vector< TosaSerializationOperator * > &operators, const bool isPoolingOp=false)