merge of deployment into main
This commit is contained in:
commit
8f24adbdbd
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,4 +3,5 @@
|
|||||||
!*.py
|
!*.py
|
||||||
!*.ipynb
|
!*.ipynb
|
||||||
!*.md
|
!*.md
|
||||||
|
!*.parquet
|
||||||
!.gitignore
|
!.gitignore
|
||||||
|
|||||||
BIN
files_for_testing/50s_25Hz_dataset.parquet
Normal file
BIN
files_for_testing/50s_25Hz_dataset.parquet
Normal file
Binary file not shown.
11
predict_pipeline/check_python_version.py
Normal file
11
predict_pipeline/check_python_version.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# from tools import db_helpers
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print(sys.version)
|
||||||
|
# db_helpers.add_columns_to_table()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
117
predict_pipeline/config.yaml
Normal file
117
predict_pipeline/config.yaml
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
database:
|
||||||
|
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\database.sqlite"
|
||||||
|
table: feature_table
|
||||||
|
key: _Id
|
||||||
|
|
||||||
|
model:
|
||||||
|
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\files_for_testing\\xgb_model_3_groupK.joblib"
|
||||||
|
|
||||||
|
scaler:
|
||||||
|
use_scaling: True
|
||||||
|
path: "C:\\repo\\Fahrsimulator_MSY2526_AI\\predict_pipeline\\normalizer_min_max_global.pkl"
|
||||||
|
|
||||||
|
mqtt:
|
||||||
|
enabled: true
|
||||||
|
host: "localhost"
|
||||||
|
port: 1883
|
||||||
|
topic: "ml/predictions"
|
||||||
|
client_id: "predictor-01"
|
||||||
|
qos: 1
|
||||||
|
retain: false
|
||||||
|
# username: ""
|
||||||
|
# password: ""
|
||||||
|
tls:
|
||||||
|
enabled: false
|
||||||
|
# ca_cert: ""
|
||||||
|
# client_cert: ""
|
||||||
|
# client_key: ""
|
||||||
|
publish_format:
|
||||||
|
result_key: prediction # where to store the predicted value in payload
|
||||||
|
include_metadata: true # e.g., timestamps, rowid, etc.
|
||||||
|
|
||||||
|
sample:
|
||||||
|
columns:
|
||||||
|
- _Id
|
||||||
|
- start_time
|
||||||
|
- FACE_AU01_mean
|
||||||
|
- FACE_AU02_mean
|
||||||
|
- FACE_AU04_mean
|
||||||
|
- FACE_AU05_mean
|
||||||
|
- FACE_AU06_mean
|
||||||
|
- FACE_AU07_mean
|
||||||
|
- FACE_AU09_mean
|
||||||
|
- FACE_AU10_mean
|
||||||
|
- FACE_AU11_mean
|
||||||
|
- FACE_AU12_mean
|
||||||
|
- FACE_AU14_mean
|
||||||
|
- FACE_AU15_mean
|
||||||
|
- FACE_AU17_mean
|
||||||
|
- FACE_AU20_mean
|
||||||
|
- FACE_AU23_mean
|
||||||
|
- FACE_AU24_mean
|
||||||
|
- FACE_AU25_mean
|
||||||
|
- FACE_AU26_mean
|
||||||
|
- FACE_AU28_mean
|
||||||
|
- FACE_AU43_mean
|
||||||
|
- Fix_count_short_66_150
|
||||||
|
- Fix_count_medium_300_500
|
||||||
|
- Fix_count_long_gt_1000
|
||||||
|
- Fix_count_100
|
||||||
|
- Fix_mean_duration
|
||||||
|
- Fix_median_duration
|
||||||
|
- Sac_count
|
||||||
|
- Sac_mean_amp
|
||||||
|
- Sac_mean_dur
|
||||||
|
- Sac_median_dur
|
||||||
|
- Blink_count
|
||||||
|
- Blink_mean_dur
|
||||||
|
- Blink_median_dur
|
||||||
|
- Pupil_mean
|
||||||
|
- Pupil_IPA
|
||||||
|
|
||||||
|
fill_nan_with_median: true
|
||||||
|
discard_if_all_nan: true
|
||||||
|
|
||||||
|
fallback:
|
||||||
|
- start_time: 0
|
||||||
|
- FACE_AU01_mean: 0.5
|
||||||
|
- FACE_AU02_mean: 0.5
|
||||||
|
- FACE_AU04_mean: 0.5
|
||||||
|
- FACE_AU05_mean: 0.5
|
||||||
|
- FACE_AU06_mean: 0.5
|
||||||
|
- FACE_AU07_mean: 0.5
|
||||||
|
- FACE_AU09_mean: 0.5
|
||||||
|
- FACE_AU10_mean: 0.5
|
||||||
|
- FACE_AU11_mean: 0.5
|
||||||
|
- FACE_AU12_mean: 0.5
|
||||||
|
- FACE_AU14_mean: 0.5
|
||||||
|
- FACE_AU15_mean: 0.5
|
||||||
|
- FACE_AU17_mean: 0.5
|
||||||
|
- FACE_AU20_mean: 0.5
|
||||||
|
- FACE_AU23_mean: 0.5
|
||||||
|
- FACE_AU24_mean: 0.5
|
||||||
|
- FACE_AU25_mean: 0.5
|
||||||
|
- FACE_AU26_mean: 0.5
|
||||||
|
- FACE_AU28_mean: 0.5
|
||||||
|
- FACE_AU43_mean: 0.5
|
||||||
|
- Fix_count_short_66_150: 2
|
||||||
|
- Fix_count_medium_300_500: 2
|
||||||
|
- Fix_count_long_gt_1000: 2
|
||||||
|
- Fix_count_100: 2
|
||||||
|
- Fix_mean_duration: 100
|
||||||
|
- Fix_median_duration: 100
|
||||||
|
- Sac_count: 2
|
||||||
|
- Sac_mean_amp: 2
|
||||||
|
- Sac_mean_dur: 100
|
||||||
|
- Sac_median_dur: 100
|
||||||
|
- Blink_count: 2
|
||||||
|
- Blink_mean_dur: 2
|
||||||
|
- Blink_median_dur: 2
|
||||||
|
- Pupil_mean: 2
|
||||||
|
- Pupil_IPA: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
9
predict_pipeline/feature_extraction.py
Normal file
9
predict_pipeline/feature_extraction.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
211
predict_pipeline/fill_db.ipynb
Normal file
211
predict_pipeline/fill_db.ipynb
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0d70a13f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"import db_helpers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ce696366",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"database_path = Path(r\"/home/edgekit/MSY_FS/databases/rawdata.sqlite\")\n",
|
||||||
|
"parquet_path = Path(r\"/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/files_for_testing/both_mod_0000.parquet\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b1aa9398",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset = pd.read_parquet(parquet_path)\n",
|
||||||
|
"dataset.head()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b183746e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dataset.dtypes"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "24ed769d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"con, cursor = db_helpers.connect_db(database_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e604ed30",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df_clean = dataset.drop(columns=['subjectID','rowID', 'STUDY', 'LEVEL', 'PHASE'])\n",
|
||||||
|
"df_first_100 = df_clean.head(200)\n",
|
||||||
|
"df_first_100 = df_first_100.reset_index(drop=True)\n",
|
||||||
|
"df_first_100.insert(0, '_Id', df_first_100.index + 1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e77a812e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def pandas_to_sqlite_dtype(dtype):\n",
|
||||||
|
" if pd.api.types.is_integer_dtype(dtype):\n",
|
||||||
|
" return \"INTEGER\"\n",
|
||||||
|
" if pd.api.types.is_float_dtype(dtype):\n",
|
||||||
|
" return \"REAL\"\n",
|
||||||
|
" if pd.api.types.is_bool_dtype(dtype):\n",
|
||||||
|
" return \"INTEGER\"\n",
|
||||||
|
" if pd.api.types.is_datetime64_any_dtype(dtype):\n",
|
||||||
|
" return \"TEXT\"\n",
|
||||||
|
" return \"TEXT\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0e8897b2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"columns = {\n",
|
||||||
|
" col: pandas_to_sqlite_dtype(dtype)\n",
|
||||||
|
" for col, dtype in df_first_100.dtypes.items()\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"constraints = {\n",
|
||||||
|
" \"_Id\": [\"NOT NULL\"]\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"primary_key = {\n",
|
||||||
|
" \"pk_df_first_100\": [\"_Id\"]\n",
|
||||||
|
"}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4ab57624",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"sql = db_helpers.create_table(\n",
|
||||||
|
" conn=con,\n",
|
||||||
|
" cursor=cursor,\n",
|
||||||
|
" table_name=\"rawdata\",\n",
|
||||||
|
" columns=columns,\n",
|
||||||
|
" constraints=constraints,\n",
|
||||||
|
" primary_key=primary_key,\n",
|
||||||
|
" commit=True\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "25096a7f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"columns_to_insert = {\n",
|
||||||
|
" col: df_first_100[col].tolist()\n",
|
||||||
|
" for col in df_first_100.columns\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7a5a3aa8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db_helpers.insert_rows_into_table(\n",
|
||||||
|
" conn=con,\n",
|
||||||
|
" cursor=cursor,\n",
|
||||||
|
" table_name=\"rawdata\",\n",
|
||||||
|
" columns=columns_to_insert,\n",
|
||||||
|
" commit=True\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b56beae2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"a = db_helpers.get_data_from_table(conn=con, table_name='rawdata',columns_list=['*'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a4a74a9d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"a.head()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "da0f8737",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"db_helpers.disconnect_db(con, cursor)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "MSY_FS_env",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
253
predict_pipeline/predict_sample.py
Normal file
253
predict_pipeline/predict_sample.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
# Imports
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
import pickle
|
||||||
|
sys.path.append('/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/tools')
|
||||||
|
# sys.path.append(r"c:\\repo\\Fahrsimulator_MSY2526_AI\\tools")
|
||||||
|
import db_helpers
|
||||||
|
import joblib
|
||||||
|
|
||||||
|
def _load_serialized(path: Path):
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix == ".pkl":
|
||||||
|
with path.open("rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
if suffix == ".joblib":
|
||||||
|
return joblib.load(path)
|
||||||
|
raise ValueError(f"Unsupported file format: {suffix}. Use .pkl or .joblib.")
|
||||||
|
|
||||||
|
def getLastEntryFromSQLite(path, table_name, key="_Id"):
|
||||||
|
conn, cursor = db_helpers.connect_db(path)
|
||||||
|
try:
|
||||||
|
row_df = db_helpers.get_data_from_table(
|
||||||
|
conn=conn,
|
||||||
|
table_name=table_name,
|
||||||
|
order_by={key: "DESC"},
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db_helpers.disconnect_db(conn, cursor, commit=False)
|
||||||
|
|
||||||
|
if row_df.empty:
|
||||||
|
return pd.Series(dtype="object")
|
||||||
|
|
||||||
|
return row_df.iloc[0]
|
||||||
|
|
||||||
|
def callModel(sample, model_path):
|
||||||
|
if callable(sample):
|
||||||
|
raise TypeError(
|
||||||
|
f"Invalid sample type: got callable `{getattr(sample, '__name__', type(sample).__name__)}`. "
|
||||||
|
"Expected numpy array / pandas row."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = Path(model_path)
|
||||||
|
if not model_path.is_absolute():
|
||||||
|
model_path = Path.cwd() / model_path
|
||||||
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
|
suffix = model_path.suffix.lower()
|
||||||
|
if suffix in {".pkl", ".joblib"}:
|
||||||
|
model = _load_serialized(model_path)
|
||||||
|
# elif suffix == ".keras":
|
||||||
|
# import tensorflow as tf
|
||||||
|
# model = tf.keras.models.load_model(model_path)
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Unsupported model format: {suffix}. Use .pkl, .joblib, or .keras.")
|
||||||
|
|
||||||
|
x = np.asarray(sample, dtype=np.float32)
|
||||||
|
if x.ndim == 1:
|
||||||
|
x = x.reshape(1, -1)
|
||||||
|
|
||||||
|
if suffix == ".keras":
|
||||||
|
x_full = x
|
||||||
|
# Future model (35 features): keep this call when your new model is active.
|
||||||
|
# prediction = model.predict(x_full[:, :35], verbose=0)
|
||||||
|
prediction = model.predict(x_full[:, :20], verbose=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if hasattr(model, "predict"):
|
||||||
|
prediction = model.predict(x[:,:20])
|
||||||
|
elif callable(model):
|
||||||
|
prediction = model(x[:,:20])
|
||||||
|
else:
|
||||||
|
raise TypeError("Loaded model has no .predict(...) and is not callable.")
|
||||||
|
|
||||||
|
prediction = np.asarray(prediction)
|
||||||
|
if prediction.size == 1:
|
||||||
|
return prediction.item()
|
||||||
|
return prediction.squeeze()
|
||||||
|
|
||||||
|
def buildMessage(valid, result: np.int32, config_file_path, sample=None):
|
||||||
|
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
|
mqtt_cfg = cfg.get("mqtt", {})
|
||||||
|
result_key = mqtt_cfg.get("publish_format", {}).get("result_key", "prediction")
|
||||||
|
|
||||||
|
sample_id = None
|
||||||
|
if isinstance(sample, pd.Series):
|
||||||
|
sample_id = sample.get("_Id", sample.get("_id"))
|
||||||
|
elif isinstance(sample, dict):
|
||||||
|
sample_id = sample.get("_Id", sample.get("_id"))
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"valid": bool(valid),
|
||||||
|
"_id": sample_id,
|
||||||
|
result_key: np.asarray(result).tolist() if isinstance(result, np.ndarray) else result,
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
|
||||||
|
def convert_int64(obj):
|
||||||
|
if isinstance(obj, np.int64):
|
||||||
|
return int(obj)
|
||||||
|
# If the object is a dictionary or list, recursively convert its values
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {key: convert_int64(value) for key, value in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [convert_int64(item) for item in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def sendMessage(config_file_path, message):
|
||||||
|
# Load the configuration
|
||||||
|
with Path(config_file_path).open("r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Get MQTT configuration
|
||||||
|
mqtt_cfg = cfg.get("mqtt", {})
|
||||||
|
topic = mqtt_cfg.get("topic", "ml/predictions")
|
||||||
|
|
||||||
|
# Convert message to ensure no np.int64 values remain
|
||||||
|
message = convert_int64(message)
|
||||||
|
|
||||||
|
# Serialize the message to JSON
|
||||||
|
payload = json.dumps(message, ensure_ascii=False)
|
||||||
|
print(payload)
|
||||||
|
|
||||||
|
# Later: publish via MQTT using config parameters above.
|
||||||
|
# Example (kept commented intentionally):
|
||||||
|
# import paho.mqtt.client as mqtt
|
||||||
|
# client = mqtt.Client(client_id=mqtt_cfg.get("client_id", "predictor-01"))
|
||||||
|
# if "username" in mqtt_cfg and mqtt_cfg.get("username"):
|
||||||
|
# client.username_pw_set(mqtt_cfg["username"], mqtt_cfg.get("password"))
|
||||||
|
# client.connect(mqtt_cfg.get("host", "localhost"), int(mqtt_cfg.get("port", 1883)), 60)
|
||||||
|
# client.publish(
|
||||||
|
# topic=topic,
|
||||||
|
# payload=payload,
|
||||||
|
# qos=int(mqtt_cfg.get("qos", 1)),
|
||||||
|
# retain=bool(mqtt_cfg.get("retain", False)),
|
||||||
|
# )
|
||||||
|
# client.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
def replace_nan(sample, config_file_path: Path):
|
||||||
|
with config_file_path.open("r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
|
fallback_list = cfg.get("fallback", [])
|
||||||
|
fallback_map = {}
|
||||||
|
for item in fallback_list:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
fallback_map.update(item)
|
||||||
|
|
||||||
|
if sample.empty:
|
||||||
|
return False, sample
|
||||||
|
|
||||||
|
nan_ratio = sample.isna().mean()
|
||||||
|
valid = nan_ratio <= 0.5
|
||||||
|
|
||||||
|
if valid and fallback_map:
|
||||||
|
sample = sample.fillna(value=fallback_map)
|
||||||
|
|
||||||
|
|
||||||
|
return valid, sample
|
||||||
|
|
||||||
|
def sample_to_numpy(sample, drop_cols=("_Id", "start_time")):
|
||||||
|
if isinstance(sample, pd.Series):
|
||||||
|
sample = sample.drop(labels=list(drop_cols), errors="ignore")
|
||||||
|
return sample.to_numpy()
|
||||||
|
|
||||||
|
if isinstance(sample, pd.DataFrame):
|
||||||
|
sample = sample.drop(columns=list(drop_cols), errors="ignore")
|
||||||
|
return sample.to_numpy()
|
||||||
|
|
||||||
|
return np.asarray(sample)
|
||||||
|
|
||||||
|
def scale_sample(sample, use_scaling=False, scaler_path=None):
|
||||||
|
if not use_scaling or scaler_path is None:
|
||||||
|
return sample
|
||||||
|
scaler_path = Path(scaler_path)
|
||||||
|
if not scaler_path.is_absolute():
|
||||||
|
scaler_path = Path.cwd() / scaler_path
|
||||||
|
scaler_path = scaler_path.resolve()
|
||||||
|
normalizer = _load_serialized(scaler_path)
|
||||||
|
|
||||||
|
# normalizer format from model_training/tools/scaler.py:
|
||||||
|
# {"scalers": {...}, "method": "...", "scope": "..."}
|
||||||
|
scalers = normalizer.get("scalers", {}) if isinstance(normalizer, dict) else {}
|
||||||
|
scope = normalizer.get("scope", "global") if isinstance(normalizer, dict) else "global"
|
||||||
|
if scope == "global":
|
||||||
|
scaler = scalers.get("global")
|
||||||
|
else:
|
||||||
|
scaler = scalers.get("global", next(iter(scalers.values()), None))
|
||||||
|
|
||||||
|
# Optional fallback if the stored object is already a raw scaler.
|
||||||
|
if scaler is None and hasattr(normalizer, "transform"):
|
||||||
|
scaler = normalizer
|
||||||
|
if scaler is None or not hasattr(scaler, "transform"):
|
||||||
|
return sample
|
||||||
|
|
||||||
|
df = sample.to_frame().T if isinstance(sample, pd.Series) else sample.copy()
|
||||||
|
feature_names = getattr(scaler, "feature_names_in_", None)
|
||||||
|
if feature_names is None:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# Keep columns not in the normalizer unchanged.
|
||||||
|
cols_to_scale = [c for c in df.columns if c in set(feature_names)]
|
||||||
|
if cols_to_scale:
|
||||||
|
df.loc[:, cols_to_scale] = scaler.transform(df.loc[:, cols_to_scale])
|
||||||
|
|
||||||
|
return df.iloc[0] if isinstance(sample, pd.Series) else df
|
||||||
|
|
||||||
|
def main():
|
||||||
|
pd.set_option('future.no_silent_downcasting', True) # kann ggf raus
|
||||||
|
|
||||||
|
config_file_path = Path("/home/edgekit/MSY_FS/fahrsimulator_msy2526_ai/predict_pipeline/config.yaml")
|
||||||
|
with config_file_path.open("r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
|
database_path = cfg["database"]["path"]
|
||||||
|
table_name = cfg["database"]["table"]
|
||||||
|
row_key = cfg["database"]["key"]
|
||||||
|
|
||||||
|
|
||||||
|
sample = getLastEntryFromSQLite(database_path, table_name, row_key)
|
||||||
|
valid, sample = replace_nan(sample, config_file_path=config_file_path)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
print("Sample invalid: more than 50% NaN.")
|
||||||
|
message = buildMessage(valid, None, config_file_path, sample=sample)
|
||||||
|
sendMessage(config_file_path, message)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_path = cfg["model"]["path"]
|
||||||
|
scaler_path = cfg["scaler"]["path"]
|
||||||
|
use_scaling = cfg["scaler"]["use_scaling"]
|
||||||
|
|
||||||
|
sample = scale_sample(sample, use_scaling=use_scaling, scaler_path=scaler_path)
|
||||||
|
sample_np = sample_to_numpy(sample)
|
||||||
|
|
||||||
|
prediction = callModel(model_path=model_path, sample=sample_np)
|
||||||
|
|
||||||
|
message = buildMessage(valid, prediction, config_file_path, sample=sample)
|
||||||
|
sendMessage(config_file_path, message)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
166
tools/db_helpers.py
Normal file
166
tools/db_helpers.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def connect_db(path_to_file: os.PathLike) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
|
||||||
|
''' Establishes a connection with a sqlite3 database. '''
|
||||||
|
conn = sqlite3.connect(path_to_file)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
return conn, cursor
|
||||||
|
|
||||||
|
def disconnect_db(conn: sqlite3.Connection, cursor: sqlite3.Cursor, commit: bool = True) -> None:
|
||||||
|
''' Commits all remaining changes and closes the connection with an sqlite3 database. '''
|
||||||
|
cursor.close()
|
||||||
|
if commit: conn.commit() # commit all pending changes made to the sqlite3 database before closing
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def create_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
cursor: sqlite3.Cursor,
|
||||||
|
table_name: str,
|
||||||
|
columns: dict,
|
||||||
|
constraints: dict,
|
||||||
|
primary_key: dict,
|
||||||
|
commit: bool = True
|
||||||
|
) -> str:
|
||||||
|
'''
|
||||||
|
Creates a new empty table with the given columns, constraints and primary key.
|
||||||
|
|
||||||
|
:param columns: dict with column names (=keys) and dtypes (=values) (e.g. BIGINT, INT, ...)
|
||||||
|
:param constraints: dict with column names (=keys) and list of constraints (=values) (like [\'NOT NULL\'(,...)])
|
||||||
|
:param primary_key: dict with primary key name (=key) and list of attributes which combined define the table's primary key (=values, like [\'att1\'(,...)])
|
||||||
|
'''
|
||||||
|
assert len(primary_key.keys()) == 1
|
||||||
|
sql = f'CREATE TABLE {table_name} (\n '
|
||||||
|
for column,dtype in columns.items():
|
||||||
|
sql += f'{column} {dtype}{" "+" ".join(constraints[column]) if column in constraints.keys() else ""},\n '
|
||||||
|
if list(primary_key.keys())[0]: sql += f'CONSTRAINT {list(primary_key.keys())[0]} '
|
||||||
|
sql += f'PRIMARY KEY ({", ".join(list(primary_key.values())[0])})\n)'
|
||||||
|
cursor.execute(sql)
|
||||||
|
if commit: conn.commit()
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def add_columns_to_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
cursor: sqlite3.Cursor,
|
||||||
|
table_name: str,
|
||||||
|
columns: dict,
|
||||||
|
constraints: dict = dict(),
|
||||||
|
commit: bool = True
|
||||||
|
) -> str:
|
||||||
|
''' Adds one/multiple columns (each with a list of constraints) to the given table. '''
|
||||||
|
sql_total = ''
|
||||||
|
for column,dtype in columns.items(): # sqlite can only add one column per query
|
||||||
|
sql = f'ALTER TABLE {table_name}\n '
|
||||||
|
sql += f'ADD "{column}" {dtype}{" "+" ".join(constraints[column]) if column in constraints.keys() else ""}'
|
||||||
|
sql_total += sql + '\n'
|
||||||
|
cursor.execute(sql)
|
||||||
|
if commit: conn.commit()
|
||||||
|
return sql_total
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def insert_rows_into_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
cursor: sqlite3.Cursor,
|
||||||
|
table_name: str,
|
||||||
|
columns: dict,
|
||||||
|
commit: bool = True
|
||||||
|
) -> str:
|
||||||
|
'''
|
||||||
|
Inserts values as multiple rows into the given table.
|
||||||
|
|
||||||
|
:param columns: dict with column names (=keys) and values to insert as lists with at least one element (=values)
|
||||||
|
|
||||||
|
Note: The number of given values per attribute must match the number of rows to insert!
|
||||||
|
Note: The values for the rows must be of normal python types (e.g. list, str, int, ...) instead of e.g. numpy arrays!
|
||||||
|
'''
|
||||||
|
assert len(set(map(len, columns.values()))) == 1, 'ERROR: Provide equal number of values for each column!'
|
||||||
|
assert len(set(list(map(type,columns.values())))) == 1 and isinstance(list(columns.values())[0], list), 'ERROR: Provide values as Python lists!'
|
||||||
|
assert set([type(a) for b in list(columns.values()) for a in b]).issubset({str,int,float,bool}), 'ERROR: Provide values as basic Python data types!'
|
||||||
|
|
||||||
|
values = list(zip(*columns.values()))
|
||||||
|
sql = f'INSERT INTO {table_name} ({", ".join(columns.keys())})\n VALUES ({("?,"*len(values[0]))[:-1]})'
|
||||||
|
cursor.executemany(sql, values)
|
||||||
|
if commit: conn.commit()
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def update_multiple_rows_in_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
cursor: sqlite3.Cursor,
|
||||||
|
table_name: str,
|
||||||
|
new_vals: dict,
|
||||||
|
conditions: str,
|
||||||
|
commit: bool = True
|
||||||
|
) -> str:
|
||||||
|
'''
|
||||||
|
Updates attribute values of some rows in the given table.
|
||||||
|
|
||||||
|
:param new_vals: dict with column names (=keys) and the new values to set (=values)
|
||||||
|
:param conditions: string which defines all concatenated conditions (e.g. \'cond1 AND (cond2 OR cond3)\' with cond1: att1=5, ...)
|
||||||
|
'''
|
||||||
|
assignments = ', '.join([f'{k}={v}' for k,v in zip(new_vals.keys(), new_vals.values())])
|
||||||
|
sql = f'UPDATE {table_name}\n SET {assignments}\n WHERE {conditions}'
|
||||||
|
cursor.execute(sql)
|
||||||
|
if commit: conn.commit()
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def delete_rows_from_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
cursor: sqlite3.Cursor,
|
||||||
|
table_name: str,
|
||||||
|
conditions: str,
|
||||||
|
commit: bool = True
|
||||||
|
) -> str:
|
||||||
|
'''
|
||||||
|
Deletes rows from the given table.
|
||||||
|
|
||||||
|
:param conditions: string which defines all concatenated conditions (e.g. \'cond1 AND (cond2 OR cond3)\' with cond1: att1=5, ...)
|
||||||
|
'''
|
||||||
|
sql = f'DELETE FROM {table_name} WHERE {conditions}'
|
||||||
|
cursor.execute(sql)
|
||||||
|
if commit: conn.commit()
|
||||||
|
return sql
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_from_table(
|
||||||
|
conn: sqlite3.Connection,
|
||||||
|
table_name: str,
|
||||||
|
columns_list: list = ['*'],
|
||||||
|
aggregations: [None,dict] = None,
|
||||||
|
where_conditions: [None,str] = None,
|
||||||
|
order_by: [None, dict] = None,
|
||||||
|
limit: [None, int] = None,
|
||||||
|
offset: [None, int] = None
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
'''
|
||||||
|
Helper function which returns (if desired: aggregated) contents from the given table as a pandas DataFrame. The rows can be filtered by providing the condition as a string.
|
||||||
|
|
||||||
|
:param columns_list: use if no aggregation is needed to select which columns to get from the table
|
||||||
|
:param (optional) aggregations: use to apply aggregations on the data from the table; dictionary with column(s) as key(s) and aggregation(s) as corresponding value(s) (e.g. {'col1': 'MIN', 'col2': 'AVG', ...} or {'*': 'COUNT'})
|
||||||
|
:param (optional) where_conditions: string which defines all concatenated conditions (e.g. \'cond1 AND (cond2 OR cond3)\' with cond1: att1=5, ...) applied on table.
|
||||||
|
:param (optional) order_by: dict defining the ordering of the outputs with column(s) as key(s) and ordering as corresponding value(s) (e.g. {'col1': 'ASC'})
|
||||||
|
:param (optional) limit: use to limit the number of returned rows
|
||||||
|
:param (optional) offset: use to skip the first n rows before displaying
|
||||||
|
|
||||||
|
Note: If aggregations is set, the columns_list is ignored.
|
||||||
|
Note: Get all data as a DataFrame with get_data_from_table(conn, table_name).
|
||||||
|
Note: If one output is wanted (e.g. count(*) or similar), get it with get_data_from_table(...).iloc[0,0] from the DataFrame.
|
||||||
|
'''
|
||||||
|
assert columns_list or aggregations
|
||||||
|
|
||||||
|
if aggregations:
|
||||||
|
selection = [f'{agg}({col})' for col,agg in aggregations.items()]
|
||||||
|
else:
|
||||||
|
selection = columns_list
|
||||||
|
selection = ", ".join(selection)
|
||||||
|
where_conditions = 'WHERE ' + where_conditions if where_conditions else ''
|
||||||
|
order_by = 'ORDER BY ' + ', '.join([f'{k} {v}' for k,v in order_by.items()]) if order_by else ''
|
||||||
|
limit = f'LIMIT {limit}' if limit else ''
|
||||||
|
offset = f'OFFSET {offset}' if offset else ''
|
||||||
|
|
||||||
|
sql = f'SELECT {selection} FROM {table_name} {where_conditions} {order_by} {limit} {offset}'
|
||||||
|
return pd.read_sql_query(sql, conn)
|
||||||
Loading…
x
Reference in New Issue
Block a user