This commit is contained in:
Celina Korzer 2026-02-24 11:42:58 +01:00
parent 42965a4733
commit 15190ac52e
6 changed files with 7633 additions and 0 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d48f2e13",
"metadata": {},
"source": [
"Importe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e34b838d",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np \n",
"import pandas as pd \n",
"import joblib \n",
"import seaborn as sns \n",
"import matplotlib.pyplot as plt \n",
"\n",
"from sklearn.metrics import ( \n",
" confusion_matrix, \n",
" roc_curve, auc, \n",
" precision_recall_curve, \n",
" f1_score, \n",
" balanced_accuracy_score \n",
")\n",
" \n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"id": "324554b5",
"metadata": {},
"source": [
"Modell und Scaler laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4acc3d2f",
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.load_model(\"hybrid_fusion_model_V2.keras\") \n",
"scaler_au = joblib.load(\"scaler_au_V2.joblib\") \n",
"scaler_eye = joblib.load(\"scaler_eye_V2.joblib\")\n",
"\n",
"print(\"Modell & Scaler erfolgreich geladen.\")"
]
},
{
"cell_type": "markdown",
"id": "4271cbee",
"metadata": {},
"source": [
"Features laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8342ea10",
"metadata": {},
"outputs": [],
"source": [
"au_columns = [...] \n",
"eye_columns = [...]"
]
},
{
"cell_type": "markdown",
"id": "4a58b20c",
"metadata": {},
"source": [
"Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b683be47",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye):\n",
" # AUs\n",
" X_au = df[au_columns].values\n",
" X_au = scaler_au.transform(X_au).reshape(len(df), len(au_columns), 1)\n",
"\n",
" # Eye\n",
" X_eye = df[eye_columns].values\n",
" X_eye = scaler_eye.transform(X_eye)\n",
"\n",
" return X_au, X_eye"
]
},
{
"cell_type": "markdown",
"id": "9dc99a3d",
"metadata": {},
"source": [
"Predict-Funktion"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00295aa6",
"metadata": {},
"outputs": [],
"source": [
"def predict_workload(df, model, au_columns, eye_columns, scaler_au, scaler_eye):\n",
" X_au, X_eye = preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye)\n",
"\n",
" probs = model.predict([X_au, X_eye]).flatten()\n",
" preds = (probs > 0.5).astype(int)\n",
" \n",
" return preds, probs"
]
},
{
"cell_type": "markdown",
"id": "5753516b",
"metadata": {},
"source": [
"Testdaten laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8875b0ee",
"metadata": {},
"outputs": [],
"source": [
"test_data = pd.read_csv(\"test_data.csv\") # oder direkt aus Notebook 1 exportieren \n",
"\n",
"X_au_test = test_data[au_columns].values[..., np.newaxis] \n",
"X_eye_test = test_data[eye_columns].values \n",
"y_test = test_data[\"label\"].values \n",
"groups_test = test_data[\"subjectID\"].values \n",
"\n",
"X_au_test_scaled = scaler_au.transform(X_au_test.reshape(len(X_au_test), -1)).reshape(X_au_test.shape) \n",
"X_eye_test_scaled = scaler_eye.transform(X_eye_test)"
]
},
{
"cell_type": "markdown",
"id": "332a3a07",
"metadata": {},
"source": [
"Vorhersagen"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5f58ece",
"metadata": {},
"outputs": [],
"source": [
"y_prob = model.predict([X_au_test_scaled, X_eye_test_scaled]).flatten() \n",
"y_pred = (y_prob > 0.5).astype(int)"
]
},
{
"cell_type": "markdown",
"id": "3bc5c66c",
"metadata": {},
"source": [
"Konfusionsmatrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40648dd7",
"metadata": {},
"outputs": [],
"source": [
"cm = confusion_matrix(y_test, y_pred) \n",
"plt.figure(figsize=(6,5)) \n",
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
" yticklabels=[\"True 0\", \"True 1\"]) \n",
"plt.title(\"Konfusionsmatrix - Testdaten\") \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "e79ad8a6",
"metadata": {},
"source": [
"ROC"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd93f15c",
"metadata": {},
"outputs": [],
"source": [
"fpr, tpr, _ = roc_curve(y_test, y_prob) \n",
"roc_auc = auc(fpr, tpr) \n",
"\n",
"plt.figure(figsize=(7,6)) \n",
"plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n",
"plt.plot([0,1], [0,1], \"k--\") \n",
"plt.xlabel(\"False Positive Rate\") \n",
"plt.ylabel(\"True Positive Rate\") \n",
"plt.title(\"ROCKurve Testdaten\") \n",
"plt.legend() \n",
"plt.grid(True) \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2eaaf2a0",
"metadata": {},
"source": [
"Precision-Recall"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "601e5dc9",
"metadata": {},
"outputs": [],
"source": [
"precision, recall, _ = precision_recall_curve(y_test, y_prob) \n",
"plt.figure(figsize=(7,6)) \n",
"plt.plot(recall, precision) \n",
"plt.xlabel(\"Recall\") \n",
"plt.ylabel(\"Precision\") \n",
"plt.title(\"PrecisionRecallKurve Testdaten\")\n",
"plt.grid(True) \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "270af771",
"metadata": {},
"source": [
"Scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2e7da5b",
"metadata": {},
"outputs": [],
"source": [
"print(\"F1Score:\", f1_score(y_test, y_pred)) \n",
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"id": "c6e22e1a",
"metadata": {},
"source": [
"Subject-Performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "731aaf73",
"metadata": {},
"outputs": [],
"source": [
"df_eval = pd.DataFrame({ \n",
" \"subject\": groups_test, \n",
" \"y_true\": y_test, \n",
" \"y_pred\": y_pred \n",
"}) \n",
"\n",
"subject_perf = df_eval.groupby(\"subject\").apply( \n",
" lambda x: balanced_accuracy_score(x[\"y_true\"], x[\"y_pred\"]) \n",
") \n",
"\n",
"print(\"\\n=== Balanced Accuracy pro Proband ===\") \n",
"print(subject_perf.sort_values())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}