package ohd.hseb.hefs.mefp.models;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import nl.wldelft.util.timeseries.TimeSeriesArray;
import ohd.hseb.hefs.mefp.tools.canonical.CanonicalEvent;
import ohd.hseb.hefs.utils.Dyad;
import ohd.hseb.hefs.utils.dist.types.NormalDist;
import ohd.hseb.hefs.utils.tools.ThreadedRandomTools;
import ohd.hseb.hefs.utils.tsarrays.TimeSeriesEnsemble;
import ohd.hseb.util.misc.HCalendar;

/**
 * Modifies a {@link TimeSeriesEnsemble} based on forecast {@link CanonicalEvent} values using the Schaake Shuffle. The
 * application takes the form of a multiplication of the base event value, unless the forecast event value is positive
 * but the base event value is 0, in which case it is a straight set value. Its also a straight set value if the
 * canonical event duration is only one period.
 * 
 * @author hankherr
 */
public class SchaakeShuffleApplier
{
    final private TimeSeriesEnsemble _baseEnsemble;
    final private int _periodStepInHours;
    final private boolean _precipitationMode;

    private double _noRainThreshold = Double.NaN; // LW

    /**
     * Calls {@link #SchaakeShuffleApplier(TimeSeriesEnsemble, int, boolean)} but assumes the first time step that can
     * be modified is 0 (all time steps can be modified).
     * 
     * @param baseEnsemble The ensemble to modify via the shuffle.
     * @param precipitationMode True to apply the Schaake shuffle for precipitation data (a multiplication
     *            modification), false for temperature (addition).
     * @param noRainThreshold The threshold for the EPT algorithm.
     */
    public SchaakeShuffleApplier(final TimeSeriesEnsemble baseEnsemble,
                                 final boolean precipitationMode,
                                 final double noRainThreshold)
    {
        this(baseEnsemble, precipitationMode);

        // Use the threshold specified for EPT as the overall threshold in the Schaake Shuffle. LW
        _noRainThreshold = noRainThreshold;
    }

    /**
     * @param baseEnsemble The ensemble to modify via the shuffle.
     * @param precipitationMode True to apply the Schaake shuffle for precipitation data (a multiplication
     *            modification), false for temperature (addition).
     */
    public SchaakeShuffleApplier(final TimeSeriesEnsemble baseEnsemble, final boolean precipitationMode)
    {
        _precipitationMode = precipitationMode;
        _baseEnsemble = baseEnsemble;
        _periodStepInHours = _baseEnsemble.getTimeStepHours();
    }

    /**
     * Usually set in the constructure, this method should only be directly called for testing when the firstTime
     * constructor is called.
     * 
     * @param noRainThreshold
     */
    public void setNoRainThreshold(final double noRainThreshold)
    {
        _noRainThreshold = noRainThreshold;
    }

    /**
     * Creates a small randomly generated number close to zero. Allows for sorting a bunch of near zero values, I guess.
     * The values are distributed N(0.0001, 0.00001).
     * 
     * @return
     */
    private float zero()
    {
        final float zavg = 0.0001f;
        final float zstd = 0.00001f;
        return zavg + zstd
            * (float)NormalDist.STD_NORM_DIST.functionInverseCDF(ThreadedRandomTools.getRandomNumGen().nextDouble());
    }

    /**
     * Creates the index array necessary for the Schaake shuffle code. Each element is in the index of the next element
     * in the given values as sorted in ascending order.
     * 
     * @param values Values to index
     * @return The index array.
     */
    private int[] indexArray(final double[] values)
    {
        //Build a list of dyads, where the second element is going to be the index and is set to the original 
        //values index.
        final List<Dyad<Double, Integer>> dyads = new ArrayList<Dyad<Double, Integer>>();
        for(int i = 0; i < values.length; i++)
        {
            dyads.add(new Dyad<Double, Integer>(values[i], i));
        }

        //Sort the dyads.  After sorting, the second element in each dyad is the original index, but the list
        //order is ascending.
        Collections.sort(dyads, new Comparator<Dyad<Double, Integer>>()
        {
            @Override
            public int compare(final Dyad<Double, Integer> arg0, final Dyad<Double, Integer> arg1)
            {
                return Double.compare(arg0.getFirst(), arg1.getFirst());
            }
        });

        //Copy the indices from the list of dyads and return it.
        final int[] indices = new int[values.length];
        for(int i = 0; i < indices.length; i++)
        {
            indices[i] = dyads.get(i).getSecond();
        }
        return indices;
    }

    /**
     * Multiplies all values in the time series affected by the event by the given multiplier. It also adds the results
     * of a call to {@link #zero()} in order to add a slight randomization to the results.
     * 
     * @param ts Time series to modify.
     * @param canonicalEvent Event.
     * @param t0 The forecast time to assume (time series need not use this forecast time).
     * @param multiplier The value to multiply through.
     */
    private void multiplyTimeSeriesValues(final TimeSeriesArray ts,
                                          final CanonicalEvent canonicalEvent,
                                          final long t0,
                                          final double multiplier)
    {
        final long startTime = canonicalEvent.computeStartTime(t0, _periodStepInHours);
        final int startIndex = ts.firstIndexAfterOrAtTime(startTime);

        //The t0Index is the index BEFORE the index for canonical event period 1.  So the difference between
        //startIndex and t0Index will be the start lead index for the canonical event.  This is used to determine
        //if the _firstTimeStepIndexWhichCanBeChanged is reached so that the time series can be modified.
        final int t0Index = ts.firstIndexAfterOrAtTime(t0 + 1 * _periodStepInHours * HCalendar.MILLIS_IN_HR) - 1;
        final int numberOfPeriods = canonicalEvent.computeNumberOfPeriods();
        float newValue;
        float tmpValue = 0.0f;

        for(int i = 0; i < numberOfPeriods; i++)
        {
            //A time series value is changed only if the index relative to T0 is equal to or after _firstTimeStepIndexWhichCanBeChanged.
            //Also, the value must exceed the ZERO_THRESHOLD (we don't want to change a zero value to a non-zero value because
            //the zero()'s added to it accumulate high enough).  
            if(startIndex + i >= t0Index)
            {
                //Limin reported that this should be fine: If the value does not exceed ZERO_THRESHOLD, leave it alone 
                //(view it as zero and not needing to be changed, since 0*x = 0).
// LW             if((ts.getValue(startIndex + i) > ZERO_THRESHOLD) || (ts.getValue(startIndex + i) == 0))
//                {
//                    ts.setValue(startIndex + i, (float)multiplier * ts.getValue(startIndex + i));
//                }

                // The ensemble values coming out of the Schaake Shuffle are required to be either 0 or no less than _noRainThreshold.
                // Therefore, in conducting this operation, small positive values less than _noRainThreshold are accumulated in a temporal 
                // variable so that it can be applied later in the sequence rather than get dropped off, in order to maintain water balance 
                // with respect to the threshold used -- in other words, the product of the multiplier and the sum of the 6-hour base values
                // in the canonical event should be maintained equal to the forecast value for the event. The hydrologic effect of doing so 
                // may be small. LW
                newValue = (float)multiplier * ts.getValue(startIndex + i);
                if(newValue < _noRainThreshold)
                {
                    tmpValue = tmpValue + newValue;
                    if(tmpValue < _noRainThreshold)
                    {
                        ts.setValue(startIndex + i, 0.0f);
                    }
                    else
                    {
                        ts.setValue(startIndex + i, tmpValue);
                        tmpValue = 0.0f;
                    }
                }
                else
                {
                    newValue = newValue + tmpValue;
                    ts.setValue(startIndex + i, newValue);
                    tmpValue = 0.0f;
                }
            }
        }

        // The remaining amount of precipitation, if any, should be conserved, if possible.
        if(tmpValue > 0.0f)
        {
            for(int i = numberOfPeriods - 1; i == 0; i--)
            {
                if(startIndex + i >= t0Index)
                {
                    if(ts.getValue(startIndex + i) >= _noRainThreshold)
                    {
                        ts.setValue(startIndex + i, tmpValue + ts.getValue(startIndex + i));
                        break;
                    }
                }
            }
        }
    }

    /**
     * Adds to all values in the time series affected by the event by the given multiplier.
     * 
     * @param ts Time series to modify.
     * @param canonicalEvent Event.
     * @param t0 The forecast time to assume (time series need not use this forecast time).
     * @param multiplier The value to multiply through.
     */
    private void addToTimeSeriesValues(final TimeSeriesArray ts,
                                       final CanonicalEvent canonicalEvent,
                                       final long t0,
                                       final double addend)
    {
        final long startTime = canonicalEvent.computeStartTime(t0, _periodStepInHours);
        final int startIndex = ts.firstIndexAfterOrAtTime(startTime);

        //The t0Index is the index BEFORE the index for canonical event period 1.  So the difference between
        //startIndex and t0Index will be the start lead index for the canonical event.  This is used to determine
        //if the _firstTimeStepIndexWhichCanBeChanged is reached so that the time series can be modified.
        final int t0Index = ts.firstIndexAfterOrAtTime(t0 + 1 * _periodStepInHours * HCalendar.MILLIS_IN_HR) - 1;

        for(int i = 0; i < canonicalEvent.computeNumberOfPeriods(); i++)
        {
            //A time series value is changed only if the index relative to T0 is equal to or after _firstTimeStepIndexWhichCanBeChanged.  
            if(startIndex + i >= t0Index)
            {
                ts.setValue(startIndex + i, (float)(ts.getValue(startIndex + i) + addend));
            }
        }
    }

    /**
     * Sets all values in the time series affected by the event to the given value. It also adds the results of a call
     * to {@link #zero()} in order to add a slight randomization to the results.
     * 
     * @param ts Time series to modify.
     * @param canonicalEvent Event.
     * @param t0 The forecast time to assume (time series need not use this forecast time).
     * @param value The value to which to set all affected values to.
     */
    private void setTimeSeriesValues(final TimeSeriesArray ts,
                                     final CanonicalEvent canonicalEvent,
                                     final long t0,
                                     final double value)
    {
        final long startTime = canonicalEvent.computeStartTime(t0, _periodStepInHours);
        final int startIndex = ts.firstIndexAfterOrAtTime(startTime);

        //The t0Index is the index BEFORE the index for canonical event period 1.  So the difference between
        //startIndex and t0Index will be the start lead index for the canonical event.  This is used to determine
        //if the _firstTimeStepIndexWhichCanBeChanged is reached so that the time series can be modified.
        final int t0Index = ts.firstIndexAfterOrAtTime(t0 + 1 * _periodStepInHours * HCalendar.MILLIS_IN_HR) - 1;

        // Note that the forecast value is a total value for the canonical event. To obtain values for the individual
        // 6-hour steps in the canonical event, the total is divided by the number of 6-hour steps to get the average. 
        // When assigning values to the 6-hour steps in the canonical event from the averaged forecast value, make sure 
        // 1) the assigned values are either 0 or no less than _noRainThreshold, and 2) the average of the assigned values 
        // is equal to the averaged forecast value. LW
        final int numberOfPeriods = canonicalEvent.computeNumberOfPeriods();
        int k = 1;
        int numberOfPositiveValues = numberOfPeriods;
        double forecastValue = value / numberOfPeriods;
        // Note at this point, numberOfPeriods*forecastValue should be no less than _noRainThreshold if forecastValue 
        // is positive. LW
        if(forecastValue > 0.0d)
        {
            while(forecastValue < _noRainThreshold && numberOfPositiveValues > 1)
            {
                k = 2 * k;
                forecastValue = 2.0d * forecastValue;
                numberOfPositiveValues = numberOfPositiveValues / 2;
            }

            if(numberOfPositiveValues < numberOfPeriods)
            {
                // Set the 6-hour steps to 0 initially.
                for(int i = 0; i < numberOfPeriods; i++)
                {
                    if(startIndex + i >= t0Index)
                    {
                        ts.setValue(startIndex + i, 0.0f);
                    }
                }
            }

            for(int i = 0; i < numberOfPositiveValues; i++)
            {
                //A time series value is changed only if the index relative to T0 is equal to or after _firstTimeStepIndexWhichCanBeChanged.
                if(startIndex + i * k >= t0Index)
                {
                    ts.setValue(startIndex + i * k, (float)forecastValue);
                }
            }
        }
        else
        {
            for(int i = 0; i < numberOfPeriods; i++)
            {
                ts.setValue(startIndex + i, (float)value);
            }
        }

// LW     for(int i = 0; i < canonicalEvent.computeNumberOfPeriods(); i++)
//        {
//            //A time series value is changed only if the index relative to T0 is equal to or after _firstTimeStepIndexWhichCanBeChanged.
//            if(startIndex + i >= t0Index + _firstTimeStepIndexWhichCanBeChanged)
//            {
//                ts.setValue(startIndex + i, (float)value);
//            }
//        }

    }

    /**
     * Calls {@link #multiplyTimeSeriesValues(TimeSeriesArray, CanonicalEvent, long, double)}, but determines info from
     * the base ensemble.
     */
    private void multiplyBaseEnsembleValue(final int memberIndex,
                                           final CanonicalEvent canonicalEvent,
                                           final double multiplier)
    {
        multiplyTimeSeriesValues(_baseEnsemble.get(memberIndex),
                                 canonicalEvent,
                                 _baseEnsemble.getForecastTime(),
                                 multiplier);
    }

    /**
     * Calls {@link #addTimeSeriesValues(TimeSeriesArray, CanonicalEvent, long, double)}, but determines info from the
     * base ensemble.
     */
    private void addToBaseEnsembleValue(final int memberIndex,
                                        final CanonicalEvent canonicalEvent,
                                        final double multiplier)
    {
        addToTimeSeriesValues(_baseEnsemble.get(memberIndex),
                              canonicalEvent,
                              _baseEnsemble.getForecastTime(),
                              multiplier);
    }

    /**
     * Calls {@link #setTimeSeriesValues(TimeSeriesArray, CanonicalEvent, long, double)}, but determines the info from
     * the base ensemble.
     */
    private void setBaseEnsembleValue(final int memberIndex, final CanonicalEvent canonicalEvent, final double value)
    {
        setTimeSeriesValues(_baseEnsemble.get(memberIndex), canonicalEvent, _baseEnsemble.getForecastTime(), value);
    }

