package fit.util;

import fit.framework.Comparison;
import fit.framework.FitCalculator;
import fit.framework.Interpolator;
import fit.framework.MetadataItem;
import fit.framework.MetadataTable;
import fit.framework.ObservationSet;
import fit.framework.ResultSink;
import fit.framework.SampleComparator;
import fit.framework.Theory;
import fit.framework.TheorySet;
import fit.util.ArraysTheorySet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import uk.ac.starlink.table.ColumnData;
import uk.ac.starlink.table.ColumnStarTable;
import uk.ac.starlink.table.ConcatStarTable;
import uk.ac.starlink.table.DefaultValueInfo;
import uk.ac.starlink.table.JoinStarTable;
import uk.ac.starlink.table.RowSequence;
import uk.ac.starlink.table.StarTable;
import uk.ac.starlink.table.ValueInfo;

/**
 * Utility methods for use in fitting packages.
 *
 * @author   Mark Taylor
 * @since    13 Oct 2006
 */
public class FitUtils {

    private static final Logger logger_ = Logger.getLogger( "fit.util" );

    /** Default interpolator factory. */
    private static InterpolatorFactory DEFAULT_INTERPOLATOR_FACTORY =
            new InterpolatorFactory() {
        public Interpolator getInterpolator( double x, double xerr ) {
            return ((InterpolatorFactory)
                    ( xerr > 0.0 ? SquareSmoother.getFactory()
                                 : LinearInterpolator.getFactory() ))
                  .getInterpolator( x, xerr );
        }
    };

    /** Speed of light. */
    private static final double C = 299792458;

    /**
     * Private sole constructor prevents instantiation.
     */
    private FitUtils() {
    }

    /**
     * Returns the array of X positions contained in an observation set.
     *
     * @param   obsSet  observation set
     * @return   array of X positions
     */
    public static double[] getXs( ObservationSet obsSet ) {
        int npoint = obsSet.getPointCount();
        double[] xs = new double[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            xs[ ip ] = obsSet.getX( ip );
        }
        return xs;
    }

    /**
     * Returns the array of X positions errors contained in an observation set.
     *
     * @param   obsSet  observation set
     * @return   array of X errors
     */
    public static double[] getXErrors( ObservationSet obsSet ) {
        int npoint = obsSet.getPointCount();
        double[] xerrs = new double[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            xerrs[ ip ] = obsSet.getXError( ip );
        }
        return xerrs;
    }

    /**
     * Returns the array of X factors contained in an observation set.
     *
     * @param   obsSet  observation set
     * @return  array of X axis stretch factors
     */
    public static double[] getXFactors( ObservationSet obsSet ) {
        int nobs = obsSet.getObsCount();
        double[] factors = new double[ nobs ];
        for ( int iobs = 0; iobs < nobs; iobs++ ) {
            factors[ iobs ] = obsSet.getXFactor( iobs );
        }
        return factors;
    }

    /**
     * Returns the range of X values covered by an observation set.
     *
     * @param  obsSet  observation set
     * @return   2-element (low,high) array giving range in X
     */
    public static double[] getXRange( ObservationSet obsSet ) {
        int npoint = obsSet.getPointCount();
        double xlo = Double.NaN;
        double xhi = Double.NaN;
        for ( int ip = 0; ip < npoint; ip++ ) {
            double x = obsSet.getX( ip );
            if ( ! ( xlo <= x ) ) {
                xlo = x;
            }
            if ( ! ( xhi >= x ) ) {
                xhi = x;
            }
        }
        return new double[] { xlo, xhi };
    }

    /**
     * Returns the array of Interpolators contained in an observation set.
     *
     * @param  obsSet   observation set
     * @return  array of interpolators
     */
    public static Interpolator[] getInterpolators( ObservationSet obsSet ) {
        int npoint = obsSet.getPointCount();
        Interpolator[] interps = new Interpolator[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            interps[ ip ] = obsSet.getInterpolator( ip );
        }
        return interps;
    }

    /**
     * Returns the array of Y values for a given observation in an
     * observation set.
     *
     * @param  obsSet  observation set
     * @param  iobs    observation index
     * @return   Y values
     */
    public static double[] getYs( ObservationSet obsSet, int iobs ) {
        int npoint = obsSet.getPointCount();
        double[] ys = new double[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            ys[ ip ] = obsSet.getY( ip, iobs );
        }
        return ys;
    }

    /**
     * Returns the array of Y errors for a given observation in an
     * observation set.
     *
     * @param  obsSet  observation set
     * @param  iobs    observation index
     * @return  Y errors
     */
    public static double[] getYErrors( ObservationSet obsSet, int iobs ) {
        int npoint = obsSet.getPointCount();
        double[] yerrs = new double[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            yerrs[ ip ] = obsSet.getYError( ip, iobs );
        }
        return yerrs;
    }

