From 0088cef32a4006af0895044ba4ddfc93bbf7a094 Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 16 Feb 2026 18:58:18 +0100 Subject: [PATCH] small changes and lazy import of tensorflow --- predict_pipeline/config.yaml | 1 + predict_pipeline/predict_sample.py | 14 ++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/predict_pipeline/config.yaml b/predict_pipeline/config.yaml index abeabe9..a2d2586 100644 --- a/predict_pipeline/config.yaml +++ b/predict_pipeline/config.yaml @@ -8,6 +8,7 @@ model: scaler: use_scaling: False + path = "C:" mqtt: enabled: true diff --git a/predict_pipeline/predict_sample.py b/predict_pipeline/predict_sample.py index 3826ea3..37411ee 100644 --- a/predict_pipeline/predict_sample.py +++ b/predict_pipeline/predict_sample.py @@ -9,7 +9,6 @@ import pickle #sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools') sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools") import db_helpers -import tensorflow as tf import joblib _MODEL_CACHE = {} @@ -55,6 +54,7 @@ def callModel(sample, model_path): elif suffix == ".joblib": model = joblib.load(model_path) elif suffix == ".keras": + import tensorflow as tf model = tf.keras.models.load_model(model_path) else: raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.") @@ -83,9 +83,6 @@ def callModel(sample, model_path): return prediction.item() return prediction.squeeze() - - - def buildMessage(valid, result: np.int32, config_file_path, sample=None): with Path(config_file_path).open("r", encoding="utf-8") as f: cfg = yaml.safe_load(f) @@ -106,7 +103,6 @@ def buildMessage(valid, result: np.int32, config_file_path, sample=None): } return message - def sendMessage(config_file_path, message): with Path(config_file_path).open("r", encoding="utf-8") as f: cfg = yaml.safe_load(f) @@ -186,11 +182,13 @@ def main(): sample = getLastEntryFromSQLite(database_path, table_name, row_key) valid, sample = replace_nan(sample, config_file_path=config_file_path) + if not valid: print("Sample invalid: more than 50% NaN.") message = buildMessage(valid, None, config_file_path, sample=sample) sendMessage(config_file_path, message) return + model_path = cfg["model"]["path"] sample_np = sample_to_numpy(sample) sample_np = scale_sample(sample_np, use_scaling=use_scaling) @@ -203,9 +201,5 @@ def main(): if __name__ == "__main__": main() -# to do: -# config file -# bei sample holen outlier ersetzen -# mediane abspeichern -# falls nur nan, dann sample verwerfen + # https://www.youtube.com/watch?v=Q09tWwz6WoI