-
-
Notifications
You must be signed in to change notification settings - Fork 647
/
Copy pathmain.py
341 lines (275 loc) · 12.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from datetime import datetime
from pathlib import Path
import fire
import torch
import torch.nn as nn
import torch.optim as optim
import utils
from torch.amp import autocast
from torch.cuda.amp import GradScaler
import ignite
import ignite.distributed as idist
from ignite.contrib.engines import common
from ignite.engine import create_supervised_evaluator, Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine, PiecewiseLinear
from ignite.metrics import Accuracy, Loss
from ignite.utils import manual_seed, setup_logger
def training(local_rank, config):
rank = idist.get_rank()
manual_seed(config["seed"] + rank)
device = idist.device()
logger = setup_logger(name="CIFAR10-QAT-Training", distributed_rank=local_rank)
log_basic_info(logger, config)
output_path = config["output_path"]
if rank == 0:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
output_path = Path(output_path) / folder_name
if not output_path.exists():
output_path.mkdir(parents=True)
config["output_path"] = output_path.as_posix()
logger.info(f"Output path: {config['output_path']}")
if "cuda" in device.type:
config["cuda device name"] = torch.cuda.get_device_name(local_rank)
if config["with_clearml"]:
from clearml import Task
task = Task.init("CIFAR10-Training", task_name=output_path.stem)
task.connect_configuration(config)
# Log hyper parameters
hyper_params = [
"model",
"batch_size",
"momentum",
"weight_decay",
"num_epochs",
"learning_rate",
"num_warmup_epochs",
]
task.connect({k: config[k] for k in hyper_params})
# Setup dataflow, model, optimizer, criterion
train_loader, test_loader = get_dataflow(config)
config["num_iters_per_epoch"] = len(train_loader)
model, optimizer, criterion, lr_scheduler = initialize(config)
# Create trainer for current task
trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger)
# Let's now setup evaluator engine to perform model's validation and compute metrics
metrics = {
"Accuracy": Accuracy(),
"Loss": Loss(criterion),
}
# We define two evaluators as they wont have exactly similar roles:
# - `evaluator` will save the best model based on validation score
evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True)
def run_validation(engine):
epoch = trainer.state.epoch
state = train_evaluator.run(train_loader)
log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics)
state = evaluator.run(test_loader)
log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation)
if rank == 0:
# Setup TensorBoard logging on trainer and evaluators. Logged values are:
# - Training metrics, e.g. running average loss values
# - Learning rate
# - Evaluation train/test metrics
evaluators = {"training": train_evaluator, "test": evaluator}
tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators)
# Store 2 best models by validation accuracy starting from num_epochs / 2:
best_model_handler = Checkpoint(
{"model": model},
get_save_handler(config),
filename_prefix="best",
n_saved=2,
global_step_transform=global_step_from_engine(trainer),
score_name="test_accuracy",
score_function=Checkpoint.get_default_score_fn("Accuracy"),
)
evaluator.add_event_handler(
Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler
)
try:
trainer.run(train_loader, max_epochs=config["num_epochs"])
except Exception as e:
logger.exception("")
raise e
if rank == 0:
tb_logger.close()
def run(
seed=543,
data_path="/tmp/cifar10",
output_path="/tmp/output-cifar10/",
model="resnet18_QAT_8b",
batch_size=512,
momentum=0.9,
weight_decay=1e-4,
num_workers=12,
num_epochs=24,
learning_rate=0.4,
num_warmup_epochs=4,
validate_every=3,
checkpoint_every=1000,
backend=None,
resume_from=None,
log_every_iters=15,
nproc_per_node=None,
with_clearml=False,
with_amp=False,
**spawn_kwargs,
):
"""Main entry to train an model on CIFAR10 dataset.
Args:
seed (int): random state seed to set. Default, 543.
data_path (str): input dataset path. Default, "/tmp/cifar10".
output_path (str): output path. Default, "/tmp/output-cifar10".
model (str): model name (from torchvision) to setup model to train. Default, "resnet18".
batch_size (int): total batch size. Default, 512.
momentum (float): optimizer's momentum. Default, 0.9.
weight_decay (float): weight decay. Default, 1e-4.
num_workers (int): number of workers in the data loader. Default, 12.
num_epochs (int): number of epochs to train the model. Default, 24.
learning_rate (float): peak of piecewise linear learning rate scheduler. Default, 0.4.
num_warmup_epochs (int): number of warm-up epochs before learning rate decay. Default, 4.
validate_every (int): run model's validation every ``validate_every`` epochs. Default, 3.
checkpoint_every (int): store training checkpoint every ``checkpoint_every`` iterations. Default, 200.
backend (str, optional): backend to use for distributed configuration. Possible values: None, "nccl", "xla-tpu",
"gloo" etc. Default, None.
nproc_per_node (int, optional): optional argument to setup number of processes per node. It is useful,
when main python process is spawning training as child processes.
resume_from (str, optional): path to checkpoint to use to resume the training from. Default, None.
log_every_iters (int): argument to log batch loss every ``log_every_iters`` iterations.
It can be 0 to disable it. Default, 15.
with_clearml (bool): if True, experiment ClearML logger is setup. Default, False.
with_amp (bool): if True, enables native automatic mixed precision. Default, False.
**spawn_kwargs: Other kwargs to spawn run in child processes: master_addr, master_port, node_rank, nnodes
"""
# check to see if the num_epochs is greater than or equal to num_warmup_epochs
if num_warmup_epochs >= num_epochs:
raise ValueError(
"num_epochs cannot be less than or equal to num_warmup_epochs, please increase num_epochs or decrease "
"num_warmup_epochs"
)
# catch all local parameters
config = locals()
config.update(config["spawn_kwargs"])
del config["spawn_kwargs"]
spawn_kwargs["nproc_per_node"] = nproc_per_node
with idist.Parallel(backend=backend, **spawn_kwargs) as parallel:
parallel.run(training, config)
def get_dataflow(config):
# - Get train/test datasets
with idist.one_rank_first(local=True):
train_dataset, test_dataset = utils.get_train_test_datasets(config["data_path"])
# Setup data loader also adapted to distributed config: nccl, gloo, xla-tpu
train_loader = idist.auto_dataloader(
train_dataset, batch_size=config["batch_size"], num_workers=config["num_workers"], shuffle=True, drop_last=True
)
test_loader = idist.auto_dataloader(
test_dataset, batch_size=2 * config["batch_size"], num_workers=config["num_workers"], shuffle=False
)
return train_loader, test_loader
def initialize(config):
model = utils.get_model(config["model"])
# Adapt model for distributed settings if configured
model = idist.auto_model(model, find_unused_parameters=True)
optimizer = optim.SGD(
model.parameters(),
lr=config["learning_rate"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
nesterov=True,
)
optimizer = idist.auto_optim(optimizer)
criterion = nn.CrossEntropyLoss().to(idist.device())
le = config["num_iters_per_epoch"]
milestones_values = [
(0, 0.0),
(le * config["num_warmup_epochs"], config["learning_rate"]),
(le * config["num_epochs"], 0.0),
]
lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)
return model, optimizer, criterion, lr_scheduler
def log_metrics(logger, epoch, elapsed, tag, metrics):
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()])
logger.info(f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}")
def log_basic_info(logger, config):
logger.info(f"Quantization Aware Training {config['model']} on CIFAR10")
logger.info(f"- PyTorch version: {torch.__version__}")
logger.info(f"- Ignite version: {ignite.__version__}")
if torch.cuda.is_available():
# explicitly import cudnn as
# torch.backends.cudnn can not be pickled with hvd spawning procs
from torch.backends import cudnn
logger.info(f"- GPU Device: {torch.cuda.get_device_name(idist.get_local_rank())}")
logger.info(f"- CUDA version: {torch.version.cuda}")
logger.info(f"- CUDNN version: {cudnn.version()}")
logger.info("\n")
logger.info("Configuration:")
for key, value in config.items():
logger.info(f"\t{key}: {value}")
logger.info("\n")
if idist.get_world_size() > 1:
logger.info("\nDistributed setting:")
logger.info(f"\tbackend: {idist.backend()}")
logger.info(f"\tworld size: {idist.get_world_size()}")
logger.info("\n")
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger):
device = idist.device()
# Setup Ignite trainer:
# - let's define training step
# - add other common handlers:
# - TerminateOnNan,
# - handler to setup learning rate scheduling,
# - ModelCheckpoint
# - RunningAverage` on `train_step` output
# - Two progress bars on epochs and optionally on iterations
with_amp = config["with_amp"]
scaler = GradScaler(enabled=with_amp)
def train_step(engine, batch):
x, y = batch[0], batch[1]
if x.device != device:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
model.train()
with autocast("cuda", enabled=with_amp):
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return {
"batch loss": loss.item(),
}
trainer = Engine(train_step)
trainer.logger = logger
to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
metric_names = [
"batch loss",
]
common.setup_common_training_handlers(
trainer=trainer,
train_sampler=train_sampler,
to_save=to_save,
save_every_iters=config["checkpoint_every"],
save_handler=get_save_handler(config),
lr_scheduler=lr_scheduler,
output_names=metric_names if config["log_every_iters"] > 0 else None,
with_pbars=False,
clear_cuda_cache=False,
)
resume_from = config["resume_from"]
if resume_from is not None:
checkpoint_fp = Path(resume_from)
assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
return trainer
def get_save_handler(config):
if config["with_clearml"]:
from ignite.handlers.clearml_logger import ClearMLSaver
return ClearMLSaver(dirname=config["output_path"])
return DiskSaver(config["output_path"], require_empty=False)
if __name__ == "__main__":
fire.Fire({"run": run})