CNN with test-split

This commit is contained in:
Celina Korzer 2026-02-24 15:48:23 +01:00
parent 984ef89a07
commit 0f11a88ae7
3 changed files with 80 additions and 31 deletions

View File

@ -1706,18 +1706,6 @@
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1756,7 +1756,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": null,
"id": "5a09f80c", "id": "5a09f80c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1815,10 +1815,10 @@
"plt.show()\n", "plt.show()\n",
"\n", "\n",
"\n", "\n",
"print(\"Precision:\", precision_score(y, y_pred_final)) \n", "#print(\"Precision:\", precision_score(y, y_pred_final)) \n",
"print(\"Recall:\", recall_score(y, y_pred_final)) \n", "#print(\"Recall:\", recall_score(y, y_pred_final)) \n",
"print(\"F1Score:\", f1_score(y, y_pred_final)) \n", "#print(\"F1Score:\", f1_score(y, y_pred_final)) \n",
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y, y_pred_final))" "#print(\"Balanced Accuracy:\", balanced_accuracy_score(y, y_pred_final))"
] ]
}, },
{ {
@ -1921,18 +1921,6 @@
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -20,8 +20,9 @@
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import random \n", "import random \n",
"import joblib \n", "import joblib \n",
"import seaborn as sns\n",
"from pathlib import Path \n", "from pathlib import Path \n",
"from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, balanced_accuracy_score\n",
"\n", "\n",
"from sklearn.model_selection import GroupKFold \n", "from sklearn.model_selection import GroupKFold \n",
"from sklearn.preprocessing import StandardScaler \n", "from sklearn.preprocessing import StandardScaler \n",
@ -1215,12 +1216,30 @@
" cm = confusion_matrix(y_val, y_pred)\n", " cm = confusion_matrix(y_val, y_pred)\n",
" all_conf_matrices.append(cm) \n", " all_conf_matrices.append(cm) \n",
" \n", " \n",
" #Heatmap \n",
" plt.figure(figsize=(6,5)) \n",
" sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
" yticklabels=[\"True 0\", \"True 1\"]) \n",
" plt.title(f\"Konfusionsmatrix Fold {fold+1}\") \n",
" plt.xlabel(\"Predicted\") \n",
" plt.ylabel(\"True\") \n",
" plt.show()\n",
" \n",
" print(f\"Konfusionsmatrix Fold {fold+1}:\\n{cm}\\n\") \n", " print(f\"Konfusionsmatrix Fold {fold+1}:\\n{cm}\\n\") \n",
" \n", " \n",
"# Aggregierte Matrix \n", "# Aggregierte Matrix \n",
"agg_cm = sum(all_conf_matrices) \n", "agg_cm = sum(all_conf_matrices) \n",
"print(\"Aggregierte Konfusionsmatrix über alle Folds:\") \n", "print(\"Aggregierte Konfusionsmatrix über alle Folds:\") \n",
"print(agg_cm)\n" "print(agg_cm)\n",
"plt.figure(figsize=(6,5))\n",
"sns.heatmap(agg_cm, annot=True, fmt=\"d\", cmap=\"Purples\",\n",
" xticklabels=[\"Pred 0\", \"Pred 1\"],\n",
" yticklabels=[\"True 0\", \"True 1\"])\n",
"plt.title(\"Aggregierte Konfusionsmatrix alle Folds\")\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"True\")\n",
"plt.show()\n"
] ]
}, },
{ {
@ -1640,6 +1659,60 @@
")" ")"
] ]
}, },
{
"cell_type": "markdown",
"id": "445fe4c8",
"metadata": {},
"source": [
"Evalutation auf allen Trainingsdaten"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a14a3bda",
"metadata": {},
"outputs": [],
"source": [
"# 1. Preprocessing der Trainingsdaten\n",
"X_au_train_final = scaler_au_final.transform(\n",
" X_au_train.reshape(len(X_au_train), -1)\n",
").reshape(X_au_train.shape)\n",
"\n",
"X_eye_train_final = scaler_eye_final.transform(X_eye_train)\n",
"\n",
"# 2. Vorhersagen\n",
"y_prob_final = final_model.predict([X_au_train_final, X_eye_train_final]).flatten()\n",
"y_pred_final = (y_prob_final > 0.5).astype(int)\n",
"\n",
"# 3. Metriken\n",
"loss, acc, auc_val = final_model.evaluate(\n",
" [X_au_train_final, X_eye_train_final], y_train, verbose=0\n",
")\n",
"\n",
"print(f\"Final Loss: {loss:.4f}\")\n",
"print(f\"Final Accuracy: {acc:.4f}\")\n",
"print(f\"Final AUC: {auc_val:.4f}\")\n",
"\n",
"# 4. Konfusionsmatrix (Heatmap)\n",
"cm_final = confusion_matrix(y_train, y_pred_final)\n",
"\n",
"plt.figure(figsize=(6,5))\n",
"sns.heatmap(cm_final, annot=True, fmt=\"d\", cmap=\"Oranges\",\n",
" xticklabels=[\"Pred 0\", \"Pred 1\"],\n",
" yticklabels=[\"True 0\", \"True 1\"])\n",
"plt.title(\"Konfusionsmatrix Finales Hybrid-Modell (Trainingsdaten)\")\n",
"plt.xlabel(\"Predicted\")\n",
"plt.ylabel(\"True\")\n",
"plt.show()\n",
"\n",
"# 5. Weitere Metriken\n",
"#print(\"Precision:\", precision_score(y_train, y_pred_final))\n",
"#print(\"Recall:\", recall_score(y_train, y_pred_final))\n",
"#print(\"F1Score:\", f1_score(y_train, y_pred_final))\n",
"#print(\"Balanced Accuracy:\", balanced_accuracy_score(y_train, y_pred_final))\n"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7c7f9cc4", "id": "7c7f9cc4",