Skip to content

Commit

Permalink
add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Bloss committed Jul 17, 2023
1 parent eecc543 commit 5403738
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def calc_dis_with_vector(self, data: list, train_data: list = None):

def calc_acc(
self, k: int, label: str, train_label: str = None, provided_distance_matrix: list = None, rand: bool = False
):
) -> tuple:
"""
Calculates the accuracy of the algorithm.
Expand Down Expand Up @@ -233,7 +233,7 @@ def calc_acc(
print("Accuracy is {}".format(sum(correct) / len(correct)))
return pred, correct

def combine_dis_acc(self, k, data, label, train_data=None, train_label=None):
def combine_dis_acc(self, k: int, data: list, label: str, train_data: list = None, train_label: str = None) -> tuple:
correct = []
pred = []
if train_label is not None:
Expand Down Expand Up @@ -272,7 +272,7 @@ def combine_dis_acc(self, k, data, label, train_data=None, train_label=None):
print("Accuracy is {}".format(sum(correct) / len(correct)))
return pred, correct

def combine_dis_acc_single(self, k, train_data, train_label, datum, label):
def combine_dis_acc_single(self, k: int, train_data: list, train_label: str, datum: list, label: str):
# Support multi processing - must provide train data and train label
distance4i = self.calc_dis_single_multi(train_data, datum)
sorted_idx = np.argpartition(np.array(distance4i), range(k))
Expand Down

0 comments on commit 5403738

Please sign in to comment.