tesseract 5.2.0
Loading...
Searching...
No Matches
tesseract::LSTMTrainer Class Reference

#include <lstmtrainer.h>

Inheritance diagram for tesseract::LSTMTrainer:
tesseract::LSTMRecognizer

Public Member Functions

 LSTMTrainer ()
 
 LSTMTrainer (const char *model_base, const char *checkpoint_name, int debug_interval, int64_t max_memory)
 
virtual ~LSTMTrainer ()
 
bool TryLoadingCheckpoint (const char *filename, const char *old_traineddata)
 
bool InitCharSet (const std::string &traineddata_path)
 
void InitCharSet (const TessdataManager &mgr)
 
bool InitNetwork (const char *network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
 
int InitTensorFlowNetwork (const std::string &tf_proto)
 
void InitIterations ()
 
double ActivationError () const
 
double CharError () const
 
const double * error_rates () const
 
double best_error_rate () const
 
int best_iteration () const
 
int learning_iteration () const
 
int32_t improvement_steps () const
 
void set_perfect_delay (int delay)
 
const std::vector< char > & best_trainer () const
 
double NewSingleError (ErrorTypes type) const
 
double LastSingleError (ErrorTypes type) const
 
const DocumentCachetraining_data () const
 
DocumentCachemutable_training_data ()
 
Trainability GridSearchDictParams (const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, std::string &results)
 
void DebugNetwork ()
 
bool LoadAllTrainingData (const std::vector< std::string > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
 
bool MaintainCheckpoints (const TestCallback &tester, std::string &log_msg)
 
bool MaintainCheckpointsSpecific (int iteration, const std::vector< char > *train_model, const std::vector< char > *rec_model, TestCallback tester, std::string &log_msg)
 
void PrepareLogMsg (std::string &log_msg) const
 
void LogIterations (const char *intro_str, std::string &log_msg) const
 
bool TransitionTrainingStage (float error_threshold)
 
int CurrentTrainingStage () const
 
bool Serialize (SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
void StartSubtrainer (std::string &log_msg)
 
SubTrainerResult UpdateSubtrainer (std::string &log_msg)
 
void ReduceLearningRates (LSTMTrainer *samples_trainer, std::string &log_msg)
 
int ReduceLayerLearningRates (TFloat factor, int num_samples, LSTMTrainer *samples_trainer)
 
bool EncodeString (const std::string &str, std::vector< int > *labels) const
 
const ImageDataTrainOnLine (LSTMTrainer *samples_trainer, bool batch)
 
Trainability TrainOnLine (const ImageData *trainingdata, bool batch)
 
Trainability PrepareForBackward (const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
 
bool SaveTrainingDump (SerializeAmount serialize_amount, const LSTMTrainer &trainer, std::vector< char > *data) const
 
bool ReadTrainingDump (const std::vector< char > &data, LSTMTrainer &trainer) const
 
bool ReadSizedTrainingDump (const char *data, int size, LSTMTrainer &trainer) const
 
bool ReadLocalTrainingDump (const TessdataManager *mgr, const char *data, int size)
 
void SetupCheckpointInfo ()
 
bool SaveTraineddata (const char *filename)
 
void SaveRecognitionDump (std::vector< char > *data) const
 
std::string DumpFilename () const
 
void FillErrorBuffer (double new_error, ErrorTypes type)
 
std::vector< int > MapRecoder (const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
 
- Public Member Functions inherited from tesseract::LSTMRecognizer
 LSTMRecognizer ()
 
 LSTMRecognizer (const std::string &language_data_path_prefix)
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
float learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
std::vector< std::string > EnumerateLayers () const
 
NetworkGetLayer (const std::string &id) const
 
float GetLayerLearningRate (const std::string &id) const
 
const char * GetNetwork () const
 
float GetAdamBeta () const
 
float GetMomentum () const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const std::string &id, double factor)
 
void SetLearningRate (float learning_rate)
 
void SetLayerLearningRate (const std::string &id, float learning_rate)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
UNICHARSETGetUnicharset ()
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
DictGetDict ()
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const ParamsVectors *params, const std::string &lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, float invert_threshold, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
std::string DecodeLabels (const std::vector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
 

Static Public Member Functions

static bool EncodeString (const std::string &str, const UNICHARSET &unicharset, const UnicharCompress *recoder, bool simple_text, int null_char, std::vector< int > *labels)
 

Protected Member Functions

void InitCharSet ()
 
void SetNullChar ()
 
void EmptyConstructor ()
 
bool DebugLSTMTraining (const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const std::vector< int > &truth_labels, const NetworkIO &outputs)
 
void DisplayTargets (const NetworkIO &targets, const char *window_name, ScrollView **window)
 
bool ComputeTextTargets (const NetworkIO &outputs, const std::vector< int > &truth_labels, NetworkIO *targets)
 
bool ComputeCTCTargets (const std::vector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
 
double ComputeErrorRates (const NetworkIO &deltas, double char_error, double word_error)
 
double ComputeRMSError (const NetworkIO &deltas)
 
double ComputeWinnerError (const NetworkIO &deltas)
 
double ComputeCharError (const std::vector< int > &truth_str, const std::vector< int > &ocr_str)
 
double ComputeWordError (std::string *truth_str, std::string *ocr_str)
 
void UpdateErrorBuffer (double new_error, ErrorTypes type)
 
void RollErrorBuffers ()
 
std::string UpdateErrorGraph (int iteration, double error_rate, const std::vector< char > &model_data, const TestCallback &tester)
 
- Protected Member Functions inherited from tesseract::LSTMRecognizer
void SetRandomSeed ()
 
void DisplayLSTMOutput (const std::vector< int > &labels, const std::vector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, std::vector< int > *labels, std::vector< int > *xcoords)
 
const char * DecodeLabel (const std::vector< int > &labels, unsigned start, unsigned *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

ScrollViewalign_win_
 
ScrollViewtarget_win_
 
ScrollViewctc_win_
 
ScrollViewrecon_win_
 
int debug_interval_
 
int checkpoint_iteration_
 
std::string model_base_
 
std::string checkpoint_name_
 
bool randomly_rotate_
 
DocumentCache training_data_
 
std::string best_model_name_
 
int num_training_stages_
 
double best_error_rate_
 
double best_error_rates_ [ET_COUNT]
 
int best_iteration_
 
double worst_error_rate_
 
double worst_error_rates_ [ET_COUNT]
 
int worst_iteration_
 
int stall_iteration_
 
std::vector< char > best_model_data_
 
std::vector< char > worst_model_data_
 
std::vector< char > best_trainer_
 
std::unique_ptr< LSTMTrainersub_trainer_
 
float error_rate_of_last_saved_best_
 
int training_stage_
 
std::vector< double > best_error_history_
 
std::vector< int32_t > best_error_iterations_
 
int32_t improvement_steps_
 
int learning_iteration_
 
int prev_sample_iteration_
 
int perfect_delay_
 
int last_perfect_training_iteration_
 
std::vector< double > error_buffers_ [ET_COUNT]
 
double error_rates_ [ET_COUNT]
 
TessdataManager mgr_
 
- Protected Attributes inherited from tesseract::LSTMRecognizer
Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
std::string network_str_
 
int32_t training_flags_
 
int32_t training_iteration_
 
int32_t sample_iteration_
 
int32_t null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Static Protected Attributes

static const int kRollingBufferSize_ = 1000
 

Detailed Description

Definition at line 83 of file lstmtrainer.h.

Constructor & Destructor Documentation

◆ LSTMTrainer() [1/2]

tesseract::LSTMTrainer::LSTMTrainer ( )

Definition at line 75 of file lstmtrainer.cpp.

76 : randomly_rotate_(false), training_data_(0), sub_trainer_(nullptr) {
79}
DocumentCache training_data_
Definition: lstmtrainer.h:425
std::unique_ptr< LSTMTrainer > sub_trainer_
Definition: lstmtrainer.h:454

◆ LSTMTrainer() [2/2]

tesseract::LSTMTrainer::LSTMTrainer ( const char *  model_base,
const char *  checkpoint_name,
int  debug_interval,
int64_t  max_memory 
)

Definition at line 81 of file lstmtrainer.cpp.

83 : randomly_rotate_(false),
84 training_data_(max_memory),
85 sub_trainer_(nullptr) {
87 debug_interval_ = debug_interval;
88 model_base_ = model_base;
89 checkpoint_name_ = checkpoint_name;
90}
std::string model_base_
Definition: lstmtrainer.h:420
std::string checkpoint_name_
Definition: lstmtrainer.h:422

◆ ~LSTMTrainer()

tesseract::LSTMTrainer::~LSTMTrainer ( )
virtual

Definition at line 92 of file lstmtrainer.cpp.

92 {
93#ifndef GRAPHICS_DISABLED
94 delete align_win_;
95 delete target_win_;
96 delete ctc_win_;
97 delete recon_win_;
98#endif
99}
ScrollView * target_win_
Definition: lstmtrainer.h:409
ScrollView * recon_win_
Definition: lstmtrainer.h:413
ScrollView * ctc_win_
Definition: lstmtrainer.h:411
ScrollView * align_win_
Definition: lstmtrainer.h:407

Member Function Documentation

◆ ActivationError()

double tesseract::LSTMTrainer::ActivationError ( ) const
inline

Definition at line 129 of file lstmtrainer.h.

129 {
130 return error_rates_[ET_DELTA];
131 }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:485

◆ best_error_rate()

double tesseract::LSTMTrainer::best_error_rate ( ) const
inline

Definition at line 138 of file lstmtrainer.h.

138 {
139 return best_error_rate_;
140 }

◆ best_iteration()

int tesseract::LSTMTrainer::best_iteration ( ) const
inline

Definition at line 141 of file lstmtrainer.h.

141 {
142 return best_iteration_;
143 }

◆ best_trainer()

const std::vector< char > & tesseract::LSTMTrainer::best_trainer ( ) const
inline

Definition at line 153 of file lstmtrainer.h.

153 {
154 return best_trainer_;
155 }
std::vector< char > best_trainer_
Definition: lstmtrainer.h:451

◆ CharError()

double tesseract::LSTMTrainer::CharError ( ) const
inline

Definition at line 132 of file lstmtrainer.h.

132 {
134 }
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:44

◆ ComputeCharError()

double tesseract::LSTMTrainer::ComputeCharError ( const std::vector< int > &  truth_str,
const std::vector< int > &  ocr_str 
)
protected

Definition at line 1324 of file lstmtrainer.cpp.

1325 {
1326 std::vector<int> label_counts(NumOutputs());
1327 unsigned truth_size = 0;
1328 for (auto ch : truth_str) {
1329 if (ch != null_char_) {
1330 ++label_counts[ch];
1331 ++truth_size;
1332 }
1333 }
1334 for (auto ch : ocr_str) {
1335 if (ch != null_char_) {
1336 --label_counts[ch];
1337 }
1338 }
1339 unsigned char_errors = 0;
1340 for (auto label_count : label_counts) {
1341 char_errors += abs(label_count);
1342 }
1343 // Limit BCER to interval [0,1] and avoid division by zero.
1344 if (truth_size <= char_errors) {
1345 return (char_errors == 0) ? 0.0 : 1.0;
1346 }
1347 return static_cast<double>(char_errors) / truth_size;
1348}

◆ ComputeCTCTargets()

bool tesseract::LSTMTrainer::ComputeCTCTargets ( const std::vector< int > &  truth_labels,
NetworkIO outputs,
NetworkIO targets 
)
protected

Definition at line 1255 of file lstmtrainer.cpp.

1256 {
1257 // Bottom-clip outputs to a minimum probability.
1258 CTC::NormalizeProbs(outputs);
1259 return CTC::ComputeCTCTargets(truth_labels, null_char_,
1260 outputs->float_array(), targets);
1261}
static bool ComputeCTCTargets(const std::vector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36

◆ ComputeErrorRates()

double tesseract::LSTMTrainer::ComputeErrorRates ( const NetworkIO deltas,
double  char_error,
double  word_error 
)
protected

Definition at line 1266 of file lstmtrainer.cpp.

1267 {
1269 // Delta error is the fraction of timesteps with >0.5 error in the top choice
1270 // score. If zero, then the top choice characters are guaranteed correct,
1271 // even when there is residue in the RMS error.
1272 double delta_error = ComputeWinnerError(deltas);
1273 UpdateErrorBuffer(delta_error, ET_DELTA);
1274 UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1275 UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1276 // Skip ratio measures the difference between sample_iteration_ and
1277 // training_iteration_, which reflects the number of unusable samples,
1278 // usually due to unencodable truth text, or the text not fitting in the
1279 // space for the output.
1280 double skip_count = sample_iteration_ - prev_sample_iteration_;
1281 UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1282 return delta_error;
1283}
@ ET_WORD_RECERR
Definition: lstmtrainer.h:43
@ ET_SKIP_RATIO
Definition: lstmtrainer.h:45
double ComputeRMSError(const NetworkIO &deltas)
double ComputeWinnerError(const NetworkIO &deltas)
void UpdateErrorBuffer(double new_error, ErrorTypes type)

◆ ComputeRMSError()

double tesseract::LSTMTrainer::ComputeRMSError ( const NetworkIO deltas)
protected

Definition at line 1286 of file lstmtrainer.cpp.

1286 {
1287 double total_error = 0.0;
1288 int width = deltas.Width();
1289 int num_classes = deltas.NumFeatures();
1290 for (int t = 0; t < width; ++t) {
1291 const float *class_errs = deltas.f(t);
1292 for (int c = 0; c < num_classes; ++c) {
1293 double error = class_errs[c];
1294 total_error += error * error;
1295 }
1296 }
1297 return sqrt(total_error / (width * num_classes));
1298}

◆ ComputeTextTargets()

bool tesseract::LSTMTrainer::ComputeTextTargets ( const NetworkIO outputs,
const std::vector< int > &  truth_labels,
NetworkIO targets 
)
protected

Definition at line 1233 of file lstmtrainer.cpp.

1235 {
1236 if (truth_labels.size() > targets->Width()) {
1237 tprintf("Error: transcription %s too long to fit into target of width %d\n",
1238 DecodeLabels(truth_labels).c_str(), targets->Width());
1239 return false;
1240 }
1241 int i = 0;
1242 for (auto truth_label : truth_labels) {
1243 targets->SetActivations(i, truth_label, 1.0);
1244 ++i;
1245 }
1246 for (i = truth_labels.size(); i < targets->Width(); ++i) {
1247 targets->SetActivations(i, null_char_, 1.0);
1248 }
1249 return true;
1250}
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
std::string DecodeLabels(const std::vector< int > &labels)

◆ ComputeWinnerError()

double tesseract::LSTMTrainer::ComputeWinnerError ( const NetworkIO deltas)
protected

Definition at line 1305 of file lstmtrainer.cpp.

1305 {
1306 int num_errors = 0;
1307 int width = deltas.Width();
1308 int num_classes = deltas.NumFeatures();
1309 for (int t = 0; t < width; ++t) {
1310 const float *class_errs = deltas.f(t);
1311 for (int c = 0; c < num_classes; ++c) {
1312 float abs_delta = std::fabs(class_errs[c]);
1313 // TODO(rays) Filtering cases where the delta is very large to cut out
1314 // GT errors doesn't work. Find a better way or get better truth.
1315 if (0.5 <= abs_delta) {
1316 ++num_errors;
1317 }
1318 }
1319 }
1320 return static_cast<double>(num_errors) / width;
1321}

◆ ComputeWordError()

double tesseract::LSTMTrainer::ComputeWordError ( std::string *  truth_str,
std::string *  ocr_str 
)
protected

Definition at line 1352 of file lstmtrainer.cpp.

1353 {
1354 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1355 std::vector<std::string> truth_words = split(*truth_str, ' ');
1356 if (truth_words.empty()) {
1357 return 0.0;
1358 }
1359 std::vector<std::string> ocr_words = split(*ocr_str, ' ');
1360 StrMap word_counts;
1361 for (const auto &truth_word : truth_words) {
1362 std::string truth_word_string(truth_word.c_str());
1363 auto it = word_counts.find(truth_word_string);
1364 if (it == word_counts.end()) {
1365 word_counts.insert(std::make_pair(truth_word_string, 1));
1366 } else {
1367 ++it->second;
1368 }
1369 }
1370 for (const auto &ocr_word : ocr_words) {
1371 std::string ocr_word_string(ocr_word.c_str());
1372 auto it = word_counts.find(ocr_word_string);
1373 if (it == word_counts.end()) {
1374 word_counts.insert(std::make_pair(ocr_word_string, -1));
1375 } else {
1376 --it->second;
1377 }
1378 }
1379 int word_recall_errs = 0;
1380 for (const auto &word_count : word_counts) {
1381 if (word_count.second > 0) {
1382 word_recall_errs += word_count.second;
1383 }
1384 }
1385 return static_cast<double>(word_recall_errs) / truth_words.size();
1386}
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:41

◆ CurrentTrainingStage()

int tesseract::LSTMTrainer::CurrentTrainingStage ( ) const
inline

Definition at line 216 of file lstmtrainer.h.

216 {
217 return training_stage_;
218 }

◆ DebugLSTMTraining()

bool tesseract::LSTMTrainer::DebugLSTMTraining ( const NetworkIO inputs,
const ImageData trainingdata,
const NetworkIO fwd_outputs,
const std::vector< int > &  truth_labels,
const NetworkIO outputs 
)
protected

Definition at line 1155 of file lstmtrainer.cpp.

1159 {
1160 const std::string &truth_text = DecodeLabels(truth_labels);
1161 if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1162 tprintf("Empty truth string at decode time!\n");
1163 return false;
1164 }
1165 if (debug_interval_ != 0) {
1166 // Get class labels, xcoords and string.
1167 std::vector<int> labels;
1168 std::vector<int> xcoords;
1169 LabelsFromOutputs(outputs, &labels, &xcoords);
1170 std::string text = DecodeLabels(labels);
1171 tprintf("Iteration %d: GROUND TRUTH : %s\n", training_iteration(),
1172 truth_text.c_str());
1173 if (truth_text != text) {
1174 tprintf("Iteration %d: ALIGNED TRUTH : %s\n", training_iteration(),
1175 text.c_str());
1176 }
1178 tprintf("TRAINING activation path for truth string %s\n",
1179 truth_text.c_str());
1180 DebugActivationPath(outputs, labels, xcoords);
1181#ifndef GRAPHICS_DISABLED
1182 DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1183 if (OutputLossType() == LT_CTC) {
1184 DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1185 DisplayTargets(outputs, "CTC Targets", &target_win_);
1186 }
1187#endif
1188 }
1189 }
1190 return true;
1191}
void DebugActivationPath(const NetworkIO &outputs, const std::vector< int > &labels, const std::vector< int > &xcoords)
LossType OutputLossType() const
void LabelsFromOutputs(const NetworkIO &outputs, std::vector< int > *labels, std::vector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const std::vector< int > &labels, const std::vector< int > &label_coords, const char *window_name, ScrollView **window)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)

◆ DebugNetwork()

void tesseract::LSTMTrainer::DebugNetwork ( )

Definition at line 287 of file lstmtrainer.cpp.

287 {
289}
virtual void DebugWeights()=0

◆ DeSerialize()

bool tesseract::LSTMTrainer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 511 of file lstmtrainer.cpp.

511 {
512 if (!LSTMRecognizer::DeSerialize(mgr, fp)) {
513 return false;
514 }
515 if (!fp->DeSerialize(&learning_iteration_)) {
516 // Special case. If we successfully decoded the recognizer, but fail here
517 // then it means we were just given a recognizer, so issue a warning and
518 // allow it.
519 tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
522 return true;
523 }
524 if (!fp->DeSerialize(&prev_sample_iteration_)) {
525 return false;
526 }
527 if (!fp->DeSerialize(&perfect_delay_)) {
528 return false;
529 }
530 if (!fp->DeSerialize(&last_perfect_training_iteration_)) {
531 return false;
532 }
533 for (auto &error_buffer : error_buffers_) {
534 if (!fp->DeSerialize(error_buffer)) {
535 return false;
536 }
537 }
538 if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) {
539 return false;
540 }
541 if (!fp->DeSerialize(&training_stage_)) {
542 return false;
543 }
544 uint8_t amount;
545 if (!fp->DeSerialize(&amount)) {
546 return false;
547 }
548 if (amount == LIGHT) {
549 return true; // Don't read the rest.
550 }
551 if (!fp->DeSerialize(&best_error_rate_)) {
552 return false;
553 }
554 if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) {
555 return false;
556 }
557 if (!fp->DeSerialize(&best_iteration_)) {
558 return false;
559 }
560 if (!fp->DeSerialize(&worst_error_rate_)) {
561 return false;
562 }
563 if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) {
564 return false;
565 }
566 if (!fp->DeSerialize(&worst_iteration_)) {
567 return false;
568 }
569 if (!fp->DeSerialize(&stall_iteration_)) {
570 return false;
571 }
572 if (!fp->DeSerialize(best_model_data_)) {
573 return false;
574 }
575 if (!fp->DeSerialize(worst_model_data_)) {
576 return false;
577 }
578 if (amount != NO_BEST_TRAINER && !fp->DeSerialize(best_trainer_)) {
579 return false;
580 }
581 std::vector<char> sub_data;
582 if (!fp->DeSerialize(sub_data)) {
583 return false;
584 }
585 if (sub_data.empty()) {
586 sub_trainer_ = nullptr;
587 } else {
588 sub_trainer_ = std::make_unique<LSTMTrainer>();
589 if (!ReadTrainingDump(sub_data, *sub_trainer_)) {
590 return false;
591 }
592 }
593 if (!fp->DeSerialize(best_error_history_)) {
594 return false;
595 }
596 if (!fp->DeSerialize(best_error_iterations_)) {
597 return false;
598 }
599 return fp->DeSerialize(&improvement_steps_);
600}
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:34
@ TS_ENABLED
Definition: network.h:93
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:61
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:113
std::vector< int32_t > best_error_iterations_
Definition: lstmtrainer.h:462
std::vector< char > worst_model_data_
Definition: lstmtrainer.h:449
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:442
std::vector< char > best_model_data_
Definition: lstmtrainer.h:448
std::vector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:483
bool ReadTrainingDump(const std::vector< char > &data, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:299
std::vector< double > best_error_history_
Definition: lstmtrainer.h:461
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:436

◆ DisplayTargets()

void tesseract::LSTMTrainer::DisplayTargets ( const NetworkIO targets,
const char *  window_name,
ScrollView **  window 
)
protected

Definition at line 1196 of file lstmtrainer.cpp.

1197 {
1198 int width = targets.Width();
1199 int num_features = targets.NumFeatures();
1200 Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1201 window);
1202 for (int c = 0; c < num_features; ++c) {
1203 int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1204 (*window)->Pen(static_cast<ScrollView::Color>(color));
1205 int start_t = -1;
1206 for (int t = 0; t < width; ++t) {
1207 double target = targets.f(t)[c];
1208 target *= kTargetYScale;
1209 if (target >= 1) {
1210 if (start_t < 0) {
1211 (*window)->SetCursor(t - 1, 0);
1212 start_t = t;
1213 }
1214 (*window)->DrawTo(t, target);
1215 } else if (start_t >= 0) {
1216 (*window)->DrawTo(t, 0);
1217 (*window)->DrawTo(start_t - 1, 0);
1218 start_t = -1;
1219 }
1220 }
1221 if (start_t >= 0) {
1222 (*window)->DrawTo(width, 0);
1223 (*window)->DrawTo(start_t - 1, 0);
1224 }
1225 }
1226 (*window)->Update();
1227}
const int kTargetYScale
Definition: lstmtrainer.cpp:72
const int kTargetXScale
Definition: lstmtrainer.cpp:71
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:350

◆ DumpFilename()

std::string tesseract::LSTMTrainer::DumpFilename ( ) const

Definition at line 1055 of file lstmtrainer.cpp.

1055 {
1056 std::string filename;
1057 filename += model_base_.c_str();
1058 filename += "_" + std::to_string(best_error_rate_);
1059 filename += "_" + std::to_string(best_iteration_);
1060 filename += "_" + std::to_string(training_iteration_);
1061 filename += ".checkpoint";
1062 return filename;
1063}

◆ EmptyConstructor()

void tesseract::LSTMTrainer::EmptyConstructor ( )
protected

Definition at line 1138 of file lstmtrainer.cpp.

1138 {
1139#ifndef GRAPHICS_DISABLED
1140 align_win_ = nullptr;
1141 target_win_ = nullptr;
1142 ctc_win_ = nullptr;
1143 recon_win_ = nullptr;
1144#endif
1146 training_stage_ = 0;
1149}

◆ EncodeString() [1/2]

bool tesseract::LSTMTrainer::EncodeString ( const std::string &  str,
const UNICHARSET unicharset,
const UnicharCompress recoder,
bool  simple_text,
int  null_char,
std::vector< int > *  labels 
)
static

Definition at line 815 of file lstmtrainer.cpp.

818 {
819 if (str.c_str() == nullptr || str.length() <= 0) {
820 tprintf("Empty truth string!\n");
821 return false;
822 }
823 unsigned err_index;
824 std::vector<int> internal_labels;
825 labels->clear();
826 if (!simple_text) {
827 labels->push_back(null_char);
828 }
829 std::string cleaned = unicharset.CleanupString(str.c_str());
830 if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
831 &err_index)) {
832 bool success = true;
833 for (auto internal_label : internal_labels) {
834 if (recoder != nullptr) {
835 // Re-encode labels via recoder.
836 RecodedCharID code;
837 int len = recoder->EncodeUnichar(internal_label, &code);
838 if (len > 0) {
839 for (int j = 0; j < len; ++j) {
840 labels->push_back(code(j));
841 if (!simple_text) {
842 labels->push_back(null_char);
843 }
844 }
845 } else {
846 success = false;
847 err_index = 0;
848 break;
849 }
850 } else {
851 labels->push_back(internal_label);
852 if (!simple_text) {
853 labels->push_back(null_char);
854 }
855 }
856 }
857 if (success) {
858 return true;
859 }
860 }
861 tprintf("Encoding of string failed! Failure bytes:");
862 while (err_index < cleaned.size()) {
863 tprintf(" %x", cleaned[err_index++] & 0xff);
864 }
865 tprintf("\n");
866 return false;
867}

