Skip to content

Commit

Permalink
start with pytorch-template
Browse files Browse the repository at this point in the history
  • Loading branch information
fyjl authored and fyjl committed May 22, 2019
0 parents commit 2f27e6f
Show file tree
Hide file tree
Showing 21 changed files with 1,003 additions and 0 deletions.
112 changes: 112 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# input data, saved log, checkpoints
data/
input/
saved/
datasets/

# editor, os cache directory
.vscode/
.idea/
__MACOSX/
3 changes: 3 additions & 0 deletions base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_data_loader import *
from .base_model import *
from .base_trainer import *
61 changes: 61 additions & 0 deletions base/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler


class BaseDataLoader(DataLoader):
"""
Base class for all data loaders
"""
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle

self.batch_idx = 0
self.n_samples = len(dataset)

self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)

def _split_sampler(self, split):
if split == 0.0:
return None, None

idx_full = np.arange(self.n_samples)

np.random.seed(0)
np.random.shuffle(idx_full)

if isinstance(split, int):
assert split > 0
assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
len_valid = split
else:
len_valid = int(self.n_samples * split)

valid_idx = idx_full[0:len_valid]
train_idx = np.delete(idx_full, np.arange(0, len_valid))

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# turn off shuffle option which is mutually exclusive with sampler
self.shuffle = False
self.n_samples = len(train_idx)

return train_sampler, valid_sampler

def split_validation(self):
if self.valid_sampler is None:
return None
else:
return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
25 changes: 25 additions & 0 deletions base/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
"""
Base class for all models
"""
@abstractmethod
def forward(self, *input):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError

def __str__(self):
"""
Model prints with number of trainable parameters
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params)
176 changes: 176 additions & 0 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
from abc import abstractmethod
from numpy import inf
from logger import WriterTensorboardX


class BaseTrainer:
"""
Base class for all trainers
"""
def __init__(self, model, loss, metrics, optimizer, config):
self.config = config
self.logger = config.get_logger('trainer', config['trainer']['verbosity'])

# setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(config['n_gpu'])
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)

self.loss = loss
self.metrics = metrics
self.optimizer = optimizer

cfg_trainer = config['trainer']
self.epochs = cfg_trainer['epochs']
self.save_period = cfg_trainer['save_period']
self.monitor = cfg_trainer.get('monitor', 'off')

# configuration to monitor model performance and save best
if self.monitor == 'off':
self.mnt_mode = 'off'
self.mnt_best = 0
else:
self.mnt_mode, self.mnt_metric = self.monitor.split()
assert self.mnt_mode in ['min', 'max']

self.mnt_best = inf if self.mnt_mode == 'min' else -inf
self.early_stop = cfg_trainer.get('early_stop', inf)

self.start_epoch = 1

self.checkpoint_dir = config.save_dir
# setup visualization writer instance
self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX'])

if config.resume is not None:
self._resume_checkpoint(config.resume)

@abstractmethod
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError

def train(self):
"""
Full training logic
"""
for epoch in range(self.start_epoch, self.epochs + 1):
result = self._train_epoch(epoch)

# save logged informations into log dict
log = {'epoch': epoch}
for key, value in result.items():
if key == 'metrics':
log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
elif key == 'val_metrics':
log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
else:
log[key] = value

# print logged informations to the screen
for key, value in log.items():
self.logger.info(' {:15s}: {}'.format(str(key), value))

# evaluate model performance according to configured metric, save best checkpoint as model_best
best = False
if self.mnt_mode != 'off':
try:
# check whether model performance improved or not, according to specified metric(mnt_metric)
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
except KeyError:
self.logger.warning("Warning: Metric '{}' is not found. "
"Model performance monitoring is disabled.".format(self.mnt_metric))
self.mnt_mode = 'off'
improved = False
not_improved_count = 0

if improved:
self.mnt_best = log[self.mnt_metric]
not_improved_count = 0
best = True
else:
not_improved_count += 1

if not_improved_count > self.early_stop:
self.logger.info("Validation performance didn\'t improve for {} epochs. "
"Training stops.".format(self.early_stop))
break

if epoch % self.save_period == 0:
self._save_checkpoint(epoch, save_best=best)

def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning("Warning: There\'s no GPU available on this machine,"
"training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
"on this machine.".format(n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids

def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
"""
arch = type(self.model).__name__
state = {
'arch': arch,
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.mnt_best,
'config': self.config
}
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
torch.save(state, filename)
self.logger.info("Saving checkpoint: {} ...".format(filename))
if save_best:
best_path = str(self.checkpoint_dir / 'model_best.pth')
torch.save(state, best_path)
self.logger.info("Saving current best: model_best.pth ...")

def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
:param resume_path: Checkpoint path to be resumed
"""
resume_path = str(resume_path)
self.logger.info("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.mnt_best = checkpoint['monitor_best']

# load architecture params from checkpoint.
if checkpoint['config']['arch'] != self.config['arch']:
self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
"checkpoint. This may yield an exception while state_dict is being loaded.")
self.model.load_state_dict(checkpoint['state_dict'])

# load optimizer state from checkpoint only when optimizer type is not changed.
if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
"Optimizer parameters not being resumed.")
else:
self.optimizer.load_state_dict(checkpoint['optimizer'])

self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
Loading

0 comments on commit 2f27e6f

Please sign in to comment.