Backpropagation parallelisieren: Kommunikation zwischen den Threads

CptK

Bekanntes Mitglied
Hallo, ich habe versucht, meinen Backpropagation-Algorithmus zu parallelisieren. Dabei stellt die Klasse BackpropagationRunnable den eigentlichen Algorithmus dar: Dieser arbeitet mit einer Kopie des Netzes und einer Teilmenge der Beispiele, wobei immer über alle Beispiele iteriert wird und jeweils die Gewichte angepasst werden.
Java:
public class BackpropagationRunnable implements Runnable {

    private Matrix samples;
    private Matrix expOut;
    private NeuralNetwork net;
    private double alpha;

    public BackpropagationRunnable(NeuralNetwork net, Matrix in, Matrix out, double alpha) {
        this.net = net;
        this.samples = in;
        this.expOut = out;
        this.alpha = alpha;
    }

    @Override
    public void run() {
        for (int i = 0; i < samples.numberOfRows(); i++) {
            // the current sample
            Vector sample = samples.getRow(i);

            // the expected output for the current sample
            Vector expectedOutput = expOut.getRow(i);

            // the actual output for the current sample
            Vector output = net.predict(sample);

            // the error of the current sample
            Vector error = expectedOutput.sub(output);

            // create variable layer and set its value to the output layer
            Layer layer = net.getOutputLayer();

            // update the error-vector of the output layer: delta = error x g'(input)
            layer.setError(error.elementProduct(layer.getInput().map(x -> net.fct.applyDerivation(x))));

            // calculate new weights and biases for the output layer
            calculateWeightsInOutputLayer(layer, alpha, sample);

            // move backwards in the network
            layer = layer.getPreviousLayer();

            // iterate over all hidden layer
            while (layer != null) {
                // calculate new weights and biases for the hidden layer
                calculateWeightsInHiddenLayer(layer, alpha, sample);
                layer = layer.getPreviousLayer();
            }

            // update weights and biases in the output layer
            layer = net.getOutputLayer();
            layer.updateWeightsAndBiases();

            layer = layer.getPreviousLayer();

            // iterate over all hidden layers and update weights and biases
            while (layer != null) {
                layer.updateWeightsAndBiases();
                layer = layer.getPreviousLayer();
            }
        }
    }

    /**
     * Calculates the new weights and biases in the output layer, for a single
     * weight:<br>
     * w_{j,k} <- w_{j,k} x alpha x a_j x error_k <br>
     * For all weights: <br>
     * W' = W + alpha x error x a^T<br>
     * W is the {@link Matrix} of weights, a is the vector of outputs of the
     * previous layer, error is the vector of errors of the current layer<br>
     * For the biases it is the same but a is not used in the formula
     *
     * @param l      the {@link Layer} the updates are calculated for
     * @param alpha  learning rate
     * @param sample the example that is currently seen
     */
    protected void calculateWeightsInOutputLayer(Layer l, double alpha, Vector sample) {
        Matrix newWeights = Matrix.clone(l.getInputWeights());
        Vector newBiases = Vector.clone(l.getBiases());

        // error of the layer
        Matrix delta = l.getError().toMatrix();

        // outputs of the previous layer / input to the net if previousLyer is null
        Matrix inputs = (l.getPreviousLayer() == null ? sample : l.getPreviousLayer().getOutput()).toMatrix();

        // calculate new weights
        newWeights = l.getInputWeights().add(delta.mul(inputs.transpose()).mul(alpha));

        // calculate new biases
        newBiases = l.getBiases().add(l.getError().mul(alpha));

        // update new weights
        l.setNewInpWeights(newWeights);

        // update new biases
        l.setNewBiases(newBiases);
    }

    /**
     * Calculates the error for the current hidden layer, for a single neuron j it
     * is:<br>
     * error_j = g'(input_j) x sum_k{w_{j,k} * delta_k}<br>
     * g' is the derivative of the activation function, delta_k is the error of the
     * next layer
     * <p>
     * Then calculates the new weights and biases, for a single weight:<br>
     * w_{j,j} <- w_{j,j} x alpha x a_i x error_j <br>
     * For all weights: <br>
     * W' = W + alpha x error x a^T<br>
     * W is the {@link Matrix} of weights, a is the vector of outputs of the
     * previous layer, error is the vector of errors of the current layer<br>
     * For the biases it is the same but a is not used in the formula
     *
     * @param l      the {@link Layer} the updates are calculated for
     * @param alpha  learning rate
     * @param sample the example that is currently seen
     */
    protected void calculateWeightsInHiddenLayer(Layer l, double alpha, Vector sample) {
        ActivationFunction fct = l.getActivationFunction();
        Layer next = l.getNextLayer();

        // update error-vector
        l.setError(l.getInput().map(x -> fct.applyDerivation(x))
                .elementProduct(next.getInputWeights().transpose().mul(next.getError())));

        Matrix newWeights = Matrix.clone(l.getInputWeights());
        Vector newBiases = Vector.clone(l.getBiases());

        // error of the layer
        Matrix delta = l.getError().toMatrix();

        // outputs of the previous layer / input to the net if previousLyer is null
        Matrix inputs = (l.getPreviousLayer() == null ? sample : l.getPreviousLayer().getOutput()).toMatrix();

        // calculate new weights
        newWeights = l.getInputWeights().add(delta.mul(inputs.transpose()).mul(alpha));

        // calculate new biases
        newBiases = l.getBiases().add(l.getError().mul(alpha));

        // update new weights
        l.setNewInpWeights(newWeights);

        // update new biases
        l.setNewBiases(newBiases);
    }

