small changes and lazy import of tensorflow
This commit is contained in:
parent
cf88f88814
commit
0088cef32a
@ -8,6 +8,7 @@ model:
|
|||||||
|
|
||||||
scaler:
|
scaler:
|
||||||
use_scaling: False
|
use_scaling: False
|
||||||
|
path = "C:"
|
||||||
|
|
||||||
mqtt:
|
mqtt:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import pickle
|
|||||||
#sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
#sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||||
sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
||||||
import db_helpers
|
import db_helpers
|
||||||
import tensorflow as tf
|
|
||||||
import joblib
|
import joblib
|
||||||
|
|
||||||
_MODEL_CACHE = {}
|
_MODEL_CACHE = {}
|
||||||
@ -55,6 +54,7 @@ def callModel(sample, model_path):
|
|||||||
elif suffix == ".joblib":
|
elif suffix == ".joblib":
|
||||||
model = joblib.load(model_path)
|
model = joblib.load(model_path)
|
||||||
elif suffix == ".keras":
|
elif suffix == ".keras":
|
||||||
|
import tensorflow as tf
|
||||||
model = tf.keras.models.load_model(model_path)
|
model = tf.keras.models.load_model(model_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||||
@ -83,9 +83,6 @@ def callModel(sample, model_path):
|
|||||||
return prediction.item()
|
return prediction.item()
|
||||||
return prediction.squeeze()
|
return prediction.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
||||||
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
@ -106,7 +103,6 @@ def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
|||||||
}
|
}
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def sendMessage(config_file_path, message):
|
def sendMessage(config_file_path, message):
|
||||||
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
@ -186,11 +182,13 @@ def main():
|
|||||||
|
|
||||||
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
||||||
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
||||||
|
|
||||||
if not valid:
|
if not valid:
|
||||||
print("Sample invalid: more than 50% NaN.")
|
print("Sample invalid: more than 50% NaN.")
|
||||||
message = buildMessage(valid, None, config_file_path, sample=sample)
|
message = buildMessage(valid, None, config_file_path, sample=sample)
|
||||||
sendMessage(config_file_path, message)
|
sendMessage(config_file_path, message)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_path = cfg["model"]["path"]
|
model_path = cfg["model"]["path"]
|
||||||
sample_np = sample_to_numpy(sample)
|
sample_np = sample_to_numpy(sample)
|
||||||
sample_np = scale_sample(sample_np, use_scaling=use_scaling)
|
sample_np = scale_sample(sample_np, use_scaling=use_scaling)
|
||||||
@ -203,9 +201,5 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
# to do:
|
|
||||||
# config file
|
|
||||||
# bei sample holen outlier ersetzen
|
|
||||||
# mediane abspeichern
|
|
||||||
# falls nur nan, dann sample verwerfen
|
|
||||||
# https://www.youtube.com/watch?v=Q09tWwz6WoI
|
# https://www.youtube.com/watch?v=Q09tWwz6WoI
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user