CNNs
This commit is contained in:
parent
42965a4733
commit
15190ac52e
1625
model_training/CNN/CNN_crossVal.ipynb
Normal file
1625
model_training/CNN/CNN_crossVal.ipynb
Normal file
File diff suppressed because one or more lines are too long
1750
model_training/CNN/CNN_crossVal_HybridFusion.ipynb
Normal file
1750
model_training/CNN/CNN_crossVal_HybridFusion.ipynb
Normal file
File diff suppressed because one or more lines are too long
1619
model_training/CNN/CNN_crossVal_faceAUs.ipynb
Normal file
1619
model_training/CNN/CNN_crossVal_faceAUs.ipynb
Normal file
File diff suppressed because one or more lines are too long
1668
model_training/CNN/CNN_crossVal_faceAUs_eyeFeatures.ipynb
Normal file
1668
model_training/CNN/CNN_crossVal_faceAUs_eyeFeatures.ipynb
Normal file
File diff suppressed because one or more lines are too long
663
model_training/CNN/CNN_simple.ipynb
Normal file
663
model_training/CNN/CNN_simple.ipynb
Normal file
File diff suppressed because one or more lines are too long
308
model_training/CNN/deployment_pipeline.ipynb
Normal file
308
model_training/CNN/deployment_pipeline.ipynb
Normal file
@ -0,0 +1,308 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d48f2e13",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Importe"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e34b838d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np \n",
|
||||
"import pandas as pd \n",
|
||||
"import joblib \n",
|
||||
"import seaborn as sns \n",
|
||||
"import matplotlib.pyplot as plt \n",
|
||||
"\n",
|
||||
"from sklearn.metrics import ( \n",
|
||||
" confusion_matrix, \n",
|
||||
" roc_curve, auc, \n",
|
||||
" precision_recall_curve, \n",
|
||||
" f1_score, \n",
|
||||
" balanced_accuracy_score \n",
|
||||
")\n",
|
||||
" \n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "324554b5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Modell und Scaler laden"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4acc3d2f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = tf.keras.models.load_model(\"hybrid_fusion_model_V2.keras\") \n",
|
||||
"scaler_au = joblib.load(\"scaler_au_V2.joblib\") \n",
|
||||
"scaler_eye = joblib.load(\"scaler_eye_V2.joblib\")\n",
|
||||
"\n",
|
||||
"print(\"Modell & Scaler erfolgreich geladen.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4271cbee",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Features laden"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8342ea10",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"au_columns = [...] \n",
|
||||
"eye_columns = [...]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4a58b20c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Preprocessing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b683be47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye):\n",
|
||||
" # AUs\n",
|
||||
" X_au = df[au_columns].values\n",
|
||||
" X_au = scaler_au.transform(X_au).reshape(len(df), len(au_columns), 1)\n",
|
||||
"\n",
|
||||
" # Eye\n",
|
||||
" X_eye = df[eye_columns].values\n",
|
||||
" X_eye = scaler_eye.transform(X_eye)\n",
|
||||
"\n",
|
||||
" return X_au, X_eye"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9dc99a3d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Predict-Funktion"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "00295aa6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_workload(df, model, au_columns, eye_columns, scaler_au, scaler_eye):\n",
|
||||
" X_au, X_eye = preprocess_sample(df, au_columns, eye_columns, scaler_au, scaler_eye)\n",
|
||||
"\n",
|
||||
" probs = model.predict([X_au, X_eye]).flatten()\n",
|
||||
" preds = (probs > 0.5).astype(int)\n",
|
||||
" \n",
|
||||
" return preds, probs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5753516b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Testdaten laden"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8875b0ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_data = pd.read_csv(\"test_data.csv\") # oder direkt aus Notebook 1 exportieren \n",
|
||||
"\n",
|
||||
"X_au_test = test_data[au_columns].values[..., np.newaxis] \n",
|
||||
"X_eye_test = test_data[eye_columns].values \n",
|
||||
"y_test = test_data[\"label\"].values \n",
|
||||
"groups_test = test_data[\"subjectID\"].values \n",
|
||||
"\n",
|
||||
"X_au_test_scaled = scaler_au.transform(X_au_test.reshape(len(X_au_test), -1)).reshape(X_au_test.shape) \n",
|
||||
"X_eye_test_scaled = scaler_eye.transform(X_eye_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "332a3a07",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Vorhersagen"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5f58ece",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_prob = model.predict([X_au_test_scaled, X_eye_test_scaled]).flatten() \n",
|
||||
"y_pred = (y_prob > 0.5).astype(int)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3bc5c66c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Konfusionsmatrix"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "40648dd7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cm = confusion_matrix(y_test, y_pred) \n",
|
||||
"plt.figure(figsize=(6,5)) \n",
|
||||
"sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", \n",
|
||||
" xticklabels=[\"Pred 0\", \"Pred 1\"], \n",
|
||||
" yticklabels=[\"True 0\", \"True 1\"]) \n",
|
||||
"plt.title(\"Konfusionsmatrix - Testdaten\") \n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e79ad8a6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"ROC"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd93f15c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fpr, tpr, _ = roc_curve(y_test, y_prob) \n",
|
||||
"roc_auc = auc(fpr, tpr) \n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(7,6)) \n",
|
||||
"plt.plot(fpr, tpr, label=f\"AUC = {roc_auc:.3f}\") \n",
|
||||
"plt.plot([0,1], [0,1], \"k--\") \n",
|
||||
"plt.xlabel(\"False Positive Rate\") \n",
|
||||
"plt.ylabel(\"True Positive Rate\") \n",
|
||||
"plt.title(\"ROC‑Kurve – Testdaten\") \n",
|
||||
"plt.legend() \n",
|
||||
"plt.grid(True) \n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2eaaf2a0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Precision-Recall"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "601e5dc9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"precision, recall, _ = precision_recall_curve(y_test, y_prob) \n",
|
||||
"plt.figure(figsize=(7,6)) \n",
|
||||
"plt.plot(recall, precision) \n",
|
||||
"plt.xlabel(\"Recall\") \n",
|
||||
"plt.ylabel(\"Precision\") \n",
|
||||
"plt.title(\"Precision‑Recall‑Kurve – Testdaten\")\n",
|
||||
"plt.grid(True) \n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "270af771",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e2e7da5b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"F1‑Score:\", f1_score(y_test, y_pred)) \n",
|
||||
"print(\"Balanced Accuracy:\", balanced_accuracy_score(y_test, y_pred))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c6e22e1a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Subject-Performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "731aaf73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_eval = pd.DataFrame({ \n",
|
||||
" \"subject\": groups_test, \n",
|
||||
" \"y_true\": y_test, \n",
|
||||
" \"y_pred\": y_pred \n",
|
||||
"}) \n",
|
||||
"\n",
|
||||
"subject_perf = df_eval.groupby(\"subject\").apply( \n",
|
||||
" lambda x: balanced_accuracy_score(x[\"y_true\"], x[\"y_pred\"]) \n",
|
||||
") \n",
|
||||
"\n",
|
||||
"print(\"\\n=== Balanced Accuracy pro Proband ===\") \n",
|
||||
"print(subject_perf.sort_values())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user