diff --git a/matrix.c b/matrix.c index cf75a94..81416fd 100644 --- a/matrix.c +++ b/matrix.c @@ -109,7 +109,6 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) { for (size_t n = 0; n < matrix1.cols; n++) { - // this is unnecessarily complicated because at this point we already know that the matrices are compatible setMatrixAt(getMatrixAt(matrix1, m, n) + getMatrixAt(matrix2, m, n), resMat, m, n); } } @@ -117,32 +116,37 @@ Matrix add(const Matrix matrix1, const Matrix matrix2) return resMat; } -Matrix multiply(const Matrix matrix1, const Matrix matrix2) +Matrix multiply(const Matrix A, const Matrix B) { - if (matrix1.cols != matrix2.rows || matrix1.buffer == NULL || matrix2.buffer == NULL) + if (A.cols != B.rows || A.buffer == NULL || B.buffer == NULL) { return createMatrix(0, 0); } - int rows = matrix1.rows, cols = matrix2.cols; - Matrix resMat = createMatrix(rows, cols); + int rows = A.rows, cols = B.cols; + Matrix C = createMatrix(rows, cols); - if (resMat.buffer == NULL) + if (C.buffer == NULL) { return createMatrix(0, 0); } - for (size_t rowIdx = 0; rowIdx < rows; rowIdx++) + // M = Rows, K = Common Dim, N = Cols + size_t M = A.rows, K = A.cols, N = B.cols; + + for (size_t i = 0; i < M; i++) { - for (size_t colIdx = 0; colIdx < cols; colIdx++) + for (size_t k = 0; k < K; k++) + { - int curCellVal = 0; - for (size_t k = 0; k < matrix1.cols; k++) + MatrixType valA = A.buffer[i * K + k]; + for (size_t j = 0; j < N; j++) { - curCellVal += getMatrixAt(matrix1, rowIdx, k) * getMatrixAt(matrix2, k, colIdx); + // C[i, j] += A[i, k] * B[k, j]; + // M x N, M x K, K x N + C.buffer[i * N + j] += valA * B.buffer[k * N + j]; } - setMatrixAt(curCellVal, resMat, rowIdx, colIdx); } } - return resMat; + return C; } \ No newline at end of file