package ohd.hseb.hefs.mefp.models;

import java.util.List;

import ohd.hseb.hefs.mefp.models.precipitation.MEFPPrecipitationModelControlOptions;
import ohd.hseb.hefs.mefp.sources.MEFPForecastSource;
import ohd.hseb.hefs.mefp.sources.MEFPSourceControlOptions;
import ohd.hseb.hefs.mefp.tools.canonical.CanonicalEventValuesGatherer;
import ohd.hseb.hefs.mefp.tools.canonical.SourceCanonicalEventValues;
import ohd.hseb.hefs.mefp.tools.canonical.StandardCanonicalEventValuesGatherer;

/**
 * General tools that can be access by both {@link MEFPParameterEstimationModel}s and {@link MEFPEnsembleGeneratorModel}
 * s.
 * 
 * @author hankherr
 */
public abstract class MEFPModelTools
{
    /**
     * The number of decimal places to which to round all forecast ensemble values output. It should be enough that the
     * precipitation zero-adjusted values show up as 0 in the results.
     */
    public static int FORECAST_ENSEMBLE_NUMBER_OF_DECIMAL_PLACES = 3;

    /**
     * Populates the provided lists given the forecastEventValues, observedEventValues, and noRainThreshold
     * 
     * @param forecastEventValues Forecast canonical event values.
     * @param observedEventValues Observed canonical event values.
     * @param noRainThreshold Threshold defining no rain: if event >= noRainThreshold, then rain has occurred.
     * @param fcstNoRainObsRain
     * @param fcstRainObsNoRain
     * @param fcstValuesBothRain
     * @param obsValuesBothRain
     * @return The number of times neither the forecast nor the observed value indicated rain (>= threshold).
     */
    public static int constructPrecipitationRainNoRainLists(final List<Float> forecastEventValues,
                                                            final List<Float> observedEventValues,
                                                            final double noRainThreshold,
                                                            final List<Float> fcstNoRainObsRain,
                                                            final List<Float> fcstRainObsNoRain,
                                                            final List<Float> fcstValuesBothRain,
                                                            final List<Float> obsValuesBothRain)
    {
        int count0 = 0;
        for(int i = 0; i < forecastEventValues.size(); i++)
        {
            if(forecastEventValues.get(i) < noRainThreshold)
            {
                if(observedEventValues.get(i) < noRainThreshold)
                {
                    count0++;
                }
                else
                {
                    fcstNoRainObsRain.add(observedEventValues.get(i));
                }
            }
            else
            {
                if(observedEventValues.get(i) < noRainThreshold)
                {
                    fcstRainObsNoRain.add(forecastEventValues.get(i));
                }
                else
                {
                    fcstValuesBothRain.add(forecastEventValues.get(i));
                    obsValuesBothRain.add(observedEventValues.get(i));
                }
            }
        }
        return count0;
    }

    /**
     * Constructs a {@link CanonicalEventValuesGatherer} for use with EPT. This is used both during parameter estimation
     * and ensemble generation.
     * 
     * @param source The forecast source.
     * @param values The {@link SourceCanonicalEventValues} from which to gather values.
     * @param estimationModelControlOptions Control options for data gathering.
     * @return A {@link StandardCanonicalEventValuesGatherer} that can be used to gather data.
     */
    public static StandardCanonicalEventValuesGatherer constructEPTCanonicalEventValuesGatherer(final MEFPForecastSource source,
                                                                                                final SourceCanonicalEventValues values,
                                                                                                final MEFPPrecipitationModelControlOptions estimationModelControlOptions,
                                                                                                final MEFPSourceControlOptions sourceControlOptions)
    {
        final StandardCanonicalEventValuesGatherer eptGatherer = (StandardCanonicalEventValuesGatherer)source.constructCanonicalEventValuesGatherer(values,
                                                                                                                                                    estimationModelControlOptions,
                                                                                                                                                    sourceControlOptions);
        eptGatherer.setMinimumRequiredBothPositive(StandardCanonicalEventValuesGatherer.DO_NOT_CHECK);
        eptGatherer.setThresholdOverride(estimationModelControlOptions.getEPT().getEPTPrecipThreshold());
        return eptGatherer;
    }

