200 lines
5.4 KiB
Plaintext
200 lines
5.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc08936c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Insights into the dataset with histogramms and scatter plots"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1014c5e0",
|
|
"metadata": {},
|
|
"source": [
|
|
"Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e42f3011",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from pathlib import Path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0a834496",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"path = Path(r\".parquet\") # TODO: enter path to dataset\n",
|
|
"df = pd.read_parquet(path=path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aa4759fa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"high_nback = df[\n",
|
|
" (df[\"STUDY\"]==\"n-back\") &\n",
|
|
" (df[\"LEVEL\"].isin([2, 3, 5, 6])) &\n",
|
|
" (df[\"PHASE\"].isin([\"train\", \"test\"]))\n",
|
|
"]\n",
|
|
"high_nback.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a2aa0596",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"low_all = df[\n",
|
|
" ((df[\"PHASE\"] == \"baseline\") |\n",
|
|
" ((df[\"STUDY\"] == \"n-back\") & (df[\"PHASE\"] != \"baseline\") & (df[\"LEVEL\"].isin([1,4]))))\n",
|
|
"]\n",
|
|
"print(low_all.shape)\n",
|
|
"high_kdrive = df[\n",
|
|
" (df[\"STUDY\"] == \"k-drive\") & (df[\"PHASE\"] != \"baseline\")\n",
|
|
"]\n",
|
|
"print(high_kdrive.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f7d446a1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print((df.shape[0]==(high_kdrive.shape[0]+high_nback.shape[0]+low_all.shape[0])))\n",
|
|
"print(df.shape[0])\n",
|
|
"print((high_kdrive.shape[0]+high_nback.shape[0]+low_all.shape[0]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "474e144a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"high_all = pd.concat([high_nback, high_kdrive])\n",
|
|
"high_all.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5dd585c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df.dtypes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0bd39d9f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"face_au_cols = [c for c in low_all.columns if c.startswith(\"FACE_AU\")]\n",
|
|
"eye_cols = ['Fix_count_short_66_150', 'Fix_count_medium_300_500',\n",
|
|
" 'Fix_count_long_gt_1000', 'Fix_count_100', 'Fix_mean_duration',\n",
|
|
" 'Fix_median_duration', 'Sac_count', 'Sac_mean_amp', 'Sac_mean_dur',\n",
|
|
" 'Sac_median_dur', 'Blink_count', 'Blink_mean_dur', 'Blink_median_dur',\n",
|
|
" 'Pupil_mean', 'Pupil_IPA']\n",
|
|
"\n",
|
|
"cols = face_au_cols+eye_cols\n",
|
|
"\n",
|
|
"# Calculate number of rows and columns for subplots\n",
|
|
"n_cols = len(cols)\n",
|
|
"n_rows = 7\n",
|
|
"n_cols_subplot = 5\n",
|
|
"\n",
|
|
"# Create figure with subplots\n",
|
|
"fig, axes = plt.subplots(n_rows, n_cols_subplot, figsize=(20, 16))\n",
|
|
"axes = axes.flatten()\n",
|
|
"fig.suptitle('Feature Distributions: Low vs High', fontsize=20, fontweight='bold', y=0.995)\n",
|
|
"\n",
|
|
"# Create histogram for each AU column\n",
|
|
"for idx, col in enumerate(cols):\n",
|
|
" ax = axes[idx]\n",
|
|
" \n",
|
|
" # Plot overlapping histograms\n",
|
|
" ax.hist(low_all[col].dropna(), bins=30, alpha=0.6, color='blue', label='low_all', edgecolor='black')\n",
|
|
" ax.hist(high_all[col].dropna(), bins=30, alpha=0.6, color='red', label='high_all', edgecolor='black')\n",
|
|
" \n",
|
|
" # Set title and labels\n",
|
|
" ax.set_title(col, fontsize=10, fontweight='bold')\n",
|
|
" ax.set_xlabel('Value', fontsize=8)\n",
|
|
" ax.set_ylabel('Frequency', fontsize=8)\n",
|
|
" ax.legend(fontsize=8)\n",
|
|
" ax.grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
"# Hide any unused subplots\n",
|
|
"for idx in range(len(cols), len(axes)):\n",
|
|
" axes[idx].set_visible(False)\n",
|
|
"\n",
|
|
"# Adjust layout\n",
|
|
"plt.tight_layout()\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6cd53cdb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create figure with subplots\n",
|
|
"fig, axes = plt.subplots(n_rows, n_cols_subplot, figsize=(20, 16))\n",
|
|
"axes = axes.flatten()\n",
|
|
"fig.suptitle('Feature Scatter: Low vs High', fontsize=20, fontweight='bold', y=0.995)\n",
|
|
"\n",
|
|
"for idx, col in enumerate(cols):\n",
|
|
" ax = axes[idx]\n",
|
|
"\n",
|
|
" # Scatterplots\n",
|
|
" ax.scatter(range(len(low_all[col])), low_all[col], alpha=0.6, color='blue', label='low_all', s=10)\n",
|
|
" ax.scatter(range(len(high_all[col])), high_all[col], alpha=0.6, color='red', label='high_all', s=10)\n",
|
|
"\n",
|
|
" ax.set_title(col, fontsize=10, fontweight='bold')\n",
|
|
" ax.set_xlabel('Sample index', fontsize=8)\n",
|
|
" ax.set_ylabel('Value', fontsize=8)\n",
|
|
" ax.legend(fontsize=8)\n",
|
|
" ax.grid(True, alpha=0.3)\n",
|
|
"\n",
|
|
"\n",
|
|
"plt.tight_layout()\n",
|
|
"plt.show()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|