    /**
     * Returns the array of X values contained in a theory object.
     *
     * @param   theory  theory
     * @return  X values
     */
    public static double[] getXs( Theory theory ) {
        int np = theory.getCount();
        double[] xs = new double[ np ];
        for ( int i = 0; i < np; i++ ) {
            xs[ i ] = theory.getX( i );
        }
        return xs;
    }

    /**
     * Returns the array of Y values contained in a theory object.
     *
     * @param  theory  theory
     * @return Y values
     */
    public static double[] getYs( Theory theory ) {
        int np = theory.getCount();
        double[] ys = new double[ np ];
        for ( int i = 0; i < np; i++ ) {
            ys[ i ] = theory.getY( i );
        }
        return ys;
    }

    /**
     * Returns the index of the point with the smallest X coordinate which
     * is greater than or equal to a given value.
     * Therefore if <code>x</code> is smaller than the lowest X value 
     * defined by <code>theory</code>, the result is 0.
     * As a special case, if <code>x</code> is greater than the highest
     * X value, the result is <code>theory.getCount()</code>.
     *
     * @param   theory   theoretical curve description
     * @param   x  X coordinate to locate
     * @return   greatest index <code>i</code>
     *           for which <code>theory.getX(i)&gt;=x</code>
     * @see   fit.framework.Theory#getFloorIndex
     */
    public static int getCeilingIndex( Theory theory, double x ) {
        if ( Double.isNaN( x ) ) {
            throw new IllegalArgumentException( "x NaN" );
        }
        int ifloor = theory.getFloorIndex( x );
        if ( ifloor == -1 ) {
            return 0;
        }
        else if ( theory.getX( ifloor ) == x ) {
            return ifloor;
        }
        else {
            return ifloor + 1;
        }
    }

    /**
     * Creates a new Theory object from a table.
     * The columns containing X and Y values must be specified.
     *
     * @param  name   theory name
     * @param  table  table containing data
     * @param  xCol   index of column in <code>table</code> containing X values
     * @param  yCol   index of column in <code>table</code> containing Y values
     */
    public static Theory createTheory( String name, StarTable table,
                                       int xCol, int yCol )
            throws IOException {
        if ( ! Number.class.isAssignableFrom( table.getColumnInfo( xCol )
                                             .getContentClass() ) ||
             ! Number.class.isAssignableFrom( table.getColumnInfo( yCol )
                                             .getContentClass() ) ) {
            throw new IllegalArgumentException( "Not numeric columns" );
        }
        TheoryFactory factory = new TheoryFactory();
        RowSequence rseq = table.getRowSequence();
        try {
            while ( rseq.next() ) {
                factory.addPoint( getNumber( rseq.getCell( xCol ) ),
                                  getNumber( rseq.getCell( yCol ) ) );
            }
        }
        finally {
            rseq.close();
        }
        return factory.createTheory( name );
    }

    /**
     * Returns a new SampleComparator based on a given set of theories and
     * of observations.
     * The theories are optionally [red]shifted by a given factor to
     * match the observations.
     *
     * @param   theorySet   array of theoretical function descriptions
     * @param   obsSet  set of observations
     * @param   xfactor  factor by which the X coordinates of the observations
     *          are multiplied relative to the supplied theories - 1+z if 
     *          x is wavelength and observations are redshifted
     */
    public static SampleComparator createComparator( TheorySet theorySet,
                                                     ObservationSet obsSet,
                                                     double xfactor ) {
        int npoint = obsSet.getPointCount();
        double[] xs = new double[ npoint ];
        Interpolator[] interps = new Interpolator[ npoint ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            interps[ ip ] = obsSet.getInterpolator( ip );
            xs[ ip ] = obsSet.getX( ip );
        }
        if ( xfactor != 1.0 ) {
            int ntheory = theorySet.getTheoryCount();
            Theory[] shiftedTheories = new Theory[ ntheory ];
            for ( int ith = 0; ith < ntheory; ith++ ) {
                shiftedTheories[ ith ] =
                    new XFactorTheory( theorySet.getTheory( ith ), xfactor );
            }
            ArraysTheorySet shiftedTheorySet =
                new ArraysTheorySet( shiftedTheories,
                                     theorySet.getMetadataX(),
                                     theorySet.getMetadataY(),
                                     theorySet.getMetadataTable() );
            for ( int ith = 0; ith < ntheory; ith++ ) {
                shiftedTheorySet.registerTheory( theorySet.getTheory( ith ),
                                                 ith );
            }
            theorySet = shiftedTheorySet;
        }
        return new SampleComparator( theorySet, xs, interps );
    }

