#include #include #include #include "mnistVisualization.h" /* Raylib-Fix */ #define Matrix RaylibMatrix #include "raylib.h" #undef Matrix #define MAX_TEXT_LEN 100 typedef struct { Vector2 position; /* Galerie */ const GrayScaleImageSeries *series; const unsigned char *predictions; unsigned int currentIdx; /* Malen */ GrayScaleImage canvas; unsigned char canvasPrediction; int isDrawingMode; Vector2 pixelSize; NeuralNetwork model; } MnistVisualization; typedef struct { char text[MAX_TEXT_LEN]; Vector2 position; unsigned int fontSize; Color color; } TextLabel; static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color color) { TextLabel *label = NULL; if(text != NULL) { label = (TextLabel *)malloc(sizeof(TextLabel)); if(label != NULL) { strncpy(label->text, text, MAX_TEXT_LEN); label->text[MAX_TEXT_LEN-1] = '\0'; label->position.x = 0; label->position.y = 0; label->fontSize = fontSize; label->color = color; } } return label; } static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, NeuralNetwork model) { MnistVisualization *container = NULL; if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL) { container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization)); if(container != NULL) { unsigned int imgW = series->images[0].width; unsigned int imgH = series->images[0].height; Vector2 pixelSize = {(int)(size.x / imgW), (int)(size.y / imgH)}; container->pixelSize = pixelSize; container->series = series; container->predictions = predictions; container->model = model; container->isDrawingMode = 0; container->canvas.width = imgW; container->canvas.height = imgH; container->canvas.buffer = (GrayScalePixelType*)calloc(imgW * imgH, sizeof(GrayScalePixelType)); } } return container; } static void drawTextLabel(const TextLabel *label) { DrawText(label->text, label->position.x, label->position.y, label->fontSize, label->color); } static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixelSize) { float xOffset = 0; float yOffset = position.y - pixelSize.y; Color color = {0}; color.a = 255; for(int i = 0; i < image.width * image.height; i++) { if(i % image.width == 0) { xOffset = position.x; yOffset += pixelSize.y; } color.r = image.buffer[i]; color.b = image.buffer[i]; color.g = image.buffer[i]; DrawRectangle(xOffset, yOffset, pixelSize.x, pixelSize.y, color); xOffset += pixelSize.x; } } static void drawAll(const MnistVisualization *container, const TextLabel *navLabel, const TextLabel *predLabel) { BeginDrawing(); ClearBackground(BLACK); if (container->isDrawingMode) { drawDigit(container->canvas, container->position, container->pixelSize); DrawText("DRAWING MODE", 10, 10, 20, RED); DrawText("Left Click: Draw | 'C': Clear | TAB: Gallery", 10, 35, 20, LIGHTGRAY); /* Text nach unten versetzt (Y=70), damit sichtbar */ DrawText(TextFormat("Predicted: %u", container->canvasPrediction), 10, 70, 50, GREEN); } else { drawDigit(container->series->images[container->currentIdx], container->position, container->pixelSize); DrawText("GALLERY MODE", 10, 10, 20, BLUE); DrawText("TAB: Drawing Mode", 10, 35, 20, LIGHTGRAY); drawTextLabel(navLabel); drawTextLabel(predLabel); } EndDrawing(); } static void clearCanvas(MnistVisualization *container) { if(container->canvas.buffer != NULL) memset(container->canvas.buffer, 0, container->canvas.width * container->canvas.height * sizeof(GrayScalePixelType)); } static void paintOnCanvas(MnistVisualization *container) { Vector2 mouse = GetMousePosition(); int gridX = (int)(mouse.x / container->pixelSize.x); int gridY = (int)(mouse.y / container->pixelSize.y); int brushSize = 1; for(int dy = -brushSize; dy <= brushSize; dy++) { for(int dx = -brushSize; dx <= brushSize; dx++) { int px = gridX + dx; int py = gridY + dy; if(px >= 0 && px < container->canvas.width && py >= 0 && py < container->canvas.height) { int idx = py * container->canvas.width + px; /* Zeichnen mit 255 (Weiß) */ container->canvas.buffer[idx] = 255; } } } } static void runLivePrediction(MnistVisualization *container) { /* HIER: Wir benutzen jetzt die neue Funktion predictLive für Zeichnungen */ unsigned char *result = predictLive(container->model, container->canvas); if(result != NULL) { container->canvasPrediction = result[0]; free(result); } } static int checkUserInput(MnistVisualization *container) { if (IsKeyReleased(KEY_TAB)) { container->isDrawingMode = !container->isDrawingMode; return 0; } if (container->isDrawingMode) { if (IsMouseButtonDown(MOUSE_LEFT_BUTTON)) { paintOnCanvas(container); runLivePrediction(container); } if (IsKeyPressed(KEY_C)) { clearCanvas(container); container->canvasPrediction = 0; } return 0; } else { if(IsKeyReleased(KEY_LEFT)) return -1; else if(IsKeyReleased(KEY_RIGHT)) return 1; } return 0; } static void updateDisplayContainer(MnistVisualization *container, int updateDirection) { int newIndex = (int)container->currentIdx + updateDirection; if(newIndex < 0) newIndex = 0; else if(newIndex >= container->series->count) newIndex = container->series->count - 1; container->currentIdx = newIndex; } static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel) { snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel); } static void update(MnistVisualization *container, TextLabel *predictionLabel, int updateDirection) { if (!container->isDrawingMode) { updateDisplayContainer(container, updateDirection); updatePredictionLabel(predictionLabel, container->series->labels[container->currentIdx], container->predictions[container->currentIdx]); } } void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[], NeuralNetwork model) { const Vector2 windowSize = {windowWidth, windowHeight}; MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, model); TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE); TextLabel *predictionLabel = createTextLabel("", 20, WHITE); predictionLabel->position.x = 10; predictionLabel->position.y = windowSize.y - 60; navigationLabel->position.y = windowSize.y - 90; navigationLabel->position.x = 10; if(container != NULL && navigationLabel != NULL && predictionLabel != NULL) { InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Painter"); SetTargetFPS(60); while (!WindowShouldClose()) { int updateDirection = checkUserInput(container); update(container, predictionLabel, updateDirection); drawAll(container, navigationLabel, predictionLabel); } } CloseWindow(); if (container != NULL) { if (container->canvas.buffer != NULL) free(container->canvas.buffer); free(container); } free(navigationLabel); free(predictionLabel); }