adjusted paths (this is the deployment setting)
This commit is contained in:
parent
2b01085a9e
commit
4eab3c9876
@ -6,8 +6,8 @@ import numpy as np
|
||||
import sys
|
||||
import yaml
|
||||
import pickle
|
||||
#sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||
sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
||||
sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||
# sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
||||
import db_helpers
|
||||
import joblib
|
||||
|
||||
@ -52,11 +52,11 @@ def callModel(sample, model_path):
|
||||
suffix = model_path.suffix.lower()
|
||||
if suffix in {".pkl", ".joblib"}:
|
||||
model = _load_serialized(model_path)
|
||||
elif suffix == ".keras":
|
||||
import tensorflow as tf
|
||||
model = tf.keras.models.load_model(model_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||
# elif suffix == ".keras":
|
||||
# import tensorflow as tf
|
||||
# model = tf.keras.models.load_model(model_path)
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||
|
||||
x = np.asarray(sample, dtype=np.float32)
|
||||
if x.ndim == 1:
|
||||
@ -101,13 +101,29 @@ def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
||||
}
|
||||
return message
|
||||
|
||||
def convert_int64(obj):
|
||||
if isinstance(obj, np.int64):
|
||||
return int(obj)
|
||||
# If the object is a dictionary or list, recursively convert its values
|
||||
elif isinstance(obj, dict):
|
||||
return {key: convert_int64(value) for key, value in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_int64(item) for item in obj]
|
||||
return obj
|
||||
|
||||
def sendMessage(config_file_path, message):
|
||||
# Load the configuration
|
||||
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# Get MQTT configuration
|
||||
mqtt_cfg = cfg.get("mqtt", {})
|
||||
topic = mqtt_cfg.get("topic", "ml/predictions")
|
||||
|
||||
# Convert message to ensure no np.int64 values remain
|
||||
message = convert_int64(message)
|
||||
|
||||
# Serialize the message to JSON
|
||||
payload = json.dumps(message, ensure_ascii=False)
|
||||
print(payload)
|
||||
|
||||
@ -197,7 +213,9 @@ def scale_sample(sample, use_scaling=False, scaler_path=None):
|
||||
return df.iloc[0] if isinstance(sample, pd.Series) else df
|
||||
|
||||
def main():
|
||||
config_file_path = Path("predict_pipeline/config.yaml")
|
||||
pd.set_option('future.no_silent_downcasting', True) # kann ggf raus
|
||||
|
||||
config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml")
|
||||
with config_file_path.open("r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user