238 lines
7.9 KiB
C
238 lines
7.9 KiB
C
#include <stdlib.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
#include "mnistVisualization.h"
|
|
|
|
/* Raylib-Fix */
|
|
#define Matrix RaylibMatrix
|
|
#include "raylib.h"
|
|
#undef Matrix
|
|
|
|
#define MAX_TEXT_LEN 100
|
|
|
|
typedef struct
|
|
{
|
|
Vector2 position;
|
|
/* Galerie */
|
|
const GrayScaleImageSeries *series;
|
|
const unsigned char *predictions;
|
|
unsigned int currentIdx;
|
|
/* Malen */
|
|
GrayScaleImage canvas;
|
|
unsigned char canvasPrediction;
|
|
int isDrawingMode;
|
|
|
|
Vector2 pixelSize;
|
|
NeuralNetwork model;
|
|
} MnistVisualization;
|
|
|
|
typedef struct
|
|
{
|
|
char text[MAX_TEXT_LEN];
|
|
Vector2 position;
|
|
unsigned int fontSize;
|
|
Color color;
|
|
} TextLabel;
|
|
|
|
static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color color)
|
|
{
|
|
TextLabel *label = NULL;
|
|
if(text != NULL) {
|
|
label = (TextLabel *)malloc(sizeof(TextLabel));
|
|
if(label != NULL) {
|
|
strncpy(label->text, text, MAX_TEXT_LEN);
|
|
label->text[MAX_TEXT_LEN-1] = '\0';
|
|
label->position.x = 0; label->position.y = 0;
|
|
label->fontSize = fontSize; label->color = color;
|
|
}
|
|
}
|
|
return label;
|
|
}
|
|
|
|
static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, NeuralNetwork model)
|
|
{
|
|
MnistVisualization *container = NULL;
|
|
if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL)
|
|
{
|
|
container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization));
|
|
if(container != NULL)
|
|
{
|
|
unsigned int imgW = series->images[0].width;
|
|
unsigned int imgH = series->images[0].height;
|
|
Vector2 pixelSize = {(int)(size.x / imgW), (int)(size.y / imgH)};
|
|
container->pixelSize = pixelSize;
|
|
container->series = series;
|
|
container->predictions = predictions;
|
|
container->model = model;
|
|
container->isDrawingMode = 0;
|
|
|
|
container->canvas.width = imgW;
|
|
container->canvas.height = imgH;
|
|
container->canvas.buffer = (GrayScalePixelType*)calloc(imgW * imgH, sizeof(GrayScalePixelType));
|
|
}
|
|
}
|
|
return container;
|
|
}
|
|
|
|
static void drawTextLabel(const TextLabel *label)
|
|
{
|
|
DrawText(label->text, label->position.x, label->position.y, label->fontSize, label->color);
|
|
}
|
|
|
|
static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixelSize)
|
|
{
|
|
float xOffset = 0;
|
|
float yOffset = position.y - pixelSize.y;
|
|
Color color = {0};
|
|
color.a = 255;
|
|
|
|
for(int i = 0; i < image.width * image.height; i++)
|
|
{
|
|
if(i % image.width == 0) {
|
|
xOffset = position.x;
|
|
yOffset += pixelSize.y;
|
|
}
|
|
color.r = image.buffer[i]; color.b = image.buffer[i]; color.g = image.buffer[i];
|
|
DrawRectangle(xOffset, yOffset, pixelSize.x, pixelSize.y, color);
|
|
xOffset += pixelSize.x;
|
|
}
|
|
}
|
|
|
|
static void drawAll(const MnistVisualization *container, const TextLabel *navLabel, const TextLabel *predLabel)
|
|
{
|
|
BeginDrawing();
|
|
ClearBackground(BLACK);
|
|
|
|
if (container->isDrawingMode)
|
|
{
|
|
drawDigit(container->canvas, container->position, container->pixelSize);
|
|
DrawText("DRAWING MODE", 10, 10, 20, RED);
|
|
DrawText("Left Click: Draw | 'C': Clear | TAB: Gallery", 10, 35, 20, LIGHTGRAY);
|
|
/* Text nach unten versetzt (Y=70), damit sichtbar */
|
|
DrawText(TextFormat("Predicted: %u", container->canvasPrediction), 10, 70, 50, GREEN);
|
|
}
|
|
else
|
|
{
|
|
drawDigit(container->series->images[container->currentIdx], container->position, container->pixelSize);
|
|
DrawText("GALLERY MODE", 10, 10, 20, BLUE);
|
|
DrawText("TAB: Drawing Mode", 10, 35, 20, LIGHTGRAY);
|
|
drawTextLabel(navLabel);
|
|
drawTextLabel(predLabel);
|
|
}
|
|
EndDrawing();
|
|
}
|
|
|
|
static void clearCanvas(MnistVisualization *container)
|
|
{
|
|
if(container->canvas.buffer != NULL)
|
|
memset(container->canvas.buffer, 0, container->canvas.width * container->canvas.height * sizeof(GrayScalePixelType));
|
|
}
|
|
|
|
static void paintOnCanvas(MnistVisualization *container)
|
|
{
|
|
Vector2 mouse = GetMousePosition();
|
|
int gridX = (int)(mouse.x / container->pixelSize.x);
|
|
int gridY = (int)(mouse.y / container->pixelSize.y);
|
|
int brushSize = 1;
|
|
|
|
for(int dy = -brushSize; dy <= brushSize; dy++)
|
|
{
|
|
for(int dx = -brushSize; dx <= brushSize; dx++)
|
|
{
|
|
int px = gridX + dx;
|
|
int py = gridY + dy;
|
|
if(px >= 0 && px < container->canvas.width && py >= 0 && py < container->canvas.height)
|
|
{
|
|
int idx = py * container->canvas.width + px;
|
|
/* Zeichnen mit 255 (Weiß) */
|
|
container->canvas.buffer[idx] = 255;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void runLivePrediction(MnistVisualization *container)
|
|
{
|
|
/* HIER: Wir benutzen jetzt die neue Funktion predictLive für Zeichnungen */
|
|
unsigned char *result = predictLive(container->model, container->canvas);
|
|
if(result != NULL)
|
|
{
|
|
container->canvasPrediction = result[0];
|
|
free(result);
|
|
}
|
|
}
|
|
|
|
static int checkUserInput(MnistVisualization *container)
|
|
{
|
|
if (IsKeyReleased(KEY_TAB)) {
|
|
container->isDrawingMode = !container->isDrawingMode;
|
|
return 0;
|
|
}
|
|
|
|
if (container->isDrawingMode) {
|
|
if (IsMouseButtonDown(MOUSE_LEFT_BUTTON)) {
|
|
paintOnCanvas(container);
|
|
runLivePrediction(container);
|
|
}
|
|
if (IsKeyPressed(KEY_C)) {
|
|
clearCanvas(container);
|
|
container->canvasPrediction = 0;
|
|
}
|
|
return 0;
|
|
} else {
|
|
if(IsKeyReleased(KEY_LEFT)) return -1;
|
|
else if(IsKeyReleased(KEY_RIGHT)) return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
static void updateDisplayContainer(MnistVisualization *container, int updateDirection)
|
|
{
|
|
int newIndex = (int)container->currentIdx + updateDirection;
|
|
if(newIndex < 0) newIndex = 0;
|
|
else if(newIndex >= container->series->count) newIndex = container->series->count - 1;
|
|
container->currentIdx = newIndex;
|
|
}
|
|
|
|
static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel)
|
|
{
|
|
snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel);
|
|
}
|
|
|
|
static void update(MnistVisualization *container, TextLabel *predictionLabel, int updateDirection)
|
|
{
|
|
if (!container->isDrawingMode) {
|
|
updateDisplayContainer(container, updateDirection);
|
|
updatePredictionLabel(predictionLabel, container->series->labels[container->currentIdx], container->predictions[container->currentIdx]);
|
|
}
|
|
}
|
|
|
|
void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[], NeuralNetwork model)
|
|
{
|
|
const Vector2 windowSize = {windowWidth, windowHeight};
|
|
MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, model);
|
|
TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE);
|
|
TextLabel *predictionLabel = createTextLabel("", 20, WHITE);
|
|
|
|
predictionLabel->position.x = 10; predictionLabel->position.y = windowSize.y - 60;
|
|
navigationLabel->position.y = windowSize.y - 90; navigationLabel->position.x = 10;
|
|
|
|
if(container != NULL && navigationLabel != NULL && predictionLabel != NULL)
|
|
{
|
|
InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Painter");
|
|
SetTargetFPS(60);
|
|
while (!WindowShouldClose())
|
|
{
|
|
int updateDirection = checkUserInput(container);
|
|
update(container, predictionLabel, updateDirection);
|
|
drawAll(container, navigationLabel, predictionLabel);
|
|
}
|
|
}
|
|
CloseWindow();
|
|
if (container != NULL) {
|
|
if (container->canvas.buffer != NULL) free(container->canvas.buffer);
|
|
free(container);
|
|
}
|
|
free(navigationLabel);
|
|
free(predictionLabel);
|
|
} |