30 #include "utils/Utils.h"
38 class GraphDeepSpeechExample :
public Example
41 GraphDeepSpeechExample() : cmd_parser(), common_opts(cmd_parser), common_params(), graph(0,
"DeepSpeech v0.4.1")
44 bool do_setup(
int argc,
char **argv)
override
47 cmd_parser.parse(argc, argv);
48 cmd_parser.validate();
54 if (common_params.help)
56 cmd_parser.print_help(argv[0]);
61 std::cout << common_params << std::endl;
64 std::string data_path = common_params.data_path;
65 const std::string model_path =
"/cnn_data/deepspeech_model/";
67 if (!data_path.empty())
69 data_path += model_path;
74 const unsigned int n_steps = 16;
77 const float cell_clip = 20.f;
88 graph << common_params.target << common_params.fast_math_hint
122 NodeParams unstack_params = {
"unstack", graph.hints().target_hint};
124 GraphBuilder::add_split_node(graph.graph(), unstack_params, {graph.tail_node(), 0}, n_steps, 2);
145 lstm_weights_descriptor,
153 std::pair<SubStream, SubStream> new_state_1 =
154 add_lstm_cell(unstack_nid, 0, previous_state, previous_state, add_y, lstm_fc_weights, lstm_fc_bias);
155 std::pair<SubStream, SubStream> new_state_2 =
156 add_lstm_cell(unstack_nid, 1, new_state_1.first, new_state_1.second, add_y, lstm_fc_weights, lstm_fc_bias);
157 std::pair<SubStream, SubStream> new_state_3 =
158 add_lstm_cell(unstack_nid, 2, new_state_2.first, new_state_2.second, add_y, lstm_fc_weights, lstm_fc_bias);
159 std::pair<SubStream, SubStream> new_state_4 =
160 add_lstm_cell(unstack_nid, 3, new_state_3.first, new_state_3.second, add_y, lstm_fc_weights, lstm_fc_bias);
161 std::pair<SubStream, SubStream> new_state_5 =
162 add_lstm_cell(unstack_nid, 4, new_state_4.first, new_state_4.second, add_y, lstm_fc_weights, lstm_fc_bias);
163 std::pair<SubStream, SubStream> new_state_6 =
164 add_lstm_cell(unstack_nid, 5, new_state_5.first, new_state_5.second, add_y, lstm_fc_weights, lstm_fc_bias);
165 std::pair<SubStream, SubStream> new_state_7 =
166 add_lstm_cell(unstack_nid, 6, new_state_6.first, new_state_6.second, add_y, lstm_fc_weights, lstm_fc_bias);
167 std::pair<SubStream, SubStream> new_state_8 =
168 add_lstm_cell(unstack_nid, 7, new_state_7.first, new_state_7.second, add_y, lstm_fc_weights, lstm_fc_bias);
169 std::pair<SubStream, SubStream> new_state_9 =
170 add_lstm_cell(unstack_nid, 8, new_state_8.first, new_state_8.second, add_y, lstm_fc_weights, lstm_fc_bias);
171 std::pair<SubStream, SubStream> new_state_10 =
172 add_lstm_cell(unstack_nid, 9, new_state_9.first, new_state_9.second, add_y, lstm_fc_weights, lstm_fc_bias);
173 std::pair<SubStream, SubStream> new_state_11 = add_lstm_cell(
174 unstack_nid, 10, new_state_10.first, new_state_10.second, add_y, lstm_fc_weights, lstm_fc_bias);
175 std::pair<SubStream, SubStream> new_state_12 = add_lstm_cell(
176 unstack_nid, 11, new_state_11.first, new_state_11.second, add_y, lstm_fc_weights, lstm_fc_bias);
177 std::pair<SubStream, SubStream> new_state_13 = add_lstm_cell(
178 unstack_nid, 12, new_state_12.first, new_state_12.second, add_y, lstm_fc_weights, lstm_fc_bias);
179 std::pair<SubStream, SubStream> new_state_14 = add_lstm_cell(
180 unstack_nid, 13, new_state_13.first, new_state_13.second, add_y, lstm_fc_weights, lstm_fc_bias);
181 std::pair<SubStream, SubStream> new_state_15 = add_lstm_cell(
182 unstack_nid, 14, new_state_14.first, new_state_14.second, add_y, lstm_fc_weights, lstm_fc_bias);
183 std::pair<SubStream, SubStream> new_state_16 = add_lstm_cell(
184 unstack_nid, 15, new_state_15.first, new_state_15.second, add_y, lstm_fc_weights, lstm_fc_bias);
188 graph <<
StackLayer(axis, std::move(new_state_1.second), std::move(new_state_2.second),
189 std::move(new_state_3.second), std::move(new_state_4.second), std::move(new_state_5.second),
190 std::move(new_state_6.second), std::move(new_state_7.second), std::move(new_state_8.second),
191 std::move(new_state_9.second), std::move(new_state_10.second),
192 std::move(new_state_11.second), std::move(new_state_12.second),
193 std::move(new_state_13.second), std::move(new_state_14.second),
194 std::move(new_state_15.second), std::move(new_state_16.second))
212 config.
use_tuner = common_params.enable_tuner;
214 config.
mlgo_file = common_params.mlgo_file;
218 graph.finalize(common_params.target, config);
222 void do_run()
override
244 std::pair<SubStream, SubStream> add_lstm_cell(
NodeID unstack_nid,
245 unsigned int unstack_idx,
252 const std::string cell_name(
"rnn/lstm_cell_" +
std::to_string(unstack_idx));
261 set_node_params(graph.
graph(), concat_nid, concat_params);
267 const unsigned int num_splits = 4;
268 const unsigned int split_axis = 0;
272 GraphBuilder::add_split_node(graph.
graph(), split_params, {graph.tail_node(), 0}, num_splits, split_axis);
283 set_node_params(graph.
graph(), sigmoid_1_nid, sigmoid_1_params);
289 set_node_params(graph.
graph(), tanh_nid, tanh_params);
292 tanh_ss.forward_tail(tanh_nid);
299 set_node_params(graph.
graph(), add_nid, add_params);
305 set_node_params(graph.
graph(), sigmoid_2_nid, sigmoid_2_params);
308 sigmoid_1_ss.forward_tail(sigmoid_1_nid);
310 mul_1_ss <<
EltwiseLayer(std::move(sigmoid_1_ss), std::move(tanh_ss), EltwiseOperation::Mul)
314 tanh_1_ss_tmp.forward_tail(add_nid);
319 tanh_1_ss_tmp2 <<
EltwiseLayer(std::move(tanh_1_ss_tmp), std::move(previous_state_c), EltwiseOperation::Mul)
322 tanh_1_ss <<
EltwiseLayer(std::move(tanh_1_ss_tmp2), std::move(mul_1_ss), EltwiseOperation::Add)
323 .
set_name(cell_name +
"/new_state_c");
330 sigmoid_2_ss.forward_tail(sigmoid_2_nid);
331 graph <<
EltwiseLayer(std::move(sigmoid_2_ss), std::move(tanh_1_ss), EltwiseOperation::Mul)
332 .
set_name(cell_name +
"/new_state_h");
335 return std::pair<SubStream, SubStream>(new_state_c, new_state_h);
355 int main(
int argc,
char **argv)
357 return arm_compute::utils::run_example<GraphDeepSpeechExample>(argc, argv);