    /**
     * @param forecastCanonicalEventEnsemble An ensemble of canonical event values for the event passed in. There should
     *            be one element in this array per member in {@link #_baseEnsemble}.
     * @param canonicalEvent The canonical events.
     */
    public void applySchaakeShuffle(final double[] forecastCanonicalEventEnsemble, final CanonicalEvent canonicalEvent)
    {

        if(forecastCanonicalEventEnsemble.length != _baseEnsemble.size())
        {
            throw new IllegalArgumentException("The number of members in the forecast canonical event ensemble, "
                + forecastCanonicalEventEnsemble.length + ", does not match the number in the base ensemble, "
                + _baseEnsemble.size() + ".");
        }

        //Compute the canonical events for the base/historical time series.
        final double[] baseEventValues = new double[_baseEnsemble.size()];
        final double[] baseEventValuesForRanking = new double[_baseEnsemble.size()];
        for(int i = 0; i < baseEventValues.length; i++)
        {
            final TimeSeriesArray ts = _baseEnsemble.get(i);
            baseEventValues[i] = canonicalEvent.computeEvent(ts, _baseEnsemble.getForecastTime(), _periodStepInHours);
            baseEventValuesForRanking[i] = baseEventValues[i];
        }

        // Add small positive random numbers to historical MAP values (base event values), in order to make 
        // the Schaake Shuffle work properly in random ranking. LW
        if(_precipitationMode)
        {
            for(int i = 0; i < baseEventValues.length; i++)
            {
                baseEventValuesForRanking[i] = baseEventValuesForRanking[i] + zero();
            }
        }

        //Compute the index array.
        final int[] indx = indexArray(baseEventValuesForRanking);

        //Compute the number of canonical periods.
        final int numberOfCanonicalPeriods = canonicalEvent.computeNumberOfPeriods();

        //Loop through the members in decreasing order (not sure why decreasing).
        double multiplier, addend;
        for(int memberIndex = _baseEnsemble.size() - 1; memberIndex >= 0; memberIndex--)
        {
            final int orderIndexOfBaseEventValues = indx[memberIndex];

            //Precipitation includes zero checking and calling the multiply method.
            if(_precipitationMode)
            {
                if(forecastCanonicalEventEnsemble[memberIndex] >= 0)
                {
                    if(numberOfCanonicalPeriods == 1)
                    {
                        //Use the forecast event value as is.
                        setBaseEnsembleValue(orderIndexOfBaseEventValues,
                                             canonicalEvent,
                                             forecastCanonicalEventEnsemble[memberIndex]);
                    }
                    else
                    {
                        //multiplier is infinity if baseEventValues[...] is 0.  
//                        multiplier = forecastCanonicalEventEnsemble[memberIndex]
//                            / baseEventValues[orderIndexOfBaseEventValues];
//                        System.err.println("####>> MULTIPLIER ----- " + multiplier);
//                        if(multiplier <= 5d)
//                        {
//                            multiplyBaseEnsembleValue(orderIndexOfBaseEventValues, canonicalEvent, multiplier);
//                        }
//                        else
                        //Extreme difference between base and forecast ensembles, just use the fcst ensemble.  
                        //This extreme difference usually indicates that a 0 base ensemble value is to be adjust to a non-zero
                        //forecast value, or that both values are 0.  However, it could also be triggered if the forecast event
                        //is just that much larger than the base event (I'm not sure if this ever happens in practice).
//                        {
//                            setBaseEnsembleValue(orderIndexOfBaseEventValues,
//                                                 canonicalEvent,
//                                                 forecastCanonicalEventEnsemble[memberIndex]);
//                        }

                        // LW: Modify the base values depending on whether the base values or forecast values  are 
                        // positive or zero.
                        final double baseEventValue = baseEventValues[orderIndexOfBaseEventValues];
                        final double forecastCanonicalEventValue = forecastCanonicalEventEnsemble[memberIndex];
                        if(baseEventValue < 0.1 * _noRainThreshold
                            || forecastCanonicalEventValue < 0.1 * _noRainThreshold)
                        {
                            setBaseEnsembleValue(orderIndexOfBaseEventValues,
                                                 canonicalEvent,
                                                 forecastCanonicalEventValue);
                        }
                        else
                        {
                            multiplier = forecastCanonicalEventValue / baseEventValue;
                            multiplyBaseEnsembleValue(orderIndexOfBaseEventValues, canonicalEvent, multiplier);
                        }

                    }
                }
            }

            //Temperature does not have to worry about zeros and calls the add method.
            else
            {
                addend = forecastCanonicalEventEnsemble[memberIndex] - baseEventValues[orderIndexOfBaseEventValues];
                addToBaseEnsembleValue(orderIndexOfBaseEventValues, canonicalEvent, addend);
            }
        }

        //DO NOT FORGET THAT THIS IS DONE FOR THE SOURCE RFC!!! We need source-specific function to be called that
        //can be used to overwrite ensemble with observed data at the beginning.  This may not be necessary unless MEFP
        //is run for a non-12Z time!  If so, we need to somehow handle observed data or perhaps do that via CHPS merging???
//      c
//      c                 fix fmap = observed values for first nobs_rfc values
//      c
//                        do jper=1,nobs_rfc
//                          do imem=1,nmem
//                            epp_rfc(imem,jper) = fcstx_rfc(jper)
//                          enddo
//                        enddo
//                        write (ilog,*) 'array epp_rfc now contains ',
//           x                           'the rfc ensemble forecasts'
//                        n = min(nrfcpers_out,8)
//                        write (ilog,*) 'imem,(epp_rfc(imem,j),j=1,n)'
//                        write (ilog,*) 'n = ',n
//                        do imem=1,nmem
//                          write (ilog,'(i4,10f7.1)') imem,
//           x                           (epp_rfc(imem,jper),jper=1,n)
//                        enddo
//      c
//      c                 ...array epp_rfc now contains the rfc ensemble forecasts for
//      c                    the current forecast date
//      c
//                        endif
//                      endif
    }

//    c                   *************create ensemble members*************
//    c
//    c                   ...get distribution of epp_rfc values for SS
//    c
//                        j1 = ipd(ii)
//                        j2 = lpd(ii)
//                        n = j2 - j1 + 1
//                        do imem=1,nmem
//                          sum = 0
//                          do j=j1,j2
//                            sum = sum + epp_rfc(imem,j)
//                          enddo
//                          avg = sum/n
//                          x(imem) = avg
//                        enddo
//                        call indexx (nmem,x,indx)
//    c
//    c                   ...apply Schaake Shuffle
//    c
//                        nout = 0
//                        do imem=nmem,1,-1
//                          ipp = indx(imem)
//    c                      if (pp(ipp).ge.0.) then
//                          if (pp(imem).ge.0.) then
//                            if (n.eq.1) then
//                              epp_rfc(ipp,ipd(ii)) = pp(imem) +
//         x                                                zero(ix)
//                              ix = ix+1
//                              n = n
//                            else
//                              cmult = pp(imem)/x(ipp)
//                              if (pp(imem).gt.0..and.cmult.le.5.) then
//                                nout = nout + 1
//    c                            if (nout.lt.6) then
//    c                              write (ilog,*)
//    c                              write (ilog,*) 'rfc'
//    c                              write (ilog,*) 'imem = ',imem
//    c                              write (ilog,*) 'ipp = ',ipp
//    c                              write (ilog,*) 'cmult = ',cmult
//    c                              write (ilog,*) 'pp(imem) = ',pp(imem)
//    c                              write (ilog,*) 'x(ipp) = ',x(ipp)
//    c                              write (ilog,*) 'ipd(ii) = ',ipd(ii)
//    c                              write (ilog,*) 'lpd(ii) = ',lpd(ii)
//    c                              write (ilog,*) 'epp_rfc(ipp,k) = ',
//    c     x                                        (epp_rfc(ipp,k),
//    c     x                                        k=ipd(ii),lpd(ii))
//    c                            endif
//                              endif
//                              if (cmult.lt.5.) then
//                                kx = 0
//                                do k=ipd(ii),lpd(ii)
//                                  kx = kx + 1
//                                  p1(kx) = epp_rfc(ipp,k)
//                                  epp_rfc(ipp,k) = cmult*epp_rfc(ipp,k) +
//         x                                         zero(ix)
//                                  ix = ix+1
//                                  p2(kx) = epp_rfc(ipp,k)
//                                  n = n
//                                enddo
//                                if (cmult.gt.0.) then
//                                  n = n
//                                endif
//                              else
//                                nk = lpd(ii) - ipd(ii) + 1
//                                do k=ipd(ii),lpd(ii)
//                                  kx = kx + 1
//                                  p1(kx) = epp_rfc(ipp,k)
//                                  epp_rfc(ipp,k) = pp(imem) + zero(ix)
//                                  ix = ix+1
//                                  p2(kx) = epp_rfc(ipp,k)
//                                  n = n
//                                enddo
//                              endif
//    c                          if (nout.lt.6.and.pp(imem).gt.0.) then
//    c                            write (ilog,*) 'epp_rfc(ipp,k) = ',
//    c     x                                      (epp_rfc(ipp,k),
//    c     x                                      k=ipd(ii),lpd(ii))
//    c                            nk = lpd(ii) - ipd(ii) + 1
//    c                            write (ilog,*) 'p1 = ',(p1(k),k=1,nk)
//    c                            write (ilog,*) 'p2 = ',(p2(k),k=1,nk)
//    c                            n = n
//    c                          endif
//                            endif
//                          endif
//                          n = n
//                        enddo
//                      enddo
//    c
//    c                 fix fmap = observed values for first nobs_rfc values
//    c
//                      do jper=1,nobs_rfc
//                        do imem=1,nmem
//                          epp_rfc(imem,jper) = fcstx_rfc(jper)
//                        enddo
//                      enddo
//                      write (ilog,*) 'array epp_rfc now contains ',
//         x                           'the rfc ensemble forecasts'
//                      n = min(nrfcpers_out,8)
//                      write (ilog,*) 'imem,(epp_rfc(imem,j),j=1,n)'
//                      write (ilog,*) 'n = ',n
//                      do imem=1,nmem
//                        write (ilog,'(i4,10f7.1)') imem,
//         x                           (epp_rfc(imem,jper),jper=1,n)
//                      enddo
//    c
//    c                 ...array epp_rfc now contains the rfc ensemble forecasts for
//    c                    the current forecast date
//    c
//                      endif
//                    endif
}
