{ "cells": [ { "cell_type": "markdown", "id": "d48f2e13", "metadata": {}, "source": [ "Importe" ] }, { "cell_type": "code", "execution_count": null, "id": "e34b838d", "metadata": {}, "outputs": [], "source": [ "import numpy as np \n", "import pandas as pd \n", "import joblib \n", "import seaborn as sns \n", "import matplotlib.pyplot as plt \n", "\n", "from sklearn.metrics import ( \n", " confusion_matrix, \n", " roc_curve, auc, \n", " precision_recall_curve, \n", " f1_score, \n", " balanced_accuracy_score \n", ")\n", " \n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "id": "324554b5", "metadata": {}, "source": [ "Modell und Scaler laden" ] }, { "cell_type": "code", "execution_count": null, "id": "4acc3d2f", "metadata": {}, "outputs": [], "source": [ "model = tf.keras.models.load_model(\"hybrid_fusion_model_V2.keras\") \n", "scaler_au = joblib.load(\"scaler_au_V2.joblib\") \n", "scaler_eye = joblib.load(\"scaler_eye_V2.joblib\")\n", "\n", "print(\"Modell & Scaler erfolgreich geladen.\")" ] }, { "cell_type": "markdown", "id": "4271cbee", "metadata": {}, "source": [ "Features laden" ] }, { "cell_type": "code", "execution_count": null, "id": "8342ea10", "metadata": {}, "outputs": [], "source": [ "au_columns = [...] \n", "eye_columns = [...]" ] }, { "cell_type": "markdown", "id": "4a58b20c", "metadata": {}, "source": [ "Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "id": "b683be47", "metadata": {}, "outputs": [], "source": [ "def preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye):\n", " # AUs\n", " X_au = df[au_columns].values\n", " X_au = scaler_au.transform(X_au).reshape(len(df), len(au_columns), 1)\n", "\n", " # Eye\n", " X_eye = df[eye_columns].values\n", " X_eye = scaler_eye.transform(X_eye)\n", "\n", " return X_au, X_eye" ] }, { "cell_type": "markdown", "id": "9dc99a3d", "metadata": {}, "source": [ "Predict-Funktion" ] }, { "cell_type": "code", "execution_count": null, "id": "00295aa6", "metadata": {}, "outputs": [], "source": [ "def predict_workload(df, model, au_columns, eye_columns, scaler_au, scaler_eye):\n", " X_au, X_eye = preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye)\n", "\n", " probs = model.predict([X_au, X_eye]).flatten()\n", " preds = (probs > 0.5).astype(int)\n", " \n", " return preds, probs" ] }, { "cell_type": "markdown", "id": "5753516b", "metadata": {}, "source": [ "Testdaten laden" ] }, { "cell_type": "code", "execution_count": null, "id": "8875b0ee", "metadata": {}, "outputs": [], "source": [ "test_data = pd.read_csv(\"test_data.csv\") # oder direkt aus Notebook 1 exportieren \n", "\n", "X_au_test = test_data[au_columns].values[..., np.newaxis] \n", "X_eye_test = test_data[eye_columns].values \n", "y_test = test_data[\"label\"].values \n", "groups_test = test_data[\"subjectID\"].values \n", "\n", "X_au_test_scaled = scaler_au.transform(X_au_test.reshape(len(X_au_test), -1)).reshape(X_au_test.shape) \n", "X_eye_test_scaled = scaler_eye.transform(X_eye_test)" ] }, { "cell_type": "markdown", "id": "332a3a07", "metadata": {}, "source": [ "Vorhersagen" ] }, { "cell_type": "code", "execution_count": null, "id": "b5f58ece", "metadata": {}, "outputs": [], "source": [ "y_prob = model.predict([X_au_test_scaled, X_eye_test_scaled]).flatten() \n", "y_pred = (y_prob > 0.5).astype(int)" ] }, { "cell_type": "markdown", "id": "3bc5c66c", "metadata": {}, "source": [ "Konfusionsmatrix" ] }, { "cell_type": "code", "execution_count": null, "id": "40648dd7", "metadata": {}, "outputs": [], "source": [ "cm = confusion_matrix(y_test, y_pred) \n", "plt.figure(figsize=(6,5)) \n", "sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n", " xticklabels=[\"Pred 0\", \"Pred 1\"], \n", " yticklabels=[\"True 0\", \"True 1\"]) \n", "plt.title(\"Konfusionsmatrix - Testdaten\") \n", "plt.show()" ] }, { "cell_type": "markdown", "id": "e79ad8a6", "metadata": {}, "source": [ "ROC" ] }, { "cell_type": "code", "execution_count": null, "id": "dd93f15c", "metadata": {}, "outputs": [], "source": [ "fpr, tpr, _ = roc_curve(y_test, y_prob) \n", "roc_auc = auc(fpr, tpr) \n", "\n", "plt.figure(figsize=(7,6)) \n", "plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n", "plt.plot([0,1], [0,1], \"k--\") \n", "plt.xlabel(\"False Positive Rate\") \n", "plt.ylabel(\"True Positive Rate\") \n", "plt.title(\"ROC‑Kurve – Testdaten\") \n", "plt.legend() \n", "plt.grid(True) \n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2eaaf2a0", "metadata": {}, "source": [ "Precision-Recall" ] }, { "cell_type": "code", "execution_count": null, "id": "601e5dc9", "metadata": {}, "outputs": [], "source": [ "precision, recall, _ = precision_recall_curve(y_test, y_prob) \n", "plt.figure(figsize=(7,6)) \n", "plt.plot(recall, precision) \n", "plt.xlabel(\"Recall\") \n", "plt.ylabel(\"Precision\") \n", "plt.title(\"Precision‑Recall‑Kurve – Testdaten\")\n", "plt.grid(True) \n", "plt.show()" ] }, { "cell_type": "markdown", "id": "270af771", "metadata": {}, "source": [ "Scores" ] }, { "cell_type": "code", "execution_count": null, "id": "e2e7da5b", "metadata": {}, "outputs": [], "source": [ "print(\"F1‑Score:\", f1_score(y_test, y_pred)) \n", "print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred))" ] }, { "cell_type": "markdown", "id": "c6e22e1a", "metadata": {}, "source": [ "Subject-Performance" ] }, { "cell_type": "code", "execution_count": null, "id": "731aaf73", "metadata": {}, "outputs": [], "source": [ "df_eval = pd.DataFrame({ \n", " \"subject\": groups_test, \n", " \"y_true\": y_test, \n", " \"y_pred\": y_pred \n", "}) \n", "\n", "subject_perf = df_eval.groupby(\"subject\").apply( \n", " lambda x: balanced_accuracy_score(x[\"y_true\"], x[\"y_pred\"]) \n", ") \n", "\n", "print(\"\\n=== Balanced Accuracy pro Proband ===\") \n", "print(subject_perf.sort_values())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }