570 lines
20 KiB
C
570 lines
20 KiB
C
#include <stdlib.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
#include "mnistVisualization.h"
|
|
#include "imageInput.h"
|
|
#include "neuralNetwork.h"
|
|
|
|
// Matrix-Namenskonflikt mit Raylib lösen: Raylib's Matrix umbenennen
|
|
#define Matrix RaylibMatrix
|
|
#include "raylib.h"
|
|
#undef Matrix
|
|
|
|
#define MAX_TEXT_LEN 100
|
|
|
|
// Enum für die verschiedenen Modi
|
|
typedef enum {
|
|
MODE_BROWSE,
|
|
MODE_DRAW
|
|
} AppMode;
|
|
|
|
typedef struct
|
|
{
|
|
Vector2 position;
|
|
const GrayScaleImageSeries *series;
|
|
const unsigned char *predictions;
|
|
unsigned int currentIdx;
|
|
Vector2 pixelSize;
|
|
AppMode mode;
|
|
GrayScaleImage drawingCanvas;
|
|
unsigned char canvasPrediction;
|
|
const NeuralNetwork *model;
|
|
} MnistVisualization;
|
|
|
|
typedef struct
|
|
{
|
|
char text[MAX_TEXT_LEN];
|
|
Vector2 position;
|
|
unsigned int fontSize;
|
|
Color color;
|
|
} TextLabel;
|
|
|
|
static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color color)
|
|
{
|
|
TextLabel *label = NULL;
|
|
|
|
if(text != NULL)
|
|
{
|
|
label = (TextLabel *)malloc(sizeof(TextLabel));
|
|
|
|
if(label != NULL)
|
|
{
|
|
strncpy(label->text, text, MAX_TEXT_LEN);
|
|
label->text[MAX_TEXT_LEN-1] = '\0';
|
|
label->position.x = 0;
|
|
label->position.y = 0;
|
|
label->fontSize = fontSize;
|
|
label->color = color;
|
|
}
|
|
}
|
|
|
|
return label;
|
|
}
|
|
|
|
static GrayScaleImage createCanvas(unsigned int width, unsigned int height)
|
|
{
|
|
GrayScaleImage canvas;
|
|
canvas.width = width;
|
|
canvas.height = height;
|
|
canvas.buffer = (GrayScalePixelType *)calloc(width * height, sizeof(GrayScalePixelType));
|
|
return canvas;
|
|
}
|
|
|
|
static void clearCanvas(GrayScaleImage *canvas)
|
|
{
|
|
if(canvas != NULL && canvas->buffer != NULL)
|
|
{
|
|
memset(canvas->buffer, 0, canvas->width * canvas->height * sizeof(GrayScalePixelType));
|
|
}
|
|
}
|
|
|
|
static GrayScaleImage downsampleCanvas(const GrayScaleImage *largeCanvas, unsigned int targetWidth, unsigned int targetHeight)
|
|
{
|
|
GrayScaleImage smallImage = createCanvas(targetWidth, targetHeight);
|
|
|
|
if(smallImage.buffer != NULL && largeCanvas != NULL && largeCanvas->buffer != NULL)
|
|
{
|
|
unsigned int scaleX = largeCanvas->width / targetWidth;
|
|
unsigned int scaleY = largeCanvas->height / targetHeight;
|
|
|
|
for(unsigned int y = 0; y < targetHeight; y++)
|
|
{
|
|
for(unsigned int x = 0; x < targetWidth; x++)
|
|
{
|
|
// Durchschnitt über den entsprechenden Bereich berechnen
|
|
unsigned int sum = 0;
|
|
unsigned int count = 0;
|
|
|
|
for(unsigned int sy = 0; sy < scaleY; sy++)
|
|
{
|
|
for(unsigned int sx = 0; sx < scaleX; sx++)
|
|
{
|
|
unsigned int srcX = x * scaleX + sx;
|
|
unsigned int srcY = y * scaleY + sy;
|
|
|
|
if(srcX < largeCanvas->width && srcY < largeCanvas->height)
|
|
{
|
|
sum += largeCanvas->buffer[srcY * largeCanvas->width + srcX];
|
|
count++;
|
|
}
|
|
}
|
|
}
|
|
|
|
smallImage.buffer[y * targetWidth + x] = (unsigned char)(sum / count);
|
|
}
|
|
}
|
|
}
|
|
|
|
return smallImage;
|
|
}
|
|
|
|
static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, const NeuralNetwork *model)
|
|
{
|
|
MnistVisualization *container = NULL;
|
|
|
|
if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL)
|
|
{
|
|
container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization));
|
|
|
|
if(container != NULL)
|
|
{
|
|
// Canvas ist 4x größer (112x112), also pixelSize anpassen
|
|
Vector2 pixelSize = {(int)(size.x / (series->images[0].width * 4)),
|
|
(int)(size.y / (series->images[0].height * 4))};
|
|
container->pixelSize = pixelSize;
|
|
container->series = series;
|
|
container->predictions = predictions;
|
|
container->mode = MODE_BROWSE;
|
|
container->model = model;
|
|
|
|
// Canvas erstellen (4x größer als MNIST für feineres Zeichnen: 112x112 statt 28x28)
|
|
unsigned int canvasSize = series->images[0].width * 4;
|
|
container->drawingCanvas = createCanvas(canvasSize, canvasSize);
|
|
container->canvasPrediction = 0;
|
|
}
|
|
}
|
|
|
|
return container;
|
|
}
|
|
|
|
static void drawTextLabel(const TextLabel *label)
|
|
{
|
|
DrawText(label->text, label->position.x, label->position.y, label->fontSize, label->color);
|
|
}
|
|
|
|
static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixelSize)
|
|
{
|
|
float xOffset = 0;
|
|
float yOffset = position.y - pixelSize.y;
|
|
Color color = {0};
|
|
color.a = 255;
|
|
|
|
for(int i = 0; i < image.width * image.height; i++)
|
|
{
|
|
if(i % image.width == 0)
|
|
{
|
|
xOffset = position.x;
|
|
yOffset += pixelSize.y;
|
|
}
|
|
|
|
color.r = image.buffer[i];
|
|
color.b = image.buffer[i];
|
|
color.g = image.buffer[i];
|
|
|
|
DrawRectangle(xOffset, yOffset, pixelSize.x, pixelSize.y, color);
|
|
xOffset += pixelSize.x;
|
|
}
|
|
}
|
|
|
|
static void drawAll(const MnistVisualization *container, const TextLabel *navigationLabel, const TextLabel *predictionLabel, const TextLabel *modeLabel)
|
|
{
|
|
BeginDrawing();
|
|
ClearBackground(BLACK);
|
|
|
|
if(container->mode == MODE_BROWSE)
|
|
{
|
|
// Im Browse-Modus: MNIST Bilder sind 28x28, aber pixelSize ist für 112x112
|
|
// Also pixelSize * 4 verwenden
|
|
Vector2 browsePixelSize = {container->pixelSize.x * 4, container->pixelSize.y * 4};
|
|
drawDigit(container->series->images[container->currentIdx], container->position, browsePixelSize);
|
|
}
|
|
else // MODE_DRAW
|
|
{
|
|
// Im Draw-Modus: Canvas ist 112x112, pixelSize passt
|
|
drawDigit(container->drawingCanvas, container->position, container->pixelSize);
|
|
}
|
|
|
|
drawTextLabel(navigationLabel);
|
|
drawTextLabel(predictionLabel);
|
|
drawTextLabel(modeLabel);
|
|
|
|
EndDrawing();
|
|
}
|
|
|
|
static int checkUserInput()
|
|
{
|
|
int inputResult = 0;
|
|
|
|
if(IsKeyReleased(KEY_LEFT))
|
|
inputResult = -1;
|
|
else if(IsKeyReleased(KEY_RIGHT))
|
|
inputResult = 1;
|
|
|
|
return inputResult;
|
|
}
|
|
|
|
static void handleDrawing(MnistVisualization *container)
|
|
{
|
|
static Vector2 lastMousePos = {-1, -1};
|
|
|
|
if(IsMouseButtonDown(MOUSE_LEFT_BUTTON))
|
|
{
|
|
Vector2 mousePos = GetMousePosition();
|
|
|
|
// Berechne welches Pixel geklickt wurde (pixelSize ist bereits für 112x112)
|
|
int pixelX = (int)((mousePos.x - container->position.x) / container->pixelSize.x);
|
|
int pixelY = (int)((mousePos.y - container->position.y) / container->pixelSize.y);
|
|
|
|
// Wenn wir eine vorherige Position haben, zeichne eine Linie
|
|
if(lastMousePos.x >= 0 && lastMousePos.y >= 0)
|
|
{
|
|
int lastPixelX = (int)((lastMousePos.x - container->position.x) / container->pixelSize.x);
|
|
int lastPixelY = (int)((lastMousePos.y - container->position.y) / container->pixelSize.y);
|
|
|
|
// Bresenham Linien-Algorithmus (vereinfacht)
|
|
int dx = abs(pixelX - lastPixelX);
|
|
int dy = abs(pixelY - lastPixelY);
|
|
int sx = (lastPixelX < pixelX) ? 1 : -1;
|
|
int sy = (lastPixelY < pixelY) ? 1 : -1;
|
|
int err = dx - dy;
|
|
|
|
int currentX = lastPixelX;
|
|
int currentY = lastPixelY;
|
|
|
|
while(1)
|
|
{
|
|
// Zeichne dünnen Pinsel für 112x112 Canvas
|
|
if(currentX >= 0 && currentX < container->drawingCanvas.width &&
|
|
currentY >= 0 && currentY < container->drawingCanvas.height)
|
|
{
|
|
// Zentrum: 2x2 Pixel weiß (entspricht 0.5x0.5 auf 28x28)
|
|
for(int dy = 0; dy <= 1; dy++)
|
|
{
|
|
for(int dx = 0; dx <= 1; dx++)
|
|
{
|
|
int nx = currentX + dx;
|
|
int ny = currentY + dy;
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
container->drawingCanvas.buffer[nidx] = 255;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ring 1: Direkte Nachbarn (sehr hell)
|
|
int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}};
|
|
for(int i = 0; i < 8; i++)
|
|
{
|
|
int nx = currentX + ring1[i][0];
|
|
int ny = currentY + ring1[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 200) {
|
|
container->drawingCanvas.buffer[nidx] = 200;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ring 2: Weitere Nachbarn (mittel)
|
|
int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3},
|
|
{1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}};
|
|
for(int i = 0; i < 12; i++)
|
|
{
|
|
int nx = currentX + ring2[i][0];
|
|
int ny = currentY + ring2[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 140) {
|
|
container->drawingCanvas.buffer[nidx] = 140;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ring 3: Äußere Nachbarn (dunkel)
|
|
int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3},
|
|
{0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3},
|
|
{3,-1}, {3,2}, {4,0}, {4,1}};
|
|
for(int i = 0; i < 16; i++)
|
|
{
|
|
int nx = currentX + ring3[i][0];
|
|
int ny = currentY + ring3[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 80) {
|
|
container->drawingCanvas.buffer[nidx] = 80;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if(currentX == pixelX && currentY == pixelY) break;
|
|
|
|
int e2 = 2 * err;
|
|
if(e2 > -dy)
|
|
{
|
|
err -= dy;
|
|
currentX += sx;
|
|
}
|
|
if(e2 < dx)
|
|
{
|
|
err += dx;
|
|
currentY += sy;
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Erstes Pixel (kein Vorgänger)
|
|
if(pixelX >= 0 && pixelX < container->drawingCanvas.width &&
|
|
pixelY >= 0 && pixelY < container->drawingCanvas.height)
|
|
{
|
|
// Gleiche Logik wie oben
|
|
for(int dy = 0; dy <= 1; dy++)
|
|
{
|
|
for(int dx = 0; dx <= 1; dx++)
|
|
{
|
|
int nx = pixelX + dx;
|
|
int ny = pixelY + dy;
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
container->drawingCanvas.buffer[nidx] = 255;
|
|
}
|
|
}
|
|
}
|
|
|
|
int ring1[][2] = {{-1,0}, {-1,1}, {0,-1}, {2,0}, {2,1}, {0,2}, {1,2}, {1,-1}};
|
|
for(int i = 0; i < 8; i++)
|
|
{
|
|
int nx = pixelX + ring1[i][0];
|
|
int ny = pixelY + ring1[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 200) {
|
|
container->drawingCanvas.buffer[nidx] = 200;
|
|
}
|
|
}
|
|
}
|
|
|
|
int ring2[][2] = {{-2,0}, {-2,1}, {-1,-1}, {-1,2}, {0,-2}, {0,3},
|
|
{1,-2}, {1,3}, {2,-1}, {2,2}, {3,0}, {3,1}};
|
|
for(int i = 0; i < 12; i++)
|
|
{
|
|
int nx = pixelX + ring2[i][0];
|
|
int ny = pixelY + ring2[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 140) {
|
|
container->drawingCanvas.buffer[nidx] = 140;
|
|
}
|
|
}
|
|
}
|
|
|
|
int ring3[][2] = {{-3,0}, {-3,1}, {-2,-1}, {-2,2}, {-1,-2}, {-1,3},
|
|
{0,-3}, {0,4}, {1,-3}, {1,4}, {2,-2}, {2,3},
|
|
{3,-1}, {3,2}, {4,0}, {4,1}};
|
|
for(int i = 0; i < 16; i++)
|
|
{
|
|
int nx = pixelX + ring3[i][0];
|
|
int ny = pixelY + ring3[i][1];
|
|
if(nx >= 0 && nx < container->drawingCanvas.width &&
|
|
ny >= 0 && ny < container->drawingCanvas.height)
|
|
{
|
|
int nidx = ny * container->drawingCanvas.width + nx;
|
|
if(container->drawingCanvas.buffer[nidx] < 80) {
|
|
container->drawingCanvas.buffer[nidx] = 80;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
lastMousePos = mousePos;
|
|
}
|
|
else
|
|
{
|
|
// Maustaste losgelassen - Reset der letzten Position
|
|
lastMousePos.x = -1;
|
|
lastMousePos.y = -1;
|
|
}
|
|
}
|
|
|
|
static void updatePredictionForCanvas(MnistVisualization *container)
|
|
{
|
|
if(container->model != NULL && container->series != NULL && container->series->images != NULL)
|
|
{
|
|
// Canvas von 112x112 auf 28x28 runterskalieren
|
|
unsigned int targetSize = container->series->images[0].width;
|
|
GrayScaleImage downsampled = downsampleCanvas(&container->drawingCanvas, targetSize, targetSize);
|
|
|
|
if(downsampled.buffer != NULL)
|
|
{
|
|
unsigned char *prediction = predict(*container->model, &downsampled, 1);
|
|
if(prediction != NULL)
|
|
{
|
|
container->canvasPrediction = prediction[0];
|
|
free(prediction);
|
|
}
|
|
|
|
// Downsampled Image aufräumen
|
|
free(downsampled.buffer);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void updateDisplayContainer(MnistVisualization *container, int updateDirection)
|
|
{
|
|
int newIndex = (int)container->currentIdx + updateDirection;
|
|
|
|
if(newIndex < 0)
|
|
newIndex = 0;
|
|
else if(newIndex >= container->series->count)
|
|
newIndex = container->series->count - 1;
|
|
|
|
container->currentIdx = newIndex;
|
|
}
|
|
|
|
static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel, AppMode mode)
|
|
{
|
|
if(mode == MODE_BROWSE)
|
|
{
|
|
snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel);
|
|
}
|
|
else // MODE_DRAW
|
|
{
|
|
snprintf(predictionLabel->text, MAX_TEXT_LEN, "Predicted label: %u", predictedLabel);
|
|
}
|
|
}
|
|
|
|
static void updateModeLabel(TextLabel *modeLabel, AppMode mode)
|
|
{
|
|
if(mode == MODE_BROWSE)
|
|
{
|
|
snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: BROWSE | Press 'D' to draw");
|
|
}
|
|
else // MODE_DRAW
|
|
{
|
|
snprintf(modeLabel->text, MAX_TEXT_LEN, "Mode: DRAW | Press 'B' to browse | Press 'C' to clear");
|
|
}
|
|
}
|
|
|
|
static void update(MnistVisualization *container, TextLabel *predictionLabel, TextLabel *modeLabel, int updateDirection)
|
|
{
|
|
// Mode-Wechsel
|
|
if(IsKeyPressed(KEY_D))
|
|
{
|
|
container->mode = MODE_DRAW;
|
|
clearCanvas(&container->drawingCanvas);
|
|
}
|
|
else if(IsKeyPressed(KEY_B))
|
|
{
|
|
container->mode = MODE_BROWSE;
|
|
}
|
|
|
|
// Canvas löschen im Draw-Mode
|
|
if(container->mode == MODE_DRAW && IsKeyPressed(KEY_C))
|
|
{
|
|
clearCanvas(&container->drawingCanvas);
|
|
container->canvasPrediction = 0;
|
|
}
|
|
|
|
if(container->mode == MODE_BROWSE)
|
|
{
|
|
updateDisplayContainer(container, updateDirection);
|
|
updatePredictionLabel(predictionLabel,
|
|
container->series->labels[container->currentIdx],
|
|
container->predictions[container->currentIdx],
|
|
MODE_BROWSE);
|
|
}
|
|
else // MODE_DRAW
|
|
{
|
|
handleDrawing(container);
|
|
|
|
// Prediction alle paar Frames aktualisieren (nicht bei jedem Frame für Performance)
|
|
static int frameCounter = 0;
|
|
frameCounter++;
|
|
if(frameCounter % 10 == 0)
|
|
{
|
|
updatePredictionForCanvas(container);
|
|
}
|
|
|
|
updatePredictionLabel(predictionLabel, 0, container->canvasPrediction, MODE_DRAW);
|
|
}
|
|
|
|
updateModeLabel(modeLabel, container->mode);
|
|
}
|
|
|
|
void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[])
|
|
{
|
|
// Model laden (für Draw-Modus)
|
|
NeuralNetwork model = loadModel("mnist_model.info2");
|
|
|
|
const Vector2 windowSize = {windowWidth, windowHeight};
|
|
|
|
MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, &model);
|
|
TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE);
|
|
TextLabel *predictionLabel = createTextLabel("", 20, WHITE);
|
|
TextLabel *modeLabel = createTextLabel("", 20, WHITE);
|
|
|
|
navigationLabel->position.x = windowSize.x - 400; // Rechts (mit Abstand)
|
|
navigationLabel->position.y = windowSize.y - 30; // Ganz unten
|
|
|
|
predictionLabel->position.x = 10;
|
|
predictionLabel->position.y = windowSize.y - 50;
|
|
|
|
modeLabel->position.x = 10;
|
|
modeLabel->position.y = 10;
|
|
|
|
if(container != NULL && navigationLabel != NULL && predictionLabel != NULL && modeLabel != NULL)
|
|
{
|
|
InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Drawer");
|
|
|
|
SetTargetFPS(60);
|
|
|
|
while (!WindowShouldClose())
|
|
{
|
|
int updateDirection = checkUserInput();
|
|
update(container, predictionLabel, modeLabel, updateDirection);
|
|
drawAll(container, navigationLabel, predictionLabel, modeLabel);
|
|
}
|
|
}
|
|
|
|
CloseWindow();
|
|
|
|
// Cleanup
|
|
if(container != NULL)
|
|
{
|
|
if(container->drawingCanvas.buffer != NULL)
|
|
{
|
|
free(container->drawingCanvas.buffer);
|
|
}
|
|
free(container);
|
|
}
|
|
free(navigationLabel);
|
|
free(predictionLabel);
|
|
free(modeLabel);
|
|
clearModel(&model);
|
|
} |