Skip to content

Commit

Permalink
Merge pull request optuna#2983 from nzw0301/simplified-pl-test
Browse files Browse the repository at this point in the history
Simplify the DDP model definition in the test of pytorch-lightning
  • Loading branch information
himkt authored Oct 9, 2021
2 parents d9dcb3b + fe7163a commit 36a2f74
Showing 1 changed file with 3 additions and 34 deletions.
37 changes: 3 additions & 34 deletions tests/integration_tests/test_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def validation_step( # type: ignore
return {"validation_accuracy": accuracy}

def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]) -> None:
if not len(outputs):
return

accuracy = sum(x["validation_accuracy"] for x in outputs) / len(outputs)
self.log("accuracy", accuracy)
Expand All @@ -69,24 +71,10 @@ def _generate_dummy_dataset(self) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(dataset, batch_size=1)


class ModelDDP(pl.LightningModule):
class ModelDDP(Model):
def __init__(self) -> None:

super().__init__()
self._model = nn.Sequential(nn.Linear(4, 8))

def forward(self, data: torch.Tensor) -> torch.Tensor: # type: ignore

return self._model(data)

def training_step( # type: ignore
self, batch: List[torch.Tensor], batch_nb: int
) -> Dict[str, torch.Tensor]:

data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return {"loss": loss}

def validation_step( # type: ignore
self, batch: List[torch.Tensor], batch_nb: int
Expand All @@ -104,25 +92,6 @@ def validation_step( # type: ignore

self.log("accuracy", accuracy, sync_dist=True)

def configure_optimizers(self) -> torch.optim.Optimizer:

return torch.optim.SGD(self._model.parameters(), lr=1e-2)

def train_dataloader(self) -> torch.utils.data.DataLoader:

return self._generate_dummy_dataset()

def val_dataloader(self) -> torch.utils.data.DataLoader:

return self._generate_dummy_dataset()

def _generate_dummy_dataset(self) -> torch.utils.data.DataLoader:

data = torch.zeros(3, 4, dtype=torch.float32)
target = torch.zeros(3, dtype=torch.int64)
dataset = torch.utils.data.TensorDataset(data, target)
return torch.utils.data.DataLoader(dataset, batch_size=1)


def test_pytorch_lightning_pruning_callback() -> None:
def objective(trial: optuna.trial.Trial) -> float:
Expand Down

0 comments on commit 36a2f74

Please sign in to comment.