◆ EncodeString() [2/2]

bool tesseract::LSTMTrainer::EncodeString ( const std::string &  str,
std::vector< int > *  labels 
) const
inline

Definition at line 253 of file lstmtrainer.h.

253 {
254 return EncodeString(str, GetUnicharset(),
255 IsRecoding() ? &recoder_ : nullptr, SimpleTextOutput(),
256 null_char_, labels);
257 }
const UNICHARSET & GetUnicharset() const
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:253

◆ error_rates()

const double * tesseract::LSTMTrainer::error_rates ( ) const
inline

Definition at line 135 of file lstmtrainer.h.

135 {
136 return error_rates_;
137 }

◆ FillErrorBuffer()

void tesseract::LSTMTrainer::FillErrorBuffer ( double  new_error,
ErrorTypes  type 
)

Definition at line 1066 of file lstmtrainer.cpp.

1066 {
1067 for (int i = 0; i < kRollingBufferSize_; ++i) {
1068 error_buffers_[type][i] = new_error;
1069 }
1070 error_rates_[type] = 100.0 * new_error;
1071}
static const int kRollingBufferSize_
Definition: lstmtrainer.h:482

◆ GridSearchDictParams()

Trainability tesseract::LSTMTrainer::GridSearchDictParams ( const ImageData trainingdata,
int  iteration,
double  min_dict_ratio,
double  dict_ratio_step,
double  max_dict_ratio,
double  min_cert_offset,
double  cert_offset_step,
double  max_cert_offset,
std::string &  results 
)

