Fix matrix multiply

This commit is contained in:
D2A62006 2025-11-26 17:43:41 +01:00
parent 2783663ab9
commit 57bf46bfc5

View File

@ -107,7 +107,7 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
if (matrix1.cols == matrix2.cols && matrix1.rows == matrix2.rows){ if (matrix1.cols == matrix2.cols && matrix1.rows == matrix2.rows){
//size of the matrices should be the same, if the addition is supposed to happen //size of the matrices should be the same, if the addition is supposed to happen
// Matrix outputMatrix = createMatrix(matrix1.rows, matrix1.cols); // Matrix outputMatrix = createMatrix(matrix1.rows, matrix1.cols);
Matrix outputMatrix = createMatrix(matrix1.cols, matrix1.rows); Matrix outputMatrix = createMatrix(matrix1.rows, matrix1.cols);
for (int i = 0; i < matrix1.rows;i++) { for (int i = 0; i < matrix1.rows;i++) {
for (int j = 0; j < matrix1.cols; j++) { for (int j = 0; j < matrix1.cols; j++) {
// how this should work in normal Matrix version: // how this should work in normal Matrix version:
@ -127,18 +127,21 @@ Matrix add(const Matrix matrix1, const Matrix matrix2)
} }
/*
Matrix multiply(const Matrix matrix1, const Matrix matrix2) Matrix multiply(const Matrix matrix1, const Matrix matrix2)
{ {
//check, if the matrices can be multiplied //check, if the matrices can be multiplied
if (matrix1.rows == matrix2.cols) { if (matrix1.rows == matrix2.cols) {
// Matrix outputMatrix = createMatrix(matrix1.rows, matrix2.cols); Matrix outputMatrix = createMatrix(matrix1.rows, matrix2.cols);
Matrix outputMatrix = createMatrix(matrix2.cols, matrix1.rows); //Matrix outputMatrix = createMatrix(matrix2.cols, matrix1.rows);
for(int i = 0; i < matrix1.rows; i++) { for(int i = 0; i < matrix1.rows; i++) {
for (int j = 0; j < matrix2.cols; j++) { for (int j = 0; j < matrix2.cols; j++) {
for (int k = 0; k < matrix2.rows; k++) { for (int k = 0; k < matrix2.rows; k++) {
// how this should work in normal Matrix version: // how this should work in normal Matrix version:
// outputMatrix.buffer[i][j] = matrix1.buffer[i][k] * matrix2.buffer[k][j]; // outputMatrix.buffer[i][j] = matrix1.buffer[i][k] * matrix2.buffer[k][j];
outputMatrix.buffer[i + outputMatrix.rows * j] = matrix1.buffer[i + matrix1.rows * k] * matrix2.buffer[k + matrix2.rows * j]; outputMatrix.buffer[i + outputMatrix.rows * j] += matrix1.buffer[i + matrix1.rows * k] * matrix2.buffer[k + matrix2.rows * j];
} }
} }
} }
@ -151,4 +154,37 @@ Matrix multiply(const Matrix matrix1, const Matrix matrix2)
m.buffer = NULL; m.buffer = NULL;
return m; return m;
} }
}
*/
Matrix multiply(const Matrix matrix1, const Matrix matrix2)
{
if(matrix1.cols != matrix2.rows){
Matrix m = {NULL, 0, 0};
return m;
}
Matrix outputMatrix = createMatrix(matrix1.rows, matrix2.cols);
if(!outputMatrix.buffer){
Matrix m = {NULL, 0, 0};
return m;
}
//Matrix outputMatrix = createMatrix(matrix2.cols, matrix1.rows);
for(int i = 0; i < matrix1.rows; i++) {
for (int j = 0; j < matrix2.cols; j++) {
MatrixType sum = 0;
for (int k = 0; k < matrix1.cols; k++) {
// how this should work in normal Matrix version:
// outputMatrix.buffer[i][j] = matrix1.buffer[i][k] * matrix2.buffer[k][j];
//outputMatrix.buffer[i + outputMatrix.rows * j] += matrix1.buffer[i + matrix1.rows * k] * matrix2.buffer[k + matrix2.rows * j];
sum += matrix1.buffer[i * matrix1.cols + k] * matrix2.buffer[j + matrix2.cols * k];
}
outputMatrix.buffer[i * outputMatrix.cols + j] = sum;
}
}
return outputMatrix;
} }