diff --git a/matrix.c b/matrix.c index bcdf3e8..86f3a30 100644 --- a/matrix.c +++ b/matrix.c @@ -52,11 +52,10 @@ MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int co Matrix add(const Matrix matrix1, const Matrix matrix2) { + Matrix result_add = createMatrix(matrix1.rows, matrix1.cols); // check, equal rows if (matrix1.rows == matrix2.rows) { - Matrix result_add = createMatrix(matrix1.rows, matrix1.cols); - // "Elementweise Addition": test, if two matrix has exact size if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) { @@ -94,12 +93,12 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) } } } - - return result_add; } - else if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) + else if (matrix1.rows != matrix2.rows || matrix1.cols != 1 || matrix2.cols != 1) return createMatrix(0, 0); + + return result_add; } Matrix multiply(const Matrix matrix1, const Matrix matrix2)