Definition at line 234 of file lstmtrainer.cpp.

237 {
238 sample_iteration_ = iteration;
239 NetworkIO fwd_outputs, targets;
240 Trainability result =
241 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
242 if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr) {
243 return result;
244 }
245
246 // Encode/decode the truth to get the normalization.
247 std::vector<int> truth_labels, ocr_labels, xcoords;
248 ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
249 // NO-dict error.
250 RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(),
251 nullptr);
252 base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
253 nullptr);
254 base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
255 std::string truth_text = DecodeLabels(truth_labels);
256 std::string ocr_text = DecodeLabels(ocr_labels);
257 double baseline_error = ComputeWordError(&truth_text, &ocr_text);
258 results += "0,0=" + std::to_string(baseline_error);
259
260 RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_);
261 for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
262 for (double c = min_cert_offset; c < max_cert_offset;
263 c += cert_offset_step) {
264 search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty,
265 nullptr);
266 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
267 truth_text = DecodeLabels(truth_labels);
268 ocr_text = DecodeLabels(ocr_labels);
269 // This is destructive on both strings.
270 double word_error = ComputeWordError(&truth_text, &ocr_text);
271 if ((r == min_dict_ratio && c == min_cert_offset) ||
272 !std::isfinite(word_error)) {
273 std::string t = DecodeLabels(truth_labels);
274 std::string o = DecodeLabels(ocr_labels);
275 tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
276 t.c_str(), o.c_str(), word_error, truth_labels[0]);
277 }
278 results += " " + std::to_string(r);
279 results += "," + std::to_string(c);
280 results += "=" + std::to_string(word_error);
281 }
282 }
283 return result;
284}
#define ASSERT_HOST(x)
Definition: errcode.h:54
@ HI_PRECISION_ERR
Definition: lstmtrainer.h:54
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:211
static constexpr float kMinCertainty
Definition: recodebeam.h:243
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
double ComputeWordError(std::string *truth_str, std::string *ocr_str)

