-
Notifications
You must be signed in to change notification settings - Fork 0
/
NeuralNetworkT.java
122 lines (84 loc) · 3.62 KB
/
NeuralNetworkT.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package neuralnetwork;
public class NeuralNetworkT {
int inputsNodes; //Number of inputs for the Network
int hiddenNodes; // Number of hidden layers
int outputsNodes; // Number of outputs in the netwoek
Matrix weightsIH; // Weight matrix for each layers hidden -> Output
Matrix weightsHO; // Weight matrix for each layers hidden -> Output
Matrix biasH; // Bias of each layer
Matrix biasO; // Bias of each layer
double learningRate = 0.1; // learning rate
class Sigmoid implements MatrixMethod{
//Sigmoid function for making the matrix between 0 and 1
@Override
public double change(double n) {
return (1/( 1 + Math.pow(Math.E,(-1*n))));
}
}
class DSigmoid implements MatrixMethod{
//Derivative of sigmoid for already sigmoided inputs
@Override
public double change(double n) {
return n * (1-n);
}
}
MatrixMethod sigmoid = new Sigmoid();
MatrixMethod dSigmoid = new DSigmoid();
//Contructor for creating Neural Network
public NeuralNetworkT(int inputs, int hiddenNodes, int outputs){
this.inputsNodes = inputs;
this.hiddenNodes = hiddenNodes;
this.outputsNodes = outputs;
weightsIH = new Matrix(hiddenNodes, inputsNodes);
weightsHO = new Matrix(outputsNodes, hiddenNodes);
biasH = new Matrix(hiddenNodes, 1);
biasO = new Matrix(outputsNodes, 1);
weightsIH.randomize(); // Weight matrix for each layers hidden -> Output
weightsHO.randomize(); // Weight matrix for each layers hidden -> Output
biasH.randomize();
biasO.randomize();
}
//Train the Network -> 1 input at a time
public void train(double[] input, double[] answers) {
//To check inputs and answers are according to the networks requirement
Matrix inputs = Matrix.fromArray(input);
Matrix hiddenOut = Matrix.multiply(weightsIH, inputs);
hiddenOut.add(biasH);
hiddenOut.changeElements(sigmoid);
Matrix outputOut = Matrix.multiply(weightsHO, hiddenOut);
outputOut.add(biasO);
outputOut.changeElements(sigmoid);
Matrix target = Matrix.fromArray(answers);
Matrix error = Matrix.subtract(target, outputOut);
Matrix gradient = Matrix.changeElements(outputOut, dSigmoid);
gradient.multiply(learningRate);
gradient.multiply(error);
Matrix hiddenOutT = Matrix.transpose(hiddenOut);
Matrix deltaWeightHO = Matrix.multiply(gradient, hiddenOutT);
weightsHO.add(deltaWeightHO);
biasO.add(gradient);
Matrix weightHOT = Matrix.transpose(weightsHO);
Matrix errorH = Matrix.multiply(weightHOT, error);
Matrix gradientH = Matrix.changeElements(hiddenOut, dSigmoid);
gradientH.multiply(learningRate);
gradientH.multiply(errorH);
Matrix inputT = Matrix.transpose(inputs);
Matrix deltaWeightIH = Matrix.multiply(gradientH, inputT);
// weightsIH.printmatrix();
// deltaWeightIH.printmatrix();
weightsIH.add(deltaWeightIH);
biasH.add(gradientH);
}
//Predict the output of and Input
public double[] predict(double[] input){
//Input into a matrix object
Matrix inputs = Matrix.fromArray(input);
Matrix hiddenOut = Matrix.multiply(weightsIH, inputs);
hiddenOut.add(biasH);
hiddenOut.changeElements(sigmoid);
Matrix outputOut = Matrix.multiply(weightsHO, hiddenOut);
outputOut.add(biasO);
outputOut.changeElements(sigmoid);
return outputOut.toArray();
}
}