Back to Wang-Kulkarni-Verdu estimator
// Description: Kullback-Leibler divergence estimation
// Detail: Wang-Kulkarni-Verdu nearest neighbor estimator
#ifndef TIM_DIVERGENCE_WKV_H
#define TIM_DIVERGENCE_WKV_H
#include "tim/core/mytypes.h"
#include "tim/core/signal.h"
#include "tim/core/signalpointset.h"
#include <pastel/sys/range.h>
#include <pastel/sys/indicator/predicate_indicator.h>
#include <pastel/geometry/search_nearest.h>
#include <pastel/geometry/nearestset/kdtree_nearestset.h>
#include <tbb/parallel_reduce.h>
#include <tbb/blocked_range.h>
namespace Tim
{
//! Computes Kullback-Leibler divergence between signals.
/*!
This is a convenience function that calls the
more general divergenceWkv().
See the documentation for that function.
*/
TIM dreal divergenceWkv(
const Signal& xSignal,
const Signal& ySignal);
//! Computes Kullback-Leibler divergence between signals.
/*!
xSignalSet:
A set of signals representing trials for X.
ySignalSet:
A set of signals representing trials for X.
returns:
The Kullback-Leibler divergence between the signals.
If the estimate is undefined, a NaN is returned.
*/
template <
typename X_Signal_Range,
typename Y_Signal_Range>
dreal divergenceWkv(
const X_Signal_Range& xSignalSet,
const Y_Signal_Range& ySignalSet)
{
// "A Nearest-Neighbor Approach to Estimating
// Divergence between Continuous Random Vectors"
// Qing Wang, Sanjeev R. Kulkarni, Sergio Verdu,
// IEEE International Symposium on Information Theory (ISIT),
// 2006.
if (xSignalSet.empty() || ySignalSet.empty())
{
return 0;
}
integer xDimension = std::begin(xSignalSet)->dimension();
integer yDimension = std::begin(ySignalSet)->dimension();
ENSURE_OP(xDimension, ==, yDimension);
typedef SignalPointSet::Point_ConstIterator Point_ConstIterator;
// Construct point-sets.
SignalPointSet xPointSet(xSignalSet);
SignalPointSet yPointSet(ySignalSet);
integer xSamples = xPointSet.samples();
integer ySamples = yPointSet.samples();
using Block = tbb::blocked_range<integer>;
using Pair = std::pair<dreal, integer>;
auto compute = [&](
const Block& block,
const Pair& start)
{
dreal estimate = start.first;
integer acceptedSamples = start.second;
for (integer i = block.begin(); i < block.end(); ++i)
{
// Find out the nearest neighbor in X for a point in X.
Point_ConstIterator query =
*(xPointSet.begin() + i);
Vector<dreal> queryPoint(
ofDimension(xDimension),
withAliasing((dreal*)(query->point())));
dreal xxDistance2 =
(dreal)searchNearest(
kdTreeNearestSet(xPointSet.kdTree()),
queryPoint,
PASTEL_TAG(accept), predicateIndicator(query, NotEqualTo())
).first;
if (xxDistance2 > 0 && xxDistance2 < infinity<dreal>())
{
// Find out the nearest neighbor in Y for a point in X.
dreal xyDistance2 =
(dreal)searchNearest(kdTreeNearestSet(yPointSet.kdTree()),
queryPoint).first;
if (xyDistance2 > 0 && xyDistance2 < infinity<dreal>())
{
estimate += std::log(xyDistance2 / xxDistance2);
++acceptedSamples;
}
}
}
return Pair(estimate, acceptedSamples);
};
auto reduce = [](const Pair& left, const Pair& right)
{
return Pair(
left.first + right.first,
left.second + right.second);
};
dreal estimate = 0;
integer acceptedSamples = 0;
std::tie(estimate, acceptedSamples) =
tbb::parallel_reduce(
Block(0, xSamples),
Pair(0, 0),
compute,
reduce);
if (acceptedSamples > 0)
{
// The factor 2 in the denominator is because
// 'xyDistance' and 'xxDistance' are squared distances
// and thus need to be taken a square root. However,
// this can be taken outside the logarithm with a
// division by 2.
estimate *= (dreal)xDimension / (2 * acceptedSamples);
estimate += std::log((dreal)ySamples / (xSamples - 1));
}
else
{
estimate = (dreal)Nan();
}
return estimate;
}
}
#endif