solve_linear.hpp

Back to Matrix algorithms

pastel/math/matrix/

#ifndef PASTELMATH_SOLVE_LINEAR_HPP
#define PASTELMATH_SOLVE_LINEAR_HPP

#include "pastel/math/matrix/solve_linear.h"

namespace Pastel
{

    template <typename Real, integer N, 
        typename Expression_A, typename Expression_B>
    Vector<Real, N> solveLinear(
        const MatrixExpression<Real, Expression_A>& a,
        const VectorExpression<Real, N, Expression_B>& b)
    {
        // The linear system is solved by Gaussian elimination 
        // with and partial pivoting and back-substitution.
        ENSURE_OP(a.width(), ==, a.height());
        ENSURE_OP(b.size(), ==, a.height());

        integer n = a.width();
        integer m = a.height();

        Matrix<Real> left(a);
        Vector<Real, N> right(b);

        // Reduce the system
        // Ax = b
        // to the system
        // A'x = b'
        // where A' is upper triangular
        // (and 1's on the diagonal).
        for (integer k = 0;k < n;++k)
        {
            // From this column, find the element with
            // the maximum absolute value (with i >= k).
            integer maxAbsRow = k;
            Real maxAbsValue = abs(left(k, k));
            for (integer i = k + 1;i < m;++i)
            {
                Real currentAbsValue = abs(left(i, k));
                if (currentAbsValue > maxAbsValue)
                {
                    maxAbsRow = i;
                    maxAbsValue = currentAbsValue;
                }
            }

            // Swap rows (if necessary) so that the maximum
            // absolute value will be at (k, k).
            if (maxAbsRow != k)
            {
                using std::swap;
                for (integer j = k;j < m;++j)
                {
                    swap(left(k, j), left(maxAbsRow, j));
                }
                swap(right[k], right[maxAbsRow]);
            }

            // Use the k:th row to zero the lower
            // part of the k:th column.
            for (integer i = k + 1;i < m;++i)
            {
                Real r = left(i, k) / left(k, k);
                left(i, k) = 0;
                for (integer j = k + 1;j < n;++j)
                {
                    left(i, j) -= left(k, j) * r;
                }
                right[i] -= right[k] * r;
            }
        }

        // Now the system is of the form A'x = b',
        // where A' is upper triangular. Solve it
        // with back-substitution.
        return solveUpperTriangular(left, right);
    }

    template <typename Real, integer N, 
        typename Expression_A, typename Expression_B>
    Vector<Real, N> solveLowerTriangular(
        const MatrixExpression<Real, Expression_A>& a,
        const VectorExpression<Real, N, Expression_B>& b)
    {
        ENSURE_OP(a.width(), ==, a.height());
        ENSURE_OP(b.size(), ==, a.height());

        integer n = a.width();
        integer m = a.height();

        Vector<Real, N> right = b;

        // We want to solve the system
        // Ax = b
        // where A is lower triangular.

        // Use forward-substitution to solve for x.
        for (integer j = 0;j < n - 1;++j)
        {
            Real factor = right[j] / a(j, j);
            for (integer i = j + 1;i < m;++i)
            {
                right[i] -= a(i, j) * factor;
            }
            right[j] /= a(j, j);
        }
        right[n - 1] /= a(n - 1, n - 1);

        return right;
    }

    template <typename Real, integer N, 
        typename Expression_A, typename Expression_B>
    Vector<Real, N> solveUnitLowerTriangular(
        const MatrixExpression<Real, Expression_A>& a,
        const VectorExpression<Real, N, Expression_B>& b)
    {
        ENSURE_OP(a.width(), ==, a.height());
        ENSURE_OP(b.size(), ==, a.height());

        integer n = a.width();
        integer m = a.height();

        Vector<Real, N> right = b;

        // We want to solve the system
        // Ax = b
        // where A is lower unit-triangular
        // (1's on the diagonal).

        // Use forward-substitution to solve for x.
        for (integer j = 0;j < n - 1;++j)
        {
            Real factor = right[j];
            for (integer i = j + 1;i < m;++i)
            {

                right[i] -= a(i, j) * factor;
            }
        }

        return right;
    }

    template <typename Real, integer N, 
        typename Expression_A, typename Expression_B>
    Vector<Real, N> solveUpperTriangular(
        const MatrixExpression<Real, Expression_A>& a,
        const VectorExpression<Real, N, Expression_B>& b)
    {
        ENSURE_OP(a.width(), ==, a.height());
        ENSURE_OP(b.size(), ==, a.height());

        // We want to solve the system
        // Ax = b
        // where A is upper triangular.

        integer n = a.width();

        Vector<Real, N> right = b;
        if (n == 0)
        {
            return right;
        }

        // Use back-substitution to solve for x.
        for (integer j = n - 1;j > 0;--j)
        {
            // As a loop-invariant, the 'right'
            // contains solved x_i elements
            // for i in [j + 1, n[ when entering
            // this loop, and for i in [j, n[
            // when at the end of the loop.

            Real factor = right[j] / a(j, j);
            for (integer i = j - 1;i >= 0;--i)
            {
                right[i] -= a(i, j) * factor;
            }
            right[j] /= a(j, j);
        }
        // The x_1 only needs to be scaled to be solved.
        right[0] /= a(0, 0);

        return right;
    }

    template <typename Real, integer N, 
        typename Expression_A, typename Expression_B>
    Vector<Real, N> solveUnitUpperTriangular(
        const MatrixExpression<Real, Expression_A>& a,
        const VectorExpression<Real, N, Expression_B>& b)
    {
        ENSURE_OP(a.width(), ==, a.height());
        ENSURE_OP(b.size(), ==, a.height());

        // We want to solve the system
        // Ax = b
        // where A is unit upper triangular
        // (1' on the diagonal).

        integer n = a.height();

        // Use back-substitution to solve for x.
        Vector<Real, N> right = b;
        for (integer j = n - 1;j > 0;--j)
        {
            // As a loop-invariant, the 'right'
            // contains solved x_i elements
            // for i in [j, n[ when entering
            // this loop, and for i in [j - 1, n[
            // when at the end of the loop.

            Real factor = right[j];
            for (integer i = j - 1;i >= 0;--i)
            {
                right[i] -= a(i, j) * factor;
            }
        }

        return right;
    }

}

#endif