mini changes in predict pipeline

This commit is contained in:
Michael Weig 2026-03-04 17:01:43 +01:00
parent 13bd76631f
commit 3169c29319
2 changed files with 33 additions and 39 deletions

View File

@ -1,20 +1,20 @@
database: database:
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\database.sqlite" path: "/home/edgekit/MSY_FS/databases/database.sqlite"
table: feature_table table: feature_table
key: _Id key: _Id
model: model:
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib" path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/xgb_model_3_groupK.joblib"
scaler: scaler:
use_scaling: True use_scaling: true
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\normalizer_min_max_global.pkl" path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/normalizer_min_max_global.pkl"
mqtt: mqtt:
enabled: true enabled: true
host: "141.75.215.233" host: "141.75.215.233"
port: 1883 port: 1883
topic: "PREDICTIONS" topic: "PREDICTION"
client_id: "jetson-board" client_id: "jetson-board"
qos: 0 qos: 0
retain: false retain: false
@ -107,4 +107,4 @@ fallback:
Blink_mean_dur: 0.38857142857142857 Blink_mean_dur: 0.38857142857142857
Blink_median_dur: 0.2 Blink_median_dur: 0.2
Pupil_mean: 3.2823675201416016 Pupil_mean: 3.2823675201416016
Pupil_IPA: 0.0036347377340156025 Pupil_IPA: 0.0036347377340156025

View File

@ -7,9 +7,10 @@ import sys
import yaml import yaml
import pickle import pickle
sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools') sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
# sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
import db_helpers import db_helpers
import joblib import joblib
import paho.mqtt.client as mqtt
def _load_serialized(path: Path): def _load_serialized(path: Path):
suffix = path.suffix.lower() suffix = path.suffix.lower()
@ -52,11 +53,11 @@ def callModel(sample, model_path):
suffix = model_path.suffix.lower() suffix = model_path.suffix.lower()
if suffix in {".pkl", ".joblib"}: if suffix in {".pkl", ".joblib"}:
model = _load_serialized(model_path) model = _load_serialized(model_path)
# elif suffix == ".keras": elif suffix == ".keras":
# import tensorflow as tf import tensorflow as tf
# model = tf.keras.models.load_model(model_path) model = tf.keras.models.load_model(model_path)
# else: else:
# raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.") raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
x = np.asarray(sample, dtype=np.float32) x = np.asarray(sample, dtype=np.float32)
if x.ndim == 1: if x.ndim == 1:
@ -125,44 +126,37 @@ def sendMessage(config_file_path, message):
# Serialize the message to JSON # Serialize the message to JSON
payload = json.dumps(message, ensure_ascii=False) payload = json.dumps(message, ensure_ascii=False)
print(payload) print(payload) # for debugging purposes
# Later: publish via MQTT using config parameters above.
# Example (kept commented intentionally): client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01"))
# import paho.mqtt.client as mqtt if "username" in mqtt_cfg and mqtt_cfg.get("username"):
# client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01")) client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
# if "username" in mqtt_cfg and mqtt_cfg.get("username"): client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60)
# client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) client.publish(
# client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60) topic=topic,
# client.publish( payload=payload,
# topic=topic, qos=int(mqtt_cfg.get("qos", 1)),
# payload=payload, retain=bool(mqtt_cfg.get("retain", False)),
# qos=int(mqtt_cfg.get("qos", 1)), )
# retain=bool(mqtt_cfg.get("retain", False)), client.disconnect()
# )
# client.disconnect()
return return
def replace_nan(sample, config_file_path: Path): def replace_nan(sample, config_file_path: Path):
with config_file_path.open("r", encoding="utf-8") as f: with config_file_path.open("r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
fallback_list = cfg.get("fallback", []) fallback_map = cfg.get("fallback", {})
fallback_map = {}
for item in fallback_list:
if isinstance(item, dict):
fallback_map.update(item)
if sample.empty: if sample.empty:
return False, sample return False, sample
nan_ratio = sample.isna().mean() nan_ratio = sample.isna().mean()
valid = nan_ratio <= 0.5 valid = nan_ratio <= 0.5
if valid and fallback_map: if valid and fallback_map:
sample = sample.fillna(value=fallback_map) sample = sample.fillna(value=fallback_map)
return valid, sample return valid, sample
def sample_to_numpy(sample, drop_cols=("_Id", "start_time")): def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
@ -213,7 +207,7 @@ def scale_sample(sample, use_scaling=False, scaler_path=None):
return df.iloc[0] if isinstance(sample, pd.Series) else df return df.iloc[0] if isinstance(sample, pd.Series) else df
def main(): def main():
pd.set_option('future.no_silent_downcasting', True) # kann ggf raus pd.set_option('future.no_silent_downcasting', True)
config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml") config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml")
with config_file_path.open("r", encoding="utf-8") as f: with config_file_path.open("r", encoding="utf-8") as f: