70 {
71
73 UNICHARSET unicharset;
74 ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
76 std::vector<std::string> words;
77 EXPECT_EQ(0,
CombineLangModel(unicharset, script_dir,
"", FLAGS_test_tmpdir, kLang, !recode,
78 words, words, words, false, nullptr, nullptr));
79 std::string model_path =
file::JoinPath(FLAGS_test_tmpdir, model_name);
80 std::string checkpoint_path = model_path + "_checkpoint";
81 trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
84 int net_mode = adam ?
NF_ADAM : 0;
85
86
87 if (adam) {
88 learning_rate *= 20.0f;
89 }
90 if (layer_specific) {
92 }
93 EXPECT_TRUE(
94 trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95 std::vector<std::string> filenames;
98 LOG(
INFO) <<
"Setup network:" << model_name <<
"\n";
99 }
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
static std::string JoinPath(const std::string &s1, const std::string &s2)
std::unique_ptr< LSTMTrainer > trainer_
std::string TestDataNameToPath(const std::string &name)