commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject [2/2] [math] MATH-1263
Date Sun, 30 Aug 2015 16:21:32 GMT
MATH-1263

Accessor to get neighbouring neurons (in a square grid).


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/46e97d9e
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/46e97d9e
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/46e97d9e

Branch: refs/heads/master
Commit: 46e97d9e04591e8198ff38bd2949fd5514da0d3e
Parents: 47fa07e
Author: Gilles <erans@apache.org>
Authored: Sun Aug 30 18:20:19 2015 +0200
Committer: Gilles <erans@apache.org>
Committed: Sun Aug 30 18:20:19 2015 +0200

----------------------------------------------------------------------
 src/changes/changes.xml                         |   3 +
 .../ml/neuralnet/twod/NeuronSquareMesh2D.java   | 132 +++++++++++++++
 .../neuralnet/twod/NeuronSquareMesh2DTest.java  | 163 +++++++++++++++++++
 3 files changed, 298 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/46e97d9e/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 52297c0..2dc0628 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
     </release>
 
     <release version="4.0" date="XXXX-XX-XX" description="">
+      <action dev="erans" type="add" issue="MATH-1263">
+        Accessor (class "o.a.c.m.ml.neuralnet.twod.NeuronSquareMesh2D").
+      </action>
       <action dev="erans" type="add" issue="MATH-1259">
         New "IntegerSequence" class (in package "o.a.c.m.util") with "Incrementor" inner
class.
       </action>

