Skip to content

Commit

Permalink
style(wmd): adjust style in wmf
Browse files Browse the repository at this point in the history
  • Loading branch information
domainxz committed Feb 23, 2020
1 parent ba7fb19 commit 3a9d159
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions single/wmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def train(self, max_iter: int = 200, tol: float = 1e-4, model_path: str = None)
if len(self.usm[i]) > 0:
Vi = self.fie[np.array(self.usm[i]), :]
self.fue[i, :] = np.linalg.solve(np.dot(Vi.T, Vi)*(self.a-self.b)+XX, np.sum(Vi, axis=0)*self.a)
loss += 0.5 * self.lu * np.sum(self.fue[i,:]**2)
loss += 0.5 * self.lu * np.sum(self.fue[i,:]**2)
Ur = self.fue[np.array(self.u_rated), :]
XX = np.dot(Ur.T, Ur)*self.b
for j in self.ism:
Expand All @@ -84,7 +84,7 @@ def train(self, max_iter: int = 200, tol: float = 1e-4, model_path: str = None)
loss += 0.5 * len(self.ism[j])*self.a
loss += 0.5 * np.linalg.multi_dot((self.fie[j, :], B, self.fie[j, :]))
loss -= np.sum(np.multiply(Uj, self.fie[j, :]))*self.a
loss += 0.5 * self.lv * np.sum(self.fie[j, :]**2)
loss += 0.5 * self.lv * np.sum(self.fie[j, :]**2)
cond = np.abs(loss_old - loss) / loss_old
tprint('Iter %3d, loss %.6f, converge %.6f, time %.2fs'%(it, loss, cond, time.time()-t1))
if cond < tol:
Expand Down

0 comments on commit 3a9d159

Please sign in to comment.