Skip to content

Commit

Permalink
better logging and code structure
Browse files Browse the repository at this point in the history
- now main() is on the top of "train.py" and previous out-of-function house-keeping code are in setup()
- fix a bug in drqa/layers.py: call "contiguous()" before "view()"
- logging for stdout will not be overwhelmed by thousands of training logs now and we have a progress bar for evaluation
  • Loading branch information
hitvoice committed Apr 10, 2018
1 parent 90f8a34 commit 4ad4452
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 108 deletions.
2 changes: 1 addition & 1 deletion drqa/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def forward(self, x, x_mask):
x = batch * len * hdim
x_mask = batch * len
"""
x_flat = x.view(-1, x.size(-1))
x_flat = x.contiguous().view(-1, x.size(-1))
scores = self.linear(x_flat).view(x.size(0), x.size(1))
scores.data.masked_fill_(x_mask.data, -float('inf'))
alpha = F.softmax(scores, dim=1)
Expand Down
247 changes: 140 additions & 107 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,120 +13,30 @@
from drqa.model import DocReaderModel
from drqa.utils import str2bool

parser = argparse.ArgumentParser(
description='Train a Document Reader model.'
)
# system
parser.add_argument('--log_file', default='output.log',
help='path for log file.')
parser.add_argument('--log_per_updates', type=int, default=3,
help='log model loss per x updates (mini-batches).')
parser.add_argument('--data_file', default='SQuAD/data.msgpack',
help='path to preprocessed data file.')
parser.add_argument('--model_dir', default='models',
help='path to store saved models.')
parser.add_argument('--save_last_only', action='store_true',
help='only save the final models.')
parser.add_argument('--eval_per_epoch', type=int, default=1,
help='perform evaluation per x epochs.')
parser.add_argument('--seed', type=int, default=1013,
help='random seed for data shuffling, dropout, etc.')
parser.add_argument("--cuda", type=str2bool, nargs='?',
const=True, default=torch.cuda.is_available(),
help='whether to use GPU acceleration.')
# training
parser.add_argument('-e', '--epochs', type=int, default=40)
parser.add_argument('-bs', '--batch_size', type=int, default=32)
parser.add_argument('-rs', '--resume', default='',
help='previous model file name (in `model_dir`). '
'e.g. "checkpoint_epoch_11.pt"')
parser.add_argument('-ro', '--resume_options', action='store_true',
help='use previous model options, ignore the cli and defaults.')
parser.add_argument('-rlr', '--reduce_lr', type=float, default=0.,
help='reduce initial (resumed) learning rate by this factor.')
parser.add_argument('-op', '--optimizer', default='adamax',
help='supported optimizer: adamax, sgd')
parser.add_argument('-gc', '--grad_clipping', type=float, default=10)
parser.add_argument('-wd', '--weight_decay', type=float, default=0)
parser.add_argument('-lr', '--learning_rate', type=float, default=0.1,
help='only applied to SGD.')
parser.add_argument('-mm', '--momentum', type=float, default=0,
help='only applied to SGD.')
parser.add_argument('-tp', '--tune_partial', type=int, default=1000,
help='finetune top-x embeddings.')
parser.add_argument('--fix_embeddings', action='store_true',
help='if true, `tune_partial` will be ignored.')
parser.add_argument('--rnn_padding', action='store_true',
help='perform rnn padding (much slower but more accurate).')
# model
parser.add_argument('--question_merge', default='self_attn')
parser.add_argument('--doc_layers', type=int, default=3)
parser.add_argument('--question_layers', type=int, default=3)
parser.add_argument('--hidden_size', type=int, default=128)
parser.add_argument('--num_features', type=int, default=4)
parser.add_argument('--pos', type=str2bool, nargs='?', const=True, default=True,
help='use pos tags as a feature.')
parser.add_argument('--ner', type=str2bool, nargs='?', const=True, default=True,
help='use named entity tags as a feature.')
parser.add_argument('--use_qemb', type=str2bool, nargs='?', const=True, default=True)
parser.add_argument('--concat_rnn_layers', type=str2bool, nargs='?',
const=True, default=True)
parser.add_argument('--dropout_emb', type=float, default=0.4)
parser.add_argument('--dropout_rnn', type=float, default=0.4)
parser.add_argument('--dropout_rnn_output', type=str2bool, nargs='?',
const=True, default=True)
parser.add_argument('--max_len', type=int, default=15)
parser.add_argument('--rnn_type', default='lstm',
help='supported types: rnn, gru, lstm')

args = parser.parse_args()

# set model dir
model_dir = args.model_dir
os.makedirs(model_dir, exist_ok=True)
model_dir = os.path.abspath(model_dir)

# set random seed
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)

# setup logger
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
fh = logging.FileHandler(args.log_file)
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
log.addHandler(fh)
log.addHandler(ch)


def main():
log.info('[program starts.]')
args, log = setup()
log.info('[Program starts. Loading data...]')
train, dev, dev_y, embedding, opt = load_data(vars(args))
log.info(opt)
log.info('[Data loaded.]')

if args.resume:
log.info('[loading previous model...]')
checkpoint = torch.load(os.path.join(model_dir, args.resume))
checkpoint = torch.load(os.path.join(args.model_dir, args.resume))
if args.resume_options:
opt = checkpoint['config']
state_dict = checkpoint['state_dict']
model = DocReaderModel(opt, embedding, state_dict)
epoch_0 = checkpoint['epoch'] + 1
for i in range(checkpoint['epoch']):
# synchronize random seed
random.setstate(checkpoint['random_state'])
torch.random.set_rng_state(checkpoint['torch_state'])
# synchronize random seed
random.setstate(checkpoint['random_state'])
torch.random.set_rng_state(checkpoint['torch_state'])
if args.cuda:
torch.cuda.set_rng_state(checkpoint['torch_cuda_state'])
if args.reduce_lr:
lr_decay(model.optimizer, lr_decay=args.reduce_lr)
log.info('[learning rate reduced by {}]'.format(args.reduce_lr))
else:
model = DocReaderModel(opt, embedding)
epoch_0 = 1
Expand All @@ -137,8 +47,9 @@ def main():
if args.resume:
batches = BatchGen(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda)
predictions = []
for batch in batches:
for i, batch in enumerate(batches):
predictions.extend(model.predict(batch))
log.debug('> evaluating [{}/{}]'.format(i, len(batches)))
em, f1 = score(predictions, dev_y)
log.info("[dev EM: {} F1: {}]".format(em, f1))
best_val_score = f1
Expand All @@ -153,33 +64,150 @@ def main():
for i, batch in enumerate(batches):
model.update(batch)
if i % args.log_per_updates == 0:
log.info('epoch [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'.format(
log.info('> epoch [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]'.format(
epoch, model.updates, model.train_loss.value,
str((datetime.now() - start) / (i + 1) * (len(batches) - i - 1)).split('.')[0]))
# eval
if epoch % args.eval_per_epoch == 0:
batches = BatchGen(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda)
predictions = []
for batch in batches:
for i, batch in enumerate(batches):
predictions.extend(model.predict(batch))
log.debug('> evaluating [{}/{}]'.format(i, len(batches)))
em, f1 = score(predictions, dev_y)
log.warning("dev EM: {} F1: {}".format(em, f1))
# save
if not args.save_last_only or epoch == epoch_0 + args.epochs - 1:
model_file = os.path.join(model_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
model_file = os.path.join(args.model_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
model.save(model_file, epoch)
if f1 > best_val_score:
best_val_score = f1
copyfile(
model_file,
os.path.join(model_dir, 'best_model.pt'))
os.path.join(args.model_dir, 'best_model.pt'))
log.info('[new best model saved.]')


def setup():
parser = argparse.ArgumentParser(
description='Train a Document Reader model.'
)
# system
parser.add_argument('--log_file', default='output.log',
help='path for log file.')
parser.add_argument('--log_per_updates', type=int, default=3,
help='log model loss per x updates (mini-batches).')
parser.add_argument('--data_file', default='SQuAD/data.msgpack',
help='path to preprocessed data file.')
parser.add_argument('--model_dir', default='models',
help='path to store saved models.')
parser.add_argument('--save_last_only', action='store_true',
help='only save the final models.')
parser.add_argument('--eval_per_epoch', type=int, default=1,
help='perform evaluation per x epochs.')
parser.add_argument('--seed', type=int, default=1013,
help='random seed for data shuffling, dropout, etc.')
parser.add_argument("--cuda", type=str2bool, nargs='?',
const=True, default=torch.cuda.is_available(),
help='whether to use GPU acceleration.')
# training
parser.add_argument('-e', '--epochs', type=int, default=40)
parser.add_argument('-bs', '--batch_size', type=int, default=32)
parser.add_argument('-rs', '--resume', default='best_model.pt',
help='previous model file name (in `model_dir`). '
'e.g. "checkpoint_epoch_11.pt"')
parser.add_argument('-ro', '--resume_options', action='store_true',
help='use previous model options, ignore the cli and defaults.')
parser.add_argument('-rlr', '--reduce_lr', type=float, default=0.,
help='reduce initial (resumed) learning rate by this factor.')
parser.add_argument('-op', '--optimizer', default='adamax',
help='supported optimizer: adamax, sgd')
parser.add_argument('-gc', '--grad_clipping', type=float, default=10)
parser.add_argument('-wd', '--weight_decay', type=float, default=0)
parser.add_argument('-lr', '--learning_rate', type=float, default=0.1,
help='only applied to SGD.')
parser.add_argument('-mm', '--momentum', type=float, default=0,
help='only applied to SGD.')
parser.add_argument('-tp', '--tune_partial', type=int, default=1000,
help='finetune top-x embeddings.')
parser.add_argument('--fix_embeddings', action='store_true',
help='if true, `tune_partial` will be ignored.')
parser.add_argument('--rnn_padding', action='store_true',
help='perform rnn padding (much slower but more accurate).')
# model
parser.add_argument('--question_merge', default='self_attn')
parser.add_argument('--doc_layers', type=int, default=3)
parser.add_argument('--question_layers', type=int, default=3)
parser.add_argument('--hidden_size', type=int, default=128)
parser.add_argument('--num_features', type=int, default=4)
parser.add_argument('--pos', type=str2bool, nargs='?', const=True, default=True,
help='use pos tags as a feature.')
parser.add_argument('--ner', type=str2bool, nargs='?', const=True, default=True,
help='use named entity tags as a feature.')
parser.add_argument('--use_qemb', type=str2bool, nargs='?', const=True, default=True)
parser.add_argument('--concat_rnn_layers', type=str2bool, nargs='?',
const=True, default=True)
parser.add_argument('--dropout_emb', type=float, default=0.4)
parser.add_argument('--dropout_rnn', type=float, default=0.4)
parser.add_argument('--dropout_rnn_output', type=str2bool, nargs='?',
const=True, default=True)
parser.add_argument('--max_len', type=int, default=15)
parser.add_argument('--rnn_type', default='lstm',
help='supported types: rnn, gru, lstm')

args = parser.parse_args()

# set model dir
model_dir = args.model_dir
os.makedirs(model_dir, exist_ok=True)
args.model_dir = os.path.abspath(model_dir)

if args.resume == 'best_model.pt' and not os.path.exists(os.path.join(args.model_dir, args.resume)):
# means we're starting fresh
args.resume = ''

# set random seed
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)

# setup logger
class ProgressHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
self.is_overwrite = False

def emit(self, record):
log_entry = self.format(record)
if record.message.startswith('> '):
sys.stdout.write('{}\r'.format(log_entry.rstrip()))
sys.stdout.flush()
self.is_overwrite = True
else:
if self.is_overwrite:
sys.stdout.write('\n')
self.is_overwrite = False
sys.stdout.write('{}\n'.format(log_entry))

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
fh = logging.FileHandler(args.log_file)
fh.setLevel(logging.INFO)
ch = ProgressHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
log.addHandler(fh)
log.addHandler(ch)

return args, log


def lr_decay(optimizer, lr_decay):
for param_group in optimizer.param_groups:
param_group['lr'] *= lr_decay
log.info('[learning rate reduced by {}]'.format(lr_decay))
return optimizer


Expand All @@ -192,7 +220,9 @@ def load_data(opt):
opt['embedding_dim'] = embedding.size(1)
opt['pos_size'] = len(meta['vocab_tag'])
opt['ner_size'] = len(meta['vocab_ent'])
with open(args.data_file, 'rb') as f:
BatchGen.pos_size = opt['pos_size']
BatchGen.ner_size = opt['ner_size']
with open(opt['data_file'], 'rb') as f:
data = msgpack.load(f, encoding='utf8')
train = data['train']
data['dev'].sort(key=lambda x: len(x[1]))
Expand All @@ -202,6 +232,9 @@ def load_data(opt):


class BatchGen:
pos_size = None
ner_size = None

def __init__(self, data, batch_size, gpu, evaluation=False):
"""
input:
Expand Down Expand Up @@ -245,12 +278,12 @@ def __iter__(self):
for j, feature in enumerate(doc):
context_feature[i, j, :] = torch.Tensor(feature)

context_tag = torch.Tensor(batch_size, context_len, args.pos_size).fill_(0)
context_tag = torch.Tensor(batch_size, context_len, self.pos_size).fill_(0)
for i, doc in enumerate(batch[3]):
for j, tag in enumerate(doc):
context_tag[i, j, tag] = 1

context_ent = torch.Tensor(batch_size, context_len, args.ner_size).fill_(0)
context_ent = torch.Tensor(batch_size, context_len, self.ner_size).fill_(0)
for i, doc in enumerate(batch[4]):
for j, ent in enumerate(doc):
context_ent[i, j, ent] = 1
Expand Down

0 comments on commit 4ad4452

Please sign in to comment.