KNN Erkennt Zahl Nicht

HerrInfo

Aktives Mitglied
Hallo zusammen,

ich habe damit begonnen, mich in neuronale Netze einzuarbeiten und dachte, ich starte mit etwas "Simplerem" – einem Netzwerk zur Handschrifterkennung. Als Framework nutzte ich dafür Deeplearning4j und ND4J. Das Netzwerk wurde über 30 Generationen mit den Daten von MNIST trainiert.

Wenn ich jetzt jedoch ein Bild selbst auswähle und es bewerten lasse, wird als Antwort immer die Zahl 8 mit einer sehr hohen Sicherheit ausgegeben. Habt ihr eine Idee, woran es liegen kann? Nebenbei bemerkt, wenn ich die Zahl ändere, bleibt die Antwort weiterhin 8 und das bei genau gleicher Sicherheit.

Java:
package net.tim;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

public class DigitRecognizer {
    private MultiLayerNetwork model;
    private DataSetIterator trainIter;
    private DataSetIterator testIter;
    private final int outputNum;
    private final int batchSize;

    public DigitRecognizer(int outputNum, int batchSize) {
        this.outputNum = outputNum;
        this.batchSize = batchSize;
    }

    public MultiLayerNetwork generateModel(int numInput, int numHidden, double learningRate, int outputNum) {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(123)
                .updater(new Adam(learningRate))
                .list()
                .layer(new DenseLayer.Builder().nIn(numInput).nOut(numHidden)
                        .activation(Activation.RELU)
                        .build())
                .layer(new DenseLayer.Builder().nIn(numHidden).nOut(numHidden)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(numHidden).nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(10));

        return model;
    }

    public void setModel(MultiLayerNetwork model) {
        this.model = model;
    }

    public void fetchData(String trainDataPath, String testDataPath) throws IOException, InterruptedException {
        System.out.println("Loading Training data...");
        RecordReader rr = new CSVRecordReader();
        rr.initialize(new FileSplit(new File(trainDataPath)));
        this.trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 10);

        System.out.println("Loading Test data...");
        RecordReader rrTest = new CSVRecordReader();
        rrTest.initialize(new FileSplit(new File(testDataPath)));
        this.testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 10);

        if (!trainIter.hasNext() || !testIter.hasNext()) {
            throw new InterruptedException("No data found in the CSV files");
        } else {
            System.out.println("Data loaded successfully");
        }
    }

    public void train(int nEpochs) {
        System.out.println("Training and evaluating model for " + nEpochs + " epochs...");
        for (int i = 0; i < nEpochs; i++) {
            System.out.println("Epoch " + i + " of " + nEpochs);
            while (trainIter.hasNext()) {
                DataSet next = trainIter.next();
                next.shuffle();
                model.fit(next);
            }
            trainIter.reset();
        }
        System.out.println("Training completed");
    }

    public void test() {
        testIter.reset();
        Evaluation eval = new Evaluation(outputNum);
        while (testIter.hasNext()) {
            DataSet next = testIter.next();
            next.shuffle();
            eval.eval(next.getLabels(), model.output(next.getFeatures()));
        }
        System.out.println(eval.stats());
    }

    public void saveModel(String filePath) throws IOException {
        File file = new File(filePath);
        ModelSerializer.writeModel(model, file, true);
        System.out.println("Model saved to " + filePath);
    }

    public void loadModel(String filePath) throws IOException {
        File file = new File(filePath);
        this.model = ModelSerializer.restoreMultiLayerNetwork(file);

        System.out.println("Model loaded from " + filePath);
    }

    public int predictDigitFromImage(String imagePath) throws IOException {
        BufferedImage img = ImageIO.read(new File(imagePath));
        if (img.getHeight() != 28 || img.getWidth() != 28) {
            throw new IllegalArgumentException("Image dimensions must be 28x28 pixels");
        }

        float[] pixels = new float[28 * 28];
        for (int y = 0; y < 28; y++) {
            for (int x = 0; x < 28; x++) {
                int rgb = img.getRGB(x, y);
                float gray = 0.299f * ((rgb >> 16) & 0xFF) + 0.587f * ((rgb >> 8) & 0xFF) + 0.114f * (rgb & 0xFF);
                // Normalize grayscale value to be between 0 and 1 and invert colors
                pixels[y * 28 + x] = 1 - (gray / 255f);
            }
        }

        INDArray features = Nd4j.create(pixels, new int[]{1, 784});
        INDArray output = model.output(features);

        // Get the digit with the highest probability
        int predictedDigit = Nd4j.argMax(output, 1).getInt(0);

        //System out raw output values for all digits
        System.out.println(output);
        return predictedDigit;
    }

    public static void main(String[] args) throws IOException, InterruptedException {
        int batchSize = 64;
        int outputNum = 10; // Number of possible outcomes (0-9)


        int numInput = 28 * 28; // Data input size
        int numHidden = 500; // Number of hidden nodes
        double learningRate = 0.001;

        DigitRecognizer recognizer = new DigitRecognizer(outputNum, batchSize);
        recognizer.setModel(recognizer.generateModel(numInput, numHidden, learningRate, outputNum));
        recognizer.fetchData("src/main/resources/mnist_train.csv", "src/main/resources/mnist_test.csv");

        recognizer.loadModel("src/main/resources/model2.zip");
        recognizer.test();
        //recognizer.saveModel("src/main/resources/model2.zip");

        System.out.println("Predicted digit: " + recognizer.predictDigitFromImage("src/main/resources/2.png"));
    }
}

