Skip to content

aidings/image_bucket

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

image_bucket

图像桶dataset,按照不同的分辨率对图像进行batch处理,主要用于SD的多分辨率训练

1.install

pip install git+https://github.com/aidings/image_bucket.git

2.example

  1. 下载1.png和samples.jsonl
  2. 运行如下代码
import json
import torch
from typing import List
from image_bucket import ImageBuckets, BuckNode
from PIL import Image
from torchvision import transforms

class SDDataset(ImageBuckets):
    def __init__(self, json_path, tokenizer, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = tokenizer
        self.data = self.__load(json_path)


        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])
    
    def __load(self, json_path):
        data = []
        with open(json_path, 'r') as f:
            idx = 0
            for line in f:
                jdict = json.loads(line.strip())
                data.append(jdict)
                file_path = jdict['file_path']
                image = Image.open(file_path) 
                w, h = image.size
                self.inject(BuckNode(w, h, idx))
                idx += 1
        return data
    
    def __get_input_ids(self, caption):
        input_ids = self.tokenizer(
            caption, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
        ).input_ids

        return input_ids

    def transforms(self, bidxs: List[int], bucket, resolution):
        image_tensors = []
        input_ids = []
        for idx in bidxs:
            jdict = self.data[idx]
            file_path = jdict['file_path']
            tokens = self.__get_input_ids(jdict['caption'])
            image = Image.open(file_path).convert('RGB')
            image = image.resize(resolution)
            image_tensor = self.trans(image)
            image_tensors.append(image_tensor)
            input_ids.append(tokens)

        return torch.stack(image_tensors), torch.stack(input_ids)
    
    def collate_fn(self, batch):
        image_tensors, input_ids = batch[0]
        return image_tensors, input_ids

    

if __name__ == '__main__':
    from transformers import CLIPTokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    sd = SDDataset('./samples.jsonl', tokenizer)
    sd.make(4, shuffle=True)

    dataloader = torch.utils.data.DataLoader(sd, batch_size=1, shuffle=True, num_workers=4, collate_fn = sd.collate_fn)

    for data in dataloader:
        print(data[0].shape, data[1].shape)

About

image bucket for sd trainer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages