|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +from sklearn.model_selection import StratifiedKFold |
| 4 | +import joblib |
| 5 | +from dotenv import load_dotenv |
| 6 | +from sklearn.metrics import classification_report, balanced_accuracy_score |
| 7 | +from sklearn.model_selection import cross_val_predict |
| 8 | +from sklearn.ensemble import RandomForestClassifier |
| 9 | +from sklearn.neural_network import MLPClassifier |
| 10 | +from sklearn.model_selection import GridSearchCV |
| 11 | +import pandas as pd |
| 12 | +import wandb |
| 13 | + |
| 14 | +load_dotenv() |
| 15 | + |
| 16 | +#### Import the data |
| 17 | +df = pd.read_csv("../data/text_dataset_translate.csv") |
| 18 | +Y = df["diag"].values |
| 19 | + |
| 20 | +# Remove CFTD and unclear diagnosis |
| 21 | +df["diag"].value_counts() |
| 22 | +# Drop the rows with unclear diagnosis |
| 23 | +df = df[df["diag"] != "UNCLEAR"] |
| 24 | +# Do the same for the X array based on the df index |
| 25 | + |
| 26 | +cv_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) |
| 27 | +df["diag"].value_counts() |
| 28 | + |
| 29 | + |
| 30 | +LANGUAGE = ["fr", "en"] |
| 31 | +EMBEDDING_MODEL = ["instructor", "openai"] |
| 32 | + |
| 33 | +Y = df["diag"].values |
| 34 | + |
| 35 | +for lang in LANGUAGE: |
| 36 | + for embedding_method in EMBEDDING_MODEL: |
| 37 | + X = np.load(f"../data/embeddings/{embedding_method}_{lang}_embeddings.npy") |
| 38 | + |
| 39 | + ######################################### |
| 40 | + # MLPC |
| 41 | + param_grid = { |
| 42 | + "hidden_layer_sizes": [(400,), (200,), (100, 100), (200, 200)], |
| 43 | + "activation": ["tanh", "relu"], |
| 44 | + "solver": ["adam"], |
| 45 | + "learning_rate_init": [0.001, 0.01], |
| 46 | + "max_iter": [800, 1500, 2500], |
| 47 | + } |
| 48 | + |
| 49 | + # Create grid search |
| 50 | + cls = MLPClassifier(random_state=42) |
| 51 | + gs_mlpc = GridSearchCV( |
| 52 | + cls, param_grid, scoring="accuracy", cv=cv_fold, verbose=1 |
| 53 | + ) |
| 54 | + gs_mlpc.fit(X, Y) |
| 55 | + best_mlpc = gs_mlpc.best_estimator_ |
| 56 | + df_cv_search_rf = pd.DataFrame(gs_mlpc.cv_results_) |
| 57 | + # Print the best parameters and score |
| 58 | + print("Best parameters:", gs_mlpc.best_params_) |
| 59 | + print("Best score:", gs_mlpc.best_score_) |
| 60 | + joblib.dump( |
| 61 | + gs_mlpc, f"../models/{embedding_method}_{lang}_gridsearch_mlpc.joblib" |
| 62 | + ) |
| 63 | + joblib.dump(best_mlpc, f"../models/{embedding_method}_{lang}_model_mlpc.joblib") |
| 64 | + |
| 65 | + gs_mlpc = joblib.load( |
| 66 | + f"../models/{embedding_method}_{lang}_gridsearch_mlpc.joblib" |
| 67 | + ) |
| 68 | + best_mlpc = joblib.load( |
| 69 | + f"../models/{embedding_method}_{lang}_model_mlpc.joblib" |
| 70 | + ) |
| 71 | + |
| 72 | + # Use cross_val_predict to get predicted labels and probabilities |
| 73 | + y_pred = cross_val_predict(best_mlpc, X, Y, cv=cv_fold) |
| 74 | + y_probas = cross_val_predict( |
| 75 | + best_mlpc, X, Y, cv=cv_fold, method="predict_proba" |
| 76 | + ) |
| 77 | + # Compute classification report |
| 78 | + report = classification_report( |
| 79 | + Y, y_pred, target_names=best_mlpc.classes_, output_dict=True |
| 80 | + ) |
| 81 | + |
| 82 | + run = wandb.init( |
| 83 | + project="myo-text-classify", |
| 84 | + name=f"{embedding_method}_{lang}_mlpc", |
| 85 | + config={ |
| 86 | + "embedding": f"{embedding_method}", |
| 87 | + "doc_lang": f"{lang}", |
| 88 | + "corpus": "complete_1704023_190reports", |
| 89 | + "model": "MLPClassifier", |
| 90 | + }, |
| 91 | + ) |
| 92 | + config = wandb.config |
| 93 | + best_params = gs_mlpc.best_params_ |
| 94 | + best_score = gs_mlpc.best_score_ |
| 95 | + best_std = gs_mlpc.cv_results_["std_test_score"][gs_mlpc.best_index_] |
| 96 | + balanced_accuracy_metric = balanced_accuracy_score(Y, y_pred) |
| 97 | + |
| 98 | + wandb.log( |
| 99 | + { |
| 100 | + "Classification Report": report, |
| 101 | + "Best Params": best_params, |
| 102 | + "Best Score (gs)": best_score, |
| 103 | + "CV Std Devs (gs)": best_std, |
| 104 | + "Balanced Accuracy": balanced_accuracy_metric, |
| 105 | + } |
| 106 | + ) |
| 107 | + wandb.sklearn.plot_confusion_matrix(Y, y_pred, best_mlpc.classes_) |
| 108 | + wandb.sklearn.plot_classifier( |
| 109 | + best_mlpc, |
| 110 | + X, |
| 111 | + X, |
| 112 | + Y, |
| 113 | + Y, |
| 114 | + y_pred, |
| 115 | + y_probas, |
| 116 | + labels=best_mlpc.classes_, |
| 117 | + model_name=f"{embedding_method}_{lang}_model", |
| 118 | + feature_names=None, |
| 119 | + ) |
| 120 | + # Create artifact for best model |
| 121 | + model_artifact = wandb.Artifact( |
| 122 | + f"{embedding_method}_{lang}_model_mlpc", type="model" |
| 123 | + ) |
| 124 | + # Add best estimator to artifact |
| 125 | + model_artifact.add_file( |
| 126 | + f"../models/{embedding_method}_{lang}_model_mlpc.joblib" |
| 127 | + ) |
| 128 | + # Log artifact to WandB |
| 129 | + wandb.run.log_artifact(model_artifact) |
| 130 | + wandb.finish() |
| 131 | + |
| 132 | + ############################################# |
| 133 | + # RANDOM FOREST |
| 134 | + param_grid_rf = { |
| 135 | + "n_estimators": [10, 50, 100, 200], |
| 136 | + "max_depth": [None, 5, 10, 20], |
| 137 | + "min_samples_split": [2, 5, 10], |
| 138 | + "min_samples_leaf": [1, 2, 4], |
| 139 | + "class_weight": ["balanced", "balanced_subsample"], |
| 140 | + } |
| 141 | + |
| 142 | + # Create grid search |
| 143 | + cls_rf = RandomForestClassifier(random_state=42) |
| 144 | + gs_rf = GridSearchCV( |
| 145 | + cls_rf, param_grid_rf, scoring="accuracy", cv=cv_fold, verbose=1 |
| 146 | + ) |
| 147 | + gs_rf.fit(X, Y) |
| 148 | + best_rf = gs_rf.best_estimator_ |
| 149 | + df_cv_search_rf = pd.DataFrame(gs_rf.cv_results_) |
| 150 | + # Print the best parameters and score |
| 151 | + print("Best parameters:", gs_rf.best_params_) |
| 152 | + print("Best score:", gs_rf.best_score_) |
| 153 | + joblib.dump(gs_rf, f"../models/{embedding_method}_{lang}_gridsearch_rf.joblib") |
| 154 | + joblib.dump(best_rf, f"../models/{embedding_method}_{lang}_model_rf.joblib") |
| 155 | + |
| 156 | + gs_rf = joblib.load(f"../models/{embedding_method}_{lang}_gridsearch_rf.joblib") |
| 157 | + best_rf = joblib.load(f"../models/{embedding_method}_{lang}_model_rf.joblib") |
| 158 | + |
| 159 | + # Use cross_val_predict to get predicted labels and probabilities |
| 160 | + y_pred = cross_val_predict(best_rf, X, Y, cv=cv_fold) |
| 161 | + y_probas = cross_val_predict(best_rf, X, Y, cv=cv_fold, method="predict_proba") |
| 162 | + # Compute classification report |
| 163 | + report = classification_report( |
| 164 | + Y, y_pred, target_names=best_rf.classes_, output_dict=True |
| 165 | + ) |
| 166 | + |
| 167 | + run = wandb.init( |
| 168 | + project="myo-text-classify", |
| 169 | + name=f"{embedding_method}_{lang}_rf", |
| 170 | + config={ |
| 171 | + "embedding": f"{embedding_method}", |
| 172 | + "doc_lang": f"{lang}", |
| 173 | + "corpus": "complete_1704023_190reports", |
| 174 | + "model": "RandomForest", |
| 175 | + }, |
| 176 | + ) |
| 177 | + config = wandb.config |
| 178 | + best_params = gs_mlpc.best_params_ |
| 179 | + best_score = gs_mlpc.best_score_ |
| 180 | + best_std = gs_mlpc.cv_results_["std_test_score"][gs_mlpc.best_index_] |
| 181 | + balanced_accuracy_metric = balanced_accuracy_score(Y, y_pred) |
| 182 | + |
| 183 | + wandb.log( |
| 184 | + { |
| 185 | + "Classification Report": report, |
| 186 | + "Best Params": best_params, |
| 187 | + "Best Score (gs)": best_score, |
| 188 | + "CV Std Devs (gs)": best_std, |
| 189 | + "Balanced Accuracy": balanced_accuracy_metric, |
| 190 | + } |
| 191 | + ) |
| 192 | + wandb.sklearn.plot_confusion_matrix(Y, y_pred, best_rf.classes_) |
| 193 | + wandb.sklearn.plot_classifier( |
| 194 | + best_rf, |
| 195 | + X, |
| 196 | + X, |
| 197 | + Y, |
| 198 | + Y, |
| 199 | + y_pred, |
| 200 | + y_probas, |
| 201 | + labels=best_rf.classes_, |
| 202 | + model_name=f"{embedding_method}_{lang}_model", |
| 203 | + feature_names=None, |
| 204 | + ) |
| 205 | + # Create artifact for best model |
| 206 | + model_artifact = wandb.Artifact( |
| 207 | + f"{embedding_method}_{lang}_model_rf", type="model" |
| 208 | + ) |
| 209 | + # Add best estimator to artifact |
| 210 | + model_artifact.add_file(f"../models/{embedding_method}_{lang}_model_rf.joblib") |
| 211 | + # Log artifact to WandB |
| 212 | + wandb.run.log_artifact(model_artifact) |
| 213 | + wandb.finish() |
0 commit comments