forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext2vocabulary.py
executable file
·83 lines (73 loc) · 2.44 KB
/
text2vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import logging
import six
import sys
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description="create a vocabulary file from text files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--output", "-o", default="", type=str, help="output a vocabulary file"
)
parser.add_argument("--cutoff", "-c", default=0, type=int, help="cut-off frequency")
parser.add_argument(
"--vocabsize", "-s", default=20000, type=int, help="vocabulary size"
)
parser.add_argument("text_files", nargs="*", help="input text files")
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
# count the word occurrences
counts = {}
exclude = ["<sos>", "<eos>", "<unk>"]
if len(args.text_files) == 0:
args.text_files.append("-")
for fn in args.text_files:
fd = (
codecs.open(fn, "r", encoding="utf-8")
if fn != "-"
else codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer
)
)
for ln in fd.readlines():
for tok in ln.split():
if tok not in exclude:
if tok not in counts:
counts[tok] = 1
else:
counts[tok] += 1
if fn != "-":
fd.close()
# limit the vocabulary size
total_count = sum(counts.values())
invocab_count = 0
vocabulary = []
for w, c in sorted(counts.items(), key=lambda x: -x[1]):
if c <= args.cutoff:
break
if len(vocabulary) >= args.vocabsize:
break
vocabulary.append(w)
invocab_count += c
logging.warning(
"OOV rate = %.2f %%" % (float(total_count - invocab_count) / total_count * 100)
)
# write the vocabulary
fd = (
codecs.open(args.output, "w", encoding="utf-8")
if args.output
else codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
)
six.print_("<unk> 1", file=fd)
for n, w in enumerate(sorted(vocabulary)):
six.print_("%s %d" % (w, n + 2), file=fd)
if args.output:
fd.close()