Konsole:
Code:
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Loading Training data...
Loading Test data...
Data loaded successfully
Model loaded from src/main/resources/model2.zip


========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0,9684
 Precision:       0,9681
 Recall:          0,9681
 F1 Score:        0,9680
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  972    1    1    0    1    1    2    1    1    0 | 0 = 0
    0 1123    3    0    0    1    2    2    3    1 | 1 = 1
   13    5  987    8    0    1    4    6    7    1 | 2 = 2
    0    0    2  976    0   15    0    5    6    6 | 3 = 3
    4    3    1    1  945    0    4    2    4   18 | 4 = 4
    3    0    0    6    0  866    2    1    1   13 | 5 = 5
   11    2    0    0    6   26  910    0    3    0 | 6 = 6
    0    4   12    2    1    0    0 1001    3    5 | 7 = 7
    3    3    2    4    3   16    0    1  940    2 | 8 = 8
    1    2    0    2    6    5    0   21    8  964 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
[[    0.0803,    0.0614,    0.0966,    0.1142,    0.0882,    0.0981,    0.0905,    0.0893,    0.1646,    0.1168]]
Predicted digit: 8

Process finished with exit code 0
In diesem Fall wurde eine 2 erwartet. Vielen Dank im Voraus
 

HerrInfo

Aktives Mitglied
Klappt damit leider auch nicht:
Code:
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Loading Training data...
Loading Test data...
Data loaded successfully
Model loaded from src/main/resources/model.zip


========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0,9668
 Precision:       0,9663
 Recall:          0,9665
 F1 Score:        0,9663
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  964    0    2    1    0    1    5    2    5    0 | 0 = 0
    0 1123    2    1    0    1    1    0    6    1 | 1 = 1
    3    1 1005    3    0    0    4    7    9    0 | 2 = 2
    0    0    5  975    0   17    0    4    7    2 | 3 = 3
    2    1    1    1  949    0    8    1    7   12 | 4 = 4
    3    0    1   16    0  853    9    1    6    3 | 5 = 5
    5    2    0    0    1    9  938    0    3    0 | 6 = 6
    0    7   12    2    8    1    0  988    2    8 | 7 = 7
   10    1    6    6    4    8    6    2  929    2 | 8 = 8
    3    5    0   10   22   10    0    9    6  944 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
