kdtree_nearestset.h

Back to Nearest set

pastel/geometry/nearestset/

// Description: Nearest neighbors searching in a kd-tree
// Documentation: nearestset.txt

#ifndef PASTELGEOMETRY_KDTREE_NEARESTSET_H
#define PASTELGEOMETRY_KDTREE_NEARESTSET_H

// Template concepts

#include "pastel/sys/indicator/indicator_concept.h"
#include "pastel/sys/output/output_concept.h"
#include "pastel/sys/point/point_concept.h"
#include "pastel/math/normbijection/normbijection_concept.h"
#include "pastel/geometry/nearestset/nearestset_concept.h"

// Template defaults

#include "pastel/math/normbijection/euclidean_normbijection.h"
#include "pastel/sys/indicator/all_indicator.h"
#include "pastel/sys/output/null_output.h"
#include "pastel/geometry/depthfirst_pointkdtree_searchalgorithm.h"

// Template requirements

#include "pastel/sys/type_traits/is_template_instance.h"
#include "pastel/geometry/pointkdtree/pointkdtree_fwd.h"
#include "pastel/geometry/tdtree/tdtree_fwd.h"

// Implementation requirements

#include "pastel/geometry/distance/distance_point_point.h"
#include "pastel/geometry/distance/distance_alignedbox_point.h"

#include "pastel/sys/rankedset/rankedset.h"
#include "pastel/sys/set/range_set.h"

namespace Pastel
{

    template <
        typename KdTree,
        typename IntervalSequence_,
        typename SearchAlgorithm>
    class KdTree_NearestSet
    {
    public:
        using Fwd = KdTree;
        PASTEL_FWD(Real);
        PASTEL_FWD(Cursor);
        PASTEL_FWD(Point_ConstIterator);
        using ConstIterator = Point_ConstIterator;

        KdTree_NearestSet(
            const KdTree& kdTree_,
            const Real& maxRelativeError_,
            integer nBruteForce_,
            IntervalSequence_ timeIntervalSequence_)
        : kdTree(kdTree_)
        , maxRelativeError(maxRelativeError_)
        , nBruteForce(nBruteForce_)
        , timeIntervalSequence(std::move(timeIntervalSequence_))
        {
        }

        auto pointSet() const
        {
            return intervalSet(kdTree.begin(), kdTree.end());
        }

        auto begin() const
        {
            using std::begin;
            return begin(pointSet());
        }

        auto end() const
        {
            using std::end;
            return end(pointSet());
        }

        const auto& pointSetLocator() const
        {
            return kdTree.locator();
        }

        decltype(auto) asPoint(Point_ConstIterator point) const
        {
            return location(point->point(), kdTree.locator());
        }

        template <
            typename Search_Point,
            typename NormBijection,
            typename Real,
            typename Output,
            Requires<
                Models<Search_Point, Point_Concept>,
                Models<NormBijection, NormBijection_Concept>
            > = 0
        >
        void findNearbyPointsets(
            const Search_Point& searchPoint,
            const NormBijection& normBijection,
            const Real& maxDistance2,
            const Output& report) const
        {
            Real cullDistance2 = maxDistance2;

            const Real errorFactor = 
                inverse(normBijection.scalingFactor(1 + maxRelativeError));
            Real nodeCullDistance2 = 
                cullDistance2 * errorFactor;

            if (kdTree.empty())
            {
                // There is nothing to search for.
                // Note that we consider the search-ball open.
                return;
            }

            // Compute the distance from the search-point to the
            // bounding-box of the kd-tree.
            Real rootDistance2 = 
                distance2(kdTree.bound(), searchPoint, normBijection);
            if (rootDistance2 >= maxDistance2)
            {
                // The bounding box for the points does not
                // intersect the search ball.
                return;
            }

            using TimeSequence = 
                RemoveCvRef<decltype(timeIntervalSequence)>;
            using IntervalSequence = 
                Vector<integer, Point_N<TimeSequence>::value>;

            // The temporal restriction is given as a union
            // of time-intervals. Convert the time-points to
            // indices in the point-set.
            IntervalSequence indexSequence(
                ofDimension(timeIntervalSequence.n()));
            for (integer i = 0;i < timeIntervalSequence.size();i += 2)
            {
                indexSequence[i] = kdTree.timeToIndex(
                    timeIntervalSequence[i]);

                if (i + 1 < timeIntervalSequence.size())
                {
                    indexSequence[i + 1] = kdTree.timeToIndex(
                        timeIntervalSequence[i + 1]);
                }
            }

            struct State
            : boost::less_than_comparable<State>
            {
                State(
                    const Cursor& cursor_,
                    const Real& distance_,
                    const IntervalSequence& indexSequence_)
                : cursor(cursor_)
                , distance(distance_)
                , indexSequence(indexSequence_)
                {
                }

                bool operator<(const State& that) const
                {
                    return distance < that.distance;
                }

                Cursor cursor;
                Real distance;
                IntervalSequence indexSequence;
            };

            using SearchAlgorithm_ = 
                typename SearchAlgorithm::template Instance<State>;

            SearchAlgorithm_ searchAlgorithm;

            // Start from the root node.
            searchAlgorithm.insertNode(
                State(kdTree.root(), rootDistance2, indexSequence));

            auto intervalDistance = [&](
                const Real& x,
                const Real& min, 
                const Real& max)
            {
                Real distance = 0;

                if (x < min)
                {
                    distance = 
                        normBijection.axis(min - x);
                }
                else if (x > max)
                {
                    distance =
                        normBijection.axis(x - max);
                }

                return distance;
            };

            auto queueNode = [&](
                const State& parent,
                const Real& searchPosition,
                bool right)
            {
                const Cursor node = right ? parent.cursor.right() : parent.cursor.left();

                const Real axisDistance = 
                    intervalDistance(
                        searchPosition, node.min(), node.max());

                const Real axisDistancePrev = 
                    intervalDistance(
                        searchPosition, node.prevMin(), node.prevMax());

                const Real newDistance2 = 
                    normBijection.replaceAxis(
                        parent.distance,
                        axisDistancePrev,
                        axisDistance);

                IntervalSequence sequence(parent.indexSequence);
                for (integer i = 0;i < parent.indexSequence.size();++i)
                {
                    sequence[i] = parent.cursor.cascade(parent.indexSequence[i], right);
                }

                State state(node, newDistance2, sequence);

                if (newDistance2 > nodeCullDistance2 ||
                    searchAlgorithm.skipNode(state))
                {
                    return;
                }

                searchAlgorithm.insertNode(
                    std::move(state));
            };

            while (searchAlgorithm.nodesLeft())
            {
                State state = searchAlgorithm.nextNode();

                const Real& distance = state.distance;
                const Cursor& cursor = state.cursor;
                const IntervalSequence& intervalSequence = state.indexSequence;

                if (distance > nodeCullDistance2)
                {
                    if (searchAlgorithm.breakOnCulling())
                    {
                        break;
                    }

                    continue;
                }

                // Search a node with brute-force if it is a leaf node, or the 
                // search-algorithm says so. Usually the latter is when 
                // cursor.points() <= nBruteForce.
                if (cursor.leaf() ||
                    searchAlgorithm.shouldSearchSplitNode(state, nBruteForce))
                {
                    // Search the node using brute-force.

                    Cursor cursor = state.cursor;

                    // Having const IntervalSequence& here triggers a
                    // bug in gcc 4.9.2. But const auto& works!
                    const auto& indexSequence = state.indexSequence;

                    for (integer i = 0; i < indexSequence.size(); i += 2)
                    {
                        // For each pair of integers in the index-sequence...

                        // The index-sequence is a sequence of integer 
                        // pairs (i, j). If the index-sequence has an odd
                        // number of elements, then the last index is
                        // implicitly taken to be infinity.
                        integer indexMin = indexSequence[i];
                        integer indexMax = (i + 1) < indexSequence.size() ?
                            indexSequence[i + 1] : cursor.points();

                        const Real cullSuggestion2 = 
                            report(
                                cursor.pointSet(indexMin, indexMax), 
                                cullDistance2);
                        if (cullSuggestion2 < cullDistance2)
                        {
                            cullDistance2 = cullSuggestion2;
                            nodeCullDistance2 = cullDistance2 * errorFactor;
                        }
                    }

                    continue;
                }

                // For an intermediate node our task is to
                // recurse to child nodes while updating
                // incrementally the distance 
                // to the current node. 

                integer splitAxis = cursor.splitAxis();
                Real searchPosition = 
                    pointAxis(searchPoint, splitAxis);

                // Queue non-culled child nodes for 
                // future handling.
                queueNode(state, searchPosition, false);
                queueNode(state, searchPosition, true);
            }

        }

