forked from optuna/optuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_lightning_simple.py
203 lines (156 loc) · 6.81 KB
/
pytorch_lightning_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Optuna example that optimizes multi-layer perceptrons using PyTorch Lightning.
In this example, we optimize the validation accuracy of hand-written digit recognition using
PyTorch Lightning, and MNIST. We optimize the neural network architecture. As it is too time
consuming to use the whole MNIST dataset, we here use a small subset of it.
We have the following two ways to execute this example:
(1) Execute this code directly. Pruning can be turned on and off with the `--pruning` argument.
$ python pytorch_lightning_simple.py [--pruning]
(2) Execute through CLI. Pruning is enabled automatically.
$ STUDY_NAME=`optuna create-study --direction maximize --storage sqlite:///example.db`
$ optuna study optimize pytorch_lightning_simple.py objective --n-trials=100 --study \
$STUDY_NAME --storage sqlite:///example.db
"""
import argparse
import os
import pkg_resources
import shutil
import pytorch_lightning as pl
from pytorch_lightning.logging import LightningLoggerBase
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import optuna
from optuna.integration import PyTorchLightningPruningCallback
if pkg_resources.parse_version(pl.__version__) < pkg_resources.parse_version("0.6.0"):
raise RuntimeError("PyTorch Lightning>=0.6.0 is required for this example.")
PERCENT_TEST_EXAMPLES = 0.1
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 10
DIR = os.getcwd()
MODEL_DIR = os.path.join(DIR, "result")
class DictLogger(LightningLoggerBase):
"""PyTorch Lightning `dict` logger."""
def __init__(self, version):
super(DictLogger, self).__init__()
self.metrics = []
self._version = version
def log_metrics(self, metric, step=None):
self.metrics.append(metric)
@property
def version(self):
return self._version
class Net(nn.Module):
def __init__(self, trial):
super(Net, self).__init__()
self.layers = []
self.dropouts = []
# We optimize the number of layers, hidden untis in each layer and drouputs.
n_layers = trial.suggest_int("n_layers", 1, 3)
dropout = trial.suggest_uniform("dropout", 0.2, 0.5)
input_dim = 28 * 28
for i in range(n_layers):
output_dim = int(trial.suggest_loguniform("n_units_l{}".format(i), 4, 128))
self.layers.append(nn.Linear(input_dim, output_dim))
self.dropouts.append(nn.Dropout(dropout))
input_dim = output_dim
self.layers.append(nn.Linear(input_dim, CLASSES))
# Assigning the layers as class variables (PyTorch requirement).
for idx, layer in enumerate(self.layers):
setattr(self, "fc{}".format(idx), layer)
# Assigning the dropouts as class variables (PyTorch requirement).
for idx, dropout in enumerate(self.dropouts):
setattr(self, "drop{}".format(idx), dropout)
def forward(self, data):
data = data.view(-1, 28 * 28)
for layer, dropout in zip(self.layers, self.dropouts):
data = F.relu(layer(data))
data = dropout(data)
return F.log_softmax(self.layers[-1](data), dim=1)
class LightningNet(pl.LightningModule):
def __init__(self, trial):
super(LightningNet, self).__init__()
# Be careful not to overwrite `pl.LightningModule` attributes such as `self.model`.
self._model = Net(trial)
def forward(self, data):
return self._model(data)
def training_step(self, batch, batch_nb):
data, target = batch
output = self.forward(data)
loss = F.nll_loss(output, target)
return {"loss": loss}
def validation_step(self, batch, batch_nb):
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
correct = pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / data.size(0)
return {"validation_accuracy": accuracy}
def validation_end(self, outputs):
accuracy = sum(x["validation_accuracy"] for x in outputs) / len(outputs)
# Pass the accuracy to the `DictLogger` via the `'log'` key.
return {"log": {"accuracy": accuracy}}
def configure_optimizers(self):
return Adam(self._model.parameters())
@pl.data_loader
def train_dataloader(self):
return torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=True,
)
@pl.data_loader
def val_dataloader(self):
return torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=False, download=True, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=False,
)
def objective(trial):
# PyTorch Lightning will try to restore model parameters from previous trials if checkpoint
# filenames match. Therefore, the filenames for each trial must be made unique.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
os.path.join(MODEL_DIR, "trial_{}".format(trial.number)), monitor="accuracy"
)
# The default logger in PyTorch Lightning writes to event files to be consumed by
# TensorBoard. We create a simple logger instead that holds the log in memory so that the
# final accuracy can be obtained after optimization. When using the default logger, the
# final accuracy could be stored in an attribute of the `Trainer` instead.
logger = DictLogger(trial.number)
trainer = pl.Trainer(
logger=logger,
val_percent_check=PERCENT_TEST_EXAMPLES,
checkpoint_callback=checkpoint_callback,
max_epochs=EPOCHS,
gpus=0 if torch.cuda.is_available() else None,
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="accuracy"),
)
model = LightningNet(trial)
trainer.fit(model)
return logger.metrics[-1]["accuracy"]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
args = parser.parse_args()
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
shutil.rmtree(MODEL_DIR)