Skip to content

Commit

Permalink
make funcs more readable in knn
Browse files Browse the repository at this point in the history
  • Loading branch information
zlotus committed Dec 5, 2016
1 parent 94fcefa commit 883e2ea
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions assignment1/cs231n/classifiers/k_nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ def compute_distances_one_loop(self, X):
# Compute the l2 distance between the ith test point and all training #
# points, and store the result in dists[i, :]. #
#######################################################################
xtr_square = (self.X_train**2).dot(np.ones(self.X_train.shape[1]))
xte_square = X[i].dot(X[i].T)
xtr_dot_xte = self.X_train.dot(X[i])
dists[i] = (xtr_square+xte_square-2*xtr_dot_xte)**.5
dists[i] = np.sum((X[i]-self.X_train)**2, axis=1)**.5
#######################################################################
# END OF YOUR CODE #
#######################################################################
Expand Down Expand Up @@ -125,8 +122,8 @@ def compute_distances_no_loops(self, X):
# HINT: Try to formulate the l2 distance using matrix multiplication #
# and two broadcast sums. #
#########################################################################
Xtr_square = (self.X_train**2).dot(np.ones(self.X_train.shape[1])).reshape([1, self.X_train.shape[0]])
Xte_square = (X**2).dot(np.ones(X.shape[1])).reshape([X.shape[0], 1])
Xtr_square = np.sum(self.X_train**2, axis=1).reshape([1, self.X_train.shape[0]])
Xte_square = np.sum(X**2, axis=1).reshape([X.shape[0], 1])
Xtr_dot_Xte = X.dot(self.X_train.T)
dists = (Xtr_square+Xte_square-2*Xtr_dot_Xte)**.5
#########################################################################
Expand Down

0 comments on commit 883e2ea

Please sign in to comment.