Skip to content

Commit

Permalink
adding some options
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Jan 24, 2017
1 parent 10cc0b7 commit 623cd7f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
10 changes: 3 additions & 7 deletions snli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,12 @@ def __init__(self, config):
self.rnn = nn.LSTM(input_size=config.d_embed, hidden_size=config.d_hidden,
num_layers=config.n_layers, dropout=config.dp_ratio,
bidirectional=config.bidirectional)
self.init_config = self.config.n_cells, self.config.batch_size, self.config.d_hidden
# self.register_buffer('h0', h0)
# self.register_buffer('c0', c0)

def forward(self, inputs):
batch_size = inputs.size()[1]
h0 = Variable(torch.zeros(*self.init_config)).cuda()
c0 = Variable(torch.zeros(*self.init_config)).cuda()

_, (hn, _) = self.rnn(inputs, (h0[:, :batch_size].contiguous(), c0[:, :batch_size].contiguous()))
h0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda()
c0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda()
_, (hn, _) = self.rnn(inputs, (h0, c0))
return hn[-1] if not self.config.bidirectional else hn[-2:].view(batch_size, -1)


Expand Down
20 changes: 8 additions & 12 deletions snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
model = SNLIClassifier(config)
model.cuda()
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters())
opt = O.Adam(model.parameters(), lr=args.lr)

iterations = 0
start = time.time()
Expand All @@ -47,27 +47,23 @@
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
print(header)
for batch_idx, batch in enumerate(train_iter):
model.train()
model.train(); opt.zero_grad()
iterations += 1
opt.zero_grad()
answer = model(batch)
n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum()
n_total += batch.batch_size
train_acc = 100. * n_correct/n_total
loss = criterion(answer, batch.label)
loss.backward()
opt.step()
loss.backward(); opt.step()
if iterations % args.save_every == 0:
torch.save(model, os.path.join(args.save_path, 'snapshot_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)))
if iterations % args.val_every == 0:
val_iter.init_epoch()
model.eval(); val_iter.init_epoch()
n_dev_correct, dev_loss = 0, 0
model.eval()
for dev_batch_idx, batch in enumerate(val_iter):
opt.zero_grad()
answer = model(batch)
n_dev_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum()
dev_loss = criterion(answer, batch.label)
for dev_batch_idx, dev_batch in enumerate(val_iter):
answer = model(dev_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
dev_loss = criterion(answer, dev_batch.label)
dev_acc = 100. * n_dev_correct / len(val)
print(val_log_template.format(time.time()-start,
epoch, iterations, batch_idx, len(train_iter),
Expand Down
19 changes: 19 additions & 0 deletions snli/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from argparse import ArgumentParser

def get_args():
parser = ArgumentParser(description='PyTorch/torchtext SNLI example')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--d_embed', type=int, default=10)
parser.add_argument('--d_hidden', type=int, default=10)
parser.add_argument('--n_layers', type=int, default=1)
parser.add_argument('--log_every', type=int, default=5)
parser.add_argument('--lr', type=float, default=.1)
parser.add_argument('--val_every', type=int, default=1000000)
parser.add_argument('--save_every', type=int, default=100000)
parser.add_argument('--dp_ratio', type=int, default=0.0)
parser.add_argument('--bidirectional', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--save_path', type=str, default='')
args = parser.parse_args()
return args

0 comments on commit 623cd7f

Please sign in to comment.