        const KdTree& kdTree;
        Real maxRelativeError;
        integer nBruteForce;
        IntervalSequence_ timeIntervalSequence;
    };

    //! Constructs a kd-tree nearest-set.
    /*!
   kdTree:
   The kd-tree to search neighbors in. 
   Either a PointKdTree or a TdTree.

   Optional arguments
   ------------------

   maxRelativeError (Real >= 0):
   Maximum allowed relative error in the distance of the 
   result point to the true nearest neighbor. Allowing error
   increases performance. Use 0 for exact matches. 
   Default: 0

   nBruteForce (integer >= 0):
   The number of points under which to start a brute-force
   search in a node. Leaf nodes will always be searched.
   Default: 16

   searchAlgorithm:
   The search-algorithm to use for searching the 'kdTree'.
   See 'pointkdtree_searchalgorithm.txt'.

   timeIntervalSequence:
   An interval sequence in time. A sequence 
   (t_1, t_2, t_3, t_4, ...) corresponds to the
   time-intervals [t_1, t_2), [t_3, t_4), ...
   If the number of time-instants is odd, then
   the sequence is implicitly appended 
   (Real)Infinity().
   */
    template <
        typename KdTree,
        typename... ArgumentSet>
    decltype(auto) kdTreeNearestSet(
        const KdTree& kdTree,
        ArgumentSet&&... argumentSet)
    {
        using Fwd = KdTree;
        PASTEL_FWD(Real);

        Real maxRelativeError = PASTEL_ARG_S(maxRelativeError, 0);
        ENSURE_OP(maxRelativeError, >=, 0);

        integer nBruteForce = PASTEL_ARG_S(nBruteForce, 16);
        ENSURE_OP(nBruteForce, >=, 0);

        auto&& timeIntervalSequence = 
            PASTEL_ARG(
                intervalSequence, 
                []() {return Vector<Real, 2>({-(Real)Infinity(), (Real)Infinity()});},
                [](auto input) {return Models<decltype(input), Point_Concept>();}
            );

        auto timeIntervalSequence_ = evaluate(pointAsVector(timeIntervalSequence));
        using IntervalSequence = decltype(timeIntervalSequence_);

        auto&& searchAlgorithmObject =
            PASTEL_ARG(
                searchAlgorithm, 
                []() {return DepthFirst_SearchAlgorithm_PointKdTree();},
                [](auto input) {return std::true_type();}
            );
        using SearchAlgorithm = RemoveCvRef<decltype(searchAlgorithmObject)>;

        return KdTree_NearestSet<KdTree, IntervalSequence, SearchAlgorithm>(
            kdTree,
            maxRelativeError,
            nBruteForce,
            std::move(timeIntervalSequence_));
    }

}

#endif