    /**
     * Concatenates an array of TheorySet objects to produce a single
     * TheorySet.
     * 
     * @param   theorySets  array of input sets
     * @return  one big theory set
     */
    public static TheorySet concatTheorySets( TheorySet[] theorySets ) {
        int nset = theorySets.length;

        /* Concatenate theories into one array. */
        List theoryList = new ArrayList();
        for ( int iset = 0; iset < nset; iset++ ) {
            TheorySet theorySet = theorySets[ iset ];
            for ( int ith = 0; ith < theorySet.getTheoryCount(); ith++ ) {
                theoryList.add( theorySet.getTheory( ith ) );
            }
        }
        Theory[] theories = (Theory[]) theoryList.toArray( new Theory[ 0 ] );

        /* Work out if the X and Y metadata from the theory sets are
         * compatible. */
        MetadataItem xMeta = theorySets[ 0 ].getMetadataX();
        MetadataItem yMeta = theorySets[ 0 ].getMetadataY();
        boolean xCompatible = true;
        boolean yCompatible = true;
        for ( int iset = 1; iset < nset; iset++ ) {
            TheorySet tset = theorySets[ iset ];
            xCompatible = xCompatible
                       && xMeta.isCompatible( tset.getMetadataX() );
            yCompatible = yCompatible
                       && yMeta.isCompatible( tset.getMetadataY() );
        }
        if ( ! xCompatible ) {
            logger_.warning( "Incompatible quantities on X axis?" );
            xMeta = new MetadataItem( xMeta.getName(), null );
        }
        if ( ! yCompatible ) {
            logger_.warning( "Incompatible quantities on Y axis?" );
            yMeta = new MetadataItem( xMeta.getName(), null );
        }

        /* Work out if the metadata tables from the theory sets are compatible
         * with each other (same column names and types). */
        boolean tableCompatible = theorySets.length > 0 
                               && theorySets[ 0 ].getMetadataTable() != null;
        if ( tableCompatible ) {
            tableCompatible = true;
            MetadataTable meta0 = theorySets[ 0 ].getMetadataTable();
            int ncol = meta0.getColumnCount();
            for ( int iset = 1; tableCompatible && iset < nset; iset++ ) {
                MetadataTable meta = theorySets[ iset ].getMetadataTable();
                if ( meta == null || meta.getColumnCount() != ncol ) {
                    tableCompatible = false;
                }
                else {
                    for ( int icol = 0; tableCompatible && icol < ncol;
                          icol++ ) {
                        tableCompatible = tableCompatible
                            && meta.getColumnName( icol )
                                   .equals( meta0.getColumnName( icol ) )
                            && meta.getColumnClass( icol )
                                   .equals( meta0.getColumnClass( icol ) );
                    }
                }
            }
        }

        /* Come up with a metadata table for the concatenated theory set.
         * If the input metadata tables are not compatible with each other,
         * just use null. */
        MetadataTable metaTable;
        if ( tableCompatible ) {
            StarTable[] starTables = new StarTable[ nset ];
            for ( int iset = 0; iset < nset; iset++ ) {
                starTables[ iset ] =
                    getStarTable( theorySets[ iset ].getMetadataTable() );
            }
            try {
                metaTable =
                    new StarMetadataTable( new ConcatStarTable( starTables[ 0 ],
                                                                starTables ) );
            }
            catch ( IOException e ) {
                logger_.log( Level.WARNING,
                             "Error constructing combined metadata", e );
                metaTable = null;
            }
        }
        else {
            metaTable = null;
        }

        /* Return new theory set. */
        return new ArraysTheorySet( theories, xMeta, yMeta, metaTable );
    }

    /**
     * Returns the index of the first column in a metadata table whose
     * name is as given.  If no column has that name, -1 is returned.
     *
     * @param   metaTable  table
     * @param   colName    column name
     * @return  index of table column with name <code>colName</code>, or -1
     */
    public static int getColumnIndex( MetadataTable metaTable,
                                      String colName ) {
        for ( int icol = 0; icol < metaTable.getColumnCount(); icol++ ) {
            if ( colName.equals( metaTable.getColumnName( icol ) ) ) {
                return icol;
            }
        }
        return -1;
    }

    /**
     * Determines whether two metadata tables are compatible with each other.
     * They are judged compatible if each of their columns has the same
     * name and class as each other.
     *
     * @param   meta1  first object to test
     * @param   meta2  second object to test
     * @return   true iff <code>meta1</code> and <code>meta2</code> have
     *           compatible columns
     */
    public static boolean isCompatible( MetadataTable meta1,
                                        MetadataTable meta2 ) {
        if ( meta1 == null || meta2 == null ) {
            return false;
        }
        int ncol = meta1.getColumnCount();
        if ( meta2.getColumnCount() != ncol ) {
            return false;
        }
        for ( int icol = 0; icol < ncol; icol++ ) {
            if ( ! meta1.getColumnName( icol )
                        .equals( meta2.getColumnName( icol ) ) ) {
                return false;
            }
            if ( ! meta1.getColumnClass( icol )
                        .equals( meta2.getColumnClass( icol ) ) ) {
                return false;
            }
        }
        return true;
    }

