matlab_pointkdtree.cpp

Back to PointKdTree (Matlab)

pastel/geometrymatlab/

// Description: PointKdTree Matlab interface
// Documentation: pointkdtree_matlab.txt

#include "pastel/geometry/pointkdtree/pointkdtree.h"
#include "pastel/geometry/count_nearest.h"
#include "pastel/geometry/search_nearest.h"
#include "pastel/geometry/nearestset/kdtree_nearestset.h"
#include "pastel/math/normbijection.h"

#include "pastel/sys/allocator/pool_allocator.h"
#include "pastel/sys/indicator/predicate_indicator.h"

#include <boost/range/algorithm/fill.hpp>

#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>

#include <unordered_map>

#include "pastel/matlab/pastelmatlab.h"

void force_linking_pointkdtree() {}

namespace Pastel
{

    namespace
    {

        struct TreePoint
        {
            TreePoint()
                : data(0)
                , id(0)
            {
            }

            TreePoint(real* data_, integer id_)
                : data(data_)
                , id(id_)
            {
            }

            real* data;
            integer id;
        };

        class TreePoint_Locator
        {
        public:
            using Real = real;
            using Point = TreePoint;
            static constexpr integer N = Dynamic;

            explicit TreePoint_Locator(
                integer dimension)
                : dimension_(dimension)
            {
            }

            const Real& operator()(
                const Point& point, integer i) const
            {
                return *(point.data + i);
            }

            integer n() const
            {
                return dimension_;
            }

            integer n(const Point& point) const
            {
                return n();
            }

        private:
            integer dimension_;
        };

        using Settings = PointKdTree_Settings<TreePoint_Locator>;

        using Tree = PointKdTree<Settings>;
        using Point_ConstIterator = Tree::Point_ConstIterator;
        using Point_ConstRange = Tree::Point_ConstRange;
        using IndexMap = std::unordered_map<integer, Point_ConstIterator>;
        using ConstIterator = IndexMap::const_iterator;

        struct KdState
        {
            explicit KdState(integer dimension)
                : tree(TreePoint_Locator(dimension))
                , indexMap()
                , pointAllocator(dimension * sizeof(real))
                , index(1)
            {
            }

            explicit KdState(const KdState& that)
                : tree(that.tree)
                , indexMap()
                , pointAllocator(that.tree.n() * sizeof(real))
                , index(that.index)
            {
                // Remove points, but retain structure.
                tree.erase(true);

                const Tree& thatTree = that.tree;
                integer dimension = thatTree.n();

                // Copy point data.
                {
                    Point_ConstRange range = thatTree.range();
                    while(!range.empty())
                    {

                        const TreePoint& treePoint =
                            range.front().point();

                        real* data = (real*)pointAllocator.allocate();
                        real* thatData = treePoint.data;

                        std::copy(thatData, thatData + dimension, data);

                        integer index = treePoint.id;

                        Point_ConstIterator iter = tree.insert(TreePoint(data, index));
                        indexMap.insert(std::make_pair(index, iter));

                        range.pop_front();
                    }
                }

                // Copy hidden data.
                {
                    Point_ConstRange range = thatTree.hiddenRange();
                    while(!range.empty())
                    {

                        const TreePoint& treePoint =
                            range.front().point();

                        real* data = (real*)pointAllocator.allocate();
                        real* thatData = treePoint.data;

                        std::copy(thatData, thatData + dimension, data);

                        integer index = treePoint.id;

                        Point_ConstIterator iter = tree.insert(
                            TreePoint(data, index), true);
                        indexMap.insert(std::make_pair(index, iter));

                        range.pop_front();
                    }
                }
            }

            ~KdState()
            {
                tree.clear();
                pointAllocator.clear();
            }

            Tree tree;
            // Mapping from integer identifiers to point iterators.
            IndexMap indexMap;
            // Memory region for point data.
            PoolAllocator pointAllocator;
            // The current allocation index for identifiers.
            integer index;
        };

