Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions src/main/java/com/thealgorithms/maths/LUDecomposition.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package com.thealgorithms.maths;

/**
* @brief Implementation of LU Decomposition using the Doolittle algorithm
* @details Decomposes a square matrix A into a lower triangular matrix L and
* an upper triangular matrix U such that A = L * U. The diagonal of L contains
* all ones (Doolittle convention). This decomposition is useful for solving
* systems of linear equations, computing determinants, and finding inverses.
* @see <a href="https://en.wikipedia.org/wiki/LU_decomposition">LU Decomposition</a>
*/
public final class LUDecomposition {

private LUDecomposition() {
}

/**
* @brief Performs LU decomposition on a square matrix using the Doolittle algorithm
* @param matrix a square matrix
* @return a 2D array where the lower triangle (excluding diagonal) contains L
* elements (with implicit 1s on the diagonal) and the upper triangle
* (including diagonal) contains U elements
* @throws IllegalArgumentException if the matrix is not square
* @throws ArithmeticException if a zero pivot is encountered
*/
public static double[][] decompose(double[][] matrix) {
int n = matrix.length;
for (double[] row : matrix) {
if (row.length != n) {
throw new IllegalArgumentException("Matrix must be square.");
}
}

double[][] lu = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
lu[i][j] = matrix[i][j];
}
}

for (int k = 0; k < n; k++) {
for (int j = k; j < n; j++) {
double sum = 0;
for (int s = 0; s < k; s++) {
sum += lu[k][s] * lu[s][j];
}
lu[k][j] -= sum;
}

if (lu[k][k] == 0) {
throw new ArithmeticException("Zero pivot encountered. Matrix may be singular.");
}

for (int i = k + 1; i < n; i++) {
double sum = 0;
for (int s = 0; s < k; s++) {
sum += lu[i][s] * lu[s][k];
}
lu[i][k] = (lu[i][k] - sum) / lu[k][k];
}
}

return lu;
}

/**
* @brief Extracts the lower triangular matrix L from the combined LU matrix
* @param lu the combined LU matrix from decompose()
* @return the lower triangular matrix L with 1s on the diagonal
*/
public static double[][] getLowerMatrix(double[][] lu) {
int n = lu.length;
double[][] lower = new double[n][n];
for (int i = 0; i < n; i++) {
lower[i][i] = 1.0;
for (int j = 0; j < i; j++) {
lower[i][j] = lu[i][j];
}
}
return lower;
}

/**
* @brief Extracts the upper triangular matrix U from the combined LU matrix
* @param lu the combined LU matrix from decompose()
* @return the upper triangular matrix U
*/
public static double[][] getUpperMatrix(double[][] lu) {
int n = lu.length;
double[][] upper = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
upper[i][j] = lu[i][j];
}
}
return upper;
}

/**
* @brief Solves a system of linear equations Ax = b using LU decomposition
* @param lu the combined LU matrix from decompose()
* @param b the right-hand side vector
* @return the solution vector x
*/
public static double[] solve(double[][] lu, double[] b) {
int n = lu.length;
double[] y = new double[n];
double[] x = new double[n];

// Forward substitution: solve Ly = b
for (int i = 0; i < n; i++) {
double sum = 0;
for (int j = 0; j < i; j++) {
sum += lu[i][j] * y[j];
}
y[i] = b[i] - sum;
}

// Back substitution: solve Ux = y
for (int i = n - 1; i >= 0; i--) {
double sum = 0;
for (int j = i + 1; j < n; j++) {
sum += lu[i][j] * x[j];
}
x[i] = (y[i] - sum) / lu[i][i];
}

return x;
}
}
106 changes: 106 additions & 0 deletions src/test/java/com/thealgorithms/maths/LUDecompositionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package com.thealgorithms.maths;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;

public class LUDecompositionTest {

private static final double DELTA = 1e-9;

@Test
public void testDecomposeSimpleMatrix() {
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
double[][] lu = LUDecomposition.decompose(matrix);

double[][] lower = LUDecomposition.getLowerMatrix(lu);
double[][] upper = LUDecomposition.getUpperMatrix(lu);

assertArrayEquals(new double[] {1, 1, 1}, new double[] {lower[0][0], lower[1][1], lower[2][2]}, DELTA);

double[][] product = multiply(lower, upper);
assertArrayEquals(new double[] {2, 1, 1, 4, 3, 3, 8, 7, 9}, flatten(product), DELTA);
}

@Test
public void testDecomposeTwoByTwo() {
double[][] matrix = {{1, 2}, {3, 4}};
double[][] lu = LUDecomposition.decompose(matrix);

double[][] lower = LUDecomposition.getLowerMatrix(lu);
double[][] upper = LUDecomposition.getUpperMatrix(lu);

double[][] product = multiply(lower, upper);
assertArrayEquals(new double[] {1, 2, 3, 4}, flatten(product), DELTA);
}

@Test
public void testDecomposeIdentityMatrix() {
double[][] matrix = {{1, 0}, {0, 1}};
double[][] lu = LUDecomposition.decompose(matrix);

double[][] lower = LUDecomposition.getLowerMatrix(lu);
double[][] upper = LUDecomposition.getUpperMatrix(lu);

assertArrayEquals(new double[] {1, 0, 0, 1}, flatten(lower), DELTA);
assertArrayEquals(new double[] {1, 0, 0, 1}, flatten(upper), DELTA);
}

@Test
public void testDecomposeNonSquareMatrixThrows() {
double[][] matrix = {{1, 2, 3}, {4, 5, 6}};
assertThrows(IllegalArgumentException.class, () -> LUDecomposition.decompose(matrix));
}

@Test
public void testDecomposeSingularMatrixThrows() {
double[][] matrix = {{0, 1}, {1, 0}};
assertThrows(ArithmeticException.class, () -> LUDecomposition.decompose(matrix));
}

@Test
public void testSolveLinearSystem() {
double[][] matrix = {{2, 1, 1}, {4, 3, 3}, {8, 7, 9}};
double[] b = {8, 20, 46};
double[][] lu = LUDecomposition.decompose(matrix);
double[] solution = LUDecomposition.solve(lu, b);

assertArrayEquals(new double[] {1, 3, 3}, solution, DELTA);
}

@Test
public void testSolveTwoByTwoSystem() {
double[][] matrix = {{2, 1}, {1, 3}};
double[] b = {5, 7};
double[][] lu = LUDecomposition.decompose(matrix);
double[] solution = LUDecomposition.solve(lu, b);

assertArrayEquals(new double[] {1.6, 1.8}, solution, DELTA);
}

private static double[][] multiply(double[][] a, double[][] b) {
int n = a.length;
double[][] result = new double[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
result[i][j] += a[i][k] * b[k][j];
}
}
}
return result;
}

private static double[] flatten(double[][] matrix) {
int n = matrix.length;
double[] result = new double[n * n];
int idx = 0;
for (double[] row : matrix) {
for (double val : row) {
result[idx++] = val;
}
}
return result;
}
}
Loading