hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Apache Wiki <wikidi...@apache.org>
Subject [Hama Wiki] Trivial Update of "MultiLayerPerceptron" by YexiJiang
Date Sun, 20 Oct 2013 15:44:55 GMT
Dear Wiki user,

You have subscribed to a wiki page or wiki category on "Hama Wiki" for change notification.

The "MultiLayerPerceptron" page has been changed by YexiJiang:
https://wiki.apache.org/hama/MultiLayerPerceptron?action=diff&rev1=26&rev2=27

  
  The following is the sample code regarding model initialization.
  {{{
-     String modelPath = "/tmp/xorModel-training-by-xor.data";
-     double learningRate = 0.6;
-     double regularization = 0.02; // no regularization
-     double momentum = 0.3; // no momentum
-     String squashingFunctionName = "Tanh";
-     String costFunctionName = "SquaredError";
-     int[] layerSizeArray = new int[] { 2, 5, 1 };
-     SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
-         regularization, momentum, squashingFunctionName, costFunctionName,
-         layerSizeArray);
+   SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
+ 
+   ann.setLearningRate(0.1); // set the learning rate
+   ann.setMomemtumWeight(0.1); // set the momemtum weight
+ 
+   // initialize the topology of the model, a three-layer model is created in this example
+   ann.addLayer(featureDimension, false, FunctionFactory.createDoubleFunction("Sigmoid"));
+   ann.addLayer(featureDimension, false, FunctionFactory.createDoubleFunction("Sigmoid"));
+   ann.addLayer(labelDimension, true, FunctionFactory.createDoubleFunction("Sigmoid"));
+ 
+   // set the cost function to evaluate the error
+   ann.setCostFunction(FunctionFactory.createDoubleDoubleFunction("CrossEntropy")); 
+   String trainedModelPath = ...;
+   ann.setModelPath(trainedModelPath); // set the path to store the trained model
+ 
+   // add training parameters
+   Map<String, String> trainingParameters = new HashMap<String, String>();
+   trainingParameters.put("tasks", "5"); // the number of concurrent tasks
+   trainingParameters.put("training.max.iterations", "" + iteration); // the number of maximum
iterations
+   trainingParameters.put("training.batch.size", "300");  // the number of training instances
read per update
+   ann.train(new Path(trainingDataPath), trainingParameters);
  }}}
  
  === Two class learning problem ===

Mime
View raw message