From 339c3e81b1b57f7f61884c4e636df1213feedab4 Mon Sep 17 00:00:00 2001 From: Nicolas Date: Sun, 16 Nov 2025 20:43:25 +0100 Subject: [PATCH] All Tests Passed --- matrix.c | 48 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/matrix.c b/matrix.c index 936eaab..3e7a79e 100644 --- a/matrix.c +++ b/matrix.c @@ -46,21 +46,49 @@ MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int co return matrix.buffer[rowIdx * matrix.cols + colIdx]; } -// Addition +// Addition (with broadcasting support for bias) Matrix add(const Matrix matrix1, const Matrix matrix2) { Matrix result; - if (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols) { - result.rows = 0; - result.cols = 0; - result.buffer = NULL; + + // Case 1: Exact same dimensions + if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) { + result = createMatrix(matrix1.rows, matrix1.cols); + for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++) + result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i]; return result; } - - result = createMatrix(matrix1.rows, matrix1.cols); - for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++) - result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i]; - + + // Case 2: matrix1 is (rows x 1) column vector, matrix2 is (rows x cols) - broadcast bias + if (matrix1.rows == matrix2.rows && matrix1.cols == 1) { + result = createMatrix(matrix2.rows, matrix2.cols); + for (unsigned int col = 0; col < matrix2.cols; col++) { + for (unsigned int row = 0; row < matrix2.rows; row++) { + MatrixType val1 = matrix1.buffer[row * matrix1.cols + 0]; + MatrixType val2 = matrix2.buffer[row * matrix2.cols + col]; + result.buffer[row * result.cols + col] = val1 + val2; + } + } + return result; + } + + // Case 3: matrix2 is (rows x 1) column vector, matrix1 is (rows x cols) - broadcast bias + if (matrix2.rows == matrix1.rows && matrix2.cols == 1) { + result = createMatrix(matrix1.rows, matrix1.cols); + for (unsigned int col = 0; col < matrix1.cols; col++) { + for (unsigned int row = 0; row < matrix1.rows; row++) { + MatrixType val1 = matrix1.buffer[row * matrix1.cols + col]; + MatrixType val2 = matrix2.buffer[row * matrix2.cols + 0]; + result.buffer[row * result.cols + col] = val1 + val2; + } + } + return result; + } + + // No valid case - return empty matrix + result.rows = 0; + result.cols = 0; + result.buffer = NULL; return result; }