clean up and improve allocation error handling

This commit is contained in:
Simon Wiesend 2025-11-21 09:09:55 +01:00
parent 6137e45bdb
commit 92ad1e1c31
Signed by: wiesendsi102436
GPG Key ID: C18A833054142CF0

View File

@ -1,6 +1,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include "matrix.h" #include "matrix.h"
#include <stdio.h>
Matrix createMatrix(unsigned int rows, unsigned int cols) Matrix createMatrix(unsigned int rows, unsigned int cols)
{ {
@ -20,6 +21,7 @@ Matrix createMatrix(unsigned int rows, unsigned int cols)
if (mat.buffer == NULL) if (mat.buffer == NULL)
{ {
clearMatrix(&mat); clearMatrix(&mat);
perror("could not allocate memory");
} }
return mat; return mat;
@ -60,6 +62,11 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
{ {
Matrix resMat = (matrix1.cols > matrix2.cols) ? createMatrix(matrix1.rows, matrix1.cols) : createMatrix(matrix2.rows, matrix2.cols); Matrix resMat = (matrix1.cols > matrix2.cols) ? createMatrix(matrix1.rows, matrix1.cols) : createMatrix(matrix2.rows, matrix2.cols);
if (resMat.buffer == NULL)
{
return createMatrix(0, 0);
}
if (matrix1.cols != matrix2.cols) if (matrix1.cols != matrix2.cols)
{ {
if (matrix1.rows != matrix2.rows) if (matrix1.rows != matrix2.rows)
@ -110,7 +117,6 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
return resMat; return resMat;
} }
// TODO implement
Matrix multiply(const Matrix matrix1, const Matrix matrix2) Matrix multiply(const Matrix matrix1, const Matrix matrix2)
{ {
if (matrix1.cols != matrix2.rows || matrix1.buffer == NULL || matrix2.buffer == NULL) if (matrix1.cols != matrix2.rows || matrix1.buffer == NULL || matrix2.buffer == NULL)
@ -121,6 +127,11 @@ Matrix multiply(const Matrix matrix1, const Matrix matrix2)
int rows = matrix1.rows, cols = matrix2.cols; int rows = matrix1.rows, cols = matrix2.cols;
Matrix resMat = createMatrix(rows, cols); Matrix resMat = createMatrix(rows, cols);
if (resMat.buffer == NULL)
{
return createMatrix(0, 0);
}
for (size_t rowIdx = 0; rowIdx < rows; rowIdx++) for (size_t rowIdx = 0; rowIdx < rows; rowIdx++)
{ {
for (size_t colIdx = 0; colIdx < cols; colIdx++) for (size_t colIdx = 0; colIdx < cols; colIdx++)