changed data preprocessing in i forest training
This commit is contained in:
parent
58faff8f68
commit
5f00341b1b
@ -28,7 +28,13 @@
|
||||
"sys.path.append(base_dir)\n",
|
||||
"print(base_dir)\n",
|
||||
"\n",
|
||||
"from tools import evaluation_tools"
|
||||
"from tools import evaluation_tools\n",
|
||||
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
|
||||
"from sklearn.ensemble import IsolationForest\n",
|
||||
"from sklearn.model_selection import GridSearchCV, KFold\n",
|
||||
"from sklearn.metrics import roc_auc_score\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -112,41 +118,113 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "20394aca",
|
||||
"id": "47a0f44d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_subjects, test_subjects = train_test_split(\n",
|
||||
" subjects, \n",
|
||||
" train_size=12, \n",
|
||||
" test_size=6, \n",
|
||||
" random_state=42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Get all column names that start with 'AU'\n",
|
||||
"au_columns = [col for col in low_all.columns if col.startswith('AU')]\n",
|
||||
"\n",
|
||||
"# Create train set: only normal samples from train subjects, only AU columns\n",
|
||||
"X_train = low_all[low_all['subjectID'].isin(train_subjects)][au_columns].copy()\n",
|
||||
"y_train = np.ones(len(X_train)) # Label 1 for normal samples\n",
|
||||
"\n",
|
||||
"# Create test set: both normal and high load from test subjects, only AU columns\n",
|
||||
"X_test_normal = low_all[low_all['subjectID'].isin(test_subjects)][au_columns].copy()\n",
|
||||
"X_test_high = high_all[high_all['subjectID'].isin(test_subjects)][au_columns].copy()\n",
|
||||
"\n",
|
||||
"# Combine test sets\n",
|
||||
"X_test = pd.concat([X_test_normal, X_test_high], ignore_index=True)\n",
|
||||
"\n",
|
||||
"# Create labels for test set\n",
|
||||
"y_test_normal = np.ones(len(X_test_normal)) # 1 for normal\n",
|
||||
"y_test_high = -np.ones(len(X_test_high)) # -1 for anomalies\n",
|
||||
"y_test = np.concatenate([y_test_normal, y_test_high])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f\"Number of AU features: {len(au_columns)}\")\n",
|
||||
"print(f\"AU columns: {au_columns}\")\n",
|
||||
"print(f\"\\nTrain set: {len(X_train)} normal samples\")\n",
|
||||
"print(f\"Test set: {len(X_test_normal)} normal + {len(X_test_high)} high load = {len(X_test)} total samples\")\n"
|
||||
"def fit_normalizer(train_data, au_columns, method='standard', scope='global'):\n",
|
||||
" \"\"\"\n",
|
||||
" Fit normalization scalers on training data.\n",
|
||||
" \n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" train_data : pd.DataFrame\n",
|
||||
" Training dataframe with AU columns and subjectID\n",
|
||||
" au_columns : list\n",
|
||||
" List of AU column names to normalize\n",
|
||||
" method : str, default='standard'\n",
|
||||
" Normalization method: 'standard' for StandardScaler or 'minmax' for MinMaxScaler\n",
|
||||
" scope : str, default='global'\n",
|
||||
" Normalization scope: 'subject' for per-subject or 'global' for across all subjects\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" dict\n",
|
||||
" Dictionary containing fitted scalers\n",
|
||||
" \"\"\"\n",
|
||||
" # Select scaler based on method\n",
|
||||
" if method == 'standard':\n",
|
||||
" Scaler = StandardScaler\n",
|
||||
" elif method == 'minmax':\n",
|
||||
" Scaler = MinMaxScaler\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"method must be 'standard' or 'minmax'\")\n",
|
||||
" \n",
|
||||
" scalers = {}\n",
|
||||
" \n",
|
||||
" if scope == 'subject':\n",
|
||||
" # Fit one scaler per subject\n",
|
||||
" for subject in train_data['subjectID'].unique():\n",
|
||||
" subject_mask = train_data['subjectID'] == subject\n",
|
||||
" scaler = Scaler()\n",
|
||||
" scaler.fit(train_data.loc[subject_mask, au_columns])\n",
|
||||
" scalers[subject] = scaler\n",
|
||||
" \n",
|
||||
" elif scope == 'global':\n",
|
||||
" # Fit one scaler for all subjects\n",
|
||||
" scaler = Scaler()\n",
|
||||
" scaler.fit(train_data[au_columns])\n",
|
||||
" scalers['global'] = scaler\n",
|
||||
" \n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"scope must be 'subject' or 'global'\")\n",
|
||||
" \n",
|
||||
" return {'scalers': scalers, 'method': method, 'scope': scope}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "642d0017",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def apply_normalizer(data, au_columns, normalizer_dict):\n",
|
||||
" \"\"\"\n",
|
||||
" Apply fitted normalization scalers to data.\n",
|
||||
" \n",
|
||||
" Parameters:\n",
|
||||
" -----------\n",
|
||||
" data : pd.DataFrame\n",
|
||||
" Dataframe with AU columns and subjectID\n",
|
||||
" au_columns : list\n",
|
||||
" List of AU column names to normalize\n",
|
||||
" normalizer_dict : dict\n",
|
||||
" Dictionary containing fitted scalers from fit_normalizer()\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" --------\n",
|
||||
" pd.DataFrame\n",
|
||||
" DataFrame with normalized AU columns\n",
|
||||
" \"\"\"\n",
|
||||
" normalized_data = data.copy()\n",
|
||||
" scalers = normalizer_dict['scalers']\n",
|
||||
" scope = normalizer_dict['scope']\n",
|
||||
" \n",
|
||||
" if scope == 'subject':\n",
|
||||
" # Apply per-subject normalization\n",
|
||||
" for subject in data['subjectID'].unique():\n",
|
||||
" subject_mask = data['subjectID'] == subject\n",
|
||||
" \n",
|
||||
" # Use the subject's scaler if available, otherwise use a fitted scaler from training\n",
|
||||
" if subject in scalers:\n",
|
||||
" scaler = scalers[subject]\n",
|
||||
" else:\n",
|
||||
" # For new subjects not seen in training, use the first available scaler\n",
|
||||
" # (This is a fallback - ideally all test subjects should be in training for subject-level normalization)\n",
|
||||
" print(f\"Warning: Subject {subject} not found in training data. Using fallback scaler.\")\n",
|
||||
" scaler = list(scalers.values())[0]\n",
|
||||
" \n",
|
||||
" normalized_data.loc[subject_mask, au_columns] = scaler.transform(\n",
|
||||
" data.loc[subject_mask, au_columns]\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" elif scope == 'global':\n",
|
||||
" # Apply global normalization\n",
|
||||
" scaler = scalers['global']\n",
|
||||
" normalized_data[au_columns] = scaler.transform(data[au_columns])\n",
|
||||
" \n",
|
||||
" return normalized_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -160,54 +238,246 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5cd4ac6",
|
||||
"id": "bfec0188",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"iforest = IsolationForest(random_state=42)\n",
|
||||
"iforest.fit(X_train)\n",
|
||||
"iforest_scores = iforest.score_samples(X_test)\n",
|
||||
"iforest_predictions = iforest.predict(X_test)"
|
||||
"def supervised_one_class_grid_search(estimator, param_grid, data, labels, seed=None):\n",
|
||||
" np.random.seed(seed)\n",
|
||||
" idx = np.arange(data.shape[0])\n",
|
||||
" anomaly_idx = idx[labels==-1]\n",
|
||||
" normal_idx = idx[labels!=-1]\n",
|
||||
"\n",
|
||||
" np.random.shuffle(normal_idx)\n",
|
||||
"\n",
|
||||
" cv = [(normal_idx[pair[0]], np.concatenate([normal_idx[pair[1]], anomaly_idx], axis=0)) for pair in KFold().split(normal_idx)]\n",
|
||||
" \n",
|
||||
" grid_search = GridSearchCV(estimator=estimator,\n",
|
||||
" param_grid=param_grid,\n",
|
||||
" scoring=lambda est, X, y: roc_auc_score(y_true=y, y_score=est.score_samples(X)),\n",
|
||||
" n_jobs=-2,\n",
|
||||
" cv=cv,\n",
|
||||
" verbose=1,\n",
|
||||
" refit=False)\n",
|
||||
" \n",
|
||||
" grid_search.fit(data, labels)\n",
|
||||
"\n",
|
||||
" return grid_search"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15c45f66",
|
||||
"id": "91d5f83d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluation_tools.plot_confusion_matrix(true_labels=y_test, predictions=iforest_predictions, label_names=['high load', 'low load'])"
|
||||
"# First split: separate test set\n",
|
||||
"train_val_subjects, test_subjects = train_test_split(\n",
|
||||
" subjects, \n",
|
||||
" train_size=12, \n",
|
||||
" test_size=6, \n",
|
||||
" random_state=42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Second split: separate train and validation from the remaining subjects\n",
|
||||
"# Adjust these numbers based on your total subject count\n",
|
||||
"train_subjects, val_subjects = train_test_split(\n",
|
||||
" train_val_subjects,\n",
|
||||
" train_size=8,\n",
|
||||
" test_size=4,\n",
|
||||
" random_state=42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Train subjects: {len(train_subjects)}\")\n",
|
||||
"print(f\"Validation subjects: {len(val_subjects)}\")\n",
|
||||
"print(f\"Test subjects: {len(test_subjects)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "326fcb47",
|
||||
"id": "2400c15a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluation_tools.plot_roc_curve_IF(y_test, iforest_scores)"
|
||||
"# Cell 2: Get AU columns and prepare datasets\n",
|
||||
"# Get all column names that start with 'AU'\n",
|
||||
"au_columns = [col for col in low_all.columns if col.startswith('AU')]\n",
|
||||
"\n",
|
||||
"# Prepare training data (only normal/low data)\n",
|
||||
"train_data = low_all[low_all['subjectID'].isin(train_subjects)][['subjectID'] + au_columns].copy()\n",
|
||||
"\n",
|
||||
"# Prepare validation data (normal and anomaly)\n",
|
||||
"val_normal_data = low_all[low_all['subjectID'].isin(val_subjects)][['subjectID'] + au_columns].copy()\n",
|
||||
"val_high_data = high_all[high_all['subjectID'].isin(val_subjects)][['subjectID'] + au_columns].copy()\n",
|
||||
"\n",
|
||||
"# Prepare test data (normal and anomaly)\n",
|
||||
"test_normal_data = low_all[low_all['subjectID'].isin(test_subjects)][['subjectID'] + au_columns].copy()\n",
|
||||
"test_high_data = high_all[high_all['subjectID'].isin(test_subjects)][['subjectID'] + au_columns].copy()\n",
|
||||
"\n",
|
||||
"print(f\"Train samples: {len(train_data)}\")\n",
|
||||
"print(f\"Val normal samples: {len(val_normal_data)}, Val high samples: {len(val_high_data)}\")\n",
|
||||
"print(f\"Test normal samples: {len(test_normal_data)}, Test high samples: {len(test_high_data)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "141267e4",
|
||||
"id": "5c24f9d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"iforest.offset_"
|
||||
"# Cell 3: Fit normalizer on training data\n",
|
||||
"normalizer = fit_normalizer(train_data, au_columns, method='minmax', scope='global')\n",
|
||||
"print(\"Normalizer fitted on training data\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4bf81d7b",
|
||||
"id": "cbe29b06",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(classification_report(y_test, iforest_predictions))"
|
||||
"# Cell 4: Apply normalization to all datasets\n",
|
||||
"train_normalized = apply_normalizer(train_data, au_columns, normalizer)\n",
|
||||
"val_normal_normalized = apply_normalizer(val_normal_data, au_columns, normalizer)\n",
|
||||
"val_high_normalized = apply_normalizer(val_high_data, au_columns, normalizer)\n",
|
||||
"test_normal_normalized = apply_normalizer(test_normal_data, au_columns, normalizer)\n",
|
||||
"test_high_normalized = apply_normalizer(test_high_data, au_columns, normalizer)\n",
|
||||
"\n",
|
||||
"print(\"Normalization applied to all datasets\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e39fd185",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Cell 5: Extract AU columns and create labels for grid search\n",
|
||||
"# Extract only AU columns (drop subjectID)\n",
|
||||
"X_train = train_normalized[au_columns].copy()\n",
|
||||
"X_val_normal = val_normal_normalized[au_columns].copy()\n",
|
||||
"X_val_high = val_high_normalized[au_columns].copy()\n",
|
||||
"\n",
|
||||
"# Combine train and validation sets for grid search\n",
|
||||
"X_grid_search = pd.concat([X_train, X_val_normal, X_val_high], ignore_index=True)\n",
|
||||
"\n",
|
||||
"# Create labels for grid search\n",
|
||||
"y_train = np.ones(len(X_train)) # 1 for normal (training)\n",
|
||||
"y_val_normal = np.ones(len(X_val_normal)) # 1 for normal (validation)\n",
|
||||
"y_val_high = -np.ones(len(X_val_high)) # -1 for anomalies (validation)\n",
|
||||
"y_grid_search = np.concatenate([y_train, y_val_normal, y_val_high])\n",
|
||||
"\n",
|
||||
"print(f\"Grid search data shape: {X_grid_search.shape}\")\n",
|
||||
"print(f\"Labels distribution: Normal={np.sum(y_grid_search==1)}, Anomaly={np.sum(y_grid_search==-1)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2330e817",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define your estimator and parameter grid\n",
|
||||
"estimator = IsolationForest(random_state=42)\n",
|
||||
"iforest_param_grid = {\n",
|
||||
" 'n_estimators': [100, 200, 300], # Number of trees\n",
|
||||
" 'max_samples': [0.5, 0.75, 1.0, 'auto'], # Subsample size for each tree \n",
|
||||
" 'max_features': [0.5, 0.75, 1.0], # Features to draw for each tree\n",
|
||||
" 'bootstrap': [True, False], # Whether to bootstrap samples\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Perform grid search\n",
|
||||
"grid_search = supervised_one_class_grid_search(\n",
|
||||
" estimator=estimator,\n",
|
||||
" param_grid=iforest_param_grid,\n",
|
||||
" data=X_grid_search.values,\n",
|
||||
" labels=y_grid_search,\n",
|
||||
" seed=42\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Get best parameters\n",
|
||||
"best_params = grid_search.best_params_\n",
|
||||
"best_score = grid_search.best_score_\n",
|
||||
"\n",
|
||||
"print(f\"Best parameters: {best_params}\")\n",
|
||||
"print(f\"Best validation AUC: {best_score:.4f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ad31c951",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Cell 7: Train final model with best parameters on training data\n",
|
||||
"final_model = IsolationForest(**best_params, random_state=42)\n",
|
||||
"final_model.fit(X_train.values)\n",
|
||||
"\n",
|
||||
"print(\"Final model trained on training data only\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4a7a3307",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Cell 8: Prepare independent test set\n",
|
||||
"X_test_normal = test_normal_normalized[au_columns].copy()\n",
|
||||
"X_test_high = test_high_normalized[au_columns].copy()\n",
|
||||
"\n",
|
||||
"# Combine test sets\n",
|
||||
"X_test = pd.concat([X_test_normal, X_test_high], ignore_index=True)\n",
|
||||
"\n",
|
||||
"# Create labels for test set\n",
|
||||
"y_test_normal = np.ones(len(X_test_normal)) # 1 for normal\n",
|
||||
"y_test_high = -np.ones(len(X_test_high)) # -1 for anomalies\n",
|
||||
"y_test = np.concatenate([y_test_normal, y_test_high])\n",
|
||||
"\n",
|
||||
"print(f\"Test set shape: {X_test.shape}\")\n",
|
||||
"print(f\"Test labels distribution: Normal={np.sum(y_test==1)}, Anomaly={np.sum(y_test==-1)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8353d431",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get anomaly scores\n",
|
||||
"y_scores = final_model.score_samples(X_test.values)\n",
|
||||
"# Get predictions (-1 for anomaly, 1 for normal)\n",
|
||||
"y_pred = final_model.predict(X_test.values)\n",
|
||||
"print(classification_report(y_test, y_pred, target_names=['Anomaly', 'Normal']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "64f753a3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluation_tools.plot_confusion_matrix(y_test, y_pred, label_names=['Anomaly', 'Normal'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a3245f17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluation_tools.plot_roc_curve_IF(y_test, y_scores)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user