diff --git a/matrix.c b/matrix.c index 917a8b2..4d42efd 100644 --- a/matrix.c +++ b/matrix.c @@ -58,13 +58,44 @@ MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int co Matrix add(const Matrix matrix1, const Matrix matrix2) { - Matrix resMat = createMatrix(matrix1.rows, matrix1.cols); + Matrix resMat = (matrix1.cols > matrix2.cols) ? createMatrix(matrix1.rows, matrix1.cols) : createMatrix(matrix2.rows, matrix2.cols); - // clear matrix and return if the dimensions of the input matrices differ from each other - if (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols) + if (matrix1.cols != matrix2.cols) { - clearMatrix(&resMat); - return resMat; + if (matrix1.rows != matrix2.rows) + { + clearMatrix(&resMat); + return resMat; + } + else if (matrix1.cols == 1) + { + // broadcast vector + for (size_t m = 0; m < matrix2.rows; m++) + { + for (size_t n = 0; n < matrix2.cols; n++) + { + setMatrixAt(getMatrixAt(matrix2, m, n) + getMatrixAt(matrix1, m, 0), resMat, m, n); + } + } + return resMat; + } + else if (matrix2.cols == 1) + { + // broadcast vector + for (size_t m = 0; m < matrix1.rows; m++) + { + for (size_t n = 0; n < matrix1.cols; n++) + { + setMatrixAt(getMatrixAt(matrix1, m, n) + getMatrixAt(matrix2, m, 0), resMat, m, n); + } + } + return resMat; + } + else + { + clearMatrix(&resMat); + return resMat; + } } for (size_t m = 0; m < matrix1.rows; m++)