...
Parameter | Description | |
model path | The path to specify the location to store the model. | |
learningRate | Control the aggressive of learning. A big learning rate can accelerate the training speed, | |
regularization | Control the complexity of the model. A large regularization value can make the weights between | |
momentum | Control the speed of training. A big momentum can accelerate the training speed, but it may | ]]></ac:plain-text-body></ac:structured-macro> |
squashing function | Activate function used by MLP. Candidate squashing function: sigmoid, tanh. | |
cost function | Evaluate the error made during training. Candidate cost function: squared error, cross entropy (logistic). | |
layer size array | An array specify the number of neurons (exclude bias neurons) in each layer (include input and output layer). | |
The following is the sample code regarding model initialization.
No Format |
---|
SmallLayeredNeuralNetwork ann String modelPath = "/tmp/xorModel-training-by-xor.data"; double learningRate = 0.6; double regularization = 0.02; // no regularization double momentum = 0.3; // no momentum String squashingFunctionName = "Tanh"; String costFunctionName = "SquaredError"; int[] layerSizeArray = new int[] { 2, 5, 1 }; SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate, regularization, momentum, squashingFunctionName, costFunctionName, layerSizeArray= new SmallLayeredNeuralNetwork(); ann.setLearningRate(0.1); // set the learning rate ann.setMomemtumWeight(0.1); // set the momemtum weight // initialize the topology of the model, a three-layer model is created in this example ann.addLayer(featureDimension, false, FunctionFactory.createDoubleFunction("Sigmoid")); ann.addLayer(featureDimension, false, FunctionFactory.createDoubleFunction("Sigmoid")); ann.addLayer(labelDimension, true, FunctionFactory.createDoubleFunction("Sigmoid")); // set the cost function to evaluate the error ann.setCostFunction(FunctionFactory.createDoubleDoubleFunction("CrossEntropy")); String trainedModelPath = ...; ann.setModelPath(trainedModelPath); // set the path to store the trained model // add training parameters Map<String, String> trainingParameters = new HashMap<String, String>(); trainingParameters.put("tasks", "5"); // the number of concurrent tasks trainingParameters.put("training.max.iterations", "" + iteration); // the number of maximum iterations trainingParameters.put("training.batch.size", "300"); // the number of training instances read per update ann.train(new Path(trainingDataPath), trainingParameters); |
Two class learning problem
...