#include #include #include #include "mnistVisualization.h" #include "imageInput.h" #include "neuralNetwork.h" // Matrix-Namenskonflikt mit Raylib lösen: Raylib's Matrix umbenennen #define Matrix RaylibMatrix #include "raylib.h" #undef Matrix #define MAX_TEXT_LEN 100 // Enum für die verschiedenen Modi typedef enum { MODE_BROWSE, MODE_DRAW } AppMode; typedef struct { Vector2 position; const GrayScaleImageSeries *series; const unsigned char *predictions; unsigned int currentIdx; Vector2 pixelSize; AppMode mode; GrayScaleImage drawingCanvas; unsigned char canvasPrediction; const NeuralNetwork *model; } MnistVisualization; typedef struct { char text[MAX_TEXT_LEN]; Vector2 position; unsigned int fontSize; Color color; } TextLabel; static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color color) { TextLabel *label = NULL; if(text != NULL) { label = (TextLabel *)malloc(sizeof(TextLabel)); if(label != NULL) { strncpy(label->text, text, MAX_TEXT_LEN); label->text[MAX_TEXT_LEN-1] = '\0'; label->position.x = 0; label->position.y = 0; label->fontSize = fontSize; label->color = color; } } return label; } static GrayScaleImage createCanvas(unsigned int width, unsigned int height) { GrayScaleImage canvas; canvas.width = width; canvas.height = height; canvas.buffer = (GrayScalePixelType *)calloc(width * height, sizeof(GrayScalePixelType)); return canvas; } static void clearCanvas(GrayScaleImage *canvas) { if(canvas != NULL && canvas->buffer != NULL) { memset(canvas->buffer, 0, canvas->width * canvas->height * sizeof(GrayScalePixelType)); } } static GrayScaleImage downsampleCanvas(const GrayScaleImage *largeCanvas, unsigned int targetWidth, unsigned int targetHeight) { GrayScaleImage smallImage = createCanvas(targetWidth, targetHeight); if(smallImage.buffer != NULL && largeCanvas != NULL && largeCanvas->buffer != NULL) { unsigned int scaleX = largeCanvas->width / targetWidth; unsigned int scaleY = largeCanvas->height / targetHeight; for(unsigned int y = 0; y < targetHeight; y++) { for(unsigned int x = 0; x < targetWidth; x++) { // Durchschnitt über den entsprechenden Bereich berechnen unsigned int sum = 0; unsigned int count = 0; for(unsigned int sy = 0; sy < scaleY; sy++) { for(unsigned int sx = 0; sx < scaleX; sx++) { unsigned int srcX = x * scaleX + sx; unsigned int srcY = y * scaleY + sy; if(srcX < largeCanvas->width && srcY < largeCanvas->height) { sum += largeCanvas->buffer[srcY * largeCanvas->width + srcX]; count++; } } } smallImage.buffer[y * targetWidth + x] = (unsigned char)(sum / count); } } } return smallImage; } static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, const NeuralNetwork *model) { MnistVisualization *container = NULL; if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL) { container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization)); if(container != NULL) { // Canvas ist 4x größer (112x112), also pixelSize anpassen Vector2 pixelSize = {(int)(size.x / (series->images[0].width * 4)), (int)(size.y / (series->images[0].height * 4))}; container->pixelSize = pixelSize; container->series = series; container->predictions = predictions; container->mode = MODE_BROWSE; container->model = model; // Canvas erstellen (4x größer als MNIST für feineres Zeichnen: 112x112 statt 28x28) unsigned int canvasSize = series->images[0].width * 4; container->drawingCanvas = createCanvas(canvasSize, canvasSize); container->canvasPrediction = 0; } } return container; } static void drawTextLabel(const TextLabel *label) { DrawText(label->text, label->position.x, label->position.y, label->fontSize, label->color); } static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixelSize) { float xOffset = 0; float yOffset = position.y - pixelSize.y; Color color = {0}; color.a = 255; for(int i = 0; i < image.width * image.height; i++) { if(i % image.width == 0) { xOffset = position.x; yOffset += pixelSize.y; } color.r = image.buffer[i]; color.b = image.buffer[i]; color.g = image.buffer[i]; DrawRectangle(xOffset, yOffset, pixelSize.x, pixelSize.y, color); xOffset += pixelSize.x; } } static void drawAll(const MnistVisualization *container, const TextLabel *navigationLabel, const TextLabel *predictionLabel, const TextLabel *modeLabel) { BeginDrawing(); ClearBackground(BLACK); if(container->mode == MODE_BROWSE) { // Im Browse-Modus: MNIST Bilder sind 28x28, aber pixelSize ist für 112x112 // Also pixelSize * 4 verwenden Vector2 browsePixelSize = {container->pixelSize.x * 4, container->pixelSize.y * 4}; drawDigit(container->series->images[container->currentIdx], container->position, browsePixelSize); } else // MODE_DRAW { // Im Draw-Modus: Canvas ist 112x112, pixelSize passt drawDigit(container->drawingCanvas, container->position, container->pixelSize); } drawTextLabel(navigationLabel); drawTextLabel(predictionLabel); drawTextLabel(modeLabel); EndDrawing(); } static int checkUserInput() { int inputResult = 0; if(IsKeyReleased(KEY_LEFT)) inputResult = -1; else if(IsKeyReleased(KEY_RIGHT)) inputResult = 1; return inputResult; } static void handleDrawing(MnistVisualization *container) { static Vector2 lastMousePos = {-1, -1}; if(IsMouseButtonDown(MOUSE_LEFT_BUTTON)) { Vector2 mousePos = GetMousePosition(); // Berechne welches Pixel geklickt wurde (pixelSize ist bereits für 112x112) int pixelX = (int)((mousePos.x - container->position.x) / container->pixelSize.x); int pixelY = (int)((mousePos.y - container->position.y) / container->pixelSize.y); // Wenn wir eine vorherige Position haben, zeichne eine Linie if(lastMousePos.x >= 0 && lastMousePos.y >= 0) { int lastPixelX = (int)((lastMousePos.x - container->position.x) / container->pixelSize.x); int lastPixelY = (int)((lastMousePos.y - container->position.y) / container->pixelSize.y); // Bresenham Linien-Algorithmus (vereinfacht) int dx = abs(pixelX - lastPixelX); int dy = abs(pixelY - lastPixelY); int sx = (lastPixelX < pixelX) ? 1 : -1; int sy = (lastPixelY < pixelY) ? 1 : -1; int err = dx - dy; int currentX = lastPixelX; int currentY = lastPixelY; while(1) { // Zeichne dünnen Pinsel für 112x112 Canvas if(currentX >= 0 && currentX < container->drawingCanvas.width && currentY >= 0 && currentY < container->drawingCanvas.height) { // Zentrum: 2x2 Pixel weiß (entspricht 0.5x0.5 auf 28x28) for(int dy = 0; dy <= 1; dy++) { for(int dx = 0; dx <= 1; dx++) { int nx = currentX + dx; int ny = currentY + dy; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; container->drawingCanvas.buffer[nidx] = 255; } } } // Ring 1: Direkte Nachbarn (sehr hell) int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}}; for(int i = 0; i < 8; i++) { int nx = currentX + ring1[i][0]; int ny = currentY + ring1[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 200) { container->drawingCanvas.buffer[nidx] = 200; } } } // Ring 2: Weitere Nachbarn (mittel) int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3}, {1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}}; for(int i = 0; i < 12; i++) { int nx = currentX + ring2[i][0]; int ny = currentY + ring2[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 140) { container->drawingCanvas.buffer[nidx] = 140; } } } // Ring 3: Äußere Nachbarn (dunkel) int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3}, {0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3}, {3,-1}, {3,2}, {4,0}, {4,1}}; for(int i = 0; i < 16; i++) { int nx = currentX + ring3[i][0]; int ny = currentY + ring3[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 80) { container->drawingCanvas.buffer[nidx] = 80; } } } } if(currentX == pixelX && currentY == pixelY) break; int e2 = 2 * err; if(e2 > -dy) { err -= dy; currentX += sx; } if(e2 < dx) { err += dx; currentY += sy; } } } else { // Erstes Pixel (kein Vorgänger) if(pixelX >= 0 && pixelX < container->drawingCanvas.width && pixelY >= 0 && pixelY < container->drawingCanvas.height) { // Gleiche Logik wie oben for(int dy = 0; dy <= 1; dy++) { for(int dx = 0; dx <= 1; dx++) { int nx = pixelX + dx; int ny = pixelY + dy; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; container->drawingCanvas.buffer[nidx] = 255; } } } int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}}; for(int i = 0; i < 8; i++) { int nx = pixelX + ring1[i][0]; int ny = pixelY + ring1[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 200) { container->drawingCanvas.buffer[nidx] = 200; } } } int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3}, {1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}}; for(int i = 0; i < 12; i++) { int nx = pixelX + ring2[i][0]; int ny = pixelY + ring2[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 140) { container->drawingCanvas.buffer[nidx] = 140; } } } int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3}, {0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3}, {3,-1}, {3,2}, {4,0}, {4,1}}; for(int i = 0; i < 16; i++) { int nx = pixelX + ring3[i][0]; int ny = pixelY + ring3[i][1]; if(nx >= 0 && nx < container->drawingCanvas.width && ny >= 0 && ny < container->drawingCanvas.height) { int nidx = ny * container->drawingCanvas.width + nx; if(container->drawingCanvas.buffer[nidx] < 80) { container->drawingCanvas.buffer[nidx] = 80; } } } } } lastMousePos = mousePos; } else { // Maustaste losgelassen - Reset der letzten Position lastMousePos.x = -1; lastMousePos.y = -1; } } static void updatePredictionForCanvas(MnistVisualization *container) { if(container->model != NULL && container->series != NULL && container->series->images != NULL) { // Canvas von 112x112 auf 28x28 runterskalieren unsigned int targetSize = container->series->images[0].width; GrayScaleImage downsampled = downsampleCanvas(&container->drawingCanvas, targetSize, targetSize); if(downsampled.buffer != NULL) { unsigned char *prediction = predict(*container->model, &downsampled, 1); if(prediction != NULL) { container->canvasPrediction = prediction[0]; free(prediction); } // Downsampled Image aufräumen free(downsampled.buffer); } } } static void updateDisplayContainer(MnistVisualization *container, int updateDirection) { int newIndex = (int)container->currentIdx + updateDirection; if(newIndex < 0) newIndex = 0; else if(newIndex >= container->series->count) newIndex = container->series->count - 1; container->currentIdx = newIndex; } static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel, AppMode mode) { if(mode == MODE_BROWSE) { snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel); } else // MODE_DRAW { snprintf(predictionLabel->text, MAX_TEXT_LEN, "Predicted label: %u", predictedLabel); } } static void updateModeLabel(TextLabel *modeLabel, AppMode mode) { if(mode == MODE_BROWSE) { snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: BROWSE | Press 'D' to draw"); } else // MODE_DRAW { snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: DRAW | Press 'B' to browse | Press 'C' to clear"); } } static void update(MnistVisualization *container, TextLabel *predictionLabel, TextLabel *modeLabel, int updateDirection) { // Mode-Wechsel if(IsKeyPressed(KEY_D)) { container->mode = MODE_DRAW; clearCanvas(&container->drawingCanvas); } else if(IsKeyPressed(KEY_B)) { container->mode = MODE_BROWSE; } // Canvas löschen im Draw-Mode if(container->mode == MODE_DRAW && IsKeyPressed(KEY_C)) { clearCanvas(&container->drawingCanvas); container->canvasPrediction = 0; } if(container->mode == MODE_BROWSE) { updateDisplayContainer(container, updateDirection); updatePredictionLabel(predictionLabel, container->series->labels[container->currentIdx], container->predictions[container->currentIdx], MODE_BROWSE); } else // MODE_DRAW { handleDrawing(container); // Prediction alle paar Frames aktualisieren (nicht bei jedem Frame für Performance) static int frameCounter = 0; frameCounter++; if(frameCounter % 10 == 0) { updatePredictionForCanvas(container); } updatePredictionLabel(predictionLabel, 0, container->canvasPrediction, MODE_DRAW); } updateModeLabel(modeLabel, container->mode); } void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[]) { // Model laden (für Draw-Modus) NeuralNetwork model = loadModel("mnist_model.info2"); const Vector2 windowSize = {windowWidth, windowHeight}; MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, &model); TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE); TextLabel *predictionLabel = createTextLabel("", 20, WHITE); TextLabel *modeLabel = createTextLabel("", 20, WHITE); navigationLabel->position.x = windowSize.x - 400; // Rechts (mit Abstand) navigationLabel->position.y = windowSize.y - 30; // Ganz unten predictionLabel->position.x = 10; predictionLabel->position.y = windowSize.y - 50; modeLabel->position.x = 10; modeLabel->position.y = 10; if(container != NULL && navigationLabel != NULL && predictionLabel != NULL && modeLabel != NULL) { InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Drawer"); SetTargetFPS(60); while (!WindowShouldClose()) { int updateDirection = checkUserInput(); update(container, predictionLabel, modeLabel, updateDirection); drawAll(container, navigationLabel, predictionLabel, modeLabel); } } CloseWindow(); // Cleanup if(container != NULL) { if(container->drawingCanvas.buffer != NULL) { free(container->drawingCanvas.buffer); } free(container); } free(navigationLabel); free(predictionLabel); free(modeLabel); clearModel(&model); }