Home : Map : Chapter 8 : Java : Tech : Physics :
Least Squares Fit to a Polynomial
JavaTech
Course Map
Chapter 8

Introduction
Threads Overview
  Demo 1   Demo 2
Stopping Threads
Multi-Processing
Thread Tasks
Animations
 
 Demo 3   Demo 4  
  Demo 5

Non-interacting
  Demo 6

Task Splitting
  Demo 7

Exclusivity
  Demo 8

Communicating
  Demo 9

Priority/Scheduling
More Thread

Exercises

    Supplements
Java2D Animation
  Demo 1 
Processor View
More Concurrency
Cloning
  Demo 2  Demo 3 

     About JavaTech
     Codes List
     Exercises
     Feedback
     References
     Resources
     Tips
     Topic Index
     Course Guide
     What's New

In the previous section we discussed the least squares fit algorithm for finding the optimum straight line to a set of data points. Fitting to a polynomial follows the same basic approach: the residuals distribution is minimized with respect to each coefficient in the polynomial. This leads to a set of k equations for k coefficients. The equations can be expressed as a matrix equation and matrix techniques are then used to solve for the coefficients. The techniques for solving for the coefficients are given in many numerical analysis texts (see references below) so we won't go into the details here.

Rather than writing a matrix package from scratch, we can take advantage of the many Java math packages now available (see references below). Here we use the JAMA: Java Matrix Package, which is a free, open source set of mathematical classes. It provides a matrix class, a LU decompostion, QR decomposition and several other useful classes. The decomposition classes provide least square solution methods.

In the FitPoly.java class shown below, a static method fit() receives in its argument list the arrays for the coefficients, the x, y, point coordinates, the corresponding errors, and the number of points to fit.

  public static void fit(double [] parameters, double [] x, double [] y,
                       double [] sigmaX, double [] sigmaY, int numPoints){

For example, a quadratic polynomial would correspond to

   f(x[i]) = parameter[0] + parameter[1] * x[i] + parameters[2]*x[i]*x[i]

The least squares method minimizes

   sum{ (f(x[i]) - y[i])2/(yErr[i]*yErr[i]) }

The errors on the coefficiencts are returned in the top half of the parameters. So parameter[3] is the error on the coefficient parameter[0], parameter[4] is the error on parameter[1], and so forth.

The first part of the fit() method makes the sums obtained from the least squares method. These are carried in arrays from which instances of the the JAMA matrix are made. The QR decomposition class is then used to solve for the coefficients as in:.

    Matrix alphaMatrix = new Matrix(alpha);
    QRDecomposition alphaQRD = new QRDecomposition(alphaMatrix);
    Matrix betaMatrix = new Matrix(beta,nk);
    Matrix paramMatrix;
    try{
       paramMatrix = alphaQRD.solve(betaMatrix);
    }catch( Exception e){
      System.out.println("QRD solve failed: "+ e);
      return;
    }

The errors are provided by the diagonal elements of the covariance matrix.

    // The inverse provides the covariance matrix.
    Matrix c = alphaMatrix.inverse();

    for(int k=0; k < nk; k++){

      parameters[k] = paramMatrix.get(k,0);

      // Diagonal elements of the covariance matrix provide
      // the square of the parameter errors. Put in top half
      // of the parametes array.
      parameters[k+nk] = Math.sqrt(c.get(k,k));
    }

The following applet creates quadratic polynomials with random coefficient values. Points along these curves, along with dummy error values, are passed to the FitPoly.fit(,,) method for fitting. The DrawFunction subclass DrawPoly then overlays the curve on the DrawPanel, which also displays the points with the DrawPonts class previously discussed.

The JAMA files reside in the subdirectories Jama/ and Jama/util below the directory holding the classes for the applet. So the files belong to the packages Jama and Jama.util.

PolyFitApplet.java - This program generates points along a quadratic line and then fits a polynominal to them. A histogram displays the residuals.

+ New classes:
FitPoly.java - fit points to a polynominal. It uses the open source JAMA package of matrix math classes to do the least squares fit with QR decomposition.

DrawPoly.java
- Subclass of DrawFunction that plots a polynominal curve on an instance of DrawPanel.

+ Previous classes:
Ch. 8:Physics: Fit.java
Ch. 6:Tech: DrawFunction.java, DrawPoints.java, DrawPanel.java
Ch. 6:Tech: Histogram.java, HistPanel.java
Ch. 6:Tech: PlotPanel.java, PlotFormat.java

import javax.swing.*;
import java.awt.*;
import java.awt.event.*;

/**
  *
  *  It generates points along a quadratic curve and then fits a
  *  polynomial to them. This simulates track fitting
  *  in a detector.
  *
  *  The number of curves and the SD of
  *  the smearing of the track measurement errors taken from
  *  entries in two text fields. A  histogram holds the residuals.
  *
  *  This program will run as an applet inside
  *  an application frame.
  *
  *  The "Go" button starts the track generation and fitting in a
  *  thread. "Clear"  button clears the histograms.
  *  In standalone mode, the Exit button closes the program.
  *  
 **/
public class PolyFitApplet extends JApplet
             implements ActionListener, Runnable
{
  // Use the HistPanel JPanel subclass here
  HistPanel fResidualsPanel;

  // Use a DrawPanel to display the points to fit
  DrawPanel fDrawPanel;

  // Thee histograms to record differences between
  // generated tracks and fitted tracks.
  Histogram fResidualsHist;

  // Use DrawFunction subclasses to plot on the DrawPanel
  DrawFunction [] fDrawFunctions;

  // Set values for the tracks including the default number
  // of tracks to generate, the track area, the SD smearing
  // of the data points, and the x values where the track
  // y coordinates are measured.
  int fNumCurves = 1;
  double fYMin   =   0.0;
  double fYMax   =  10.0;
  double fXMin   =   0.0;
  double fXMax   = 100.0;

  double fCurveSmear = 0.5;

  double [] fX    = new double[20];
  double [] fY    = new double[20];
  double [] fYErr = new double[20];

  // Data array used to pass track points to DrawPoints
  double [][]fData = new double[4][];

  // Random number generator
  java.util.Random fRan;

  // Inputs for the number of tracks to generate
  JTextField fNumCurvesField;
  // and the smearing of the tracking points.
  JTextField fSmearField;

  // Flag for whether the applet is in a browser
  // or running via the main () below.
  boolean fInBrowser=true;

  //Buttons
  JButton fGoButton;
  JButton fClearButton;
  JButton fExitButton;

  // Use thread reference as flag.
  Thread fThread;

  /**
    * Create a User Interface with histograms and buttons to
    * control the program. Two text files hold number of tracks
    * to be generated and the measurement smearing.
   **/
  public void init () {

    // Will need random number generator for generating tracks
    // and for smearing the measurement points
    fRan = new java.util.Random ();

    // Create instances of DrawFunction for use in DrawPanel
    // to plot the tracks and the measured points along them.
    fDrawFunctions    = new DrawFunction[2];
    fDrawFunctions[0] = new DrawPoly ();
    fDrawFunctions[1] = new DrawPoints ();

    // Start building the GUI.
    JPanel panel = new JPanel (new GridLayout (2,1));

    // Will plot the tracks on an instance of DrawPanel.
    fDrawPanel =
      new DrawPanel (fYMin,fYMax, fXMin, fXMax,
                    fDrawFunctions);

    fDrawPanel.setTitle ("Fit Points");
    fDrawPanel.setXLabel ("Y vs X");

    // Create the x axis values for the curves.
    double dx = 5.0;
    fX[0] = 0.0;
    for (int i=1; i < 20; i++){
        fX[i] = fX[i-1] + dx;
    }

    panel.add (fDrawPanel);

    // Create histogram to show the quality of the fits.
    fResidualsHist  = new Histogram ("Ydata - Yfit","Residuals", 20, -2,2.);

    // Use another panel to hold the histogram and controls panels.
    JPanel hist_crls_panel = new JPanel (new BorderLayout ());

    // A panel to hold residuals histogram
    fResidualsPanel=new HistPanel (fResidualsHist);

    // Add the panel of histograms to the main panel
    hist_crls_panel.add ("Center",fResidualsPanel);

    // Use a textfield for an input parameter.
    fNumCurvesField =
      new JTextField (Integer.toString (fNumCurves), 10);

    // Use a textfield for an input parameter.
    fSmearField =
      new JTextField (Double.toString (fCurveSmear), 10);

    // If return hit after entering text, the
    // actionPerformed will be invoked.
    fNumCurvesField.addActionListener (this);
    fSmearField.addActionListener (this);

    fGoButton = new JButton ("Go");
    fGoButton.addActionListener (this);

    fClearButton = new JButton ("Clear");
    fClearButton.addActionListener (this);

    fExitButton = new JButton ("Exit");
    fExitButton.addActionListener (this);

    JPanel control_panel = new JPanel (new GridLayout (1,5));

    control_panel.add (fNumCurvesField);
    control_panel.add (fSmearField);
    control_panel.add (fGoButton);
    control_panel.add (fClearButton);
    control_panel.add (fExitButton);

    if (fInBrowser) fExitButton.setEnabled (false);

    hist_crls_panel.add (control_panel,"South");

    panel.add (hist_crls_panel);

    // Add text area with scrolling to the applet
    add (panel);

  } // init

  public void actionPerformed (ActionEvent e) {
    Object source = e.getSource ();
    if ( source == fGoButton || source == fNumCurvesField
                           || source == fSmearField) {
      String strNumDataPoints = fNumCurvesField.getText ();
      String strCurveSmear = fSmearField.getText ();
      try{
          fNumCurves = Integer.parseInt (strNumDataPoints);
          fCurveSmear = Double.parseDouble (strCurveSmear);
      }
      catch (NumberFormatException ex) {
        // Could open an error dialog here but just
        // display a message on the browser status line.
        showStatus ("Bad input value");
        return;
      }

      fGoButton.setEnabled (false);
      fClearButton.setEnabled (false);
      if (fThread != null) stop ();
      fThread = new Thread (this);
      fThread.start ();

    }
    else if ( source == fClearButton) {
        fResidualsHist.clear ();
        repaint ();
    } else if (!fInBrowser)
        System.exit (0);

  } // actionPerformed

  public void stop (){
    // If thread is still running, setting this
    // flag will kill it.
    fThread = null;
  } // stop

  /**  Generate the tracks in a thread. */
  public void run () {

    for (int i=0; i < fNumCurves; i++){
      // Stop the thread if flag set
      if (fThread == null) return;

      // Generate a random track.
      double [] genParams = genRanCurve (fXMax-fXMin,
                  fYMax-fYMin,
                  fX, fY, fYErr, fCurveSmear);

      // Fit points to quadratic. Use constant error.
      double [] fitParams = new double[6];
      FitPoly.fit (fitParams, fX, fY, null, fYErr, fX.length);

      // Pass the parameters to the polynominal line fit.
      fDrawFunctions[0].setParameters (fitParams,null);

      // Pass the data points to the DrawPoints object via
      // the 2-D array.
      fDrawFunctions[1].setParameters (null, fData);

      // Redrawing the panel will cause the paintContents (Graphics g)
      // method in DrawPanel to invoke the draw () method for the line
      // and points drawing functions.
      fDrawPanel.repaint ();

      // Include residuals == difference between the measured value
      // and the fitted value at the points at each x position
      for (int j=0; j < fX.length; j++){
          double yFit = fitParams[0] + fitParams[1]*fX[j]
                                  + fitParams[2]*fX[j]*fX[j];

          fResidualsHist.add (fY[j] - yFit);
      }

      // Pause briefly to let users see the track.
      try{
          Thread.sleep (30);
      }catch (InterruptedException e){}
    }

    repaint ();

    fGoButton.setEnabled (true);
    fClearButton.setEnabled (true);
  } // run

  /**
    *  Generate a quadratic plot and obtain points along the curve.
    *  Smear the vertical coordinate with a Gaussian.
   **/
  double [] genRanCurve (double x_range, double y_range,
                   double [] x_curve, double [] y_curve,
                   double [] y_curve_err,
                   double smear){

    // Parameters for a quadratic line.
    double [] quadParam = new double[3];

    // Simulated quadratic
    double y0 = y_range* (0.5 + 0.25 * fRan.nextDouble ());
    double y1 = y_range * fRan.nextDouble ();

    // Choose some dummy paramters for the polynominal
    quadParam[0] = y0;
    quadParam[1] =  (y1-y0)/ (8.0*x_range);
    quadParam[2] =  (fRan.nextDouble () - 0.5)/100.0;

    // Make the points and errors along a quadratic line
    for (int i=0; i < x_curve.length; i++) {
      y_curve[i] = y0 + quadParam[1]*x_curve[i]
                     + quadParam[2]*x_curve[i]*x_curve[i];

      double curve_err = smear*fRan.nextGaussian ();

      // Add smear factor for this point
      y_curve[i] += curve_err;

      // Create a dummy average std.dev. error on the y value
      // for this x position.
      y_curve_err[i] =  (1.0 + fRan.nextDouble () ) * smear;

    }

    // Set up the parameters in the drawing function.

    fDrawFunctions[0].setParameters (quadParam,null);

    // The FitPoly function will need this data via
    // a 2-D array.
    fData[0] = fY;
    fData[1] = fX;
    fData[2] = fYErr;
    fData[3] = null;

    // Return the track parameters.
    return quadParam;

  } // genRanCurve


  /**
    *  Allow for option of running the program in standalone mode.
    *  Create the applet and add to a frame.
   **/
  public static void main (String[] args) {
    //
    int frame_width=450;
    int frame_height=450;

    //
    PolyFitApplet applet = new PolyFitApplet ();
    applet.fInBrowser = false;
    applet.init ();

    // Following anonymous class used to close window & exit program
    JFrame f = new JFrame ("Demo");
    f.setDefaultCloseOperation (JFrame.EXIT_ON_CLOSE);

    // Add applet to the frame
    f.getContentPane ().add ( applet);
    f.setSize (new Dimension (frame_width,frame_height));
    f.setVisible (true);
  } // main

} // PolyFitApplet
import Jama.*;
import Jama.util.*;

/**
  *  Fit polynomial line to a set of data points.
  *  Implements the Fit interface.
 **/
public class FitPoly extends Fit
{
  /**
    *  Use the Least Squares fit method for fitting a
    *  polynomial to 2-D data for measurements
    *  y[i] vs. dependent variable x[i]. This fit assumes
    *  there are errors only on the y measuresments as
    *  given by the sigma_y array.
    *
    *  See, e.g. Press et al., "Numerical Recipes..." for details
    *  of the algorithm.
    *
    *  The solution to the LSQ fit uses the open source JAMA -
    *  "A Java Matrix Package" classes. See http://math.nist.gov/javanumerics/jama/
    *  for description.
    *
    *  @param parameters - first half of the array holds the coefficients for
    *  the polynomial.
    *  The second half holds the errors on the coefficients.
    *  @param x - independent variable
    *  @param y - vertical dependent variable
    *  @param sigma_x - std. dev. error on each x value
    *  @param sigma_y - std. dev. error on each y value
    *  @param num_points - number of points to fit. Less than or equal to the
    *  dimension of the x array.
   **/
  public static void fit (double [] parameters, double [] x, double [] y,
                       double [] sigma_x, double [] sigma_y, int num_points){

    // numParams = num coeff + error on each coeff.
    int nk = parameters.length/2;

    double [][] alpha  = new double[nk][nk];
    double [] beta = new double[nk];
    double term = 0;

    for (int k=0; k < nk; k++) {

        // Only need to calculate diagonal and upper half
        // of symmetric matrix.
        for (int j=k; j < nk; j++) {

            // Calc terms over the data points
            term = 0.0;
            alpha[k][j] = 0.0;
            for (int i=0; i < num_points; i++) {

                double prod1 = 1.0;
                // Calculate x^k
                if ( k > 0) for (int m=0; m < k; m++) prod1 *= x[i];

                double prod2 = 1.0;
                // Calculate x^j
                if ( j > 0) for (int m=0; m < j; m++) prod2 *= x[i];

                // Calculate x^k * x^j
                term =  (prod1*prod2);

                if (sigma_y != null && sigma_y[i] != 0.0)
                    term /=  (sigma_y[i]*sigma_y[i]);
                alpha[k][j] += term;
            }
            alpha[j][k] = alpha[k][j];// C will need to be inverted.
        }

        for (int i=0; i < num_points; i++) {
            double prod1 = 1.0;
            if (k > 0) for ( int m=0; m < k; m++) prod1 *= x[i];
            term =  (y[i] * prod1);
            if (sigma_y != null  && sigma_y[i] != 0.0)
                term /=  (sigma_y[i]*sigma_y[i]);
            beta[k] +=term;
        }
    }

    // Use the Jama QR Decomposition classes to solve for
    // the parameters.
    Matrix alpha_matrix = new Matrix (alpha);
    QRDecomposition alpha_QRD = new QRDecomposition (alpha_matrix);
    Matrix beta_matrix = new Matrix (beta,nk);
    Matrix param_matrix;
    try {
       param_matrix = alpha_QRD.solve (beta_matrix);
    }
    catch (Exception e) {
      System.out.println ("QRD solve failed: "+ e);
      return;
    }

    // The inverse provides the covariance matrix.
    Matrix c = alpha_matrix.inverse ();

    for (int k=0; k < nk; k++) {

      parameters[k] = param_matrix.get (k,0);

      // Diagonal elements of the covariance matrix provide
      // the square of the parameter errors. Put in top half
      // of the parametes array.
      parameters[k+nk] = Math.sqrt (c.get (k,k));
    }

  } // fit

} // FitPoly
import java.awt.*;

/**
  *  Drawi polynominal line onto the PlotPanel. Extend the
  *  DrawFuction class and override the draw method.
  *
 **/
public class DrawPoly extends DrawFunction
{

  int [] fXFrame;
  int [] fYFrame;

  /**
    *  Draw a quadracti funtion ax^2 + bx + c  onto the PlotPanel.
    *
    *  @param g graphics context
    *  @param frame_width display area width in pixels.
    *  @param frame_height display area height in pixels.
    *  @param frame_start_x horizontal point on display where
    *    drawing starts in pixel number.
    *  @param frame_start_y vertical point on display where
    *    drawing starts in pixel number.
    *  @param x_scale 2 dimensional array holding lower and
    *    upper values of the function input scale range.
    *  @param y_scale 2 dimensional array holding lower and
    *    upper values of the function output scale range.
   **/
  public void draw (Graphics g,
                   int frame_start_x, int frame_start_y,
                   int frame_width, int frame_height,
                   double [] x_scale, double [] y_scale) {

    Color save_color = g.getColor ();

    g.setColor (fColor);

    // Check if ready to draw the line
    if (fParameters == null) return;

    int num_params = fParameters.length/2;

    // Limit to polynominals of degree 5
    if (num_params > 6) return;

    // Get the number of horizontal pixels.
    int num_points = frame_width;

    // Get conversion factor from data scale to frame pixels
    double y_scale_factor = frame_height/(y_scale[y_scale.length-1] - y_scale[0]);
    double x_scale_factor = frame_width/(x_scale[x_scale.length-1] - x_scale[0]);

    // Create arrays of points for each
    // point of the curve. Recreate if width changes.
    if (fYFrame == null || fYFrame.length != frame_width){
        fYFrame = new int[num_points];
        fXFrame = new int[num_points];
    }

    // Create a sine curve from a sequence
    // of short line segments
    double prod,prod2,prod3,y;
    double x = x_scale[0];
    double del_x = 1/x_scale_factor;

    // Calculate the func = a0 + a1 * x + ...
    for (int i=0; i < num_points; i++) {

        // a0
        y = fParameters[0];
        x += del_x;

        prod2 = x*x;
        prod3 = prod2 * x;

        //
        switch  (num_params){
           case 6:
              // p5 * x^5
              y += fParameters[5] * prod2 * prod3;
           case 5:
              // a4 * x^4
              y += fParameters[4] * prod2 * prod2;
           case 4:
              // a3 * x^3
              y += fParameters[3] * prod2 * x;
           case 3:
              // a2 * x^2
              y += fParameters[2] * prod2 ;
           case 2:
              // a1 * x^1
              y += fParameters[1] * x;
        }

        // Convert to pixel coords
        fYFrame[i] = frame_height - 
          (int)((y - y_scale[0]) * y_scale_factor) + frame_start_y;
        fXFrame[i] = frame_start_x + (int)((x - x_scale[0]) * x_scale_factor);
    }

    // Then pass the polygon object for drawing
    g.drawPolyline (fXFrame,fYFrame,num_points);

    g.setColor (save_color);

  } // draw

} // DrawPoly

 

References & Web Resources

 

Most recent update: Oct. 27, 20050

              Tech
Timers
  Demo 1
Hist. Adapt Range
  Demo 2
Sorting in Java
  Demo 3
Histogram Median
  Demo 4
Refactoring
  Demo 5
Error Bars
  Demo 6
Exercises

           Physics
Least Squares Fit
  Demo 1
Fit to Polynomial
  Demo 2
Fit Hist Errors
  Demo 3
Discretization
  Demo 4
Timing
  Demo 5
Exercises

  Part I Part II Part III
Java Core 1  2  3  4  5  6  7  8  9  10  11  12 13 14 15 16 17
18 19 20
21
22 23 24
Supplements

1  2  3  4  5  6  7  8  9  10  11  12

Tech 1  2  3  4  5  6  7  8  9  10  11  12
Physics 1  2  3  4  5  6  7  8  9  10  11  12

Java is a trademark of Sun Microsystems, Inc.