moved scaler into separat file
This commit is contained in:
parent
e3d9020032
commit
bc1d6d4cac
98
model_training/tools/scaler.py
Normal file
98
model_training/tools/scaler.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def fit_normalizer(train_data, au_columns, method='standard', scope='global'):
|
||||||
|
"""
|
||||||
|
Fit normalization scalers on training data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
train_data : pd.DataFrame
|
||||||
|
Training dataframe with AU columns and subjectID
|
||||||
|
au_columns : list
|
||||||
|
List of AU column names to normalize
|
||||||
|
method : str, default='standard'
|
||||||
|
Normalization method: 'standard' for StandardScaler or 'minmax' for MinMaxScaler
|
||||||
|
scope : str, default='global'
|
||||||
|
Normalization scope: 'subject' for per-subject or 'global' for across all subjects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
dict
|
||||||
|
Dictionary containing fitted scalers
|
||||||
|
"""
|
||||||
|
# Select scaler based on method
|
||||||
|
if method == 'standard':
|
||||||
|
Scaler = StandardScaler
|
||||||
|
elif method == 'minmax':
|
||||||
|
Scaler = MinMaxScaler
|
||||||
|
else:
|
||||||
|
raise ValueError("method must be 'standard' or 'minmax'")
|
||||||
|
|
||||||
|
scalers = {}
|
||||||
|
|
||||||
|
if scope == 'subject':
|
||||||
|
# Fit one scaler per subject
|
||||||
|
for subject in train_data['subjectID'].unique():
|
||||||
|
subject_mask = train_data['subjectID'] == subject
|
||||||
|
scaler = Scaler()
|
||||||
|
scaler.fit(train_data.loc[subject_mask, au_columns])
|
||||||
|
scalers[subject] = scaler
|
||||||
|
|
||||||
|
elif scope == 'global':
|
||||||
|
# Fit one scaler for all subjects
|
||||||
|
scaler = Scaler()
|
||||||
|
scaler.fit(train_data[au_columns])
|
||||||
|
scalers['global'] = scaler
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("scope must be 'subject' or 'global'")
|
||||||
|
|
||||||
|
return {'scalers': scalers, 'method': method, 'scope': scope}
|
||||||
|
|
||||||
|
def apply_normalizer(data, au_columns, normalizer_dict):
|
||||||
|
"""
|
||||||
|
Apply fitted normalization scalers to data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
data : pd.DataFrame
|
||||||
|
Dataframe with AU columns and subjectID
|
||||||
|
au_columns : list
|
||||||
|
List of AU column names to normalize
|
||||||
|
normalizer_dict : dict
|
||||||
|
Dictionary containing fitted scalers from fit_normalizer()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
pd.DataFrame
|
||||||
|
DataFrame with normalized AU columns
|
||||||
|
"""
|
||||||
|
normalized_data = data.copy()
|
||||||
|
scalers = normalizer_dict['scalers']
|
||||||
|
scope = normalizer_dict['scope']
|
||||||
|
|
||||||
|
if scope == 'subject':
|
||||||
|
# Apply per-subject normalization
|
||||||
|
for subject in data['subjectID'].unique():
|
||||||
|
subject_mask = data['subjectID'] == subject
|
||||||
|
|
||||||
|
# Use the subject's scaler if available, otherwise use a fitted scaler from training
|
||||||
|
if subject in scalers:
|
||||||
|
scaler = scalers[subject]
|
||||||
|
else:
|
||||||
|
# For new subjects not seen in training, use the first available scaler
|
||||||
|
# (This is a fallback - ideally all test subjects should be in training for subject-level normalization)
|
||||||
|
print(f"Warning: Subject {subject} not found in training data. Using fallback scaler.")
|
||||||
|
scaler = list(scalers.values())[0]
|
||||||
|
|
||||||
|
normalized_data.loc[subject_mask, au_columns] = scaler.transform(
|
||||||
|
data.loc[subject_mask, au_columns]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif scope == 'global':
|
||||||
|
# Apply global normalization
|
||||||
|
scaler = scalers['global']
|
||||||
|
normalized_data[au_columns] = scaler.transform(data[au_columns])
|
||||||
|
|
||||||
|
return normalized_data
|
||||||
Loading…
x
Reference in New Issue
Block a user