This repository has been archived by the owner on Aug 22, 2024. It is now read-only.
forked from prdwb/bert_hae
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcqa_gen_batches.py
143 lines (115 loc) · 5.59 KB
/
cqa_gen_batches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import math
import os
import modeling
import optimization
import tokenization
import six
import tensorflow as tf
import numpy as np
def cqa_gen_batches(features, batch_size, num_epoches, shuffle=False):
num_examples = len(features)
if shuffle:
np.random.seed(0)
idx = np.random.permutation(num_examples)
features_shuffled = np.asarray(features)[idx]
else:
features_shuffled = np.asarray(features)
num_steps = math.ceil(num_examples / batch_size)
for _ in range(int(num_epoches)):
i = 0
for _ in range(num_steps):
batch_features = features_shuffled[i: i + batch_size]
batch_unique_ids = []
batch_input_ids = []
batch_input_mask = []
batch_segment_ids = []
batch_start_positions = []
batch_end_positions = []
batch_history_answer_marker = []
batch_metadata = []
for feature in batch_features:
batch_unique_ids.append(feature.unique_id)
batch_input_ids.append(feature.input_ids)
batch_input_mask.append(feature.input_mask)
batch_segment_ids.append(feature.segment_ids)
batch_start_positions.append(feature.start_position)
batch_end_positions.append(feature.end_position)
batch_history_answer_marker.append(feature.history_answer_marker)
batch_metadata.append(feature.metadata)
i += batch_size
yield (batch_unique_ids, batch_input_ids, batch_input_mask, batch_segment_ids,
batch_start_positions, batch_end_positions, batch_history_answer_marker, batch_metadata)
def cqa_gen_example_batches(examples, batch_size, num_epoches, shuffle=False):
num_examples = len(examples)
if shuffle:
np.random.seed(0)
idx = np.random.permutation(num_examples)
examples_shuffled = np.asarray(examples)[idx]
else:
examples_shuffled = np.asarray(examples)
num_steps = math.ceil(num_examples / batch_size)
for _ in range(int(num_epoches)):
i = 0
for _ in range(num_steps):
batch_examples = examples_shuffled[i: i + batch_size]
i += batch_size
yield batch_examples
def cqa_gen_example_aware_batches(features, example_tracker, variation_tracker, example_features_nums, batch_size, num_epoches, shuffle=False):
# generate example-aware batches: generate batches that contain the features for FLAGS.example_batch_size examples
# the training examples have been shuffled before this function, so no need to shuffle here
# num_examples = len(features)
# if shuffle:
# np.random.seed(0)
# idx = np.random.permutation(num_examples)
# features_shuffled = np.asarray(features)[idx]
# else:
# features_shuffled = np.asarray(features)
# num_steps = math.ceil(num_examples / batch_size)
for _ in range(int(num_epoches)):
# we greedily select all the features that are generated by the next example,
# as long as the sum of example_features does not exceed FLAGS.train_batch_size
start_example_index, end_example_index = 0, 0
while start_example_index in example_tracker:
features_sum = example_features_nums[start_example_index]
while features_sum <= batch_size:
end_example_index += 1
try:
features_sum += example_features_nums[end_example_index]
except:
break
start_index = example_tracker.index(start_example_index)
# sometimes an example generates more features than a batch can handle
if end_example_index == start_example_index:
end_example_index += 1
try:
end_index = example_tracker.index(end_example_index)
except:
end_index = None
batch_features = features[start_index: end_index]
batch_example_tracker = example_tracker[start_index: end_index]
batch_variation_tracker = variation_tracker[start_index: end_index]
start_example_index = end_example_index
assert len(batch_features) > 0
yield batch_features, batch_example_tracker, batch_variation_tracker
print('epoch finished!')
# for _ in range(int(num_epoches)):
# start_example_index = 0
# end_example_index = start_example_index + example_batch_size # this is actually the first example index in the next batch
# while start_example_index in example_tracker:
# start_index = example_tracker.index(start_example_index)
# try:
# end_index = example_tracker.index(end_example_index)
# except:
# end_index = None
# batch_features = features[start_index: end_index]
# batch_example_tracker = example_tracker[start_index: end_index]
# batch_variation_tracker = variation_tracker[start_index: end_index]
# start_example_index += example_batch_size
# end_example_index += example_batch_size
# yield batch_features, batch_example_tracker, batch_variation_tracker
# print('epoch finished!')