Skip to content

Commit

Permalink
Fix for data_loader in pl version 0.7
Browse files Browse the repository at this point in the history
  • Loading branch information
AntixK committed Mar 12, 2020
1 parent 6d6a3f3 commit 90172be
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
10 changes: 6 additions & 4 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from torch import optim
from models import BaseVAE
from models.types_ import *
from utils import data_loader
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader



class VAEXperiment(pl.LightningModule):

def __init__(self,
Expand Down Expand Up @@ -50,7 +50,7 @@ def validation_step(self, batch, batch_idx, optimizer_idx = 0):

results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = self.params['batch_size']/ self.num_train_imgs,
M_N = self.params['batch_size']/ self.num_val_imgs,
optimizer_idx = optimizer_idx,
batch_idx = batch_idx)

Expand Down Expand Up @@ -132,7 +132,7 @@ def configure_optimizers(self):
except:
return optims

@pl.data_loader
@data_loader
def train_dataloader(self):
transform = self.data_transforms()

Expand All @@ -150,7 +150,7 @@ def train_dataloader(self):
shuffle = True,
drop_last=True)

@pl.data_loader
@data_loader
def val_dataloader(self):
transform = self.data_transforms()

Expand All @@ -162,8 +162,10 @@ def val_dataloader(self):
batch_size= 144,
shuffle = True,
drop_last=True)
self.num_val_imgs = len(self.sample_dataloader)
else:
raise ValueError('Undefined dataset type')

return self.sample_dataloader

def data_transforms(self):
Expand Down
22 changes: 22 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytorch_lightning as pl


## Utils to handle newer PyTorch Lightning changes from version 0.6
## ==================================================================================================== ##


def data_loader(fn):
"""
Decorator to handle the deprecation of data_loader from 0.7
:param fn: User defined data loader function
:return: A wrapper for the data_loader function
"""

def func_wrapper(self):
try: # Works for version 0.6.0
return pl.data_loader(fn)(self)

except: # Works for version > 0.6.0
return fn(self)

return func_wrapper

0 comments on commit 90172be

Please sign in to comment.