    public NeuralNetwork getNet() {
        return net;
    }

    public Matrix[] getWeights() {
        return net.getWeights();
    }

    public void setWeights(Matrix[] weights) {
        net.setWeights(weights);
    }

    public Vector[] getBiases() {
        return net.getBiases();
    }

    public void setBiases(Vector[] biases) {
        net.setBiases(biases);
    }
}

Die Klasse BackpropagationThread bekommt das Netz übergeben und eine Anzhal an Iterationen, epochs, wie oft der Thread die zugehörige Runnable ausführen soll. Vor dem Ausführen von run() werden allerdings die Gewichte aus dem Netz, das alle Threads nutzen, geladen und nach run() werden die neuen Gewichte gespeichert:
Java:
class BackpropagationThread extends Thread {

    private final BackpropagationRunnable backprop;

    private final NeuralNetwork sharedNet;

    private final int epochs;

    public BackpropagationThread(NeuralNetwork net, Matrix in, Matrix out, double alpha, int epochs) {
        backprop = new BackpropagationRunnable((NeuralNetwork) net.clone(), in, out, alpha);
        this.sharedNet = net;
        this.epochs = epochs;
    }

    @Override
    public void run() {

        for (int a = 0; a < epochs; a++) {

            synchronized (sharedNet) {
                backprop.setWeights(sharedNet.getWeights());
                backprop.setBiases(sharedNet.getBiases());
            }

            backprop.run();

            synchronized (sharedNet) {
                Matrix weights[] = backprop.getWeights();
                Vector biases[] = backprop.getBiases();

                Matrix[] newWeights = sharedNet.getWeights();
                Vector[] newBiases = sharedNet.getBiases();
                for (int i = 0; i < newWeights.length; i++) {
                    newWeights[i] = weights[i];
                    newBiases[i] = biases[i];
                }

                sharedNet.setWeights(newWeights);
                sharedNet.setBiases(newBiases);

                backprop.setWeights(newWeights);
                backprop.setBiases(newBiases);
            }
        }

    }

    public NeuralNetwork getNet() {
        return backprop.getNet();
    }

    public Matrix[] getWeights() {
        return backprop.getWeights();
    }

    public void setWeights(Matrix[] weights) {
        backprop.setWeights(weights);
    }

    public Vector[] getBiases() {
        return backprop.getBiases();
    }

    public void setBiases(Vector[] biases) {
        backprop.setBiases(biases);
    }
}

Die Klasse BackpropagationParallel teilt die Beispielmenge auf und erzeugt die Threads
Java:
public class BackpropagationParallel extends Backpropagation {

    @Override
    public NeuralNetwork backprop(NeuralNetwork net, Matrix in, Matrix out, double alpha, int epochs) {
        Objects.requireNonNull(net, "The specified NeuralNetwork may not be null");
        Objects.requireNonNull(in, "The specified input Matrix may not be null");
        Objects.requireNonNull(out, "The specified output Matrix may no be null");

        if (in.numberOfRows() != out.numberOfRows())
            throw new IllegalArgumentException(
                    "The number of rows of the input matrix must equal the number of rows of the output matrix");

        int cores = Runtime.getRuntime().availableProcessors();
        System.out.println("available cores: " + cores);

        BackpropagationThread[] threads = new BackpropagationThread[cores];
        List<Matrix> splittedIn = in.split(cores);
        List<Matrix> splittedOut = out.split(cores);

        for (int i = 0; i < threads.length; i++) {
            threads[i] = new BackpropagationThread(net, splittedIn.get(i), splittedOut.get(i),
                    alpha, epochs);
        }

        for (BackpropagationThread t : threads)
            t.start();

        try {
            for (BackpropagationThread t : threads)
                t.join();
        } catch (InterruptedException e) {
            throw new UnsupportedOperationException(e.getMessage());
        }
        return net;
    }
}

Problem, die Kommunikation zwischen den Threads, bzw. das Austauschen der Gewichte funktioniert auf diese Weise überhaupt nicht und ich habe auch keine Idee, wie man das verbessern kann.
 

