Skip to content

Commit

Permalink
Move the training example
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-l committed Jun 14, 2022
1 parent 418888a commit bb30664
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := tests src utils
check_dirs := examples tests src utils

modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,23 @@
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import InterpolationMode, CenterCrop, Compose, Lambda, RandomRotation, RandomHorizontalFlip, Resize, ToTensor
from torchvision.transforms import (
Compose,
InterpolationMode,
Lambda,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup


def set_seed(seed):
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
Expand All @@ -33,7 +42,7 @@ def set_seed(seed):
dropout=0.0,
num_res_blocks=2,
resamp_with_conv=True,
resolution=32
resolution=32,
)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
Expand All @@ -44,15 +53,15 @@ def set_seed(seed):

augmentations = Compose(
[
RandomHorizontalFlip(),
RandomRotation(15, interpolation=InterpolationMode.BILINEAR, fill=1),
Resize(32, interpolation=InterpolationMode.BILINEAR),
CenterCrop(32),
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(32),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
dataset = load_dataset("huggan/flowers-102-categories", split="train")


def transforms(examples):
Expand Down Expand Up @@ -127,5 +136,5 @@ def transforms(examples):
image_pil = PIL.Image.fromarray(image_processed[0])

# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
pipeline.save_pretrained("./flowers-ddpm")
image_pil.save(f"./flowers-ddpm/test_{epoch}.png")
2 changes: 1 addition & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch

from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
Expand Down

0 comments on commit bb30664

Please sign in to comment.