    /**
     * Returns a StarTable representing the data from an ObservationSet.
     *
     * @param   obsSet  observation set
     * @return  StarTable
     */
    public static StarTable getObservationSetTable(
            final ObservationSet obsSet ) throws IOException {
        int nobs = obsSet.getObsCount();
        int npoint = obsSet.getPointCount();
        ColumnStarTable st = ColumnStarTable.makeTableWithRows( nobs );
        for ( int ip = 0; ip < npoint; ip++ ) {
            final int ip0 = ip;
            int ip1 = ip + 1;
            ValueInfo yInfo =
                new DefaultValueInfo( "POINT_" + ip1, Double.class );
            ValueInfo yerrInfo =
                new DefaultValueInfo( "ERR_" + ip1, Double.class );
            st.addColumn( new ColumnData( yInfo ) {
                public Object readValue( long irow ) {
                    return new Double( obsSet.getY( ip0, (int) irow ) );
                }
            } );
            st.addColumn( new ColumnData( yerrInfo ) {
                public Object readValue( long irow ) {
                    return new Double( obsSet.getYError( ip0, (int) irow ) );
                }
            } );
        }
        MetadataTable meta = obsSet.getMetadataTable();
        return meta == null
             ? (StarTable) st
             : (StarTable) new JoinStarTable( new StarTable[]
                                              { st, getStarTable( meta ), } );
    }

    /**
     * Returns a <code>StarTable</code> representation of a given MetadataTable.
     * If <code>metaTable</code> is built on top of a <code>StarTable</code>
     * then the underlying StarTable will be returned, otherwise a new
     * adaptor will be returned.
     *
     * @param  metaTable  metadata table
     * @return  StarTable  with the same content
     */
    public static StarTable getStarTable( MetadataTable metaTable ) {
        return metaTable instanceof StarMetadataTable
             ? ((StarMetadataTable) metaTable).getStarTable()
             : new MetaStarTable( metaTable );
    }

    /**
     * Returns a double precision value based on an object.
     *
     * @param   value  should normally be a Number object
     * @return  value given by <code>value</code> or <code>NaN</code>
     */ 
    public static double getNumber( Object value ) {
        return value instanceof Number ? ((Number) value).doubleValue()
                                       : Double.NaN;
    }

    /**
     * Returns a general purpose interpolator factory.
     *
     * @return  interpolator factory
     */
    public static InterpolatorFactory defaultInterpolatorFactory() {
        return DEFAULT_INTERPOLATOR_FACTORY;
    }

    /**
     * Converts from flux per Angstrom to flux per Hz.
     *
     * @param   perA  flux per Angstrom
     * @param   lambdaA  wavelength in Angstrom
     * @return  flux per Hz
     */
    public static double perA2perHz( double perA, double lambdaA ) {
        return perA * 1e10 / C * ( lambdaA * 1e-10 ) * ( lambdaA * 1e-10 );
    }

    /**
     * Performs the actual fit and dumps the results into the supplied
     * result sinks.  The sinks are not closed by this routine.
     *
     * @param   theorySet  model data set
     * @param   obsSet    observational data set
     * @param   preScale  whether to prescale the model data to get the best
     *           fit
     * @param   fitCalc   calculator which measures how well curves match
     * @param   sinks    array of result destination handlers
     */
    public static void doFitting( TheorySet theorySet, ObservationSet obsSet,
                                  boolean preScale, FitCalculator fitCalc,
                                  ResultSink[] sinks )
            throws IOException {
        int nobs = obsSet.getObsCount();
        double xFactor = Double.NaN;
        SampleComparator comparator = null;

        /* Loop over each observation in the input set. */
        for ( int iobs = 0; iobs < nobs; iobs++ ) {

            /* Ensure that we have a suitable comparator. */
            if ( obsSet.getXFactor( iobs ) != xFactor ) {
                xFactor = obsSet.getXFactor( iobs );
                comparator = createComparator( theorySet, obsSet, xFactor );
            }

            /* Perform comparisons of the observation with each of the
             * theories in the TheorySet. */
            Comparison[] comparisons =
                comparator.fitData( getYs( obsSet, iobs ),
                                    getYErrors( obsSet, iobs ),
                                    fitCalc, preScale );

            /* Pass the resulting comparison set to each of the result
             * handlers.  Use the TheorySet from the comparator here
             * rather than the original, since it may be altered. */
            TheorySet tset = comparator.getTheorySet();
            for ( int is = 0; is < sinks.length; is++ ) {
                sinks[ is ].addResult( obsSet, iobs, comparisons, tset );
            }
        }
    }
}
