package fit.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import uk.ac.starlink.table.ColumnInfo;
import uk.ac.starlink.table.RowSequence;
import uk.ac.starlink.table.StarTable;
import uk.ac.starlink.table.Tables;
import uk.ac.starlink.util.DoubleList;
import fit.framework.Interpolator;
import fit.framework.MetadataItem;
import fit.framework.ObservationSet;

/**
 * Produces an ObservationSet from data in a StarTable.
 * The table must be arranged so that different columns represent the
 * Y values (and perhaps errors) of observations at different X values,
 * and each row represents a different observation.
 * An instance of this class can be used to produce a single ObservationSet
 * from a single table.  To use it, make repeated calls of 
 * {@link #definePoint} to describe which columns of the table represent
 * which X values.
 *
 * <p>The X values are typically photometric magnitudes, so for a table
 * which looks like this (2MASS):
 * <pre>
 *                ID       J  J_ERR      H  H_ERR      K  K_ERR
 *                --       -  -----      -  -----      -  -----
 *     115918+294553   15.10   0.07  14.43   0.09  13.94   0.10
 *     083121-542124   15.38   0.14  14.50   0.14  13.79   0.13
 *     113713+152557   12.91   0.04  12.22   0.06  12.02   0.08
 * </pre>
 * You would do something like:
 * <pre>
 *    TableObservationSetFactory fact = new TableObservationSetFactory( table );
 *    fact.definePoint( ... );
 *    fact.definePoint( ... );
 *    ObservationSet obsSet = fact.createObservationSet();
 * </pre>
 *
 * @author   Mark Taylor
 * @since    23 Oct 2006
 */
public class TableObservationSetFactory {

    private final StarTable table_;
    private final MetadataItem metaX_;
    private final MetadataItem metaY_;
    private final Map colMap_;
    private final List pointDefList_;
    private int zCol_ = -1;
    private InterpolatorFactory interpFact_;

    /**
     * Constructor.
     * 
     * @param   table   table containing observation data
     */
    public TableObservationSetFactory( StarTable table, MetadataItem metaX,
                                       MetadataItem metaY ) {
        table_ = table;
        metaX_ = metaX;
        metaY_ = metaY;
        pointDefList_ = new ArrayList();
        colMap_ = new HashMap();
        int ncol = table_.getColumnCount();
        for ( int icol = 0; icol < ncol; icol++ ) {
            ColumnInfo info = table_.getColumnInfo( icol );
            String key = normalise( info.getName() );
            if ( colMap_.containsKey( key ) ) {
                colMap_.put( key, new Integer( -1 ) );
            }
            else {
                colMap_.put( key, new Integer( icol ) );
            }
        }
    }

    /**
     * Defines one point (X value) in the output ObservationSet by
     * giving its nominal X value, X error, interpolator, and the 
     * names of the columns giving the Y value and Y error for that point.
     * The <code>yerrColName</code> parameter may be <code>null</code> if 
     * no Y error is available.
     *
     * @param   x   X value
     * @param   xerr  X error
     * @param   yColName  name of the table column giving Y values
     * @param   yerrColName  name of the table column giving Y errors, or null
     * @param   interp  interpolator for interpolating in theoretical functions
     */
    public void definePoint( double x, double xerr, String yColName,
                             String yerrColName, Interpolator interp ) {
        String yKey = normalise( yColName );
        String yerrKey = normalise( yerrColName );
        if ( ! colMap_.containsKey( yKey ) ) {
            throw new IllegalArgumentException( "No column " + yColName );
        }
        int yCol = ((Integer) colMap_.get( yKey )).intValue();
        int yerrCol;
        if ( yerrKey == null ) {
            yerrCol = -1;
        }
        else if ( ! colMap_.containsKey( yerrKey ) ) {
            throw new IllegalArgumentException( "No column " + yerrColName );
        }
        else {
            yerrCol = ((Integer) colMap_.get( yerrKey )).intValue();
        }
        pointDefList_.add( new PointDef( x, xerr, yCol, yerrCol, interp ) );
    }

    /**
     * Defines one point (X value) in the output ObservationSet by
     * giving its nominal X value, X error, and the names of the
     * columns giving the Y value and Y error for that point.
     * The interpolation scheme is obtained using the currently
     * installed {@link InterpolatorFactory}.
     *
     * @param   x   X value
     * @param   xerr  X error
     * @param   yColName  name of the table column giving Y values
     * @param   yerrColName  name of the table column giving Y errors
     */
    public void definePoint( double x, double xerr, String yColName,
                             String yerrColName ) {
        if ( interpFact_ == null ) {
            throw new IllegalStateException( "No interpolator factory set" );
        }
        definePoint( x, xerr, yColName, yerrColName,
                     interpFact_.getInterpolator( x, xerr ) );
    }

