package ohd.hseb.hefs.utils.dist.types;

import ohd.hseb.hefs.utils.dist.DataFittingDistributionException;
import ohd.hseb.hefs.utils.dist.HMathTools;
import ohd.hseb.hefs.utils.dist.MomentsFittingDistribution;
import ohd.hseb.hefs.utils.xml.vars.XMLDouble;
import ohd.hseb.util.data.DataSet;

/**
 * This class is a subclass of {@link ContinuousDist} corresponding to the Beta distribution. It has four parameters:
 * the two shape parameters (Beta.SHAPE1 and Beta.SHAPE2) and the upper and lower bounds on the interval. The default
 * interval is (0,1). For {@link #fitToData(DataSet)}, if the data passed in is not null, then it is assumed that we
 * will acquire the moments of the {@link DataSet} variable and fit the distribution to those moments. If the passed in
 * data is null, then the fitparms array is assumed to specify two values: the sample mean and the sample variance of
 * the desired beta distribution. (in that order). Method of moments is used to fit the beta. This class also provides
 * static methods to calculate the Gamma function (actually the log-Gamma function, which can be used to acquire the
 * gamma function with a Math.exp call), the beta function, and the incomplete beta function.
 * 
 * @author hank
 */
public class BetaDist extends ContinuousDist implements MomentsFittingDistribution
{
    /**
     * Must be kept consistent with the position of the parameter in the constructors.
     */
    private final static int SHAPE1 = 0;

    /**
     * Must be kept consistent with the position of the parameter in the constructors.
     */
    private final static int SHAPE2 = 1;

    private final static double DEFAULT_SHAPE1 = 1.0;
    private final static double DEFAULT_SHAPE2 = 2.0;
    private final static double DEFAULT_LB = 0.0;
    private final static double DEFAULT_UB = 1.0;

    /**
     * Constant dictates the precision of the inverse CDF, which must do a search using the CDF.
     */
    final static double INVERSE_PRECISION = 0.00001;

    /**
     * Default distribution has domain (0,1) with parameters {@link #DEFAULT_SHAPE1} and {@link #DEFAULT_SHAPE2}.
     */
    public BetaDist()
    {
        this(DEFAULT_SHAPE1, DEFAULT_SHAPE2, DEFAULT_LB, DEFAULT_UB);
    }

    /**
     * @param shape1 First shape parameter.
     * @param shape2 Second shape parameter.
     * @param lb Defines lower bound of domain, set through the first argument to the super constructor
     *            {@link ContinuousDist#ContinuousDist(XMLDouble, ohd.hseb.hefs.utils.xml.vars.XMLNumber...)}.
     * @param ub Defines lower bound of domain, set through the first argument to the super constructor
     *            {@link ContinuousDist#ContinuousDist(XMLDouble, ohd.hseb.hefs.utils.xml.vars.XMLNumber...)}.
     */
    public BetaDist(final double shape1, final double shape2, final double lb, final double ub)
    {
        super(new XMLDouble("domain", 0.0D, lb, ub),
              new XMLDouble("shape1", shape1, 0.0D, null),
              new XMLDouble("shape2", shape2, 0.0D, null));
        if(ub < lb)
        {
            throw new IllegalArgumentException("Beta distribution domain upper bound, " + ub
                + ", cannot be less than the lower bound, " + lb + ".");
        }
    }

    public void setShape1(final double shape1)
    {
        super.setParameter(SHAPE1, shape1);
    }

    public double getShape1()
    {
        return getParameter(SHAPE1).doubleValue();
    }

    public void setShape2(final double shape1)
    {
        super.setParameter(SHAPE2, shape1);
    }

    public double getShape2()
    {
        return getParameter(SHAPE2).doubleValue();
    }

    @Override
    public double functionCDF(final Double value)
    {
        double temp; //Stores the result.  This makes sure the returned prob is 0 or 1 only when outside bounds.

        //Get the bounds
        final double a = getDomainLowerBound();
        final double b = getDomainUpperBound();

        //Make sure the value is between a and b (inclusive).
        if(value < a)
        {
            return 0;
        }
        if(value > b)
        {
            return 1;
        }

        //Okay... the value is within the bounds.  Rescale the value to be on interval [0, 1].
        final double usedval = (value - a) / b;

        //Get the shape1 and shape2 parameters.
        final double shape1 = getShape1();
        final double shape2 = getShape2();

        //Try to perform the computation using a Beta[0,1] CDF.
        try
        {
            temp = HMathTools.incompleteBeta(usedval, shape1, shape2);
        }
        catch(final ArithmeticException except)
        {
            return getMissing();
        }

        //Make sure the number is within reasonable bounds.
        if(temp == 1)
        {
            temp = 0.99999;
        }
        if(temp == 0)
        {
            temp = 0.00001;
        }

        return temp;
    }

    @Override
    public double functionPDF(final Double value)
    {
        double temp; //Stores the result.  This makes sure the returned prob is 0 or 1 only when outside bounds.

        //Get the bounds
        final double a = getDomainLowerBound();
        final double b = getDomainUpperBound();

        //Make sure the value is between a and b (inclusive).
        if(value < a)
        {
            return 0;
        }
        if(value > b)
        {
            return 1;
        }

        //Okay... the value is within the bounds.  Rescale the value to be on interval [0, 1].
//        double usedval = (value - a) / b;

        //Get the shape1 and shape2 parameters.
        final double shape1 = getShape1();
        final double shape2 = getShape2();

        //Try to perform the computation using a Beta[0,1] CDF.        
        // Since the PDF is the derivative of the CDF, you can check this at Wolfram Alpha using:
        // derivative BetaRegularized[x, a, b]

        try
        {
            temp = Math.pow(value, shape1 - 1) * Math.pow((1 - value), shape2 - 1) / HMathTools.beta(shape1, shape2);
        }
        catch(final ArithmeticException except)
        {
            return getMissing();
        }

        return temp / (b - a);
    }

    /**
     * Does a simple binary search (i.e. halfs the interval width until I've bounded the inverse value within a certain
     * precision.
     */
    @Override
    public double functionInverseCDF(final double prob)
    {
        //Check Prob
        if((prob <= 0.0) || (prob >= 1.0))
        {
            return getMissing();
        }

        //temp will store the quantile on the [0, 1] interval.
        double temp;

        //these values store the current lower and upper bound on the search interval.
        double lb, ub;

        //Initialize the bounds.  If either is null, this will bomb out, which is a good thing (though the messaging may not be sufficient).
        lb = getDomainLowerBound();
        ub = getDomainUpperBound();

        //Do the loop until the distance from ub to lb is less than the desired precision.
        while((ub - lb) > INVERSE_PRECISION)
        {
            temp = (ub + lb) / 2.0;

            //Evaluate the midpoint CDF and if it is larger than the prob passed in, set ub equal to it.
            //Otherwise, set lb equal to it.
            if(functionCDF(temp) > prob)
            {
                ub = temp;
            }
            else
            {
                lb = temp;
            }
        }

        //I now have an interval, [lb, ub], in which I know the correct quantile exists.  Return its midpoint
        //as my approximation. 
        temp = (ub + lb) / 2.0;
        return temp;
    }

    /**
     * The current lower and upper bounds are assumed to be the desired lower and upper bounds of the fitted
     * distribution.<br>
     * <br>
     * If data is NOT null, then it is assumed that we will acquire the rescaled moments of the data set variable and
     * fit the distribution to those moments. If data is null, then fitparms is assumed to specify two values: the mean
     * and the variance of the desired beta distribution. By rescaling, I mean to transform the data to reduce or expand
     * the bounds to be [0, 1].<br>
     * <br>
     * Method of moments is used. The first item of fitparms must be the sample mean, and the second item must be the
     * sample variance. The data object need not have any CDF data: this is a MOM estimation routine. Thus, you only
     * need to make sure _fitsample is set within the DataSet (via setFitSampleVariable).
     */
    @Override
    public void fitToData(final DataSet data, final double[] fitparms) throws DataFittingDistributionException
    {
        //store the two sample statistics.
        double smean, svar;

        //Check for null data.  
        if(data == null)
        {
            //fitparms better not be null.
            if(fitparms == null)
            {
                throw new DataFittingDistributionException("Since data is not provided, the fit parameters must be defined, but are not.");
            }

            //Fit parms must have a length of at least two.
            if(fitparms.length < 2)
            {
                throw new DataFittingDistributionException("Since data is not provided, the fit parameters must include two values specifying the sample mean and sample variance, in that order.");
            }

            //Copy sample moments.
            smean = fitparms[0]; //First item is the sample mean.
            svar = fitparms[1]; //Second item is the sample var.
        }
        //If data is not null, then I must acquire the moments manually.
        else
        {
            //Rescale the data values by, first, copying the dataset and then transforming the data.
            final DataSet datacopy = new DataSet(data);

            //Get the fit sample.
            final int index = data.getFitSampleVariable();
            if(index == DataSet.MISSING)
            {
                throw new DataFittingDistributionException("The fit-sample index was not provided in the data set.");
            }

            //Transform the data in place using two steps: shift by lb and scale by width of (ub - lb).
            datacopy.shiftVariable(index, (-1 * getDomainLowerBound()));
            datacopy.scaleVariable(index, (1 / (getDomainUpperBound() - getDomainLowerBound())));

            //Try to acquire the mean of the data.
            smean = data.mean(index);
            if(smean == DataSet.MISSING)
            {
                throw new DataFittingDistributionException("The sample mean computation returned missing.");
            }

            //Try to acquire hte sample variance of the data.
            svar = data.sampleVariance(index);
            if(svar == DataSet.MISSING)
            {
                throw new DataFittingDistributionException("The sample variance computation returned missing.");
            }
        }

        //I now have the sample moments, so lets estimate the SHAPE1 and SHAPE2 parameters using the moments.
        fitToMoments(smean, svar);
    }

    /**
     * NOTE: The moments must be computed from the data AFTER linearly rescaling to the range [0, 1]. Hence the
     * attributes are renamed as local variables scaled* within this method before the sets are called at the endg.
     */
    @Override
    public void fitToMoments(final double scaledMean, final double scaledCoefficientOfVariation)
    {
        final double scaledSampleVar = Math.pow(scaledCoefficientOfVariation * scaledMean, 2);

        //Make sure the scaled mean is between 0 and 1, which is must be after scaling.
        //Note that, if the mean is 0 or 1, then one of the two shape parameters will be zero, which
        //is not allowed.
        if((scaledMean <= 0) || (scaledMean >= 1))
        {
            throw new IllegalArgumentException("The rescaled mean, " + scaledMean + ", is outside [0, 1].");
        }

        //Of course, the scaledsvar better be positive.
        if(scaledSampleVar <= 0)
        {
            throw new IllegalArgumentException("The rescaled variance, " + scaledSampleVar
                + ", is negative after computation based on the rescaled mean, " + scaledMean
                + ", and coefficient of variation, " + scaledCoefficientOfVariation + ".");
        }

        double temp;

        //The variable temp is used to store the piece common to estimating both shape parameters.
        //If this piece is 0, problems will occur.  In other words, if (1 - scaledmean = scaledsvar)
        //then problems will occur, since both parameters will be 0.
        temp = ((scaledMean * (1.0 - scaledMean) / scaledSampleVar) - 1.0);
        if(temp == 0)
        {
            throw new IllegalStateException("The value of 1 minus the scalend mean, " + scaledMean
                + " cannot equal the scaled sample variance, " + scaledSampleVar + ".");
        }

        //The SHAPE1 parameter:
        setShape1(scaledMean * temp);

        //the SHAPE2 parameter:
        setShape2((1.0 - scaledMean) * temp);
    }

    ////////////////////////////////////////////////////////////////////////
    //Static Routines
    ////////////////////////////////////////////////////////////////////////

    //TODO Need to move these into a MathTools class or use something provided in commons math, if it exists.

    ////////////////////////////////////////
    public static void main(final String argv[])
    {
        try
        {
            System.out.println("Gamma(2.5) = " + Math.exp(HMathTools.gammaln(2.5)));
            System.out.println("Gamma(4) = " + Math.exp(HMathTools.gammaln(4)));
            System.out.println("Beta(0.5, 2.5) = " + HMathTools.beta(0.5, 2.5));
            System.out.println("Beta(2.5, 5.0) = " + HMathTools.beta(2.5, 5.0));
            System.out.println("BetaDist(0.5, 1.0, 2.0) = " + HMathTools.incompleteBeta(0.5, 1.0, 2.0));
            System.out.println("BetaDist(0.25, 2.0, 3.0) = " + HMathTools.incompleteBeta(0.25, 2.0, 3.0));

            BetaDist beta = new BetaDist(1.0, 1.0, 0.0, 1.0);
            System.out.println("BetaDist 1 = " + beta.functionCDF(0.25));
            System.out.println("BetaDist 2 = " + beta.functionCDF(0.50));
            System.out.println("BetaDist 3 = " + beta.functionPDF(0.25));
            System.out.println("BetaDist 4 = " + beta.functionPDF(0.50));

            beta.fitToMoments(0.5, (1.0 / 8.0));
            System.out.println("Beta params after fit: " + beta.getParameter(BetaDist.SHAPE1) + ", "
                + beta.getParameter(BetaDist.SHAPE2));

            //Inverse checker...
            int i;
            beta = new BetaDist(2.0, 5.0, 0.0, 1.0);
            for(i = 0; i < 10; i++)
            {
                System.out.println("Inverse CDF for " + (i * 0.10) + " is " + beta.functionInverseCDF((i * 0.10)));
            }

        }
        catch(final ArithmeticException e)
        {
            System.out.println("ERROR");
        }

    }

}
