CNN with test-split

This commit is contained in:
Celina Korzer 2026-02-24 15:48:23 +01:00
parent 984ef89a07
commit 0f11a88ae7
3 changed files with 80 additions and 31 deletions

View File

@ -1706,18 +1706,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -1756,7 +1756,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": null,
"id": "5a09f80c",
"metadata": {},
"outputs": [
@ -1815,10 +1815,10 @@
"plt.show()\n",
"\n",
"\n",
"print(\"Precision:\", precision_score(y, y_pred_final)) \n",
"print(\"Recall:\", recall_score(y, y_pred_final)) \n",
"print(\"F1Score:\", f1_score(y, y_pred_final)) \n",
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y, y_pred_final))"
"#print(\"Precision:\", precision_score(y, y_pred_final)) \n",
"#print(\"Recall:\", recall_score(y, y_pred_final)) \n",
"#print(\"F1Score:\", f1_score(y, y_pred_final)) \n",
"#print(\"Balanced Accuracy:\", balanced_accuracy_score(y, y_pred_final))"
]
},
{
@ -1921,18 +1921,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -20,8 +20,9 @@
"import matplotlib.pyplot as plt\n",
"import random \n",
"import joblib \n",
"import seaborn as sns\n",
"from pathlib import Path \n",
"from sklearn.metrics import confusion_matrix\n",
"from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, balanced_accuracy_score\n",
"\n",
"from sklearn.model_selection import GroupKFold \n",
"from sklearn.preprocessing import StandardScaler \n",
@ -1215,12 +1216,30 @@
" cm = confusion_matrix(y_val, y_pred)\n",
" all_conf_matrices.append(cm) \n",
" \n",
" #Heatmap \n",
" plt.figure(figsize=(6,5)) \n",
" sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
" yticklabels=[\"True 0\", \"True 1\"]) \n",
" plt.title(f\"Konfusionsmatrix Fold {fold+1}\") \n",
" plt.xlabel(\"Predicted\") \n",
" plt.ylabel(\"True\") \n",
" plt.show()\n",
" \n",
" print(f\"Konfusionsmatrix Fold {fold+1}:\\n{cm}\\n\") \n",
" \n",
"# Aggregierte Matrix \n",
"agg_cm = sum(all_conf_matrices) \n",
"print(\"Aggregierte Konfusionsmatrix über alle Folds:\") \n",
"print(agg_cm)\n"
"print(agg_cm)\n",
"plt.figure(figsize=(6,5))\n",
"sns.heatmap(agg_cm, annot=True, fmt=\"d\", cmap=\"Purples\",\n",
" xticklabels=[\"Pred 0\", \"Pred 1\"],\n",
" yticklabels=[\"True 0\", \"True 1\"])\n",
"plt.title(\"Aggregierte Konfusionsmatrix alle Folds\")\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"True\")\n",
"plt.show()\n"
]
},
{
@ -1640,6 +1659,60 @@
")"
]
},
{
"cell_type": "markdown",
"id": "445fe4c8",
"metadata": {},
"source": [
"Evalutation auf allen Trainingsdaten"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a14a3bda",
"metadata": {},
"outputs": [],
"source": [
"# 1. Preprocessing der Trainingsdaten\n",
"X_au_train_final = scaler_au_final.transform(\n",
" X_au_train.reshape(len(X_au_train), -1)\n",
").reshape(X_au_train.shape)\n",
"\n",
"X_eye_train_final = scaler_eye_final.transform(X_eye_train)\n",
"\n",
"# 2. Vorhersagen\n",
"y_prob_final = final_model.predict([X_au_train_final, X_eye_train_final]).flatten()\n",
"y_pred_final = (y_prob_final > 0.5).astype(int)\n",
"\n",
"# 3. Metriken\n",
"loss, acc, auc_val = final_model.evaluate(\n",
" [X_au_train_final, X_eye_train_final], y_train, verbose=0\n",
")\n",
"\n",
"print(f\"Final Loss: {loss:.4f}\")\n",
"print(f\"Final Accuracy: {acc:.4f}\")\n",
"print(f\"Final AUC: {auc_val:.4f}\")\n",
"\n",
"# 4. Konfusionsmatrix (Heatmap)\n",
"cm_final = confusion_matrix(y_train, y_pred_final)\n",
"\n",
"plt.figure(figsize=(6,5))\n",
"sns.heatmap(cm_final, annot=True, fmt=\"d\", cmap=\"Oranges\",\n",
" xticklabels=[\"Pred 0\", \"Pred 1\"],\n",
" yticklabels=[\"True 0\", \"True 1\"])\n",
"plt.title(\"Konfusionsmatrix Finales Hybrid-Modell (Trainingsdaten)\")\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"True\")\n",
"plt.show()\n",
"\n",
"# 5. Weitere Metriken\n",
"#print(\"Precision:\", precision_score(y_train, y_pred_final))\n",
"#print(\"Recall:\", recall_score(y_train, y_pred_final))\n",
"#print(\"F1Score:\", f1_score(y_train, y_pred_final))\n",
"#print(\"Balanced Accuracy:\", balanced_accuracy_score(y_train, y_pred_final))\n"
]
},
{
"cell_type": "markdown",
"id": "7c7f9cc4",