Neuronetz_V2/mnistVisualization.c
Efe Kaan Turhan a8fbf37709 Neuronetz_Fertig
Please enter the commit message for your changes. Lines starting
2025-11-25 19:52:51 +01:00

238 lines
7.9 KiB
C

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "mnistVisualization.h"
/* Raylib-Fix */
#define Matrix RaylibMatrix
#include "raylib.h"
#undef Matrix
#define MAX_TEXT_LEN 100
typedef struct
{
Vector2 position;
/* Galerie */
const GrayScaleImageSeries *series;
const unsigned char *predictions;
unsigned int currentIdx;
/* Malen */
GrayScaleImage canvas;
unsigned char canvasPrediction;
int isDrawingMode;
Vector2 pixelSize;
NeuralNetwork model;
} MnistVisualization;
typedef struct
{
char text[MAX_TEXT_LEN];
Vector2 position;
unsigned int fontSize;
Color color;
} TextLabel;
static TextLabel *createTextLabel(const char *text, unsigned int fontSize, Color color)
{
TextLabel *label = NULL;
if(text != NULL) {
label = (TextLabel *)malloc(sizeof(TextLabel));
if(label != NULL) {
strncpy(label->text, text, MAX_TEXT_LEN);
label->text[MAX_TEXT_LEN-1] = '\0';
label->position.x = 0; label->position.y = 0;
label->fontSize = fontSize; label->color = color;
}
}
return label;
}
static MnistVisualization *createVisualizationContainer(const GrayScaleImageSeries *series, const unsigned char predictions[], Vector2 size, NeuralNetwork model)
{
MnistVisualization *container = NULL;
if(size.x > 0 && size.y > 0 && series != NULL && series->images != NULL && series->count > 0 && predictions != NULL)
{
container = (MnistVisualization *)calloc(1, sizeof(MnistVisualization));
if(container != NULL)
{
unsigned int imgW = series->images[0].width;
unsigned int imgH = series->images[0].height;
Vector2 pixelSize = {(int)(size.x / imgW), (int)(size.y / imgH)};
container->pixelSize = pixelSize;
container->series = series;
container->predictions = predictions;
container->model = model;
container->isDrawingMode = 0;
container->canvas.width = imgW;
container->canvas.height = imgH;
container->canvas.buffer = (GrayScalePixelType*)calloc(imgW * imgH, sizeof(GrayScalePixelType));
}
}
return container;
}
static void drawTextLabel(const TextLabel *label)
{
DrawText(label->text, label->position.x, label->position.y, label->fontSize, label->color);
}
static void drawDigit(const GrayScaleImage image, Vector2 position, Vector2 pixelSize)
{
float xOffset = 0;
float yOffset = position.y - pixelSize.y;
Color color = {0};
color.a = 255;
for(int i = 0; i < image.width * image.height; i++)
{
if(i % image.width == 0) {
xOffset = position.x;
yOffset += pixelSize.y;
}
color.r = image.buffer[i]; color.b = image.buffer[i]; color.g = image.buffer[i];
DrawRectangle(xOffset, yOffset, pixelSize.x, pixelSize.y, color);
xOffset += pixelSize.x;
}
}
static void drawAll(const MnistVisualization *container, const TextLabel *navLabel, const TextLabel *predLabel)
{
BeginDrawing();
ClearBackground(BLACK);
if (container->isDrawingMode)
{
drawDigit(container->canvas, container->position, container->pixelSize);
DrawText("DRAWING MODE", 10, 10, 20, RED);
DrawText("Left Click: Draw | 'C': Clear | TAB: Gallery", 10, 35, 20, LIGHTGRAY);
/* Text nach unten versetzt (Y=70), damit sichtbar */
DrawText(TextFormat("Predicted: %u", container->canvasPrediction), 10, 70, 50, GREEN);
}
else
{
drawDigit(container->series->images[container->currentIdx], container->position, container->pixelSize);
DrawText("GALLERY MODE", 10, 10, 20, BLUE);
DrawText("TAB: Drawing Mode", 10, 35, 20, LIGHTGRAY);
drawTextLabel(navLabel);
drawTextLabel(predLabel);
}
EndDrawing();
}
static void clearCanvas(MnistVisualization *container)
{
if(container->canvas.buffer != NULL)
memset(container->canvas.buffer, 0, container->canvas.width * container->canvas.height * sizeof(GrayScalePixelType));
}
static void paintOnCanvas(MnistVisualization *container)
{
Vector2 mouse = GetMousePosition();
int gridX = (int)(mouse.x / container->pixelSize.x);
int gridY = (int)(mouse.y / container->pixelSize.y);
int brushSize = 1;
for(int dy = -brushSize; dy <= brushSize; dy++)
{
for(int dx = -brushSize; dx <= brushSize; dx++)
{
int px = gridX + dx;
int py = gridY + dy;
if(px >= 0 && px < container->canvas.width && py >= 0 && py < container->canvas.height)
{
int idx = py * container->canvas.width + px;
/* Zeichnen mit 255 (Weiß) */
container->canvas.buffer[idx] = 255;
}
}
}
}
static void runLivePrediction(MnistVisualization *container)
{
/* HIER: Wir benutzen jetzt die neue Funktion predictLive für Zeichnungen */
unsigned char *result = predictLive(container->model, container->canvas);
if(result != NULL)
{
container->canvasPrediction = result[0];
free(result);
}
}
static int checkUserInput(MnistVisualization *container)
{
if (IsKeyReleased(KEY_TAB)) {
container->isDrawingMode = !container->isDrawingMode;
return 0;
}
if (container->isDrawingMode) {
if (IsMouseButtonDown(MOUSE_LEFT_BUTTON)) {
paintOnCanvas(container);
runLivePrediction(container);
}
if (IsKeyPressed(KEY_C)) {
clearCanvas(container);
container->canvasPrediction = 0;
}
return 0;
} else {
if(IsKeyReleased(KEY_LEFT)) return -1;
else if(IsKeyReleased(KEY_RIGHT)) return 1;
}
return 0;
}
static void updateDisplayContainer(MnistVisualization *container, int updateDirection)
{
int newIndex = (int)container->currentIdx + updateDirection;
if(newIndex < 0) newIndex = 0;
else if(newIndex >= container->series->count) newIndex = container->series->count - 1;
container->currentIdx = newIndex;
}
static void updatePredictionLabel(TextLabel *predictionLabel, unsigned char trueLabel, unsigned char predictedLabel)
{
snprintf(predictionLabel->text, MAX_TEXT_LEN, "True label: %u\nPredicted label: %u", trueLabel, predictedLabel);
}
static void update(MnistVisualization *container, TextLabel *predictionLabel, int updateDirection)
{
if (!container->isDrawingMode) {
updateDisplayContainer(container, updateDirection);
updatePredictionLabel(predictionLabel, container->series->labels[container->currentIdx], container->predictions[container->currentIdx]);
}
}
void showMnist(unsigned int windowWidth, unsigned int windowHeight, const GrayScaleImageSeries *series, const unsigned char predictions[], NeuralNetwork model)
{
const Vector2 windowSize = {windowWidth, windowHeight};
MnistVisualization *container = createVisualizationContainer(series, predictions, windowSize, model);
TextLabel *navigationLabel = createTextLabel("Use left and right key to navigate ...", 20, WHITE);
TextLabel *predictionLabel = createTextLabel("", 20, WHITE);
predictionLabel->position.x = 10; predictionLabel->position.y = windowSize.y - 60;
navigationLabel->position.y = windowSize.y - 90; navigationLabel->position.x = 10;
if(container != NULL && navigationLabel != NULL && predictionLabel != NULL)
{
InitWindow(windowSize.x, windowSize.y, "MNIST Browser & Painter");
SetTargetFPS(60);
while (!WindowShouldClose())
{
int updateDirection = checkUserInput(container);
update(container, predictionLabel, updateDirection);
drawAll(container, navigationLabel, predictionLabel);
}
}
CloseWindow();
if (container != NULL) {
if (container->canvas.buffer != NULL) free(container->canvas.buffer);
free(container);
}
free(navigationLabel);
free(predictionLabel);
}