All Tests Passed

This commit is contained in:
Nicolas Reckert 2025-11-16 20:43:25 +01:00
parent b7e44e9670
commit 339c3e81b1

View File

@ -46,21 +46,49 @@ MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int co
return matrix.buffer[rowIdx * matrix.cols + colIdx]; return matrix.buffer[rowIdx * matrix.cols + colIdx];
} }
// Addition // Addition (with broadcasting support for bias)
Matrix add(const Matrix matrix1, const Matrix matrix2) Matrix add(const Matrix matrix1, const Matrix matrix2)
{ {
Matrix result; Matrix result;
if (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols) {
result.rows = 0; // Case 1: Exact same dimensions
result.cols = 0; if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) {
result.buffer = NULL; result = createMatrix(matrix1.rows, matrix1.cols);
for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++)
result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i];
return result; return result;
} }
result = createMatrix(matrix1.rows, matrix1.cols); // Case 2: matrix1 is (rows x 1) column vector, matrix2 is (rows x cols) - broadcast bias
for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++) if (matrix1.rows == matrix2.rows && matrix1.cols == 1) {
result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i]; result = createMatrix(matrix2.rows, matrix2.cols);
for (unsigned int col = 0; col < matrix2.cols; col++) {
for (unsigned int row = 0; row < matrix2.rows; row++) {
MatrixType val1 = matrix1.buffer[row * matrix1.cols + 0];
MatrixType val2 = matrix2.buffer[row * matrix2.cols + col];
result.buffer[row * result.cols + col] = val1 + val2;
}
}
return result;
}
// Case 3: matrix2 is (rows x 1) column vector, matrix1 is (rows x cols) - broadcast bias
if (matrix2.rows == matrix1.rows && matrix2.cols == 1) {
result = createMatrix(matrix1.rows, matrix1.cols);
for (unsigned int col = 0; col < matrix1.cols; col++) {
for (unsigned int row = 0; row < matrix1.rows; row++) {
MatrixType val1 = matrix1.buffer[row * matrix1.cols + col];
MatrixType val2 = matrix2.buffer[row * matrix2.cols + 0];
result.buffer[row * result.cols + col] = val1 + val2;
}
}
return result;
}
// No valid case - return empty matrix
result.rows = 0;
result.cols = 0;
result.buffer = NULL;
return result; return result;
} }