Q - Learning Algorithmus Bug

Bitte aktiviere JavaScript!
Hey,

nachdem ich versucht habe der Stackoverflowcommunity eine bessere Antwort herauszulocken, versuche ich es mal hier.
Ich versuche eine AI zu entwickeln, die zum späteren Teil einmal Astroids spielt. Nun soll aber das Q - Learning zu nächst allgemein funktionieren.
Die Brain Klasse sieht wie folgt aus:


Java:
package rlgame;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.encog.engine.network.activation.ActivationLOG;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.engine.network.activation.ActivationSoftMax;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.back.Backpropagation;

public class Brain {
    private ArrayList<ArrayList<Tuple>> biglist = new ArrayList<ArrayList<Tuple>>();
    BasicNetwork nn;
    BasicNetwork oldnn;
    private int index = 0;
    MLDataSet set = new BasicMLDataSet();

    public Brain() {
        nn = new BasicNetwork();
        nn.addLayer(new BasicLayer(new ActivationLinear(),true,29));
        nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
        nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
        nn.addLayer(new BasicLayer(new ActivationLinear(),false,5));
        nn.getStructure().finalizeStructure();
        nn.reset();
        oldnn = (BasicNetwork) nn.clone();
      
    }
  

    public void rlearn(ArrayList<Tuple> tupels, double learningrate, double discountfactor, boolean rememberTuples) {
        if(rememberTuples)biglist.add(tupels);
      
        //newQ = sum of all rewards you have got through
        for(int i = tupels.size()-1; i > 0; i--) {
            MLData in = new BasicMLData(29);
            MLData out = new BasicMLData(5);
          
            //Add State as in
            int index = 0;
            for(double w : tupels.get(i).statefirst.elements) {
                in.add(index++, w);
            }
          
            //Now start updating Q - Values
            double qnew = 0;
            if(i <= tupels.size()-2){
                qnew = tupels.get(i).rewardafter + discountfactor*qMax(tupels.get(i).stateafter);
            } else {
                qnew = tupels.get(i).rewardafter;
            }
          
            tupels.get(i).qactions.elements[tupels.get(i).actionTaken] = qnew;
            //Add Q Values as out
            index = 0;
            for(double w : tupels.get(i).qactions.elements) {
                out.add(index++, w);
            }
           
          
            set.add(in, out);          
        }
      
      
    }
      
    private double qMax(VectorND stateafter) {
        double[] qactions = oldnn.compute(new BasicMLData(stateafter.elements)).getData();
        double max = Double.MIN_VALUE;
        for(double w : qactions) {
            if(w > max) {
                max = w;
            }
        }
        return max;
    }


    public double[] getOutput(MLData input) {
        return nn.compute(input).getData();
    }
  
    public void experienceReplay(double learningRate, double discountFactor) {
  
        for(int i = 0; i < 10; i++) {
            Collections.shuffle(biglist);
            List<ArrayList<Tuple>> list = biglist.subList(0, (int)(biglist.size()*0.3));
            for(ArrayList<Tuple> tuples : list) {
                rlearn(tuples,learningRate, discountFactor, false);
            }
            Backpropagation prop = new Backpropagation(nn, set);
            prop.setLearningRate(learningRate);
            prop.iteration(10);
            System.out.println(prop.getError());
        }
      
        oldnn = (BasicNetwork) nn.clone();
        if(biglist.size() > 10000) {
            System.out.println("List trimmed.");
            while(biglist.size() > 10000) {
                biglist.remove(biglist.size()-1);
            }
        }
        set = new BasicMLDataSet();

    }
    public void addTuples(ArrayList<Tuple> tuples) {
        biglist.add(tuples);
    }


}
Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen. Unglücklicherweise ist in dieser Klasse ein ziemlich idiotischer Bug, ich weiß nur nicht wo... Ich benutze ein älteres Netz um den maximalen Q - Value des jeweils nächsten States zu berechnen, das soll Stabilität gewährleisten.

Das Tuplecode findet sich hier:

Code:
package rlgame;

import java.util.ArrayList;

public class Tuple {
    VectorND statefirst = new VectorND();
    VectorND stateafter = new VectorND();
    VectorND qactions = new VectorND();
    double rewardafter;
    int actionTaken;

}


package rlgame;
import java.util.ArrayList;
public class Tuple {
VectorND statefirst = new VectorND();
VectorND stateafter = new VectorND();
VectorND qactions = new VectorND();
double rewardafter;
int actionTaken;
}

Den gesamten Code findet ihr hier:

https://github.com/SuchtyTV/RLearningBird
 
A

Anzeige


Vielleicht hilft dir dieser Kurs hier weiter: (hier klicken)
Ja, was läuft denn falsch? Aktuell hast du nur Code gepostet, gesagt, `Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen.`, wozu man nur sagen kann "aha" und dann erwähnst du, dass irgendwo in dem Code ein Fehler ist, ohne darauf hinzuweisen, wo denn die Diskrepanz zwischen dem, was der Code tut und dem, was du erwartest, liegt.
Dann wäre meine erste Frage also erstmal: Woher weißt du denn, dass in dem Code ein Fehler ist? Worin äußert sich das?
 
Der Punkt geht an dich.
Naja zum einen sind die Fehler nach der Backpropagation viel zu groß.
Ich habe nun die Sigmoidfunktionen mit Logarithmen ausgetauscht.
Nachdem sind die Fehler im Bereich 0 bis 20; besser...
Die Q-Funktion wird dennoch nicht korrekt approximiert. (vielleicht ein Overfit, was ich aber nicht annehme)
Außerdem erhalte ich immer ein OutOfMemoryError (den ich zwar fixen kann, dennoch trotzdem nicht sicher bin wo der her rührt.)
 
Zuletzt bearbeitet:
Passende Stellenanzeigen aus deiner Region:

Neue Themen

Oben