coherent_point_drift.h

Back to Coherent point drift

pastel/geometry/pattern_matching/

// Description: Coherent point drift algorithm.
// Documentation: coherent_point_drift.txt

#ifndef PASTELGEOMETRY_COHERENT_POINT_DRIFT_H
#define PASTELGEOMETRY_COHERENT_POINT_DRIFT_H

#include "pastel/geometry/pattern_matching/ls_affine.h"
#include "pastel/geometry/nearestset/nearestset_concept.h"
#include "pastel/sys/output/null_output.h"
#include "pastel/sys/math/constants.h"

namespace Pastel
{

    using Cpd_Matrix = LsAffine_Matrix;
    using Cpd_Scaling = LsAffine_Scaling;
    using Cpd_Translation = LsAffine_Translation;

    template <typename Real>
    struct Cpd_State
    {
        MatrixView<Real> Q;
        MatrixView<Real> S;
        MatrixView<Real> t;
        Real sigma2;
        MatrixView<Real> W;
    };

    //! Coherent point drift algorithm.
    /*!
    Preconditions:
    0 <= minIterations <= maxIterations
    0 < noiseRatio < 1
    n > 0
    m > 0
    d > 0

    Finds matrices Q, S, and t such that

        (Q * S * fromSet).colwise() + t

    matches toSet.

    Input
    -----

    fromSet ((d x n) dreal matrix):
    A set of n points, given as a matrix, where each column
    contains the coordinates of a d-dimensional point. 

    toSet ((d x m) dreal matrix):
    A set of m points, given as a matrix, where each column
    contains the coordinates of a d-dimensional point.

    Qs ((d x d) real matrix):
    Initial guess on Q and storage for the solution; an orthogonal matrix. 
    Use identity matrix if no better initial guess.

    Ss ((d x d) real matrix):
    Initial guess on S and storage for the solution; a symmetric matrix. 
    Use identity matrix if no better initial guess.
    
    ts ((d x 1) real vector): 
    Initial guess on t and storage for the solution; a column matrix. 
    Use zero vector if no better initial guess.

    Returns
    -------

    sigma2 (Real):
    The estimated variance is Identity(d, d) * sigma2.

    Optional input arguments
    ------------------------

    initialize (bool : true):
    Whether to initialize the matrices to defaults:
    Q = identity
    S = identity 
    t = zero
    If false, the passed matrices are used instead.

    noiseRatio (Real : 0.2):
    A dreal number between (0, 1), which gives the weight for an 
    additive improper uniform distribution component for the Gaussian 
    mixture model. Larger noise-ratio makes the algorithm more 
    tolerant to noise, but declines convergence rate when the actual 
    noise level is lower.

    matrix (Cpd_Matrix : Free):
    Specifies constraints for the matrix Q.
        Free: Q^T Q = I
        Identity: Q = I

    scaling (Cpd_Scaling : Free):
    Specifies constraints for the scaling S.
        Free: S^T = S
        Diagonal: S is diagonal
        Conformal: S = sI
        Rigid: S = I

    translation (Cpd_Translation : Free):
    Specifies constraints for the translation t. 
        Free: no constraint
        Identity: T = 0

    orientation (integer : 1): 
    Specifies constraints for the determinant of A.
        <0: det(A) < 0,
         0: no constraint
        >0: det(A) > 0.

    minIterations (integer : 0):
    The minimum number of iterations for the algorithm to take.

    maxIterations (integer : std::max(minIterations, 100)):
    The maximum number of iterations for the algorithm to take. 

    minError (Real : see below):
    The minimum error under which to accept the transformation 
    and stop iteration. For float 1e-4; for double 1e-11.
    */
    template <
        typename Real_from, int M_from, int N_from,
        typename Real_to, int M_to, int N_to,
        typename Real_Q, int M_Q, int N_Q,
        typename Real_S, int M_S, int N_S,
        typename Real_t, int M_t,
        typename... ArgumentSet
    >
    requires 
        IsPlain<Real_Q> &&
        IsPlain<Real_S> && 
        IsPlain<Real_t> && 
        IsSameObject<Real_from, Real_to, Real_Q, Real_S, Real_t>
    Real_from coherentPointDrift(
        const MatrixView<Real_from, M_from, N_from>& fromSet,
        const MatrixView<Real_to, M_to, N_to>& toSet,
        const MatrixView<Real_Q, M_Q, N_Q>& Qs,
        const MatrixView<Real_S, M_S, N_S>& Ss,
        const ColMatrixView<Real_t, M_t>& ts,
        ArgumentSet&&... argumentSet)
    {
        // Point Set Registration: Coherent Point Drift,
        // Andriy Myronenko, Xubo Song,
        // IEEE Transactions on Pattern Analysis and Machine Intelligence,
        // Volume 32, Number 12, December 2010.

        ENSURE_OP(fromSet.rows(), ==, toSet.rows());

        using Real = Real_from;

        constexpr const int D = Common_Dimension<M_from, M_to, M_Q, M_S, M_t, N_Q, N_S>;

        MapMatrix<Real, D, N_from> P = asMatrix(fromSet);
        MapMatrix<Real, D, N_to> R = asMatrix(toSet);

        integer d = P.rows();
        integer m = P.cols();
        integer n = R.cols();

        ENSURE_OP(d, >, 0);
        ENSURE_OP(m, >, 0);
        ENSURE_OP(n, >, 0);

        ENSURE_OP(Qs.rows(), ==, d);
        ENSURE_OP(Qs.cols(), ==, d);
        ENSURE_OP(Ss.rows(), ==, d);
        ENSURE_OP(Ss.cols(), ==, d);
        ENSURE_OP(ts.rows(), ==, d);
        ENSURE_OP(ts.cols(), ==, 1);

        constexpr Real defaultMinError = 
            std::is_same<Real, float>::value ? 1e-4 : 1e-11;

        // Optional input arguments
        bool initialize =
            PASTEL_ARG_S(initialize, true);
        Real noiseRatio = 
            PASTEL_ARG_S(noiseRatio, 0.2);
        integer minIterations = 
            PASTEL_ARG_S(minIterations, 0);
        integer maxIterations = 
            PASTEL_ARG_S(maxIterations, std::max(minIterations, (integer)100));
        Real minError = 
            PASTEL_ARG_S(minError, defaultMinError);
        Cpd_Matrix matrix = 
            PASTEL_ARG_ENUM(matrix, Cpd_Matrix::Free);
        Cpd_Scaling scaling = 
            PASTEL_ARG_ENUM(scaling, Cpd_Scaling::Free);
        Cpd_Translation translation = 
            PASTEL_ARG_ENUM(translation, Cpd_Translation::Free);
        integer orientation = 
            PASTEL_ARG_S(orientation, (integer)1);
        auto&& report =
            PASTEL_ARG_S(report, nullOutput());

        ENSURE(noiseRatio > 0);
        ENSURE(noiseRatio < 1);
        ENSURE_OP(minIterations, >=, 0);
        ENSURE_OP(minIterations, <=, maxIterations);

        MapMatrix<Real, D, D> Q = asMatrix(Qs);
        MapMatrix<Real, D, D> S = asMatrix(Ss);
        MapColMatrix<Real, D> t = asMatrix(ts);

        if (initialize) {
            // Initialize Q, S, and t.
            Q = Matrix<Real, D, D>::Identity(d, d);
            S = Matrix<Real, D, D>::Identity(d, d);
            t = ColMatrix<Real, D>::Zero(d, 1);
        }

        // Compute the transformed model-set according
        // to the initial guess.
        Matrix<Real, D, N_from> transformedSet = 
            (Q * S * P).colwise() + t;

        // Returns the transformed set centered on toSet.col(j).
        // Note that it is important that we return by decltype(auto),
        // to capture the expression template.
        auto deltaSet = [&](integer j) -> decltype(auto)
        {
            return transformedSet.colwise() - R.col(j);
        };

        // Compute a constant to be used later.
        Real c = (noiseRatio / (1 - noiseRatio)) * ((Real)m / n);

        // Compute an initial estimate for sigma^2.
        Real sigma2 = 0;
        for (integer j = 0; j < n;++j)
        {
            sigma2 += deltaSet(j).array().square().sum();
        }
        sigma2 = sigma2 / (d * m * n);

        // The weighting matrix will be computed here.
        Matrix<Real, N_from, N_to> W(m, n);
        MatrixView<Real, N_from, N_to> Ws = view(W);

        // These will be used as temporary space for
        // computing the weighting matrix.
        RowMatrix<Real, D> expSet(m);

        // These will store the previous estimate.
        Matrix<Real, D, D> qPrev(d, d);
        Matrix<Real, D, D> sPrev(d, d);
        ColMatrix<Real, D> tPrev(d, 1);

        for (integer iteration = 0; iteration < maxIterations; ++iteration)
        {
            if (sigma2 == 0)
            {
                // Having zero sigma^2 can happen at least with easy
                // cases where there is an exact solution; it is related
                // to fast convergence.
                break;
            }

            // Compute a constant for the improper uniform distribution.
            // Note that this is dependent on 'sigma2', which is being
            // updated at each iteration; this cannot be moved out of
            // the loop.
            Real f = std::pow(2 * constantPi<Real>() * sigma2, (Real)d / 2) * c;

            // Compute the weighting matrix.
            for (integer j = 0;j < n;++j)
            {
                expSet = (deltaSet(j).array().square().colwise().sum() / (-2 * sigma2)).exp();
                W.col(j) = expSet.transpose() / (expSet.sum() + f);
            }

            // Store the previous transformation for comparison.
            qPrev = Q;
            sPrev = S;
            tPrev = t;

            // Compute a new estimate for the optimal transformation.
            lsAffine(
                fromSet, toSet,
                Qs, Ss, ts,
                PASTEL_TAG(matrix), matrix,
                PASTEL_TAG(scaling), scaling,
                PASTEL_TAG(translation), translation,
                PASTEL_TAG(orientation), orientation,
                PASTEL_TAG(W), Ws
                );

            // Compute the transformed model-set.
            transformedSet = 
                (Q * S * P).colwise() + t;

            // Compute a new estimate for sigma^2.
            sigma2 = 0;
            for (integer j = 0;j < n;++j)
            {
                sigma2 += 
                    (W.col(j).transpose().array() * deltaSet(j).array().square().colwise().sum()).sum();
            }
            sigma2 /= W.sum() * d;

            // Report the current estimate.
            Cpd_State<Real> state = 
            {
                Qs, Ss, ts, sigma2, Ws
            };

            report(addConst(state));

            dreal qError = maxNorm(qPrev - Q);
            dreal sError = maxNorm(sPrev - S);
            dreal tError = maxNorm(tPrev - t);

            if (std::max(std::max(qError, sError), tError) <= minError && 
                iteration + 1 >= minIterations)
            {
                // When the change to the previous transformation 
                // falls below the given error threshold, we will 
                // stop, provided that a minimum number of iterations
                // has been performed.
               break;
            }
        }

        return sigma2;
    }

}

#endif