diff --git a/matrix.c b/matrix.c index 999f713..f9ea09a 100644 --- a/matrix.c +++ b/matrix.c @@ -56,7 +56,11 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) { Matrix result = {0}; - if (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols) + int broadcast_case = + (matrix1.cols == 1 && matrix1.rows == matrix2.rows) || + (matrix2.cols == 1 && matrix1.rows == matrix2.rows); + + if (!broadcast_case && (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols)) { return result; } @@ -75,6 +79,8 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) if (matrix1.cols == 1 && matrix1.rows == matrix2.rows) // Broadcasting { + result.rows = matrix2.rows; + result.cols = matrix2.cols; for (unsigned int i = 0; i < matrix1.rows; i++) { @@ -89,6 +95,9 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) else if (matrix2.cols == 1 && matrix1.rows == matrix2.rows) { + result.rows = matrix1.rows; + result.cols = matrix1.cols; + for (unsigned int i = 0; i < matrix2.rows; i++) { for (unsigned int j = 0; j < result.cols; j++) diff --git a/runMatrixTests b/runMatrixTests deleted file mode 100755 index d05b8a8..0000000 Binary files a/runMatrixTests and /dev/null differ