127 lines
4.0 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "matrix.h"
#include <stdlib.h>
// Matrix erstellen
Matrix createMatrix(unsigned int rows, unsigned int cols)
{
Matrix m;
if (rows == 0 || cols == 0) {
m.rows = 0;
m.cols = 0;
m.buffer = NULL;
return m;
}
m.rows = rows;
m.cols = cols;
m.buffer = (MatrixType*)calloc(rows * cols, sizeof(MatrixType));
if (!m.buffer) {
m.rows = 0;
m.cols = 0;
}
return m;
}
// Speicher freigeben
void clearMatrix(Matrix *matrix)
{
if (!matrix || !matrix->buffer) return;
free(matrix->buffer);
matrix->buffer = NULL;
matrix->rows = 0;
matrix->cols = 0;
}
// Wert setzen
void setMatrixAt(MatrixType value, Matrix matrix, unsigned int rowIdx, unsigned int colIdx)
{
if (rowIdx >= matrix.rows || colIdx >= matrix.cols) return;
matrix.buffer[rowIdx * matrix.cols + colIdx] = value;
}
// Wert auslesen
MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int colIdx)
{
if (rowIdx >= matrix.rows || colIdx >= matrix.cols) return UNDEFINED_MATRIX_VALUE;
return matrix.buffer[rowIdx * matrix.cols + colIdx];
}
// Addition (mit Broadcasting-Unterstützung für Bias)
Matrix add(const Matrix matrix1, const Matrix matrix2)
{
Matrix result;
// Fall 1: Exakte Dimensionen (Element-weise Addition)
if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) {
result = createMatrix(matrix1.rows, matrix1.cols);
for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++)
result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i];
return result;
}
// Fall 2: matrix1 ist (zeilen x 1) Spaltenvektor, matrix2 ist (zeilen x spalten)
// Broadcasting: matrix1's Spalte wird zu jeder Spalte von matrix2 addiert
if (matrix1.rows == matrix2.rows && matrix1.cols == 1) {
result = createMatrix(matrix2.rows, matrix2.cols);
for (unsigned int col = 0; col < matrix2.cols; col++) {
for (unsigned int row = 0; row < matrix2.rows; row++) {
MatrixType val1 = matrix1.buffer[row * matrix1.cols + 0];
MatrixType val2 = matrix2.buffer[row * matrix2.cols + col];
result.buffer[row * result.cols + col] = val1 + val2;
}
}
return result;
}
// Fall 3: matrix2 ist (zeilen x 1) Spaltenvektor, matrix1 ist (zeilen x spalten)
// Broadcasting: matrix2's Spalte wird zu jeder Spalte von matrix1 addiert
if (matrix2.rows == matrix1.rows && matrix2.cols == 1) {
result = createMatrix(matrix1.rows, matrix1.cols);
for (unsigned int col = 0; col < matrix1.cols; col++) {
for (unsigned int row = 0; row < matrix1.rows; row++) {
MatrixType val1 = matrix1.buffer[row * matrix1.cols + col];
MatrixType val2 = matrix2.buffer[row * matrix2.cols + 0];
result.buffer[row * result.cols + col] = val1 + val2;
}
}
return result;
}
// Ungültige Dimensionen - leere Matrix zurückgeben
result.rows = 0;
result.cols = 0;
result.buffer = NULL;
return result;
}
// Multiplikation
Matrix multiply(const Matrix matrix1, const Matrix matrix2)
{
Matrix result;
// Überprüfe ob Multiplikation möglich ist (Spalten matrix1 == Zeilen matrix2)
if (matrix1.cols != matrix2.rows) {
result.rows = 0;
result.cols = 0;
result.buffer = NULL;
return result;
}
result = createMatrix(matrix1.rows, matrix2.cols);
// Berechne alle Elemente des Ergebnisses
for (unsigned int i = 0; i < matrix1.rows; i++)
{
for (unsigned int j = 0; j < matrix2.cols; j++)
{
// Skalarprodukt: Reihe i von matrix1 × Spalte j von matrix2
MatrixType sum = 0;
for (unsigned int k = 0; k < matrix1.cols; k++)
sum += matrix1.buffer[i * matrix1.cols + k] * matrix2.buffer[k * matrix2.cols + j];
result.buffer[i * result.cols + j] = sum;
}
}
return result;
}