Hi
Ich hatte heute mit jemanden zusammengesessen, um unser erstes neuronales Netzwerk zu programmieren. Das scheint doch ziemlich komplex zu sein. In einer allerersten ungetesteten Version sind wir zu dem folgenden Ergebnis gekommen:
ThresholdValueFunc und andere mögliche Func-Klassen gibts hier:
Wir haben zuvor noch nie ein objektorientiert programmiertes neuronales Netzwerk gesehen. Darum würde mich interssieren, ob wir auf dem richtigen Weg sind. Ist die Programmierung soweit in Ordnung?
Ich hatte heute mit jemanden zusammengesessen, um unser erstes neuronales Netzwerk zu programmieren. Das scheint doch ziemlich komplex zu sein. In einer allerersten ungetesteten Version sind wir zu dem folgenden Ergebnis gekommen:
Java:
class Neuron
{
private double o;
private Neuron[] neurons;
private double[] weights;
private Func propagationFunc;
private Func activationFunc;
private Func outputFunc;
/**
* Input-Neuron
*/
public Neuron(Func propagationFunc, Func activationFunc, Func outputFunc)
{
this.propagationFunc = propagationFunc;
this.activationFunc = activationFunc;
this.outputFunc = outputFunc;
}
/**
* Output-Neuron
* @param Neuron[] neurons Eingehende Neuronen
* @param double[] weights Gewichte der eigehenden Verbindungen
*/
public Neuron(Neuron[] neurons, double[] weights, Func propagationFunc, Func activationFunc, Func outputFunc)
{
this.neurons = neurons;
this.weights = weights;
this.propagationFunc = propagationFunc;
this.activationFunc = activationFunc;
this.outputFunc = outputFunc;
}
/**
* Wert für ein Input-Neuron.
* @param double value
*/
public void in(double value)
{
this.o = value;
}
/**
* Ausgabe des Neurons.
* @return double
*/
public double out()
{
if(neurons != null)
{
// Propagierungsfunktion => net
double net = 0.0;
for(int i = 0; i < neurons.length; ++i)
{
net += weights[i] * neurons[i].out();
}
if(propagationFunc != null)
{
net = propagationFunc.calc(net);
}
// Aktivierungsfunktion => a
double a = (activationFunc != null)
? activationFunc.calc(net)
: net;
// Outputfunktion => o
double o = (outputFunc != null)
? outputFunc.calc(a)
: a;
}
return o;
}
}
class Main
{
public static void main(String[] args)
{
double[] values = {1.0, 0.0, 1.0};
double[] weights = {0.6, 0.6, 0.6};
double threshold = 0.5;
double value = 1.0;
Func f = new ThresholdValueFunc(threshold, value); // 1
// Schicht 0:
Neuron n0 = new Neuron(f, null, null);
Neuron n1 = new Neuron(f, null, null);
// Schicht 1:
Neuron[] level0 = {n0, n1};
Neuron n2 = new Neuron(level0, new double[]{0.02, 0.12}, f, null, null);
Neuron n3 = new Neuron(level0, new double[]{0.03, 0.13}, f, null, null);
// Schicht 2:
Neuron n4 = new Neuron(new Neuron[]{n0, n1, n2, n3}, new double[]{0.04, 0.14, 0.24, 0.34}, f, null, null);
n0.in(1.0);
n1.in(1.0);
System.out.println(n4.out());
}
}
Java:
interface Func
{
public double calc(double value);
}
class ThresholdValueFunc implements Func
{
private double threshold;
private double thresholdValue;
/**
* @param double value Der Wert, der ausgegeben wird, wenn calc einen Wert > threshold berechnet.
*/
public ThresholdValueFunc(double threshold, double thresholdValue)
{
this.threshold = threshold;
this.thresholdValue = thresholdValue;
}
/**
* Gleiche Anzahl Elemente!
* @return double
*/
public double calc(double value)
{
if(value > threshold)
{
return thresholdValue;
}
return 0.0;
}
}
class ThresholdValuesFunc implements Func
{
private double threshold;
private double upperThresholdValue;
private double lowerThresholdValue;
/**
* @param double value Der Wert, der ausgegeben wird, wenn calc einen Wert > threshold berechnet.
*/
public ThresholdValuesFunc(double threshold, double upperThresholdValue, double lowerThresholdValue)
{
this.threshold = threshold;
this.upperThresholdValue = upperThresholdValue;
this.lowerThresholdValue = lowerThresholdValue;
}
/**
* Gleiche Anzahl Elemente!
* @return double
*/
public double calc(double value)
{
if(value > threshold)
{
return upperThresholdValue;
}
return lowerThresholdValue;
}
}
class SignumFunc implements Func
{
private double threshold;
public SignumFunc(double threshold)
{
this.threshold = threshold;
}
/**
* @return double
*/
public double calc(double value)
{
if(value > threshold)
{
return 1.0;
}
return -1.0;
}
}
class LinearFunc implements Func
{
private double upperThreshold;
private double lowerThreshold;
private double upperThresholdValue;
private double lowerThresholdValue;
/**
* Garantiert, dass es keine Sprünge in der Kurve gibt.
*/
public LinearFunc(double upperThreshold, double lowerThreshold)
{
this.upperThresholdValue = this.upperThreshold = upperThreshold;
this.lowerThresholdValue = this.lowerThreshold = lowerThreshold;
}
/**
* Sprünge in der Kurve möglich.
*/
public LinearFunc(double upperThreshold, double lowerThreshold, double upperThresholdValue, double lowerThresholdValue)
{
this.upperThreshold = upperThreshold;
this.lowerThreshold = lowerThreshold;
this.upperThresholdValue = upperThresholdValue;
this.lowerThresholdValue = lowerThresholdValue;
}
/**
* @return double
*/
public double calc(double value)
{
if(value >= upperThreshold)
{
return upperThresholdValue;
}
if(value <= lowerThreshold)
{
return lowerThresholdValue;
}
return value;
}
}
class SigmoidFunc implements Func
{
private double delta;
public SigmoidFunc(double delta)
{
this.delta = delta;
}
/**
* @return double
*/
public double calc(double value)
{
return 1.0 / (1.0 + Math.exp(-value) / delta); // e^(-sum)
}
}
class TanHFunc implements Func
{
/**
* @return double
*/
public double calc(double value)
{
return Math.tanh(value);
}
}