Compute Library
 22.08
ClFusedKernelGraph.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
27 
28 namespace arm_compute
29 {
30 namespace experimental
31 {
32 namespace dynamic_fusion
33 {
34 namespace
35 {
36 std::vector<std::pair<ClKernelFusionGroup *, ClKernelFusionGroup *>> get_combinations(const std::vector<ClKernelFusionGroup *> &sorted_fgs)
37 {
38  ARM_COMPUTE_ERROR_ON(sorted_fgs.size() <= 1);
39  std::vector<std::pair<ClKernelFusionGroup *, ClKernelFusionGroup *>> combo;
40  for(size_t i = 0; i < sorted_fgs.size() - 1; ++i)
41  {
42  for(size_t j = i + 1; j < sorted_fgs.size(); ++j)
43  {
44  combo.push_back(std::make_pair(sorted_fgs.at(i), sorted_fgs.at(j)));
45  }
46  }
47  return combo;
48 }
49 } // namespace
50 std::vector<const ClKernel *> traverse(const ClKernelFusionGroup &group)
51 {
52  std::vector<const ClKernel *> kernels;
53  const auto sorted = group.graph.topological_sort();
54  for(const auto &pack : sorted.second)
55  {
56  kernels.push_back(group.fused_kernels.at(pack.op));
57  }
58  return kernels;
59 }
60 
61 std::vector<const ClKernelFusionGroup *> traverse(const ClFusedKernelGraph &graph)
62 {
63  std::vector<const ClKernelFusionGroup *> kernels;
64  const auto sorted = graph.fg_dependency.topological_sort();
65  for(const auto &pack : sorted.second)
66  {
67  kernels.push_back(graph.fusion_groups.at(pack.op).get());
68  }
69  return kernels;
70 }
71 
72 std::vector<ClKernelFusionGroup *> traverse(ClFusedKernelGraph &graph)
73 {
74  std::vector<ClKernelFusionGroup *> kernels;
75  const auto sorted = graph.fg_dependency.topological_sort();
76  for(const auto &pack : sorted.second)
77  {
78  kernels.push_back(graph.fusion_groups.at(pack.op).get());
79  }
80  return kernels;
81 }
82 
83 std::pair<Status, ClFusedKernelGraph> init_fusion_graph(const ClKernelGraph &kernel_graph)
84 {
85  ClFusedKernelGraph fused_kernel_graph{};
86  fused_kernel_graph.original_graph = &kernel_graph; // Create a copy of the original kernel graph
87  fused_kernel_graph.fg_dependency = DependencyGraph();
88  // Initialize all fusion groups
89  for(const auto &kernel : traverse(kernel_graph))
90  {
91  fused_kernel_graph.add_fusion_group({ kernel });
92  }
93  return { Status{}, fused_kernel_graph };
94 }
95 
96 Status fuse(ClFusedKernelGraph &fused_kernel_graph)
97 {
98  // A naive fusion algorithm that's guaranteed to find optimal pattern if there are no branches
99  // If there are branches, the algorithm cannot guanrantee optimality as it doesn't perform any searches
100 
101  bool fusion_found = false;
102  do
103  {
104  fusion_found = false;
105  const auto sorted_fgs = traverse(fused_kernel_graph);
106  if(sorted_fgs.size() <= 1)
107  {
108  // Only one or zero fusion group, thus no need to perform fusion
109  return Status{};
110  }
111  auto fgs_combo = get_combinations(sorted_fgs);
112  for(auto fgs : fgs_combo)
113  {
114  auto fg0 = fgs.first;
115  auto fg1 = fgs.second;
116  const auto st = fused_kernel_graph.can_fuse(*fg0, *fg1);
117  if(bool(st))
118  {
119  const auto st = fused_kernel_graph.fuse(*fg0, *fg1);
120  if(!bool(st))
121  {
122  return st;
123  }
124  fusion_found = true;
125  break;
126  }
127  }
128  }
129  while(fusion_found);
130  return Status{};
131 }
133 {
134  Status st{};
135  for(const auto &dst_t_id : fused_kernel_graph.fg_dependency.dst_tensors(fg.id))
136  {
137  const auto dst_t = fused_kernel_graph.original_graph->get_tensor(dst_t_id);
138 
139  /// NOTE: dst tensor must have already been added to the blueprint at this point
140  ArgumentID dst_id;
141  st = add_tensor(bp, dst_t->desc, dst_id, dst_t->id);
142  if(!bool(st))
143  {
144  return st;
145  }
146  /// NOTE: the extra dst tensor is needed as the store kcomp requires 2 tensors. But this is irrelevant to the fused kernel graph
147  /// since both tensors share the exact same info and kernel arg descriptor
148  ArgumentID dst_dst_id;
149  st = add_tensor(bp, dst_t->desc, dst_dst_id);
150  if(!bool(st))
151  {
152  return st;
153  }
154  /// NOTE: Update the merge point map to link dst_dst_id with dst_t->id instead.
155  /// This is required because the get_arguments() returned by the blueprint returns the dst tensor added by the store component
156  st = update_merge_point(bp, dst_dst_id, dst_t->id);
157  if(!bool(st))
158  {
159  return st;
160  }
161  st = add_kcomp_store(bp, fg.get_root_kernel()->config().store_type, dst_id, dst_dst_id);
162  if(!bool(st))
163  {
164  return st;
165  }
166  }
167  return st;
168 }
169 
170 Status generate(ClWorkload &workload, const ClWorkloadContext &ctx, const ClFusedKernelGraph &fused_kernel_graph)
171 {
172  workload.context = ctx;
173  for(const auto &fg : traverse(fused_kernel_graph))
174  {
175  ClKernelBlueprint bp{};
176  for(const auto &kernel : traverse(*fg))
177  {
178  const auto st = kernel->generate(bp);
179  if(!bool(st))
180  {
181  return st;
182  }
183  }
184  auto st = set_tile_info(bp, fg->get_root_kernel()->config().tile_desc);
185  if(!bool(st))
186  {
187  return st;
188  }
189  st = generate_store(bp, fused_kernel_graph, *fg);
190  if(!bool(st))
191  {
192  return st;
193  }
194 
195  ClKernelCode code{};
196  st = build(code, ClCodeBuilderContext{ ctx.gpu_info }, bp);
197  if(!bool(st))
198  {
199  return st;
200  }
201  const auto bp_graph = get_dependency_graph(bp);
202 
203  // Get tensor info
204  std::vector<Id> workload_src_tensors{};
205  for(const auto &src_t_id : fused_kernel_graph.fg_dependency.src_tensors(fg->id))
206  {
207  const auto src_t = fused_kernel_graph.original_graph->get_tensor(src_t_id);
208  // Get corresponding kernel arg descriptor
209  const auto arg_desc = code.arguments.at(bp_graph.get_merge_points().at(src_t->id));
210  const auto kernel_t_id = workload.add_workload_tensor(src_t->desc, src_t->memory_type, src_t->memory_info, arg_desc, src_t->id);
211  workload_src_tensors.push_back(kernel_t_id);
212  }
213  std::vector<Id> workload_dst_tensors{};
214  for(const auto &dst_t_id : fused_kernel_graph.fg_dependency.dst_tensors(fg->id))
215  {
216  const auto dst_t = fused_kernel_graph.original_graph->get_tensor(dst_t_id);
217  // Get corresponding kernel arg descriptor
218  const auto arg_desc = code.arguments.at(bp_graph.get_merge_points().at(dst_t->id));
219  const auto kernel_t_id = workload.add_workload_tensor(dst_t->desc, dst_t->memory_type, dst_t->memory_info, arg_desc, dst_t->id);
220  workload_dst_tensors.push_back(kernel_t_id);
221  }
222 
223  workload.add_unit_workload(fg->get_root_kernel()->config().stage, code, workload_src_tensors, workload_dst_tensors);
224  }
225 
226  return Status{};
227 }
228 
229 } // namespace dynamic_fusion
230 } // namespace experimental
231 } // namespace arm_compute
232 #endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
The dependency graph of a workload, where the nodes are of 2 types: Tensor or Operator Represented as...
Status fuse(ClKernelFusionGroup &fg0, ClKernelFusionGroup &fg1)
Tid add_workload_tensor(ITensorInfo *info, MemoryType memory_type, const AuxMemoryInfo &memory_info, const ClKernelArgDescriptor &kernel_arg, Tid merge_point)
Definition: ClWorkload.h:175
All information required for building the ClKernelCode.
Context (device capabilities, platform details) associated with a ClWorkload.
Definition: ClWorkload.h:162
#define ARM_COMPUTE_ERROR_ON(cond)
If the condition is true then an error message is printed and an exception thrown.
Definition: Error.h:466
Status class.
Definition: Error.h:52
std::pair< Status, std::vector< OpPack > > topological_sort() const
Sort the graph in a topological order.
std::vector< const ClKernel * > traverse(const ClKernelFusionGroup &group)
ClWorkloadContext context
Workload context.
Definition: ClWorkload.h:198
DependencyGraph get_dependency_graph(const ClKernelBlueprint &blueprint)
Get dependency graph.
Copyright (c) 2017-2022 Arm Limited.
Status update_merge_point(ClKernelBlueprint &bp, ArgumentID t_id, ArgumentID merge_point)
Update existing merge tensor merge_point to point to t_id.
Status can_fuse(const ClKernelFusionGroup &fg0, const ClKernelFusionGroup &fg1) const
Intermediate representation of the final, complete kernel source.
Contains kernel code to be compiled and run in a ClUnitWorkload.
Definition: ClWorkload.h:100
Status generate(ClWorkload &workload, const ClWorkloadContext &ctx, const ClFusedKernelGraph &fused_kernel_graph)
std::vector< Id > src_tensors(Id op) const
Get source tensors to an operator.
Status fuse(ClFusedKernelGraph &fused_kernel_graph)
std::vector< Id > dst_tensors(Id op) const
Get destination tensors to an operator.
Status build(ClWorkload &workload, const OperatorGraph &op_graph, const ClWorkloadContext &ctx)
Build a ClWorkload from an OperatorGraph.
Definition: ClWorkload.cpp:36
OpTensor add_tensor(OperatorGraph &graph, ITensorInfo &info)
Associate a TensorInfo with a newly created OpTensor in the graph.
UnitWorkId add_unit_workload(UnitWorkloadStage stage, const ClKernelCode &code, const std::vector< Tid > &inputs, const std::vector< Tid > &outputs)
Definition: ClWorkload.h:184
A const view of a subgraph of the ClKernelGraph to be fused together.
Status set_tile_info(ClKernelBlueprint &bp, const TileDescriptor &tile_info)
Status add_kcomp_store(ClKernelBlueprint &kernel_blueprint, const StoreType &store_type, ArgumentID src_tile, ArgumentID dst_tile)
std::pair< Status, ClFusedKernelGraph > init_fusion_graph(const ClKernelGraph &kernel_graph)
Status generate_store(ClKernelBlueprint &bp, const ClFusedKernelGraph &fused_kernel_graph, const ClKernelFusionGroup &fg)