draw_mutual_information.m

Back to Orphans

matlab/+tim/+example/

% DRAW_MUTUAL_INFORMATION
% Draws the distribution of mutual_information -estimates.
%
% draw_mutual_information('key', value, ...)
%
% Optional input arguments in 'key'-value pairs:
%
% K ('k') is a positive integer which denotes the number of nearest 
% neighbors to be used by the estimator. Greater number decreases bias,
% but increases variance.
% Default: 1
%
% N ('n') is a non-negative integer which denotes the number 
% of points to generate.
% Default: 1000
%
% M ('m') is a non-negative integer which denotes the number of estimates 
% to compute. The histogram is formed from these estimates.
% Default: 1000
%
% COV ('cov') is a positive-definite (D x D) real matrix which
% denotes the covariance matrix of the random variable (X, Y).
% The D also determines the dimension of the random variable (X, Y).
% Default: eye(4, 4)
%
% DX ('dx') is a positive integer which denotes the dimension of 
% the random variable X.
% Default: floor(D / 2)
%
% As a trial, we generate point-sets X and Y from a multi-variate normal 
% distribution with identity covariance, and compute a differential 
% entropy estimate via differential_entropy_kl. We will then accumulate 
% many such trials to numerically reveal the distribution of the 
% estimator for the chosen arguments.

% Description: Draws the distribution of mutual_information -estimates

function draw_mutual_information(varargin)

% Optional input arguments
k = 1;
n = 1000;
m = 1000;
cov = eye(4, 4);
dx = {};
eval(tim.process_options({'k', 'n', 'm', 'cov', 'dx'}, varargin));

d = size(cov, 1);
if iscell(dx)
    dx = floor(d / 2);
end

pastelmatlab.concept_check(...
    k, 'integer', ...
    k, 'positive', ...
    n, 'integer', ...
    n, 'non_negative', ...
    m, 'integer', ...
    m, 'non_negative', ...
    dx, 'integer', ...
    dx, 'positive', ...
    cov, 'real_matrix', ...
    cov, 'square_matrix');

% Accumulate the estimates for each trial.
hSet = zeros(1, m);
for i = 1 : m
    % Generate a point-set randomly from the multi-variate
    % normal distribution.
    XY = tim.random_normal(d, n, 'cov', cov);
    X = XY(1 : dx, :);
    Y = XY((dx + 1) : end, :);
    % Compute the mutual information estimate.
    hSet(i) = tim.mutual_information(X, Y, 'k', k);
end

% Compute the sample mean of the estimates.
hMean = mean(hSet);

% Compute the standard deviation of the estimates.
hDeviation = std(hSet);

xCov = cov(1 : dx, 1 : dx);
yCov = cov((dx + 1) : end, (dx + 1) : end);

% Compute the correct mutual information; the
% case for the multi-variate normal distribution can be
% solved analytically.
hCorrect = tim.mutual_information_normal(det(xCov) * det(yCov), det(cov));

xMin = hCorrect - 0.25;
xMax = hCorrect + 0.25;

% Draw a histogram of the estimates.
hist(hSet, 100);
hold on
yLimits = [0, 50];

% Draw the sample mean as a vertical line.
line([hMean, hMean], yLimits, 'Color', [1, 0, 0]);
% Draw the correct value as a vertical line.
line([hCorrect, hCorrect], yLimits, 'Color', [0, 1, 0]);

% Name the axes and the figure.
title(['Distribution of the mutual information -estimate', ...
    ', k = ', num2str(k), ...
    ', d = ', num2str(d), ...
    ', n = ', num2str(n)]);
xlabel(['mutual information -estimate', ...
    ', mean = ', num2str(hMean), ...
    ', deviation = ', num2str(hDeviation)]);
ylabel('Samples');
legend('Samples', 'Sample mean', 'Correct value');
axis([xMin, xMax, yLimits(1), yLimits(2)]);
hold off;