Parallel Reduction: Workload

Ich bin grad dabei etwas mit Parallelverabeitung rum zu spielen.
Das Beispiel ist recht simpel, es soll von einem Startwert bis zu einem Endwert einfach aufsummiert werden. Das Programm erhält als Übergabeparameter die Anzahl der Threads, den Startwert und den Endwert.
Die gesamte Range wird brüderlich aufgeteilt und genau hier habe ich ein Problem. Ist die Range gerade funktioniert alles wunderbar, aber ist sie ungerade dann wills nicht so recht.
Evtl. bin ich einfach schon betriebsblind und es ist schon zu spät, auf jeden Fall hänge ich hier fest ...

Java:
public class ParallelReduction {

	private int nThreads;
	private int startRange;
	private int endRange;
	private int stepLength;
	private int result;
	private Collection<Thread> threads;
	private Map<String, Integer> chunks; // lower bound for each Thread
	private Map<String, Integer> results;

	public ParallelReduction(String[] args) {
		this.result = 0;
		this.nThreads = Integer.parseInt(args[0]);
		this.startRange = Integer.parseInt(args[1]);
		this.endRange = Integer.parseInt(args[2]);
		this.threads = new Vector<Thread>();
		this.chunks = new Hashtable<String, Integer>();
		this.results = new Hashtable<String, Integer>();

		for (int i = 0; i < this.nThreads; ++i) {
			this.threads.add(new Thread(new Compute(), String.valueOf(i)));
		}
	}

	public int getStepLength() {
		return this.stepLength;
	}

	public void go() {
		this.workload();
		this.startThreads();
		this.sumResults();

		// print result + check
		System.out.println();
		System.out
				.println("Result (multithreaded computation): " + this.result);
		System.out.println();

		System.out.println("check: sequential computation");
		int res = 0;
		for (int i = this.startRange; i <= this.endRange; ++i) {
			res += i;
		}
		System.out.println("res: " + res);
	}

	private void workload() {
		// **ToDo: Handle uneven range of work
		if ((this.endRange - this.startRange) % this.nThreads != 0) {
			throw new IllegalArgumentException("Wrong Workload");
		}

		this.stepLength = (this.endRange - this.startRange) / this.nThreads;

		for (int i = this.startRange, j = 0; i < this.endRange; i += this.stepLength, ++j) {
			this.chunks.put(String.valueOf(j), i);
		}

	}

	private void startThreads() {
		for (Thread t : this.threads) {
			t.start();
		}
	}

	private void sumResults() {
		for (Thread t : this.threads) {
			try {
				t.join();
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
			this.result += this.results.get(t.getName());
		}
	}

	public static void main(String[] args) {
		if (args.length != 3) {
			System.out.println("usage:\t1. Number of Threads\n "
					+ "\t2. Start number of the range\n"
					+ "\t3. End number of the range");
			System.exit(-1);
		}

		ParallelReduction objP = new ParallelReduction(args);
		objP.go();
	}

	private class Compute implements Runnable {

		private int startRange;
		private int endRange;
		private int result = 0;

		@Override
		public void run() {
			this.startRange = chunks.get(Thread.currentThread().getName());

			if (this.startRange + stepLength < ParallelReduction.this.endRange) {
				this.endRange = this.startRange + stepLength - 1;
			}

			// it is the "last" Thread!
			else if (this.startRange + stepLength == ParallelReduction.this.endRange) {
				this.endRange = this.startRange + stepLength;
			}

			for (int i = this.startRange; i <= this.endRange; ++i) {
				this.result += i;
			}

			results.put(Thread.currentThread().getName(), this.result);

			System.out.println("Thread: " + Thread.currentThread().getName()
					+ " start: " + this.startRange + " end: " + this.endRange
					+ " result: " + this.result);
		}
	}
}
 

knilch

Bekanntes Mitglied
Hi,
Das Problem liegt dabei in der Methode workload:
Java:
if ((this.endRange - this.startRange) % this.nThreads != 0) {
    throw new IllegalArgumentException("Wrong Workload");
}
Wenn der Range ungerade ist, dann wirfst du eine Exception.
Hier mal ein Beispiel das funktionieren sollte, egal ob der Range gerade oder ungerade ist
Java:
private void workload() {
	int mod = (this.endRange - this.startRange) % this.nThreads;
	this.stepLength = (this.endRange - this.startRange) / this.nThreads;
	if (mod == 0) {
		for (int i = this.startRange, j = 0; i < this.endRange; i += this.stepLength, ++j) {
			this.chunks.put(String.valueOf(j), i);
		}
	}
	else {
		for (int i = this.startRange, j = 0; i < this.endRange; i += this.stepLength, ++j) {
			this.chunks.put(String.valueOf(j), i + mod);
		}
	}
}
 
Zuletzt bearbeitet:
Moin

Hatte leider bisher keine Zeit die Sache mal zu testen.
Die Exception werfe ich ja mit Absicht. ;)