◆ improvement_steps()

int32_t tesseract::LSTMTrainer::improvement_steps ( ) const
inline

Definition at line 147 of file lstmtrainer.h.

147 {
148 return improvement_steps_;
149 }

◆ InitCharSet() [1/3]

void tesseract::LSTMTrainer::InitCharSet ( )
protected

Definition at line 1116 of file lstmtrainer.cpp.

1116 {
1119 // Initialize the unicharset and recoder.
1120 if (!LoadCharsets(&mgr_)) {
1122 "Must provide a traineddata containing lstm_unicharset and"
1123 " lstm_recoder!\n" != nullptr);
1124 }
1125 SetNullChar();
1126}
@ TF_COMPRESS_UNICHARSET
bool LoadCharsets(const TessdataManager *mgr)
TessdataManager mgr_
Definition: lstmtrainer.h:487

◆ InitCharSet() [2/3]

bool tesseract::LSTMTrainer::InitCharSet ( const std::string &  traineddata_path)
inline

Definition at line 99 of file lstmtrainer.h.

99 {
100 bool success = mgr_.Init(traineddata_path.c_str());
101 if (success) {
102 InitCharSet();
103 }
104 return success;
105 }
bool Init(const char *data_file_name)

◆ InitCharSet() [3/3]

void tesseract::LSTMTrainer::InitCharSet ( const TessdataManager mgr)
inline

Definition at line 106 of file lstmtrainer.h.

106 {
107 mgr_ = mgr;
108 InitCharSet();
109 }

◆ InitIterations()

void tesseract::LSTMTrainer::InitIterations ( )

Definition at line 206 of file lstmtrainer.cpp.

206 {
211 best_error_rate_ = 100.0;
212 best_iteration_ = 0;
213 worst_error_rate_ = 0.0;
216 best_error_history_.clear();
219 perfect_delay_ = 0;
221 for (int i = 0; i < ET_COUNT; ++i) {
222 best_error_rates_[i] = 100.0;
223 worst_error_rates_[i] = 0.0;
224 error_buffers_[i].clear();
226 error_rates_[i] = 100.0;
227 }
229}
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:456

◆ InitNetwork()

bool tesseract::LSTMTrainer::InitNetwork ( const char *  network_spec,
int  append_index,
int  net_flags,
float  weight_range,
float  learning_rate,
float  momentum,
float  adam_beta 
)

Definition at line 162 of file lstmtrainer.cpp.

165 {
166 mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec);
167 adam_beta_ = adam_beta;
169 momentum_ = momentum;
170 SetNullChar();
172 append_index, net_flags, weight_range,
173 &randomizer_, &network_)) {
174 return false;
175 }
176 network_str_ += network_spec;
177 tprintf("Built network:%s from request %s\n", network_->spec().c_str(),
178 network_spec);
179 tprintf(
180 "Training parameters:\n Debug interval = %d,"
181 " weights = %g, learning rate = %g, momentum=%g\n",
183 tprintf("null char=%d\n", null_char_);
184 return true;
185}
std::string VersionString() const
void SetVersionString(const std::string &v_str)
virtual std::string spec() const
Definition: network.h:143
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)

◆ InitTensorFlowNetwork()

int tesseract::LSTMTrainer::InitTensorFlowNetwork ( const std::string &  tf_proto)

◆ LastSingleError()

double tesseract::LSTMTrainer::LastSingleError ( ErrorTypes  type) const
inline

Definition at line 163 of file lstmtrainer.h.

163 {
164 return error_buffers_[type]
167 }

◆ learning_iteration()

int tesseract::LSTMTrainer::learning_iteration ( ) const
inline

Definition at line 144 of file lstmtrainer.h.

144 {
145 return learning_iteration_;
146 }

◆ LoadAllTrainingData()

bool tesseract::LSTMTrainer::LoadAllTrainingData ( const std::vector< std::string > &  filenames,
CachingStrategy  cache_strategy,
bool  randomly_rotate 
)

Definition at line 294 of file lstmtrainer.cpp.

296 {
297 randomly_rotate_ = randomly_rotate;
299 return training_data_.LoadDocuments(filenames, cache_strategy,
301}
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
TESS_API bool LoadDocuments(const std::vector< std::string > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:614

◆ LogIterations()

void tesseract::LSTMTrainer::LogIterations ( const char *  intro_str,
std::string &  log_msg 
) const

Definition at line 407 of file lstmtrainer.cpp.

408 {
409 log_msg += intro_str;
410 log_msg += " iteration " + std::to_string(learning_iteration());
411 log_msg += "/" + std::to_string(training_iteration());
412 log_msg += "/" + std::to_string(sample_iteration());
413}
int learning_iteration() const
Definition: lstmtrainer.h:144

◆ MaintainCheckpoints()

bool tesseract::LSTMTrainer::MaintainCheckpoints ( const TestCallback tester,
std::string &  log_msg 
)

Definition at line 307 of file lstmtrainer.cpp.

308 {
309 PrepareLogMsg(log_msg);
310 double error_rate = CharError();
311 int iteration = learning_iteration();
312 if (iteration >= stall_iteration_ &&
313 error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
315 // It hasn't got any better in a long while, and is a margin worse than the
316 // best, so go back to the best model and try a different learning rate.
317 StartSubtrainer(log_msg);
318 }
319 SubTrainerResult sub_trainer_result = STR_NONE;
320 if (sub_trainer_ != nullptr) {
321 sub_trainer_result = UpdateSubtrainer(log_msg);
322 if (sub_trainer_result == STR_REPLACED) {
323 // Reset the inputs, as we have overwritten *this.
324 error_rate = CharError();
325 iteration = learning_iteration();
326 PrepareLogMsg(log_msg);
327 }
328 }
329 bool result = true; // Something interesting happened.
330 std::vector<char> rec_model_data;
331 if (error_rate < best_error_rate_) {
332 SaveRecognitionDump(&rec_model_data);
333 log_msg += " New best BCER = " + std::to_string(error_rate);
334 log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
335 // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
336 // just overwrote *this. In either case, we have finished with it.
337 sub_trainer_.reset();
340 log_msg +=
341 " Transitioned to stage " + std::to_string(CurrentTrainingStage());
342 }
345 std::string best_model_name = DumpFilename();
346 if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
347 log_msg += " failed to write best model:";
348 } else {
349 log_msg += " wrote best model:";
351 }
352 log_msg += best_model_name;
353 }
354 } else if (error_rate > worst_error_rate_) {
355 SaveRecognitionDump(&rec_model_data);
356 log_msg += " New worst BCER = " + std::to_string(error_rate);
357 log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
360 // Error rate has ballooned. Go back to the best model.
361 log_msg += "\nDivergence! ";
362 // Copy best_trainer_ before reading it, as it will get overwritten.
363 std::vector<char> revert_data(best_trainer_);
364 if (ReadTrainingDump(revert_data, *this)) {
365 LogIterations("Reverted to", log_msg);
366 ReduceLearningRates(this, log_msg);
367 } else {
368 LogIterations("Failed to Revert at", log_msg);
369 }
370 // If it fails again, we will wait twice as long before reverting again.
371 stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
372 // Re-save the best trainer with the new learning rates and stall
373 // iteration.
375 }
376 } else {
377 // Something interesting happened only if the sub_trainer_ was trained.
378 result = sub_trainer_result != STR_NONE;
379 }
380 if (checkpoint_name_.length() > 0) {
381 // Write a current checkpoint.
382 std::vector<char> checkpoint;
383 if (!SaveTrainingDump(FULL, *this, &checkpoint) ||
384 !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
385 log_msg += " failed to write checkpoint.";
386 } else {
387 log_msg += " wrote checkpoint.";
388 }
389 }
390 log_msg += "\n";
391 return result;
392}
@ STR_REPLACED
Definition: lstmtrainer.h:69
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:45
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:68
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:62
bool TransitionTrainingStage(float error_threshold)
std::string UpdateErrorGraph(int iteration, double error_rate, const std::vector< char > &model_data, const TestCallback &tester)
void ReduceLearningRates(LSTMTrainer *samples_trainer, std::string &log_msg)
double CharError() const
Definition: lstmtrainer.h:132
void PrepareLogMsg(std::string &log_msg) const
void SaveRecognitionDump(std::vector< char > *data) const
void LogIterations(const char *intro_str, std::string &log_msg) const
void StartSubtrainer(std::string &log_msg)
SubTrainerResult UpdateSubtrainer(std::string &log_msg)
int CurrentTrainingStage() const
Definition: lstmtrainer.h:216
std::string DumpFilename() const
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer &trainer, std::vector< char > *data) const

