package ohd.hseb.hefs.utils.dist.types;

import ohd.hseb.util.data.DataSet;

/**
 * Distribution of X1 given a conditioned value of variable X2. X1 ~ Norm(mean1, stddev1^2) and X2 ~ Norm(mean2,
 * stddev2^2), and (X1, X2) ~ Bivariate Normal with correlation coefficient as provided.
 * 
 * @author hankherr
 */
public class BivariateConditionalNormalDist extends NormalDist
{

    /**
     * @param conditionedValue The value of variable 2 to use in the conditioning
     * @param mean1 Mean of variable 1
     * @param stddev1 Standard deviation of variable 1
     * @param mean2 Mean of variable 2
     * @param stddev2 Standard deviation of variable 2
     * @param correlation Pearsons correlation coefficient between the two variables.
     */
    public BivariateConditionalNormalDist(final double conditionedValue,
                                          final double mean1,
                                          final double stddev1,
                                          final double mean2,
                                          final double stddev2,
                                          final double correlation)
    {
        initialize(conditionedValue, mean1, stddev1, mean2, stddev2, correlation);
    }

    /**
     * The means, standard deviations, and correlation are computed from the provided data.
     * 
     * @param conditionedValue The value of variable 2 to use in the conditioning
     * @param data Data used to compute the moments and correlation.
     * @param firstVariable The variable in the data set to use as X1.
     * @param secondVariable The variable to use as X2.
     */
    public BivariateConditionalNormalDist(final double conditionedValue,
                                          final DataSet data,
                                          final int firstVariable,
                                          final int secondVariable)
    {
        final double mean1 = data.mean(firstVariable);
        final double stddev1 = data.sampleStandardDeviation(firstVariable);
        final double mean2 = data.mean(secondVariable);
        final double stddev2 = data.sampleStandardDeviation(secondVariable);
        final double correlation = data.correlation(firstVariable, secondVariable);
        initialize(conditionedValue, mean1, stddev1, mean2, stddev2, correlation);
    }

    /**
     * The means, standard deviations, and correlation are computed from the provided data.
     * 
     * @param conditionedValue The value of variable 2 to use in the conditioning
     * @param mean1 Mean of variable 1
     * @param stddev1 Standard deviation of variable 1
     * @param mean2 Mean of variable 2
     * @param stddev2 Standard deviation of variable 2
     * @param correlation Pearsons correlation coefficient between the two variables.
     */
    public BivariateConditionalNormalDist(final double conditionedValue,
                                          final NormalDist firstVarDist,
                                          final NormalDist secondVarDist,
                                          final double correlation)
    {
        initialize(conditionedValue,
                   firstVarDist.getMean(),
                   firstVarDist.getStandardDeviation(),
                   secondVarDist.getMean(),
                   secondVarDist.getStandardDeviation(),
                   correlation);
    }

    /**
     * Initialization done using provided numbers.
     * 
     * @param conditionedValue The value of variable 2 to use in the conditioning
     * @param mean1 Mean of variable 1
     * @param stddev1 Standard deviation of variable 1
     * @param mean2 Mean of variable 2
     * @param stddev2 Standard deviation of variable 2
     * @param correlation Pearsons correlation coefficient between the two variables.
     */
    private void initialize(final double conditionedValue,
                            final double mean1,
                            final double stddev1,
                            final double mean2,
                            final double stddev2,
                            final double correlation)
    {
        if(stddev1 < 0)
        {
            throw new IllegalArgumentException("Standard deviation of variable 1, " + stddev1 + ", cannot be negative.");
        }
        if(stddev2 < 0)
        {
            throw new IllegalArgumentException("Standard deviation of variable 2, " + stddev1 + ", cannot be negative.");
        }
        super.setMean(mean1 + stddev1 / stddev2 * correlation * (conditionedValue - mean2));
        super.setStandardDeviation(Math.sqrt((1 - Math.pow(correlation, 2.0d)) * Math.pow(stddev1, 2)));
    }
}
