mini changes in predict pipeline
This commit is contained in:
parent
13bd76631f
commit
3169c29319
@ -1,20 +1,20 @@
|
|||||||
database:
|
database:
|
||||||
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\database.sqlite"
|
path: "/home/edgekit/MSY_FS/databases/database.sqlite"
|
||||||
table: feature_table
|
table: feature_table
|
||||||
key: _Id
|
key: _Id
|
||||||
|
|
||||||
model:
|
model:
|
||||||
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib"
|
path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/xgb_model_3_groupK.joblib"
|
||||||
|
|
||||||
scaler:
|
scaler:
|
||||||
use_scaling: True
|
use_scaling: true
|
||||||
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\normalizer_min_max_global.pkl"
|
path: "/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/normalizer_min_max_global.pkl"
|
||||||
|
|
||||||
mqtt:
|
mqtt:
|
||||||
enabled: true
|
enabled: true
|
||||||
host: "141.75.215.233"
|
host: "141.75.215.233"
|
||||||
port: 1883
|
port: 1883
|
||||||
topic: "PREDICTIONS"
|
topic: "PREDICTION"
|
||||||
client_id: "jetson-board"
|
client_id: "jetson-board"
|
||||||
qos: 0
|
qos: 0
|
||||||
retain: false
|
retain: false
|
||||||
@ -107,4 +107,4 @@ fallback:
|
|||||||
Blink_mean_dur: 0.38857142857142857
|
Blink_mean_dur: 0.38857142857142857
|
||||||
Blink_median_dur: 0.2
|
Blink_median_dur: 0.2
|
||||||
Pupil_mean: 3.2823675201416016
|
Pupil_mean: 3.2823675201416016
|
||||||
Pupil_IPA: 0.0036347377340156025
|
Pupil_IPA: 0.0036347377340156025
|
||||||
@ -7,9 +7,10 @@ import sys
|
|||||||
import yaml
|
import yaml
|
||||||
import pickle
|
import pickle
|
||||||
sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||||
# sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
|
||||||
import db_helpers
|
import db_helpers
|
||||||
import joblib
|
import joblib
|
||||||
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
def _load_serialized(path: Path):
|
def _load_serialized(path: Path):
|
||||||
suffix = path.suffix.lower()
|
suffix = path.suffix.lower()
|
||||||
@ -52,11 +53,11 @@ def callModel(sample, model_path):
|
|||||||
suffix = model_path.suffix.lower()
|
suffix = model_path.suffix.lower()
|
||||||
if suffix in {".pkl", ".joblib"}:
|
if suffix in {".pkl", ".joblib"}:
|
||||||
model = _load_serialized(model_path)
|
model = _load_serialized(model_path)
|
||||||
# elif suffix == ".keras":
|
elif suffix == ".keras":
|
||||||
# import tensorflow as tf
|
import tensorflow as tf
|
||||||
# model = tf.keras.models.load_model(model_path)
|
model = tf.keras.models.load_model(model_path)
|
||||||
# else:
|
else:
|
||||||
# raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||||
|
|
||||||
x = np.asarray(sample, dtype=np.float32)
|
x = np.asarray(sample, dtype=np.float32)
|
||||||
if x.ndim == 1:
|
if x.ndim == 1:
|
||||||
@ -125,44 +126,37 @@ def sendMessage(config_file_path, message):
|
|||||||
|
|
||||||
# Serialize the message to JSON
|
# Serialize the message to JSON
|
||||||
payload = json.dumps(message, ensure_ascii=False)
|
payload = json.dumps(message, ensure_ascii=False)
|
||||||
print(payload)
|
print(payload) # for debugging purposes
|
||||||
|
|
||||||
# Later: publish via MQTT using config parameters above.
|
|
||||||
# Example (kept commented intentionally):
|
client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01"))
|
||||||
# import paho.mqtt.client as mqtt
|
if "username" in mqtt_cfg and mqtt_cfg.get("username"):
|
||||||
# client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01"))
|
client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
|
||||||
# if "username" in mqtt_cfg and mqtt_cfg.get("username"):
|
client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60)
|
||||||
# client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
|
client.publish(
|
||||||
# client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60)
|
topic=topic,
|
||||||
# client.publish(
|
payload=payload,
|
||||||
# topic=topic,
|
qos=int(mqtt_cfg.get("qos", 1)),
|
||||||
# payload=payload,
|
retain=bool(mqtt_cfg.get("retain", False)),
|
||||||
# qos=int(mqtt_cfg.get("qos", 1)),
|
)
|
||||||
# retain=bool(mqtt_cfg.get("retain", False)),
|
client.disconnect()
|
||||||
# )
|
|
||||||
# client.disconnect()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def replace_nan(sample, config_file_path: Path):
|
def replace_nan(sample, config_file_path: Path):
|
||||||
with config_file_path.open("r", encoding="utf-8") as f:
|
with config_file_path.open("r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
fallback_list = cfg.get("fallback", [])
|
fallback_map = cfg.get("fallback", {})
|
||||||
fallback_map = {}
|
|
||||||
for item in fallback_list:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
fallback_map.update(item)
|
|
||||||
|
|
||||||
if sample.empty:
|
if sample.empty:
|
||||||
return False, sample
|
return False, sample
|
||||||
|
|
||||||
nan_ratio = sample.isna().mean()
|
nan_ratio = sample.isna().mean()
|
||||||
valid = nan_ratio <= 0.5
|
valid = nan_ratio <= 0.5
|
||||||
|
|
||||||
if valid and fallback_map:
|
if valid and fallback_map:
|
||||||
sample = sample.fillna(value=fallback_map)
|
sample = sample.fillna(value=fallback_map)
|
||||||
|
|
||||||
|
|
||||||
return valid, sample
|
return valid, sample
|
||||||
|
|
||||||
def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
|
def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
|
||||||
@ -213,7 +207,7 @@ def scale_sample(sample, use_scaling=False, scaler_path=None):
|
|||||||
return df.iloc[0] if isinstance(sample, pd.Series) else df
|
return df.iloc[0] if isinstance(sample, pd.Series) else df
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
pd.set_option('future.no_silent_downcasting', True) # kann ggf raus
|
pd.set_option('future.no_silent_downcasting', True)
|
||||||
|
|
||||||
config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml")
|
config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml")
|
||||||
with config_file_path.open("r", encoding="utf-8") as f:
|
with config_file_path.open("r", encoding="utf-8") as f:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user