◆ MaintainCheckpointsSpecific()

bool tesseract::LSTMTrainer::MaintainCheckpointsSpecific ( int  iteration,
const std::vector< char > *  train_model,
const std::vector< char > *  rec_model,
TestCallback  tester,
std::string &  log_msg 
)

◆ MapRecoder()

std::vector< int > tesseract::LSTMTrainer::MapRecoder ( const UNICHARSET old_chset,
const UnicharCompress old_recoder 
) const

Definition at line 1075 of file lstmtrainer.cpp.

1076 {
1077 int num_new_codes = recoder_.code_range();
1078 int num_new_unichars = GetUnicharset().size();
1079 std::vector<int> code_map(num_new_codes, -1);
1080 for (int c = 0; c < num_new_codes; ++c) {
1081 int old_code = -1;
1082 // Find all new unichar_ids that recode to something that includes c.
1083 // The <= is to include the null char, which may be beyond the unicharset.
1084 for (int uid = 0; uid <= num_new_unichars; ++uid) {
1085 RecodedCharID codes;
1086 int length = recoder_.EncodeUnichar(uid, &codes);
1087 int code_index = 0;
1088 while (code_index < length && codes(code_index) != c) {
1089 ++code_index;
1090 }
1091 if (code_index == length) {
1092 continue;
1093 }
1094 // The old unicharset must have the same unichar.
1095 int old_uid =
1096 uid < num_new_unichars
1097 ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
1098 : old_chset.size() - 1;
1099 if (old_uid == INVALID_UNICHAR_ID) {
1100 continue;
1101 }
1102 // The encoding of old_uid at the same code_index is the old code.
1103 RecodedCharID old_codes;
1104 if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
1105 old_code = old_codes(code_index);
1106 break;
1107 }
1108 }
1109 code_map[c] = old_code;
1110 }
1111 return code_map;
1112}
int EncodeUnichar(unsigned unichar_id, RecodedCharID *code) const
size_t size() const
Definition: unicharset.h:355

◆ mutable_training_data()

DocumentCache * tesseract::LSTMTrainer::mutable_training_data ( )
inline

Definition at line 171 of file lstmtrainer.h.

171 {
172 return &training_data_;
173 }

◆ NewSingleError()

double tesseract::LSTMTrainer::NewSingleError ( ErrorTypes  type) const
inline

Definition at line 157 of file lstmtrainer.h.

157 {
159 }

◆ PrepareForBackward()

Trainability tesseract::LSTMTrainer::PrepareForBackward ( const ImageData trainingdata,
NetworkIO fwd_outputs,
NetworkIO targets 
)

Definition at line 904 of file lstmtrainer.cpp.

906 {
907 if (trainingdata == nullptr) {
908 tprintf("Null trainingdata.\n");
909 return UNENCODABLE;
910 }
911 // Ensure repeatability of random elements even across checkpoints.
912 bool debug =
914 std::vector<int> truth_labels;
915 if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
916 tprintf("Can't encode transcription: '%s' in language '%s'\n",
917 trainingdata->transcription().c_str(),
918 trainingdata->language().c_str());
919 return UNENCODABLE;
920 }
921 bool upside_down = false;
922 if (randomly_rotate_) {
923 // This ensures consistent training results.
925 upside_down = randomizer_.SignedRand(1.0) > 0.0;
926 if (upside_down) {
927 // Modify the truth labels to match the rotation:
928 // Apart from space and null, increment the label. This changes the
929 // script-id to the same script-id but upside-down.
930 // The labels need to be reversed in order, as the first is now the last.
931 for (auto truth_label : truth_labels) {
932 if (truth_label != UNICHAR_SPACE && truth_label != null_char_) {
933 ++truth_label;
934 }
935 }
936 std::reverse(truth_labels.begin(), truth_labels.end());
937 }
938 }
939 unsigned w = 0;
940 while (w < truth_labels.size() &&
941 (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_)) {
942 ++w;
943 }
944 if (w == truth_labels.size()) {
945 tprintf("Blank transcription: %s\n", trainingdata->transcription().c_str());
946 return UNENCODABLE;
947 }
948 float image_scale;
949 NetworkIO inputs;
950 bool invert = trainingdata->boxes().empty();
951 if (!RecognizeLine(*trainingdata, invert ? 0.5f : 0.0f, debug, invert, upside_down,
952 &image_scale, &inputs, fwd_outputs)) {
953 tprintf("Image %s not trainable\n", trainingdata->imagefilename().c_str());
954 return UNENCODABLE;
955 }
956 targets->Resize(*fwd_outputs, network_->NumOutputs());
957 LossType loss_type = OutputLossType();
958 if (loss_type == LT_SOFTMAX) {
959 if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
960 tprintf("Compute simple targets failed for %s!\n",
961 trainingdata->imagefilename().c_str());
962 return UNENCODABLE;
963 }
964 } else if (loss_type == LT_CTC) {
965 if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
966 tprintf("Compute CTC targets failed for %s!\n",
967 trainingdata->imagefilename().c_str());
968 return UNENCODABLE;
969 }
970 } else {
971 tprintf("Logistic outputs not implemented yet!\n");
972 return UNENCODABLE;
973 }
974 std::vector<int> ocr_labels;
975 std::vector<int> xcoords;
976 LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
977 // CTC does not produce correct target labels to begin with.
978 if (loss_type != LT_CTC) {
979 LabelsFromOutputs(*targets, &truth_labels, &xcoords);
980 }
981 if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
982 *targets)) {
983 tprintf("Input width was %d\n", inputs.Width());
984 return UNENCODABLE;
985 }
986 std::string ocr_text = DecodeLabels(ocr_labels);
987 std::string truth_text = DecodeLabels(truth_labels);
988 targets->SubtractAllFromFloat(*fwd_outputs);
989 if (debug_interval_ != 0) {
990 if (truth_text != ocr_text) {
991 tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
992 ocr_text.c_str());
993 }
994 }
995 double char_error = ComputeCharError(truth_labels, ocr_labels);
996 double word_error = ComputeWordError(&truth_text, &ocr_text);
997 double delta_error = ComputeErrorRates(*targets, char_error, word_error);
998 if (debug_interval_ != 0) {
999 tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
1000 trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
1001 }
1002 if (delta_error == 0.0) {
1003 return PERFECT;
1004 }
1005 if (targets->AnySuspiciousTruth(kHighConfidence)) {
1006 return HI_PRECISION_ERR;
1007 }
1008 return TRAINABLE;
1009}
@ UNICHAR_SPACE
Definition: unicharset.h:36
const double kHighConfidence
Definition: lstmtrainer.cpp:64
double SignedRand(double range)
Definition: helpers.h:76
void RecognizeLine(const ImageData &image_data, float invert_threshold, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
int NumOutputs() const
Definition: network.h:125
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
bool ComputeCTCTargets(const std::vector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
bool ComputeTextTargets(const NetworkIO &outputs, const std::vector< int > &truth_labels, NetworkIO *targets)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const std::vector< int > &truth_labels, const NetworkIO &outputs)
double ComputeCharError(const std::vector< int > &truth_str, const std::vector< int > &ocr_str)

◆ PrepareLogMsg()

void tesseract::LSTMTrainer::PrepareLogMsg ( std::string &  log_msg) const

Definition at line 395 of file lstmtrainer.cpp.

395 {
396 LogIterations("At", log_msg);
397 log_msg += ", Mean rms=" + std::to_string(error_rates_[ET_RMS]);
398 log_msg += "%, delta=" + std::to_string(error_rates_[ET_DELTA]);
399 log_msg += "%, BCER train=" + std::to_string(error_rates_[ET_CHAR_ERROR]);
400 log_msg += "%, BWER train=" + std::to_string(error_rates_[ET_WORD_RECERR]);
401 log_msg += "%, skip ratio=" + std::to_string(error_rates_[ET_SKIP_RATIO]);
402 log_msg += "%, ";
403}

◆ ReadLocalTrainingDump()

bool tesseract::LSTMTrainer::ReadLocalTrainingDump ( const TessdataManager mgr,
const char *  data,
int  size 
)

Definition at line 1024 of file lstmtrainer.cpp.

1025 {
1026 if (size == 0) {
1027 tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
1028 return false;
1029 }
1030 TFile fp;
1031 fp.Open(data, size);
1032 return DeSerialize(mgr, &fp);
1033}
bool DeSerialize(const TessdataManager *mgr, TFile *fp)

◆ ReadSizedTrainingDump()

bool tesseract::LSTMTrainer::ReadSizedTrainingDump ( const char *  data,
int  size,
LSTMTrainer trainer 
) const
inline

Definition at line 306 of file lstmtrainer.h.

307 {
308 return trainer.ReadLocalTrainingDump(&mgr_, data, size);
309 }

◆ ReadTrainingDump()

bool tesseract::LSTMTrainer::ReadTrainingDump ( const std::vector< char > &  data,
LSTMTrainer trainer 
) const
inline

Definition at line 299 of file lstmtrainer.h.

300 {
301 if (data.empty()) {
302 return false;
303 }
304 return ReadSizedTrainingDump(&data[0], data.size(), trainer);
305 }
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer &trainer) const
Definition: lstmtrainer.h:306

