Machine-learning Framework

Hallo Leute, ich habe mich in den letzten Wochen mit neuralen Netzwerken beschäftigt und habe schon ein neurales Netzwerk, welches ein csv File einliest, erstellt. Jetzt möchte ich gerne mit Backpropagation arbeiten, ich weiß, wie dies theoretisch funktioniert, aber nicht wie man das implementieren kann.
Die Aufgabe, die ich von der Uni aus lösen muss, ist folgende:

The resulting software package shall not contain any problem-specific part (e.g. how to recognize a digit). But it shall work with an ANN that supports multiple hidden layers (as specified with your API or XML document from assignment 11). This requires adapting the backpropagation algorithm.

Use this framework to implement the digit recognition example with a 5 layer ANN.

Avoid making the same mistakes as in the pouring-problem assignment (particularly, be sure not to mix problem-specific code with problem-independent one). You'll notice that there is a tradeoff between the amount of code that is required to use the framework and the degree of flexibility that your framework supports. Which assumptions can you make (e.g. input-format, input data type, ...) without restricting the system too much?

Ich sollte also ein Framework erstellen, welches eben mit den 5 Layern eine Ziffer erkennen sollte. Das csv. File ist schon vorhanden und ich habe schon folgenden Code vom vorherigen Beispiel, dieser sollte adaptiert werden.

Java:
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class App {
    public static void main(String[] args) throws Exception {
        DigitRecognizer recognizer = new DigitRecognizer();
        recognizer.init(new File("mnist_train_100.csv"));
        recognizer.train();
        // scorecard for how well the network performs, initially empty
        int attemptsOk = 0;
        int attemptsFailed = 0;
        try (Stream<String> stream = Files
                .lines(Paths.get("mnist_test_10.csv"))) {
            List<String> testDataList = stream.collect(Collectors.toList());

            System.out.println("correct | recognized");

            // go through all the records in the test data set
            for (String record : testDataList) {
                // split the record by the ',' commas
                String[] allValues = record.split(",", 2);

                // correct answer is first value
                int correctDigit = Integer.parseInt(allValues[0]);
                int recognizedDigit = recognizer.recognize(allValues[1]);

                if (correctDigit == recognizedDigit)
                    attemptsOk++;
                else
                    attemptsFailed++;

                System.out.println("  " + correctDigit + "     |     " + recognizedDigit);
            }
        }
        // calculate the performance score, the fraction of correct answers
        System.out.println("performance = " + (double) attemptsOk / (attemptsOk + attemptsFailed) * 100 + "%");
    }
}
Java:
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;


public class DigitRecognizer implements Assignment9 {

    private List<Integer> trainingDataDigit;
    private List<List<Double>> trainingData;
    private static final int amountOfTargets = 10;
    private double learningRate;
    private Function<Double, Double> activation;
    private Random random = new Random();
    private Matrix wih;
    private Matrix who;

    /**
     * Loads the .csv file with the training data or throws an Exception if anything goes wrong;
     * returns true iff the initialization completed successfully.
     *
     * @param csvTrainingData
     *            the data used to train the neural network
     * @return true if the initialization was successful
     */
    @Override
    public boolean init(File csvTrainingData) throws Exception {
        try (BufferedReader br = new BufferedReader(new FileReader(csvTrainingData))) {
            trainingData = new ArrayList<>();
            trainingDataDigit = new ArrayList<>();

            br.lines().forEach(s -> {
                String[] allValues = s.split(",", 2);

                trainingDataDigit.add(Integer.parseInt(allValues[0]));

                // scale and shift the inputs
                trainingData.add(Arrays.stream(allValues[1].split(","))
                        .mapToDouble(x -> Double.parseDouble(x) / 255.0 * 0.99 + 0.01)
                        .boxed()
                        .collect(Collectors.toList()));
            });
        }

        if (trainingData.size() > 0) {
            init(trainingData.get(0).size(), 200, amountOfTargets, 0.1, (Double x) -> 1.0 / (1.0 + Math.exp(-x)));

            return true;
        } else
            return false;
    }

    /**
     * trains the neural network used for digit recogniztion.
     *
     * @return true iff the training of the neural network was successful.
     * @throws Exception
     */
    @Override
    public boolean train() throws Exception {
        // create the target output values (all 0.01, except the desired label which is 0.99)
        double[] targets = DoubleStream.generate(() -> 0.01).limit(amountOfTargets).toArray();

        for (int epochs = 0; epochs < 5; epochs++) {
            for (int i = 0; i < trainingData.size(); i++) {
                targets[trainingDataDigit.get(i)] = 0.99;

                double[] inputs = Arrays.stream(trainingData.get(i).toArray()).mapToDouble(d -> (double) d).toArray();

                train(inputs, targets);

                targets[trainingDataDigit.get(i)] = 0.01;
            }
        }

        return true;
    }

