Skip to content

Commit 4fa897c

Browse files
authored
Merge pull request #34 from vaquierm/marine/BNB_Issue1
implementation of bernouilli naive bayes
2 parents a612e67 + 65f43c8 commit 4fa897c

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

src/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# These are all the different vectorizers to run ("BINARY", "TFIDF")
1414
vectorizers_to_run = ["TFIDF"]
1515

16-
# These are all the models to run and compare performance on a k fold cross validation ("LR", "NB", "MNNB", "KNN", "DT", "RF", "SVM", "SUPER")
16+
# These are all the models to run and compare performance on a k fold cross validation ("LR", "NB", "NB_SKLEARN", "MNNB", "KNN", "DT", "RF", "SVM", "SUPER")
1717
models_to_run = ["MNNB"]
1818

1919
# If this is true, run gridsearch on each model (This will significantly increase the runtime of the validation pipeline for model types that support gridsearch)

src/models/NaiveBayes.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1+
import numpy as np
2+
3+
from sklearn.preprocessing import LabelBinarizer
4+
from sklearn.utils.extmath import safe_sparse_dot
15
from src.models.Model import Model
26

37

48
class NaiveBayes(Model):
59

10+
def __init__(self, alpha: float = 1):
11+
if alpha < 0:
12+
raise Exception("Alpha must be greater than zero")
13+
14+
self.alpha = alpha
15+
616
def fit(self, X, Y):
717
"""
818
Fit the model with the training data
@@ -11,14 +21,61 @@ def fit(self, X, Y):
1121
:return: None
1222
"""
1323
super().fit(X, Y)
14-
# TODO: Marine Implement this
1524
# https://github.com/vaquierm/RedditCommentTextClassification/issues/1
1625

26+
subreddits = np.unique(Y)
27+
28+
# fit the model
29+
self.parameters = {}
30+
total_per_class = []
31+
thetak = [] # parameter theta k = nb comment of class 1 / total number of comments
32+
alpha = 1
33+
34+
# compute theta k
35+
# for each class
36+
for i in range(len(subreddits)):
37+
feature = subreddits[i]
38+
numbExamples = 0
39+
40+
# loop through all the comments
41+
for j in range(len(Y)):
42+
if (Y[j] == feature):
43+
numbExamples += 1
44+
45+
total_per_class.append(float(numbExamples))
46+
thetak_i = float(numbExamples) / float(X.shape[0])
47+
thetak.append(thetak_i)
48+
49+
binarizer = LabelBinarizer()
50+
Y = binarizer.fit_transform(Y)
51+
52+
# parameter thate of kj using sparse matrices
53+
# add 1 for Laplace Smoothing
54+
kj_numerator = safe_sparse_dot(Y.T, X) + alpha
55+
# kj_denominator == # of comments from that class
56+
total_per_class = np.array(total_per_class)
57+
58+
# add 2 for Laplace Smoothing
59+
kj_denominator = total_per_class.reshape(-1, 1) + 2*alpha
60+
61+
log_thetakj = np.log(kj_numerator) - np.log(kj_denominator)
62+
63+
self.parameters.update({'parameter_k': thetak})
64+
self.parameters.update({'parameter_log_kj': log_thetakj})
65+
1766
def predict(self, X):
1867
"""
1968
Predict the labels based on the inputs
2069
:param X: Inputs
2170
:return: The predicted labels based on the training
2271
"""
2372
super().predict(X)
24-
# TODO: Marine Implement this
73+
log_one_minus_thatakj = np.log(1 - np.exp(self.parameters["parameter_log_kj"]))
74+
first_summation = self.parameters["parameter_log_kj"] - log_one_minus_thatakj
75+
76+
first_term = np.log(self.parameters["parameter_k"])
77+
second_term = safe_sparse_dot(X, first_summation.T)
78+
third_term = log_one_minus_thatakj.sum(axis=1)
79+
prediction = first_term + second_term + third_term
80+
81+
return np.argmax(prediction, axis=1)

src/utils/factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.tree import DecisionTreeClassifier
99
from sklearn.ensemble import RandomForestClassifier
1010
from sklearn.neighbors import KNeighborsClassifier
11-
from sklearn.naive_bayes import MultinomialNB
11+
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
1212

1313

1414
def get_vectorizer(vectorizer_name):
@@ -33,6 +33,8 @@ def get_model(model_name: str, grid_search: bool = False):
3333
return GridSearchCV(LogisticRegression(multi_class='auto'), param_grid, cv=5)
3434
elif model_name == "NB":
3535
return NaiveBayes()
36+
elif model_name == "NB_SKLEARN":
37+
return BernoulliNB()
3638
elif model_name == "MNNB":
3739
if not grid_search:
3840
return MultinomialNB(alpha=0.0001)

src/validation_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def run_validation_pipeline(linear_correlation: bool = True):
6262

6363
print("\t\t\tThe best parameters for model: " + model_to_run + " are ", model.best_params_)
6464
print("\t\t\tRunning k fold validation with the best model")
65-
acc, conf_mat = k_fold_validation(model.best_estimator_, X_trains, X_tests, Y_trains, Y_tests, linear_correlation)
65+
accuracy, conf_mat = k_fold_validation(model.best_estimator_, X_trains, X_tests, Y_trains, Y_tests, linear_correlation)
6666

6767
results_confusion_matrix_file = os.path.join(results_dir_path, vocabulary + "_"+ vec + "_" + model_to_run + "_" + "confusion.png")
6868
save_confusion_matrix(conf_mat, "Confusion Matrix for vocabulary " + vocabulary + ", vectorizer " + vec + "and model " + model_to_run, list(map(lambda pred: int_to_subreddit[pred], unique_labels(Y))), results_confusion_matrix_file)
69-
print("\t\t\t\tAccuracy of model " + model_to_run + ": ", acc)
69+
print("\t\t\t\tAccuracy of model " + model_to_run + ": ", accuracy)
7070

71-
append_results(model_to_run + ": " + str(acc), results_data_file)
72-
accuracies = accuracies.append(pd.DataFrame({"Model": [model_to_run], "Vectorizer": [vec], "Accuracy": [acc]}), ignore_index=True)
71+
append_results(model_to_run + ": " + str(accuracy), results_data_file)
72+
accuracies = accuracies.append(pd.DataFrame({"Model": [model_to_run], "Vectorizer": [vec], "Accuracy": [accuracy]}), ignore_index=True)
7373

7474
# save the accuracies of vocab for each model
7575
results_model_accuracy_file = os.path.join(results_dir_path, "accuracies_" + vocabulary + ".png")

0 commit comments

Comments
 (0)