◆ ReduceLayerLearningRates()

int tesseract::LSTMTrainer::ReduceLayerLearningRates ( TFloat  factor,
int  num_samples,
LSTMTrainer samples_trainer 
)

Definition at line 696 of file lstmtrainer.cpp.

697 {
698 enum WhichWay {
699 LR_DOWN, // Learning rate will go down by factor.
700 LR_SAME, // Learning rate will stay the same.
701 LR_COUNT // Size of arrays.
702 };
703 std::vector<std::string> layers = EnumerateLayers();
704 int num_layers = layers.size();
705 std::vector<int> num_weights(num_layers);
706 std::vector<TFloat> bad_sums[LR_COUNT];
707 std::vector<TFloat> ok_sums[LR_COUNT];
708 for (int i = 0; i < LR_COUNT; ++i) {
709 bad_sums[i].resize(num_layers, 0.0);
710 ok_sums[i].resize(num_layers, 0.0);
711 }
712 auto momentum_factor = 1 / (1 - momentum_);
713 std::vector<char> orig_trainer;
714 samples_trainer->SaveTrainingDump(LIGHT, *this, &orig_trainer);
715 for (int i = 0; i < num_layers; ++i) {
716 Network *layer = GetLayer(layers[i]);
717 num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
718 }
719 int iteration = sample_iteration();
720 for (int s = 0; s < num_samples; ++s) {
721 // Which way will we modify the learning rate?
722 for (int ww = 0; ww < LR_COUNT; ++ww) {
723 // Transfer momentum to learning rate and adjust by the ww factor.
724 auto ww_factor = momentum_factor;
725 if (ww == LR_DOWN) {
726 ww_factor *= factor;
727 }
728 // Make a copy of *this, so we can mess about without damaging anything.
729 LSTMTrainer copy_trainer;
730 samples_trainer->ReadTrainingDump(orig_trainer, copy_trainer);
731 // Clear the updates, doing nothing else.
732 copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
733 // Adjust the learning rate in each layer.
734 for (int i = 0; i < num_layers; ++i) {
735 if (num_weights[i] == 0) {
736 continue;
737 }
738 copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
739 }
740 copy_trainer.SetIteration(iteration);
741 // Train on the sample, but keep the update in updates_ instead of
742 // applying to the weights.
743 const ImageData *trainingdata =
744 copy_trainer.TrainOnLine(samples_trainer, true);
745 if (trainingdata == nullptr) {
746 continue;
747 }
748 // We'll now use this trainer again for each layer.
749 std::vector<char> updated_trainer;
750 samples_trainer->SaveTrainingDump(LIGHT, copy_trainer, &updated_trainer);
751 for (int i = 0; i < num_layers; ++i) {
752 if (num_weights[i] == 0) {
753 continue;
754 }
755 LSTMTrainer layer_trainer;
756 samples_trainer->ReadTrainingDump(updated_trainer, layer_trainer);
757 Network *layer = layer_trainer.GetLayer(layers[i]);
758 // Update the weights in just the layer, using Adam if enabled.
759 layer->Update(0.0, momentum_, adam_beta_,
760 layer_trainer.training_iteration_ + 1);
761 // Zero the updates matrix again.
762 layer->Update(0.0, 0.0, 0.0, 0);
763 // Train again on the same sample, again holding back the updates.
764 layer_trainer.TrainOnLine(trainingdata, true);
765 // Count the sign changes in the updates in layer vs in copy_trainer.
766 float before_bad = bad_sums[ww][i];
767 float before_ok = ok_sums[ww][i];
768 layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
769 &ok_sums[ww][i], &bad_sums[ww][i]);
770 float bad_frac =
771 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
772 if (bad_frac > 0.0f) {
773 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
774 }
775 }
776 }
777 ++iteration;
778 }
779 int num_lowered = 0;
780 for (int i = 0; i < num_layers; ++i) {
781 if (num_weights[i] == 0) {
782 continue;
783 }
784 Network *layer = GetLayer(layers[i]);
785 float lr = GetLayerLearningRate(layers[i]);
786 TFloat total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
787 TFloat total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
788 TFloat frac_down = bad_sums[LR_DOWN][i] / total_down;
789 TFloat frac_same = bad_sums[LR_SAME][i] / total_same;
790 tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
791 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
792 if (frac_down < frac_same * kImprovementFraction) {
793 tprintf(" REDUCED\n");
794 ScaleLayerLearningRate(layers[i], factor);
795 ++num_lowered;
796 } else {
797 tprintf(" SAME\n");
798 }
799 }
800 if (num_lowered == 0) {
801 // Just lower everything to make sure.
802 for (int i = 0; i < num_layers; ++i) {
803 if (num_weights[i] > 0) {
804 ScaleLayerLearningRate(layers[i], factor);
805 ++num_lowered;
806 }
807 }
808 }
809 return num_lowered;
810}
const double kImprovementFraction
Definition: lstmtrainer.cpp:66
double TFloat
Definition: tesstypes.h:39
void ScaleLayerLearningRate(const std::string &id, double factor)
std::vector< std::string > EnumerateLayers() const
float GetLayerLearningRate(const std::string &id) const
Network * GetLayer(const std::string &id) const

◆ ReduceLearningRates()

void tesseract::LSTMTrainer::ReduceLearningRates ( LSTMTrainer samples_trainer,
std::string &  log_msg 
)

Definition at line 676 of file lstmtrainer.cpp.

677 {
679 int num_reduced = ReduceLayerLearningRates(
681 log_msg +=
682 "\nReduced learning rate on layers: " + std::to_string(num_reduced);
683 } else {
685 log_msg += "\nReduced learning rate to :" + std::to_string(learning_rate_);
686 }
687 log_msg += "\n";
688}
const double kLearningRateDecay
Definition: lstmtrainer.cpp:52
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:54
void ScaleLearningRate(double factor)
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
int ReduceLayerLearningRates(TFloat factor, int num_samples, LSTMTrainer *samples_trainer)

◆ RollErrorBuffers()

void tesseract::LSTMTrainer::RollErrorBuffers ( )
protected

Definition at line 1406 of file lstmtrainer.cpp.

1406 {
1408 if (NewSingleError(ET_DELTA) > 0.0) {
1410 } else {
1412 }
1414 if (debug_interval_ != 0) {
1415 tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1419 }
1420}
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:157

◆ SaveRecognitionDump()

void tesseract::LSTMTrainer::SaveRecognitionDump ( std::vector< char > *  data) const

Definition at line 1045 of file lstmtrainer.cpp.

1045 {
1046 TFile fp;
1047 fp.OpenWrite(data);
1051}
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_RE_ENABLE
Definition: network.h:97
bool Serialize(const TessdataManager *mgr, TFile *fp) const

◆ SaveTraineddata()

bool tesseract::LSTMTrainer::SaveTraineddata ( const char *  filename)

Definition at line 1036 of file lstmtrainer.cpp.

1036 {
1037 std::vector<char> recognizer_data;
1038 SaveRecognitionDump(&recognizer_data);
1039 mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
1040 recognizer_data.size());
1041 return mgr_.SaveFile(filename, SaveDataToFile);
1042}
void OverwriteEntry(TessdataType type, const char *data, int size)
bool SaveFile(const char *filename, FileWriter writer) const

◆ SaveTrainingDump()

bool tesseract::LSTMTrainer::SaveTrainingDump ( SerializeAmount  serialize_amount,
const LSTMTrainer trainer,
std::vector< char > *  data 
) const

Definition at line 1015 of file lstmtrainer.cpp.

1017 {
1018 TFile fp;
1019 fp.OpenWrite(data);
1020 return trainer.Serialize(serialize_amount, &mgr_, &fp);
1021}

◆ Serialize()

bool tesseract::LSTMTrainer::Serialize ( SerializeAmount  serialize_amount,
const TessdataManager mgr,
TFile fp 
) const

Definition at line 427 of file lstmtrainer.cpp.

