Skip to content

Commit

Permalink
ROC plot for milticlass classification updated
Browse files Browse the repository at this point in the history
  • Loading branch information
rvdeo committed Jan 27, 2025
1 parent 85c9b74 commit c6c35c6
Showing 1 changed file with 36 additions and 40 deletions.
76 changes: 36 additions & 40 deletions examples/tabpfn_for_multiclass_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def calc_roc_auc(y_test, y_onehot_test, y_score,n_classes):
return fpr, tpr, roc_auc


def plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc):
def plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc,target_names):


fig, ax = plt.subplots(figsize=(6, 6))
Expand Down Expand Up @@ -158,49 +158,45 @@ def run_classifier(X,y,target_names):
#y_onehot_test.shape # (n_samples, n_classes)

fpr, tpr, roc_auc = calc_roc_auc(y_test,y_onehot_test, y_score,n_classes)
plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc)



#iris = load_iris()
#target_names = iris.target_names
#X, y = iris.data, iris.target
#y = iris.target_names[y]

#run_classifier(X,y,target_names)


df = pd.read_csv("data/exp_325/Unbalanced_325.csv")

# Prepare features and target
X = df.drop('Species', axis=1)
y = df[['Species']]
target_names = y['Species'].unique()
run_classifier(X,y,target_names)

'''
class_of_interest = "virginica"
class_id = np.flatnonzero(label_binarizer.classes_ == class_of_interest)[0]
class_id
plot_roc_auc(y_onehot_test, y_score,n_classes,fpr,tpr,roc_auc,target_names)

def run_iris():
print("Running classification on Iris Dataset")
iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
y = iris.target_names[y]

run_classifier(X,y,target_names)

display = RocCurveDisplay.from_predictions(
y_onehot_test[:, class_id],
y_score[:, class_id],
name=f"{class_of_interest} vs the rest",
color="darkorange",
plot_chance_level=True,
despine=True,
)
_ = display.ax_.set(
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title="One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)",
)
'''
def run_exp325():
df = pd.read_csv("data/exp_325/Unbalanced_325.csv")

# Prepare features and target
X = df.drop('Species', axis=1).to_numpy()
y = df[['Species']].convert_dtypes()
#y['Species'] = y['Species'].astype("str")
target_names = y['Species'].unique().to_numpy().astype(str)
y = y.to_numpy().astype(str)
run_classifier(X,y,target_names)


def run_exp310():
df = pd.read_csv("data/exp_310/Unbalanced_310.csv")

# Prepare features and target
X = df.drop('Species', axis=1).to_numpy()
y = df[['Species']].convert_dtypes()
#y['Species'] = y['Species'].astype("str")
target_names = y['Species'].unique().to_numpy().astype(str)
y = y.to_numpy().astype(str)
run_classifier(X,y,target_names)

def main():

run_exp310()


if __name__ == "__main__":
main()

0 comments on commit c6c35c6

Please sign in to comment.