Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Mar 26, 2024
1 parent 11cca06 commit 9ad6415
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,25 @@ def main():
keys=[x[0] for x in tups])
else:
sub_tups = tups[rank::world_size]
sub_out_file = f'results/{model_name}_{dname}_{rank}.pkl'
sub_res = {}
for t in tqdm(sub_tups):
index, prompt = t
sub_res[index] = model.generate(prompt)
import portalocker

with portalocker.Lock(out_file, timeout=20) as fh:
res = load(out_file)
res.update(sub_res)
dump(res, out_file)
fh.flush()
os.fsync(fh.fileno())
dump(sub_res, sub_out_file)

if world_size > 1:
dist.barrier()

if rank == 0:
res = {}
for i in range(world_size):
sub_out_file = f'results/{model_name}_{dname}_{i}.pkl'
res.update(load(sub_out_file))
if osp.exists(out_file):
res.update(load(out_file))
dump(res, out_file)

res = load(out_file)
meta['prediction'] = [res[k] for k in meta['index']]
dump(meta, f'results/{model_name}_{dname}.xlsx')
Expand All @@ -103,6 +106,7 @@ def main():
acc = dataset.evaluate(meta)
results[f'{model_name}_{dname}'] = acc
dump(results, RESULT_FILE)
shutil.remove(sub_out_file)

if __name__ == '__main__':
main()

0 comments on commit 9ad6415

Please sign in to comment.