    private int indexFromMax(double[] data) {
        int max = 0;

        for (int i = 0; i < data.length; i++)
            if (data[i] > data[max])
                max = i;

        return max;
    }

    /**
     * Tries to recognize the digit represented by csvString.
     *
     * @param csvString
     *            the digit pattern as CSV string.
     * @return the recognized digit
     */
    @Override
    public int recognize(String csvString) throws Exception {
        // scale and shift the inputs
        double[] inputs = Arrays.stream(csvString.split(","))
                .mapToDouble(s -> Double.parseDouble(s) / 255.0 * 0.99 + 0.01)
                .toArray();

        double[] outputs = query(inputs);

        // the index of the highest value corresponds to the label
        return indexFromMax(outputs);
    }

    /**
     * @param inputNodes
     *            the amount of input nodes
     * @param hiddenNodes
     *            the amount of hidden nodes
     * @param outputNodes
     *            the amount of output nodes
     * @param learningRate
     *            the learning rate
     * @param activation
     *            the activation function
     */
    public void init(int inputNodes, int hiddenNodes, int outputNodes, double learningRate, Function<Double, Double> activation) {
        this.learningRate = learningRate;
        this.activation = activation;

        /*link weight matrices, wih and who
         weights inside the arrays are w_i_j, where link is from node i to node j in the next layer
         w11 w21
         w12 w22 etc*/
        wih = new Matrix(hiddenNodes, inputNodes);
        who = new Matrix(outputNodes, hiddenNodes);

        fillWithRandomValues(wih, 0, Math.pow(inputNodes, -0.5));
        fillWithRandomValues(who, 0, Math.pow(hiddenNodes, -0.5));
    }

    private void fillWithRandomValues(Matrix matrix, double mean, double variance) {
        for (int row = 0; row < matrix.getRowDimension(); row++)
            for (int col = 0; col < matrix.getColumnDimension(); col++)
                matrix.set(row, col, nextRandom(mean, variance));
    }

    private double nextRandom(double mean, double variance) {
        return mean + random.nextGaussian() * variance;
    }

    /**
     * Trains the neural network with the given input and target values.
     *
     * @param inputsList
     *            the input values to be used
     * @param targetsList
     *            the target values for the given input values
     */
    public void train(double[] inputsList, double[] targetsList) {
        Matrix inputs = toMatrix(inputsList);
        Matrix targets = toMatrix(targetsList);

        // calculate signals into layer
        Matrix hiddenInputs = wih.matrixMultiplication(inputs);
        // calculate the signals emerging from hidden layer
        Matrix hiddenOutputs = hiddenInputs.applyFuntion(activation);

        // calculate signals into final output layer
        Matrix finalInputs = who.matrixMultiplication(hiddenOutputs);
        // calculate the signals emerging from final output layer
        Matrix finalOutputs = finalInputs.applyFuntion(activation);

        // output layer error is the (target - actual)
        Matrix outputErrors = targets.matrixSubstraction(finalOutputs);
        // hidden layer error is the output_errors, split by weights, recombined at hidden nodes
        Matrix hiddenErrors = who.transposeMatrix().matrixMultiplication(outputErrors);

        // update the weights for the links between the hidden and output layers
        who = who.matrixAddition(outputErrors.multByElement(finalOutputs)
                .multByElement(finalOutputs.applyFuntion(d -> 1.0 - d))
                .matrixMultiplication(hiddenOutputs.transposeMatrix())
                .scalarMultiplication(learningRate));

        // update the weights for the links between the input and hidden layers
        wih = wih.matrixAddition(hiddenErrors.multByElement(hiddenOutputs)
                .multByElement(hiddenOutputs.applyFuntion(d -> 1.0 - d))
                .matrixMultiplication(inputs.transposeMatrix())
                .scalarMultiplication(learningRate));
    }

    /**
     * Queries the output of the neural network for a given input.
     *
     * @param inputsList
     *            the input to query for.
     * @return the output from the network.
     */
    public double[] query(double[] inputsList) {
        Matrix inputs = toMatrix(inputsList);

        // calculate signals into hidden layer
        Matrix hiddenInputs = wih.matrixMultiplication(inputs);
        // calculate the signals emerging from hidden layer
        Matrix hiddenOutputs = hiddenInputs.applyFuntion(activation);

        // calculate signals into final output layer
        Matrix finalInputs = who.matrixMultiplication(hiddenOutputs);
        // calculate the signals emerging from final output layer
        Matrix finalOutputs = finalInputs.applyFuntion(activation);

        return toArray(finalOutputs);
    }

    private Matrix toMatrix(double[] data) {
        Matrix result = new Matrix(data.length, 1);

        for (int i = 0; i < data.length; i++)
            result.set(i, 0, data[i]);

        return result;
    }

    private double[] toArray(Matrix matrix) {
        double[] result = new double[matrix.getRowDimension()];

        for (int i = 0; i < result.length; i++)
            result[i] = matrix.get(i, 0);

        return result;
    }
}
Java:
import java.lang.StringBuilder;
import java.util.function.Function;


public class Matrix {

    private final int row;
    private final int col;

    private double[][] elements;

    public Matrix(int row, int col) {
        this.row = row;
        this.col = col;
        this.elements = new double[row][col];
    }

    public int getRowDimension() {
        return row;
    }

    public int getColumnDimension() {
        return col;
    }

    public Matrix transposeMatrix() {
        Matrix B = new Matrix(this.col, this.row);
        for (int row = 0; row < this.col; row++)
            for (int col = 0; col < this.row; col++)
                B.set(row, col, this.get(col, row));
        return B;
    }

    public void set(int row, int col, double e) {
        elements[row][col] = e;
    }

    public double get(int row, int col) {
        return elements[row][col];
    }

    public Matrix scalarAddition(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, a + this.get(row, col));
        return B;
    }

    public Matrix scalarSubstraction(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, a - this.get(row, col));
        return B;
    }

    public Matrix scalarMultiplication(double a) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, this.get(row, col) * a);
        return B;
    }

    public Matrix applyFuntion(Function<Double, Double> f) {
        Matrix B = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                B.set(row, col, f.apply(this.get(row, col)));
        return B;
    }

    public Matrix multByElement(Matrix B) {
        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                C.set(row, col, this.get(row, col) * B.get(row, col));
        return C;
    }

    public Matrix matrixAddition(Matrix B) {
        if (!(this.row == B.row && this.col == B.col))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                C.set(row, col, this.get(row, col) + B.get(row, col));
        return C;
    }

    public Matrix matrixSubstraction(Matrix B) {
        if (!(this.row == B.row && this.col == B.col))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, this.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < this.col; col++)
                 C.set(row, col, this.get(row, col) - B.get(row, col));
        return C;
    }

    public Matrix matrixMultiplication(Matrix B) {
        if (!(this.col == B.row))
            throw new RuntimeException("the matrices dimensions do not add up");

        Matrix C = new Matrix(this.row, B.col);
        for (int row = 0; row < this.row; row++)
            for (int col = 0; col < B.col; col++) {
                double sum = 0;
                for (int k = 0; k < this.col; k++)
                    sum += this.get(row, k) * B.get(k, col);
                C.set(row, col, sum);
            }
        return C;
    }

    public String toString() {
        StringBuilder str = new StringBuilder();
        for (double row[] : elements) {
            for (double d : row)
                str.append(d + " ");
            str.append("\n");
        }
        return str.toString();
    }
}


Ich hoffe ihr könnt mir zeigen, wie ich das Programm so umschreiben kann, dass es mit der Backpropagation und den 5 Layern funktioniert!

Ganz liebe Grüße!!
 
Zuletzt bearbeitet:
Ähnliche Java Themen
  Titel Forum Antworten Datum
F KI / Machine Learning Parameter verschachtelte for Schleifen Allgemeine Java-Themen 2
F KI / Machine Learning Parameter verschachtelte for Schleifen Allgemeine Java-Themen 1
M SQL-Developer Installation: Unable to launch the Java Virtual Machine Located at path msvcr100.dll Allgemeine Java-Themen 1
M Java Virtual Machine Launcher (Fehlermeldung) Allgemeine Java-Themen 8
vodkaz state machine Allgemeine Java-Themen 1
D Java Virtual Machine als Betriebssystem Allgemeine Java-Themen 7
J Post: #1could not create the java virtual machine etc Allgemeine Java-Themen 6
S could not create the java virtual machine Allgemeine Java-Themen 3
T Suche Anhaltspunkt für plattformübergreifende, "unique machine id" ... Allgemeine Java-Themen 12
K Could not create the Java Virtual Machine Allgemeine Java-Themen 1
B Java Virtual Machine Allgemeine Java-Themen 4
D HotSpot Virtual Machine stürzt ab Allgemeine Java-Themen 10
F Q - Learning Algorithmus Bug Allgemeine Java-Themen 4
windl Text mining / deep learning Allgemeine Java-Themen 0
T Datenbank oder Filesystem? (E-Learning) Allgemeine Java-Themen 2
D JUNG Framework edge length Allgemeine Java-Themen 0
R Best Practice Erfahrungswerte für eine Migration von JSF nach Angular (oder anderes JS-Framework) Allgemeine Java-Themen 1
OnDemand PDF Erstellung / Reports Framework Allgemeine Java-Themen 3
OnDemand Pluginsystem Framework Allgemeine Java-Themen 8
Z Welches GUI Framework für Java ist aktuell? Allgemeine Java-Themen 16
S Interface Design von HookUp oder Callback Methoden für eigenes Framework Allgemeine Java-Themen 9
Kirby.exe Framework für Game Design Allgemeine Java-Themen 8
C Gutes Framework für ein Neuronales Netz Allgemeine Java-Themen 15
D Library/Framework zum Umwandeln von Sound in Notenbilder Allgemeine Java-Themen 1
G Framework von nöten? Allgemeine Java-Themen 1
C BlackBox-Framework - Plugin Programmierung Allgemeine Java-Themen 4
F Framework/Plugin für Tree-Darstellung in Graph Allgemeine Java-Themen 0
F Parser Framework/Plugin für Datei in Custom-Format Allgemeine Java-Themen 2
W Suche Framework zur Prüfung von IPv4 und IPv6 Allgemeine Java-Themen 2
J Interface Interface für Framework verwenden Allgemeine Java-Themen 4
M Suche Framework/API für Monitoring-Anwendung Allgemeine Java-Themen 3
S Android: SQLite Framework einbinden Allgemeine Java-Themen 2
B Experte Play Framework 1.2.5 Allgemeine Java-Themen 5
S OOP Problembereichsmodell: Bestehende Framework Klasse in eigene Klassenstruktur einbinden Allgemeine Java-Themen 9
darekkay (JUnit) Testdaten generieren - Framework? Allgemeine Java-Themen 2
S Framework für symetrische und asymetrische Verschlüsselung Allgemeine Java-Themen 3
W Framework für RichClient Anwendung? Allgemeine Java-Themen 4
A Framework für einen Web Service Allgemeine Java-Themen 6
D Frage zu Dependency Injection (mit Framework) Allgemeine Java-Themen 3
F Bildbearbeitung Framework Allgemeine Java-Themen 2
J Java Komponenten / Framework Allgemeine Java-Themen 5
L Web-Framework und Swing Framework o.ä Allgemeine Java-Themen 15
B Framework zum durchstöbern des classpath Allgemeine Java-Themen 2
B Was ist ein Framework? Allgemeine Java-Themen 36
E Java Media Framework Allgemeine Java-Themen 5
G Suche "richtiges" Framework/Library Allgemeine Java-Themen 14
M Swing Wünsche Feedback zu GUI-Framework Allgemeine Java-Themen 6
A 2D Framework für Java Allgemeine Java-Themen 2
N Graph mit JUNG-Framework erstellen Allgemeine Java-Themen 2
H Framework empfehlung / gute Anfängerbeispiele gesucht Allgemeine Java-Themen 12
T Lib/Framework zum Automatischen Aufruf von Methoden Allgemeine Java-Themen 2
T jmf Java Media Framework - Liste der unterstützten Medien? Allgemeine Java-Themen 11
X JAVA Framework für suspend and resume Allgemeine Java-Themen 2
D Banking Framework gesucht Allgemeine Java-Themen 5
G Swing Validierungs Framework Allgemeine Java-Themen 2
J Kleine Hilfe zum Framework Click Allgemeine Java-Themen 2
G Sehr gutes Java-Framework(Gui-Builder) auf XML-Basis gesucht Allgemeine Java-Themen 21
G Was ist ein Framework Allgemeine Java-Themen 4
C java media framework - Mikrofon wird nicht gefunden Allgemeine Java-Themen 18
J mp4 dateien mit dem Java Media Framework abspielen. Allgemeine Java-Themen 2
G Framework für Multi-Prozessor-Programmierung? Allgemeine Java-Themen 4
G Java Media Framework Allgemeine Java-Themen 8
ARadauer Java Desktop Framework Allgemeine Java-Themen 3
C JUNG Framework - einfacher Graph Allgemeine Java-Themen 7
H Java6 Scripting Framework. Allgemeine Java-Themen 3
N Was benötigt man fürs Java Media Framework? Allgemeine Java-Themen 12
J Problem mit Java Multimedia Framework Allgemeine Java-Themen 6
J Eure Meinung - Das JMF (Java Media Framework) Allgemeine Java-Themen 3
G Java Media Framework - Probleme Allgemeine Java-Themen 8
R Entsprechung von Stack() im Collections Framework...? Allgemeine Java-Themen 4

Ähnliche Java Themen

Neue Themen


Oben