From d6f1596c0f508eafa32f6b4efc452af846f364e6 Mon Sep 17 00:00:00 2001 From: Nils Date: Tue, 2 Dec 2025 14:52:51 +0100 Subject: [PATCH] Optionale Aufgabe Eigene Ziffern malen --- mnistVisualization.c | 435 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 416 insertions(+), 19 deletions(-) diff --git a/mnistVisualization.c b/mnistVisualization.c index b0a29f5..9dc7148 100644 --- a/mnistVisualization.c +++ b/mnistVisualization.c @@ -2,10 +2,22 @@ #include #include #include "mnistVisualization.h" +#include "imageInput.h" +#include "neuralNetwork.h" + +// Matrix-Namenskonflikt mit Raylib lösen: Raylib's Matrix umbenennen +#define Matrix RaylibMatrix #include "raylib.h" +#undef Matrix #define MAX_TEXT_LEN 100 +// Enum für die verschiedenen Modi +typedef enum { + MODE_BROWSE, + MODE_DRAW +} AppMode; + typedef struct { Vector2 position; @@ -13,6 +25,10 @@ typedef struct const unsigned char *predictions; unsigned int currentIdx; Vector2 pixelSize; + AppMode mode; + GrayScaleImage drawingCanvas; + unsigned char canvasPrediction; + const NeuralNetwork *model; } MnistVisualization; typedef struct @@ -45,20 +61,86 @@ static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color return label; } -static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size) +static GrayScaleImage createCanvas(unsigned int width, unsigned int height) +{ + GrayScaleImage canvas; + canvas.width = width; + canvas.height = height; + canvas.buffer = (GrayScalePixelType *)calloc(width * height, sizeof(GrayScalePixelType)); + return canvas; +} + +static void clearCanvas(GrayScaleImage *canvas) +{ + if(canvas != NULL && canvas->buffer != NULL) + { + memset(canvas->buffer, 0, canvas->width * canvas->height * sizeof(GrayScalePixelType)); + } +} + +static GrayScaleImage downsampleCanvas(const GrayScaleImage *largeCanvas, unsigned int targetWidth, unsigned int targetHeight) +{ + GrayScaleImage smallImage = createCanvas(targetWidth, targetHeight); + + if(smallImage.buffer != NULL && largeCanvas != NULL && largeCanvas->buffer != NULL) + { + unsigned int scaleX = largeCanvas->width / targetWidth; + unsigned int scaleY = largeCanvas->height / targetHeight; + + for(unsigned int y = 0; y < targetHeight; y++) + { + for(unsigned int x = 0; x < targetWidth; x++) + { + // Durchschnitt über den entsprechenden Bereich berechnen + unsigned int sum = 0; + unsigned int count = 0; + + for(unsigned int sy = 0; sy < scaleY; sy++) + { + for(unsigned int sx = 0; sx < scaleX; sx++) + { + unsigned int srcX = x * scaleX + sx; + unsigned int srcY = y * scaleY + sy; + + if(srcX < largeCanvas->width && srcY < largeCanvas->height) + { + sum += largeCanvas->buffer[srcY * largeCanvas->width + srcX]; + count++; + } + } + } + + smallImage.buffer[y * targetWidth + x] = (unsigned char)(sum / count); + } + } + } + + return smallImage; +} + +static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, const NeuralNetwork *model) { MnistVisualization *container = NULL; - + if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL) { container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization)); if(container != NULL) { - Vector2 pixelSize = {(int)(size.x / series->images[0].width), (int)(size.y / series->images[0].height)}; + // Canvas ist 4x größer (112x112), also pixelSize anpassen + Vector2 pixelSize = {(int)(size.x / (series->images[0].width * 4)), + (int)(size.y / (series->images[0].height * 4))}; container->pixelSize = pixelSize; container->series = series; container->predictions = predictions; + container->mode = MODE_BROWSE; + container->model = model; + + // Canvas erstellen (4x größer als MNIST für feineres Zeichnen: 112x112 statt 28x28) + unsigned int canvasSize = series->images[0].width * 4; + container->drawingCanvas = createCanvas(canvasSize, canvasSize); + container->canvasPrediction = 0; } } @@ -94,15 +176,28 @@ static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixe } } -static void drawAll(const MnistVisualization *container, const TextLabel *navigationLabel, const TextLabel *predictionLabel) +static void drawAll(const MnistVisualization *container, const TextLabel *navigationLabel, const TextLabel *predictionLabel, const TextLabel *modeLabel) { BeginDrawing(); ClearBackground(BLACK); - - drawDigit(container->series->images[container->currentIdx], container->position, container->pixelSize); + + if(container->mode == MODE_BROWSE) + { + // Im Browse-Modus: MNIST Bilder sind 28x28, aber pixelSize ist für 112x112 + // Also pixelSize * 4 verwenden + Vector2 browsePixelSize = {container->pixelSize.x * 4, container->pixelSize.y * 4}; + drawDigit(container->series->images[container->currentIdx], container->position, browsePixelSize); + } + else // MODE_DRAW + { + // Im Draw-Modus: Canvas ist 112x112, pixelSize passt + drawDigit(container->drawingCanvas, container->position, container->pixelSize); + } + drawTextLabel(navigationLabel); drawTextLabel(predictionLabel); - + drawTextLabel(modeLabel); + EndDrawing(); } @@ -114,10 +209,232 @@ static int checkUserInput() inputResult = -1; else if(IsKeyReleased(KEY_RIGHT)) inputResult = 1; - + return inputResult; } +static void handleDrawing(MnistVisualization *container) +{ + static Vector2 lastMousePos = {-1, -1}; + + if(IsMouseButtonDown(MOUSE_LEFT_BUTTON)) + { + Vector2 mousePos = GetMousePosition(); + + // Berechne welches Pixel geklickt wurde (pixelSize ist bereits für 112x112) + int pixelX = (int)((mousePos.x - container->position.x) / container->pixelSize.x); + int pixelY = (int)((mousePos.y - container->position.y) / container->pixelSize.y); + + // Wenn wir eine vorherige Position haben, zeichne eine Linie + if(lastMousePos.x >= 0 && lastMousePos.y >= 0) + { + int lastPixelX = (int)((lastMousePos.x - container->position.x) / container->pixelSize.x); + int lastPixelY = (int)((lastMousePos.y - container->position.y) / container->pixelSize.y); + + // Bresenham Linien-Algorithmus (vereinfacht) + int dx = abs(pixelX - lastPixelX); + int dy = abs(pixelY - lastPixelY); + int sx = (lastPixelX < pixelX) ? 1 : -1; + int sy = (lastPixelY < pixelY) ? 1 : -1; + int err = dx - dy; + + int currentX = lastPixelX; + int currentY = lastPixelY; + + while(1) + { + // Zeichne dünnen Pinsel für 112x112 Canvas + if(currentX >= 0 && currentX < container->drawingCanvas.width && + currentY >= 0 && currentY < container->drawingCanvas.height) + { + // Zentrum: 2x2 Pixel weiß (entspricht 0.5x0.5 auf 28x28) + for(int dy = 0; dy <= 1; dy++) + { + for(int dx = 0; dx <= 1; dx++) + { + int nx = currentX + dx; + int ny = currentY + dy; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + container->drawingCanvas.buffer[nidx] = 255; + } + } + } + + // Ring 1: Direkte Nachbarn (sehr hell) + int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}}; + for(int i = 0; i < 8; i++) + { + int nx = currentX + ring1[i][0]; + int ny = currentY + ring1[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 200) { + container->drawingCanvas.buffer[nidx] = 200; + } + } + } + + // Ring 2: Weitere Nachbarn (mittel) + int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3}, + {1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}}; + for(int i = 0; i < 12; i++) + { + int nx = currentX + ring2[i][0]; + int ny = currentY + ring2[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 140) { + container->drawingCanvas.buffer[nidx] = 140; + } + } + } + + // Ring 3: Äußere Nachbarn (dunkel) + int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3}, + {0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3}, + {3,-1}, {3,2}, {4,0}, {4,1}}; + for(int i = 0; i < 16; i++) + { + int nx = currentX + ring3[i][0]; + int ny = currentY + ring3[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 80) { + container->drawingCanvas.buffer[nidx] = 80; + } + } + } + } + + if(currentX == pixelX && currentY == pixelY) break; + + int e2 = 2 * err; + if(e2 > -dy) + { + err -= dy; + currentX += sx; + } + if(e2 < dx) + { + err += dx; + currentY += sy; + } + } + } + else + { + // Erstes Pixel (kein Vorgänger) + if(pixelX >= 0 && pixelX < container->drawingCanvas.width && + pixelY >= 0 && pixelY < container->drawingCanvas.height) + { + // Gleiche Logik wie oben + for(int dy = 0; dy <= 1; dy++) + { + for(int dx = 0; dx <= 1; dx++) + { + int nx = pixelX + dx; + int ny = pixelY + dy; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + container->drawingCanvas.buffer[nidx] = 255; + } + } + } + + int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}}; + for(int i = 0; i < 8; i++) + { + int nx = pixelX + ring1[i][0]; + int ny = pixelY + ring1[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 200) { + container->drawingCanvas.buffer[nidx] = 200; + } + } + } + + int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3}, + {1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}}; + for(int i = 0; i < 12; i++) + { + int nx = pixelX + ring2[i][0]; + int ny = pixelY + ring2[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 140) { + container->drawingCanvas.buffer[nidx] = 140; + } + } + } + + int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3}, + {0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3}, + {3,-1}, {3,2}, {4,0}, {4,1}}; + for(int i = 0; i < 16; i++) + { + int nx = pixelX + ring3[i][0]; + int ny = pixelY + ring3[i][1]; + if(nx >= 0 && nx < container->drawingCanvas.width && + ny >= 0 && ny < container->drawingCanvas.height) + { + int nidx = ny * container->drawingCanvas.width + nx; + if(container->drawingCanvas.buffer[nidx] < 80) { + container->drawingCanvas.buffer[nidx] = 80; + } + } + } + } + } + + lastMousePos = mousePos; + } + else + { + // Maustaste losgelassen - Reset der letzten Position + lastMousePos.x = -1; + lastMousePos.y = -1; + } +} + +static void updatePredictionForCanvas(MnistVisualization *container) +{ + if(container->model != NULL && container->series != NULL && container->series->images != NULL) + { + // Canvas von 112x112 auf 28x28 runterskalieren + unsigned int targetSize = container->series->images[0].width; + GrayScaleImage downsampled = downsampleCanvas(&container->drawingCanvas, targetSize, targetSize); + + if(downsampled.buffer != NULL) + { + unsigned char *prediction = predict(*container->model, &downsampled, 1); + if(prediction != NULL) + { + container->canvasPrediction = prediction[0]; + free(prediction); + } + + // Downsampled Image aufräumen + free(downsampled.buffer); + } + } +} + static void updateDisplayContainer(MnistVisualization *container, int updateDirection) { int newIndex = (int)container->currentIdx + updateDirection; @@ -130,44 +447,124 @@ static void updateDisplayContainer(MnistVisualization *container, int updateDire container->currentIdx = newIndex; } -static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel) +static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel, AppMode mode) { - snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel); + if(mode == MODE_BROWSE) + { + snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel); + } + else // MODE_DRAW + { + snprintf(predictionLabel->text, MAX_TEXT_LEN, "Predicted label: %u", predictedLabel); + } } -static void update(MnistVisualization *container, TextLabel *predictionLabel, int updateDirection) +static void updateModeLabel(TextLabel *modeLabel, AppMode mode) { - updateDisplayContainer(container, updateDirection); - updatePredictionLabel(predictionLabel, container->series->labels[container->currentIdx], container->predictions[container->currentIdx]); + if(mode == MODE_BROWSE) + { + snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: BROWSE | Press 'D' to draw"); + } + else // MODE_DRAW + { + snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: DRAW | Press 'B' to browse | Press 'C' to clear"); + } +} + +static void update(MnistVisualization *container, TextLabel *predictionLabel, TextLabel *modeLabel, int updateDirection) +{ + // Mode-Wechsel + if(IsKeyPressed(KEY_D)) + { + container->mode = MODE_DRAW; + clearCanvas(&container->drawingCanvas); + } + else if(IsKeyPressed(KEY_B)) + { + container->mode = MODE_BROWSE; + } + + // Canvas löschen im Draw-Mode + if(container->mode == MODE_DRAW && IsKeyPressed(KEY_C)) + { + clearCanvas(&container->drawingCanvas); + container->canvasPrediction = 0; + } + + if(container->mode == MODE_BROWSE) + { + updateDisplayContainer(container, updateDirection); + updatePredictionLabel(predictionLabel, + container->series->labels[container->currentIdx], + container->predictions[container->currentIdx], + MODE_BROWSE); + } + else // MODE_DRAW + { + handleDrawing(container); + + // Prediction alle paar Frames aktualisieren (nicht bei jedem Frame für Performance) + static int frameCounter = 0; + frameCounter++; + if(frameCounter % 10 == 0) + { + updatePredictionForCanvas(container); + } + + updatePredictionLabel(predictionLabel, 0, container->canvasPrediction, MODE_DRAW); + } + + updateModeLabel(modeLabel, container->mode); } void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[]) { + // Model laden (für Draw-Modus) + NeuralNetwork model = loadModel("mnist_model.info2"); + const Vector2 windowSize = {windowWidth, windowHeight}; - MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize); + MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, &model); TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE); TextLabel *predictionLabel = createTextLabel("", 20, WHITE); + TextLabel *modeLabel = createTextLabel("", 20, WHITE); + + navigationLabel->position.x = windowSize.x - 400; // Rechts (mit Abstand) + navigationLabel->position.y = windowSize.y - 30; // Ganz unten + predictionLabel->position.x = 10; predictionLabel->position.y = windowSize.y - 50; - if(container != NULL && navigationLabel != NULL && predictionLabel != NULL) + modeLabel->position.x = 10; + modeLabel->position.y = 10; + + if(container != NULL && navigationLabel != NULL && predictionLabel != NULL && modeLabel != NULL) { - InitWindow(windowSize.x, windowSize.y, "MNIST Browser"); + InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Drawer"); SetTargetFPS(60); while (!WindowShouldClose()) { int updateDirection = checkUserInput(); - update(container, predictionLabel, updateDirection); - drawAll(container, navigationLabel, predictionLabel); + update(container, predictionLabel, modeLabel, updateDirection); + drawAll(container, navigationLabel, predictionLabel, modeLabel); } } CloseWindow(); - free(container); + // Cleanup + if(container != NULL) + { + if(container->drawingCanvas.buffer != NULL) + { + free(container->drawingCanvas.buffer); + } + free(container); + } free(navigationLabel); free(predictionLabel); + free(modeLabel); + clearModel(&model); } \ No newline at end of file