Q - Learning Algorithmus Bug

Diskutiere Q - Learning Algorithmus Bug im Allgemeine Java-Themen Forum; Hey, nachdem ich versucht habe der Stackoverflowcommunity eine bessere Antwort herauszulocken, versuche ich es mal hier. Ich versuche eine AI zu...

  1. Feeder
    Feeder Mitglied
    Hey,

    nachdem ich versucht habe der Stackoverflowcommunity eine bessere Antwort herauszulocken, versuche ich es mal hier.
    Ich versuche eine AI zu entwickeln, die zum späteren Teil einmal Astroids spielt. Nun soll aber das Q - Learning zu nächst allgemein funktionieren.
    Die Brain Klasse sieht wie folgt aus:


    Code (Java):
    package rlgame;

    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.List;

    import org.encog.engine.network.activation.ActivationLOG;
    import org.encog.engine.network.activation.ActivationLinear;
    import org.encog.engine.network.activation.ActivationSigmoid;
    import org.encog.engine.network.activation.ActivationSoftMax;
    import org.encog.ml.data.MLData;
    import org.encog.ml.data.MLDataSet;
    import org.encog.ml.data.basic.BasicMLData;
    import org.encog.ml.data.basic.BasicMLDataSet;
    import org.encog.neural.networks.BasicNetwork;
    import org.encog.neural.networks.layers.BasicLayer;
    import org.encog.neural.networks.training.propagation.back.Backpropagation;

    public class Brain {
        private ArrayList<ArrayList<Tuple>> biglist = new ArrayList<ArrayList<Tuple>>();
        BasicNetwork nn;
        BasicNetwork oldnn;
        private int index = 0;
        MLDataSet set = new BasicMLDataSet();

        public Brain() {
            nn = new BasicNetwork();
            nn.addLayer(new BasicLayer(new ActivationLinear(),true,29));
            nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
            nn.addLayer(new BasicLayer(new ActivationSigmoid(),true,20));
            nn.addLayer(new BasicLayer(new ActivationLinear(),false,5));
            nn.getStructure().finalizeStructure();
            nn.reset();
            oldnn = (BasicNetwork) nn.clone();
         
        }
     

        public void rlearn(ArrayList<Tuple> tupels, double learningrate, double discountfactor, boolean rememberTuples) {
            if(rememberTuples)biglist.add(tupels);
         
            //newQ = sum of all rewards you have got through
            for(int i = tupels.size()-1; i > 0; i--) {
                MLData in = new BasicMLData(29);
                MLData out = new BasicMLData(5);
             
                //Add State as in
                int index = 0;
                for(double w : tupels.get(i).statefirst.elements) {
                    in.add(index++, w);
                }
             
                //Now start updating Q - Values
                double qnew = 0;
                if(i <= tupels.size()-2){
                    qnew = tupels.get(i).rewardafter + discountfactor*qMax(tupels.get(i).stateafter);
                } else {
                    qnew = tupels.get(i).rewardafter;
                }
             
                tupels.get(i).qactions.elements[tupels.get(i).actionTaken] = qnew;
                //Add Q Values as out
                index = 0;
                for(double w : tupels.get(i).qactions.elements) {
                    out.add(index++, w);
                }
               
             
                set.add(in, out);          
            }
         
         
        }
         
        private double qMax(VectorND stateafter) {
            double[] qactions = oldnn.compute(new BasicMLData(stateafter.elements)).getData();
            double max = Double.MIN_VALUE;
            for(double w : qactions) {
                if(w > max) {
                    max = w;
                }
            }
            return max;
        }


        public double[] getOutput(MLData input) {
            return nn.compute(input).getData();
        }
     
        public void experienceReplay(double learningRate, double discountFactor) {
     
            for(int i = 0; i < 10; i++) {
                Collections.shuffle(biglist);
                List<ArrayList<Tuple>> list = biglist.subList(0, (int)(biglist.size()*0.3));
                for(ArrayList<Tuple> tuples : list) {
                    rlearn(tuples,learningRate, discountFactor, false);
                }
                Backpropagation prop = new Backpropagation(nn, set);
                prop.setLearningRate(learningRate);
                prop.iteration(10);
                System.out.println(prop.getError());
            }
         
            oldnn = (BasicNetwork) nn.clone();
            if(biglist.size() > 10000) {
                System.out.println("List trimmed.");
                while(biglist.size() > 10000) {
                    biglist.remove(biglist.size()-1);
                }
            }
            set = new BasicMLDataSet();

        }
        public void addTuples(ArrayList<Tuple> tuples) {
            biglist.add(tuples);
        }


    }
    Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen. Unglücklicherweise ist in dieser Klasse ein ziemlich idiotischer Bug, ich weiß nur nicht wo... Ich benutze ein älteres Netz um den maximalen Q - Value des jeweils nächsten States zu berechnen, das soll Stabilität gewährleisten.

    Das Tuplecode findet sich hier:

    Code (Text):
    package rlgame;

    import java.util.ArrayList;

    public class Tuple {
        VectorND statefirst = new VectorND();
        VectorND stateafter = new VectorND();
        VectorND qactions = new VectorND();
        double rewardafter;
        int actionTaken;

    }
     


    package rlgame;
    import java.util.ArrayList;
    public class Tuple {
    VectorND statefirst = new VectorND();
    VectorND stateafter = new VectorND();
    VectorND qactions = new VectorND();
    double rewardafter;
    int actionTaken;
    }

    Den gesamten Code findet ihr hier:

    https://github.com/SuchtyTV/RLearningBird
     
  2. Vielleicht hilft dir dieses Buch hier weiter.
  3. mihe7
    mihe7 Bekanntes Mitglied
    Wenn Du da auch keine Frage gestellt hast... ;)
     
  4. Feeder
    Feeder Mitglied
    Was läuft so falsch?
     
  5. httpdigest
    httpdigest Bekanntes Mitglied
    Ja, was läuft denn falsch? Aktuell hast du nur Code gepostet, gesagt, `Die rLearn Methode evaluiert den Fehler bzw. den neuen Q - value, während die experienceReplay Methode versucht dem Netz etwas beizubringen.`, wozu man nur sagen kann "aha" und dann erwähnst du, dass irgendwo in dem Code ein Fehler ist, ohne darauf hinzuweisen, wo denn die Diskrepanz zwischen dem, was der Code tut und dem, was du erwartest, liegt.
    Dann wäre meine erste Frage also erstmal: Woher weißt du denn, dass in dem Code ein Fehler ist? Worin äußert sich das?
     
    JuKu gefällt das.
  6. Feeder
    Feeder Mitglied
    Der Punkt geht an dich.
    Naja zum einen sind die Fehler nach der Backpropagation viel zu groß.
    Ich habe nun die Sigmoidfunktionen mit Logarithmen ausgetauscht.
    Nachdem sind die Fehler im Bereich 0 bis 20; besser...
    Die Q-Funktion wird dennoch nicht korrekt approximiert. (vielleicht ein Overfit, was ich aber nicht annehme)
    Außerdem erhalte ich immer ein OutOfMemoryError (den ich zwar fixen kann, dennoch trotzdem nicht sicher bin wo der her rührt.)
     
    Zuletzt bearbeitet: 6. Jan. 2019
  7. Wenn du Java lernen möchtest, empfehlen wir dir dieses Buch hier
Passende Stellenanzeigen aus deiner Region:





Die Seite wird geladen...

Q - Learning Algorithmus Bug - Ähnliche Themen

Python Deep Learning Problem
Python Deep Learning Problem im Forum Hausaufgaben
Text mining / deep learning
Text mining / deep learning im Forum Allgemeine Java-Themen
Machine-learning Framework
Machine-learning Framework im Forum Allgemeine Java-Themen
Machine Learning - LibSVM und DeepLearning4j
Machine Learning - LibSVM und DeepLearning4j im Forum Private Stellangebote und Stellensuche von Usern
Datenbank oder Filesystem? (E-Learning)
Datenbank oder Filesystem? (E-Learning) im Forum Allgemeine Java-Themen
Thema: Q - Learning Algorithmus Bug