{ "cells": [ { "cell_type": "markdown", "id": "cc08936c", "metadata": {}, "source": [ "## Insights into the dataset with histogramms and scatter plots" ] }, { "cell_type": "markdown", "id": "1014c5e0", "metadata": {}, "source": [ "Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "e42f3011", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "id": "0a834496", "metadata": {}, "outputs": [], "source": [ "path = Path(r\"/home/jovyan/data-paulusjafahrsimulator-gpu/new_datasets/50s_25Hz_dataset.parquet\")\n", "df = pd.read_parquet(path=path)" ] }, { "cell_type": "code", "execution_count": null, "id": "aa4759fa", "metadata": {}, "outputs": [], "source": [ "high_nback = df[\n", " (df[\"STUDY\"]==\"n-back\") &\n", " (df[\"LEVEL\"].isin([2, 3, 5, 6])) &\n", " (df[\"PHASE\"].isin([\"train\", \"test\"]))\n", "]\n", "high_nback.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a2aa0596", "metadata": {}, "outputs": [], "source": [ "low_all = df[\n", " ((df[\"PHASE\"] == \"baseline\") |\n", " ((df[\"STUDY\"] == \"n-back\") & (df[\"PHASE\"] != \"baseline\") & (df[\"LEVEL\"].isin([1,4]))))\n", "]\n", "print(low_all.shape)\n", "high_kdrive = df[\n", " (df[\"STUDY\"] == \"k-drive\") & (df[\"PHASE\"] != \"baseline\")\n", "]\n", "print(high_kdrive.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "f7d446a1", "metadata": {}, "outputs": [], "source": [ "print((df.shape[0]==(high_kdrive.shape[0]+high_nback.shape[0]+low_all.shape[0])))\n", "print(df.shape[0])\n", "print((high_kdrive.shape[0]+high_nback.shape[0]+low_all.shape[0]))" ] }, { "cell_type": "code", "execution_count": null, "id": "474e144a", "metadata": {}, "outputs": [], "source": [ "high_all = pd.concat([high_nback, high_kdrive])\n", "high_all.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5dd585c2", "metadata": {}, "outputs": [], "source": [ "df.dtypes" ] }, { "cell_type": "code", "execution_count": null, "id": "0bd39d9f", "metadata": {}, "outputs": [], "source": [ "face_au_cols = [c for c in low_all.columns if c.startswith(\"FACE_AU\")]\n", "eye_cols = ['Fix_count_short_66_150', 'Fix_count_medium_300_500',\n", " 'Fix_count_long_gt_1000', 'Fix_count_100', 'Fix_mean_duration',\n", " 'Fix_median_duration', 'Sac_count', 'Sac_mean_amp', 'Sac_mean_dur',\n", " 'Sac_median_dur', 'Blink_count', 'Blink_mean_dur', 'Blink_median_dur',\n", " 'Pupil_mean', 'Pupil_IPA']\n", "\n", "cols = face_au_cols+eye_cols\n", "\n", "# Calculate number of rows and columns for subplots\n", "n_cols = len(cols)\n", "n_rows = 7\n", "n_cols_subplot = 5\n", "\n", "# Create figure with subplots\n", "fig, axes = plt.subplots(n_rows, n_cols_subplot, figsize=(20, 16))\n", "axes = axes.flatten()\n", "fig.suptitle('Feature Distributions: Low vs High', fontsize=20, fontweight='bold', y=0.995)\n", "\n", "# Create histogram for each AU column\n", "for idx, col in enumerate(cols):\n", " ax = axes[idx]\n", " \n", " # Plot overlapping histograms\n", " ax.hist(low_all[col].dropna(), bins=30, alpha=0.6, color='blue', label='low_all', edgecolor='black')\n", " ax.hist(high_all[col].dropna(), bins=30, alpha=0.6, color='red', label='high_all', edgecolor='black')\n", " \n", " # Set title and labels\n", " ax.set_title(col, fontsize=10, fontweight='bold')\n", " ax.set_xlabel('Value', fontsize=8)\n", " ax.set_ylabel('Frequency', fontsize=8)\n", " ax.legend(fontsize=8)\n", " ax.grid(True, alpha=0.3)\n", "\n", "# Hide any unused subplots\n", "for idx in range(len(cols), len(axes)):\n", " axes[idx].set_visible(False)\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "6cd53cdb", "metadata": {}, "outputs": [], "source": [ "# Create figure with subplots\n", "fig, axes = plt.subplots(n_rows, n_cols_subplot, figsize=(20, 16))\n", "axes = axes.flatten()\n", "fig.suptitle('Feature Scatter: Low vs High', fontsize=20, fontweight='bold', y=0.995)\n", "\n", "for idx, col in enumerate(cols):\n", " ax = axes[idx]\n", "\n", " # Scatterplots\n", " ax.scatter(range(len(low_all[col])), low_all[col], alpha=0.6, color='blue', label='low_all', s=10)\n", " ax.scatter(range(len(high_all[col])), high_all[col], alpha=0.6, color='red', label='high_all', s=10)\n", "\n", " ax.set_title(col, fontsize=10, fontweight='bold')\n", " ax.set_xlabel('Sample index', fontsize=8)\n", " ax.set_ylabel('Value', fontsize=8)\n", " ax.legend(fontsize=8)\n", " ax.grid(True, alpha=0.3)\n", "\n", "\n", "plt.tight_layout()\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }