Fahrsimulator_MSY2526_AI/model_training/CNN/deployment_pipeline.ipynb
2026-02-24 11:42:58 +01:00

309 lines
6.9 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "d48f2e13",
"metadata": {},
"source": [
"Importe"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e34b838d",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np \n",
"import pandas as pd \n",
"import joblib \n",
"import seaborn as sns \n",
"import matplotlib.pyplot as plt \n",
"\n",
"from sklearn.metrics import ( \n",
" confusion_matrix, \n",
" roc_curve, auc, \n",
" precision_recall_curve, \n",
" f1_score, \n",
" balanced_accuracy_score \n",
")\n",
" \n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"id": "324554b5",
"metadata": {},
"source": [
"Modell und Scaler laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4acc3d2f",
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.load_model(\"hybrid_fusion_model_V2.keras\") \n",
"scaler_au = joblib.load(\"scaler_au_V2.joblib\") \n",
"scaler_eye = joblib.load(\"scaler_eye_V2.joblib\")\n",
"\n",
"print(\"Modell & Scaler erfolgreich geladen.\")"
]
},
{
"cell_type": "markdown",
"id": "4271cbee",
"metadata": {},
"source": [
"Features laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8342ea10",
"metadata": {},
"outputs": [],
"source": [
"au_columns = [...] \n",
"eye_columns = [...]"
]
},
{
"cell_type": "markdown",
"id": "4a58b20c",
"metadata": {},
"source": [
"Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b683be47",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye):\n",
" # AUs\n",
" X_au = df[au_columns].values\n",
" X_au = scaler_au.transform(X_au).reshape(len(df), len(au_columns), 1)\n",
"\n",
" # Eye\n",
" X_eye = df[eye_columns].values\n",
" X_eye = scaler_eye.transform(X_eye)\n",
"\n",
" return X_au, X_eye"
]
},
{
"cell_type": "markdown",
"id": "9dc99a3d",
"metadata": {},
"source": [
"Predict-Funktion"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00295aa6",
"metadata": {},
"outputs": [],
"source": [
"def predict_workload(df, model, au_columns, eye_columns, scaler_au, scaler_eye):\n",
" X_au, X_eye = preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye)\n",
"\n",
" probs = model.predict([X_au, X_eye]).flatten()\n",
" preds = (probs > 0.5).astype(int)\n",
" \n",
" return preds, probs"
]
},
{
"cell_type": "markdown",
"id": "5753516b",
"metadata": {},
"source": [
"Testdaten laden"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8875b0ee",
"metadata": {},
"outputs": [],
"source": [
"test_data = pd.read_csv(\"test_data.csv\") # oder direkt aus Notebook 1 exportieren \n",
"\n",
"X_au_test = test_data[au_columns].values[..., np.newaxis] \n",
"X_eye_test = test_data[eye_columns].values \n",
"y_test = test_data[\"label\"].values \n",
"groups_test = test_data[\"subjectID\"].values \n",
"\n",
"X_au_test_scaled = scaler_au.transform(X_au_test.reshape(len(X_au_test), -1)).reshape(X_au_test.shape) \n",
"X_eye_test_scaled = scaler_eye.transform(X_eye_test)"
]
},
{
"cell_type": "markdown",
"id": "332a3a07",
"metadata": {},
"source": [
"Vorhersagen"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5f58ece",
"metadata": {},
"outputs": [],
"source": [
"y_prob = model.predict([X_au_test_scaled, X_eye_test_scaled]).flatten() \n",
"y_pred = (y_prob > 0.5).astype(int)"
]
},
{
"cell_type": "markdown",
"id": "3bc5c66c",
"metadata": {},
"source": [
"Konfusionsmatrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40648dd7",
"metadata": {},
"outputs": [],
"source": [
"cm = confusion_matrix(y_test, y_pred) \n",
"plt.figure(figsize=(6,5)) \n",
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
" yticklabels=[\"True 0\", \"True 1\"]) \n",
"plt.title(\"Konfusionsmatrix - Testdaten\") \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "e79ad8a6",
"metadata": {},
"source": [
"ROC"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd93f15c",
"metadata": {},
"outputs": [],
"source": [
"fpr, tpr, _ = roc_curve(y_test, y_prob) \n",
"roc_auc = auc(fpr, tpr) \n",
"\n",
"plt.figure(figsize=(7,6)) \n",
"plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n",
"plt.plot([0,1], [0,1], \"k--\") \n",
"plt.xlabel(\"False Positive Rate\") \n",
"plt.ylabel(\"True Positive Rate\") \n",
"plt.title(\"ROCKurve Testdaten\") \n",
"plt.legend() \n",
"plt.grid(True) \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2eaaf2a0",
"metadata": {},
"source": [
"Precision-Recall"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "601e5dc9",
"metadata": {},
"outputs": [],
"source": [
"precision, recall, _ = precision_recall_curve(y_test, y_prob) \n",
"plt.figure(figsize=(7,6)) \n",
"plt.plot(recall, precision) \n",
"plt.xlabel(\"Recall\") \n",
"plt.ylabel(\"Precision\") \n",
"plt.title(\"PrecisionRecallKurve Testdaten\")\n",
"plt.grid(True) \n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "270af771",
"metadata": {},
"source": [
"Scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2e7da5b",
"metadata": {},
"outputs": [],
"source": [
"print(\"F1Score:\", f1_score(y_test, y_pred)) \n",
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"id": "c6e22e1a",
"metadata": {},
"source": [
"Subject-Performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "731aaf73",
"metadata": {},
"outputs": [],
"source": [
"df_eval = pd.DataFrame({ \n",
" \"subject\": groups_test, \n",
" \"y_true\": y_test, \n",
" \"y_pred\": y_pred \n",
"}) \n",
"\n",
"subject_perf = df_eval.groupby(\"subject\").apply( \n",
" lambda x: balanced_accuracy_score(x[\"y_true\"], x[\"y_pred\"]) \n",
") \n",
"\n",
"print(\"\\n=== Balanced Accuracy pro Proband ===\") \n",
"print(subject_perf.sort_values())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}