75 {
76 tesseract::CheckSharedLibraryVersion();
78#if defined(__USE_GNU)
79 if (FLAGS_debug_float) {
80
81 feenableexcept(FE_DIVBYZERO | FE_OVERFLOW | FE_INVALID);
82 }
83#endif
84 if (FLAGS_model_output.empty()) {
85 tprintf(
"Must provide a --model_output!\n");
86 return EXIT_FAILURE;
87 }
88 if (FLAGS_traineddata.empty()) {
89 tprintf(
"Must provide a --traineddata see training documentation\n");
90 return EXIT_FAILURE;
91 }
92
93
94 std::string test_file = FLAGS_model_output.c_str();
95 test_file += "_wtest";
96 FILE *f = fopen(test_file.c_str(), "wb");
97 if (f != nullptr) {
98 fclose(f);
99 if (remove(test_file.c_str()) != 0) {
100 tprintf(
"Error, failed to remove %s: %s\n", test_file.c_str(), strerror(errno));
101 return EXIT_FAILURE;
102 }
103 } else {
104 tprintf(
"Error, model output cannot be written: %s\n", strerror(errno));
105 return EXIT_FAILURE;
106 }
107
108
109 std::string checkpoint_file = FLAGS_model_output.c_str();
110 checkpoint_file += "_checkpoint";
111 std::string checkpoint_bak = checkpoint_file + ".bak";
113 FLAGS_debug_interval,
114 static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
115 if (!trainer.InitCharSet(FLAGS_traineddata.c_str())) {
116 tprintf(
"Error, failed to read %s\n", FLAGS_traineddata.c_str());
117 return EXIT_FAILURE;
118 }
119
120
121
122 if (FLAGS_stop_training || FLAGS_debug_network) {
123 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
124 tprintf(
"Failed to read continue from: %s\n", FLAGS_continue_from.c_str());
125 return EXIT_FAILURE;
126 }
127 if (FLAGS_debug_network) {
128 trainer.DebugNetwork();
129 } else {
130 if (FLAGS_convert_to_int) {
131 trainer.ConvertToInt();
132 }
133 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
134 tprintf(
"Failed to write recognition model : %s\n", FLAGS_model_output.c_str());
135 }
136 }
137 return EXIT_SUCCESS;
138 }
139
140
141 if (FLAGS_train_listfile.empty()) {
142 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
143 return EXIT_FAILURE;
144 }
145 std::vector<std::string> filenames;
147 tprintf(
"Failed to load list of training filenames from %s\n", FLAGS_train_listfile.c_str());
148 return EXIT_FAILURE;
149 }
150
151
152 if (trainer.TryLoadingCheckpoint(checkpoint_file.c_str(), nullptr) ||
153 trainer.TryLoadingCheckpoint(checkpoint_bak.c_str(), nullptr)) {
154 tprintf(
"Successfully restored trainer from %s\n", checkpoint_file.c_str());
155 } else {
156 if (!FLAGS_continue_from.empty()) {
157
158 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
159 FLAGS_append_index >= 0 ? FLAGS_continue_from.c_str()
160 : FLAGS_old_traineddata.c_str())) {
161 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
162 return EXIT_FAILURE;
163 }
164 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
165 if (FLAGS_reset_learning_rate) {
166 trainer.SetLearningRate(FLAGS_learning_rate);
167 tprintf(
"Set learning rate to %f\n",
static_cast<float>(FLAGS_learning_rate));
168 }
169 trainer.InitIterations();
170 }
171 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
172 if (FLAGS_append_index >= 0) {
173 tprintf(
"Appending a new network to an old one!!");
174 if (FLAGS_continue_from.empty()) {
175 tprintf(
"Must set --continue_from for appending!\n");
176 return EXIT_FAILURE;
177 }
178 }
179
180 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index, FLAGS_net_mode,
181 FLAGS_weight_range, FLAGS_learning_rate, FLAGS_momentum,
182 FLAGS_adam_beta)) {
183 tprintf(
"Failed to create network from spec: %s\n", FLAGS_net_spec.c_str());
184 return EXIT_FAILURE;
185 }
186 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
187 }
188 }
189 if (!trainer.LoadAllTrainingData(
190 filenames,
192 FLAGS_randomly_rotate)) {
193 tprintf(
"Load of images failed!!\n");
194 return EXIT_FAILURE;
195 }
196
199 if (!FLAGS_eval_listfile.empty()) {
200 using namespace std::placeholders;
201 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
202 tprintf(
"Failed to load eval data from: %s\n", FLAGS_eval_listfile.c_str());
203 return EXIT_FAILURE;
204 }
206 }
207
208 int max_iterations = FLAGS_max_iterations;
209 if (max_iterations < 0) {
210
211 max_iterations = filenames.size() * (-max_iterations);
212 } else if (max_iterations == 0) {
213
214 max_iterations = INT_MAX;
215 }
216
217 do {
218
219 int iteration = trainer.training_iteration();
221 iteration < target_iteration && iteration < max_iterations;
222 iteration = trainer.training_iteration()) {
223 trainer.TrainOnLine(&trainer, false);
224 }
225 std::string log_str;
226 trainer.MaintainCheckpoints(tester_callback, log_str);
227 tprintf(
"%s\n", log_str.c_str());
228 } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
229 (trainer.training_iteration() < max_iterations));
230 tprintf(
"Finished! Selected model with minimal training error rate (BCER) = %g\n",
231 trainer.best_error_rate());
232 return EXIT_SUCCESS;
233}
const int kNumPagesPerBatch
void tprintf(const char *format,...)
void ParseArguments(int *argc, char ***argv)
std::function< std::string(int, const double *, const TessdataManager &, int)> TestCallback
bool LoadFileLinesToStrings(const char *filename, std::vector< std::string > *lines)
std::string RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)