Dein Lösungsansatz scheint schon mal in die richtige Richtung zu gehen, nur das
Code:
+ mod
dürfte wohl nicht jedem Thread mit gegeben werden, sondern nur einem.

Java:
private void workload() {
   int mod = (this.endRange - this.startRange) % this.nThreads;
   this.stepLength = (this.endRange - this.startRange) / this.nThreads;

   for (int i = this.startRange, j = 0; i < this.endRange; i += this.stepLength, ++j) {
      if (j == 0) {
         this.chunks.put(String.valueOf(j), i + mod);
      }
      else {
         this.chunks.put(String.valueOf(j), i);
      }
   }
}

Ich hoffe ich komme zeitig mal dazu es auszuprobieren. Wahrscheinlich bin ich nicht drauf gekommen, weil es meinem Sinn für Symetrie widerspricht. :D
 
Ganz vergessen :oops:... das funktioniert so! :toll:

Zu früh gefreut :autsch:

Aber ich bin nicht doof und hab etwas rum probiert.
Die Methode workload() hat sich vereinfacht. In der run() Methode der Threads packt sich der "letzte" Thread (, der mit der höchsten Nummer,) zusätzliche Schleifendurchläufe in endRange.
[Java]
public class ParallelReduction {

private int numberThreads;
private int startRange;
private int endRange;
private int stepLength;
private int result;
private int mod;
private Collection<Thread> threads;
private Map<String, Integer> chunks; // lower bound for each Thread
private Map<String, Integer> results;

public ParallelReduction(String[] args) {
this.result = 0;
this.numberThreads = Integer.parseInt(args[0]);
this.startRange = Integer.parseInt(args[1]);
this.endRange = Integer.parseInt(args[2]);
this.threads = new Vector<Thread>();
this.chunks = new Hashtable<String, Integer>();
this.results = new Hashtable<String, Integer>();

for (int i = 0; i < this.numberThreads; ++i) {
this.threads.add(new Thread(new Compute(), String.valueOf(i)));
}

this.mod = (this.endRange - this.startRange) % this.numberThreads;
this.stepLength = (this.endRange - this.startRange)
/ this.numberThreads;
}

public int getStepLength() {
return this.stepLength;
}

public int getNumberThreads() {
return this.numberThreads;
}

public void go() {
System.out.println("MOD: " + this.mod);
this.workload();
this.startThreads();
this.sumResults();

// print result + check
System.out.println();
System.out
.println("Result (multithreaded computation): " + this.result);
System.out.println();

System.out.println("check: sequential computation");
int res = 0;
for (int i = this.startRange; i <= this.endRange; ++i) {
res += i;
}
System.out.println("res: " + res);
}

private void workload() {
for (int i = this.startRange, j = 0; i < this.endRange; i += this.stepLength, ++j) {
this.chunks.put(String.valueOf(j), i);
}
}

private void startThreads() {
for (Thread t : this.threads) {
t.start();
}
}

private void sumResults() {
for (Thread t : this.threads) {
try {
t.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
this.result += this.results.get(t.getName());
}
}

public static void main(String[] args) {
if (args.length != 3) {
System.out.println("usage:\t1. Number of Threads\n "
+ "\t2. Start number of the range\n"
+ "\t3. End number of the range");
System.exit(-1);
}

ParallelReduction objP = new ParallelReduction(args);
objP.go();
}

/*
* This Class implements the Thread
*/
private class Compute implements Runnable {

private int startRange;
private int endRange;
private int result = 0;

@Override
public void run() {
this.startRange = chunks.get(Thread.currentThread().getName());

this.endRange = this.startRange + ParallelReduction.this.stepLength
- 1;

// it is the "last" Thread!
if (Integer.valueOf(Thread.currentThread().getName()) + 1 == ParallelReduction.this.numberThreads) {
this.endRange += ParallelReduction.this.mod + 1;
}

for (int i = this.startRange; i <= this.endRange; ++i) {
this.result += i;
}

ParallelReduction.this.results.put(
Thread.currentThread().getName(), this.result);

System.out.println("Thread: " + Thread.currentThread().getName()
+ " start: " + this.startRange + " end: " + this.endRange
+ " result: " + this.result);
}
}
}
[/Java]
 

Ähnliche Java Themen

Neue Themen


Oben