Skip to content

Commit

Permalink
Language model: replace the optimizer and LR-decay scheduler
Browse files Browse the repository at this point in the history
Replace the original "homebrew" optimizer and LR-decay schedule with
PyTorch's SGD and ReduceLROnPlateau.
SGD with momentum=0 and weight_decay=0, and ReduceLROnPlateau with
patience=0 and factor=0.5 will give the same behavior as in the
original PyTorch example.

Having a standard optimizer and LR-decay schedule gives us the
flexibility to experiment with these during the training process.
  • Loading branch information
nzmora committed Jun 13, 2018
1 parent d6ffeaf commit a9b2892
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions examples/word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import data
import model

# Distiller imports
import os
import sys
script_dir = os.path.dirname(__file__)
Expand All @@ -24,8 +25,8 @@
sys.path.append(module_path)
import distiller
import apputils
from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSparsityCollector
import torchnet.meter as tnt
from distiller.data_loggers import TensorBoardLogger, PythonLogger


parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2',
Expand Down Expand Up @@ -58,20 +59,25 @@
help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='report interval')
parser.add_argument('--save', type=str, default='model.pt',
parser.add_argument('--save', type=str, default='checkpoint.pth.tar',
help='path to save the final model')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--onnx-export', type=str, default='',
help='path to export the final model in onnx format')

# Distiller-related arguments
SUMMARY_CHOICES = ['sparsity', 'compute', 'optimizer', 'model', 'modules', 'png', 'percentile']
SUMMARY_CHOICES = ['sparsity', 'model', 'modules', 'png', 'percentile']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--momentum', default=0., type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=0., type=float,
metavar='W', help='weight decay (default: 1e-4)')

args = parser.parse_args()

# Set the random seed manually for reproducibility.
Expand All @@ -84,6 +90,11 @@


def draw_lang_model_to_file(model, png_fname, dataset):
"""Draw a language model graph to a PNG file.
Caveat: the PNG that is produced has some problems, which we suspect are due to
PyTorch issues related to RNN ONNX export.
"""
try:
if dataset == 'wikitext2':
batch_size = 20
Expand All @@ -92,16 +103,16 @@ def draw_lang_model_to_file(model, png_fname, dataset):
hidden = model.init_hidden(batch_size)
dummy_input = (dummy_input, hidden)
else:
print("Unsupported dataset (%s) - aborting draw operation" % dataset)
msglogger.info("Unsupported dataset (%s) - aborting draw operation" % dataset)
return
g = apputils.SummaryGraph(model, dummy_input)
apputils.draw_model_to_file(g, png_fname)
print("Network PNG image generation completed")
msglogger.info("Network PNG image generation completed")

except FileNotFoundError as e:
print("An error has occured while generating the network PNG image.")
print("Please check that you have graphviz installed.")
print("\t$ sudo apt-get install graphviz")
msglogger.info("An error has occured while generating the network PNG image.")
msglogger.info("Please check that you have graphviz installed.")
msglogger.info("\t$ sudo apt-get install graphviz")
raise e

###############################################################################
Expand Down Expand Up @@ -229,18 +240,13 @@ def train(epoch, optimizer, compression_scheduler=None):
regularizer_loss = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
minibatches_per_epoch=steps_per_epoch, loss=loss)
loss += regularizer_loss
#losses['regularizer_loss'].add(regularizer_loss.item())

model.zero_grad()
#optimizer.zero_grad()
optimizer.zero_grad()
loss.backward()


# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)
#optimizer.step()
optimizer.step()

total_loss += loss.item()

Expand All @@ -250,8 +256,10 @@ def train(epoch, optimizer, compression_scheduler=None):
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss / args.log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
lr = optimizer.param_groups[0]['lr']
msglogger.info(
'| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.4f} | ms/batch {:5.2f} '
'| loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // args.bptt, lr,
elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
total_loss = 0
Expand All @@ -264,24 +272,20 @@ def train(epoch, optimizer, compression_scheduler=None):
('Batch Time', elapsed * 1000)])
)
steps_completed = batch + 1
#tflogger.log_training_progress(stats, epoch, steps_completed, total=steps_per_epoch, freq=args.log_interval)
distiller.log_training_progress(stats, model.named_parameters(), epoch, steps_completed,
steps_per_epoch, args.log_interval, [tflogger])


def export_onnx(path, batch_size, seq_len):
print('The model is also exported in ONNX format at {}'.
msglogger.info('The model is also exported in ONNX format at {}'.
format(os.path.realpath(args.onnx_export)))
model.eval()
dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device)
hidden = model.init_hidden(batch_size)
torch.onnx.export(model, (dummy_input, hidden), path)


# Loop over epochs.
lr = args.lr
best_val_loss = None

# Distiller loggers
msglogger = apputils.config_pylogger('logging.conf', None)
tflogger = TensorBoardLogger(msglogger.logdir)
tflogger.log_gradients = True
Expand All @@ -297,28 +301,31 @@ def export_onnx(path, batch_size, seq_len):
if param.dim() < 2:
# Skip biases
continue
bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()), largest=False, sorted=True)
bottomk, _ = torch.topk(param.abs().view(-1), int(percentile * param.numel()),
largest=False, sorted=True)
threshold = bottomk.data[-1]
print("parameter %s: q = %.2f" %(name, threshold))
msglogger.info("parameter %s: q = %.2f" %(name, threshold))
else:
distiller.model_summary(model, None, which_summary, 'wikitext2')

exit(0)

compression_scheduler = None

if args.compress:
# The main use-case for this sample application is CNN compression. Compression
# requires a compression schedule configuration file in YAML.
# Create a CompressionScheduler and configure it from a YAML schedule file
source = args.compress
compression_scheduler = distiller.CompressionScheduler(model)
distiller.config.fileConfig(model, None, compression_scheduler, args.compress, msglogger)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
#optimizer = optim.SparseAdam(model.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
patience=0, verbose=True, factor=0.5)

# Loop over epochs.
# At any point you can hit Ctrl + C to break out of training early.
best_val_loss = float("inf")
try:
for epoch in range(0, args.epochs):
epoch_start_time = time.time()
Expand All @@ -328,11 +335,12 @@ def export_onnx(path, batch_size, seq_len):
train(epoch, optimizer, compression_scheduler)

val_loss = evaluate(val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
msglogger.info('-' * 89)
msglogger.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
msglogger.info('-' * 89)

distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])

stats = ('Peformance/Validation/',
Expand All @@ -341,23 +349,24 @@ def export_onnx(path, batch_size, seq_len):
('Perplexity', math.exp(val_loss))]))
tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1)

with open(args.save, 'wb') as f:
torch.save(model, f)

# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
with open(args.save, 'wb') as f:
if val_loss < best_val_loss:
with open(args.save+".best", 'wb') as f:
torch.save(model, f)
best_val_loss = val_loss
else:
# Anneal the learning rate if no improvement has been seen in the validation dataset.
lr /= 4 #1.2
lr_scheduler.step(val_loss)

if compression_scheduler:
compression_scheduler.on_epoch_end(epoch)

except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
msglogger.info('-' * 89)
msglogger.info('Exiting from training early')

# Load the best saved model.
# Load the last saved model.
with open(args.save, 'rb') as f:
model = torch.load(f)
# after load the rnn params are not a continuous chunk of memory
Expand All @@ -366,10 +375,10 @@ def export_onnx(path, batch_size, seq_len):

# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
msglogger.info('=' * 89)
msglogger.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
msglogger.info('=' * 89)

if len(args.onnx_export) > 0:
# Export the model in ONNX format.
Expand Down

0 comments on commit a9b2892

Please sign in to comment.