[[         0,         0,         0,         0,         0,    1.0000,         0,         0,         0,         0]]
Predicted digit: 5

Process finished with exit code 0
 

mihe7

Top Contributor
Hm... ich habe das mal kurz ausprobiert. Die MNIST-Dateien habe ich von https://pjreddie.com/projects/mnist-in-csv/, die 2 habe ich mir im GIMP gemalt (scharze Schrift auf transparentem Hintergrund. Den Alphakanal habe ich einfach mal mit >10 berücksichtigt:
Java:
        float[] pixels = new float[28 * 28];
        for (int y = 0; y < 28; y++) {
            for (int x = 0; x < 28; x++) {
                int rgb = img.getRGB(x, y);
                if (rgb >>> 24 > 10) {
                    float gray = 0.299f * ((rgb >> 16) & 0xFF) + 0.587f * ((rgb >> 8) & 0xFF) + 0.114f * (rgb & 0xFF);
                    // Normalize grayscale value to be between 0 and 1 and invert colors
                    pixels[y * 28 + x] = 255f * (1 - (gray / 255f));
                }
            }
        }

Ausgabe:
Code:
[[1.1238e-42,4.9575e-11,    1.0000,9.1025e-10,         0,1.0401e-23,         0,8.8838e-15,9.9152e-21,4.9864e-31]]
Predicted digit: 2
 

HerrInfo

Aktives Mitglied
Tatsache, vielen Dank! Ich habe keine Ahnung, warum die Methode bei mir nicht funktioniert hat. Wahrscheinlich habe ich die if-Abfrage vergessen. Nochmals vielen Dank!
 

HerrSchlaui

Neues Mitglied
Vermutlich hast du nicht an die Normalisierung der Outputs gedacht. Eine "Streckung" wäre natürlich auch möglich.

Was heißt das? Es wird das richtige Ergebnis berechnet aber zum Schluss falsch interpretiert.

Gleiches ist mir auch schon mit Schachfiguren passiert, jede Figur war ein Turm, da dieser höchsten raw Wert hatte.^^
 
Ähnliche Java Themen
  Titel Forum Antworten Datum
