-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_nlvr.py
87 lines (76 loc) · 2.62 KB
/
extract_nlvr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import json
import os
import yaml
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(
description="Extract prompts YAML file from raw NLVR dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"nlvr_metadata",
type=str,
help="Path to the NLVR metadata file (e.g., test1.json)",
)
parser.add_argument(
"data_dir",
type=str,
help="Path to the directory containing the NLVR images",
)
parser.add_argument(
"prompt_file",
type=str,
help="Path to the output YAML file",
)
args = parser.parse_args()
with open(args.nlvr_metadata, "r") as f:
raw_data = [json.loads(line) for line in f]
prompts_by_sentence = {}
extracted_pairs = 0
for sample in tqdm(raw_data, total=len(raw_data)):
split, set_id, pair_id, sentence_id = sample["identifier"].split("-")
sentence_uid = f"{split}-{set_id}-{sentence_id}"
if sentence_uid not in prompts_by_sentence:
prompts_by_sentence[sentence_uid] = dict(
sentence=sample["sentence"],
pairs=[],
)
image_prefix = "-".join(sample["identifier"].split("-")[:-1])
left_image = f"{image_prefix}-img0.png"
right_image = f"{image_prefix}-img1.png"
if not os.path.exists(os.path.join(args.data_dir, left_image)):
continue
if not os.path.exists(os.path.join(args.data_dir, right_image)):
continue
assert (
prompts_by_sentence[sentence_uid]["sentence"] == sample["sentence"]
), f"Sentence mismatch for {sentence_uid} for pair {pair_id}"
extracted_pairs += 1
prompts_by_sentence[sentence_uid]["pairs"].append(
dict(
id=int(pair_id),
left_image=left_image,
right_image=right_image,
label=sample["label"].lower() == "true",
)
)
prompts = []
for sentence_uid, sentence_data in prompts_by_sentence.items():
prompts.append(
dict(
id=sentence_uid,
prompt=dict(
statement=sentence_data["sentence"],
),
pairs=sentence_data["pairs"],
)
)
print(
f"Extracted {len(prompts)} sentences with a total of {extracted_pairs} pairs."
)
os.makedirs(os.path.dirname(args.prompt_file), exist_ok=True)
with open(args.prompt_file, "w") as f:
yaml.dump(prompts, f, sort_keys=False)
if __name__ == "__main__":
main()