Skip to content

Commit

Permalink
functionaliza loss computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kun Qian committed Dec 29, 2018
1 parent f1685c6 commit ed65d6f
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ def _convert_batch(self, py_batch, prev_z_py=None):
return u_input, u_input_np, z_input, m_input, m_input_np,u_len, m_len, \
degree_input, kw_ret

def supervised_loss(self, pz_proba, pm_dec_proba, z_input, m_input):
# pz_proba = torch.log(pz_proba)
# pm_dec_proba = torch.log(pm_dec_proba)
pz_proba, pm_dec_proba = pz_proba[:, :, :cfg.vocab_size].contiguous(), pm_dec_proba[:, :,
:cfg.vocab_size].contiguous()
pr_loss = self.pr_loss(pz_proba.view(-1, pz_proba.size(2)), z_input.view(-1))
m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1))
loss = pr_loss + m_loss
return loss, pr_loss, m_loss

def train(self):
lr = cfg.lr
prev_min_loss, early_stop_count = 1 << 30, cfg.early_stop_count
Expand Down Expand Up @@ -146,6 +156,7 @@ def train(self):
# m_input_np=m_input_np,
# turn_states=turn_states,
# u_len=u_len, m_len=m_len, mode='train', **kw_ret)

pz_proba, pm_dec_proba, turn_states = self.m(u_input=u_input,
z_input=z_input,
m_input=m_input,
Expand All @@ -157,19 +168,12 @@ def train(self):
m_len=m_len,
mode='train',
**kw_ret)
# loss, pr_loss, m_loss = self.supervised_loss(torch.log(pz_proba),
# torch.log(pm_dec_proba),
# z_input,
# m_input)
# def supervised_loss(self, pz_proba, pm_dec_proba, z_input, m_input):
pz_proba = torch.log(pz_proba)
pm_dec_proba = torch.log(pm_dec_proba)
pz_proba, pm_dec_proba = pz_proba[:, :, :cfg.vocab_size].contiguous(), pm_dec_proba[:, :,
:cfg.vocab_size].contiguous()
pr_loss = self.pr_loss(pz_proba.view(-1, pz_proba.size(2)), z_input.view(-1))
m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1))
loss = pr_loss + m_loss
# return loss, pr_loss, m_loss

loss, pr_loss, m_loss = self.supervised_loss(torch.log(pz_proba),
torch.log(pm_dec_proba),
z_input,
m_input)



# ##################
Expand Down Expand Up @@ -269,13 +273,10 @@ def validate(self, data='dev'):
mode='train',
**kw_ret)

pz_proba = torch.log(pz_proba)
pm_dec_proba = torch.log(pm_dec_proba)
pz_proba, pm_dec_proba = pz_proba[:, :, :cfg.vocab_size].contiguous(), pm_dec_proba[:, :,
:cfg.vocab_size].contiguous()
pr_loss = self.pr_loss(pz_proba.view(-1, pz_proba.size(2)), z_input.view(-1))
m_loss = self.dec_loss(pm_dec_proba.view(-1, pm_dec_proba.size(2)), m_input.view(-1))
loss = pr_loss + m_loss
loss, pr_loss, m_loss = self.supervised_loss(torch.log(pz_proba),
torch.log(pm_dec_proba),
z_input,
m_input)


sup_loss += loss.data[0]
Expand Down

0 comments on commit ed65d6f

Please sign in to comment.