http://git-wip-us.apache.org/repos/asf/commons-math/blob/46e97d9e/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
index 3463b75..dd6cf6a 100644
--- a/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
+++ b/src/main/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2D.java
@@ -64,6 +64,29 @@ public class NeuronSquareMesh2D implements Serializable {
     private final long[][] identifiers;
 
     /**
+     * Horizontal (along row) direction.
+     */
+    public enum HorizontalDirection {
+        /** Column at the right of the current column. */
+       RIGHT,
+       /** Current column. */
+       CENTER,
+       /** Column at the left of the current column. */
+       LEFT,
+    }
+    /**
+     * Vertical (along column) direction.
+     */
+    public enum VerticalDirection {
+        /** Row above the current row. */
+        UP,
+        /** Current row. */
+        CENTER,
+        /** Row below the current row. */
+        DOWN,
+    }
+
+    /**
      * Constructor with restricted access, solely used for deserialization.
      *
      * @param wrapRowDim Whether to wrap the first dimension (i.e the first
@@ -205,12 +228,16 @@ public class NeuronSquareMesh2D implements Serializable {
 
     /**
      * Retrieves the neuron at location {@code (i, j)} in the map.
+     * The neuron at position {@code (0, 0)} is located at the upper-left
+     * corner of the map.
      *
      * @param i Row index.
      * @param j Column index.
      * @return the neuron at {@code (i, j)}.
      * @throws OutOfRangeException if {@code i} or {@code j} is
      * out of range.
+     *
+     * @see #getNeuron(int,int,HorizontalDirection,VerticalDirection)
      */
     public Neuron getNeuron(int i,
                             int j) {
@@ -227,6 +254,111 @@ public class NeuronSquareMesh2D implements Serializable {
     }
 
     /**
+     * Retrieves the requested neuron relative to the given {@code (row, col)}
+     * position.
+     * The neuron at position {@code (0, 0)} is located at the upper-left
+     * corner of the map.
+     *
+     * @param row Row index.
+     * @param col Column index.
+     * @param alongRowDir Direction along the given {@code row} (i.e. an
+     * offset will be added to the given <em>column</em> index.
+     * @param alongColDir Direction along the given {@code col} (i.e. an
+     * offset will be added to the given <em>row</em> index.
+     * @return the neuron at the requested location, or {@code null} if
+     * the location is not on the map.
+     *
+     * @see #getNeuron(int,int)
+     */
+    public Neuron getNeuron(int row,
+                            int col,
+                            HorizontalDirection alongRowDir,
+                            VerticalDirection alongColDir) {
+        final int[] location = getLocation(row, col, alongRowDir, alongColDir);
+
+        return location == null ? null : getNeuron(location[0], location[1]);
+    }
+
+    /**
+     * Computes the location of a neighbouring neuron.
+     * Returns {@code null} if the resulting location is not part
+     * of the map.
+     * Position {@code (0, 0)} is at the upper-left corner of the map.
+     *
+     * @param row Row index.
+     * @param col Column index.
+     * @param alongRowDir Direction along the given {@code row} (i.e. an
+     * offset will be added to the given <em>column</em> index.
+     * @param alongColDir Direction along the given {@code col} (i.e. an
+     * offset will be added to the given <em>row</em> index.
+     * @return an array of length 2 containing the indices of the requested
+     * location, or {@code null} if that location is not part of the map.
+     *
+     * @see #getNeuron(int,int)
+     */
+    private int[] getLocation(int row,
+                              int col,
+                              HorizontalDirection alongRowDir,
+                              VerticalDirection alongColDir) {
+        final int colOffset;
+        switch (alongRowDir) {
+        case LEFT:
+            colOffset = -1;
+            break;
+        case RIGHT:
+            colOffset = 1;
+            break;
+        case CENTER:
+            colOffset = 0;
+            break;
+        default:
+            // Should never happen.
+            throw new MathInternalError();
+        }
+        int colIndex = col + colOffset;
+        if (wrapColumns) {
+            if (colIndex < 0) {
+                colIndex += numberOfColumns;
+            } else {
+                colIndex %= numberOfColumns;
+            }
+        }
+
+        int rowOffset = 0;
+        switch (alongColDir) {
+        case UP:
+            rowOffset = -1;
+            break;
+        case DOWN:
+            rowOffset = 1;
+            break;
+        case CENTER:
+            rowOffset = 0;
+            break;
+        default:
+            // Should never happen.
+            throw new MathInternalError();
+        }
+        int rowIndex = row + rowOffset;
+        if (wrapRows) {
+            if (rowIndex < 0) {
+                rowIndex += numberOfRows;
+            } else {
+                rowIndex %= numberOfRows;
+            }
+        }
+
+        if (rowIndex < 0 ||
+            rowIndex >= numberOfRows ||
+            colIndex < 0 ||
+            colIndex >= numberOfColumns) {
+            return null;
+        } else {
+            return new int[] { rowIndex, colIndex };
+        }
+     }
+
+    /**
      * Creates the neighbour relationships between neurons.
      */
     private void createLinks() {

http://git-wip-us.apache.org/repos/asf/commons-math/blob/46e97d9e/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
index f5395b4..a48a1ce 100644
--- a/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
+++ b/src/test/java/org/apache/commons/math4/ml/neuralnet/twod/NeuronSquareMesh2DTest.java
@@ -26,6 +26,7 @@ import java.util.Collection;
 import java.util.HashSet;
 
 import org.apache.commons.math4.exception.NumberIsTooSmallException;
+import org.apache.commons.math4.exception.OutOfRangeException;
 import org.apache.commons.math4.ml.neuralnet.FeatureInitializer;
 import org.apache.commons.math4.ml.neuralnet.FeatureInitializerFactory;
 import org.apache.commons.math4.ml.neuralnet.Network;
@@ -682,4 +683,166 @@ public class NeuronSquareMesh2DTest {
             }
         }
     }
+
+    /*
+     * Test assumes that the network is
+     *
+     *  0-----1
+     *  |     |
+     *  |     |
+     *  2-----3
+     */
+    @Test
+    public void testGetNeuron() {
+        final FeatureInitializer[] initArray = { init };
+        final NeuronSquareMesh2D net = new NeuronSquareMesh2D(2, false,
+                                                              2, true,
+                                                              SquareNeighbourhood.VON_NEUMANN,
+                                                              initArray);
+        Assert.assertEquals(0, net.getNeuron(0, 0).getIdentifier());
+        Assert.assertEquals(1, net.getNeuron(0, 1).getIdentifier());
+        Assert.assertEquals(2, net.getNeuron(1, 0).getIdentifier());
+        Assert.assertEquals(3, net.getNeuron(1, 1).getIdentifier());
+
+        try {
+            net.getNeuron(2, 0);
+            Assert.fail("exception expected");
+        } catch (OutOfRangeException e) {
+            // Expected.
+        }
+        try {
+            net.getNeuron(0, 2);
+            Assert.fail("exception expected");
+        } catch (OutOfRangeException e) {
+            // Expected.
+        }
+        try {
+            net.getNeuron(-1, 0);
+            Assert.fail("exception expected");
+        } catch (OutOfRangeException e) {
+            // Expected.
+        }
+        try {
+            net.getNeuron(0, -1);
+            Assert.fail("exception expected");
+        } catch (OutOfRangeException e) {
+            // Expected.
+        }
+    }
+
+    /*
+     * Test assumes that the network is
+     *
+     *  0-----1-----2
+     *  |     |     |
+     *  |     |     |
+     *  3-----4-----5
+     *  |     |     |
+     *  |     |     |
+     *  6-----7-----8
+     */
+    @Test
+    public void testGetNeuronAlongDirection() {
+        final FeatureInitializer[] initArray = { init };
+        final NeuronSquareMesh2D net = new NeuronSquareMesh2D(3, false,
+                                                              3, false,
+                                                              SquareNeighbourhood.VON_NEUMANN,
+                                                              initArray);
+        Assert.assertEquals(0, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(1, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(2, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(3, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.CENTER).getIdentifier());
+        Assert.assertEquals(4, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                             NeuronSquareMesh2D.VerticalDirection.CENTER).getIdentifier());
+        Assert.assertEquals(5, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.CENTER).getIdentifier());
+        Assert.assertEquals(6, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+        Assert.assertEquals(7, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+        Assert.assertEquals(8, net.getNeuron(1, 1,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+
+        // Locations not in map.
+        Assert.assertNull(net.getNeuron(0, 1,
+                                        NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                        NeuronSquareMesh2D.VerticalDirection.UP));
+        Assert.assertNull(net.getNeuron(1, 0,
+                                        NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                        NeuronSquareMesh2D.VerticalDirection.CENTER));
+        Assert.assertNull(net.getNeuron(2, 1,
+                                        NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                        NeuronSquareMesh2D.VerticalDirection.DOWN));
+        Assert.assertNull(net.getNeuron(1, 2,
+                                        NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                        NeuronSquareMesh2D.VerticalDirection.CENTER));
+    }
+
+    /*
+     * Test assumes that the network is
+     *
+     *  0-----1-----2
+     *  |     |     |
+     *  |     |     |
+     *  3-----4-----5
+     *  |     |     |
+     *  |     |     |
+     *  6-----7-----8
+     */
+    @Test
+    public void testGetNeuronAlongDirectionWrappedMap() {
+        final FeatureInitializer[] initArray = { init };
+        final NeuronSquareMesh2D net = new NeuronSquareMesh2D(3, true,
+                                                              3, true,
+                                                              SquareNeighbourhood.VON_NEUMANN,
+                                                              initArray);
+        // No wrapping.
+        Assert.assertEquals(3, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+        // With wrapping.
+        Assert.assertEquals(2, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.CENTER).getIdentifier());
+        Assert.assertEquals(7, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(8, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(6, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.CENTER,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(5, net.getNeuron(0, 0,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+
+        // No wrapping.
+        Assert.assertEquals(1, net.getNeuron(1, 2,
+                                             NeuronSquareMesh2D.HorizontalDirection.LEFT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        // With wrapping.
+        Assert.assertEquals(0, net.getNeuron(1, 2,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.UP).getIdentifier());
+        Assert.assertEquals(3, net.getNeuron(1, 2,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.CENTER).getIdentifier());
+        Assert.assertEquals(6, net.getNeuron(1, 2,
+                                             NeuronSquareMesh2D.HorizontalDirection.RIGHT,
+                                             NeuronSquareMesh2D.VerticalDirection.DOWN).getIdentifier());
+    }
 }


Mime
View raw message