normal_distribution.h

Back to Orphans

pastel/sys/random/

// Description: Normal distribution

#ifndef PASTELSYS_NORMAL_DISTRIBUTION_H
#define PASTELSYS_NORMAL_DISTRIBUTION_H

#include "pastel/sys/real/real_concept.h"

#include <armadillo>

namespace Pastel
{

    template <typename Real>
    class Normal_Distribution
    {
    public:
        using Point = arma::Col<Real>;

        explicit Normal_Distribution(integer n = 1)
        : mean_(n, arma::fill::zeros)
        , scale_(n, arma::fill::ones)
        , rotation_(n, n, arma::fill::eye)
        , detCovariance_(1)
        {
        }

        Normal_Distribution(const Normal_Distribution&) = default;
        Normal_Distribution(Normal_Distribution&&) = default;
        Normal_Distribution& operator=(const Normal_Distribution&) = default;
        Normal_Distribution& operator=(Normal_Distribution&&) = default;

        Normal_Distribution(
            arma::Col<Real> mean,
            arma::Col<Real> scale,
            arma::Mat<Real> rotation)
        : mean_(std::move(mean))
        , scale_(std::move(scale))
        , rotation_(std::move(rotation))
        , detCovariance_(0)
        {
            ENSURE_OP(mean.n_elem, ==, rotation.n_rows);
            ENSURE_OP(mean.n_elem, ==, rotation.n_cols);
            detCovariance_ = arma::prod(scale_);
        }

        integer n() const
        {
            return mean_.n_elem;
        }

        PASTEL_GETTER(mean);
        PASTEL_GETTER(scale);
        PASTEL_GETTER(rotation);
        PASTEL_GETTER(detCovariance);

    private:
        arma::Col<Real> mean_;
        arma::Col<Real> scale_;
        arma::Mat<Real> rotation_;
        Real detCovariance_;
    };

}

#endif