home | career | drupal | java | mac | mysql | perl | scala | uml | unix  

Java example source code file (LevenbergMarquardtOptimizerTest.java)

This example Java source code file (LevenbergMarquardtOptimizerTest.java) is included in the alvinalexander.com "Java Source Code Warehouse" project. The intent of this project is to help you "Learn Java by Example" TM.

Learn more about this Java project at its project page.

Java - Java tags/keywords

arraylist, bevingtonproblem, circleproblem, circlevectorial, leastsquaresbuilder, leastsquaresproblem, levenbergmarquardtoptimizer, linearproblem, optimum, override, parametervalidator, randomcirclepointgenerator, realvector, test, util

The LevenbergMarquardtOptimizerTest.java Java example source code

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.commons.math3.fitting.leastsquares;

import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
import org.apache.commons.math3.geometry.euclidean.twod.Vector2D;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Precision;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.hamcrest.CoreMatchers.is;

/**
 * <p>Some of the unit tests are re-implementations of the MINPACK  and  test files.
 * The redistribution policy for MINPACK is available <a
 * href="http://www.netlib.org/minpack/disclaimer">here</a>.
 *
 */
public class LevenbergMarquardtOptimizerTest
    extends AbstractLeastSquaresOptimizerAbstractTest{

    public LeastSquaresBuilder builder(BevingtonProblem problem){
        return base()
                .model(problem.getModelFunction(), problem.getModelFunctionJacobian());
    }

    public LeastSquaresBuilder builder(CircleProblem problem){
        return base()
                .model(problem.getModelFunction(), problem.getModelFunctionJacobian())
                .target(problem.target())
                .weight(new DiagonalMatrix(problem.weight()));
    }

    @Override
    public int getMaxIterations() {
        return 25;
    }

    @Override
    public LeastSquaresOptimizer getOptimizer() {
        return new LevenbergMarquardtOptimizer();
    }

    @Override
    @Test
    public void testNonInvertible() {
        try{
            /*
             * Overrides the method from parent class, since the default singularity
             * threshold (1e-14) does not trigger the expected exception.
             */
            LinearProblem problem = new LinearProblem(new double[][] {
                    {  1, 2, -3 },
                    {  2, 1,  3 },
                    { -3, 0, -9 }
            }, new double[] { 1, 1, 1 });

            final Optimum optimum = optimizer.optimize(
                    problem.getBuilder().maxIterations(20).build());

            //TODO check that it is a bad fit? Why the extra conditions?
            Assert.assertTrue(FastMath.sqrt(problem.getTarget().length) * optimum.getRMS() > 0.6);

            optimum.getCovariances(1.5e-14);

            fail(optimizer);
        }catch (SingularMatrixException e){
            //expected
        }
    }

    @Test
    public void testControlParameters() {
        CircleVectorial circle = new CircleVectorial();
        circle.addPoint( 30.0,  68.0);
        circle.addPoint( 50.0,  -6.0);
        circle.addPoint(110.0, -20.0);
        circle.addPoint( 35.0,  15.0);
        circle.addPoint( 45.0,  97.0);
        checkEstimate(
                circle, 0.1, 10, 1.0e-14, 1.0e-16, 1.0e-10, false);
        checkEstimate(
                circle, 0.1, 10, 1.0e-15, 1.0e-17, 1.0e-10, true);
        checkEstimate(
                circle, 0.1,  5, 1.0e-15, 1.0e-16, 1.0e-10, true);
        circle.addPoint(300, -300);
        //wardev I changed true => false
        //TODO why should this fail? It uses 15 evaluations.
        checkEstimate(
                circle, 0.1, 20, 1.0e-18, 1.0e-16, 1.0e-10, false);
    }

    private void checkEstimate(CircleVectorial circle,
                               double initialStepBoundFactor, int maxCostEval,
                               double costRelativeTolerance, double parRelativeTolerance,
                               double orthoTolerance, boolean shouldFail) {
        try {
            final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer()
                .withInitialStepBoundFactor(initialStepBoundFactor)
                .withCostRelativeTolerance(costRelativeTolerance)
                .withParameterRelativeTolerance(parRelativeTolerance)
                .withOrthoTolerance(orthoTolerance)
                .withRankingThreshold(Precision.SAFE_MIN);

            final LeastSquaresProblem problem = builder(circle)
                    .maxEvaluations(maxCostEval)
                    .maxIterations(100)
                    .start(new double[] { 98.680, 47.345 })
                    .build();

            optimizer.optimize(problem);

            Assert.assertTrue(!shouldFail);
            //TODO check it got the right answer

        } catch (DimensionMismatchException ee) {
            Assert.assertTrue(shouldFail);
        } catch (TooManyEvaluationsException ee) {
            Assert.assertTrue(shouldFail);
        }
    }

    /**
     * Non-linear test case: fitting of decay curve (from Chapter 8 of
     * Bevington's textbook, "Data reduction and analysis for the physical sciences").
     * XXX The expected ("reference") values may not be accurate and the tolerance too
     * relaxed for this test to be currently really useful (the issue is under
     * investigation).
     */
    @Test
    public void testBevington() {
        final double[][] dataPoints = {
            // column 1 = times
            { 15, 30, 45, 60, 75, 90, 105, 120, 135, 150,
              165, 180, 195, 210, 225, 240, 255, 270, 285, 300,
              315, 330, 345, 360, 375, 390, 405, 420, 435, 450,
              465, 480, 495, 510, 525, 540, 555, 570, 585, 600,
              615, 630, 645, 660, 675, 690, 705, 720, 735, 750,
              765, 780, 795, 810, 825, 840, 855, 870, 885, },
            // column 2 = measured counts
            { 775, 479, 380, 302, 185, 157, 137, 119, 110, 89,
              74, 61, 66, 68, 48, 54, 51, 46, 55, 29,
              28, 37, 49, 26, 35, 29, 31, 24, 25, 35,
              24, 30, 26, 28, 21, 18, 20, 27, 17, 17,
              14, 17, 24, 11, 22, 17, 12, 10, 13, 16,
              9, 9, 14, 21, 17, 13, 12, 18, 10, },
        };
        final double[] start = {10, 900, 80, 27, 225};

        final BevingtonProblem problem = new BevingtonProblem();

        final int len = dataPoints[0].length;
        final double[] weights = new double[len];
        for (int i = 0; i < len; i++) {
            problem.addPoint(dataPoints[0][i],
                             dataPoints[1][i]);

            weights[i] = 1 / dataPoints[1][i];
        }

        final Optimum optimum = optimizer.optimize(
                builder(problem)
                        .target(dataPoints[1])
                        .weight(new DiagonalMatrix(weights))
                        .start(start)
                        .maxIterations(20)
                        .build()
        );

        final RealVector solution = optimum.getPoint();
        final double[] expectedSolution = { 10.4, 958.3, 131.4, 33.9, 205.0 };

        final RealMatrix covarMatrix = optimum.getCovariances(1e-14);
        final double[][] expectedCovarMatrix = {
            { 3.38, -3.69, 27.98, -2.34, -49.24 },
            { -3.69, 2492.26, 81.89, -69.21, -8.9 },
            { 27.98, 81.89, 468.99, -44.22, -615.44 },
            { -2.34, -69.21, -44.22, 6.39, 53.80 },
            { -49.24, -8.9, -615.44, 53.8, 929.45 }
        };

        final int numParams = expectedSolution.length;

        // Check that the computed solution is within the reference error range.
        for (int i = 0; i < numParams; i++) {
            final double error = FastMath.sqrt(expectedCovarMatrix[i][i]);
            Assert.assertEquals("Parameter " + i, expectedSolution[i], solution.getEntry(i), error);
        }

        // Check that each entry of the computed covariance matrix is within 10%
        // of the reference matrix entry.
        for (int i = 0; i < numParams; i++) {
            for (int j = 0; j < numParams; j++) {
                Assert.assertEquals("Covariance matrix [" + i + "][" + j + "]",
                                    expectedCovarMatrix[i][j],
                                    covarMatrix.getEntry(i, j),
                                    FastMath.abs(0.1 * expectedCovarMatrix[i][j]));
            }
        }
    }

    @Test
    public void testCircleFitting2() {
        final double xCenter = 123.456;
        final double yCenter = 654.321;
        final double xSigma = 10;
        final double ySigma = 15;
        final double radius = 111.111;
        // The test is extremely sensitive to the seed.
        final long seed = 59421061L;
        final RandomCirclePointGenerator factory
            = new RandomCirclePointGenerator(xCenter, yCenter, radius,
                                             xSigma, ySigma,
                                             seed);
        final CircleProblem circle = new CircleProblem(xSigma, ySigma);

        final int numPoints = 10;
        for (Vector2D p : factory.generate(numPoints)) {
            circle.addPoint(p.getX(), p.getY());
        }

        // First guess for the center's coordinates and radius.
        final double[] init = { 90, 659, 115 };

        final Optimum optimum = optimizer.optimize(
                builder(circle).maxIterations(50).start(init).build());

        final double[] paramFound = optimum.getPoint().toArray();

        // Retrieve errors estimation.
        final double[] asymptoticStandardErrorFound = optimum.getSigma(1e-14).toArray();

        // Check that the parameters are found within the assumed error bars.
        Assert.assertEquals(xCenter, paramFound[0], asymptoticStandardErrorFound[0]);
        Assert.assertEquals(yCenter, paramFound[1], asymptoticStandardErrorFound[1]);
        Assert.assertEquals(radius, paramFound[2], asymptoticStandardErrorFound[2]);
    }

    @Test
    public void testParameterValidator() {
        // Setup.
        final double xCenter = 123.456;
        final double yCenter = 654.321;
        final double xSigma = 10;
        final double ySigma = 15;
        final double radius = 111.111;
        final long seed = 3456789L;
        final RandomCirclePointGenerator factory
            = new RandomCirclePointGenerator(xCenter, yCenter, radius,
                                             xSigma, ySigma,
                                             seed);
        final CircleProblem circle = new CircleProblem(xSigma, ySigma);

        final int numPoints = 10;
        for (Vector2D p : factory.generate(numPoints)) {
            circle.addPoint(p.getX(), p.getY());
        }

        // First guess for the center's coordinates and radius.
        final double[] init = { 90, 659, 115 };
        final Optimum optimum
            = optimizer.optimize(builder(circle).maxIterations(50).start(init).build());
        final int numEval = optimum.getEvaluations();
        Assert.assertTrue(numEval > 1);

        // Build a new problem with a validator that amounts to cheating.
        final ParameterValidator cheatValidator
            = new ParameterValidator() {
                    public RealVector validate(RealVector params) {
                        // Cheat: return the optimum found previously.
                        return optimum.getPoint();
                    }
                };

        final Optimum cheatOptimum
            = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
        final int cheatNumEval = cheatOptimum.getEvaluations();
        Assert.assertTrue(cheatNumEval < numEval);
        // System.out.println("n=" + numEval + " nc=" + cheatNumEval);
    }

    @Test
    public void testEvaluationCount() {
        //setup
        LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})
                .getBuilder()
                .checker(new ConvergenceChecker<Evaluation>() {
                    public boolean converged(int iteration, Evaluation previous, Evaluation current) {
                        return true;
                    }
                })
                .build();

        //action
        Optimum optimum = optimizer.optimize(lsp);

        //verify
        //check iterations and evaluations are not switched.
        Assert.assertThat(optimum.getIterations(), is(1));
        Assert.assertThat(optimum.getEvaluations(), is(2));
    }

    private static class BevingtonProblem {
        private List<Double> time;
        private List<Double> count;

        public BevingtonProblem() {
            time = new ArrayList<Double>();
            count = new ArrayList<Double>();
        }

        public void addPoint(double t, double c) {
            time.add(t);
            count.add(c);
        }

        public MultivariateVectorFunction getModelFunction() {
            return new MultivariateVectorFunction() {
                public double[] value(double[] params) {
                    double[] values = new double[time.size()];
                    for (int i = 0; i < values.length; ++i) {
                        final double t = time.get(i);
                        values[i] = params[0] +
                            params[1] * FastMath.exp(-t / params[3]) +
                            params[2] * FastMath.exp(-t / params[4]);
                    }
                    return values;
                }
            };
        }

        public MultivariateMatrixFunction getModelFunctionJacobian() {
            return new MultivariateMatrixFunction() {
                public double[][] value(double[] params) {
                    double[][] jacobian = new double[time.size()][5];

                    for (int i = 0; i < jacobian.length; ++i) {
                        final double t = time.get(i);
                        jacobian[i][0] = 1;

                        final double p3 =  params[3];
                        final double p4 =  params[4];
                        final double tOp3 = t / p3;
                        final double tOp4 = t / p4;
                        jacobian[i][1] = FastMath.exp(-tOp3);
                        jacobian[i][2] = FastMath.exp(-tOp4);
                        jacobian[i][3] = params[1] * FastMath.exp(-tOp3) * tOp3 / p3;
                        jacobian[i][4] = params[2] * FastMath.exp(-tOp4) * tOp4 / p4;
                    }
                    return jacobian;
                }
            };
        }
    }
}


my book on functional programming

 

new blog posts

 

Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.

A percentage of advertising revenue from
pages under the /java/jwarehouse URI on this website is
paid back to open source projects.