matrix_inverse.hpp

Back to Matrix algorithms

pastel/math/matrix/

#ifndef PASTELMATH_MATRIX_INVERSE_HPP
#define PASTELMATH_MATRIX_INVERSE_HPP

#include "pastel/math/matrix/matrix_inverse.h"
#include "pastel/math/matrix/matrix_adjugate.h"
#include "pastel/math/matrix/matrix_determinant.h"

namespace Pastel
{

    template <typename Real, typename Expression>
    Matrix<Real> inverse(
        const MatrixExpression<Real, Expression>& matrix)
    {
        // Let A in RR^{n x n} be non-singular. Then by
        // non-singularity of A the linear systems
        //
        //     A x_i = e_i, for all i in [1, n],
        //
        // where e_i is the i:th column of the nxn identity matrix,
        // each have a unique solution x_i. These linear systems
        // can be arranged into a combined linear system
        //
        //     AX = I, 
        //
        // where X = [x_1, ..., x_n]. Therefore X is the 
        // right-inverse of A. To solve X, we start from the 
        // augmented matrix [A | I], where the extent of the 
        // identity matrix I is the same as A, and then multiply 
        // from the left by elementary matrices such that we 
        // end up [I | X]. This is Gauss-Jordan elimination.

        integer n = matrix.width();
        integer m = matrix.height();

        ENSURE_OP(m, ==, n);

        // The original matrix, left part of [A | I].
        Matrix<Real> left(matrix);

        // The identity matrix, right part of [A | I].
        Matrix<Real> right(m, n);

        if (n <= 3)
        {
            Real det = determinant(left);
            if (det == 0)
            {
                throw SingularMatrix_Exception();
            }

            // Specialization for small matrices.
            right = adjugate(left) / det;

            return right;
        }

        for (integer k = 0;k < n;++k)
        {
            // The strategy in Gauss-Jordan elimination is to modify
            // [A | I] by elementary row-operations such that 
            // at the end of the (k + 1):th iteration of this loop 
            // the (k + 1) first columns of left are [e_1, ..., e_(k + 1)].

            // To do this, given the k:th column,
            // 1) Pick a row i such that left(i, k) != 0 and i >= k.
            // 2) If k != i, swap rows i and k.
            // 3) Divide the k:th row by left(k, k).
            // 4) Subtract the row k multiplied by left(j, k)
            //    from each row j != k.

            // While we could pick any non-zero element from the
            // column (with i >= k), for numerical stability it 
            // is better to pick the one with the largest absolute
            // value. This is called partial pivoting.

            // From the k:th column, find the element with
            // the maximum absolute value (with i >= k).
            integer maxAbsRow = k;
            Real maxAbsValue = abs(left(k, k));
            for (integer i = k + 1;i < m;++i)
            {
                Real currentAbsValue = abs(left(i, k));
                if (currentAbsValue > maxAbsValue)
                {
                    maxAbsRow = i;
                    maxAbsValue = currentAbsValue;
                }
            }

            // Swap (if necessary) so that the maximum
            // absolute value will be at (k, k).
            if (maxAbsRow != k)
            {
                using std::swap;
                for (integer j = 0;j < k;++j)
                {
                    // By the loop invariant invariant 
                    // left(i, j) = 0, for j < k and i > k.
                    swap(right(k, j), right(maxAbsRow, j));
                }
                for (integer j = k;j < n;++j)
                {
                    swap(left(k, j), left(maxAbsRow, j));
                    swap(right(k, j), right(maxAbsRow, j));
                }
            }

            if (left(k, k) == 0)
            {
                throw SingularMatrix_Exception();
            }

            // Use the k:th row to clear out the k:th column
            // except for the k:th row. 
            for (integer i = 0;i < m;++i)
            {
                if (i == k)
                {
                    // Skip the k:th row.
                    continue;
                }

                Real value = left(i, k) / left(k, k);
                for (integer j = 0;j < k;++j)
                {
                    // By the loop invariant invariant 
                    // left(i, j) = 0, for j < k.

                    right(i, j) -= right(k, j) * value;
                }
                for (integer j = k;j < n;++j)
                {
                    left(i, j) -= left(k, j) * value;
                    right(i, j) -= right(k, j) * value;
                }
            }

            // Scale the k:th row such that left(k, k) = 1.
            Real a = inverse(left(k, k));
            for (integer j = 0;j < k;++j)
            {
                // By the loop invariant invariant 
                // left(i, j) = 0, for j < k.

                right(k, j) *= a;
            }
            for (integer j = k;j < n;++j)
            {
                left(k, j) *= a;
                right(k, j) *= a;
            }
        }

        return right;
    }

}

#endif