forked from freudenreichan/info2Praktikum-NeuronalesNetz
make matrix multiplication faster
This commit is contained in:
parent
00ac1aa04a
commit
adf75b66d2
30
matrix.c
30
matrix.c
@ -109,7 +109,6 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
|
|||||||
{
|
{
|
||||||
for (size_t n = 0; n < matrix1.cols; n++)
|
for (size_t n = 0; n < matrix1.cols; n++)
|
||||||
{
|
{
|
||||||
// this is unnecessarily complicated because at this point we already know that the matrices are compatible
|
|
||||||
setMatrixAt(getMatrixAt(matrix1, m, n) + getMatrixAt(matrix2, m, n), resMat, m, n);
|
setMatrixAt(getMatrixAt(matrix1, m, n) + getMatrixAt(matrix2, m, n), resMat, m, n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,32 +116,37 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
|
|||||||
return resMat;
|
return resMat;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix multiply(const Matrix matrix1, const Matrix matrix2)
|
Matrix multiply(const Matrix A, const Matrix B)
|
||||||
{
|
{
|
||||||
if (matrix1.cols != matrix2.rows || matrix1.buffer == NULL || matrix2.buffer == NULL)
|
if (A.cols != B.rows || A.buffer == NULL || B.buffer == NULL)
|
||||||
{
|
{
|
||||||
return createMatrix(0, 0);
|
return createMatrix(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int rows = matrix1.rows, cols = matrix2.cols;
|
int rows = A.rows, cols = B.cols;
|
||||||
Matrix resMat = createMatrix(rows, cols);
|
Matrix C = createMatrix(rows, cols);
|
||||||
|
|
||||||
if (resMat.buffer == NULL)
|
if (C.buffer == NULL)
|
||||||
{
|
{
|
||||||
return createMatrix(0, 0);
|
return createMatrix(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t rowIdx = 0; rowIdx < rows; rowIdx++)
|
// M = Rows, K = Common Dim, N = Cols
|
||||||
|
size_t M = A.rows, K = A.cols, N = B.cols;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < M; i++)
|
||||||
{
|
{
|
||||||
for (size_t colIdx = 0; colIdx < cols; colIdx++)
|
for (size_t k = 0; k < K; k++)
|
||||||
|
|
||||||
{
|
{
|
||||||
int curCellVal = 0;
|
MatrixType valA = A.buffer[i * K + k];
|
||||||
for (size_t k = 0; k < matrix1.cols; k++)
|
for (size_t j = 0; j < N; j++)
|
||||||
{
|
{
|
||||||
curCellVal += getMatrixAt(matrix1, rowIdx, k) * getMatrixAt(matrix2, k, colIdx);
|
// C[i, j] += A[i, k] * B[k, j];
|
||||||
|
// M x N, M x K, K x N
|
||||||
|
C.buffer[i * N + j] += valA * B.buffer[k * N + j];
|
||||||
}
|
}
|
||||||
setMatrixAt(curCellVal, resMat, rowIdx, colIdx);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return resMat;
|
return C;
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user