entropy_combination.h

Back to Entropy combinations

tim/core/

// Description: Estimation of entropy combinations

#ifndef TIM_ENTROPY_COMBINATION_H
#define TIM_ENTROPY_COMBINATION_H

#include "tim/core/mytypes.h"
#include "tim/core/signal.h"
#include "tim/core/signal_tools.h"
#include "tim/core/signalpointset.h"
#include "tim/core/reconstruction.h"

#include <pastel/geometry/pointkdtree/pointkdtree.h>
#include <pastel/geometry/count_nearest.h>
#include <pastel/geometry/search_nearest.h>
#include <pastel/geometry/nearestset/kdtree_nearestset.h>

#include <pastel/math/normbijection/maximum_normbijection.h>

#include <pastel/sys/array/array.h>
#include <pastel/sys/range.h>
#include <pastel/sys/math/eps.h>
#include <pastel/sys/sequence/copy_n.h>
#include <pastel/sys/indicator/predicate_indicator.h>

#include <numeric>
#include <iterator>
#include <vector>

#include <tbb/blocked_range.h>
#include <tbb/parallel_for.h>
#include <tbb/parallel_reduce.h>

namespace Tim
{

    //! Computes an entropy combination of signals.
    /*!
   Preconditions:
   kNearest > 0

   signalSet:
   An ensemble of joint signals representing trials
   of the same experiment. Note: all the marginal signals
   share the memory with these joint signals.

   rangeSet:
   A sequence of m triples T_i = (a_i, b_i, s_i), 
   where [a_i, b_i] is an interval such that picking those 
   dimensions from the joint signal X gives the marginal 
   signal X_i. The s_i is the factor by which the differential 
   entropy of such a marginal signal is multiplied before summing
   to the end-result.

   timeWindowRadius:
   The radius of the time-window in samples to use.
   Smaller values give more temporal adaptivity,
   but increase errors.

   kNearest:
   The k:th nearest neighbor that is used to
   estimate entropy combination.

   Returns:
   An estimate of the entropy combination of the signals.
   */
    template <
        ranges::forward_range Integer3_Range,
        ranges::forward_range Lag_Range>
    dreal entropyCombination(
        const Array<Signal>& signalSet,
        const Integer3_Range& rangeSet,
        const Lag_Range& lagSet,
        integer kNearest = 1)
    {
        ENSURE_OP(kNearest, >, 0);
        ENSURE_OP(ranges::size(lagSet), ==, signalSet.height());

        typedef typename SignalPointSet::Point_ConstIterator
            Point_ConstIterator;

        if (ranges::empty(signalSet) || rangeSet.empty())
        {
            return 0;
        }

        // Construct the joint signal.

        integer trials = signalSet.width();

        std::vector<SignalData> jointSignalSet;
        jointSignalSet.reserve(trials);
        merge(signalSet, 
            std::back_inserter(jointSignalSet), lagSet);

        integer samples = std::begin(jointSignalSet)->samples();
        if (samples == 0)
        {
            return 0;
        }

        integer signals = signalSet.height();

        // Find out the dimension ranges of the marginal
        // signals.

        std::vector<integer> offsetSet;
        offsetSet.reserve(signals + 1);
        offsetSet.push_back(0);
        for (integer i = 1;i < signals + 1;++i)
        {
            offsetSet.push_back(offsetSet[i - 1] + signalSet(0, i - 1).dimension());
        }

        const integer n = samples * trials;
        integer marginals = ranges::size(rangeSet);

        // Construct point sets

        SignalPointSet jointPointSet(jointSignalSet);

        std::vector<integer> weightSet;
        weightSet.reserve(marginals);

        auto iter = std::begin(rangeSet);
        std::vector<SignalPointSet> pointSet;
        pointSet.reserve(marginals);
        for (integer i = 0;i < marginals;++i)
        {
            const Integer3& range = *iter;
            pointSet.emplace_back(
                jointSignalSet,
                offsetSet[range[0]], offsetSet[range[1]]);
            weightSet.push_back(range[2]);
            ++iter;
        }

        // It is essential that the used norm is the
        // maximum norm.

        Maximum_Norm<dreal> norm;

        // Find the distances to the k:th nearest neighbors.

        Array<dreal> distanceArray(Vector2i(1, n));

        using Block = tbb::blocked_range<integer>;

        auto search = [&](const Block& block)
        {
            for (integer i = block.begin(); i < block.end(); ++i)
            {
                auto query = *(jointPointSet.begin() + i);

                Vector<dreal> queryPoint(
                    ofDimension(jointPointSet.dimension()),
                    withAliasing((dreal*)(query->point())));

                distanceArray(i) =
                    (dreal)searchNearest(
                        kdTreeNearestSet(jointPointSet.kdTree()),
                        queryPoint,
                        PASTEL_TAG(accept), predicateIndicator(query, NotEqualTo()),
                        PASTEL_TAG(norm), norm,
                        PASTEL_TAG(kNearest), kNearest
                    ).first;
            }
        };

        tbb::parallel_for(Block(0, n), search);

        const dreal signalWeightSum = 
            std::accumulate(weightSet.begin(), weightSet.end(), (dreal)0);

        dreal estimate = 0;
        for (integer i = 0;i < marginals;++i)
        {
            using Block = tbb::blocked_range<integer>;
            using Pair = std::pair<dreal, integer>;

            auto compute = [&](
                const Block& block,
                const Pair& start)
            {
                dreal signalEstimate = start.first;
                integer acceptedSamples = start.second;
                for (integer j = block.begin();j < block.end();++j) 
                {
                    auto query = *(pointSet[i].begin() + j);

                    Vector<dreal> queryPoint(
                        ofDimension(pointSet[i].dimension()),
                        withAliasing((dreal*)(query->point())));

                    integer k = countNearest(
                        kdTreeNearestSet(pointSet[i].kdTree()),
                        queryPoint,
                        PASTEL_TAG(norm), norm,
                        PASTEL_TAG(maxDistance2), norm(distanceArray(j))
                    );

                    // A neighbor count of zero can happen when the distance
                    // to the k:th neighbor is zero because of using an
                    // open search ball. These points are ignored.
                    if (k > 0)
                    {
                        signalEstimate += digamma<dreal>(k);
                        ++acceptedSamples;
                    }
                }

                return Pair(signalEstimate, acceptedSamples);
            };

            auto reduce = [](const Pair& left, const Pair& right)
            {
                return Pair(
                    left.first + right.first, 
                    left.second + right.second);
            };

            dreal signalEstimate = 0;
            integer acceptedSamples = 0;

            std::tie(signalEstimate, acceptedSamples) = 
                tbb::parallel_reduce(
                    Block(0, n),
                    Pair(0, 0),
                    compute,
                    reduce);

            if (acceptedSamples > 0)
            {
                signalEstimate /= acceptedSamples;
            }

            estimate -= signalEstimate * weightSet[i];
        }

        estimate += digamma<dreal>(kNearest);
        estimate += (signalWeightSum - 1) * digamma<dreal>(n);

        return estimate;
    }

    //! Computes an entropy combination of signals.
    /*!
   This is a convenience function that calls:

   entropyCombination(
       signalSet,
       rangeSet,
       constantRange(0, signalSet.height()));

   See the documentation for that function.
   */
    template <ranges::forward_range Integer3_Range>
    dreal entropyCombination(
        const Array<Signal>& signalSet,
        const Integer3_Range& rangeSet)
    {
        return Tim::entropyCombination(
            signalSet,
            rangeSet,
            constantRange(0, signalSet.height()));
    }

}

#endif