Jw456

Top Contributor
in einem Thread einen anderen Thread über die Run Methode aufrufen geht gar nicht.

Java:
public BackpropagationThread(NeuralNetwork net, Matrix in, Matrix out, double alpha, int epochs) {
        backprop = new BackpropagationRunnable((NeuralNetwork) net.clone(), in, out, alpha);
      ...
    }

    @Override
    public void run() {

       ....
            backprop.run();
 

CptK

Bekanntes Mitglied
Also kleines Update:
Mein Code macht jetzt genau das, was er machen soll (nur das das immernoch nicht optimal ist, dazu später mehr):
CyclicMain.java:
public class CyclicMain extends Backpropagation {

    private final int NUMBER_OF_THREADS;

    public CyclicMain() {
        NUMBER_OF_THREADS = Runtime.getRuntime().availableProcessors();
    }

    public NeuralNetwork backprop(NeuralNetwork net, Matrix in, Matrix out, double alpha, final int EPOCHS) {
        List<Matrix> splittedIn = in.split(NUMBER_OF_THREADS);
        List<Matrix> splittedOut = out.split(NUMBER_OF_THREADS);

        BackpropagationThread[] threads = new BackpropagationThread[NUMBER_OF_THREADS];

        AggregationThread aggregationThread = new AggregationThread(net, NUMBER_OF_THREADS);

        CyclicBarrier barrier = new CyclicBarrier(NUMBER_OF_THREADS, aggregationThread);

        for (int i = 0; i < NUMBER_OF_THREADS; i++) {
            threads[i] = new BackpropagationThread(net, barrier, splittedIn.get(i), splittedOut.get(i), alpha, EPOCHS);
        }

        aggregationThread.setThreads(threads);

        for (BackpropagationThread t : threads) {
            t.start();
        }

        for (Thread t : threads)
            try {
                t.join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

        return net;
    }
}
BackpropagationThread.java:
public class BackpropagationThread extends Thread {

    private final NeuralNetwork net;
    private final Matrix in;
    private final Matrix out;
    private final double alpha;
    private final int epochs;
    private final CyclicBarrier barrier;
    
    ...

    @Override
    public void run() {
        // repeat epochs-times
        for (int e = 0; e < epochs; e++) {
            // iterate over all samples
            for (int s = 0; s < in.numberOfRows(); s++) {
                // update weights and biases with sample s
                updateWithSample(s);
            }

            try {
                barrier.await();
            } catch (InterruptedException | BrokenBarrierException exc) {
                exc.printStackTrace();
            }
        }
    }

    public Matrix[] getWeights() {
        return net.getWeights();
    }

    public Vector[] getBiases() {
        return net.getBiases();
    }

    public void synchronize(Matrix[] newWeights, Vector[] newBiases) {
        net.setWeights(newWeights);
        net.setBiases(newBiases);
    }
    ...
}
AggregationThread.java:
public class AggregationThread implements Runnable {

    private volatile NeuralNetwork net;
    private final int NUMBER_OF_THREADS;
    private BackpropagationThread[] threads;

    public AggregationThread(NeuralNetwork net, final int NUMBER_OF_THREADS) {
        this.net = net;
        this.NUMBER_OF_THREADS = NUMBER_OF_THREADS;
    }

    @Override
    public void run() {
        Matrix[] newWeights = null;
        Vector[] newBiases = null;

        // iterate over all threads and add the weights and biases
        for (BackpropagationThread t : threads) {
            if (newWeights == null) {
                newWeights = t.getWeights();
                newBiases = t.getBiases();
            } else {
                newWeights = Matrix.addMatrixArrays(newWeights, t.getWeights());
                newBiases = Vector.addVectorArrays(newBiases, t.getBiases());
            }
        }
        
        // divide each weight by the number of threads
        for(int i = 0; i < net.numberOfLayers(); i++) {
            newWeights[i] = newWeights[i].mul(1.0 / (double) NUMBER_OF_THREADS);
            newBiases[i] = newBiases[i].mul(1.0 / (double) NUMBER_OF_THREADS);
        }

        // update the weights and biases of the return-net
        net.setWeights(newWeights);
        net.setBiases(newBiases);

        // synchronize the nets in the threads with the new weights and biases
        for (BackpropagationThread t : threads)
            t.synchronize(newWeights, newBiases);

    }
    ...
}
Das Problem ist jetzt, dass das Ergebnis schlechter wird, je mehr Threads ich habe. Das bedeutet, dass die Methode, einfach die Gewichte der verschiedenen Threads aufzusummieren und dann durch die Anzahl der Threads zu teilen, nicht besonders gut ist. Allerdings habe ich keine Idee, wie man das besser machen könnte und im Internet habe ich auch nicht wirklich was brauchbares gefunden.
 

CptK

Bekanntes Mitglied
So direkt glaube ich auch nicht, da müsste ich mich höchstens mal ein bissi an der Uni umhören obs da irgendwo was gibt. Das Ding ist ich bin erst im zweiten Semester, habe also noch nichts mit diesem Thema (und dank Corona mit den Leuten schon gar nichts) zu tun gehabt, was bedeutet, dass mir da auch bissi die Connections fehlen...
 

Barista

Top Contributor
Du müsstest analysieren, welche Ergebnisse von anderen Ergebnissen abhängen.

Das Rumpfummeln im Code mit der Hoffnung, es bringt irgendwas, bringt nichts.

Nun hilft die Programmiersprache (der Compiler) bei Arbeit mit Seiteneffekten nicht.

Du müsstest den Code also konsequent auf unveränderliche Daten umstellen.

Wenn es sich um grosse Objekte handelt, schau mal die Idee Lenses an.

Wenn es sich um Collections handelt, schau mal Vavr an.

Durch die Umstellung auf unveränderliche Daten werden die Abhängigkeiten explizit und dann ergeben sich Möglichkeiten zur Parallelisierung.
 
Ähnliche Java Themen
  Titel Forum Antworten Datum
CptK Backpropagation Algorithmus Allgemeine Java-Themen 6
A Methoden parallelisieren? Allgemeine Java-Themen 2
D JGAP parallelisieren? Allgemeine Java-Themen 3
L Kommunikation zwischen C# und Java? Allgemeine Java-Themen 5
R PIPE Kommunikation mit Prozess blockiert Allgemeine Java-Themen 0
M Checksummenprüfung bei Client Server kommunikation Allgemeine Java-Themen 3
D Interthread-Kommunikation Allgemeine Java-Themen 6
S Threads Kommunikation zwischen SocketThread und WorkerThread Allgemeine Java-Themen 11
0 Lösungsweg Client Server Kommunikation Fehlermeldung ausgeben Allgemeine Java-Themen 12
L Software-Design: Kommunikation mit SerialPort (RXTX) Allgemeine Java-Themen 2
F Serielle Kommunikation Allgemeine Java-Themen 20
A Kommunikation zwischen 2 Jar-dateien Allgemeine Java-Themen 16
S Kommunikation von Core und GUI über Schnittstellen Allgemeine Java-Themen 2
P Komponenten kommunikation Allgemeine Java-Themen 7
L Serielle Kommunikation Allgemeine Java-Themen 6
G Implementierung einer Kommunikation Allgemeine Java-Themen 7
B SerialPort Kommunikation - Pufferüberlauf Allgemeine Java-Themen 6
0 Sandbox und Applet Kommunikation Allgemeine Java-Themen 9
E kommunikation zwischen Fenstern Allgemeine Java-Themen 3
H Java C++ Interprozess Kommunikation Allgemeine Java-Themen 7
D Klassen Kommunikation Allgemeine Java-Themen 11
M Applet Servlet Kommunikation ein Problem bei externem Server Allgemeine Java-Themen 3
F Kommunikation zw 2 Java-Programmen mit unterschiedl Compiler Allgemeine Java-Themen 13
G Kommunikation mit Remotegeräten Allgemeine Java-Themen 2
A Kommunikation zwischen C++ und Java-Programm Allgemeine Java-Themen 4
J Kommunikation mit USB Gerät, das Midi-Signale sendet Allgemeine Java-Themen 4
G Kommunikation mit der seriellen Schnittstelle Allgemeine Java-Themen 6
H Kommunikation mit einem c-Prozess funzt nicht richtig Allgemeine Java-Themen 5
C Pattern für Kommunikation gesucht Allgemeine Java-Themen 3
B Kommunikation mit entferntem Thread Allgemeine Java-Themen 6
O kommunikation zw. objekten Allgemeine Java-Themen 7
G Kommunikation von zwei Java-Programmen Allgemeine Java-Themen 3
N Inter-Thread-Kommunikation Allgemeine Java-Themen 3
M Kommunikation mit externem Programm ueber Linux-Pipes Allgemeine Java-Themen 4
M Kommunikation zwischen 2 Programmen Allgemeine Java-Themen 7
G Kommunikation zwischen Threads und Gui Allgemeine Java-Themen 2
T Kommunikation mit batch datei Allgemeine Java-Themen 3
P kommunikation zwischen dialog und threads Allgemeine Java-Themen 4
K MVC - Kommunikation Controller <> Gui Allgemeine Java-Themen 5
P Kommunikation von Threads Allgemeine Java-Themen 5
V Kommunikation zwischen Programmen und Threads Allgemeine Java-Themen 7
H Servlet/Applet Kommunikation Allgemeine Java-Themen 2
C Kommunikation mit USB Port Allgemeine Java-Themen 1

Ähnliche Java Themen


Oben