diff --git a/predict_pipeline/predict_sample.py b/predict_pipeline/predict_sample.py index 7951c39..3826ea3 100644 --- a/predict_pipeline/predict_sample.py +++ b/predict_pipeline/predict_sample.py @@ -3,41 +3,209 @@ import pandas as pd import json from pathlib import Path import numpy as np +import sys +import yaml +import pickle +#sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools') +sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools") +import db_helpers +import tensorflow as tf +import joblib -def getLastEntryFromSQLite(): +_MODEL_CACHE = {} - return +def getLastEntryFromSQLite(path, table_name, key="_Id"): + conn, cursor = db_helpers.connect_db(path) + try: + row_df = db_helpers.get_data_from_table( + conn=conn, + table_name=table_name, + order_by={key: "DESC"}, + limit=1, + ) + finally: + db_helpers.disconnect_db(conn, cursor, commit=False) -def callModel(sample): - prediction: np.int32 = sample # noch unklar ob jedes mal ein load oder z.B. mit Flask API - return prediction + if row_df.empty: + return pd.Series(dtype="object") -def getMessageConfig( config_file_path): + return row_df.iloc[0] - return dict() +def callModel(sample, model_path): + if callable(sample): + raise TypeError( + f"Invalid sample type: got callable `{getattr(sample, '__name__', type(sample).__name__)}`. " + "Expected numpy array / pandas row." + ) + + model_path = Path(model_path) + if not model_path.is_absolute(): + model_path = Path.cwd() / model_path + model_path = model_path.resolve() + + suffix = model_path.suffix.lower() + cache_key = str(model_path) + + if cache_key in _MODEL_CACHE: + model = _MODEL_CACHE[cache_key] + else: + if suffix == ".pkl": + with model_path.open("rb") as f: + model = pickle.load(f) + elif suffix == ".joblib": + model = joblib.load(model_path) + elif suffix == ".keras": + model = tf.keras.models.load_model(model_path) + else: + raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.") + _MODEL_CACHE[cache_key] = model + + x = np.asarray(sample, dtype=np.float32) + if x.ndim == 1: + x = x.reshape(1, -1) + + if suffix == ".keras": + x_full = x + # Future model (35 features): keep this call when your new model is active. + # prediction = model.predict(x_full[:, :35], verbose=0) + prediction = model.predict(x_full[:, :20], verbose=0) + + else: + if hasattr(model, "predict"): + prediction = model.predict(x[:,:20]) + elif callable(model): + prediction = model(x[:,:20]) + else: + raise TypeError("Loaded model has no .predict(...) and is not callable.") + + prediction = np.asarray(prediction) + if prediction.size == 1: + return prediction.item() + return prediction.squeeze() -def buildMessage(result: np.int32, config: dict): - # message =json... - message = 5 + + +def buildMessage(valid, result: np.int32, config_file_path, sample=None): + with Path(config_file_path).open("r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + + mqtt_cfg = cfg.get("mqtt", {}) + result_key = mqtt_cfg.get("publish_format", {}).get("result_key", "prediction") + + sample_id = None + if isinstance(sample, pd.Series): + sample_id = sample.get("_Id", sample.get("_id")) + elif isinstance(sample, dict): + sample_id = sample.get("_Id", sample.get("_id")) + + message = { + "valid": bool(valid), + "_id": sample_id, + result_key: np.asarray(result).tolist() if isinstance(result, np.ndarray) else result, + } return message -def sendMessage(destination, message): - return 2 +def sendMessage(config_file_path, message): + with Path(config_file_path).open("r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + + mqtt_cfg = cfg.get("mqtt", {}) + topic = mqtt_cfg.get("topic", "ml/predictions") + + payload = json.dumps(message, ensure_ascii=False) + print(payload) + + # Later: publish via MQTT using config parameters above. + # Example (kept commented intentionally): + # import paho.mqtt.client as mqtt + # client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01")) + # if "username" in mqtt_cfg and mqtt_cfg.get("username"): + # client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password")) + # client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60) + # client.publish( + # topic=topic, + # payload=payload, + # qos=int(mqtt_cfg.get("qos", 1)), + # retain=bool(mqtt_cfg.get("retain", False)), + # ) + # client.disconnect() + return + +def replace_nan(sample, config_file_path: Path): + with config_file_path.open("r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + + fallback_list = cfg.get("fallback", []) + fallback_map = {} + for item in fallback_list: + if isinstance(item, dict): + fallback_map.update(item) + + if sample.empty: + return False, sample + + nan_ratio = sample.isna().mean() + valid = nan_ratio <= 0.5 + + if valid and fallback_map: + sample = sample.fillna(value=fallback_map) + + + return valid, sample + +def sample_to_numpy(sample, drop_cols=("_Id", "start_time")): + if isinstance(sample, pd.Series): + sample = sample.drop(labels=list(drop_cols), errors="ignore") + return sample.to_numpy() + + if isinstance(sample, pd.DataFrame): + sample = sample.drop(columns=list(drop_cols), errors="ignore") + return sample.to_numpy() + + return np.asarray(sample) + +def scale_sample(sample, use_scaling=False): + if use_scaling: + # load scaler + # normalize + return sample + else: + return sample def main(): - config_file_path = Path("") - config = getMessageConfig(config_file_path=config_file_path) + config_file_path = Path("predict_pipeline/config.yaml") + with config_file_path.open("r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) - sample = getLastEntryFromSQLite() + database_path = cfg["database"]["path"] + table_name = cfg["database"]["table"] + row_key = cfg["database"]["key"] + use_scaling = cfg.get("scaler", {}).get("use_scaling", cfg.get("scaler", {}).get("use_scaler", False)) - prediction = callModel(sample=sample) + sample = getLastEntryFromSQLite(database_path, table_name, row_key) + valid, sample = replace_nan(sample, config_file_path=config_file_path) + if not valid: + print("Sample invalid: more than 50% NaN.") + message = buildMessage(valid, None, config_file_path, sample=sample) + sendMessage(config_file_path, message) + return + model_path = cfg["model"]["path"] + sample_np = sample_to_numpy(sample) + sample_np = scale_sample(sample_np, use_scaling=use_scaling) + prediction = callModel(model_path=model_path, sample=sample_np) - message = buildMessage(result=prediction, config=config) - - sendMessage(config, message) + message = buildMessage(valid, prediction, config_file_path, sample=sample) + sendMessage(config_file_path, message) if __name__ == "__main__": - main() \ No newline at end of file + main() + +# to do: +# config file +# bei sample holen outlier ersetzen +# mediane abspeichern +# falls nur nan, dann sample verwerfen +# https://www.youtube.com/watch?v=Q09tWwz6WoI