428 {
429 if (!LSTMRecognizer::Serialize(mgr, fp)) {
430 return false;
431 }
432 if (!fp->Serialize(&learning_iteration_)) {
433 return false;
434 }
435 if (!fp->Serialize(&prev_sample_iteration_)) {
436 return false;
437 }
438 if (!fp->Serialize(&perfect_delay_)) {
439 return false;
440 }
441 if (!fp->Serialize(&last_perfect_training_iteration_)) {
442 return false;
443 }
444 for (const auto &error_buffer : error_buffers_) {
445 if (!fp->Serialize(error_buffer)) {
446 return false;
447 }
448 }
449 if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) {
450 return false;
451 }
452 if (!fp->Serialize(&training_stage_)) {
453 return false;
454 }
455 uint8_t amount = serialize_amount;
456 if (!fp->Serialize(&amount)) {
457 return false;
458 }
459 if (serialize_amount == LIGHT) {
460 return true; // We are done.
461 }
462 if (!fp->Serialize(&best_error_rate_)) {
463 return false;
464 }
465 if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) {
466 return false;
467 }
468 if (!fp->Serialize(&best_iteration_)) {
469 return false;
470 }
471 if (!fp->Serialize(&worst_error_rate_)) {
472 return false;
473 }
474 if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) {
475 return false;
476 }
477 if (!fp->Serialize(&worst_iteration_)) {
478 return false;
479 }
480 if (!fp->Serialize(&stall_iteration_)) {
481 return false;
482 }
483 if (!fp->Serialize(best_model_data_)) {
484 return false;
485 }
486 if (!fp->Serialize(worst_model_data_)) {
487 return false;
488 }
489 if (serialize_amount != NO_BEST_TRAINER && !fp->Serialize(best_trainer_)) {
490 return false;
491 }
492 std::vector<char> sub_data;
493 if (sub_trainer_ != nullptr &&
494 !SaveTrainingDump(LIGHT, *sub_trainer_, &sub_data)) {
495 return false;
496 }
497 if (!fp->Serialize(sub_data)) {
498 return false;
499 }
500 if (!fp->Serialize(best_error_history_)) {
501 return false;
502 }
503 if (!fp->Serialize(best_error_iterations_)) {
504 return false;
505 }
506 return fp->Serialize(&improvement_steps_);
507}

◆ set_perfect_delay()

void tesseract::LSTMTrainer::set_perfect_delay ( int  delay)
inline

Definition at line 150 of file lstmtrainer.h.

150 {
151 perfect_delay_ = delay;
152 }

◆ SetNullChar()

void tesseract::LSTMTrainer::SetNullChar ( )
protected

Definition at line 1129 of file lstmtrainer.cpp.

1129 {
1131 : GetUnicharset().size();
1132 RecodedCharID code;
1134 null_char_ = code(0);
1135}
@ UNICHAR_BROKEN
Definition: unicharset.h:38
bool has_special_codes() const
Definition: unicharset.h:756

◆ SetupCheckpointInfo()

void tesseract::LSTMTrainer::SetupCheckpointInfo ( )

◆ StartSubtrainer()

void tesseract::LSTMTrainer::StartSubtrainer ( std::string &  log_msg)

Definition at line 605 of file lstmtrainer.cpp.

605 {
606 sub_trainer_ = std::make_unique<LSTMTrainer>();
608 log_msg += " Failed to revert to previous best for trial!";
609 sub_trainer_.reset();
610 } else {
611 log_msg += " Trial sub_trainer_ from iteration " +
612 std::to_string(sub_trainer_->training_iteration());
613 // Reduce learning rate so it doesn't diverge this time.
614 sub_trainer_->ReduceLearningRates(this, log_msg);
615 // If it fails again, we will wait twice as long before reverting again.
616 int stall_offset =
617 learning_iteration() - sub_trainer_->learning_iteration();
618 stall_iteration_ = learning_iteration() + 2 * stall_offset;
619 sub_trainer_->stall_iteration_ = stall_iteration_;
620 // Re-save the best trainer with the new learning rates and stall iteration.
622 }
623}

◆ training_data()

const DocumentCache & tesseract::LSTMTrainer::training_data ( ) const
inline

Definition at line 168 of file lstmtrainer.h.

168 {
169 return training_data_;
170 }

◆ TrainOnLine() [1/2]

Trainability tesseract::LSTMTrainer::TrainOnLine ( const ImageData trainingdata,
bool  batch 
)

Definition at line 871 of file lstmtrainer.cpp.

872 {
873 NetworkIO fwd_outputs, targets;
874 Trainability trainable =
875 PrepareForBackward(trainingdata, &fwd_outputs, &targets);
877 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
878 return trainable; // Sample was unusable.
879 }
880 bool debug =
882 // Run backprop on the output.
883 NetworkIO bp_deltas;
884 if (network_->IsTraining() &&
885 (trainable != PERFECT ||
888 network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
891 }
892#ifndef GRAPHICS_DISABLED
893 if (debug_interval_ == 1 && debug_win_ != nullptr) {
895 }
896#endif // !GRAPHICS_DISABLED
897 // Roll the memory of past means.
899 return trainable;
900}
@ SVET_CLICK
Definition: scrollview.h:55
NetworkScratch scratch_space_
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool IsTraining() const
Definition: network.h:113
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:235
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:445

◆ TrainOnLine() [2/2]

const ImageData * tesseract::LSTMTrainer::TrainOnLine ( LSTMTrainer samples_trainer,
bool  batch 
)
inline

Definition at line 267 of file lstmtrainer.h.

267 {
268 int sample_index = sample_iteration();
269 const ImageData *image =
270 samples_trainer->training_data_.GetPageBySerial(sample_index);
271 if (image != nullptr) {
272 Trainability trainable = TrainOnLine(image, batch);
273 if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
274 return nullptr; // Sample was unusable.
275 }
276 } else {
278 }
279 return image;
280 }
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:267

◆ TransitionTrainingStage()

bool tesseract::LSTMTrainer::TransitionTrainingStage ( float  error_threshold)

Definition at line 417 of file lstmtrainer.cpp.

417 {
418 if (best_error_rate_ < error_threshold &&
421 return true;
422 }
423 return false;
424}

◆ TryLoadingCheckpoint()

bool tesseract::LSTMTrainer::TryLoadingCheckpoint ( const char *  filename,
const char *  old_traineddata 
)

Definition at line 103 of file lstmtrainer.cpp.

104 {
105 std::vector<char> data;
106 if (!LoadDataFromFile(filename, &data)) {
107 return false;
108 }
109 tprintf("Loaded file %s, unpacking...\n", filename);
110 if (!ReadTrainingDump(data, *this)) {
111 return false;
112 }
113 if (IsIntMode()) {
114 tprintf("Error, %s is an integer (fast) model, cannot continue training\n",
115 filename);
116 return false;
117 }
118 if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
120 filename == old_traineddata) {
121 return true; // Normal checkpoint load complete.
122 }
123 tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
125 if (old_traineddata == nullptr || *old_traineddata == '\0') {
126 tprintf("Must supply the old traineddata for code conversion!\n");
127 return false;
128 }
129 TessdataManager old_mgr;
130 ASSERT_HOST(old_mgr.Init(old_traineddata));
131 TFile fp;
132 if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) {
133 return false;
134 }
135 UNICHARSET old_chset;
136 if (!old_chset.load_from_file(&fp, false)) {
137 return false;
138 }
139 if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) {
140 return false;
141 }
142 UnicharCompress old_recoder;
143 if (!old_recoder.DeSerialize(&fp)) {
144 return false;
145 }
146 std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
147 // Set the null_char_ to the new value.
148 int old_null_char = null_char_;
149 SetNullChar();
150 // Map the softmax(s) in the network.
151 network_->RemapOutputs(old_recoder.code_range(), code_map);
152 tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
153 return true;
154}
@ TESSDATA_LSTM_UNICHARSET
@ TESSDATA_LSTM_RECODER
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:190
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const

◆ UpdateErrorBuffer()

void tesseract::LSTMTrainer::UpdateErrorBuffer ( double  new_error,
ErrorTypes  type 
)
protected

Definition at line 1390 of file lstmtrainer.cpp.

1390 {
1392 error_buffers_[type][index] = new_error;
1393 // Compute the mean error.
1394 int mean_count =
1395 std::min<int>(training_iteration_ + 1, error_buffers_[type].size());
1396 double buffer_sum = 0.0;
1397 for (int i = 0; i < mean_count; ++i) {
1398 buffer_sum += error_buffers_[type][i];
1399 }
1400 double mean = buffer_sum / mean_count;
1401 // Trim precision to 1/1000 of 1%.
1402 error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1403}
int IntCastRounded(double x)
Definition: helpers.h:168

◆ UpdateErrorGraph()

std::string tesseract::LSTMTrainer::UpdateErrorGraph ( int  iteration,
double  error_rate,
const std::vector< char > &  model_data,
const TestCallback tester 
)
protected

Definition at line 1426 of file lstmtrainer.cpp.

