package ohd.hseb.hefs.utils.dist.types;

import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;

import ohd.hseb.hefs.utils.dist.DataFittingDistributionException;
import ohd.hseb.hefs.utils.dist.LMomentsFittingDistribution;
import ohd.hseb.hefs.utils.dist.LMomentsMath;
import ohd.hseb.hefs.utils.dist.MomentsFittingDistribution;
import ohd.hseb.hefs.utils.dist.ShiftOptimizationFittingDistribution;
import ohd.hseb.hefs.utils.xml.vars.XMLDouble;
import ohd.hseb.util.data.DataSet;

/**
 * Three parameter Gamma -- i.e. a Pearson Type III. This wraps {@link GammaDistribution} provided by commons.math.
 * Method of moments is used for parameter estimation of the shape and scale. LMoments is also provided for parameter
 * estimation of shape and scale (only). If you need to use a non-zero shift parameter, call {@link #fitToData(DataSet)}
 * in the super class, which will use {@link #estimateMethodOfMoments(DataSet, double)} to estimate the shape and scale
 * for various shifts and find the optimal one according to RMSE.
 * 
 * @author hank.herr
 */
public class GammaDist extends ContinuousDist
implements MomentsFittingDistribution, LMomentsFittingDistribution,
ShiftOptimizationFittingDistribution
{
    /**
     * Must be kept consistent with the position of the parameter in the constructors.
     */
    private static final int SCALE = 0;

    /**
     * Must be kept consistent with the position of the parameter in the constructors.
     */
    private static final int SHAPE = 1;

    /**
     * Must be kept consistent with the position of the parameter in the constructors.
     */
    private static final int SHIFT = 2;

    private static final double DEFAULT_SCALE = 1.0D;
    private static final double DEFAULT_SHAPE = 2.0D;
    private static final double DEFAULT_SHIFT = 0.0D;

    private boolean _fitShift = false;

    private GammaDistribution _wrapped;

    /**
     * Constructs the Gamma distribution using the defaults provided by the constants: {@link #DEFAULT_SCALE},
     * {@link #DEFAULT_SHAPE}, and {@link #DEFAULT_SHIFT}.
     */
    public GammaDist()
    {
        this(DEFAULT_SCALE, DEFAULT_SHAPE, DEFAULT_SHIFT);
        _wrapped = new GammaDistribution(getShape(), getScale());
    }

    /**
     * Constructs the distribution given the specified parameters.
     * 
     * @param scale Scale parameter.
     * @param shape Shape parameter.
     * @param shift Shift parameter.
     */
    public GammaDist(final double scale, final double shape, final double shift)
    {
        super(new XMLDouble("domain", shift, shift, null),
              new XMLDouble("scale", scale, 0.0D, null),
              new XMLDouble("shape", shape, 0.0D, null),
              new XMLDouble("shift", shift, null, null));
        _wrapped = new GammaDistribution(getShape(), getScale());
    }

    public void setFitShift(final boolean b)
    {
        _fitShift = b;
    }

    public void setScale(final double scale)
    {
        setParameter(SCALE, scale);
        _wrapped = new GammaDistribution(getShape(), getScale());
    }

    public double getScale()
    {
        return getParameter(SCALE).doubleValue();
    }

    public void setShape(final double shape)
    {
        setParameter(SHAPE, shape);
        _wrapped = new GammaDistribution(getShape(), getScale());
    }

    public double getShape()
    {
        return getParameter(SHAPE).doubleValue();
    }

    /**
     * Ensures {@link #setDomainLowerBound(Double)} is called as well.
     */
    public void setShift(final double shift)
    {
        setParameter(SHIFT, shift);
        setDomainLowerBound(shift);
    }

    public double getShift()
    {
        return getParameter(SHIFT).doubleValue();
    }

    @Override
    public void setParameter(final int index, final Number value)
    {
        super.setParameter(index, value);
        _wrapped = new GammaDistribution(getShape(), getScale());
    }

    @Override
    public double functionCDF(final Double value)
    {
        if(value < getShift())
        {
            return 0;
        }
        try
        {
            return _wrapped.cumulativeProbability(value - getShift());
        }
        catch(final Throwable e)
        {
            return getMissing();
        }
    }

    @Override
    public double functionPDF(final Double value)
    {
        if(value < getShift())
        {
            return 0;
        }

        // Since the PDF is the derivative of the CDF, you can check this at Wolfram Alpha using:
        // derivative Gamma [b, x / a] / Gamma [ b ]
        return _wrapped.density(value - getShift());
    }

    @Override
    public double functionInverseCDF(final double prob)
    {
        if(prob <= 0D || prob >= 1.0D)
        {
            return getMissing();
        }
        try
        {
            return _wrapped.inverseCumulativeProbability(prob) + getShift();
        }
        catch(final Throwable e)
        {
            return getMissing();
        }
    }

    /**
     * @param data Be sure to call {@link DataSet#setFitSampleVariable(int)} and {@link DataSet#setFitCDFVariable(int)}
     *            prior to calling this routine.
     * @return True if successful, false if not. Need a better way of signalling this.
     */
    public void estimateMethodOfMoments(final DataSet data) throws DataFittingDistributionException
    {
        //Apply the shift -- subtract it from all values in a copy of the data
        final DataSet usedData = new DataSet(data);
        usedData.applyShiftTransform(data.getFitSampleVariable(),
                                     -1D * getShift());

        //Compute the first two moments and call fitToMoments.
        final double mean = usedData.mean(usedData.getFitSampleVariable());
        final double var =
                         usedData.sampleVariance(usedData.getFitSampleVariable());
        fitToMoments(mean, Math.sqrt(var) / mean);
    }

    /**
     * Override to force a fit using the three parameter version. If this is not overridden, then the two parameter
     * method {@link #estimateParameters(DataSet, double)} will be called within an optimization. That method calls the
     * two parameter L-moments version, {@link #estimateLMoments(DataSet, double)}.
     */
    @Override
    public void fitToData(final DataSet data,
                          final double[] fitParms) throws DataFittingDistributionException
    {
        fitToLMoments(data);

//SHIFT OPTIMIZATION CODE:
//        if(!_fitShift)
//        {
//            fitToLMoments(data);
//        }
//        else
//        {
//            if(fitParms == null)
//            {
//                fitParms = new double[]{0.0D}; //Minimum set to 0!!!
//            }
//            DistributionTools.optimizeShiftFitForBoundedBelowDistribution(this, data, fitParms);
//        }
    }

    @Override
    public void fitToMoments(final double mean,
                             final double coefficientOfVariation)
    {
        //Assumes a fixedShift of zero!!!
        final double cofvsq = coefficientOfVariation * coefficientOfVariation;

        if((mean <= 0) || (coefficientOfVariation == 0))
        {
            throw new IllegalArgumentException("Either the mean, " + mean
                + ", is negative or the coefficient of variation"
                + coefficientOfVariation + " is 0.");
        }

        setShape(1.0 / cofvsq);
        setScale(mean * cofvsq);
    }

    @Override
    public void fitToLMoments(final DataSet data) throws DataFittingDistributionException
    {
        final double lmoments[];

        // We could either fit to a 2 parameter (PELGAM) or 3 parameter (PELPE3) Gamma.
        // Use the 3 parameter Gamma.

        // Get the 3 lmoment ratios  (maybe only need 2)
        //L-Moments are computed based on a copy of the data, so that the original data is not shifted.
        //This is only necessary because this is a two-parameter version of the L-moments algorithm.
        final DataSet usedData = new DataSet(data);
        if(!_fitShift)
        {
            usedData.applyShiftTransform(usedData.getFitSampleVariable(),
                                         -1D * getShift());
        }
        lmoments = LMomentsMath.dataToLMoments(usedData, 3);

        // Convert the 3 lmoment ratios into 3 parameters

        fitToLMoments(lmoments);
    }

    /**
     * fitToLMomentTwoParameter() fits the L-moment ratios to a 2 parameter gamma distribution. The fixed shift is
     * passed in. The code is based on the FORTRAN routine (in lmoments.f) PELGAM(XMOM,PARA) from the:<br>
     * <br>
     * IBM RESEARCH REPORT RC20525<br>
     * 'FORTRAN ROUTINES FOR USE WITH THE METHOD OF L-MOMENTS, VERSION 3'<br>
     * J. R. M. HOSKING <br>
     * IBM RESEARCH DIVISION <br>
     * VERSION 3.04 JULY 2005 <br>
     * <br>
     * 
     * @param data - The data for which to calculate moments. The data should have FitSampleVariable set.
     * @param shift - The fixed shift to apply.
     */
    private void fitToLMomentTwoParameter(final double[] lmoments)
    {
        final double a1 = -0.3080;
        final double a2 = -0.05812;
        final double a3 = 0.01765;

        final double b1 = 0.7213;
        final double b2 = -0.5947;
        final double b3 = -2.1817;
        final double b4 = 1.2113;

        final double cv, t, alpha;

        if((lmoments[0] <= lmoments[1]) || (lmoments[1] <= 0.0))
        {
            throw new IllegalArgumentException("Either " + "lmoments[0] "
                + lmoments[0] + " <= lmoments[1] " + lmoments[1]
                + ", or lmoments[1] <= 0.0.");
        }

        cv = lmoments[1] / lmoments[0];

        if(cv >= 0.5)
        {
            t = 1.0 - cv;
            alpha = t * (b1 + t * b2) / (1.0 + t * (b3 + t * b4));
        }
        else
        {
            t = Math.PI * cv * cv;
            alpha = (1.0 + a1 * t) / (t * (1.0 + t * (a2 + t * a3)));
        }

        setShape(alpha);
        setScale(lmoments[0] / alpha);
    }

    private void fitToLMomentThreeParameter(final double[] lmoments)
    {
        // This code is based on the FORTRAN routine (in lmoments.f):
        //
        //      PELPE3(XMOM,PARA)
        //
        // from the IBM RESEARCH REPORT RC20525:
        //'FORTRAN ROUTINES FOR USE WITH THE METHOD OF L-MOMENTS, VERSION 3'
        // J. R. M. HOSKING      
        // IBM RESEARCH DIVISION  
        // VERSION 3.04  JULY 2005

        // SMALL IS USED TO TEST WHETHER SKEWNESS IS EFFECTIVELY ZERO

        final double small = 1.0e-6;

        // CONSTANTS final USED IN MINIMAX final APPROXIMATIONS

        final double c1 = 0.2906;
        final double c2 = 0.1882;
        final double c3 = 0.0442;

        final double d1 = 0.36067;
        final double d2 = -0.59567;
        final double d3 = 0.25361;
        final double d4 = -2.78861;
        final double d5 = 2.56096;
        final double d6 = -0.77045;

        final double pi3 = Math.PI * 3.0;
        final double root_pi = Math.sqrt(Math.PI);

        final double t, t3;
        final double alpha, root_alpha, beta;

        final double mean;
        final double sd;
        final double skewness;

        t3 = Math.abs(lmoments[2]);

        if((lmoments[1] <= 0.0) || (t3 >= 1.0))
        {
            throw new IllegalArgumentException("Either " + "lmoments[1] "
                + lmoments[1] + " <= 0.0 " + ", or lmoments[2] " + t3
                + " >= 1.0.");
        }

        if(t3 <= small)
        {
            // Note: The FORTRAN code only calculates mean, sd, and skewness. 
            // A skewness of 0 will cause an IllegalArgumentException below.
            // If we set skewness = small, that will give a shape of 4.0e12.

            mean = lmoments[0];
            sd = lmoments[1] * root_pi;
            skewness = 0.0;
        }
        else
        {
            if(t3 >= 1.0 / 3.0)
            {
                t = 1.0 - t3;
                alpha = t * (d1 + t * (d2 + t * d3))
                    / (1.0 + t * (d4 + t * (d5 + t * d6)));
            }
            else
            {
                t = pi3 * t3 * t3;
                alpha = (1.0 + c1 * t) / (t * (1.0 + t * (c2 + t * c3)));
            }

            root_alpha = Math.sqrt(alpha);
            beta = root_pi * lmoments[1] * Math.exp(Gamma.logGamma(alpha)
                - Gamma.logGamma(alpha + 1.0 / 2.0));

            mean = lmoments[0];
            sd = beta * root_alpha;
            skewness = Math.signum(lmoments[2]) * 2.0 / root_alpha;
        }

        // Convert mean, sd, skewness -> scale, shape, shift
        // See http://mathworld.wolfram.com/PearsonTypeIIIDistribution.html
        //
        // In Mathworld parameters:
        //
        // beta  <-> scale
        // p     <-> shape
        // alpha <-> shift
        //
        // skewness = 2.0 / sqrt(shape)
        //
        // so shape = (2.0 / skewness)^2 = (4.0 / (skewness^2)
        //
        // sd^2     = p * beta^2
        //          = shape * scale^2
        //
        // so scale = sd / sqrt(shape) 
        //          = sd / (2.0 / skewness) 
        //          = (sd * skewness) / 2.0
        //
        // mean     = alpha + p * beta = shift + (shape * scale)   
        //
        // so shift = mean - (shape * scale)
        //          = mean - ((4.0 / (skewness^2) * (sd * skewness) / 2.0)
        //          = mean - (2.0 * sd) / skewness

        if(skewness == 0.0)
        {
            throw new IllegalArgumentException("Skewness == 0.0, cannot calculate shape");
        }

        setShape(4.0 / (skewness * skewness));
        setScale((sd * skewness) / 2.0);
        setShift(mean - ((2.0 * sd) / skewness));
    }

    @Override
    public void fitToLMoments(final double[] lmoments) throws DataFittingDistributionException
    {
        if(!_fitShift)
        {
            fitToLMomentTwoParameter(lmoments);
        }
        else
        {
            fitToLMomentThreeParameter(lmoments);
        }
    }

    @Override
    public void estimateParameters(final DataSet data,
                                   final double shift) throws DataFittingDistributionException
    {
        final GammaDist fitDist = new GammaDist();
        fitDist.setFitShift(false); //Force a two-param fit
        fitDist.setShift(shift);
        fitDist.fitToData(data);

        //COpy the results
        setScale(fitDist.getScale());
        setShape(fitDist.getShape());
        setShift(shift);
    }
}
