In the previous section we discussed
the least squares fit algorithm for finding the optimum straight
line to a set of data points. Fitting to a polynomial follows the
same basic approach: the residuals distribution is minimized with
respect to each coefficient in the polynomial. This leads to a set
of k equations
for k coefficients.
The equations can be expressed as a matrix equation and matrix techniques
are then used to solve for the coefficients. The techniques for
solving for the coefficients are given in many numerical analysis
texts (see references below) so we won't go
into the details here.
Rather than writing a matrix package from scratch, we can take
advantage of the many Java math packages now available (see references
below). Here we use the JAMA:
Java Matrix Package, which is a free, open source set of mathematical
classes. It provides a matrix class, a LU decompostion, QR decomposition
and several other useful classes. The decomposition classes provide
least square solution methods.
In the FitPoly.java
class shown below, a static method fit()
receives in its argument list the arrays for the coefficients, the
x, y, point coordinates, the corresponding errors, and the number
of points to fit.
public
static void fit(double [] parameters, double [] x, double [] y,
double
[] sigmaX, double [] sigmaY, int numPoints){
For example, a quadratic polynomial would correspond to
f(x[i])
= parameter[0] + parameter[1] * x[i] + parameters[2]*x[i]*x[i]
The least squares method minimizes
sum{
(f(x[i]) - y[i])2/(yErr[i]*yErr[i]) }
The errors on the coefficiencts are returned in the top half of
the parameters. So parameter[3]
is the error on the coefficient parameter[0],
parameter[4]
is the error on parameter[1],
and so forth.
The first part of the fit()
method makes the sums obtained from the least squares method. These
are carried in arrays from which instances of the the JAMA matrix
are made. The QR decomposition class is then used to solve for the
coefficients as in:.
Matrix
alphaMatrix = new Matrix(alpha);
QRDecomposition alphaQRD = new QRDecomposition(alphaMatrix);
Matrix betaMatrix = new Matrix(beta,nk);
Matrix paramMatrix;
try{
paramMatrix = alphaQRD.solve(betaMatrix);
}catch( Exception e){
System.out.println("QRD solve
failed: "+ e);
return;
}
The errors are provided by the diagonal elements of the covariance
matrix.
//
The inverse provides the covariance matrix.
Matrix c = alphaMatrix.inverse();
for(int k=0; k < nk; k++){
parameters[k] = paramMatrix.get(k,0);
// Diagonal elements of the
covariance matrix provide
// the square of the parameter
errors. Put in top half
// of the parametes array.
parameters[k+nk] = Math.sqrt(c.get(k,k));
}
The following applet creates quadratic polynomials with random
coefficient values. Points along these curves, along with dummy
error values, are passed to the FitPoly.fit(,,)
method for fitting. The DrawFunction
subclass DrawPoly
then overlays the curve on the DrawPanel,
which also displays the points with the DrawPonts
class previously discussed.
The JAMA files reside in the subdirectories Jama/
and Jama/util
below the directory holding the classes for the applet. So the files
belong to the packages Jama
and Jama.util.
PolyFitApplet.java
- This
program generates points along a quadratic line and then
fits a polynominal to them. A histogram displays the residuals.
+ New classes:
FitPoly.java
- fit points to a polynominal. It uses the open source
JAMA package of matrix math classes to do the least squares
fit with QR decomposition.
DrawPoly.java
- Subclass of DrawFunction that plots a polynominal curve
on an instance of DrawPanel.
+ Previous classes:
Ch. 8:Physics:
Fit.java
Ch.
6:Tech: DrawFunction.java,
DrawPoints.java,
DrawPanel.java
Ch.
6:Tech: Histogram.java,
HistPanel.java
Ch.
6:Tech: PlotPanel.java,
PlotFormat.java
|
import
javax.swing.*;
import java.awt.*;
import java.awt.event.*;
/**
*
* It generates points along a quadratic
curve and then fits a
* polynomial to them. This simulates
track fitting
* in a detector.
*
* The number of curves and the SD of
* the smearing of the track measurement
errors taken from
* entries in two text fields. A histogram
holds the residuals.
*
* This program will run as an applet
inside
* an application frame.
*
* The "Go" button starts the track
generation and fitting in a
* thread. "Clear" button
clears the histograms.
* In standalone mode, the Exit button
closes the program.
*
**/
public class PolyFitApplet extends JApplet
implements ActionListener, Runnable
{
// Use the HistPanel JPanel subclass here
HistPanel fResidualsPanel;
// Use a DrawPanel to display the points to fit
DrawPanel fDrawPanel;
// Thee histograms to record differences between
// generated tracks and fitted tracks.
Histogram fResidualsHist;
// Use DrawFunction subclasses to plot on the
DrawPanel
DrawFunction [] fDrawFunctions;
// Set values for the tracks including the default
number
// of tracks to generate, the track area, the
SD smearing
// of the data points, and the x values where
the track
// y coordinates are measured.
int fNumCurves = 1;
double fYMin = 0.0;
double fYMax = 10.0;
double fXMin = 0.0;
double fXMax = 100.0;
double fCurveSmear = 0.5;
double [] fX = new double[20];
double [] fY = new double[20];
double [] fYErr = new double[20];
// Data array used to pass track points to DrawPoints
double [][]fData = new double[4][];
// Random number generator
java.util.Random fRan;
// Inputs for the number of tracks to generate
JTextField fNumCurvesField;
// and the smearing of the tracking points.
JTextField fSmearField;
// Flag for whether the applet is in a browser
// or running via the main () below.
boolean fInBrowser=true;
//Buttons
JButton fGoButton;
JButton fClearButton;
JButton fExitButton;
// Use thread reference as flag.
Thread fThread;
/**
* Create a User Interface with histograms
and buttons to
* control the program. Two text files
hold number of tracks
* to be generated and the measurement
smearing.
**/
public void init () {
// Will need random number generator
for generating tracks
// and for smearing the measurement
points
fRan = new java.util.Random ();
// Create instances of DrawFunction
for use in DrawPanel
// to plot the tracks and the measured
points along them.
fDrawFunctions =
new DrawFunction[2];
fDrawFunctions[0] = new DrawPoly ();
fDrawFunctions[1] = new DrawPoints
();
// Start building the GUI.
JPanel panel = new JPanel (new GridLayout
(2,1));
// Will plot the tracks on an instance
of DrawPanel.
fDrawPanel =
new DrawPanel (fYMin,fYMax,
fXMin, fXMax,
fDrawFunctions);
fDrawPanel.setTitle ("Fit Points");
fDrawPanel.setXLabel ("Y vs X");
// Create the x axis values for the
curves.
double dx = 5.0;
fX[0] = 0.0;
for (int i=1; i < 20; i++){
fX[i] = fX[i-1]
+ dx;
}
panel.add (fDrawPanel);
// Create histogram to show the quality
of the fits.
fResidualsHist = new Histogram
("Ydata - Yfit","Residuals", 20, -2,2.);
// Use another panel to hold the histogram
and controls panels.
JPanel hist_crls_panel = new JPanel
(new BorderLayout ());
// A panel to hold residuals histogram
fResidualsPanel=new HistPanel (fResidualsHist);
// Add the panel of histograms to
the main panel
hist_crls_panel.add ("Center",fResidualsPanel);
// Use a textfield for an input parameter.
fNumCurvesField =
new JTextField (Integer.toString
(fNumCurves), 10);
// Use a textfield for an input parameter.
fSmearField =
new JTextField (Double.toString
(fCurveSmear), 10);
// If return hit after entering text,
the
// actionPerformed will be invoked.
fNumCurvesField.addActionListener
(this);
fSmearField.addActionListener (this);
fGoButton = new JButton ("Go");
fGoButton.addActionListener (this);
fClearButton = new JButton ("Clear");
fClearButton.addActionListener (this);
fExitButton = new JButton ("Exit");
fExitButton.addActionListener (this);
JPanel control_panel = new JPanel
(new GridLayout (1,5));
control_panel.add (fNumCurvesField);
control_panel.add (fSmearField);
control_panel.add (fGoButton);
control_panel.add (fClearButton);
control_panel.add (fExitButton);
if (fInBrowser) fExitButton.setEnabled
(false);
hist_crls_panel.add (control_panel,"South");
panel.add (hist_crls_panel);
// Add text area with scrolling to
the applet
add (panel);
} // init
public void actionPerformed (ActionEvent e) {
Object source = e.getSource ();
if ( source == fGoButton || source
== fNumCurvesField
|| source == fSmearField) {
String strNumDataPoints
= fNumCurvesField.getText ();
String strCurveSmear =
fSmearField.getText ();
try{
fNumCurves
= Integer.parseInt (strNumDataPoints);
fCurveSmear
= Double.parseDouble (strCurveSmear);
}
catch (NumberFormatException
ex) {
// Could open
an error dialog here but just
// display
a message on the browser status line.
showStatus
("Bad input value");
return;
}
fGoButton.setEnabled (false);
fClearButton.setEnabled
(false);
if (fThread != null) stop
();
fThread = new Thread (this);
fThread.start ();
}
else if ( source == fClearButton)
{
fResidualsHist.clear
();
repaint ();
} else if (!fInBrowser)
System.exit
(0);
} // actionPerformed
public void stop (){
// If thread is still running, setting
this
// flag will kill it.
fThread = null;
} // stop
/** Generate the tracks in a thread.
*/
public void run () {
for (int i=0; i < fNumCurves; i++){
// Stop the thread if
flag set
if (fThread == null) return;
// Generate a random track.
double [] genParams =
genRanCurve (fXMax-fXMin,
fYMax-fYMin,
fX,
fY, fYErr, fCurveSmear);
// Fit points to quadratic.
Use constant error.
double [] fitParams =
new double[6];
FitPoly.fit (fitParams,
fX, fY, null, fYErr, fX.length);
// Pass the parameters
to the polynominal line fit.
fDrawFunctions[0].setParameters
(fitParams,null);
// Pass the data points
to the DrawPoints object via
// the 2-D array.
fDrawFunctions[1].setParameters
(null, fData);
// Redrawing the panel
will cause the paintContents (Graphics g)
// method in DrawPanel
to invoke the draw () method for the line
// and points drawing
functions.
fDrawPanel.repaint ();
// Include residuals ==
difference between the measured value
// and the fitted value
at the points at each x position
for (int j=0; j < fX.length;
j++){
double
yFit = fitParams[0] + fitParams[1]*fX[j]
+
fitParams[2]*fX[j]*fX[j];
fResidualsHist.add
(fY[j] - yFit);
}
// Pause briefly to let
users see the track.
try{
Thread.sleep
(30);
}catch (InterruptedException
e){}
}
repaint ();
fGoButton.setEnabled (true);
fClearButton.setEnabled (true);
} // run
/**
* Generate a quadratic
plot and obtain points along the curve.
* Smear the vertical coordinate
with a Gaussian.
**/
double [] genRanCurve (double x_range, double
y_range,
double [] x_curve, double [] y_curve,
double [] y_curve_err,
double smear){
// Parameters for a quadratic line.
double [] quadParam = new double[3];
// Simulated quadratic
double y0 = y_range* (0.5 + 0.25 *
fRan.nextDouble ());
double y1 = y_range * fRan.nextDouble
();
// Choose some dummy paramters for
the polynominal
quadParam[0] = y0;
quadParam[1] = (y1-y0)/
(8.0*x_range);
quadParam[2] = (fRan.nextDouble
() - 0.5)/100.0;
// Make the points and errors along
a quadratic line
for (int i=0; i < x_curve.length;
i++) {
y_curve[i] = y0 + quadParam[1]*x_curve[i]
+ quadParam[2]*x_curve[i]*x_curve[i];
double curve_err = smear*fRan.nextGaussian
();
// Add smear factor for
this point
y_curve[i] += curve_err;
// Create a dummy average
std.dev. error on the y value
// for this x position.
y_curve_err[i] = (1.0
+ fRan.nextDouble () ) * smear;
}
// Set up the parameters in the drawing
function.
fDrawFunctions[0].setParameters (quadParam,null);
// The FitPoly function will need
this data via
// a 2-D array.
fData[0] = fY;
fData[1] = fX;
fData[2] = fYErr;
fData[3] = null;
// Return the track parameters.
return quadParam;
} // genRanCurve
/**
* Allow for option of running
the program in standalone mode.
* Create the applet and add
to a frame.
**/
public static void main (String[] args) {
//
int frame_width=450;
int frame_height=450;
//
PolyFitApplet applet = new PolyFitApplet
();
applet.fInBrowser = false;
applet.init ();
// Following anonymous class used
to close window & exit program
JFrame f = new JFrame ("Demo");
f.setDefaultCloseOperation (JFrame.EXIT_ON_CLOSE);
// Add applet to the frame
f.getContentPane ().add ( applet);
f.setSize (new Dimension (frame_width,frame_height));
f.setVisible (true);
} // main
} // PolyFitApplet |
import
Jama.*;
import Jama.util.*;
/**
* Fit polynomial line to a set of data points.
* Implements the Fit interface.
**/
public class FitPoly extends Fit
{
/**
* Use the Least Squares
fit method for fitting a
* polynomial to 2-D data
for measurements
* y[i] vs. dependent variable
x[i]. This fit assumes
* there are errors only
on the y measuresments as
* given by the sigma_y
array.
*
* See, e.g. Press et al.,
"Numerical Recipes..." for details
* of the algorithm.
*
* The solution to the LSQ
fit uses the open source JAMA -
* "A Java Matrix Package"
classes. See http://math.nist.gov/javanumerics/jama/
* for description.
*
* @param parameters - first
half of the array holds the coefficients for
* the polynomial.
* The second half holds
the errors on the coefficients.
* @param x - independent
variable
* @param y - vertical dependent
variable
* @param sigma_x - std.
dev. error on each x value
* @param sigma_y - std.
dev. error on each y value
* @param num_points - number
of points to fit. Less than or equal to the
* dimension of the x array.
**/
public static void fit (double [] parameters,
double [] x, double [] y,
double [] sigma_x, double [] sigma_y, int num_points){
// numParams = num coeff + error on
each coeff.
int nk = parameters.length/2;
double [][] alpha = new
double[nk][nk];
double [] beta = new double[nk];
double term = 0;
for (int k=0; k < nk; k++) {
// Only need
to calculate diagonal and upper half
// of symmetric
matrix.
for (int j=k;
j < nk; j++) {
//
Calc terms over the data points
term
= 0.0;
alpha[k][j]
= 0.0;
for
(int i=0; i < num_points; i++) {
double
prod1 = 1.0;
//
Calculate x^k
if
( k > 0) for (int m=0; m < k; m++) prod1 *= x[i];
double
prod2 = 1.0;
//
Calculate x^j
if
( j > 0) for (int m=0; m < j; m++) prod2 *= x[i];
//
Calculate x^k * x^j
term
= (prod1*prod2);
if
(sigma_y != null && sigma_y[i] != 0.0)
term
/= (sigma_y[i]*sigma_y[i]);
alpha[k][j]
+= term;
}
alpha[j][k]
= alpha[k][j];// C will need to be inverted.
}
for (int i=0;
i < num_points; i++) {
double
prod1 = 1.0;
if
(k > 0) for ( int m=0; m < k; m++) prod1 *= x[i];
term
= (y[i] * prod1);
if
(sigma_y != null && sigma_y[i] != 0.0)
term
/= (sigma_y[i]*sigma_y[i]);
beta[k]
+=term;
}
}
// Use the Jama QR Decomposition classes
to solve for
// the parameters.
Matrix alpha_matrix = new Matrix (alpha);
QRDecomposition alpha_QRD = new QRDecomposition
(alpha_matrix);
Matrix beta_matrix = new Matrix (beta,nk);
Matrix param_matrix;
try {
param_matrix = alpha_QRD.solve
(beta_matrix);
}
catch (Exception e) {
System.out.println ("QRD
solve failed: "+ e);
return;
}
// The inverse provides the covariance
matrix.
Matrix c = alpha_matrix.inverse ();
for (int k=0; k < nk; k++) {
parameters[k] = param_matrix.get
(k,0);
// Diagonal elements of
the covariance matrix provide
// the square of the parameter
errors. Put in top half
// of the parametes array.
parameters[k+nk] = Math.sqrt
(c.get (k,k));
}
} // fit
} // FitPoly |
import
java.awt.*;
/**
* Drawi polynominal line onto the PlotPanel.
Extend the
* DrawFuction class and override the
draw method.
*
**/
public class DrawPoly extends DrawFunction
{
int [] fXFrame;
int [] fYFrame;
/**
* Draw a quadracti funtion
ax^2 + bx + c onto the PlotPanel.
*
* @param g graphics context
* @param frame_width display
area width in pixels.
* @param frame_height display
area height in pixels.
* @param frame_start_x
horizontal point on display where
* drawing starts
in pixel number.
* @param frame_start_y
vertical point on display where
* drawing starts
in pixel number.
* @param x_scale 2 dimensional
array holding lower and
* upper values
of the function input scale range.
* @param y_scale 2 dimensional
array holding lower and
* upper values
of the function output scale range.
**/
public void draw (Graphics g,
int frame_start_x, int frame_start_y,
int frame_width, int frame_height,
double [] x_scale, double [] y_scale) {
Color save_color = g.getColor ();
g.setColor (fColor);
// Check if ready to draw the line
if (fParameters == null) return;
int num_params = fParameters.length/2;
// Limit to polynominals of degree
5
if (num_params > 6) return;
// Get the number of horizontal pixels.
int num_points = frame_width;
// Get conversion factor from data
scale to frame pixels
double y_scale_factor = frame_height/(y_scale[y_scale.length-1]
- y_scale[0]);
double x_scale_factor = frame_width/(x_scale[x_scale.length-1]
- x_scale[0]);
// Create arrays of points for each
// point of the curve. Recreate if
width changes.
if (fYFrame == null || fYFrame.length
!= frame_width){
fYFrame =
new int[num_points];
fXFrame =
new int[num_points];
}
// Create a sine curve from a sequence
// of short line segments
double prod,prod2,prod3,y;
double x = x_scale[0];
double del_x = 1/x_scale_factor;
// Calculate the func = a0 + a1 *
x + ...
for (int i=0; i < num_points; i++)
{
// a0
y = fParameters[0];
x += del_x;
prod2 = x*x;
prod3 = prod2
* x;
//
switch (num_params){
case 6:
//
p5 * x^5
y
+= fParameters[5] * prod2 * prod3;
case 5:
//
a4 * x^4
y
+= fParameters[4] * prod2 * prod2;
case 4:
//
a3 * x^3
y
+= fParameters[3] * prod2 * x;
case 3:
//
a2 * x^2
y
+= fParameters[2] * prod2 ;
case 2:
//
a1 * x^1
y
+= fParameters[1] * x;
}
// Convert
to pixel coords
fYFrame[i]
= frame_height -
(int)((y
- y_scale[0]) * y_scale_factor) + frame_start_y;
fXFrame[i]
= frame_start_x + (int)((x - x_scale[0]) * x_scale_factor);
}
// Then pass the polygon object for
drawing
g.drawPolyline (fXFrame,fYFrame,num_points);
g.setColor (save_color);
} // draw
} // DrawPoly
|
References
& Web Resources
Most recent update: Oct. 27, 20050
|