        KdState* asState(const mxArray* matlabArray)
        {
            return *((KdState**)mxGetData(matlabArray));
        }

        void kdAsPoints(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                IdSet,
                Inputs
            };

            enum
            {
                PointSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            Array<integer> idSet = matlabAsLinearizedArray<integer>(inputSet[IdSet]);

            integer d = state->tree.n();
            integer n = idSet.size();

            Array<real> pointSet = matlabCreateArray<real>(
                n, d, outputSet[PointSet]);
            for (integer i = 0;i < n;++i)
            {
                integer id = idSet(i);
                ConstIterator iter = state->indexMap.find(id);

                if (iter != state->indexMap.end())
                {
                    Point_ConstIterator pointIter =
                        iter->second;

                    std::copy(
                        pointIter->point().data,
                        pointIter->point().data + d,
                        pointSet.columnBegin(i));
                }                    
                else
                {
                    std::fill(
                        pointSet.columnBegin(i),
                        pointSet.columnEnd(i),
                        (real)Nan());
                }
            }
        }

        void kdConstruct(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                Dimension,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            integer dimension = matlabAsScalar<integer>(inputSet[Dimension]);

            // Get enough memory to hold a KdState pointer.

            KdState** rawResult = matlabCreateScalar<KdState*>(outputSet[0]);

            // Create a new KdState and store the pointer to
            // the returned array.
            KdState* state = new KdState(dimension);
            *rawResult = state;
        }

        void kdDestruct(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            delete state;
        }

        void kdCopy(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);

            // Get enough memory to hold a KdState pointer.
            KdState** rawResult = matlabCreateScalar<KdState*>(outputSet[0]);

            // Create a new KdState and store the pointer to
            // the returned array.
            KdState* newState = new KdState(*state);
            *rawResult = newState;
        }

        void kdCheck(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            Tree& tree = state->tree;

            ENSURE(testInvariants(tree));
        }

        void kdMerge(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            state->tree.merge();
        }

        void kdClear(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            state->tree.clear();
        }

        void kdInsert(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                PointSet,
                Inputs
            };

            enum
            {
                IdSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            Tree& tree = state->tree;

            Array<real> pointSet = matlabAsArray<real>(inputSet[PointSet]);
            integer points = pointSet.width();

            Array<integer> result =
                matlabCreateArray<integer>(points, 1, outputSet[IdSet]);
            for (integer i = 0;i < points;++i)
            {

                real* data = (real*)state->pointAllocator.allocate();
                std::copy(pointSet.cColumnBegin(i), pointSet.cColumnEnd(i), data);

                integer index = state->index;
                Point_ConstIterator iter = tree.insert(TreePoint(data, index));
                state->indexMap.insert(std::make_pair(index, iter));
                ++state->index;

                result(i) = index;
            }
        }

        void kdErase(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                IdSet,
                Inputs
            };

            enum
            {
                Outputs
            };

            ENSURE_OP(inputs, >=, IdSet);

            KdState* state = asState(inputSet[State]);
            Tree& tree = state->tree;

            if (inputs <= IdSet)
            {
                // Erase all;
                tree.erase();
                return;
            }

            Array<integer> idSet = matlabAsLinearizedArray<integer>(inputSet[IdSet]);

