forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract.py
executable file
·90 lines (71 loc) · 2.85 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Extracts random constraints from reference files."""
import argparse
import random
import sys
def get_phrase(words, index, length):
assert index < len(words) - length + 1
phr = " ".join(words[index : index + length])
for i in range(index, index + length):
words.pop(index)
return phr
def main(args):
if args.seed:
random.seed(args.seed)
for line in sys.stdin:
constraints = []
def add_constraint(constraint):
constraints.append(constraint)
source = line.rstrip()
if "\t" in line:
source, target = line.split("\t")
if args.add_sos:
target = f"<s> {target}"
if args.add_eos:
target = f"{target} </s>"
if len(target.split()) >= args.len:
words = [target]
num = args.number
choices = {}
for i in range(num):
if len(words) == 0:
break
segmentno = random.choice(range(len(words)))
segment = words.pop(segmentno)
tokens = segment.split()
phrase_index = random.choice(range(len(tokens)))
choice = " ".join(
tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
)
for j in range(
phrase_index, min(len(tokens), phrase_index + args.len)
):
tokens.pop(phrase_index)
if phrase_index > 0:
words.append(" ".join(tokens[0:phrase_index]))
if phrase_index + 1 < len(tokens):
words.append(" ".join(tokens[phrase_index:]))
choices[target.find(choice)] = choice
# mask out with spaces
target = target.replace(choice, " " * len(choice), 1)
for key in sorted(choices.keys()):
add_constraint(choices[key])
print(source, *constraints, sep="\t")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
parser.add_argument(
"--add-sos", default=False, action="store_true", help="add <s> token"
)
parser.add_argument(
"--add-eos", default=False, action="store_true", help="add </s> token"
)
parser.add_argument("--seed", "-s", default=0, type=int)
args = parser.parse_args()
main(args)