10 #include <arm_compute/dynamic_fusion/sketch/gpu/operators/GpuReshape.h>
11 #include <arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h>
12 #include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h>
13 #include <arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h>
15 using namespace arm_compute::experimental::dynamic_fusion;
20 using namespace armcomputetensorutils;
24 auto compileContext = arm_compute::CLKernelLibrary::get().get_compile_context();
25 auto workloadContext = GpuWorkloadContext(&compileContext);
27 GpuWorkloadSketch sketch(&workloadContext);
29 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, input.
GetNumDimensions());
30 aclInputInfo.set_are_values_constant(input.
IsConstant());
32 arm_compute::ITensorInfo* inputInfo = workloadContext.create_tensor_info(aclInputInfo);
34 ReshapeAttributes attributes;
35 attributes.shape(BuildArmComputeTensorShape(descriptor.
m_TargetShape));
40 if (aclStatus.error_code() != arm_compute::ErrorCode::OK)
42 std::cout <<
"GpuFsaReshapeValidate failed: " << aclStatus.error_description() << std::endl;
51 GpuWorkloadSketch* sketch = blob->
sketch.get();
54 std::vector<arm_compute::ITensorInfo*> inputTensorInfos;
55 std::vector<arm_compute::ITensorInfo*> outputTensorInfos;
57 arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, input.
GetNumDimensions());
59 aclInputInfo.set_are_values_constant(input.
IsConstant());
61 inputTensorInfos.emplace_back(workloadContext->create_tensor_info(aclInputInfo));
63 ReshapeAttributes attributes;
64 attributes.shape(BuildArmComputeTensorShape(descriptor.
m_TargetShape));
66 arm_compute::ITensorInfo* addOutputInfo = GpuReshape::create_op(*sketch, inputTensorInfos[0], attributes);
69 outputTensorInfos.emplace_back(workloadContext->create_tensor_info());
70 GpuOutput::create_op(*sketch, addOutputInfo, outputTensorInfos[0]);
73 blob->
inputTensorInfos = std::make_unique<std::vector<arm_compute::ITensorInfo*>>(inputTensorInfos);
74 blob->
outputTensorInfos = std::make_unique<std::vector<arm_compute::ITensorInfo*>>(outputTensorInfos);