B Scanner erkennt keinen Text in Textdatei, obwohl welcher drinsteht Allgemeine Java-Themen 10
M Programm erkennt String aus .txt Datei nicht Allgemeine Java-Themen 3
T Swing TrayIcon erkennt links klick nicht (Mac) Allgemeine Java-Themen 2
V VisualVM Was erkennt ihr hier? Allgemeine Java-Themen 9
K BufferedReader.readLine erkennt Zeilenende nicht Allgemeine Java-Themen 11
G RegEx erkennt nicht Allgemeine Java-Themen 2
L Tomcat erkennt Share nicht Allgemeine Java-Themen 6
O Dateinamen mit Zahl um eins erhöhen Allgemeine Java-Themen 16
B Millionen bit lange zahl bauen? Allgemeine Java-Themen 7
J Zerlegen einer Zahl Allgemeine Java-Themen 6
J Die Letzte Zahl aus einer Text datei lesen Allgemeine Java-Themen 8
Tronert Alphabetische Aufzählung aus Zahl? Allgemeine Java-Themen 5
E String in Zahl umwandeln, ohne Befehl Integer.parseInt Allgemeine Java-Themen 3
E Swing andere schreibart für jButtoni (i = Zahl des Buttons) Allgemeine Java-Themen 6
J Eine bestimmte Zahl im Integer ändern Allgemeine Java-Themen 9
J While Schleife ausführen bis Zahl = X Allgemeine Java-Themen 19
J Repräsentation in Java - 32bit Zahl Allgemeine Java-Themen 8
T Quadrieren einer Zahl nur durch Addition Allgemeine Java-Themen 5
Z Zahl raten Allgemeine Java-Themen 2
Chr1s ergebnis = Zahl? Allgemeine Java-Themen 3
A Zahl abgerundet obwohl Double Allgemeine Java-Themen 9
K Interpreter-Fehler Java Zahl Raten Spiel- Fehlermeldung mir unbekannt Allgemeine Java-Themen 12
J Die Menge einer Zahl im Binärbaum zählen Allgemeine Java-Themen 7
P Input/Output java.util.Scanner in einer Schleife und Exception-Behandlung: Einlesen einer Zahl Allgemeine Java-Themen 4
A Zahl zu lang für Long Allgemeine Java-Themen 3
L Leerzeichen zu string hinzufügen, um eine gerade zahl zu erhalten Allgemeine Java-Themen 9
O Prüfen ob String eine Zahl mit maximal 2 Nachkommastellen ist Allgemeine Java-Themen 4
N Zahl mit bestimmter Länge und nur bestimmten Zahlen generieren lassen Allgemeine Java-Themen 7
J Bestimmter Buchstabe = bestimmte Zahl Allgemeine Java-Themen 10
H Eclipse x Stellen einer Zahl in array speichern Allgemeine Java-Themen 3
S Antlr Grammatik übersetzt ohne Fehler, dennoch wird Zahl nicht als Eingabe erkannt Allgemeine Java-Themen 4
C Zahl im Textarea anzeigen lassen Allgemeine Java-Themen 8
C Regex: Zahl ohne führende Null Allgemeine Java-Themen 13
cedi int Zahl in ein ASCII zeichen umwandeln und dieses in ein externes Textfenster schreiben Allgemeine Java-Themen 6
Rudolf Aus Collection<Integer> eine Zahl machen Allgemeine Java-Themen 2
M Zahl aktiver Threads einer Gruppe verlässlich abfragen Allgemeine Java-Themen 3
C Prüfen auf Zahl und 6 stellig fehlerhaft? warum? Allgemeine Java-Themen 7
S Zahl konvertieren [Internationalisierung l10n, l18n] Allgemeine Java-Themen 4
T Zufallszahlen generieren und dabei eine Zahl weglassen Allgemeine Java-Themen 4
Z Zahl einer spanne zuordnen Allgemeine Java-Themen 2
FoolMoon Elegante Möglichkeit die kleinste Zahl zu ermitteln. Allgemeine Java-Themen 7
E Konstante Zahl Threads parallel rechnen lassen Allgemeine Java-Themen 6
L Berechnung mit Module bis bes.timme Zahl erreicht. Allgemeine Java-Themen 4
Ark O-Notation und Zahl versus String-Repräsentation Allgemeine Java-Themen 7
N int[] eindeutig durch eine Zahl repräsentieren Allgemeine Java-Themen 12
D Regular Expression Mit Punkt und Zahl Allgemeine Java-Themen 4
X Substring aus Zahl Allgemeine Java-Themen 8
G Auf eine ganze Zahl aufrunden Allgemeine Java-Themen 30
G Zahl aus dem String Allgemeine Java-Themen 6
K Double-Zahl runden Allgemeine Java-Themen 4
L Partitionen der Länge x einer natürlichen Zahl n Allgemeine Java-Themen 21
G Prüfen ob Ziffern einer Zahl pandigital sind? Allgemeine Java-Themen 15
J Große Zahl (double) as text ausgeben? Allgemeine Java-Themen 2
0 Alle Teiler einer Zahl performant berechnen? Allgemeine Java-Themen 9
G Double Zahl quadrieren Allgemeine Java-Themen 8
G String in Zahl umwandeln Allgemeine Java-Themen 9
C Server-Zahl von google.com Allgemeine Java-Themen 11
B Umwandeln von Bytes in float Zahl (DataInputStream) Allgemeine Java-Themen 3
H ganze zahl true / false Allgemeine Java-Themen 3
M Umwandeln String (mit Zahl zur Basis 36) in Dezimalzahl Allgemeine Java-Themen 2
N Float zahl auf eine Stelle nach dem Komma runden Allgemeine Java-Themen 3
G Double Zahl auf 4 Stellen hinter Komma kuerzen Allgemeine Java-Themen 4
S addAtPosition - Zahl an einer bestimmten Position einfügen Allgemeine Java-Themen 8
G String als Zahl erkennen Allgemeine Java-Themen 19
N Zahl mit DecimalFormat formattieren Allgemeine Java-Themen 2
R Zahl eingeben! Allgemeine Java-Themen 9
W Warum funktioniert das nicht? Allgemeine Java-Themen 9
H HashMap in HashMap, klappt nicht, Verständnisproblem Allgemeine Java-Themen 2
kodela ArrayList wird nicht komplett gespeichert Allgemeine Java-Themen 3
kodela HelpSet wird nicht gefunden Allgemeine Java-Themen 8
D Compiler-Fehler Compilierung in VM nicht erfolgreich Allgemeine Java-Themen 10
G WSDL-Aufruf funktioniert nicht mehr nach Umstieg auf Maven Allgemeine Java-Themen 4
W ICEpdf PDF-Dateien werden mit Java 21 nicht nicht mehr vollständig dargestellt Allgemeine Java-Themen 3
Zrebna Berechnung der Zeit funktioniert nicht wie erwartet: Date, GregorianCalendar Allgemeine Java-Themen 16
Zrebna Wieso sollte man Null-Prüfungen nicht mit Optional-Objekten nutzen? Allgemeine Java-Themen 13
kodela Textfeld nicht rechteckig Allgemeine Java-Themen 10
G Doppelklick auf Javaprogramm klapt nicht Allgemeine Java-Themen 1
W Timer terminiert nicht Allgemeine Java-Themen 5
D Linux, Java-Version wird nicht erkannt bzw. welche Einstellung fehlt noch? Allgemeine Java-Themen 19
W Überflüssige Deklaration vermeiden...war da nicht mal was? Allgemeine Java-Themen 3
N lwjgl kann textureSampler nicht finden Allgemeine Java-Themen 4
P Fehler: Hauptklasse Main konnte nicht gefunden oder geladen werden Ursache: java.lang.ClassNotFoundException: Main Allgemeine Java-Themen 24
S Java Programm lässt sich vom USB-Stick starten, aber nicht von HDD Allgemeine Java-Themen 16
T .Jar kann man nicht ausführen Allgemeine Java-Themen 18
P JDK nicht installiert in Net Object Fusion Allgemeine Java-Themen 7
D Image bewegt sich nicht nach Klicken auf Button Allgemeine Java-Themen 15
N Regex schlägt nicht an Allgemeine Java-Themen 10
Y Wieso krieg ich die Unit Tests nicht hin Allgemeine Java-Themen 55
D Erste Schritte Mp3 Datei kann nicht von der Festplatte geöffnet werden - mit ChatGPT erstellt Allgemeine Java-Themen 7
G Popup wird nicht sichtbar Allgemeine Java-Themen 9
8u3631984 Funktions Parameter mit Lombok "NonNull" annotieren wird in Jacococ Testcoverage nicht herausgefiltert Allgemeine Java-Themen 3
kodela String kann nicht zu Pfad konvertiert werden Allgemeine Java-Themen 16
M Apache Proxy Weiterleitung auf Tomcat funktioniert nicht wie gewünscht Allgemeine Java-Themen 1
Momo16 Brauche Hilfe - Java Projekt kann nicht erstellt werden Allgemeine Java-Themen 12
OnDemand ApacheCommon FTP Client zuckt nicht Allgemeine Java-Themen 3
T JavaPoet - (noch) nicht existente Typen Allgemeine Java-Themen 2
E Es ist nicht möglich, eine Batch-Anweisung auszuführen. Allgemeine Java-Themen 9
C Was passt hier nicht bei der Calendar-Class Allgemeine Java-Themen 2
T Testing JUnit5: try ... catch arbeitet nicht sauber Allgemeine Java-Themen 6
W While Schleife funktioniert nicht ganz Allgemeine Java-Themen 4

Ähnliche Java Themen

Neue Themen


Oben