    /**
     * Computes the average observation in original space over the interval defined by the probIntervalLB and
     * probIntervalUB. It uses {@link #_fullRealizationCDFValues} and {@link #_fullRealizationsInOriginalSpace}. Its
     * results are returned.<br>
     * <br>
     * This is a direct translation of the routines in Fortran's xextractp.f, which is present in both epp3_precip_epp
     * and epp3_temp_epp (the two files are exactly identical except for whitespace). It is an ugly routine that needs
     * to be cleaned up in the future if it is to be maintained.<br>
     * <br>
     * This is basically computing E[Obs | Fcst, P(Obs) in (f1, f2)], where (Obs|Fcst) is distributed as a conditional
     * normal after converting both via NQT where the means/vars of obs and fcst are known. As such, I think the
     * algorithm below is essentially doing a numerical integration making use of the 1000 points in the _vval (and
     * associated) arrays. We might be able to do this more efficiently by using Simpsons rule for three points per each
     * of the desired intervals (a much smaller number than 1000). However, I'm not positive and the conversion below
     * exactly matches the results from Fortran to less than 0.0001, so let's stick with this until changes are needed.
     * 
     * @param obspz The observed probability of zero.
     * @param numberOfEnsembleMembers The number of members to create.
     * @param fullRealizationsInOriginalSpace A bunch of realizations, typically 1000, that will be used to estimate the
     *            conditional expectation noted above.
     * @param fullRealizationCDFValues The corresponding CDF values.
     * @return
     */
    public static double[] computeEnsembleMembers(final double obspz,
                                                  final int numberOfEnsembleMembers,
                                                  final double[] fullRealizationsInOriginalSpace,
                                                  final double[] fullRealizationCDFValues)
    {
        double f1 = 0;
        final int nv = fullRealizationsInOriginalSpace.length;

        final double[] forecastEnsemble = new double[numberOfEnsembleMembers];
        final int npp = forecastEnsemble.length;
        boolean inside;
        int jlast = 0;
//        double plast = 0;
        double flast = 0;
        int jnext = 0;
        double pnext = fullRealizationsInOriginalSpace[0];
        double fnext = fullRealizationCDFValues[0];
        final double df = 1. / npp;
        double avg = 0;
        double sumdf = 0;

        //Used in loop, do not need to initialize
        double f2;
        double dfv;
        double dff;

        for(int i = 0; i < npp; i++)
        {
            f2 = (i + 1) * df;
            if(f2 <= obspz)
            {
                forecastEnsemble[i] = 0;
            }
            else
            //f2 > obspz
            {
                inside = true;
                while(inside)
                {
                    jnext = jlast + 1;
                    if(flast > f2 || jnext >= nv)
                    {
                        //TODO Is this guaranteed to happen at some point???  I've triggered one case where
                        //it did not happen but I'm not sure how to recreate it.
                        inside = false;
                    }
                    else
                    {
//                        System.err.println("####>> " + flast + ", " + f2 + ", " + jnext + ", " + nv);
                        if(jnext < nv) //jnext is in bounds
                        {
                            pnext = fullRealizationsInOriginalSpace[jnext];
                            fnext = fullRealizationCDFValues[jnext]; //end of _p[jnext] interval
                            if(fnext > f2) //_p[next] ends after f2
                            {
                                if(flast >= f1)
                                {
                                    dfv = f2 - flast;
                                }
                                else
                                {
                                    dfv = f2 - f1;
                                }
                                dff = Math.max(0d, Math.min(dfv, f2 - obspz));
                                sumdf = sumdf + dfv;
                                avg = avg + dff * pnext;
                                if(sumdf > 0)
                                {
                                    avg = avg / sumdf;
                                }
                                forecastEnsemble[i] = avg;
                                dfv = fnext - f2;
                                sumdf = dfv;
                                dff = Math.max(0d, Math.min(dfv, f2 - obspz));
                                avg = dff * pnext;
                                inside = false;
                            }
                            else
                            {
                                if(fnext <= f1) //_p[jnext] is before <f1, f2> interval
                                {
                                    flast = fnext;
//                                    plast = pnext;
                                    jlast = jnext;
                                }
                                else
                                //_p[jnext] is in <f1, f2> interval
                                {
                                    if(flast < f1) //_p[jnext] interval starts before f1
                                    {
                                        if(f1 <= obspz)
                                        {
                                            if(fnext <= f2)
                                            {
                                                dfv = fnext - f1;
                                                dff = Math.max(0d, fnext - obspz);
                                                sumdf = sumdf + dfv;
                                                avg = avg + dff * pnext;
                                                jlast = jnext;
//                                                plast = pnext;
                                                flast = fnext;
                                            }
                                            else
                                            {
                                                dfv = f2 - 1;
                                                dff = Math.max(0, f2 - obspz);
                                                sumdf = sumdf + dfv;
                                                avg = avg + dff * pnext;
                                            }
                                        }
                                        else
                                        //p[jnext] starts after f1
                                        {
                                            if(fnext <= f2) //_p[jnext] ends before f2, update jlast
                                            {
                                                dfv = fnext - f1;
                                                sumdf = sumdf + dfv;
                                                if(fnext < obspz)
                                                {
                                                    dff = 0;
                                                }
                                                else
                                                {
                                                    if(f1 < obspz)
                                                    {
                                                        dff = fnext - obspz;
                                                    }
                                                    else
                                                    {
                                                        dff = fnext - f1;
                                                    }
                                                }
                                                avg = avg + dff * pnext;
                                                jlast = jnext;
//                                                plast = pnext;
                                                flast = fnext;
                                            }
                                            else
                                            //p(jnext) ends after f2, don't update jlast
                                            {
                                                dff = f2 - f1;
                                                sumdf = sumdf + df;
                                                avg = avg + dff * pnext;
                                            }
                                        }
                                    }
                                    else
                                    //flsast is inside of the <f1, f2> interval
                                    {
                                        if(fnext < f2) //all of _p[j] is in the interval, update jlast
                                        {
                                            dfv = fnext - flast;
                                            if(flast <= obspz)
                                            {
                                                dff = Math.max(0, fnext - obspz);
                                            }
                                            else
                                            {
                                                dff = fnext - flast;
                                            }
                                            sumdf = sumdf + dfv;
                                            avg = avg + dff * pnext;
                                            jlast = jnext;
//                                            plast = pnext;
                                            flast = fnext;
                                        }
                                        else
                                        //fnext is outside of hte interval ending at f2, don't update jlast
                                        {
                                            inside = false;
                                            if(flast >= f1)
                                            {
                                                if(flast < obspz)
                                                {
                                                    dff = Math.max(0, f2 - obspz);
                                                }
                                                else
                                                {
                                                    dff = f2 - flast;
                                                }
                                                dfv = f2 - flast;
                                            }
                                            else
                                            //flast < f1
                                            {
                                                if(f1 < obspz)
                                                {
                                                    dff = Math.max(0, f2 - obspz);
                                                }
                                                else
                                                {
                                                    dff = f2 - f1;
                                                }
                                                dfv = f2 - f1;
                                            }
                                            sumdf = sumdf + dfv;
                                            avg = avg + pnext * dff;
                                            if(sumdf > 0)
                                            {
                                                avg = avg / sumdf;
                                            }
                                            forecastEnsemble[i] = avg;
                                            sumdf = 0;
                                            avg = 0;
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
            f1 = f2; //set start of next interval.
        }

        return forecastEnsemble;
    }
}
