23 lines
753 B
Python
23 lines
753 B
Python
import pandas as pd
|
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
|
|
import matplotlib.pyplot as plt
|
|
|
|
def plot_confusion_matrix(true_labels, predictions, label_names):
|
|
for normalize in [None, 'true']:
|
|
cm = confusion_matrix(true_labels, predictions, normalize=normalize)
|
|
cm_disp = ConfusionMatrixDisplay(cm, display_labels=label_names)
|
|
cm_disp.plot(cmap="Blues")
|
|
|
|
|
|
def plot_roc_curve_IF(true_labels, scores):
|
|
fpr, tpr, thr = roc_curve(true_labels, -scores, pos_label=-1)
|
|
auc_score = auc(fpr, tpr)
|
|
|
|
plt.figure()
|
|
plt.plot(fpr, tpr, '-')
|
|
plt.text(0.5, 0.5, f'AUC: {auc_score:.4f}')
|
|
plt.xlabel('False positive rate')
|
|
plt.ylabel('True positive rate')
|
|
plt.show()
|
|
|
|
|