forked from salesforce/BLIP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpretrain_dataset.py
59 lines (41 loc) · 1.59 KB
/
pretrain_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from data.utils import pre_caption
import os,glob
class pretrain_dataset(Dataset):
def __init__(self, ann_file, laion_path, transform):
self.ann_pretrain = []
for f in ann_file:
print('loading '+f)
ann = json.load(open(f,'r'))
self.ann_pretrain += ann
self.laion_path = laion_path
if self.laion_path:
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
print('loading '+self.laion_files[0])
with open(self.laion_files[0],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
else:
self.annotation = self.ann_pretrain
self.transform = transform
def reload_laion(self, epoch):
n = epoch%len(self.laion_files)
print('loading '+self.laion_files[n])
with open(self.laion_files[n],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann['image']).convert('RGB')
image = self.transform(image)
caption = pre_caption(ann['caption'],30)
return image, caption