309 lines
6.9 KiB
Plaintext
309 lines
6.9 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d48f2e13",
|
||
"metadata": {},
|
||
"source": [
|
||
"Importe"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e34b838d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np \n",
|
||
"import pandas as pd \n",
|
||
"import joblib \n",
|
||
"import seaborn as sns \n",
|
||
"import matplotlib.pyplot as plt \n",
|
||
"\n",
|
||
"from sklearn.metrics import ( \n",
|
||
" confusion_matrix, \n",
|
||
" roc_curve, auc, \n",
|
||
" precision_recall_curve, \n",
|
||
" f1_score, \n",
|
||
" balanced_accuracy_score \n",
|
||
")\n",
|
||
" \n",
|
||
"import tensorflow as tf"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "324554b5",
|
||
"metadata": {},
|
||
"source": [
|
||
"Modell und Scaler laden"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "4acc3d2f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = tf.keras.models.load_model(\"hybrid_fusion_model_V2.keras\") \n",
|
||
"scaler_au = joblib.load(\"scaler_au_V2.joblib\") \n",
|
||
"scaler_eye = joblib.load(\"scaler_eye_V2.joblib\")\n",
|
||
"\n",
|
||
"print(\"Modell & Scaler erfolgreich geladen.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4271cbee",
|
||
"metadata": {},
|
||
"source": [
|
||
"Features laden"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "8342ea10",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"au_columns = [...] \n",
|
||
"eye_columns = [...]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4a58b20c",
|
||
"metadata": {},
|
||
"source": [
|
||
"Preprocessing"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b683be47",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye):\n",
|
||
" # AUs\n",
|
||
" X_au = df[au_columns].values\n",
|
||
" X_au = scaler_au.transform(X_au).reshape(len(df), len(au_columns), 1)\n",
|
||
"\n",
|
||
" # Eye\n",
|
||
" X_eye = df[eye_columns].values\n",
|
||
" X_eye = scaler_eye.transform(X_eye)\n",
|
||
"\n",
|
||
" return X_au, X_eye"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "9dc99a3d",
|
||
"metadata": {},
|
||
"source": [
|
||
"Predict-Funktion"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "00295aa6",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def predict_workload(df, model, au_columns, eye_columns, scaler_au, scaler_eye):\n",
|
||
" X_au, X_eye = preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye)\n",
|
||
"\n",
|
||
" probs = model.predict([X_au, X_eye]).flatten()\n",
|
||
" preds = (probs > 0.5).astype(int)\n",
|
||
" \n",
|
||
" return preds, probs"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5753516b",
|
||
"metadata": {},
|
||
"source": [
|
||
"Testdaten laden"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "8875b0ee",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"test_data = pd.read_csv(\"test_data.csv\") # oder direkt aus Notebook 1 exportieren \n",
|
||
"\n",
|
||
"X_au_test = test_data[au_columns].values[..., np.newaxis] \n",
|
||
"X_eye_test = test_data[eye_columns].values \n",
|
||
"y_test = test_data[\"label\"].values \n",
|
||
"groups_test = test_data[\"subjectID\"].values \n",
|
||
"\n",
|
||
"X_au_test_scaled = scaler_au.transform(X_au_test.reshape(len(X_au_test), -1)).reshape(X_au_test.shape) \n",
|
||
"X_eye_test_scaled = scaler_eye.transform(X_eye_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "332a3a07",
|
||
"metadata": {},
|
||
"source": [
|
||
"Vorhersagen"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "b5f58ece",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"y_prob = model.predict([X_au_test_scaled, X_eye_test_scaled]).flatten() \n",
|
||
"y_pred = (y_prob > 0.5).astype(int)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3bc5c66c",
|
||
"metadata": {},
|
||
"source": [
|
||
"Konfusionsmatrix"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "40648dd7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cm = confusion_matrix(y_test, y_pred) \n",
|
||
"plt.figure(figsize=(6,5)) \n",
|
||
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
|
||
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
|
||
" yticklabels=[\"True 0\", \"True 1\"]) \n",
|
||
"plt.title(\"Konfusionsmatrix - Testdaten\") \n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e79ad8a6",
|
||
"metadata": {},
|
||
"source": [
|
||
"ROC"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dd93f15c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"fpr, tpr, _ = roc_curve(y_test, y_prob) \n",
|
||
"roc_auc = auc(fpr, tpr) \n",
|
||
"\n",
|
||
"plt.figure(figsize=(7,6)) \n",
|
||
"plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n",
|
||
"plt.plot([0,1], [0,1], \"k--\") \n",
|
||
"plt.xlabel(\"False Positive Rate\") \n",
|
||
"plt.ylabel(\"True Positive Rate\") \n",
|
||
"plt.title(\"ROC‑Kurve – Testdaten\") \n",
|
||
"plt.legend() \n",
|
||
"plt.grid(True) \n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2eaaf2a0",
|
||
"metadata": {},
|
||
"source": [
|
||
"Precision-Recall"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "601e5dc9",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"precision, recall, _ = precision_recall_curve(y_test, y_prob) \n",
|
||
"plt.figure(figsize=(7,6)) \n",
|
||
"plt.plot(recall, precision) \n",
|
||
"plt.xlabel(\"Recall\") \n",
|
||
"plt.ylabel(\"Precision\") \n",
|
||
"plt.title(\"Precision‑Recall‑Kurve – Testdaten\")\n",
|
||
"plt.grid(True) \n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "270af771",
|
||
"metadata": {},
|
||
"source": [
|
||
"Scores"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "e2e7da5b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(\"F1‑Score:\", f1_score(y_test, y_pred)) \n",
|
||
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c6e22e1a",
|
||
"metadata": {},
|
||
"source": [
|
||
"Subject-Performance"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "731aaf73",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"df_eval = pd.DataFrame({ \n",
|
||
" \"subject\": groups_test, \n",
|
||
" \"y_true\": y_test, \n",
|
||
" \"y_pred\": y_pred \n",
|
||
"}) \n",
|
||
"\n",
|
||
"subject_perf = df_eval.groupby(\"subject\").apply( \n",
|
||
" lambda x: balanced_accuracy_score(x[\"y_true\"], x[\"y_pred\"]) \n",
|
||
") \n",
|
||
"\n",
|
||
"print(\"\\n=== Balanced Accuracy pro Proband ===\") \n",
|
||
"print(subject_perf.sort_values())"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|