Du verwendest einen veralteten Browser. Es ist möglich, dass diese oder andere Websites nicht korrekt angezeigt werden. Du solltest ein Upgrade durchführen oder ein alternativer Browser verwenden.
For-Schleife in Thread nach jedem Durchlauf pausieren
class MyThread extends Thread
private final Updater updater;
.....
public void run() {
for(int i = 0; i < x; i++) {
// do something
synchronized(updater) {
updater.update();
}
// Und hier soll jetzt wait() stehen
}
}
Und noch die Klasse Updater:
Java:
class Updater {
MyThread[] threads;
int counter = 0;
....
void update() {
// do something
counter++;
if(counter == threads.length) {
// do something
for(MyThread t : threads)
t.notify();
}
}
}
Also zusammengefasst: Ich habe eine Anzahl von Threads. Jeder Thread hat eine for-Schleife mit der gleichen Anzahl an Durchläufen. Nach jedem Durchlauf soll update() vom Updater aufgerufen werden. Bis alle Threads das gemacht haben sollen die anderen jeweils warten und wenn alle Threads update() aufgerufen haben (natürlich bekommen alle die gleiche Instanz übergeben) sollen alle Threads wieder aufgeweckt werden und der Spaß fängt wieder von vorne an. Allerdings blicke ich nicht, wie ich das mache, ich habe da ständig Probleme mit "current thread is not owner". Ich hoffe, ich konnte klar genug ausdrücken, was ich eigentlich will.
Das Gefühl habe ich langsam auch... Ich habe mir jetzt mal etwas ein bisschen anderes überlegt:
Java:
class MyRunnable implements Runnable {
private final MyThread thread;
...
public synchronized run() {
for(int i = 0; i < epochs; i++) {
// do something
thread.update();
try {
wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
Java:
class MyThread extends Thread {
private final MyRunnable runnable;
private final Updater updater;
...
public void run() {
...
runnable.run();
}
synchronized void update() {
updater.update();
}
}
Java:
class Updater {
MyThread[] threads;
int counter = 0;
....
void update() {
// do something
counter++;
if(counter == threads.length) {
// do something
// hier müssten jetzt alle Runnables wieder fortgesetzt werden
}
}
}
Ich weiß jetzt allerdings nicht, wie ich in der Methode update alle Runnables wieder fortsetzen kann. Ich habe versucht mir in der Klasse Thread eine Methode zu schreiben, die das erledigt, die ich dann vom Updater aus aufrufen kann, das klappt allerdings nicht
Wenn ich das richtig verstehe soll immer auf dem gleichen Object die Methode update aufgerufen werden - von verschieden Threads - aber nie gleichzeitig. Dann sind die Threads ja völlig überflüssig, man könnte auch einfach in einer Schleife immer wieder update aufrufen. Das hätte genau den selben Effekt aber weniger Overhead.
[CODE lang="java" title="busy wait"] // Wait until all threads have called update
while ( ! updater.allThreadsUpdated() )
{
Thread.sleep( 10 );
}
[/CODE]
Ist nicht die beste Lösung bezüglich Leistung.
Das könntest Du durch ein Object.wait() und Object.notifyAll()-Konstrukt ersetzen.
Ab Java5 gab es da Alternativen, habe ich jetzt nicht zur Hand.
Wenn ich das richtig verstehe soll immer auf dem gleichen Object die Methode update aufgerufen werden - von verschieden Threads - aber nie gleichzeitig. Dann sind die Threads ja völlig überflüssig, man könnte auch einfach in einer Schleife immer wieder update aufrufen. Das hätte genau den selben Effekt aber weniger Overhead.
Also entweder ich bin zu blöd zu kapieren, wie ich das machen soll, oder ich schaffe es nicht zu erklären, was ich eigentlich will. Deshalb versuche ich es jetzt noch mal mit einer aktuellen (abgespeckten) Version meines eigentlichen Projekts.
Ich habe 3 Klassen: BackpropagationMultithread, BackpropagationRunnable, WeightUpdater. Hierbei erstellt BackpropagationMultithread die verschiedenen threads und die zugehörigen Runnables. Die Runnables führen jeweils den Backpropagation Algorithmus für einen Teil der Trainingsmenge aus. Dies machen sie jeweils in einer for-Schleife (epochs oft) wobei in jedem durchlauf einmal über alle Trainingsbeispiele iteriert wird. Nach jeder epoch sollen die verschiedenen Threads ihre Gewichte austauschen, wofür WeightUpdater zuständig ist. Das Problem liegt beim Reaktivieren der Runnables nach dem wait() in WeightUpdater, weitere Erklärung als Kommentar im Code:
[CODE lang="java" title="BackpropagationMultithread"]public class BackpropagationMultithread extends Backpropagation {
@Override
public NeuralNetwork backprop(NeuralNetwork net, Matrix in, Matrix out, double alpha, int epochs) {
int cores = Runtime.getRuntime().availableProcessors();
Thread[] threads = new Thread[cores];
BackpropagationRunnable[] runnables = new BackpropagationRunnable[cores];
WeightUpdater updater = new WeightUpdater(net, runnables);
for (int i = 0; i < cores; i++) {
runnables = new BackpropagationRunnable((NeuralNetwork) net.clone(), splittedIn.get(i),
splittedOut.get(i), alpha, updater);
threads = new Thread(runnables);
}
for (Thread t : threads)
t.start();
try {
for (Thread t : threads)
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
private final Matrix samples;
private final Matrix expOut;
private final NeuralNetwork net;
private final double alpha;
private final WeightUpdater updater;
// some getter and setter...
}
[/CODE]
[CODE lang="java" title="WeightUpdater"]public class WeightUpdater {
private final NeuralNetwork net;
private final Matrix[] newWeights;
private final Vector[] newBiases;
private int counter = 0;
private final BackpropagationRunnable[] runnables;
//Diese Methode wird von den Runnables nach jeder epoch aufgerufen, und es wird der
//Mittelwert der Gewichte gebildet.
synchronized void updateWeightsAndBiases(Matrix[] weights, Vector[] biases) {
for (int i = 0; i < net.numberOfLayers; i++) {
newWeights = newWeights.add(weights);
newBiases = newBiases.add(biases);
}
counter++;
// Wenn alle runnables das Ergebnis ihrer epoch geliefert haben, so wird das Netz
// mit dem Mittelwert geupdatet und die Runnables sollen fortgesetzt werden
if (counter == runnables.length) {
for (int i = 0; i < net.numberOfLayers; i++) {
newWeights = newWeights.mul(1.0 / (double) runnables.length);
newBiases = newBiases.mul(1.0 / (double) runnables.length);
// und hier ist das Problem: die Runnables halten an der Stelle mit "wait()"
// zwar an, allerdings sollten sie hier wieder reaktiviert werden, was aber
// nicht passiert
for (BackpropagationRunnable t : runnables)
synchronized (t) {
t.notifyAll();
}
Einen sehr schönen möchte ich hier mal gerade vorstellen, den CyclicBarrier:
Java:
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
public class CyclicBarrierTest {
public static class MyTask implements Runnable {
public static final CyclicBarrier BARRIER = new CyclicBarrier(5);
private String name;
public MyTask(String name) {
this.name = name;
}
@Override
public void run() {
try {
while (true) {
// so something here
System.out.println("Hello from " + name);
// wait, to make it dramatically
Thread.sleep((long) (Math.random() * 5000));
// wait for all other tasks...
BARRIER.await();
}
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
}
}
public CyclicBarrierTest() {
new Thread(new MyTask("a")).start();
new Thread(new MyTask("b")).start();
new Thread(new MyTask("c")).start();
new Thread(new MyTask("d")).start();
new Thread(new MyTask("e")).start();
}
public static void main(String[] args) {
new CyclicBarrierTest();
}
}
Mir ist eingefallen, dass das wait und notify auf dem einzigen updater-Objekt aufgerufen werden müsste, dieses Objekt wird zwischen den Threads geteilt.
Was bedeutet das "als Lock"?
Und was ich nicht verstehe ist, wieso kann ich nicht jede Runnable pausieren, nachdem sie mit ihrer Schleife fertig ist und dann einfach vom Updater aus, der ja das Verbindungsstück ist, wieder reaktivieren? Also ich war jetzt davon ausgegangen, dass this.wait() immer das aktuelle Objekt pausiert und man genau dieses Objekt dann auch von außerhalb wieder starten könnte. Diese ganzen Beispiele hier, wie CyclicBarrier, sehen, naja, kompliziert aus und ich habe mir das eigentlich als ziemlich einfaches Problem vorgestellt: Jede Runnable verarbeitet ihre Beispiele und gibt die Ergebnisse an den Updater und wartet bis der Updater das Signal zum weitermachen gibt.
Mit dem busy-wait habe ich so meine Performance Bedenken und den ganzen Parallelisierungs-Spaß mache ich ja nur aus Performancegründen.
Ich habe jetzt mal noch ein bisschen versucht mich zum Thema Threads zu informieren und habe dabei einen Ansatz entwickelt, den ich eigentlich ganz einleuchtend finde, hier mal ein kleines Beispiel:
Java:
public class Communicator {
private final int numberThreads;
private volatile int seenThreads = 0;
public Communicator(final int numberThreads) {
this.numberThreads = numberThreads;
}
public synchronized void print(String msg) {
seenThreads++;
System.out.println(msg);
try {
if (seenThreads < numberThreads)
wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
seenThreads = 0;
System.out.println("reset");
notifyAll();
}
Java:
public class MyRunnable implements Runnable {
private String[] msgs;
private Communicator c;
public MyRunnable(Communicator c, String... msgs) {
this.msgs = msgs;
this.c = c;
}
@Override
public void run() {
for(String msg : msgs)
c.print(msg);
}
}
Java:
public static void main(String[] args) {
Communicator c = new Communicator(3);
MyRunnable r1 = new MyRunnable(c, "0", "3", "6");
MyRunnable r2 = new MyRunnable(c, "1", "4", "7");
MyRunnable r3 = new MyRunnable(c, "2", "5", "8");
Thread t1 = new Thread(r1);
Thread t2 = new Thread(r2);
Thread t3 = new Thread(r3);
t1.start();
t2.start();
t3.start();
}
Jetzt würde ich eine solche Ausgabe (bis auf die Reihenfolge der Zahlen zwischen den Resets) erwarten:
Ich verstehe nur nicht so wirklich was da falsch läuft. Dazu kommt dann noch, dass ich am Ende wohl ein Deadlock habe und das Programm nicht terminiert.
Ich verstehe nur nicht so wirklich was da falsch läuft. Dazu kommt dann noch, dass ich am Ende wohl ein Deadlock habe und das Programm nicht terminiert.
Also ich war jetzt davon ausgegangen, dass this.wait() immer das aktuelle Objekt pausiert und man genau dieses Objekt dann auch von außerhalb wieder starten könnte. Diese ganzen Beispiele hier, wie CyclicBarrier, sehen, naja, kompliziert aus und ich habe mir das eigentlich als ziemlich einfaches Problem vorgestellt
Was ich verstanden habe was du machen willst: Du hast n Threads für eine Aufgabe. Alle Threads sollen sich nach jedem Schleifendurchlauf synchronisieren.
Und da frage ich mich: Wofür brauchst du Threads, was erhoffst du dir davon? Ein Geschwindigkeitsvorteil kann es nicht sein, denn du wirst für die Synchronisation wahrscheinlich(!) mehr Zeit benötigen als wenn du deine Aufgabe in einem einzelnen Thread erledigen würdest.
Threads sind dann gut, wenn du eine Aufgabe gut parallelisieren kannst, und die Parallelisierungszweige möglichst unabhängig voneinander arbeiten können.
Folgendes Beispiel:
Du hast einen Haufen Steine, der von einer Ecke in die andere Ecke geschafft werden soll. Und du hast einen Arbeiter, der genau einen Stein tragen kann. Jetzt kann dieser eine arme Arbeiter ewig an dem Haufen schleppen, jeden Stein einzeln, irgendwann wird der Haufen abgearbeitet sein.
Jetzt nimmst du einfach z.B. vier Leute, und läßt diese den Haufen zusammen abtragen. Alle Arbeiter laufen unterschiedlich schnell, manche sind wieselflink, manche sind Lahmärsche, aber jeder kann nur einen Stein gleichzeitig tragen.
Wenn jetzt jeder Arbeiter seinen Stein in seinem Tempo einfach bewegen kann, wird der Haufen schneller abgetragen sein, je mehr Arbeiter du hast.
Wenn jetzt aber alle Arbeiter, nachdem sie einen Stein transportiert haben, erstmal eine Besprechung abhalten und dazu warten bis auch der letzte Trödelheini dazukommt, dann dauert es. Je mehr Arbeiter sich dabei synchronisieren müssen, umso länger dauert es, und unter Umständen wäre ein einzelner Arbeiter schneller fertig als wenn es zwei oder mehr sind.
Bei Threads ist das im Prinzip genauso: die laufen unterschiedlich schnell, und du hast keinen oder kaum Einfluß darauf, wie schnell (das ist betriebssystemabhängig). Du kannst Threads bremsen (mit Thread.sleep), aber nicht beschleunigen. Und selbst relativ einfache Ordnungsmechanismen bremsen Threads sehr schnell sehr stark aus.
Und bevor du dich mit Problemen wie Deadlocks herumschlagen mußt: Wie sicher bist du dir das Multithreading dir weiterhilft?
Edit:
Anstelle von Runnable arbeite ich lieber mit der Threadklasse...aber das ist wohl Geschmackssache.
Java.lang.Object.wait() Method - The java.lang.Object.wait() causes current thread to wait until another thread invokes the notify() method or the notifyAll() method for this object. In other words, this method behaves exactly as if it simply performs the call wait(0).
www.tutorialspoint.com
Nach dem zweiten Link sollte es besser heissen: owner of the object's monitor
Was ich verstanden habe was du machen willst: Du hast n Threads für eine Aufgabe. Alle Threads sollen sich nach jedem Schleifendurchlauf synchronisieren.
Und da frage ich mich: Wofür brauchst du Threads, was erhoffst du dir davon? Ein Geschwindigkeitsvorteil kann es nicht sein, denn du wirst für die Synchronisation wahrscheinlich(!) mehr Zeit benötigen als wenn du deine Aufgabe in einem einzelnen Thread erledigen würdest.
Also ich habe ja eine äußere Schleife for (int epoch = 0; epoch < epochs; epoch++) und eine innere Schleife for(int s = 0; s < samples.length; s++). samples.length ist hier in meinen Beispielen immer etwas um die 1000 gewesen. Synchronisiert wird immer, wenn die innere Schleife abgeschlossen ist. Anstatt dass ein einzelner Thread jetzt epochs * samples.length * numberOfThreads oft rechnen muss muss er immer nur einen Anteil von 1/NumberOfThreads davon machen. Mit dem Beispiel oben (MyRunnable und Communicator) läuft das Multithreadsystem etwa 3 mal so schnell wie die ursprüngliche Variante, die nur einen Thread benutzt.
Ja, aber ich verstehe dennoch nicht so ganz die Problematik des TE.
Es klang so, als ob er Blöcke von Aufgaben hat, die parallel bearbeitet werden können. Aber ein Block muss abgeschlossen sein, ehe der nächste Block bearbeitet werden kann.
Wenn dem so ist, dann kann er seine Aufgaben die zugleich verarbeitet werden sollen, einfach zusammen packen und dann über einen parallelen Stream abarbeiten lassen.
=> Komplexität und Tests gehen direkt massiv runter.
Oder man startet einfach die Threads für alle Aufgaben und macht dann joins. Da er ständig neue Thread Instanzen erzeugen muss, ist es nicht ganz so toll - wirklich problematisch wird das aber nur, wenn die Anzahl hoch ist und die Laufzeit immer sehr kurz.
Dann würde ich das anders aufbauen. Zunächst: Spare dir die äußere Schleife. Was sind epochs bei dir? Ist das eine Art Datenpaket, aus dem die Threadinstanzen ihre Samples bekommen? Wenn ja: schmeiß die epochs in ein HashSet. Oder, falls die Reihenfolge wichtig ist, in ein ArrayList.
Den Zugriff auf das HashSet/ArrayList mußt du synchronisieren, entweder über eine Zugriffsmethode und @synchronized, ich persönlich arbeite aber gerne damit:
Java provides mechanism for the synchronization of blocks of code based on the Lock interface and classes that implement it (such as ReentrantLock). In this tutorial, we will see a basic usage of Lock interface to solve printer queue problem.
howtodoinjava.com
Dann holt ein Thread den Lock, holt sich ein Element aus der Collection, gibt den Lock wieder frei und erst dann kann der nächste Thread.
Die Threads selber holen sich selbständig ein neues Datenpaket, bis die Collection leer ist. Es kann dann sein das ein Thread mehr abarbeitet als ein anderer Thread, aber jeder Thread kann stur sein Probramm abfahren. Ohne deine Aufgabe näher zu kennen hier mal ein Rohentwurf wie ich das angehen würde.
Das ist jetzt unter mangelhafter Konzentration auf die Schnelle hingerotzt (auf dem Foreneditor und nicht in der IDE, es kann durchaus sein daß manche Methoden anders heißen oder noch der ein oder andere Fehler drin schlummert), aber vielleicht bekommst du eine Idee davon.
Java:
interface DataSource{
DataPackage getNewDataPackage(); //DataPackage enthält die Daten, die ein Thread in einem Durchlauf verarbeiten soll
boolean dataRemaining();
}
public class MainTask implements DataSource{
//Klasse für einen Thread
private class TaskRunner extends Thread{
private DataSource dataSource;
private boolean iAmReady;
TaskRunner(DataSource dataSource){
this.dataSource = dataSource;
iAmReady = false;
}
@Override
run(){
while(dataSource.dataRemaining()){
var dataPackage = dataSource.getNewDataPackage();
//Verarbeite Daten...
}
//Keine Daten mehr da
iAmReady = true;
}
boolean isReady(){
return iAmReady;
}
}
private HashSet<TaskRunner> runners;
private ArrayList<DataPackage> datapackages;
private Lock lock;
void executeTaskWithMultipleThreads(){
runners = new HashSet<>();
datapackages = getDataFromAnyWhere();
lock = new ReentrantLock();
for(int i = 0; i < cpuCores; i++){
TaskRunner thread = new TastRunner(this);
runners.add(thread); //Irgendwo mußt du die Threadobjekte ja halten, damit sie vom GC nicht gleich wieder abgeräumt werden. Außerdem willst du über alle Threads iterieren um zu prüfen, ob sie fertig sind.
}
while(threadsStillWorking()){
Thread.sleep(100); //Alle 100ms prüfen, ob alle Threads fertig sind
}
//Alles verarbeitet, weiter gehts...
}
private boolean threadsStillWorking(){
for(TaskRunner runner : runners){
if(!runner.isReady()){return false;}
}
return true;
}
private DataPackage getNewDataPackage(){
DataPackage dp;
lock.lock();
dp = datapackages.remove(0);
lock.unlock
return dp;
}
private boolean dataRemaining(){
boolea b;
lock.lock();
b = datapackages.hasElements();
lock.unlock();
return b;
}
}
Dann würde ich das anders aufbauen. Zunächst: Spare dir die äußere Schleife. Was sind epochs bei dir? Ist das eine Art Datenpaket, aus dem die Threadinstanzen ihre Samples bekommen? Wenn ja: schmeiß die epochs in ein HashSet. Oder, falls die Reihenfolge wichtig ist, in ein ArrayList.
Nein, epochs ist die Zahl, wie oft der jeweilige Thread über seine Beispielmenge iterieren soll. Ich habe das gesamte Konzept hier nochmal versucht darzustellen:
Die Beispielmenge, mit der ein einzelner Thread arbeitet, ändert sich also nicht.
Und warum ist es erforderlich, die Threads dauernd anzuhalten?
Jedenfalls kannst du es so machen:
Java:
class TaskThread extends Thread{
private boolean waitForSynchronisation;
@Override
public void run(){
for(...){
for(...){
//Mach was mit deinen Daten...
}
waitForSynchronisation = true;
while(waitForSynchronisation){
Thread.sleep(10);
}
}
}
void nextRun(){
waitForSynchronisation = false;
}
}
Diese Klasse kannst du dann so oft instanzieren wie du parallele Threads haben willst. Die Threads warten so lange, bis die Methode nextRun() aufgerufen wird. Das Thread.sleep sorgt dafür, daß dein Programm die CPU nicht mir Warten auslastet.
Ich persönlich finde, daß Thread erweitern deutlich besser lesbareren und verständlicheren Code ergibt als Runnalbe zu implementieren. Und Runnable wurde wohl auch von den Entwicklern nicht umsonst um Thread ergänzt.
Ich muss jetzt doch mal widersprechen, nachdem ich jetzt schon zum zweiten Mal die Empfehlung lese, dass Thread zu erweitern besser ist, als Runnable zu implementieren.
Dem ist mMn eindeutig nicht so. Mit run wird nur eine Methode von Runnable nichts von Thread überschrieben, es reicht also komlett aus nur Runnable zu implementieren.
Eine Variable, die von mehreren Threads verändert werden soll und deren Änderung man direkt mitbekommen will, sollte volatile sein.
Bezüglich Inheritance hat @thecain schon sehr schön etwas geschrieben.
Und fehlt da nicht noch irgend eine Erkennung bezüglich: Alle Threads sind durch?
Wenn ich mir #16 betrachte, dann geht es ja um Aufgaben die getan werden sollen:
Erst Aufgaben 0, 1, 2
Dann die 3, 4, 5
Und dann die 6, 7, 8
Ein erster, einfacher Ansatz, der mir da in den Kopf kommt, wäre da direkt (Thread Erstellung hat etwas Overhead, aber muss hier optimiert werden? Sieht ja nach einer kleinen Anzahl Aufgaben aus und es wird ja was mit Daten gemacht - also doch etwas mehr, so dass da die Thread Erstellung hoffentlich nicht ins Gewicht fällt):
- Für jeden Block
--> Für jede Aufgabe des Blockes:
---->Neuen Thread erstellen für Aufgabe und Thread starten
--> Für jeden Thread
----> thread.join()
Damit laufen die Aufgaben parallel und erst wenn alle Threads fertig sind, geht es weiter in der äußeren Schleife.
Aber dank Streams gibt es das ja auch in kurz:
Einfach z.B. eine List<Aufgabe> für jeden Block und da hat man dann ein einfaches aufgaben.parallelStream().forEach(a -> a.execute()); oder so.
Java:
import java.util.ArrayList;
import java.util.List;
public class Aufgabe {
private int number;
public Aufgabe(int number) {
this.number = number;
}
public void execute() {
for (int i = 0; i<10; i++) {
System.out.println("Aufgabe: " + number + ", step: " + i);
waitALittle();
}
}
private void waitALittle() {
try {
Thread.sleep((int)(Math.random() * 100));
} catch (InterruptedException ex) {}
}
public static void main(String[] args) {
List<Aufgabe> block1 = new ArrayList<>();
block1.add(new Aufgabe(0));
block1.add(new Aufgabe(1));
block1.add(new Aufgabe(2));
List<Aufgabe> block2 = new ArrayList<>();
block2.add(new Aufgabe(3));
block2.add(new Aufgabe(4));
block2.add(new Aufgabe(5));
List<List<Aufgabe>> allBlocks= new ArrayList<>();
allBlocks.add(block1);
allBlocks.add(block2);
for (List<Aufgabe> aufgaben : allBlocks) {
System.out.println("Starte Block ...");
aufgaben.parallelStream().forEach(a -> a.execute());
System.out.println("Block done!");
}
}
}
Und warum ist es erforderlich, die Threads dauernd anzuhalten?
Jedenfalls kannst du es so machen:
Java:
class TaskThread extends Thread{
private boolean waitForSynchronisation;
@Override
public void run(){
for(...){
for(...){
//Mach was mit deinen Daten...
}
waitForSynchronisation = true;
while(waitForSynchronisation){
Thread.sleep(10);
}
}
}
void nextRun(){
waitForSynchronisation = false;
}
}
Diese Klasse kannst du dann so oft instanzieren wie du parallele Threads haben willst. Die Threads warten so lange, bis die Methode nextRun() aufgerufen wird. Das Thread.sleep sorgt dafür, daß dein Programm die CPU nicht mir Warten auslastet.
Ich persönlich finde, daß Thread erweitern deutlich besser lesbareren und verständlicheren Code ergibt als Runnalbe zu implementieren. Und Runnable wurde wohl auch von den Entwicklern nicht umsonst um Thread ergänzt.
Also, ich habe das jetzt mal versucht nachzubauen, ich habe aber das Problem, dass der erste Durchlauf (also epoch = 0) problemlos durchläuft, aber die while-Schleife in einem Thread nicht wieder abgebrochen wird. Mein Code:
Java:
@Override
public void run() {
// repeat epochs-times
for (int e = 0; e < epochs; e++) {
// iterate over all samples
for (int s = 0; s < in.numberOfRows(); s++) {
// update weights and biases with sample s
updateWithSample(s);
}
// share the weights and biases of the current net
updater.update(net.getWeights(), net.getBiases());
isSynchronized = false;
while (!isSynchronized) {
try {
Thread.sleep(10);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
}
}
public void synchronize(Matrix[] weights, Vector[] biases) {
net.setWeights(weights);
net.setBiases(biases);
isSynchronized = true;
}
Und die update-Methode, die das "Aufwecken" übernimmt:
Java:
private final BackpropagationThread[] threads;
private volatile int seenThreads = 0;
public synchronized void update(Matrix[] weights, Vector[] biases) {
seenThreads++;
...
if (seenThreads == threads.length) {
...
seenThreads = 0;
for (BackpropagationThread t : threads) {
t.synchronize(net.getWeights(), net.getBiases());
}
}
}
Ich habe da mal ein paar Ausgaben eingefügt und folgendes passiert:
Alle (in meinem Fall 8 Threads) rufen update() auf
In update() wird für alle Threads synchronize() aufgerufen
7 von 8 Threads brechen die While-Schleife ab, einer nicht, wieso?
Die sinnvollste Lösung ist für diesen Fall wohl CyclicBarrier. JavaDoc dafür: „CyclicBarriers are useful in programs involving a fixed sized party of threads that must occasionally wait for each other.“
Das selbst zu machen ist eigentlich nur Fehlerquelle und Performance-Falle.
Die sinnvollste Lösung ist für diesen Fall wohl CyclicBarrier. JavaDoc dafür: „CyclicBarriers are useful in programs involving a fixed sized party of threads that must occasionally wait for each other.“
Also hier mein Versuch damit:
[CODE lang="java" title="BackpropagationThread"]public class BackpropagationThread extends Thread {
private final NeuralNetwork net;
private volatile NeuralNetwork sharedNet;
private final Matrix in;
private final Matrix out;
private final double alpha;
private final int epochs;
private final CyclicBarrier barrier;
public BackpropagationThread(final NeuralNetwork net, final CyclicBarrier barrier, final Matrix samples,
final Matrix expectedOutputs,
final double alpha, final int epochs) {
this.sharedNet = net;
this.net = net.clone();
this.barrier = barrier;
this.in = samples;
this.out = expectedOutputs;
this.alpha = alpha;
this.epochs = epochs;
}
@Override
public void run() {
// repeat epochs-times
for (int e = 0; e < epochs; e++) {
// iterate over all samples
for (int s = 0; s < in.numberOfRows(); s++) {
// update weights and biases with sample s on net
updateWithSample(s);
}
// share the weights and biases of the current net
update();
net.setWeights(newWeights);
net.setBiases(newBiases);
}
}[/CODE]
[CODE lang="java" title="CyclicMain"]public class CyclicMain {
private final int NUMBER_OF_THREADS;
public CyclicMain() {
NUMBER_OF_THREADS = 2;// Runtime.getRuntime().availableProcessors();
}
public NeuralNetwork run(NeuralNetwork net, Matrix in, Matrix out, double alpha, int epochs) {
List<Matrix> splittedIn = in.split(NUMBER_OF_THREADS);
List<Matrix> splittedOut = out.split(NUMBER_OF_THREADS);
BackpropagationThread[] threads = new BackpropagationThread[NUMBER_OF_THREADS];
AggregationThread aggregationThread = new AggregationThread(net, NUMBER_OF_THREADS);
CyclicBarrier barrier = new CyclicBarrier(NUMBER_OF_THREADS, aggregationThread);
for (int i = 0; i < NUMBER_OF_THREADS; i++) {
threads = new BackpropagationThread(net, barrier, splittedIn.get(i), splittedOut.get(i), alpha, epochs);
}
for (BackpropagationThread t : threads) {
t.start();
}
return net;
}
}[/CODE]
Und hier noch eine Main mit der ich das teste:
Java:
public static void main(String[] args) {
Matrix input = ...
Matrix output = ...
NeuralNetwork net = new FeedforwardNet(2, new int[] { 3 }, 1, new SigmoidFunction(), -0.5, 0.5);
// Abschnitt 1
for (int i = 0; i < testInput.numberOfRows(); i++) {
System.out.println(testInput.getRow(i) + ", result: " + net.predict(testInput.getRow(i)));
}
// Abschnitt 2
CyclicMain backprop = new CyclicMain();
final long timeStart = System.currentTimeMillis();
// Abschnitt 3
net = backprop.run(net, input, output, 0.01, 4);
// Abschnitt 4
final long timeEnd = System.currentTimeMillis();
System.out.println("\nTraining time: " + (timeEnd - timeStart) + " Millisek.");
System.out.println();
System.out.println(net.getOutputLayer().getInputWeights());
for (int i = 0; i < testInput.numberOfRows(); i++) {
System.out.println(testInput.getRow(i) + ", result: " + net.predict(testInput.getRow(i)));
}
}
Das Problem das noch bleibt ist Folgendes: Abschnitt 4 (In der main oben) wird ausgeführt, bevor Abschnitt 3 (also das eigentliche trainieren) fertig ist, wie verhindere ich das? Brauche ich da irgendwo noch ein join() und wenn ja wo?
Edit:
Eine solche Schleife:
Java:
for (Thread t : threads)
try {
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
vor dem return in CyclicMain.run löst das Problem. Ist es korrekt, das so zu lösen?
Ja, ein join in Cyclicmain vor dem return net; auf jedem BackpropagationThread sollte da klappen.
Zwei generelle Anmerkungen:
* extends Thread sollte generell eher vermieden werden, implements Runnable ist das das passendere
* private volatile NeuralNetwork sharedNet; sollte eher private final NeuralNetwork sharedNet; sein, volatile ist da nicht nötig, da die Referenz nicht verändert wird
Also wenn ich nur Attribute überschreibe (wie variable.x = new ....) reicht final und nur wenn ich die Referenz wechsle (also sharedNet = new Net(...)) brauche ich volatile?
Okay also das ganze Parallelisieren funktioniert jetzt, allerdings habe ich noch einen Fehler in der Logik (bei dem mir hier hoffentlich auch geholfen werden kann). Ich habe ja eine ganze Reihe BackpropagationThreads, die jeweils Berechnungen auf den Gewichten ausführen und das Ergebnis auf das SharedNet draufaddieren. Der AggregationThread nimmt jetzt diese Summe und teilt alles durch die Anzahl der BackpropagationThreads. Nun sollen die Gewichte in den jeweiligen Netzen in den Backprop.Threads mit denen aus dem AggregationThread geupdatet werden.
Ich habe das mit nur einem Thread getetset, was bedeutet, dass alle Gewichte durch 1 dividiert werden, was ja bedeutet, dass sie unverändert bleiben. Das updaten macht die Methode synchronize(). Wenn ich diese aber aufrufe ist das Ergebnis um ein vielfaches schlechter, wie wenn ich den Inhalt auskommentiere, obwohl die Gewichte ja eigentlich unverändert bleiben.
[CODE lang="java" title="AggregationThread"]
// net in dier Klasse ist die gleiche Instanz wie sharedNet in BackpropagationThread
@Override
public void run() {
epochCount++;
Matrix[] newWeights = net.getWeights();
Vector[] newBiases = net.getBiases();
// durch Anzahl BackpropagationThreads teilen
for(int i = 0; i < net.numberOfLayers(); i++) {
newWeights = newWeights.mul(1.0 / (double) NUMBER_OF_THREADS);
newBiases = newBiases.mul(1.0 / (double) NUMBER_OF_THREADS);
}
for (BackpropagationThread t : threads)
t.synchronize(newWeights, newBiases);
...
}[/CODE]
[CODE lang="java" title="BackpropagationThread"]@Override
public void run() {
// repeat epochs-times
for (int e = 0; e < epochs; e++) {
// iterate over all samples
for (int s = 0; s < in.numberOfRows(); s++) {
// update weights and biases with sample s
updateWithSample(s);
}
// share the weights and biases of the current net
update();