- added comments to the code
- translated all outputs and comments to english
This commit is contained in:
parent
f19dde3f9a
commit
4cb06d0497
@ -1,79 +1,76 @@
|
||||
import cv2
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from feat import Detector
|
||||
import torch
|
||||
import mediapipe as mp
|
||||
import pandas as pd
|
||||
import db_helper as db
|
||||
|
||||
from pathlib import Path
|
||||
from eyeFeature_new import compute_features_from_parquet
|
||||
|
||||
# Suppress specific Protobuf deprecation warnings from the library
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*SymbolDatabase\.GetPrototype\(\) is deprecated.*",
|
||||
category=UserWarning,
|
||||
module=r"google\.protobuf\.symbol_database"
|
||||
)
|
||||
import cv2
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from feat import Detector
|
||||
import torch
|
||||
import mediapipe as mp
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from eyeFeature_new import compute_features_from_parquet
|
||||
|
||||
# Import your helper functions
|
||||
# from db_helper import connect_db, disconnect_db, insert_rows_into_table, create_table
|
||||
import db_helper as db
|
||||
|
||||
# Konfiguration
|
||||
# --- Configuration & Hyperparameters ---
|
||||
DB_PATH = Path("~/MSY_FS/databases/database.sqlite").expanduser()
|
||||
CAMERA_INDEX = 0
|
||||
OUTPUT_DIR = "recordings"
|
||||
VIDEO_DURATION = 50 # Sekunden
|
||||
START_INTERVAL = 5 # Sekunden bis zum nächsten Start
|
||||
FPS = 25.0 # Feste FPS
|
||||
OUTPUT_DIR = Path("recordings")
|
||||
VIDEO_DURATION = 50 # Seconds per recording segment
|
||||
START_INTERVAL = 5 # Delay between starting overlapping recordings
|
||||
FPS = 25.0 # Target Frames Per Second
|
||||
|
||||
# Global feature storage - Updated to be thread-safe in production environments
|
||||
eye_tracking_features = {}
|
||||
|
||||
if not os.path.exists(OUTPUT_DIR):
|
||||
os.makedirs(OUTPUT_DIR)
|
||||
if not OUTPUT_DIR.exists():
|
||||
OUTPUT_DIR.mkdir(parents=True)
|
||||
|
||||
# Globaler Detector, um ihn nicht bei jedem Video neu laden zu müssen (spart massiv Zeit/Speicher)
|
||||
print("Initialisiere AU-Detector (bitte warten)...")
|
||||
# Initialize the AU-Detector globally to optimize VRAM/RAM usage
|
||||
print("[INFO] Initializing Facial Action Unit Detector (XGB)...")
|
||||
detector = Detector(au_model="xgb")
|
||||
|
||||
# ===== MediaPipe FaceMesh Setup =====
|
||||
# --- MediaPipe FaceMesh Configuration ---
|
||||
mp_face_mesh = mp.solutions.face_mesh
|
||||
face_mesh = mp_face_mesh.FaceMesh(
|
||||
static_image_mode=False,
|
||||
max_num_faces=1,
|
||||
refine_landmarks=True, # wichtig für Iris
|
||||
refine_landmarks=True, # Mandatory for Iris tracking
|
||||
min_detection_confidence=0.5,
|
||||
min_tracking_confidence=0.5
|
||||
)
|
||||
|
||||
# Landmark Indices for Oculometrics
|
||||
LEFT_IRIS = [474, 475, 476, 477]
|
||||
RIGHT_IRIS = [469, 470, 471, 472]
|
||||
|
||||
LEFT_EYE_LIDS = (159, 145)
|
||||
RIGHT_EYE_LIDS = (386, 374)
|
||||
|
||||
EYE_OPEN_THRESHOLD = 6
|
||||
|
||||
LEFT_EYE_ALL = [33, 7, 163, 144, 145, 153, 154, 155,
|
||||
133, 173, 157, 158, 159, 160, 161, 246
|
||||
]
|
||||
|
||||
RIGHT_EYE_ALL = [263, 249, 390, 373, 374, 380, 381, 382,
|
||||
362, 398, 384, 385, 386, 387, 388, 466
|
||||
]
|
||||
|
||||
|
||||
# Bounding box indices for eye regions
|
||||
LEFT_EYE_ALL = [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246]
|
||||
RIGHT_EYE_ALL = [263, 249, 390, 373, 374, 380, 381, 382, 362, 398, 384, 385, 386, 387, 388, 466]
|
||||
|
||||
def eye_openness(landmarks, top_idx, bottom_idx, img_height):
|
||||
"""Calculates the vertical distance between eyelids normalized by image height."""
|
||||
top = landmarks[top_idx]
|
||||
bottom = landmarks[bottom_idx]
|
||||
return abs(top.y - bottom.y) * img_height
|
||||
|
||||
|
||||
def compute_gaze(landmarks, iris_center, eye_indices, w, h):
|
||||
"""
|
||||
Computes normalized gaze coordinates (0.0 to 1.0) relative to the eye's
|
||||
internal bounding box.
|
||||
"""
|
||||
iris_x, iris_y = iris_center
|
||||
|
||||
eye_points = []
|
||||
@ -101,42 +98,37 @@ def compute_gaze(landmarks, iris_center, eye_indices, w, h):
|
||||
return gaze_x, gaze_y
|
||||
|
||||
def extract_aus(path, skip_frames):
|
||||
|
||||
# torch.no_grad() deaktiviert die Gradientenberechnung.
|
||||
# Das löst den "Can't call numpy() on Tensor that requires grad" Fehler.
|
||||
"""
|
||||
Infers facial Action Units from video file.
|
||||
Uses torch.no_grad() to optimize inference and prevent memory leakage.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
try:
|
||||
video_prediction = detector.detect_video(
|
||||
path,
|
||||
skip_frames=skip_frames,
|
||||
face_detection_threshold=0.95
|
||||
)
|
||||
|
||||
# Falls video_prediction oder .aus noch Tensoren sind,
|
||||
# stellen wir sicher, dass sie korrekt summiert werden.
|
||||
try:
|
||||
# Wir nehmen die Summe der Action Units über alle detektierten Frames
|
||||
res = video_prediction.aus.mean()
|
||||
return res
|
||||
# Compute temporal mean of Action Units across the segment
|
||||
return video_prediction.aus.mean()
|
||||
except Exception as e:
|
||||
print(f"Fehler bei der Summenbildung: {e}")
|
||||
print(f"[ERROR] AU Extraction failed: {e}")
|
||||
return None
|
||||
|
||||
def startAU_creation(video_path, db_path):
|
||||
"""Diese Funktion läuft nun in einem eigenen Thread."""
|
||||
def process_and_store_analysis(video_path, db_path):
|
||||
"""
|
||||
Worker function: Handles AU extraction, data merging, and SQL persistence.
|
||||
Designed to run in a background thread.
|
||||
"""
|
||||
try:
|
||||
print(f"\n[THREAD START] Analyse läuft für: {video_path}")
|
||||
# skip_frames berechnen (z.B. alle 5 Sekunden bei 25 FPS = 125)
|
||||
output = extract_aus(video_path, skip_frames=int(FPS*5))
|
||||
|
||||
print(f"\n--- Ergebnis für {os.path.basename(video_path)} ---")
|
||||
print(output)
|
||||
print("--------------------------------------------------\n")
|
||||
print(f"[THREAD] Analyzing segment: {video_path}")
|
||||
# Analysis sampling: one frame every 5 seconds
|
||||
output = extract_aus(video_path, skip_frames=int(FPS * 5))
|
||||
if output is not None:
|
||||
# Verbindung für diesen Thread öffnen (SQLite Sicherheit)
|
||||
conn, cursor = db.connect_db(db_path)
|
||||
|
||||
# Daten vorbereiten: Timestamp + AU Ergebnisse
|
||||
# Wir wandeln die Series/Dataframe in ein Dictionary um
|
||||
# Prepare payload: Prefix keys to distinguish facial AUs
|
||||
data_to_insert = output.to_dict()
|
||||
|
||||
data_to_insert = {
|
||||
@ -149,25 +141,27 @@ def startAU_creation(video_path, db_path):
|
||||
data_to_insert['start_time'] = [ticks]
|
||||
data_to_insert = data_to_insert | eye_tracking_features
|
||||
|
||||
#data_to_insert['start_time'] = [datetime.now().strftime("%Y-%m-%d %H:%M:%S")]
|
||||
|
||||
# Da die AU-Spaltennamen dynamisch sind, stellen wir sicher, dass sie Listen sind
|
||||
# (insert_rows_into_table erwartet Listen für jeden Key)
|
||||
# making sure that dynamic AU-columns are lists
|
||||
# (insert_rows_into_table expects lists for every key)
|
||||
final_payload = {k: [v] if not isinstance(v, list) else v for k, v in data_to_insert.items()}
|
||||
|
||||
|
||||
db.insert_rows_into_table(conn, cursor, "feature_table", final_payload)
|
||||
|
||||
db.disconnect_db(conn, cursor)
|
||||
print(f"--- Ergebnis für {os.path.basename(video_path)} in DB gespeichert ---")
|
||||
print(f"[SUCCESS] Data persisted for {os.path.basename(video_path)}")
|
||||
|
||||
# Cleanup temporary files to save disk space
|
||||
os.remove(video_path)
|
||||
os.remove(video_path.replace(".avi", "_gaze.parquet"))
|
||||
print(f"Löschen der Datei: {video_path}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Fehler bei der Analyse von {video_path}: {e}")
|
||||
print(f"[ERROR] Threaded analysis failed for {video_path}: {e}")
|
||||
|
||||
|
||||
class VideoRecorder:
|
||||
"""Manages the asynchronous writing of video frames to disk."""
|
||||
def __init__(self, filename, width, height, db_path):
|
||||
self.gaze_data = []
|
||||
self.filename = filename
|
||||
@ -190,15 +184,16 @@ class VideoRecorder:
|
||||
self.out.release()
|
||||
self.is_finished = True
|
||||
abs_path = os.path.abspath(self.filename)
|
||||
print(f"Video fertig gespeichert: {self.filename}")
|
||||
print(f"Video saved: {self.filename}")
|
||||
|
||||
# --- MULTITHREADING HIER ---
|
||||
# Wir starten die Analyse in einem neuen Thread, damit main() sofort weiter frames lesen kann
|
||||
analysis_thread = threading.Thread(target=startAU_creation, args=(abs_path, self.db_path))
|
||||
analysis_thread.daemon = True # Beendet sich, wenn das Hauptprogramm schließt
|
||||
# Trigger background analysis thread
|
||||
# Passing a snapshot of eye_tracking_features to avoid race conditions
|
||||
analysis_thread = threading.Thread(target=process_and_store_analysis, args=(abs_path, self.db_path))
|
||||
analysis_thread.daemon = True # ends when the program ends
|
||||
analysis_thread.start()
|
||||
|
||||
class GazeRecorder:
|
||||
"""Handles the collection and Parquet serialization of oculometric data."""
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.frames_to_record = int(VIDEO_DURATION * FPS)
|
||||
@ -217,7 +212,9 @@ class GazeRecorder:
|
||||
if not self.is_finished:
|
||||
df = pd.DataFrame(self.gaze_data)
|
||||
df.to_parquet(self.filename, engine="pyarrow", index=False)
|
||||
print(f"Gaze-Parquet gespeichert: {self.filename}")
|
||||
|
||||
# Extract high-level features from raw gaze points
|
||||
print(f"Gaze-Parquet saved: {self.filename}")
|
||||
features = compute_features_from_parquet(self.filename)
|
||||
print("Features:", features)
|
||||
self.is_finished = True
|
||||
@ -226,7 +223,7 @@ class GazeRecorder:
|
||||
def main():
|
||||
cap = cv2.VideoCapture(CAMERA_INDEX)
|
||||
if not cap.isOpened():
|
||||
print("Fehler: Kamera konnte nicht geöffnet werden.")
|
||||
print("[CRITICAL] Could not access camera.")
|
||||
return
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
@ -236,7 +233,7 @@ def main():
|
||||
active_gaze_recorders = []
|
||||
last_start_time = 0
|
||||
|
||||
print("Aufnahme läuft. Drücke 'q' zum Beenden.")
|
||||
print("[INFO] Recording started. Press 'q' to terminate.")
|
||||
|
||||
try:
|
||||
while True:
|
||||
@ -244,10 +241,12 @@ def main():
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Pre-processing for MediaPipe
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
h, w, _ = frame.shape
|
||||
results = face_mesh.process(rgb)
|
||||
|
||||
# Default feature values
|
||||
left_valid = 0
|
||||
right_valid = 0
|
||||
left_diameter = None
|
||||
@ -366,7 +365,7 @@ def main():
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
print("Programm beendet. Warte ggf. auf laufende Analysen...")
|
||||
print("[INFO] Stream closed. Waiting for background analysis to complete...")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user