Compare commits
No commits in common. "46601b302072d8b57562dd064e922ffbb29a793c" and "b7e44e96703a982b32610cb0c8f459b8cf8fd4f0" have entirely different histories.
46601b3020
...
b7e44e9670
46
matrix.c
46
matrix.c
@ -50,45 +50,17 @@ MatrixType getMatrixAt(const Matrix matrix, unsigned int rowIdx, unsigned int co
|
|||||||
Matrix add(const Matrix matrix1, const Matrix matrix2)
|
Matrix add(const Matrix matrix1, const Matrix matrix2)
|
||||||
{
|
{
|
||||||
Matrix result;
|
Matrix result;
|
||||||
|
if (matrix1.rows != matrix2.rows || matrix1.cols != matrix2.cols) {
|
||||||
// Case 1: Exact same dimensions
|
result.rows = 0;
|
||||||
if (matrix1.rows == matrix2.rows && matrix1.cols == matrix2.cols) {
|
result.cols = 0;
|
||||||
result = createMatrix(matrix1.rows, matrix1.cols);
|
result.buffer = NULL;
|
||||||
for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++)
|
|
||||||
result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i];
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2: matrix1 is (rows x 1) column vector, matrix2 is (rows x cols) - broadcast bias
|
result = createMatrix(matrix1.rows, matrix1.cols);
|
||||||
if (matrix1.rows == matrix2.rows && matrix1.cols == 1) {
|
for (unsigned int i = 0; i < matrix1.rows * matrix1.cols; i++)
|
||||||
result = createMatrix(matrix2.rows, matrix2.cols);
|
result.buffer[i] = matrix1.buffer[i] + matrix2.buffer[i];
|
||||||
for (unsigned int col = 0; col < matrix2.cols; col++) {
|
|
||||||
for (unsigned int row = 0; row < matrix2.rows; row++) {
|
|
||||||
MatrixType val1 = matrix1.buffer[row * matrix1.cols + 0];
|
|
||||||
MatrixType val2 = matrix2.buffer[row * matrix2.cols + col];
|
|
||||||
result.buffer[row * result.cols + col] = val1 + val2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 3: matrix2 is (rows x 1) column vector, matrix1 is (rows x cols) - broadcast bias
|
|
||||||
if (matrix2.rows == matrix1.rows && matrix2.cols == 1) {
|
|
||||||
result = createMatrix(matrix1.rows, matrix1.cols);
|
|
||||||
for (unsigned int col = 0; col < matrix1.cols; col++) {
|
|
||||||
for (unsigned int row = 0; row < matrix1.rows; row++) {
|
|
||||||
MatrixType val1 = matrix1.buffer[row * matrix1.cols + col];
|
|
||||||
MatrixType val2 = matrix2.buffer[row * matrix2.cols + 0];
|
|
||||||
result.buffer[row * result.cols + col] = val1 + val2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No valid case - return empty matrix
|
|
||||||
result.rows = 0;
|
|
||||||
result.cols = 0;
|
|
||||||
result.buffer = NULL;
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user