            integer points = idSet.size();
            for (integer i = 0;i < points;++i)
            {
                integer id = idSet(i);
                ConstIterator iter = 
                    state->indexMap.find(id);
                if (iter != state->indexMap.end())
                {
                    state->tree.erase(iter->second);
                    state->indexMap.erase(id);
                }
            }
        }

        void kdHide(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                IdSet,
                Inputs
            };

            enum
            {
                Outputs
            };

            ENSURE_OP(inputs, >=, IdSet);

            KdState* state = asState(inputSet[State]);
            Tree& tree = state->tree;

            if (inputs <= IdSet)
            {
                // Hide all;
                tree.hide();
                return;
            }

            Array<integer> idSet = matlabAsLinearizedArray<integer>(inputSet[IdSet]);

            integer points = idSet.size();
            for (integer i = 0;i < points;++i)
            {
                ConstIterator iter = 
                    state->indexMap.find(idSet(i));
                if (iter != state->indexMap.end())
                {
                    state->tree.hide(iter->second);
                }
            }
        }

        void kdShow(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                IdSet,
                Inputs
            };

            enum
            {
                Outputs
            };

            ENSURE_OP(inputs, >=, IdSet);

            KdState* state = asState(inputSet[State]);
            Tree& tree = state->tree;

            if (inputs <= IdSet)
            {
                // Show all;
                tree.show();
                return;
            }

            Array<integer> idSet = matlabAsLinearizedArray<integer>(inputSet[IdSet]);

            integer points = idSet.size();
            for (integer i = 0;i < points;++i)
            {
                ConstIterator iter = 
                    state->indexMap.find(idSet(i));
                if (iter != state->indexMap.end())
                {
                    state->tree.show(iter->second);
                }
            }
        }

        void kdRefine(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                BucketSize,
                Inputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);
            integer bucketSize = matlabAsScalar<integer>(inputSet[BucketSize]);

            state->tree.refine(
                SlidingMidpoint_SplitRule(), bucketSize);
        }

        void kdNodes(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            enum
            {
                Nodes,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);

            integer *outNodes = matlabCreateScalar<integer>(outputSet[Nodes]);
            *outNodes = state->tree.nodes();
        }

        void kdLeaves(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            enum
            {
                Leaves,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);

            integer* outLeaves = matlabCreateScalar<integer>(outputSet[Leaves]);
            *outLeaves = state->tree.leaves();
        }

        void kdDimension(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            enum
            {
                Dimension,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);

            integer* outDimension = matlabCreateScalar<integer>(
                outputSet[Dimension]);
            *outDimension = state->tree.n();
        }

        void kdPoints(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                Inputs
            };

            enum
            {
                Points,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);

            KdState* state = asState(inputSet[State]);

            integer* outPoints = matlabCreateScalar<integer>(
                outputSet[Points]);
            *outPoints = state->tree.points();
        }

        template <typename NormBijection>
        void kdSearchNearest_(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                QuerySet,
                MaxDistanceSet,
                KNearest,
                Inputs
            };

            enum
            {
                IdSet,
                DistanceSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);
            ENSURE_OP(outputs, >=, 1);
            ENSURE_OP(outputs, <=, Outputs);

            // nearestIdSet = matlab_pointkdtree('pointkdtree_search_nearest', ...
            //     kdtree, querySet, maxDistanceSet, kNearest);

            KdState* state = asState(inputSet[State]);
            const IndexMap& indexMap = state->indexMap;
            Array<real> maxDistanceSet = matlabAsLinearizedArray<real>(inputSet[MaxDistanceSet]);
            integer kNearest = matlabAsScalar<integer>(inputSet[KNearest]);

            mxClassID id = mxGetClassID(inputSet[QuerySet]);

            bool queriesAreCoordinates = 
                id == mxSINGLE_CLASS ||
                id == mxDOUBLE_CLASS;

            bool wantDistance = (outputs >= 2);

            integer queries = 0;
            if (queriesAreCoordinates)
            {
                queries = mxGetN(inputSet[QuerySet]);
            }
            else
            {
                queries = mxGetNumberOfElements(inputSet[QuerySet]);
            }

            Array<integer> nearestArray =
                matlabCreateArray<integer>(kNearest, queries, outputSet[IdSet]);

            Array<real> distanceArray;
            if (wantDistance)
            {
                // Having the matlabCreateArray<real>() call directly
                // inside the swap() function triggers an
                // internal compiler error in Clang.
                auto copyArray = matlabCreateArray<real>(kNearest, queries, outputSet[DistanceSet]);
                distanceArray.swap(copyArray);
                boost::fill(distanceArray.range(), (real)Infinity());
            }

            using Block = tbb::blocked_range<integer>;

            if (queriesAreCoordinates)
            {
                // The queries are a set of points given explicitly
                // by their coordinates. Each column is a query point.
                Array<real> querySet = matlabAsArray<real>(inputSet[QuerySet]);
                integer n = state->tree.n();

                // Find the k-nearest-neighbors for each point.

                auto search = [&](const Block& block)
                {
                    Vector<real> query(ofDimension(n));
                    for (integer i = block.begin(); i < block.end(); ++i)
                    {
                        std::copy(
                            querySet.cColumnBegin(i),
                            querySet.cColumnEnd(i),
                            query.begin());

                        integer j = 0;
                        auto report = [&](
                            real distance,
                            Point_ConstIterator point)
                        {
                            if (point != state->tree.end())
                            {
                                nearestArray(j, i) = point->point().id;
                                if (wantDistance)
                                {
                                    distanceArray(j, i) = distance;
                                }
                            }
                            ++j;
                        };

                        searchNearest(
                            kdTreeNearestSet(state->tree), 
                            query, 
                            PASTEL_TAG(report), report,
                            NormBijection(),
                            PASTEL_TAG(kNearest), kNearest,
                            PASTEL_TAG(maxDistance2), maxDistanceSet(i));
                    }
                };

                tbb::parallel_for(Block(0, queries), search);
            }
            else
            {
                // The queries are over the points in the kd-tree,
                // given by their ids.
                Array<integer> querySet = matlabAsLinearizedArray<integer>(inputSet[QuerySet]);

                // Find the k-nearest-neighbors for each point.
                auto search = [&](const Block& block)
                {
                    for (integer i = block.begin(); i < block.end(); ++i)
                    {
                        integer j = 0;
                        auto report = [&](
                            real distance,
                            Point_ConstIterator point)
                        {
                            if (point != state->tree.end())
                            {
                                nearestArray(j, i) = point->point().id;
                                if (wantDistance)
                                {
                                    distanceArray(j, i) = distance;
                                }
                            }
                            ++j;
                        };

                        auto query_ = indexMap.find(querySet(i));
                        if (query_ == indexMap.end())
                        {
                            // The query index does not exist.
                            continue;
                        }

                        Point_ConstIterator query = query_->second;

                        searchNearest(
                            kdTreeNearestSet(state->tree),
                            location(query->point(), state->tree.locator()),
                            PASTEL_TAG(report), report,
                            PASTEL_TAG(accept), predicateIndicator(query, NotEqualTo()), 
                            NormBijection(),
                            PASTEL_TAG(kNearest), kNearest,
                            PASTEL_TAG(maxDistance2), maxDistanceSet(i));
                    }
                };

                tbb::parallel_for(Block(0, queries), search);
            }
        }

        void kdSearchNearest(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                QuerySet,
                MaxDistanceSet,
                KNearest,
                Norm,
                Inputs
            };

            enum
            {
                IdSet,
                DistanceSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);
            ENSURE_OP(outputs, >=, 1);
            ENSURE_OP(outputs, <=, Outputs);

            // nearestIdSet = matlab_pointkdtree('pointkdtree_search_nearest', ...
            //     kdtree, querySet, maxDistanceSet, kNearest, norm);

            std::string norm = matlabAsString(inputSet[Norm]);
            if (norm == "euclidean")
            {
                kdSearchNearest_<Euclidean_NormBijection<real>>(
                    outputs, outputSet,
                    inputs - 1, inputSet);
            }
            else if (norm == "maximum")
            {
                kdSearchNearest_<Maximum_NormBijection<real>>(
                    outputs, outputSet,
                    inputs - 1, inputSet);
            }
            else
            {
                bool AllowedNorm = false;
                ENSURE(AllowedNorm);
            }
        }

        template <typename NormBijection>
        void kdCountNearest_(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                QuerySet,
                MaxDistanceSet,
                Inputs
            };

            enum
            {
                IdSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);
            ENSURE_OP(outputs, >=, 1);
            ENSURE_OP(outputs, <=, Outputs);

            // nearestIdSet = matlab_pointkdtree('pointkdtree_count_nearest', ...
            //     kdtree, querySet, maxDistanceSet, norm);

            KdState* state = asState(inputSet[State]);
            const IndexMap& indexMap = state->indexMap;
            Array<real> maxDistanceSet = matlabAsLinearizedArray<real>(inputSet[MaxDistanceSet]);
            integer queries = mxGetNumberOfElements(inputSet[QuerySet]);
            ENSURE_OP(maxDistanceSet.size(), ==, queries);

            // The queries are over the points in the kd-tree,
            // given by their ids.
            Array<integer> querySet = matlabAsLinearizedArray<integer>(inputSet[QuerySet]);

            // Create the result array.
            Array<integer> result =
                matlabCreateArray<integer>(queries, 1, outputSet[IdSet]);

            // Count the neighbors for each point.

            using Block = tbb::blocked_range<integer>;

            auto count = [&](const Block& block)
            {
                for (integer i = block.begin(); i < block.end(); ++i)
                {
                    ConstIterator query_ = indexMap.find(querySet(i));
                    if (query_ == indexMap.end())
                    {
                        continue;
                    }

                    Point_ConstIterator query = query_->second;

                    integer count = countNearest(
                        kdTreeNearestSet(state->tree),
                        location(query->point(), state->tree.locator()),
                        NormBijection(),
                        PASTEL_TAG(kNearest), state->tree.points(),
                        PASTEL_TAG(maxDistance2), maxDistanceSet(i));

                    result(i) = count;
                }
            };

            tbb::parallel_for(Block(0, queries), count);
        }

        void kdCountNearest(
            int outputs, mxArray *outputSet[],
            int inputs, const mxArray *inputSet[])
        {
            enum
            {
                State,
                QuerySet,
                MaxDistanceSet,
                Norm,
                Inputs
            };

            enum
            {
                IdSet,
                Outputs
            };

            ENSURE_OP(inputs, ==, Inputs);
            ENSURE_OP(outputs, >=, 1);
            ENSURE_OP(outputs, <=, Outputs);

            // nearestIdSet = matlab_pointkdtree('pointkdtree_count_nearest', ...
            //     kdtree, querySet, maxDistanceSet, norm);

            std::string norm = matlabAsString(inputSet[Norm]);
            if (norm == "euclidean")
            {
                kdCountNearest_<Euclidean_NormBijection<real>>(
                    outputs, outputSet,
                    inputs - 1, inputSet);
            }
            else if (norm == "maximum")
            {
                kdCountNearest_<Maximum_NormBijection<real>>(
                    outputs, outputSet,
                    inputs - 1, inputSet);
            }
            else
            {
                bool AllowedNorm = false;
                ENSURE(AllowedNorm);
            }
        }

        void addStuff()
        {
            matlabAddFunction("pointkdtree_as_points", kdAsPoints);
            matlabAddFunction("pointkdtree_construct", kdConstruct);
            matlabAddFunction("pointkdtree_destruct", kdDestruct);
            matlabAddFunction("pointkdtree_copy", kdCopy);
            matlabAddFunction("pointkdtree_check", kdCheck);
            matlabAddFunction("pointkdtree_merge", kdMerge);
            matlabAddFunction("pointkdtree_clear", kdClear);
            matlabAddFunction("pointkdtree_points", kdPoints);
            matlabAddFunction("pointkdtree_nodes", kdNodes);
            matlabAddFunction("pointkdtree_leaves", kdLeaves);
            matlabAddFunction("pointkdtree_dimension", kdDimension);
            matlabAddFunction("pointkdtree_insert", kdInsert);
            matlabAddFunction("pointkdtree_erase", kdErase);
            matlabAddFunction("pointkdtree_hide", kdHide);
            matlabAddFunction("pointkdtree_show", kdShow);
            matlabAddFunction("pointkdtree_refine", kdRefine);
            matlabAddFunction("pointkdtree_search_nearest", kdSearchNearest);
            matlabAddFunction("pointkdtree_count_nearest", kdCountNearest);
        }

        CallFunction run(addStuff);

    }

}