first version
This commit is contained in:
parent
3d86bfe6d0
commit
2a014e1e4e
@ -3,41 +3,209 @@ import pandas as pd
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import sys
|
||||
import yaml
|
||||
import pickle
|
||||
#sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||
sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
||||
import db_helpers
|
||||
import tensorflow as tf
|
||||
import joblib
|
||||
|
||||
def getLastEntryFromSQLite():
|
||||
_MODEL_CACHE = {}
|
||||
|
||||
return
|
||||
def getLastEntryFromSQLite(path, table_name, key="_Id"):
|
||||
conn, cursor = db_helpers.connect_db(path)
|
||||
try:
|
||||
row_df = db_helpers.get_data_from_table(
|
||||
conn=conn,
|
||||
table_name=table_name,
|
||||
order_by={key: "DESC"},
|
||||
limit=1,
|
||||
)
|
||||
finally:
|
||||
db_helpers.disconnect_db(conn, cursor, commit=False)
|
||||
|
||||
def callModel(sample):
|
||||
prediction: np.int32 = sample # noch unklar ob jedes mal ein load oder z.B. mit Flask API
|
||||
return prediction
|
||||
if row_df.empty:
|
||||
return pd.Series(dtype="object")
|
||||
|
||||
def getMessageConfig( config_file_path):
|
||||
return row_df.iloc[0]
|
||||
|
||||
return dict()
|
||||
def callModel(sample, model_path):
|
||||
if callable(sample):
|
||||
raise TypeError(
|
||||
f"Invalid sample type: got callable `{getattr(sample, '__name__', type(sample).__name__)}`. "
|
||||
"Expected numpy array / pandas row."
|
||||
)
|
||||
|
||||
model_path = Path(model_path)
|
||||
if not model_path.is_absolute():
|
||||
model_path = Path.cwd() / model_path
|
||||
model_path = model_path.resolve()
|
||||
|
||||
suffix = model_path.suffix.lower()
|
||||
cache_key = str(model_path)
|
||||
|
||||
if cache_key in _MODEL_CACHE:
|
||||
model = _MODEL_CACHE[cache_key]
|
||||
else:
|
||||
if suffix == ".pkl":
|
||||
with model_path.open("rb") as f:
|
||||
model = pickle.load(f)
|
||||
elif suffix == ".joblib":
|
||||
model = joblib.load(model_path)
|
||||
elif suffix == ".keras":
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||
_MODEL_CACHE[cache_key] = model
|
||||
|
||||
x = np.asarray(sample, dtype=np.float32)
|
||||
if x.ndim == 1:
|
||||
x = x.reshape(1, -1)
|
||||
|
||||
if suffix == ".keras":
|
||||
x_full = x
|
||||
# Future model (35 features): keep this call when your new model is active.
|
||||
# prediction = model.predict(x_full[:, :35], verbose=0)
|
||||
prediction = model.predict(x_full[:, :20], verbose=0)
|
||||
|
||||
else:
|
||||
if hasattr(model, "predict"):
|
||||
prediction = model.predict(x[:,:20])
|
||||
elif callable(model):
|
||||
prediction = model(x[:,:20])
|
||||
else:
|
||||
raise TypeError("Loaded model has no .predict(...) and is not callable.")
|
||||
|
||||
prediction = np.asarray(prediction)
|
||||
if prediction.size == 1:
|
||||
return prediction.item()
|
||||
return prediction.squeeze()
|
||||
|
||||
|
||||
def buildMessage(result: np.int32, config: dict):
|
||||
# message =json...
|
||||
message = 5
|
||||
|
||||
|
||||
def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
||||
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
mqtt_cfg = cfg.get("mqtt", {})
|
||||
result_key = mqtt_cfg.get("publish_format", {}).get("result_key", "prediction")
|
||||
|
||||
sample_id = None
|
||||
if isinstance(sample, pd.Series):
|
||||
sample_id = sample.get("_Id", sample.get("_id"))
|
||||
elif isinstance(sample, dict):
|
||||
sample_id = sample.get("_Id", sample.get("_id"))
|
||||
|
||||
message = {
|
||||
"valid": bool(valid),
|
||||
"_id": sample_id,
|
||||
result_key: np.asarray(result).tolist() if isinstance(result, np.ndarray) else result,
|
||||
}
|
||||
return message
|
||||
|
||||
|
||||
def sendMessage(destination, message):
|
||||
return 2
|
||||
def sendMessage(config_file_path, message):
|
||||
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
mqtt_cfg = cfg.get("mqtt", {})
|
||||
topic = mqtt_cfg.get("topic", "ml/predictions")
|
||||
|
||||
payload = json.dumps(message, ensure_ascii=False)
|
||||
print(payload)
|
||||
|
||||
# Later: publish via MQTT using config parameters above.
|
||||
# Example (kept commented intentionally):
|
||||
# import paho.mqtt.client as mqtt
|
||||
# client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01"))
|
||||
# if "username" in mqtt_cfg and mqtt_cfg.get("username"):
|
||||
# client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
|
||||
# client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60)
|
||||
# client.publish(
|
||||
# topic=topic,
|
||||
# payload=payload,
|
||||
# qos=int(mqtt_cfg.get("qos", 1)),
|
||||
# retain=bool(mqtt_cfg.get("retain", False)),
|
||||
# )
|
||||
# client.disconnect()
|
||||
return
|
||||
|
||||
def replace_nan(sample, config_file_path: Path):
|
||||
with config_file_path.open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
fallback_list = cfg.get("fallback", [])
|
||||
fallback_map = {}
|
||||
for item in fallback_list:
|
||||
if isinstance(item, dict):
|
||||
fallback_map.update(item)
|
||||
|
||||
if sample.empty:
|
||||
return False, sample
|
||||
|
||||
nan_ratio = sample.isna().mean()
|
||||
valid = nan_ratio <= 0.5
|
||||
|
||||
if valid and fallback_map:
|
||||
sample = sample.fillna(value=fallback_map)
|
||||
|
||||
|
||||
return valid, sample
|
||||
|
||||
def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
|
||||
if isinstance(sample, pd.Series):
|
||||
sample = sample.drop(labels=list(drop_cols), errors="ignore")
|
||||
return sample.to_numpy()
|
||||
|
||||
if isinstance(sample, pd.DataFrame):
|
||||
sample = sample.drop(columns=list(drop_cols), errors="ignore")
|
||||
return sample.to_numpy()
|
||||
|
||||
return np.asarray(sample)
|
||||
|
||||
def scale_sample(sample, use_scaling=False):
|
||||
if use_scaling:
|
||||
# load scaler
|
||||
# normalize
|
||||
return sample
|
||||
else:
|
||||
return sample
|
||||
|
||||
def main():
|
||||
config_file_path = Path("")
|
||||
config = getMessageConfig(config_file_path=config_file_path)
|
||||
config_file_path = Path("predict_pipeline/config.yaml")
|
||||
with config_file_path.open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
sample = getLastEntryFromSQLite()
|
||||
database_path = cfg["database"]["path"]
|
||||
table_name = cfg["database"]["table"]
|
||||
row_key = cfg["database"]["key"]
|
||||
use_scaling = cfg.get("scaler", {}).get("use_scaling", cfg.get("scaler", {}).get("use_scaler", False))
|
||||
|
||||
prediction = callModel(sample=sample)
|
||||
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
||||
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
||||
if not valid:
|
||||
print("Sample invalid: more than 50% NaN.")
|
||||
message = buildMessage(valid, None, config_file_path, sample=sample)
|
||||
sendMessage(config_file_path, message)
|
||||
return
|
||||
model_path = cfg["model"]["path"]
|
||||
sample_np = sample_to_numpy(sample)
|
||||
sample_np = scale_sample(sample_np, use_scaling=use_scaling)
|
||||
prediction = callModel(model_path=model_path, sample=sample_np)
|
||||
|
||||
message = buildMessage(result=prediction, config=config)
|
||||
|
||||
sendMessage(config, message)
|
||||
message = buildMessage(valid, prediction, config_file_path, sample=sample)
|
||||
sendMessage(config_file_path, message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# to do:
|
||||
# config file
|
||||
# bei sample holen outlier ersetzen
|
||||
# mediane abspeichern
|
||||
# falls nur nan, dann sample verwerfen
|
||||
# https://www.youtube.com/watch?v=Q09tWwz6WoI
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user