coherent_point_drift.h

Back to Coherent point drift

pastel/geometry/pattern_matching/

// Description: Coherent point drift algorithm.
// Documentation: coherent_point_drift.txt

#ifndef PASTELGEOMETRY_COHERENT_POINT_DRIFT_H
#define PASTELGEOMETRY_COHERENT_POINT_DRIFT_H

#include "pastel/geometry/pattern_matching/ls_affine.h"
#include "pastel/geometry/nearestset/nearestset_concept.h"
#include "pastel/sys/output/null_output.h"
#include "pastel/sys/math/constants.h"

namespace Pastel
{

    using Cpd_Matrix = LsAffine_Matrix;
    using Cpd_Scaling = LsAffine_Scaling;
    using Cpd_Translation = LsAffine_Translation;

    template <typename Real>
    struct Cpd_Return
    {
        arma::Mat<Real> Q;
        arma::Mat<Real> S;
        arma::Col<Real> t;
        Real sigma2;
    };

    template <typename Real>
    struct Cpd_State
    {
        arma::Mat<Real> Q;
        arma::Mat<Real> S;
        arma::Col<Real> t;
        Real sigma2;
        arma::Mat<Real> W;
    };

    //! Coherent point drift algorithm.
    /*!
    Preconditions:
    0 <= minIterations <= maxIterations
    0 < noiseRatio < 1
    n > 0
    m > 0
    d > 0

    Finds matrices Q, S, and t such that

        Q * S * fromSet + t * ones(1, n)

    matches toSet.

    Input
    -----

    fromSet ((d x n) real matrix):
    A set of n points, given as a matrix, where each column
    contains the coordinates of a d-dimensional point. 

    toSet ((d x m) real matrix):
    A set of m points, given as a matrix, where each column
    contains the coordinates of a d-dimensional point.

    Returns
    -------

    Q ((d x d) real matrix):
    The estimated rotation/reflection; an orthogonal matrix.
    Initialized with Q0.

    S ((d x d) real matrix):
    The estimated scaling; a symmetric matrix.
    Initialized with S0.

    t ((d x 1) real matrix):
    The estimated translation.
   Initialized with t0.

    sigma2 (Real):
    The estimated variance is eye(d, d) * sigma2.

    Optional input arguments
    ------------------------

    Q0 ((d x d) real matrix : arma::Mat<Real>()):
    Initial guess on Q; an orthogonal matrix. Empty matrix is
    interpreted as a (d x d) identity matrix.

    S0 ((d x d) real matrix : arma::Mat<Real>()):
    Initial guess on S; a symmetric matrix. Empty matrix is
    interpreted as a (d x d) identity matrix.

    t0 ((d x 1) real vector : arma::Col<Real>()): 
    Initial guess on t. Empty matrix is
    interpreted as a (d x 1) zero matrix.

    noiseRatio (Real : 0.2):
    A real number between (0, 1), which gives the weight for an 
    additive improper uniform distribution component for the Gaussian 
    mixture model. Larger noise-ratio makes the algorithm more 
    tolerant to noise, but declines convergence rate when the actual 
    noise level is lower.

    matrix (Cpd_Matrix : Free):
    Specifies constraints for the matrix Q.
        Free: Q^T Q = I
        Identity: Q = I

    scaling (Cpd_Scaling : Free):
    Specifies constraints for the scaling S.
        Free: S^T = S
        Diagonal: S is diagonal
        Conformal: S = sI
        Rigid: S = I

    translation (Cpd_Translation : Free):
    Specifies constraints for the translation t. 
        Free: no constraint
        Identity: T = 0

    orientation (integer : 1): 
    Specifies constraints for the determinant of A.
        <0: det(A) < 0,
         0: no constraint
        >0: det(A) > 0.

    minIterations (integer : 0):
    The minimum number of iterations for the algorithm to take.

    maxIterations (integer : std::max(minIterations, 100)):
    The maximum number of iterations for the algorithm to take. 

    minError (Real : see below):
    The minimum error under which to accept the transformation 
    and stop iteration. For float 1e-4; for double 1e-11.
    */
    template <
        typename Real,
        typename... ArgumentSet
    >
    Cpd_Return<Real> coherentPointDrift(
        const arma::Mat<Real>& fromSet, 
        const arma::Mat<Real>& toSet,
        ArgumentSet&&... argumentSet)
    {
        // Point Set Registration: Coherent Point Drift,
        // Andriy Myronenko, Xubo Song,
        // IEEE Transactions on Pattern Analysis and Machine Intelligence,
        // Volume 32, Number 12, December 2010.

        ENSURE_OP(fromSet.n_rows, ==, toSet.n_rows);

        integer d = toSet.n_rows;
        integer m = fromSet.n_cols;
        integer n = toSet.n_cols;

        ENSURE_OP(d, >, 0);
        ENSURE_OP(m, >, 0);
        ENSURE_OP(n, >, 0);

        constexpr Real defaultMinError = 
            std::is_same<Real, float>::value ? 1e-4 : 1e-11;

        // Optional input arguments
        Real noiseRatio = 
            PASTEL_ARG_S(noiseRatio, 0.2);
        integer minIterations = 
            PASTEL_ARG_S(minIterations, 0);
        integer maxIterations = 
            PASTEL_ARG_S(maxIterations, std::max(minIterations, (integer)100));
        Real minError = 
            PASTEL_ARG_S(minError, defaultMinError);
        Cpd_Matrix matrix = 
            PASTEL_ARG_ENUM(matrix, Cpd_Matrix::Free);
        Cpd_Scaling scaling = 
            PASTEL_ARG_ENUM(scaling, Cpd_Scaling::Free);
        Cpd_Translation translation = 
            PASTEL_ARG_ENUM(translation, Cpd_Translation::Free);
        integer orientation = 
            PASTEL_ARG_S(orientation, (integer)1);
        arma::Mat<Real> Q = 
            PASTEL_ARG_S(Q0, arma::Mat<Real>());
        arma::Mat<Real> S = 
            PASTEL_ARG_S(S0, arma::Mat<Real>());
        arma::Col<Real> t = 
            PASTEL_ARG_S(t0, arma::Col<Real>());
        auto&& report =
            PASTEL_ARG_S(report, nullOutput());

        ENSURE(noiseRatio > 0);
        ENSURE(noiseRatio < 1);
        ENSURE_OP(minIterations, >=, 0);
        ENSURE_OP(minIterations, <=, maxIterations);

        if (Q.is_empty())
        {
            // The initial Q was not specified.

            // Reset the matrix, to clear a possible
            // replicated strict flag.
            Q.reset();

            // Use the identity matrix.
            Q.eye(d, d);
        }

        ENSURE_OP(Q.n_rows, ==, d);
        ENSURE_OP(Q.n_cols, ==, d);

        if (S.is_empty())
        {
            // The initial S was not specified.

            // Reset the matrix, to clear a possible
            // replicated strict flag.
            S.reset();

            // Use the identity matrix.
            S.eye(d, d);
        }

        ENSURE_OP(S.n_rows, ==, d);
        ENSURE_OP(S.n_cols, ==, d);

        if (t.is_empty())
        {
            // The initial t was not specified.

            // Reset the matrix, to clear a possible
            // replicated strict flag.
            t.reset();

            // Use the zero matrix.
            t.zeros(d);
        }

        ENSURE_OP(t.n_rows, ==, d);
        ENSURE_OP(t.n_cols, ==, 1);

        // We wish to preserve the memory storage
        // of Q, S, and t. Store the memory addresses
        // to check the preservation later.
        const Real* qPointer = Q.memptr();
        const Real* sPointer = S.memptr();
        const Real* tPointer = t.memptr();

        // Compute the transformed model-set according

        // to the initial guess.
        arma::Mat<Real> transformedSet = 
            Q * S * fromSet + t * arma::ones<arma::Mat<Real>>(1, m);

        // Returns the transformed set centered on toSet.col(j).
        // Note that it is important that we return by decltype(auto),
        // to capture the expression template.
        auto deltaSet = [&](integer j) -> decltype(auto)
        {
            return transformedSet - toSet.col(j) * arma::ones<arma::Mat<Real>>(1, m);
        };

        // Compute a constant to be used later.
        Real c = (noiseRatio / (1 - noiseRatio)) * ((Real)m / n);

        // Compute an initial estimate for sigma^2.
        Real sigma2 = 0;
        for (integer j = 0; j < n;++j)
        {
            sigma2 += arma::accu(arma::square(deltaSet(j)));
        }
        sigma2 = sigma2 / (d * m * n);

        // The weighting matrix will be computed here.
        arma::Mat<Real> W(m, n);

        // These will be used as temporary space for
        // computing the weighting matrix.
        arma::Row<Real> expSet(m);

        // These will store the previous estimate.
        arma::Mat<Real> qPrev(d, d);
        arma::Mat<Real> sPrev(d, d);
        arma::Col<Real> tPrev(d);

        for (integer iteration = 0; iteration < maxIterations; ++iteration)
        {
            if (sigma2 == 0)
            {
                // Having zero sigma^2 can happen at least with easy
                // cases where there is an exact solution; it is related
                // to fast convergence.
                break;
            }

            // Compute a constant for the improper uniform distribution.
            // Note that this is dependent on 'sigma2', which is being
            // updated at each iteration; this cannot be moved out of
            // the loop.
            Real f = std::pow(2 * constantPi<Real>() * sigma2, (Real)d / 2) * c;

            // Compute the weighting matrix.
            for (integer j = 0;j < n;++j)
            {
                expSet = 
                    arma::exp(
                        -arma::sum(arma::square(deltaSet(j))) / (2 * sigma2)
                    );
                W.col(j) = expSet.t() / (arma::accu(expSet) + f);
            }

            // Store the previous transformation for comparison.
            qPrev = Q;
            sPrev = S;
            tPrev = t;

            // Compute a new estimate for the optimal transformation.
            auto lsMatch = lsAffine(
                fromSet, toSet,
                matrix,
                scaling,
                translation,
                PASTEL_TAG(orientation), orientation,
                PASTEL_TAG(W), W,
                // This avoids the reallocation of 
                // Q, S. and t.
                PASTEL_TAG(Q0), std::move(Q),
                PASTEL_TAG(S0), std::move(S),
                PASTEL_TAG(t0), std::move(t)
                );

            Q = std::move(lsMatch.Q);
            S = std::move(lsMatch.S);
            t = std::move(lsMatch.t);

            // Compute the transformed model-set.
            transformedSet = 
                Q * S * fromSet + t * arma::ones<arma::Mat<Real>>(1, m);

            // Compute a new estimate for sigma^2.
            sigma2 = 0;
            for (integer j = 0;j < n;++j)
            {
                sigma2 += 
                    arma::accu(
                        W.col(j).t() % arma::sum(arma::square(deltaSet(j))) 
                    );
            }
            sigma2 /= arma::accu(W) * d;

            // Report the current estimate.
            Cpd_State<Real> state = 
            {
                std::move(Q),
                std::move(S),
                std::move(t),
                sigma2,
                std::move(W)
            };

            report(addConst(state));

            Q = std::move(state.Q);
            S = std::move(state.S);
            t = std::move(state.t);
            W = std::move(state.W);

            real qError = arma::norm(qPrev - Q, "inf");
            real sError = arma::norm(sPrev - S, "inf");
            real tError = arma::norm(tPrev - t, "inf");

            if (std::max(std::max(qError, sError), tError) <= minError && 
                iteration + 1 >= minIterations)
            {
                // When the change to the previous transformation 
                // falls below the given error threshold, we will 
                // stop, provided that a minimum number of iterations
                // has been performed.
               break;
            }
        }

        // Make sure that memory was not reallocated.
        ASSERT(Q.memptr() == qPointer);
        unused(qPointer);

        ASSERT(S.memptr() == sPointer);
        unused(sPointer);

        ASSERT(t.memptr() == tPointer);
        unused(tPointer);

        return {std::move(Q), std::move(S), std::move(t), sigma2};
    }

}

#endif