scaler v2
This commit is contained in:
parent
0088cef32a
commit
2b01085a9e
@ -7,8 +7,8 @@ model:
|
|||||||
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib"
|
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib"
|
||||||
|
|
||||||
scaler:
|
scaler:
|
||||||
use_scaling: False
|
use_scaling: True
|
||||||
path = "C:"
|
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\normalizer_min_max_global.pkl"
|
||||||
|
|
||||||
mqtt:
|
mqtt:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|||||||
@ -11,7 +11,14 @@ sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
|||||||
import db_helpers
|
import db_helpers
|
||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
_MODEL_CACHE = {}
|
def _load_serialized(path: Path):
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix == ".pkl":
|
||||||
|
with path.open("rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
if suffix == ".joblib":
|
||||||
|
return joblib.load(path)
|
||||||
|
raise ValueError(f"Unsupported file format: {suffix}. Use .pkl or .joblib.")
|
||||||
|
|
||||||
def getLastEntryFromSQLite(path, table_name, key="_Id"):
|
def getLastEntryFromSQLite(path, table_name, key="_Id"):
|
||||||
conn, cursor = db_helpers.connect_db(path)
|
conn, cursor = db_helpers.connect_db(path)
|
||||||
@ -43,22 +50,13 @@ def callModel(sample, model_path):
|
|||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
suffix = model_path.suffix.lower()
|
suffix = model_path.suffix.lower()
|
||||||
cache_key = str(model_path)
|
if suffix in {".pkl", ".joblib"}:
|
||||||
|
model = _load_serialized(model_path)
|
||||||
if cache_key in _MODEL_CACHE:
|
elif suffix == ".keras":
|
||||||
model = _MODEL_CACHE[cache_key]
|
import tensorflow as tf
|
||||||
|
model = tf.keras.models.load_model(model_path)
|
||||||
else:
|
else:
|
||||||
if suffix == ".pkl":
|
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||||
with model_path.open("rb") as f:
|
|
||||||
model = pickle.load(f)
|
|
||||||
elif suffix == ".joblib":
|
|
||||||
model = joblib.load(model_path)
|
|
||||||
elif suffix == ".keras":
|
|
||||||
import tensorflow as tf
|
|
||||||
model = tf.keras.models.load_model(model_path)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
|
||||||
_MODEL_CACHE[cache_key] = model
|
|
||||||
|
|
||||||
x = np.asarray(sample, dtype=np.float32)
|
x = np.asarray(sample, dtype=np.float32)
|
||||||
if x.ndim == 1:
|
if x.ndim == 1:
|
||||||
@ -162,14 +160,42 @@ def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
|
|||||||
|
|
||||||
return np.asarray(sample)
|
return np.asarray(sample)
|
||||||
|
|
||||||
def scale_sample(sample, use_scaling=False):
|
def scale_sample(sample, use_scaling=False, scaler_path=None):
|
||||||
if use_scaling:
|
if not use_scaling or scaler_path is None:
|
||||||
# load scaler
|
|
||||||
# normalize
|
|
||||||
return sample
|
return sample
|
||||||
|
scaler_path = Path(scaler_path)
|
||||||
|
if not scaler_path.is_absolute():
|
||||||
|
scaler_path = Path.cwd() / scaler_path
|
||||||
|
scaler_path = scaler_path.resolve()
|
||||||
|
normalizer = _load_serialized(scaler_path)
|
||||||
|
|
||||||
|
# normalizer format from model_training/tools/scaler.py:
|
||||||
|
# {"scalers": {...}, "method": "...", "scope": "..."}
|
||||||
|
scalers = normalizer.get("scalers", {}) if isinstance(normalizer, dict) else {}
|
||||||
|
scope = normalizer.get("scope", "global") if isinstance(normalizer, dict) else "global"
|
||||||
|
if scope == "global":
|
||||||
|
scaler = scalers.get("global")
|
||||||
else:
|
else:
|
||||||
|
scaler = scalers.get("global", next(iter(scalers.values()), None))
|
||||||
|
|
||||||
|
# Optional fallback if the stored object is already a raw scaler.
|
||||||
|
if scaler is None and hasattr(normalizer, "transform"):
|
||||||
|
scaler = normalizer
|
||||||
|
if scaler is None or not hasattr(scaler, "transform"):
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
df = sample.to_frame().T if isinstance(sample, pd.Series) else sample.copy()
|
||||||
|
feature_names = getattr(scaler, "feature_names_in_", None)
|
||||||
|
if feature_names is None:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# Keep columns not in the normalizer unchanged.
|
||||||
|
cols_to_scale = [c for c in df.columns if c in set(feature_names)]
|
||||||
|
if cols_to_scale:
|
||||||
|
df.loc[:, cols_to_scale] = scaler.transform(df.loc[:, cols_to_scale])
|
||||||
|
|
||||||
|
return df.iloc[0] if isinstance(sample, pd.Series) else df
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
config_file_path = Path("predict_pipeline/config.yaml")
|
config_file_path = Path("predict_pipeline/config.yaml")
|
||||||
with config_file_path.open("r", encoding="utf-8") as f:
|
with config_file_path.open("r", encoding="utf-8") as f:
|
||||||
@ -178,7 +204,7 @@ def main():
|
|||||||
database_path = cfg["database"]["path"]
|
database_path = cfg["database"]["path"]
|
||||||
table_name = cfg["database"]["table"]
|
table_name = cfg["database"]["table"]
|
||||||
row_key = cfg["database"]["key"]
|
row_key = cfg["database"]["key"]
|
||||||
use_scaling = cfg.get("scaler", {}).get("use_scaling", cfg.get("scaler", {}).get("use_scaler", False))
|
|
||||||
|
|
||||||
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
||||||
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
||||||
@ -190,8 +216,12 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
model_path = cfg["model"]["path"]
|
model_path = cfg["model"]["path"]
|
||||||
|
scaler_path = cfg["scaler"]["path"]
|
||||||
|
use_scaling = cfg["scaler"]["use_scaling"]
|
||||||
|
|
||||||
|
sample = scale_sample(sample, use_scaling=use_scaling, scaler_path=scaler_path)
|
||||||
sample_np = sample_to_numpy(sample)
|
sample_np = sample_to_numpy(sample)
|
||||||
sample_np = scale_sample(sample_np, use_scaling=use_scaling)
|
|
||||||
prediction = callModel(model_path=model_path, sample=sample_np)
|
prediction = callModel(model_path=model_path, sample=sample_np)
|
||||||
|
|
||||||
message = buildMessage(valid, prediction, config_file_path, sample=sample)
|
message = buildMessage(valid, prediction, config_file_path, sample=sample)
|
||||||
@ -202,4 +232,4 @@ if __name__ == "__main__":
|
|||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
# https://www.youtube.com/watch?v=Q09tWwz6WoI
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user