package javaforum;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
public class WeightedRandomTest
{
public static void main(String[] args)
{
WeightedGenerator<Integer> w = new WeightedGenerator<Integer>();
w.setWeight(15, 40.0);
w.setWeight(19, 40.0);
w.setWeight(14, 19.5);
w.setWeight(77, 0.5);
Random random = new Random(0);
Map<Integer, Integer> counts = new HashMap<Integer, Integer>();
int n = 100000;
for (int i=0; i<n; i++)
{
//double d = (double)i/(n-1);
//Integer result = w.generate(d);
//System.out.println("For "+d+" got "+result);
Integer result = w.generate(random);
Integer count = counts.get(result);
if (count == null)
{
counts.put(result, 1);
}
else
{
counts.put(result, count+1);
}
}
for (Integer k : counts.keySet())
{
Integer count = counts.get(k);
System.out.println("Element "+k+" probability "+(double)count/n);
}
}
}
class WeightedGenerator<T>
{
private Map<T,Double> table = new LinkedHashMap<T,Double>();
private Double totalWeight = (double) 0;
public void setWeight(T key,Double weight)
{
if(weight<0) weight=(double) 0;
totalWeight-=getWeight(key);
table.put(key,weight);
totalWeight+=getWeight(key);
}
public Double getWeight(T key)
{
if(table.containsKey(key))
{
return table.get(key);
}
else
{
return (double) 0;
}
}
public T generate(Random generator)
{
T result = null;
Iterator<T> iterator = table.keySet().iterator();
Double generated = generator.nextDouble()*totalWeight;
while(generated>=0 && iterator.hasNext())
{
result = iterator.next();
generated -= getWeight(result);
}
return result;
}
public T generate(double randomValue)
{
T result = null;
Iterator<T> iterator = table.keySet().iterator();
Double generated = randomValue*totalWeight;
while(generated>=0 && iterator.hasNext())
{
result = iterator.next();
generated -= getWeight(result);
}
return result;
}
public int size()
{
return table.size();
}
public void clear()
{
totalWeight = (double) 0;
table.clear();
}
}