1428 {
1429 if (error_rate > best_error_rate_ &&
1430 iteration < best_iteration_ + kErrorGraphInterval) {
1431 // Too soon to record a new point.
1432 if (tester != nullptr && !worst_model_data_.empty()) {
1434 worst_model_data_.size());
1435 return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1436 } else {
1437 return "";
1438 }
1439 }
1440 std::string result;
1441 // NOTE: there are 2 asymmetries here:
1442 // 1. We are computing the global minimum, but the local maximum in between.
1443 // 2. If the tester returns an empty string, indicating that it is busy,
1444 // call it repeatedly on new local maxima to test the previous min, but
1445 // not the other way around, as there is little point testing the maxima
1446 // between very frequent minima.
1447 if (error_rate < best_error_rate_) {
1448 // This is a new (global) minimum.
1449 if (tester != nullptr && !worst_model_data_.empty()) {
1451 worst_model_data_.size());
1452 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1454 worst_model_data_.clear();
1455 best_model_data_ = model_data;
1456 }
1457 best_error_rate_ = error_rate;
1459 best_iteration_ = iteration;
1460 best_error_history_.push_back(error_rate);
1461 best_error_iterations_.push_back(iteration);
1462 // Compute 2% decay time.
1463 double two_percent_more = error_rate + 2.0;
1464 int i;
1465 for (i = best_error_history_.size() - 1;
1466 i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1467 }
1468 int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1469 improvement_steps_ = iteration - old_iteration;
1470 tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1471 improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1472 old_iteration);
1473 } else if (error_rate > best_error_rate_) {
1474 // This is a new (local) maximum.
1475 if (tester != nullptr) {
1476 if (!best_model_data_.empty()) {
1478 best_model_data_.size());
1479 result = tester(best_iteration_, best_error_rates_, mgr_,
1481 } else if (!worst_model_data_.empty()) {
1482 // Allow for multiple data points with "worst" error rate.
1484 worst_model_data_.size());
1485 result = tester(worst_iteration_, worst_error_rates_, mgr_,
1487 }
1488 if (result.length() > 0) {
1489 best_model_data_.clear();
1490 }
1491 worst_model_data_ = model_data;
1492 }
1493 }
1494 worst_error_rate_ = error_rate;
1496 worst_iteration_ = iteration;
1497 return result;
1498}
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:56

◆ UpdateSubtrainer()

SubTrainerResult tesseract::LSTMTrainer::UpdateSubtrainer ( std::string &  log_msg)

Definition at line 633 of file lstmtrainer.cpp.

633 {
634 double training_error = CharError();
635 double sub_error = sub_trainer_->CharError();
636 double sub_margin = (training_error - sub_error) / sub_error;
637 if (sub_margin >= kSubTrainerMarginFraction) {
638 log_msg += " sub_trainer=" + std::to_string(sub_error);
639 log_msg += " margin=" + std::to_string(100.0 * sub_margin);
640 log_msg += "\n";
641 // Catch up to current iteration.
642 int end_iteration = training_iteration();
643 while (sub_trainer_->training_iteration() < end_iteration &&
644 sub_margin >= kSubTrainerMarginFraction) {
645 int target_iteration =
646 sub_trainer_->training_iteration() + kNumPagesPerBatch;
647 while (sub_trainer_->training_iteration() < target_iteration) {
648 sub_trainer_->TrainOnLine(this, false);
649 }
650 std::string batch_log = "Sub:";
651 sub_trainer_->PrepareLogMsg(batch_log);
652 batch_log += "\n";
653 tprintf("UpdateSubtrainer:%s", batch_log.c_str());
654 log_msg += batch_log;
655 sub_error = sub_trainer_->CharError();
656 sub_margin = (training_error - sub_error) / sub_error;
657 }
658 if (sub_error < best_error_rate_ &&
659 sub_margin >= kSubTrainerMarginFraction) {
660 // The sub_trainer_ has won the race to a new best. Switch to it.
661 std::vector<char> updated_trainer;
662 SaveTrainingDump(LIGHT, *sub_trainer_, &updated_trainer);
663 ReadTrainingDump(updated_trainer, *this);
664 log_msg += " Sub trainer wins at iteration " +
665 std::to_string(training_iteration());
666 log_msg += "\n";
667 return STR_REPLACED;
668 }
669 return STR_UPDATED;
670 }
671 return STR_NONE;
672}
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58

Member Data Documentation

◆ align_win_

ScrollView* tesseract::LSTMTrainer::align_win_
protected

Definition at line 407 of file lstmtrainer.h.

◆ best_error_history_

std::vector<double> tesseract::LSTMTrainer::best_error_history_
protected

Definition at line 461 of file lstmtrainer.h.

◆ best_error_iterations_

std::vector<int32_t> tesseract::LSTMTrainer::best_error_iterations_
protected

Definition at line 462 of file lstmtrainer.h.

◆ best_error_rate_

double tesseract::LSTMTrainer::best_error_rate_
protected

Definition at line 434 of file lstmtrainer.h.

◆ best_error_rates_

double tesseract::LSTMTrainer::best_error_rates_[ET_COUNT]
protected

Definition at line 436 of file lstmtrainer.h.

◆ best_iteration_

int tesseract::LSTMTrainer::best_iteration_
protected

Definition at line 438 of file lstmtrainer.h.

◆ best_model_data_

std::vector<char> tesseract::LSTMTrainer::best_model_data_
protected

Definition at line 448 of file lstmtrainer.h.

◆ best_model_name_

std::string tesseract::LSTMTrainer::best_model_name_
protected

Definition at line 427 of file lstmtrainer.h.

◆ best_trainer_

std::vector<char> tesseract::LSTMTrainer::best_trainer_
protected

Definition at line 451 of file lstmtrainer.h.

◆ checkpoint_iteration_

int tesseract::LSTMTrainer::checkpoint_iteration_
protected

Definition at line 418 of file lstmtrainer.h.

◆ checkpoint_name_

std::string tesseract::LSTMTrainer::checkpoint_name_
protected

Definition at line 422 of file lstmtrainer.h.

◆ ctc_win_

ScrollView* tesseract::LSTMTrainer::ctc_win_
protected

Definition at line 411 of file lstmtrainer.h.

◆ debug_interval_

int tesseract::LSTMTrainer::debug_interval_
protected

Definition at line 416 of file lstmtrainer.h.

◆ error_buffers_

std::vector<double> tesseract::LSTMTrainer::error_buffers_[ET_COUNT]
protected

Definition at line 483 of file lstmtrainer.h.

◆ error_rate_of_last_saved_best_

float tesseract::LSTMTrainer::error_rate_of_last_saved_best_
protected

Definition at line 456 of file lstmtrainer.h.

◆ error_rates_

double tesseract::LSTMTrainer::error_rates_[ET_COUNT]
protected

Definition at line 485 of file lstmtrainer.h.

◆ improvement_steps_

int32_t tesseract::LSTMTrainer::improvement_steps_
protected

Definition at line 464 of file lstmtrainer.h.

◆ kRollingBufferSize_

const int tesseract::LSTMTrainer::kRollingBufferSize_ = 1000
staticprotected

Definition at line 482 of file lstmtrainer.h.

◆ last_perfect_training_iteration_

int tesseract::LSTMTrainer::last_perfect_training_iteration_
protected

Definition at line 479 of file lstmtrainer.h.

◆ learning_iteration_

int tesseract::LSTMTrainer::learning_iteration_
protected

Definition at line 468 of file lstmtrainer.h.

◆ mgr_

TessdataManager tesseract::LSTMTrainer::mgr_
protected

Definition at line 487 of file lstmtrainer.h.

◆ model_base_

std::string tesseract::LSTMTrainer::model_base_
protected

Definition at line 420 of file lstmtrainer.h.

◆ num_training_stages_

int tesseract::LSTMTrainer::num_training_stages_
protected

Definition at line 429 of file lstmtrainer.h.

◆ perfect_delay_

int tesseract::LSTMTrainer::perfect_delay_
protected

Definition at line 476 of file lstmtrainer.h.

◆ prev_sample_iteration_

int tesseract::LSTMTrainer::prev_sample_iteration_
protected

Definition at line 470 of file lstmtrainer.h.

◆ randomly_rotate_

bool tesseract::LSTMTrainer::randomly_rotate_
protected

Definition at line 424 of file lstmtrainer.h.

◆ recon_win_

ScrollView* tesseract::LSTMTrainer::recon_win_
protected

Definition at line 413 of file lstmtrainer.h.

◆ stall_iteration_

int tesseract::LSTMTrainer::stall_iteration_
protected

Definition at line 446 of file lstmtrainer.h.

◆ sub_trainer_

std::unique_ptr<LSTMTrainer> tesseract::LSTMTrainer::sub_trainer_
protected

Definition at line 454 of file lstmtrainer.h.

◆ target_win_

ScrollView* tesseract::LSTMTrainer::target_win_
protected

Definition at line 409 of file lstmtrainer.h.

◆ training_data_

DocumentCache tesseract::LSTMTrainer::training_data_
protected

Definition at line 425 of file lstmtrainer.h.

◆ training_stage_

int tesseract::LSTMTrainer::training_stage_
protected

Definition at line 458 of file lstmtrainer.h.

◆ worst_error_rate_

double tesseract::LSTMTrainer::worst_error_rate_
protected

Definition at line 440 of file lstmtrainer.h.

◆ worst_error_rates_

double tesseract::LSTMTrainer::worst_error_rates_[ET_COUNT]
protected

Definition at line 442 of file lstmtrainer.h.

◆ worst_iteration_

int tesseract::LSTMTrainer::worst_iteration_
protected

Definition at line 444 of file lstmtrainer.h.

◆ worst_model_data_

std::vector<char> tesseract::LSTMTrainer::worst_model_data_
protected

Definition at line 449 of file lstmtrainer.h.


The documentation for this class was generated from the following files: