- added earlyFusionTest
- changed to group split
This commit is contained in:
parent
7a63c7acd3
commit
e69000fbd8
529
model_training/CNN/CNN_crossVal_EarlyFusion_Test_Eval.ipynb
Normal file
529
model_training/CNN/CNN_crossVal_EarlyFusion_Test_Eval.ipynb
Normal file
@ -0,0 +1,529 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "47f6de7b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Bibliotheken importieren"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "99294260",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd \n",
|
||||||
|
"import numpy as np \n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import seaborn as sns \n",
|
||||||
|
"import random \n",
|
||||||
|
"import joblib \n",
|
||||||
|
"from pathlib import Path \n",
|
||||||
|
"\n",
|
||||||
|
"from sklearn.model_selection import GroupKFold, GroupShuffleSplit\n",
|
||||||
|
"from sklearn.preprocessing import StandardScaler \n",
|
||||||
|
"from sklearn.metrics import ( \n",
|
||||||
|
" precision_score, recall_score,\n",
|
||||||
|
" confusion_matrix, roc_curve, auc, \n",
|
||||||
|
" precision_recall_curve, f1_score, \n",
|
||||||
|
" balanced_accuracy_score, accuracy_score\n",
|
||||||
|
") \n",
|
||||||
|
"\n",
|
||||||
|
"import tensorflow as tf \n",
|
||||||
|
"from tensorflow.keras import Input, layers, models, regularizers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "52b4ca8c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Seed festlegen"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6e49d281",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"SEED = 42 \n",
|
||||||
|
"np.random.seed(SEED) \n",
|
||||||
|
"tf.random.set_seed(SEED) \n",
|
||||||
|
"random.seed(SEED)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ae1a715f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Daten laden"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "870f01c3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data_path = Path(r\"~/data-paulusjafahrsimulator-gpu/new_datasets/50s_25Hz_dataset.parquet\") \n",
|
||||||
|
"\n",
|
||||||
|
"data = pd.read_parquet(path=data_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bedbc23b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Labels erstellen"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "38848515",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"low_all = data[((data[\"PHASE\"] == \"baseline\") | \n",
|
||||||
|
" ((data[\"STUDY\"] == \"n-back\") & (data[\"PHASE\"] != \"baseline\") & (data[\"LEVEL\"].isin([1,4]))))].copy() \n",
|
||||||
|
"\n",
|
||||||
|
"high_all = pd.concat([ \n",
|
||||||
|
" data[(data[\"STUDY\"]==\"n-back\") & (data[\"LEVEL\"].isin([2,3,5,6])) & (data[\"PHASE\"].isin([\"train\",\"test\"]))], \n",
|
||||||
|
" data[(data[\"STUDY\"]==\"k-drive\") & (data[\"PHASE\"]!=\"baseline\")] \n",
|
||||||
|
"]).copy() \n",
|
||||||
|
"\n",
|
||||||
|
"low_all[\"label\"] = 0 \n",
|
||||||
|
"high_all[\"label\"] = 1 \n",
|
||||||
|
"data = pd.concat([low_all, high_all], ignore_index=True).drop_duplicates() "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0b282acf",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Features und Labels"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5edb00a0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#Face AUs\n",
|
||||||
|
"au_columns = [col for col in data.columns if \"face\" in col.lower()] \n",
|
||||||
|
"\n",
|
||||||
|
"#Eye Features\n",
|
||||||
|
"eye_columns = [ \n",
|
||||||
|
" 'Fix_count_short_66_150', \n",
|
||||||
|
" 'Fix_count_medium_300_500', \n",
|
||||||
|
" 'Fix_count_long_gt_1000', \n",
|
||||||
|
" 'Fix_count_100', \n",
|
||||||
|
" 'Fix_mean_duration', \n",
|
||||||
|
" 'Fix_median_duration', \n",
|
||||||
|
" 'Sac_count', \n",
|
||||||
|
" 'Sac_mean_amp', \n",
|
||||||
|
" 'Sac_mean_dur', \n",
|
||||||
|
" 'Sac_median_dur', \n",
|
||||||
|
" 'Blink_count', \n",
|
||||||
|
" 'Blink_mean_dur', \n",
|
||||||
|
" 'Blink_median_dur', \n",
|
||||||
|
" 'Pupil_mean', \n",
|
||||||
|
" 'Pupil_IPA' \n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"#Early Fusion\n",
|
||||||
|
"feature_columns = au_columns + eye_columns\n",
|
||||||
|
"\n",
|
||||||
|
"#NaNs entfernen \n",
|
||||||
|
"data = data.dropna(subset=feature_columns + [\"label\"])\n",
|
||||||
|
"\n",
|
||||||
|
"X = data[feature_columns].values[..., np.newaxis] \n",
|
||||||
|
"y = data[\"label\"].values \n",
|
||||||
|
"\n",
|
||||||
|
"groups = data[\"subjectID\"].values\n",
|
||||||
|
"print(data.columns.tolist())\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Gefundene FACE_AU-Spalten:\", au_columns)\n",
|
||||||
|
"print(\"Gefundene Eye Features:\" , eye_columns)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Anzahl FACE_AUs:\", len(au_columns)) \n",
|
||||||
|
"print(\"Anzahl EYE Features:\", len(eye_columns)) \n",
|
||||||
|
"print(\"Gesamtzahl Features:\", len(feature_columns))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d8689679",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Train-Test-Split"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b5cf88c3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)\n",
|
||||||
|
"train_idx, test_idx = next(gss.split(X, y, groups))\n",
|
||||||
|
"\n",
|
||||||
|
"feature_columns_train, feature_columns_test = X[train_idx], X[test_idx]\n",
|
||||||
|
"y_train, y_test = y[train_idx], y[test_idx]\n",
|
||||||
|
"groups_train, groups_test = groups[train_idx], groups[test_idx]\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Train:\", len(y_train), \" | Test:\", len(y_test))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a539b83b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"CNN-Modell"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e4a7f496",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def build_model(input_shape, lr=1e-4): \n",
|
||||||
|
" model = models.Sequential([ \n",
|
||||||
|
" Input(shape=input_shape), \n",
|
||||||
|
" layers.Conv1D(32, kernel_size=3, activation=\"relu\", kernel_regularizer=regularizers.l2(0.001)), \n",
|
||||||
|
" layers.BatchNormalization(), \n",
|
||||||
|
" layers.MaxPooling1D(pool_size=2),\n",
|
||||||
|
"\n",
|
||||||
|
" layers.Conv1D(64, kernel_size=3, activation=\"relu\", kernel_regularizer=regularizers.l2(0.001)), \n",
|
||||||
|
" layers.BatchNormalization(), \n",
|
||||||
|
" layers.GlobalAveragePooling1D(), \n",
|
||||||
|
" \n",
|
||||||
|
" layers.Dense(32, activation=\"relu\", kernel_regularizer=regularizers.l2(0.001)), \n",
|
||||||
|
" layers.Dropout(0.5), \n",
|
||||||
|
" layers.Dense(1, activation=\"sigmoid\") \n",
|
||||||
|
" ]) \n",
|
||||||
|
" \n",
|
||||||
|
" model.compile( \n",
|
||||||
|
" optimizer=tf.keras.optimizers.Adam(learning_rate=lr), \n",
|
||||||
|
" loss=\"binary_crossentropy\", \n",
|
||||||
|
" metrics=[\"accuracy\", tf.keras.metrics.AUC(name=\"auc\")] \n",
|
||||||
|
" ) \n",
|
||||||
|
" return model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5905871b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Cross-Validation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "90658000",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"gkf = GroupKFold(n_splits=5) \n",
|
||||||
|
"cv_histories = [] \n",
|
||||||
|
"cv_results = [] \n",
|
||||||
|
"fold_subjects = []\n",
|
||||||
|
"all_conf_matrices = []\n",
|
||||||
|
"\n",
|
||||||
|
"for fold, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):\n",
|
||||||
|
" train_subjects = np.unique(groups[train_idx]) \n",
|
||||||
|
" val_subjects = np.unique(groups[val_idx]) \n",
|
||||||
|
" fold_subjects.append({\"Fold\": fold+1, \n",
|
||||||
|
" \"Train_Subjects\": train_subjects, \n",
|
||||||
|
" \"Val_Subjects\": val_subjects}) \n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"\\n--- Fold {fold+1} ---\") \n",
|
||||||
|
" print(\"Train-Subjects:\", train_subjects) \n",
|
||||||
|
" print(\"Val-Subjects:\", val_subjects) \n",
|
||||||
|
"\n",
|
||||||
|
" #Split\n",
|
||||||
|
" X_train, X_val = X[train_idx], X[val_idx] \n",
|
||||||
|
" y_train, y_val = y[train_idx], y[val_idx] # Normalisierung pro Fold \n",
|
||||||
|
"\n",
|
||||||
|
" #Normalisierung pro Fold\n",
|
||||||
|
" scaler = StandardScaler() \n",
|
||||||
|
" X_train = scaler.fit_transform(X_train.reshape(len(X_train), -1)).reshape(X_train.shape) \n",
|
||||||
|
" X_val = scaler.transform(X_val.reshape(len(X_val), -1)).reshape(X_val.shape) \n",
|
||||||
|
"\n",
|
||||||
|
" # Plausibilitäts-Check \n",
|
||||||
|
" print(\"Train Mittelwerte (erste 5 Features):\", X_train.mean(axis=0)[:5]) \n",
|
||||||
|
" print(\"Train Std (erste 5 Features):\", X_train.std(axis=0)[:5]) \n",
|
||||||
|
" print(\"Val Mittelwerte (erste 5 Features):\", X_val.mean(axis=0)[:5]) \n",
|
||||||
|
" print(\"Val Std (erste 5 Features):\", X_val.std(axis=0)[:5]) \n",
|
||||||
|
"\n",
|
||||||
|
" # Modell \n",
|
||||||
|
" model = build_model(input_shape=(len(feature_columns_train),1), lr=1e-4) \n",
|
||||||
|
" model.summary() \n",
|
||||||
|
"\n",
|
||||||
|
" callbacks = [ \n",
|
||||||
|
" tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=10, restore_best_weights=True), \n",
|
||||||
|
" tf.keras.callbacks.ReduceLROnPlateau(monitor=\"val_loss\", factor=0.5, patience=5, min_lr=1e-6) \n",
|
||||||
|
" ] \n",
|
||||||
|
"\n",
|
||||||
|
" history = model.fit( \n",
|
||||||
|
" X_train, y_train, \n",
|
||||||
|
" validation_data=(X_val, y_val), \n",
|
||||||
|
" epochs=100, \n",
|
||||||
|
" batch_size=16, \n",
|
||||||
|
" callbacks=callbacks, \n",
|
||||||
|
" verbose=0 \n",
|
||||||
|
" ) \n",
|
||||||
|
"\n",
|
||||||
|
" cv_histories.append(history.history) \n",
|
||||||
|
" scores = model.evaluate(X_val, y_val, verbose=0) \n",
|
||||||
|
" cv_results.append(scores) \n",
|
||||||
|
" print(f\"Fold {fold+1} - Val Loss: {scores[0]:.4f}, Val Acc: {scores[1]:.4f}, Val AUC: {scores[2]:.4f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" #Konfusionsmatrix \n",
|
||||||
|
" y_pred = (model.predict(X_val) > 0.5).astype(int) \n",
|
||||||
|
" cm = confusion_matrix(y_val, y_pred) \n",
|
||||||
|
" all_conf_matrices.append(cm) \n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"Konfusionsmatrix Fold {fold+1}:\\n{cm}\\n\") \n",
|
||||||
|
" \n",
|
||||||
|
"# Aggregierte Matrix \n",
|
||||||
|
"agg_cm = sum(all_conf_matrices) \n",
|
||||||
|
"print(\"Aggregierte Konfusionsmatrix über alle Folds:\") \n",
|
||||||
|
"print(agg_cm)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d10b7e78",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9aeba7f4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#results\n",
|
||||||
|
"cv_results = np.array(cv_results) \n",
|
||||||
|
"print(\"\\n=== Cross-Validation Ergebnisse ===\") \n",
|
||||||
|
"print(f\"Durchschnittlicher Val-Loss: {cv_results[:,0].mean():.4f}\") \n",
|
||||||
|
"print(f\"Durchschnittliche Val-Accuracy: {cv_results[:,1].mean():.4f}\") \n",
|
||||||
|
"print(f\"Durchschnittliche Val-AUC: {cv_results[:,2].mean():.4f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"#Ergebnis-Tabelle erstellen\n",
|
||||||
|
"results_table = pd.DataFrame({ \n",
|
||||||
|
" \"Fold\": np.arange(1, len(cv_results)+1), \n",
|
||||||
|
" \"Val Loss\": cv_results[:,0], \n",
|
||||||
|
" \"Val Accuracy\": cv_results[:,1], \n",
|
||||||
|
" \"Val AUC\": cv_results[:,2] }) \n",
|
||||||
|
"\n",
|
||||||
|
"# Durchschnittszeile hinzufügen \n",
|
||||||
|
"avg_row = pd.DataFrame({ \n",
|
||||||
|
" \"Fold\": [\"Ø\"], \n",
|
||||||
|
" \"Val Loss\": [cv_results[:,0].mean()], \n",
|
||||||
|
" \"Val Accuracy\": [cv_results[:,1].mean()], \n",
|
||||||
|
" \"Val AUC\": [cv_results[:,2].mean()] \n",
|
||||||
|
"}) \n",
|
||||||
|
"\n",
|
||||||
|
"results_table = pd.concat([results_table, avg_row], ignore_index=True) \n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n=== Ergebnis-Tabelle ===\") \n",
|
||||||
|
"print(results_table) \n",
|
||||||
|
"\n",
|
||||||
|
"#Tabelle speichern \n",
|
||||||
|
"results_table.to_csv(\"cnn_crossVal_results.csv\", index=False) \n",
|
||||||
|
"print(\"Ergebnisse gespeichert als 'cnn_crossVal_results.csv'\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fae5df7a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Finales Modell trainieren"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5b3eab61",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"scaler_final = StandardScaler() \n",
|
||||||
|
"X_scaled = scaler_final.fit_transform(feature_columns_train.reshape(len(feature_columns_train), -1)).reshape(feature_columns_train.shape) \n",
|
||||||
|
"\n",
|
||||||
|
"final_model = build_model(input_shape=(len(feature_columns_train),1), lr=1e-4) \n",
|
||||||
|
"final_model.summary() \n",
|
||||||
|
"\n",
|
||||||
|
"final_model.fit( \n",
|
||||||
|
" X_scaled, y_train, \n",
|
||||||
|
" epochs=150, \n",
|
||||||
|
" batch_size=16, \n",
|
||||||
|
" verbose=1 \n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7c7f9cc4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Speichern des Modells"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2d3af5be",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# final_model.save(\"cnn_crossVal_EarlyFusion_V2.keras\") \n",
|
||||||
|
"# joblib.dump(scaler_final, \"scaler_crossVal_EarlyFusion_V2.joblib\") \n",
|
||||||
|
"\n",
|
||||||
|
"# print(\"Finales Modell und Scaler gespeichert als 'cnn_crossVal_EarlyFusion_V2.keras' und 'scaler_crossVal_EarlyFusion_V2.joblib'\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "c11891e0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Plots"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9f6a8584",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#plots\n",
|
||||||
|
"def plot_cv_histories(cv_histories, metric): \n",
|
||||||
|
" plt.figure(figsize=(10,6)) \n",
|
||||||
|
" \n",
|
||||||
|
" for i, hist in enumerate(cv_histories): \n",
|
||||||
|
" plt.plot(hist[metric], label=f\"Fold {i+1} Train\", alpha=0.7) \n",
|
||||||
|
" plt.plot(hist[f\"val_{metric}\"], label=f\"Fold {i+1} Val\", linestyle=\"--\", alpha=0.7) \n",
|
||||||
|
" plt.xlabel(\"Epochs\") \n",
|
||||||
|
" plt.ylabel(metric.capitalize()) \n",
|
||||||
|
" plt.title(f\"Cross-Validation {metric.capitalize()} Verläufe\") \n",
|
||||||
|
" plt.legend() \n",
|
||||||
|
" plt.grid(True) \n",
|
||||||
|
" plt.show()\n",
|
||||||
|
" \n",
|
||||||
|
"plot_cv_histories(cv_histories, \"loss\") \n",
|
||||||
|
"plot_cv_histories(cv_histories, \"accuracy\") \n",
|
||||||
|
"plot_cv_histories(cv_histories, \"auc\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4aebe6c6",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Test"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0d34d6b7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Preprocessing Testdaten \n",
|
||||||
|
"X_test_scaled = scaler.transform( \n",
|
||||||
|
" feature_columns_test.reshape(len(feature_columns_test), -1) \n",
|
||||||
|
").reshape(feature_columns_test.shape) \n",
|
||||||
|
"\n",
|
||||||
|
"# Vorhersagen \n",
|
||||||
|
"y_prob_test = model.predict(X_test_scaled).flatten() \n",
|
||||||
|
"y_pred_test = (y_prob_test > 0.5).astype(int) \n",
|
||||||
|
"\n",
|
||||||
|
"# Konfusionsmatrix \n",
|
||||||
|
"cm_test = confusion_matrix(y_test, y_pred_test) \n",
|
||||||
|
"\n",
|
||||||
|
"plt.figure(figsize=(6,5)) \n",
|
||||||
|
"sns.heatmap(cm_test, annot=True, fmt=\"d\", cmap=\"Greens\", \n",
|
||||||
|
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
|
||||||
|
" yticklabels=[\"True 0\", \"True 1\"]) \n",
|
||||||
|
"plt.title(\"Konfusionsmatrix - Testdaten\") \n",
|
||||||
|
"plt.show() \n",
|
||||||
|
"\n",
|
||||||
|
"# ROC \n",
|
||||||
|
"fpr, tpr, _ = roc_curve(y_test, y_prob_test) \n",
|
||||||
|
"roc_auc = auc(fpr, tpr) \n",
|
||||||
|
"\n",
|
||||||
|
"plt.figure(figsize=(7,6)) \n",
|
||||||
|
"plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n",
|
||||||
|
"plt.plot([0,1], [0,1], \"k--\") \n",
|
||||||
|
"plt.title(\"ROC - Testdaten\") \n",
|
||||||
|
"plt.legend() \n",
|
||||||
|
"plt.grid(True) \n",
|
||||||
|
"plt.show() \n",
|
||||||
|
"\n",
|
||||||
|
"# Precision-Recall \n",
|
||||||
|
"precision, recall, _ = precision_recall_curve(y_test, y_prob_test) \n",
|
||||||
|
"plt.figure(figsize=(7,6)) \n",
|
||||||
|
"plt.plot(recall, precision) \n",
|
||||||
|
"plt.title(\"Precision-Recall - Testdaten\") \n",
|
||||||
|
"plt.grid(True) \n",
|
||||||
|
"plt.show() \n",
|
||||||
|
"\n",
|
||||||
|
"# Metriken \n",
|
||||||
|
"print(\"Accuracy:\", accuracy_score(y_test, y_pred_test))\n",
|
||||||
|
"print(\"F1-Score:\", f1_score(y_test, y_pred_test)) \n",
|
||||||
|
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred_test)) \n",
|
||||||
|
"print(\"Precision:\", precision_score(y_test, y_pred_test)) \n",
|
||||||
|
"print(\"Recall:\", recall_score(y_test, y_pred_test)) \n",
|
||||||
|
"print(\"AUC:\", roc_auc)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user