Skip to content

Commit

Permalink
touching up predict
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Oct 25, 2018
1 parent fb0dcaa commit eab80b6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
25 changes: 17 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def to_iter(data, bs, device):
def run(args, field, val_sets, model):
device = set_seed(args)
print(f'Preparing iterators')
if len(args.val_batch_size) == 1 and len(val_sets) > 1:
args.val_batch_size *= len(val_sets)
iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]

def mult(ps):
Expand All @@ -82,10 +84,11 @@ def mult(ps):
model.eval()
with torch.no_grad():
for task, it in iters:
print(task)
prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt')
answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt')
results_file_name = answer_file_name.replace('gold', 'results')
if 'sql' in task:
if 'sql' in task or 'squad' in task:
ids_file_name = answer_file_name.replace('gold', 'ids')
if os.path.exists(prediction_file_name):
print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **')
Expand All @@ -104,14 +107,15 @@ def mult(ps):
if not os.path.exists(prediction_file_name) or args.overwrite_predictions:
with open(prediction_file_name, 'w') as prediction_file:
predictions = []
wikisql_ids = []
ids = []
for batch_idx, batch in enumerate(it):
_, p = model(batch)
p = field.reverse(p)
for i, pp in enumerate(p):
if 'sql' in task:
wikisql_id = int(batch.wikisql_id[i])
wikisql_ids.append(wikisql_id)
ids.append(int(batch.wikisql_id[i]))
if 'squad' in task:
ids.append(it.dataset.q_ids[int(batch.squad_id[i])])
prediction_file.write(pp + '\n')
predictions.append(pp)
else:
Expand All @@ -120,9 +124,14 @@ def mult(ps):

if 'sql' in task:
with open(ids_file_name, 'w') as id_file:
for i in wikisql_ids:
for i in ids:
id_file.write(json.dumps(i) + '\n')


if 'squad' in task:
with open(ids_file_name, 'w') as id_file:
for i in ids:
id_file.write(i + '\n')

def from_all_answers(an):
return [it.dataset.all_answers[sid] for sid in an.tolist()]

Expand Down Expand Up @@ -165,7 +174,7 @@ def get_args():
parser = ArgumentParser()
parser.add_argument('--path', required=True)
parser.add_argument('--evaluate', type=str, required=True)
parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out', 'schema'], nargs='+')
parser.add_argument('--tasks', default=['squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+')
parser.add_argument('--gpus', default=[0], nargs='+', type=int, help='a list of gpus that can be used (multi-gpu currently WIP)')
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.')
Expand All @@ -180,7 +189,7 @@ def get_args():

with open(os.path.join(args.path, 'config.json')) as config_file:
config = json.load(config_file)
retrieve = ['model', 'val_batch_size',
retrieve = ['model',
'transformer_layers', 'rnn_layers', 'transformer_hidden',
'dimension', 'load', 'max_val_context_length', 'val_batch_size',
'transformer_heads', 'max_output_length', 'max_generative_vocab',
Expand Down
8 changes: 5 additions & 3 deletions text/torchtext/datasets/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))

examples, all_answers = [], []
examples, all_answers, q_ids = [], [], []
if os.path.exists(cache_name):
print(f'Loading cached data from {cache_name}')
examples, all_answers = torch.load(cache_name)
examples, all_answers, q_ids = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
squad = json.load(f)['data']
Expand All @@ -226,6 +226,7 @@ def __init__(self, path, field, subsample=None, **kwargs):
qas = paragraph['qas']
for qa in qas:
question = ' '.join(qa['question'].split())
q_ids.append(qa['id'])
squad_id = len(all_answers)
context_question = get_context_question(context, question)
if len(qa['answers']) == 0:
Expand Down Expand Up @@ -303,7 +304,7 @@ def __init__(self, path, field, subsample=None, **kwargs):

os.makedirs(os.path.dirname(cache_name), exist_ok=True)
print(f'Caching data to {cache_name}')
torch.save((examples, all_answers), cache_name)
torch.save((examples, all_answers, q_ids), cache_name)


FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
Expand All @@ -315,6 +316,7 @@ def __init__(self, path, field, subsample=None, **kwargs):

super(SQuAD, self).__init__(examples, fields, **kwargs)
self.all_answers = all_answers
self.q_ids = q_ids


@classmethod
Expand Down

0 comments on commit eab80b6

Please sign in to comment.