-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathaligner.py
43 lines (32 loc) · 1.17 KB
/
aligner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import webdataset as wds
import os
print('Specify the tar file you want to align.')
tar_filepath = input() # 'watermark_ds_v1.tar.gz'
assert os.path.isfile(tar_filepath), 'The specified file ({}) is not a valid .tar(.gz) file.'.format(tar_filepath)
print('Specify the keys you want to keep in the aligend .tar(.gz) file (e.g., jpg,txt for jpg and txt).')
cols = input()
cols_set = set(cols.split(','))
ds = wds.WebDataset(tar_filepath)
values = {}
for col in cols_set:
values[col] = {}
for i, row in enumerate(ds):
remaining_keys = set(row.keys()).intersection(cols_set)
for key in remaining_keys:
values[key][row['__key__']] = row[key]
if i % 10000 == 0:
print('Processed {} rows.'.format(i), end='\r')
for col in cols_set:
print('Processed {} values for {}'.format(len(values[col].keys()), col))
sink = wds.TarWriter(tar_filepath + '_aligned', encoder=False)
sink.write
for i, key in enumerate(values[list(cols_set)[0]].keys()):
sample = {
"__key__": key,
}
for col in cols_set:
sample[col] = values[col][key]
sink.write(sample)
if i % 10000 == 0:
print('Saved {} aligned rows.'.format(i))
sink.close()