    /**
     * Returns the interpolator factory used by this object.
     *
     * @return  interpolator factory
     */
    public InterpolatorFactory getInterpolatorFactory() {
        return interpFact_;
    }

    /**
     * Sets the interpolator factory used by this object.
     *
     * @param  interpFact  new interpolator factory
     */
    public void setInterpolatorFactory( InterpolatorFactory interpFact ) {
        interpFact_ = interpFact;
    }

    /**
     * Marks a table column as containing a redshift value.
     *
     * @param  zColName  redshift column name
     */
    public void defineRedshift( String zColName ) {
        String zKey = normalise( zColName );
        if ( ! colMap_.containsKey( zKey ) ) {
            throw new IllegalArgumentException( "No column " + zColName );
        }
        zCol_ = ((Integer) colMap_.get( zKey )).intValue();
    }

    /**
     * Returns a new ObservationSet based on this factory's table and
     * the {@link #definePoint} calls to date.
     *
     * @return   new observation set
     */
    public ObservationSet createObservationSet() throws IOException {
        Collections.sort( pointDefList_ );
        PointDef[] pointDefs =
            (PointDef[]) pointDefList_.toArray( new PointDef[ 0 ] );
        int npoint = pointDefs.length;
        double[] xs = new double[ npoint ];
        double[] xerrs = new double[ npoint ];
        Interpolator[] interps = new Interpolator[ npoint ];
        DoubleList[] yLists_ = new DoubleList[ npoint ];
        DoubleList[] yerrLists_ = new DoubleList[ npoint ];
        DoubleList xfactorList = new DoubleList();
        int nrow = Tables.checkedLongToInt( table_.getRowCount() );
        for ( int ip = 0; ip < npoint; ip++ ) {
            PointDef pointDef = pointDefs[ ip ];
            xs[ ip ] = pointDef.x_;
            xerrs[ ip ] = pointDef.xerr_;
            interps[ ip ] = pointDef.interp_;
            yLists_[ ip ] = nrow > 0 ? new DoubleList( nrow )
                                     : new DoubleList();
            yerrLists_[ ip ] = nrow > 0 ? new DoubleList( nrow )
                                        : new DoubleList();
        }
        RowSequence rseq = table_.getRowSequence();
        int nobs = 0;
        try {
            while ( rseq.next() ) {
                nobs++;
                Object[] row = rseq.getRow();
                for ( int ip = 0; ip < npoint; ip++ ) {
                    PointDef pointDef = pointDefs[ ip ];
                    yLists_[ ip ]
                        .add( FitUtils.getNumber( row[ pointDef.yCol_ ] ) );
                    yerrLists_[ ip ]
                        .add( pointDef.yerrCol_ >= 0
                              ? FitUtils.getNumber( row[ pointDef.yerrCol_ ] )
                              : Double.NaN );
                }
                double xfactor = zCol_ >= 0
                               ? 1.0 + FitUtils.getNumber( row[ zCol_ ] )
                               : 0.0;
                xfactorList.add( xfactor > 0.0 ? xfactor : 1.0 );
            }
        }
        finally {
            rseq.close();
        }
        double[][] ys = new double[ npoint ][ nobs ];
        double[][] yerrs = new double[ npoint ][ nobs ];
        for ( int ip = 0; ip < npoint; ip++ ) {
            ys[ ip ] = yLists_[ ip ].toDoubleArray();
            yerrs[ ip ] = yerrLists_[ ip ].toDoubleArray();
        }
        double[] xfactors = zCol_ >= 0 ? xfactorList.toDoubleArray() : null;
        return new ArraysObservationSet( npoint, nobs, xs, xerrs, interps,
                                         ys, yerrs, xfactors, metaX_, metaY_,
                                         new StarMetadataTable( table_ ) );
    }

    /**
     * Normalises a column name so that it matches other normalised ones
     * correctly (using <code>equals</code>).  Currently folds case.
     *
     * @param  colName   raw key value
     * @return   normalised key value
     */
    private String normalise( String colName ) {
        if ( colName == null || colName.equalsIgnoreCase( "null" ) ) {
            return null;
        }
        else {
            return colName.toLowerCase();
        }
    }

    /**
     * Private helper struct-type class which aggregates the items needed
     * to define a point for this factory.
     */
    private static class PointDef extends XValue {
        final double x_;
        final double xerr_;
        final int yCol_;
        final int yerrCol_;
        final Interpolator interp_;
        PointDef( double x, double xerr, int yCol, int yerrCol,
                  Interpolator interp ) {
            super( x );
            x_ = x;
            xerr_ = xerr;
            yCol_ = yCol;
            yerrCol_ = yerrCol;
            interp_ = interp;
        }
    }
}
