From 3169c293198a9ce8916a4516e231b4401b6d273c Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 4 Mar 2026 17:01:43 +0100 Subject: [PATCH] mini changes in predict pipeline --- predict_pipeline/config.yaml | 12 +++--- predict_pipeline/predict_sample.py | 60 ++++++++++++++---------------- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/predict_pipeline/config.yaml b/predict_pipeline/config.yaml index f87bb92..eb64637 100644 --- a/predict_pipeline/config.yaml +++ b/predict_pipeline/config.yaml @@ -1,20 +1,20 @@ database: - path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\database.sqlite" + path: "/home/edgekit/MSY_FS/databases/database.sqlite" table: feature_table key: _Id model: - path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib" + path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/xgb_model_3_groupK.joblib" scaler: - use_scaling: True - path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\normalizer_min_max_global.pkl" + use_scaling: true + path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/normalizer_min_max_global.pkl" mqtt: enabled: true host: "141.75.215.233" port: 1883 - topic: "PREDICTIONS" + topic: "PREDICTION" client_id: "jetson-board" qos: 0 retain: false @@ -107,4 +107,4 @@ fallback: Blink_mean_dur: 0.38857142857142857 Blink_median_dur: 0.2 Pupil_mean: 3.2823675201416016 - Pupil_IPA: 0.0036347377340156025 + Pupil_IPA: 0.0036347377340156025 \ No newline at end of file diff --git a/predict_pipeline/predict_sample.py b/predict_pipeline/predict_sample.py index 8502f68..04c5176 100644 --- a/predict_pipeline/predict_sample.py +++ b/predict_pipeline/predict_sample.py @@ -7,9 +7,10 @@ import sys import yaml import pickle sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools') -# sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools") + import db_helpers import joblib +import paho.mqtt.client as mqtt def _load_serialized(path: Path): suffix = path.suffix.lower() @@ -52,11 +53,11 @@ def callModel(sample, model_path): suffix = model_path.suffix.lower() if suffix in {".pkl", ".joblib"}: model = _load_serialized(model_path) - # elif suffix == ".keras": - # import tensorflow as tf - # model = tf.keras.models.load_model(model_path) - # else: - # raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.") + elif suffix == ".keras": + import tensorflow as tf + model = tf.keras.models.load_model(model_path) + else: + raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.") x = np.asarray(sample, dtype=np.float32) if x.ndim == 1: @@ -125,44 +126,37 @@ def sendMessage(config_file_path, message): # Serialize the message to JSON payload = json.dumps(message, ensure_ascii=False) - print(payload) + print(payload) # for debugging purposes - # Later: publish via MQTT using config parameters above. - # Example (kept commented intentionally): - # import paho.mqtt.client as mqtt - # client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01")) - # if "username" in mqtt_cfg and mqtt_cfg.get("username"): - # client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) - # client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60) - # client.publish( - # topic=topic, - # payload=payload, - # qos=int(mqtt_cfg.get("qos", 1)), - # retain=bool(mqtt_cfg.get("retain", False)), - # ) - # client.disconnect() + + client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01")) + if "username" in mqtt_cfg and mqtt_cfg.get("username"): + client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) + client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60) + client.publish( + topic=topic, + payload=payload, + qos=int(mqtt_cfg.get("qos", 1)), + retain=bool(mqtt_cfg.get("retain", False)), + ) + client.disconnect() return def replace_nan(sample, config_file_path: Path): with config_file_path.open("r", encoding="utf-8") as f: cfg = yaml.safe_load(f) - - fallback_list = cfg.get("fallback", []) - fallback_map = {} - for item in fallback_list: - if isinstance(item, dict): - fallback_map.update(item) - + + fallback_map = cfg.get("fallback", {}) + if sample.empty: return False, sample - + nan_ratio = sample.isna().mean() valid = nan_ratio <= 0.5 - + if valid and fallback_map: sample = sample.fillna(value=fallback_map) - - + return valid, sample def sample_to_numpy(sample, drop_cols=("_Id", "start_time")): @@ -213,7 +207,7 @@ def scale_sample(sample, use_scaling=False, scaler_path=None): return df.iloc[0] if isinstance(sample, pd.Series) else df def main(): - pd.set_option('future.no_silent_downcasting', True) # kann ggf raus + pd.set_option('future.no_silent_downcasting', True) config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml") with config_file_path.open("r", encoding="utf-8") as f: