Author: luc
Date: Wed Sep 12 11:22:56 2012
New Revision: 1383885

URL: http://svn.apache.org/viewvc?rev=1383885&view=rev
Log:
Added a wrapper class to compute gradient from differentiable function.

Added:
    
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
   (with props)
    
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
   (with props)

Added: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
URL: 
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java?rev=1383885&view=auto
==============================================================================
--- 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
 (added)
+++ 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
 Wed Sep 12 11:22:56 2012
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.analysis.differentiation;
+
+import org.apache.commons.math3.analysis.MultivariateVectorFunction;
+
+/** Class representing the gradient of a multivariate function.
+ * <p>
+ * The vectorial components of the function represent the derivatives
+ * with respect to each function parameters.
+ * </p>
+ * @version $Id$
+ * @since 3.1
+ */
+public class GradientFunction implements MultivariateVectorFunction {
+
+    /** Underlying real-valued function. */
+    private final MultivariateDifferentiableFunction f;
+
+    /** Simple constructor.
+     * @param f underlying real-valued function
+     */
+    public GradientFunction(final MultivariateDifferentiableFunction f) {
+        this.f = f;
+    }
+
+    /** {@inheritDoc} */
+    public double[] value(double[] point)
+        throws IllegalArgumentException {
+
+        // set up parameters
+        final DerivativeStructure[] dsX = new 
DerivativeStructure[point.length];
+        for (int i = 0; i < point.length; ++i) {
+            dsX[i] = new DerivativeStructure(point.length, 1, i, point[i]);
+        }
+
+        // compute the derivatives
+        final DerivativeStructure dsY = f.value(dsX);
+
+        // extract the gradient
+        final double[] y = new double[point.length];
+        final int[] orders = new int[point.length];
+        for (int i = 0; i < point.length; ++i) {
+            orders[i] = 1;
+            y[i] = dsY.getPartialDerivative(orders);
+            orders[i] = 0;
+        }
+
+        return y;
+
+    }
+
+}

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/analysis/differentiation/GradientFunction.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Added: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
URL: 
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java?rev=1383885&view=auto
==============================================================================
--- 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
 (added)
+++ 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
 Wed Sep 12 11:22:56 2012
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.math3.analysis.differentiation;
+
+import org.apache.commons.math3.TestUtils;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.MathIllegalArgumentException;
+import org.apache.commons.math3.util.FastMath;
+import org.junit.Test;
+
+
+/**
+ * Test for class {@link GradientFunction}.
+ */
+public class GradientFunctionTest {
+
+    @Test
+    public void test2DDistance() {
+        EuclideanDistance f = new EuclideanDistance();
+        GradientFunction g = new GradientFunction(f);
+        for (double x = -10; x < 10; x += 0.5) {
+            for (double y = -10; y < 10; y += 0.5) {
+                double[] point = new double[] { x, y };
+                TestUtils.assertEquals(f.gradient(point), g.value(point), 
1.0e-15);
+            }
+        }
+    }
+
+    @Test
+    public void test3DDistance() {
+        EuclideanDistance f = new EuclideanDistance();
+        GradientFunction g = new GradientFunction(f);
+        for (double x = -10; x < 10; x += 0.5) {
+            for (double y = -10; y < 10; y += 0.5) {
+                for (double z = -10; z < 10; z += 0.5) {
+                    double[] point = new double[] { x, y, z };
+                    TestUtils.assertEquals(f.gradient(point), g.value(point), 
1.0e-15);
+                }
+            }
+        }
+    }
+
+    private static class EuclideanDistance implements 
MultivariateDifferentiableFunction {
+        
+        public double value(double[] point) {
+            double d2 = 0;
+            for (double x : point) {
+                d2 += x * x;
+            }
+            return FastMath.sqrt(d2);
+        }
+        
+        public DerivativeStructure value(DerivativeStructure[] point)
+            throws DimensionMismatchException, MathIllegalArgumentException {
+            DerivativeStructure d2 = point[0].getField().getZero();
+            for (DerivativeStructure x : point) {
+                d2 = d2.add(x.multiply(x));
+            }
+            return d2.sqrt();
+        }
+
+        public double[] gradient(double[] point) {
+            double[] gradient = new double[point.length];
+            double d = value(point);
+            for (int i = 0; i < point.length; ++i) {
+                gradient[i] = point[i] / d;
+            }
+            return gradient;
+        }
+
+    }
+
+}

Propchange: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/analysis/differentiation/GradientFunctionTest.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"


Reply via email to