Repository: commons-math Updated Branches: refs/heads/master 753f278d1 -> 491786ce4
MATH-1172: Simple curve fitter Provides boiler-plate code so that users can readily fit any parametric function. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/491786ce Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/491786ce Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/491786ce Branch: refs/heads/master Commit: 491786ce4182daa57ea61deaf701f756f57f6f1e Parents: 753f278 Author: Gilles <er...@apache.org> Authored: Sun Dec 14 18:48:01 2014 +0100 Committer: Gilles <er...@apache.org> Committed: Sun Dec 14 18:48:01 2014 +0100 ---------------------------------------------------------------------- src/changes/changes.xml | 4 + .../math3/fitting/SimpleCurveFitter.java | 126 +++++++++++++++++++ .../math3/fitting/SimpleCurveFitterTest.java | 60 +++++++++ 3 files changed, 190 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/491786ce/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index e54d71e..cbf1eb8 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -79,6 +79,10 @@ Users are encouraged to upgrade to this version as this release not 2. A few methods in the FastMath class are in fact slower that their counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901). "> + <action dev="erans" type="add" issue="MATH-1172"> + New class "SimpleCurveFitter": Boiler-plate code to allow fitting of + a user-defined parametric function. + </action> <action dev="erans" type="add" issue="MATH-1173"> New classes "TricubicInterpolatingFunction" and "TricubicInterpolator" to replace "TricubicSplineInterpolatingFunction" and "TricubicSplineInterpolator". http://git-wip-us.apache.org/repos/asf/commons-math/blob/491786ce/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java b/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java new file mode 100644 index 0000000..0307e42 --- /dev/null +++ b/src/main/java/org/apache/commons/math3/fitting/SimpleCurveFitter.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.fitting; + +import java.util.Collection; + +import org.apache.commons.math3.analysis.ParametricUnivariateFunction; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder; +import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem; +import org.apache.commons.math3.linear.DiagonalMatrix; + +/** + * Fits points to a user-defined {@link ParametricUnivariateFunction function}. + * + * @since 3.4 + */ +public class SimpleCurveFitter extends AbstractCurveFitter { + /** Function to fit. */ + private final ParametricUnivariateFunction function; + /** Initial guess for the parameters. */ + private final double[] initialGuess; + /** Maximum number of iterations of the optimization algorithm. */ + private final int maxIter; + + /** + * Contructor used by the factory methods. + * + * @param function Function to fit. + * @param initialGuess Initial guess. Cannot be {@code null}. Its length must + * be consistent with the number of parameters of the {@code function} to fit. + * @param maxIter Maximum number of iterations of the optimization algorithm. + */ + private SimpleCurveFitter(ParametricUnivariateFunction function, + double[] initialGuess, + int maxIter) { + this.function = function; + this.initialGuess = initialGuess; + this.maxIter = maxIter; + } + + /** + * Creates a curve fitter. + * The initial guess for the parameters will be {@link ParameterGuesser} + * computed automatically, and the maximum number of iterations of the + * optimization algorithm is set to {@link Integer#MAX_VALUE}. + * + * @param f Function to fit. + * @param start Initial guess for the parameters. Cannot be {@code null}. + * Its length must be consistent with the number of parameters of the + * function to fit. + * @return a curve fitter. + * + * @see #withStartPoint(double[]) + * @see #withMaxIterations(int) + */ + public static SimpleCurveFitter create(ParametricUnivariateFunction f, + double[] start) { + return new SimpleCurveFitter(f, start, Integer.MAX_VALUE); + } + + /** + * Configure the start point (initial guess). + * @param newStart new start point (initial guess) + * @return a new instance. + */ + public SimpleCurveFitter withStartPoint(double[] newStart) { + return new SimpleCurveFitter(function, + newStart.clone(), + maxIter); + } + + /** + * Configure the maximum number of iterations. + * @param newMaxIter maximum number of iterations + * @return a new instance. + */ + public SimpleCurveFitter withMaxIterations(int newMaxIter) { + return new SimpleCurveFitter(function, + initialGuess, + newMaxIter); + } + + /** {@inheritDoc} */ + @Override + protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) { + // Prepare least-squares problem. + final int len = observations.size(); + final double[] target = new double[len]; + final double[] weights = new double[len]; + + int count = 0; + for (WeightedObservedPoint obs : observations) { + target[count] = obs.getY(); + weights[count] = obs.getWeight(); + ++count; + } + + final AbstractCurveFitter.TheoreticalValuesFunction model + = new AbstractCurveFitter.TheoreticalValuesFunction(function, + observations); + + // Create an optimizer for fitting the curve to the observed points. + return new LeastSquaresBuilder(). + maxEvaluations(Integer.MAX_VALUE). + maxIterations(maxIter). + start(initialGuess). + target(target). + weight(new DiagonalMatrix(weights)). + model(model.getModelFunction(), model.getModelFunctionJacobian()). + build(); + } +} http://git-wip-us.apache.org/repos/asf/commons-math/blob/491786ce/src/test/java/org/apache/commons/math3/fitting/SimpleCurveFitterTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/fitting/SimpleCurveFitterTest.java b/src/test/java/org/apache/commons/math3/fitting/SimpleCurveFitterTest.java new file mode 100644 index 0000000..d411d4a --- /dev/null +++ b/src/test/java/org/apache/commons/math3/fitting/SimpleCurveFitterTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.fitting; + +import java.util.Random; + +import org.apache.commons.math3.TestUtils; +import org.apache.commons.math3.analysis.ParametricUnivariateFunction; +import org.apache.commons.math3.analysis.polynomials.PolynomialFunction; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.apache.commons.math3.exception.ConvergenceException; +import org.apache.commons.math3.util.FastMath; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for class {@link SimpleCurveFitter}. + */ +public class SimpleCurveFitterTest { + @Test + public void testPolynomialFit() { + final Random randomizer = new Random(53882150042L); + final RealDistribution rng = new UniformRealDistribution(-100, 100); + rng.reseedRandomGenerator(64925784252L); + + final double[] coeff = { 12.9, -3.4, 2.1 }; // 12.9 - 3.4 x + 2.1 x^2 + final PolynomialFunction f = new PolynomialFunction(coeff); + + // Collect data from a known polynomial. + final WeightedObservedPoints obs = new WeightedObservedPoints(); + for (int i = 0; i < 100; i++) { + final double x = rng.sample(); + obs.add(x, f.value(x) + 0.1 * randomizer.nextGaussian()); + } + + final ParametricUnivariateFunction function = new PolynomialFunction.Parametric(); + // Start fit from initial guesses that are far from the optimal values. + final SimpleCurveFitter fitter + = SimpleCurveFitter.create(function, + new double[] { -1e20, 3e15, -5e25 }); + final double[] best = fitter.fit(obs.toList()); + + TestUtils.assertEquals("best != coeff", coeff, best, 2e-2); + } +}