-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdict-utils.h
125 lines (97 loc) · 3.34 KB
/
dict-utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#pragma once
#include <memory>
#include <iostream>
#include <sstream>
#include <vector>
#include "dynet/dict.h"
using namespace std;
using namespace dynet;
inline void load_vocabs(const std::string& src_vocab_file, const std::string& trg_vocab_file
, dynet::Dict& sd, dynet::Dict& td, bool freeze=true);
inline void load_vocab(const std::string& vocab_file
, dynet::Dict& d, bool freeze=true);
inline void load_joint_vocab(const std::string& vocab_file
, dynet::Dict& sd, dynet::Dict& td, bool freeze=true);
inline void save_vocabs(const std::string& src_vocab_file, const std::string& trg_vocab_file
, dynet::Dict& sd, dynet::Dict& td);
inline void save_vocab(const std::string& vocab_file
, dynet::Dict& d);
inline void load_vocabs(const std::string& src_vocab_file, const std::string& trg_vocab_file
, dynet::Dict& sd, dynet::Dict& td, bool freeze)
{
if ("" == src_vocab_file || "" == trg_vocab_file) return;
cerr << endl << "Loading vocabularies from files..." << endl;
cerr << "Source vocabulary file: " << src_vocab_file << endl;
cerr << "Target vocabulary file: " << trg_vocab_file << endl;
ifstream if_src_vocab(src_vocab_file), if_trg_vocab(trg_vocab_file);
std::string sword, tword;
while (getline(if_src_vocab, sword)) sd.convert(sword);
while (getline(if_trg_vocab, tword)) td.convert(tword);
// automatically add sentinel markers
sd.convert("<s>");// source
sd.convert("</s>");
sd.convert("<unk>");
td.convert("<s>");// target
td.convert("</s>");
td.convert("<unk>");
cerr << "Source vocabulary size: " << sd.size() << endl;
cerr << "Target vocabulary size: " << td.size() << endl;
if (freeze){
sd.freeze();
td.freeze();
}
}
inline void load_joint_vocab(const std::string& vocab_file
, dynet::Dict& sd, dynet::Dict& td, bool freeze)
{
if ("" == vocab_file) return;
cerr << "Loading joint source and target vocabulary from file: " << vocab_file << endl;
ifstream if_vocab(vocab_file);
std::string word;
while (getline(if_vocab, word)) sd.convert(word);
// automatically add sentinel markers
sd.convert("<s>");
sd.convert("</s>");
sd.convert("<unk>");
cerr << "Joint vocabulary size: " << sd.size() << endl;
if (freeze) sd.freeze();
td = sd;
}
inline void load_vocab(const std::string& vocab_file
, dynet::Dict& d, bool freeze)
{
if ("" == vocab_file) return;
cerr << endl << "Loading vocabulary from file..." << endl;
cerr << "Vocabulary file: " << vocab_file << endl;
ifstream if_vocab(vocab_file);
std::string word;
while (getline(if_vocab, word)) d.convert(word);
// automatically add sentinel markers
d.convert("<s>");
d.convert("</s>");
d.convert("<unk>");
cerr << "Vocabulary size: " << d.size() << endl;
if (freeze) d.freeze();
}
inline void save_vocabs(const std::string& src_vocab_file, const std::string& trg_vocab_file
, dynet::Dict& sd, dynet::Dict& td)
{
if ("" == src_vocab_file || "" == trg_vocab_file) return;
const auto& swords = sd.get_words();
const auto& twords = td.get_words();
ofstream of_svocab(src_vocab_file);
for (auto& sword : swords)
of_svocab << sword << endl;
ofstream of_tvocab(trg_vocab_file);
for (auto& tword : twords)
of_tvocab << tword << endl;
}
inline void save_vocab(const std::string& vocab_file
, dynet::Dict& d)
{
if ("" == vocab_file) return;
const auto& words = d.get_words();
ofstream of_vocab(vocab_file);
for (auto& word : words)
of_vocab << word << endl;
}