diff --git a/model_training/tools/scaler.py b/model_training/tools/scaler.py new file mode 100644 index 0000000..7449c9e --- /dev/null +++ b/model_training/tools/scaler.py @@ -0,0 +1,98 @@ +from sklearn.preprocessing import MinMaxScaler, StandardScaler +import pandas as pd + +def fit_normalizer(train_data, au_columns, method='standard', scope='global'): + """ + Fit normalization scalers on training data. + + Parameters: + ----------- + train_data : pd.DataFrame + Training dataframe with AU columns and subjectID + au_columns : list + List of AU column names to normalize + method : str, default='standard' + Normalization method: 'standard' for StandardScaler or 'minmax' for MinMaxScaler + scope : str, default='global' + Normalization scope: 'subject' for per-subject or 'global' for across all subjects + + Returns: + -------- + dict + Dictionary containing fitted scalers + """ + # Select scaler based on method + if method == 'standard': + Scaler = StandardScaler + elif method == 'minmax': + Scaler = MinMaxScaler + else: + raise ValueError("method must be 'standard' or 'minmax'") + + scalers = {} + + if scope == 'subject': + # Fit one scaler per subject + for subject in train_data['subjectID'].unique(): + subject_mask = train_data['subjectID'] == subject + scaler = Scaler() + scaler.fit(train_data.loc[subject_mask, au_columns]) + scalers[subject] = scaler + + elif scope == 'global': + # Fit one scaler for all subjects + scaler = Scaler() + scaler.fit(train_data[au_columns]) + scalers['global'] = scaler + + else: + raise ValueError("scope must be 'subject' or 'global'") + + return {'scalers': scalers, 'method': method, 'scope': scope} + +def apply_normalizer(data, au_columns, normalizer_dict): + """ + Apply fitted normalization scalers to data. + + Parameters: + ----------- + data : pd.DataFrame + Dataframe with AU columns and subjectID + au_columns : list + List of AU column names to normalize + normalizer_dict : dict + Dictionary containing fitted scalers from fit_normalizer() + + Returns: + -------- + pd.DataFrame + DataFrame with normalized AU columns + """ + normalized_data = data.copy() + scalers = normalizer_dict['scalers'] + scope = normalizer_dict['scope'] + + if scope == 'subject': + # Apply per-subject normalization + for subject in data['subjectID'].unique(): + subject_mask = data['subjectID'] == subject + + # Use the subject's scaler if available, otherwise use a fitted scaler from training + if subject in scalers: + scaler = scalers[subject] + else: + # For new subjects not seen in training, use the first available scaler + # (This is a fallback - ideally all test subjects should be in training for subject-level normalization) + print(f"Warning: Subject {subject} not found in training data. Using fallback scaler.") + scaler = list(scalers.values())[0] + + normalized_data.loc[subject_mask, au_columns] = scaler.transform( + data.loc[subject_mask, au_columns] + ) + + elif scope == 'global': + # Apply global normalization + scaler = scalers['global'] + normalized_data[au_columns] = scaler.transform(data[au_columns]) + + return normalized_data \ No newline at end of file