Skip to content

Commit

Permalink
Replace conditional expressions with easier-to-read if-else statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
fheinsen committed Feb 12, 2023
1 parent e730ce3 commit adf02de
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions heinsen_routing/heinsen_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def forward(self, x_inp: torch.Tensor) -> Union[torch.Tensor, dict]:

# M-step.
phi = beta_use * D_use - beta_ign * D_ign # [...ij] "bang per bit" coefficients
x_out = einsum('...jd,jd,dh->...jh', einsum('...ij,...id->...jd', phi, scaled_x_inp), self.W_F1, self.W_F2) \
+ einsum('...ij,jh->...jh', phi, self.B_F2) if V is None else einsum('...ij,...ijh->...jh', phi, V) # use precomputed V if available
if V is None:
_einsum_phi_scaled_x_inp = einsum('...ij,...id->...jd', phi, scaled_x_inp)
x_out = einsum('...jd,jd,dh->...jh', _einsum_phi_scaled_x_inp, self.W_F1, self.W_F2) + einsum('...ij,jh->...jh', phi, self.B_F2)
else:
x_out = einsum('...ij,...ijh->...jh', phi, V)

if self.normalize:
x_out = self.N(x_out)
Expand Down Expand Up @@ -297,10 +300,10 @@ def forward(self, a_inp: torch.Tensor, mu_inp: torch.Tensor) -> Union[Tuple[torc
if iter_num == 0:
R = (self.CONST_one / self.n_out).expand(V.shape[:-2]) # [...ij]
else:
log_p = \
- einsum('...ijch,...jch->...ij', V_less_mu_out_2, 1.0 / (2.0 * sig2_out)) \
- sig2_out.sqrt().log().sum((-2, -1)).unsqueeze(-2) if self.p_model == 'gaussian' \
else self.log_softmax(-V_less_mu_out_2.sum((-2, -1))) # soft k-means otherwise
if self.p_model == 'gaussian':
log_p = - einsum('...ijch,...jch->...ij', V_less_mu_out_2, 1.0 / (2.0 * sig2_out)) - sig2_out.sqrt().log().sum((-2, -1)).unsqueeze(-2)
else:
log_p = self.log_softmax(-V_less_mu_out_2.sum((-2, -1))) # soft k-means
R = self.softmax(self.log_f(a_out).unsqueeze(-2) + log_p) # [...ij]

# D-step.
Expand Down

0 comments on commit adf02de

Please sign in to comment.