Skip to content

Commit

Permalink
opt param bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
songyouwei committed May 18, 2018
1 parent 795628e commit 52502d4
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class Instructor:
def __init__(self, opt):
self.opt = opt
print('> training arguments:')
for arg in vars(opt):
print('>>> {0}: {1}'.format(arg, getattr(opt, arg)))
Expand All @@ -38,7 +39,7 @@ def reset_parameters(self):
if p.requires_grad:
n_trainable_params += n_params
if len(p.shape) > 1:
nn.init.xavier_uniform_(p)
self.opt.initializer(p)
else:
n_nontrainable_params += n_params
print('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
Expand All @@ -47,11 +48,11 @@ def run(self):
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
params = filter(lambda p: p.requires_grad, self.model.parameters())
optimizer = opt.optimizer(params, lr=opt.learning_rate)
optimizer = self.opt.optimizer(params, lr=self.opt.learning_rate)

max_test_acc = 0
global_step = 0
for epoch in range(opt.num_epoch):
for epoch in range(self.opt.num_epoch):
print('>' * 100)
print('epoch: ', epoch)
n_correct, n_total = 0, 0
Expand All @@ -62,15 +63,15 @@ def run(self):
self.model.train()
optimizer.zero_grad()

inputs = [sample_batched[col].to(device) for col in opt.inputs_cols]
inputs = [sample_batched[col].to(device) for col in self.opt.inputs_cols]
targets = sample_batched['polarity'].to(device)
outputs = self.model(inputs)

loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

if global_step % opt.log_step == 0:
if global_step % self.opt.log_step == 0:
n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
train_acc = n_correct / n_total
Expand All @@ -80,7 +81,7 @@ def run(self):
n_test_correct, n_test_total = 0, 0
with torch.no_grad():
for t_batch, t_sample_batched in enumerate(self.test_data_loader):
t_inputs = [t_sample_batched[col].to(device) for col in opt.inputs_cols]
t_inputs = [t_sample_batched[col].to(device) for col in self.opt.inputs_cols]
t_targets = t_sample_batched['polarity'].to(device)
t_outputs = self.model(t_inputs)

Expand Down

0 comments on commit 52502d4

Please sign in to comment.