diff --git a/examples/tabpfn_for_multiclass_classification.py b/examples/tabpfn_for_multiclass_classification.py index bc1ecb99..4857fb6e 100644 --- a/examples/tabpfn_for_multiclass_classification.py +++ b/examples/tabpfn_for_multiclass_classification.py @@ -87,7 +87,7 @@ def calc_roc_auc(y_test, y_onehot_test, y_score,n_classes): return fpr, tpr, roc_auc -def plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc): +def plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc,target_names): fig, ax = plt.subplots(figsize=(6, 6)) @@ -158,49 +158,45 @@ def run_classifier(X,y,target_names): #y_onehot_test.shape # (n_samples, n_classes) fpr, tpr, roc_auc = calc_roc_auc(y_test,y_onehot_test, y_score,n_classes) - plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc) - - - -#iris = load_iris() -#target_names = iris.target_names -#X, y = iris.data, iris.target -#y = iris.target_names[y] - -#run_classifier(X,y,target_names) - - -df = pd.read_csv("data/exp_325/Unbalanced_325.csv") - -# Prepare features and target -X = df.drop('Species', axis=1) -y = df[['Species']] -target_names = y['Species'].unique() -run_classifier(X,y,target_names) - -''' -class_of_interest = "virginica" -class_id = np.flatnonzero(label_binarizer.classes_ == class_of_interest)[0] -class_id - + plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc,target_names) +def run_iris(): + print("Running classification on Iris Dataset") + iris = load_iris() + target_names = iris.target_names + X, y = iris.data, iris.target + y = iris.target_names[y] + + run_classifier(X,y,target_names) -display = RocCurveDisplay.from_predictions( - y_onehot_test[:, class_id], - y_score[:, class_id], - name=f"{class_of_interest} vs the rest", - color="darkorange", - plot_chance_level=True, - despine=True, -) -_ = display.ax_.set( - xlabel="False Positive Rate", - ylabel="True Positive Rate", - title="One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)", -) -''' +def run_exp325(): + df = pd.read_csv("data/exp_325/Unbalanced_325.csv") + + # Prepare features and target + X = df.drop('Species', axis=1).to_numpy() + y = df[['Species']].convert_dtypes() + #y['Species'] = y['Species'].astype("str") + target_names = y['Species'].unique().to_numpy().astype(str) + y = y.to_numpy().astype(str) + run_classifier(X,y,target_names) +def run_exp310(): + df = pd.read_csv("data/exp_310/Unbalanced_310.csv") + + # Prepare features and target + X = df.drop('Species', axis=1).to_numpy() + y = df[['Species']].convert_dtypes() + #y['Species'] = y['Species'].astype("str") + target_names = y['Species'].unique().to_numpy().astype(str) + y = y.to_numpy().astype(str) + run_classifier(X,y,target_names) +def main(): + + run_exp310() + +if __name__ == "__main__": + main()