|
Commons Math example source code file (OLSMultipleLinearRegressionTest.java)
The Commons Math OLSMultipleLinearRegressionTest.java source code/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.commons.math.stat.regression; import static org.junit.Assert.assertEquals; import org.apache.commons.math.TestUtils; import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor; import org.apache.commons.math.linear.MatrixUtils; import org.apache.commons.math.linear.MatrixVisitorException; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.Array2DRowRealMatrix; import org.junit.Before; import org.junit.Test; public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest { private double[] y; private double[][] x; @Before @Override public void setUp(){ y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0}; x = new double[6][]; x[0] = new double[]{1.0, 0, 0, 0, 0, 0}; x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0}; x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0}; x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0}; x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0}; x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0}; super.setUp(); } @Override protected OLSMultipleLinearRegression createRegression() { OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); regression.newSampleData(y, x); return regression; } @Override protected int getNumberOfRegressors() { return x[0].length; } @Override protected int getSampleSize() { return y.length; } @Test(expected=IllegalArgumentException.class) public void cannotAddXSampleData() { createRegression().newSampleData(new double[]{}, null); } @Test(expected=IllegalArgumentException.class) public void cannotAddNullYSampleData() { createRegression().newSampleData(null, new double[][]{}); } @Test(expected=IllegalArgumentException.class) public void cannotAddSampleDataWithSizeMismatch() { double[] y = new double[]{1.0, 2.0}; double[][] x = new double[1][]; x[0] = new double[]{1.0, 0}; createRegression().newSampleData(y, x); } @Test public void testPerfectFit() { double[] betaHat = regression.estimateRegressionParameters(); TestUtils.assertEquals(betaHat, new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 }, 1e-14); double[] residuals = regression.estimateResiduals(); TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d}, 1e-14); RealMatrix errors = new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false); final double[] s = { 1.0, -1.0 / 2.0, -1.0 / 3.0, -1.0 / 4.0, -1.0 / 5.0, -1.0 / 6.0 }; RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length); referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() { @Override public double visit(int row, int column, double value) throws MatrixVisitorException { if (row == 0) { return s[column]; } double x = s[row] * s[column]; return (row == column) ? 2 * x : x; } }); assertEquals(0.0, errors.subtract(referenceVariance).getNorm(), 5.0e-16 * referenceVariance.getNorm()); } /** * Test Longley dataset against certified values provided by NIST. * Data Source: J. Longley (1967) "An Appraisal of Least Squares * Programs for the Electronic Computer from the Point of View of the User" * Journal of the American Statistical Association, vol. 62. September, * pp. 819-841. * * Certified values (and data) are from NIST: * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat */ @Test public void testLongly() { // Y values are first, then independent vars // Each row is one observation double[] design = new double[] { 60323,83.0,234289,2356,1590,107608,1947, 61122,88.5,259426,2325,1456,108632,1948, 60171,88.2,258054,3682,1616,109773,1949, 61187,89.5,284599,3351,1650,110929,1950, 63221,96.2,328975,2099,3099,112075,1951, 63639,98.1,346999,1932,3594,113270,1952, 64989,99.0,365385,1870,3547,115094,1953, 63761,100.0,363112,3578,3350,116219,1954, 66019,101.2,397469,2904,3048,117388,1955, 67857,104.6,419180,2822,2857,118734,1956, 68169,108.4,442769,2936,2798,120445,1957, 66513,110.8,444546,4681,2637,121950,1958, 68655,112.6,482704,3813,2552,123366,1959, 69564,114.2,502601,3931,2514,125368,1960, 69331,115.7,518173,4806,2572,127852,1961, 70551,116.9,554894,4007,2827,130081,1962 }; // Transform to Y and X required by interface int nobs = 16; int nvars = 6; // Estimate the model OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); model.newSampleData(design, nobs, nvars); // Check expected beta values from NIST double[] betaHat = model.estimateRegressionParameters(); TestUtils.assertEquals(betaHat, new double[]{-3482258.63459582, 15.0618722713733, -0.358191792925910E-01,-2.02022980381683, -1.03322686717359,-0.511041056535807E-01, 1829.15146461355}, 2E-8); // // Check expected residuals from R double[] residuals = model.estimateResiduals(); TestUtils.assertEquals(residuals, new double[]{ 267.340029759711,-94.0139423988359,46.28716775752924, -410.114621930906,309.7145907602313,-249.3112153297231, -164.0489563956039,-13.18035686637081,14.30477260005235, 455.394094551857,-17.26892711483297,-39.0550425226967, -155.5499735953195,-85.6713080421283,341.9315139607727, -206.7578251937366}, 1E-8); // Check standard errors from NIST double[] errors = model.estimateRegressionParametersStandardErrors(); TestUtils.assertEquals(new double[] {890420.383607373, 84.9149257747669, 0.334910077722432E-01, 0.488399681651699, 0.214274163161675, 0.226073200069370, 455.478499142212}, errors, 1E-6); } /** * Test R Swiss fertility dataset against R. * Data Source: R datasets package */ @Test public void testSwissFertility() { double[] design = new double[] { 80.2,17.0,15,12,9.96, 83.1,45.1,6,9,84.84, 92.5,39.7,5,5,93.40, 85.8,36.5,12,7,33.77, 76.9,43.5,17,15,5.16, 76.1,35.3,9,7,90.57, 83.8,70.2,16,7,92.85, 92.4,67.8,14,8,97.16, 82.4,53.3,12,7,97.67, 82.9,45.2,16,13,91.38, 87.1,64.5,14,6,98.61, 64.1,62.0,21,12,8.52, 66.9,67.5,14,7,2.27, 68.9,60.7,19,12,4.43, 61.7,69.3,22,5,2.82, 68.3,72.6,18,2,24.20, 71.7,34.0,17,8,3.30, 55.7,19.4,26,28,12.11, 54.3,15.2,31,20,2.15, 65.1,73.0,19,9,2.84, 65.5,59.8,22,10,5.23, 65.0,55.1,14,3,4.52, 56.6,50.9,22,12,15.14, 57.4,54.1,20,6,4.20, 72.5,71.2,12,1,2.40, 74.2,58.1,14,8,5.23, 72.0,63.5,6,3,2.56, 60.5,60.8,16,10,7.72, 58.3,26.8,25,19,18.46, 65.4,49.5,15,8,6.10, 75.5,85.9,3,2,99.71, 69.3,84.9,7,6,99.68, 77.3,89.7,5,2,100.00, 70.5,78.2,12,6,98.96, 79.4,64.9,7,3,98.22, 65.0,75.9,9,9,99.06, 92.2,84.6,3,3,99.46, 79.3,63.1,13,13,96.83, 70.4,38.4,26,12,5.62, 65.7,7.7,29,11,13.79, 72.7,16.7,22,13,11.22, 64.4,17.6,35,32,16.92, 77.6,37.6,15,7,4.97, 67.6,18.7,25,7,8.65, 35.0,1.2,37,53,42.34, 44.7,46.6,16,29,50.43, 42.8,27.7,22,29,58.33 }; // Transform to Y and X required by interface int nobs = 47; int nvars = 4; // Estimate the model OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); model.newSampleData(design, nobs, nvars); // Check expected beta values from R double[] betaHat = model.estimateRegressionParameters(); TestUtils.assertEquals(betaHat, new double[]{91.05542390271397, -0.22064551045715, -0.26058239824328, -0.96161238456030, 0.12441843147162}, 1E-12); // Check expected residuals from R double[] residuals = model.estimateResiduals(); TestUtils.assertEquals(residuals, new double[]{ 7.1044267859730512,1.6580347433531366, 4.6944952770029644,8.4548022690166160,13.6547432343186212, -9.3586864458500774,7.5822446330520386,15.5568995563859289, 0.8113090736598980,7.1186762732484308,7.4251378771228724, 2.6761316873234109,0.8351584810309354,7.1769991119615177, -3.8746753206299553,-3.1337779476387251,-0.1412575244091504, 1.1186809170469780,-6.3588097346816594,3.4039270429434074, 2.3374058329820175,-7.9272368576900503,-7.8361010968497959, -11.2597369269357070,0.9445333697827101,6.6544245101380328, -0.9146136301118665,-4.3152449403848570,-4.3536932047009183, -3.8907885169304661,-6.3027643926302188,-7.8308982189289091, -3.1792280015332750,-6.7167298771158226,-4.8469946718041754, -10.6335664353633685,11.1031134362036958,6.0084032641811733, 5.4326230830188482,-7.2375578629692230,2.1671550814448222, 15.0147574652763112,4.8625103516321015,-7.1597256413907706, -0.4515205619767598,-10.2916870903837587,-15.7812984571900063}, 1E-12); // Check standard errors from R double[] errors = model.estimateRegressionParametersStandardErrors(); TestUtils.assertEquals(new double[] {6.94881329475087, 0.07360008972340, 0.27410957467466, 0.19454551679325, 0.03726654773803}, errors, 1E-10); } /** * Test hat matrix computation * * @throws Exception */ @Test public void testHat() throws Exception { /* * This example is from "The Hat Matrix in Regression and ANOVA", * David C. Hoaglin and Roy E. Welsch, * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. * */ double[] design = new double[] { 11.14, .499, 11.1, 12.74, .558, 8.9, 13.13, .604, 8.8, 11.51, .441, 8.9, 12.38, .550, 8.8, 12.60, .528, 9.9, 11.13, .418, 10.7, 11.7, .480, 10.5, 11.02, .406, 10.5, 11.41, .467, 10.7 }; int nobs = 10; int nvars = 2; // Estimate the model OLSMultipleLinearRegression model = new OLSMultipleLinearRegression(); model.newSampleData(design, nobs, nvars); RealMatrix hat = model.calculateHat(); // Reference data is upper half of symmetric hat matrix double[] referenceData = new double[] { .418, -.002, .079, -.274, -.046, .181, .128, .222, .050, .242, .242, .292, .136, .243, .128, -.041, .033, -.035, .004, .417, -.019, .273, .187, -.126, .044, -.153, .004, .604, .197, -.038, .168, -.022, .275, -.028, .252, .111, -.030, .019, -.010, -.010, .148, .042, .117, .012, .111, .262, .145, .277, .174, .154, .120, .168, .315, .148, .187 }; // Check against reference data and verify symmetry int k = 0; for (int i = 0; i < 10; i++) { for (int j = i; j < 10; j++) { assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3); assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12); k++; } } /* * Verify that residuals computed using the hat matrix are close to * what we get from direct computation, i.e. r = (I - H) y */ double[] residuals = model.estimateResiduals(); RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); double[] hatResiduals = I.subtract(hat).operate(model.Y).getData(); TestUtils.assertEquals(residuals, hatResiduals, 10e-12); } } Other Commons Math examples (source code examples)Here is a short list of links related to this Commons Math OLSMultipleLinearRegressionTest.java source code file: |
... this post is sponsored by my books ... | |
#1 New Release! |
FP Best Seller |
Copyright 1998-2021 Alvin Alexander, alvinalexander.com
All Rights Reserved.
A percentage of advertising revenue from
pages under the /java/jwarehouse
URI on this website is
paid back to open source projects.