Compare commits

...

2 Commits

View File

@ -7,13 +7,331 @@
"metadata": {},
"outputs": [],
"source": [
"import xgboost as xgb"
"import pandas as pd\n",
"from pathlib import Path\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13ad96f5",
"metadata": {},
"outputs": [],
"source": [
"data_path = Path(r\"~/Fahrsimulator_MSY2526_AI/model_training/xgboost/output_windowed.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95e1a351",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_parquet(path=data_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68afd83e",
"metadata": {},
"outputs": [],
"source": [
"subjects = df['subjectID'].unique()\n",
"print(subjects)\n",
"print(len(subjects))\n",
"print(len(subjects)*0.66)\n",
"print(len(subjects)*0.33)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52dfd885",
"metadata": {},
"outputs": [],
"source": [
"low_all = df[\n",
" ((df[\"PHASE\"] == \"baseline\") |\n",
" ((df[\"STUDY\"] == \"n-back\") & (df[\"PHASE\"] != \"baseline\") & (df[\"LEVEL\"].isin([1, 4]))))\n",
"]\n",
"print(f\"low all: {low_all.shape}\")\n",
"\n",
"high_nback = df[\n",
" (df[\"STUDY\"]==\"n-back\") &\n",
" (df[\"LEVEL\"].isin([2, 3, 5, 6])) &\n",
" (df[\"PHASE\"].isin([\"train\", \"test\"]))\n",
"]\n",
"print(f\"high n-back: {high_nback.shape}\")\n",
"\n",
"high_kdrive = df[\n",
" (df[\"STUDY\"] == \"k-drive\") & (df[\"PHASE\"] != \"baseline\")\n",
"]\n",
"print(f\"high k-drive: {high_kdrive.shape}\")\n",
"\n",
"high_all = pd.concat([high_nback, high_kdrive])\n",
"print(f\"high all: {high_all.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8fba6edf",
"metadata": {},
"outputs": [],
"source": [
"def fit_normalizer(train_data, au_columns, method='standard', scope='global'):\n",
" if method == 'standard':\n",
" Scaler = StandardScaler\n",
" elif method == 'minmax':\n",
" Scaler = MinMaxScaler\n",
" else:\n",
" raise ValueError(\"method must be 'standard' or 'minmax'\")\n",
" \n",
" scalers = {}\n",
" \n",
" if scope == 'subject':\n",
" for subject in train_data['subjectID'].unique():\n",
" subject_mask = train_data['subjectID'] == subject\n",
" scaler = Scaler()\n",
" scaler.fit(train_data.loc[subject_mask, au_columns])\n",
" scalers[subject] = scaler\n",
"\n",
" elif scope == 'global':\n",
" scaler = Scaler()\n",
" scaler.fit(train_data[au_columns])\n",
" scalers['global'] = scaler\n",
"\n",
" else:\n",
" raise ValueError(\"scope must be 'subject' or 'global'\")\n",
" \n",
" return {'scalers': scalers, 'method': method, 'scope': scope}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24e3a77b",
"metadata": {},
"outputs": [],
"source": [
"%pip install xgboost"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e7fa0fa",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, confusion_matrix\n",
"import xgboost as xgb\n",
"import joblib\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "325ef71c",
"metadata": {},
"outputs": [],
"source": [
"low = low_all.copy()\n",
"high = high_all.copy()\n",
"\n",
"low[\"label\"] = 0\n",
"high[\"label\"] = 1\n",
"\n",
"data = pd.concat([low, high], ignore_index=True)\n",
"data = data.drop_duplicates()\n",
"\n",
"print(\"Label distribution:\")\n",
"print(data[\"label\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67d70e84",
"metadata": {},
"outputs": [],
"source": [
"au_columns = [col for col in data.columns if col.lower().startswith(\"au\")]\n",
"print(\"Gefundene AU-Spalten:\", au_columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "960bb8c7",
"metadata": {},
"outputs": [],
"source": [
"subjects = np.random.permutation(data[\"subjectID\"].unique())\n",
"\n",
"n = len(subjects)\n",
"n_train = int(n * 0.66)\n",
"\n",
"train_subjects = subjects[:n_train]\n",
"test_subjects = subjects[n_train:]\n",
"train_subs, val_subs = train_test_split(train_subjects, test_size=0.2, random_state=42)\n",
"\n",
"train_df = data[data.subjectID.isin(train_subs)]\n",
"val_df = data[data.subjectID.isin(val_subs)]\n",
"test_df = data[data.subjectID.isin(test_subjects)]\n",
"\n",
"print(train_df.shape, val_df.shape, test_df.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "802a45c9",
"metadata": {},
"outputs": [],
"source": [
"def apply_normalizer(df_to_transform, normalizer_dict, au_columns):\n",
" scalers = normalizer_dict[\"scalers\"]\n",
" scope = normalizer_dict[\"scope\"]\n",
" df_out = df_to_transform.copy()\n",
"\n",
" if scope == \"global\":\n",
" scaler = scalers[\"global\"]\n",
" df_out[au_columns] = scaler.transform(df_out[au_columns])\n",
"\n",
" elif scope == \"subject\":\n",
" for subj, subdf in df_out.groupby(\"subjectID\"):\n",
" if subj in scalers:\n",
" df_out.loc[subdf.index, au_columns] = scalers[subj].transform(subdf[au_columns])\n",
" elif \"global\" in scalers:\n",
" df_out.loc[subdf.index, au_columns] = scalers[\"global\"].transform(subdf[au_columns])\n",
"\n",
" return df_out"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "289f6b89",
"metadata": {},
"outputs": [],
"source": [
"normalizer = fit_normalizer(train_df, au_columns, method=\"standard\", scope=\"global\")\n",
"\n",
"train_scaled = apply_normalizer(train_df, normalizer, au_columns)\n",
"val_scaled = apply_normalizer(val_df, normalizer, au_columns)\n",
"test_scaled = apply_normalizer(test_df, normalizer, au_columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5df30e8d",
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train = train_scaled[au_columns].values, train_scaled[\"label\"].values\n",
"X_val, y_val = val_scaled[au_columns].values, val_scaled[\"label\"].values\n",
"X_test, y_test = test_scaled[au_columns].values, test_scaled[\"label\"].values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fb7c86a",
"metadata": {},
"outputs": [],
"source": [
"model = xgb.XGBClassifier(\n",
" objective=\"binary:logistic\",\n",
" eval_metric=\"auc\",\n",
" learning_rate=0.05,\n",
" max_depth=6,\n",
" n_estimators=500,\n",
" subsample=0.8,\n",
" colsample_bytree=0.8,\n",
" random_state=42\n",
")\n",
"\n",
"model.fit(\n",
" X_train, y_train,\n",
" eval_set=[(X_val, y_val)],\n",
" #early_stopping_rounds=30,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09a8cd21",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, roc_auc_score, classification_report, ConfusionMatrixDisplay\n",
"\n",
"def evaluate(model, X, y, title=\"Evaluation\"):\n",
" # Vorhersagen\n",
" preds_proba = model.predict_proba(X)[:, 1]\n",
" preds = (preds_proba > 0.5).astype(int)\n",
"\n",
" # Metriken ausgeben\n",
" print(\"Accuracy:\", accuracy_score(y, preds))\n",
" print(\"F1:\", f1_score(y, preds))\n",
" print(\"AUC:\", roc_auc_score(y, preds))\n",
" print(\"Confusion:\\n\", confusion_matrix(y, preds))\n",
" print(classification_report(y, preds))\n",
"\n",
" # Confusion Matrix plotten\n",
" def plot_confusion_matrix(true_labels, predictions, label_names):\n",
" for normalize in [None, 'true']:\n",
" cm = confusion_matrix(true_labels, predictions, normalize=normalize)\n",
" cm_disp = ConfusionMatrixDisplay(cm, display_labels=label_names)\n",
" cm_disp.plot(cmap=\"Blues\")\n",
" #cm = confusion_matrix(y, preds)\n",
" plot_confusion_matrix(y,preds, label_names=['Low','High'])\n",
" # plt.figure(figsize=(5,4))\n",
" # sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", cbar=False,\n",
" # xticklabels=[\"Predicted low\", \"Predicted high\"],\n",
" # yticklabels=[\"Actual low\", \"Actual high\"])\n",
" # plt.title(f\"Confusion Matrix - {title}\")\n",
" # plt.ylabel(\"True label\")\n",
" # plt.xlabel(\"Predicted label\")\n",
" # plt.show()\n",
"\n",
"# Aufrufen für Train/Val/Test\n",
"print(\"TRAIN:\")\n",
"evaluate(model, X_train, y_train, title=\"Train\")\n",
"\n",
"print(\"VAL:\")\n",
"evaluate(model, X_val, y_val, title=\"Validation\")\n",
"\n",
"print(\"TEST:\")\n",
"evaluate(model, X_test, y_test, title=\"Test\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c43b0c80",
"metadata": {},
"outputs": [],
"source": [
"joblib.dump(model, \"xgb_model.joblib\")\n",
"joblib.dump(normalizer, \"normalizer.joblib\")\n",
"print(\"Model gespeichert.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "project_fahrsimulator",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -27,7 +345,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.19"
"version": "3.12.10"
}
},
"nbformat": 4,