Hallo zusammen,
ich habe damit begonnen, mich in neuronale Netze einzuarbeiten und dachte, ich starte mit etwas "Simplerem" – einem Netzwerk zur Handschrifterkennung. Als Framework nutzte ich dafür Deeplearning4j und ND4J. Das Netzwerk wurde über 30 Generationen mit den Daten von MNIST trainiert.
Wenn ich jetzt jedoch ein Bild selbst auswähle und es bewerten lasse, wird als Antwort immer die Zahl 8 mit einer sehr hohen Sicherheit ausgegeben. Habt ihr eine Idee, woran es liegen kann? Nebenbei bemerkt, wenn ich die Zahl ändere, bleibt die Antwort weiterhin 8 und das bei genau gleicher Sicherheit.
Konsole:
In diesem Fall wurde eine 2 erwartet. Vielen Dank im Voraus
ich habe damit begonnen, mich in neuronale Netze einzuarbeiten und dachte, ich starte mit etwas "Simplerem" – einem Netzwerk zur Handschrifterkennung. Als Framework nutzte ich dafür Deeplearning4j und ND4J. Das Netzwerk wurde über 30 Generationen mit den Daten von MNIST trainiert.
Wenn ich jetzt jedoch ein Bild selbst auswähle und es bewerten lasse, wird als Antwort immer die Zahl 8 mit einer sehr hohen Sicherheit ausgegeben. Habt ihr eine Idee, woran es liegen kann? Nebenbei bemerkt, wenn ich die Zahl ändere, bleibt die Antwort weiterhin 8 und das bei genau gleicher Sicherheit.
Java:
package net.tim;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
public class DigitRecognizer {
private MultiLayerNetwork model;
private DataSetIterator trainIter;
private DataSetIterator testIter;
private final int outputNum;
private final int batchSize;
public DigitRecognizer(int outputNum, int batchSize) {
this.outputNum = outputNum;
this.batchSize = batchSize;
}
public MultiLayerNetwork generateModel(int numInput, int numHidden, double learningRate, int outputNum) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.updater(new Adam(learningRate))
.list()
.layer(new DenseLayer.Builder().nIn(numInput).nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(new DenseLayer.Builder().nIn(numHidden).nOut(numHidden)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(numHidden).nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
return model;
}
public void setModel(MultiLayerNetwork model) {
this.model = model;
}
public void fetchData(String trainDataPath, String testDataPath) throws IOException, InterruptedException {
System.out.println("Loading Training data...");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(trainDataPath)));
this.trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 10);
System.out.println("Loading Test data...");
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File(testDataPath)));
this.testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 10);
if (!trainIter.hasNext() || !testIter.hasNext()) {
throw new InterruptedException("No data found in the CSV files");
} else {
System.out.println("Data loaded successfully");
}
}
public void train(int nEpochs) {
System.out.println("Training and evaluating model for " + nEpochs + " epochs...");
for (int i = 0; i < nEpochs; i++) {
System.out.println("Epoch " + i + " of " + nEpochs);
while (trainIter.hasNext()) {
DataSet next = trainIter.next();
next.shuffle();
model.fit(next);
}
trainIter.reset();
}
System.out.println("Training completed");
}
public void test() {
testIter.reset();
Evaluation eval = new Evaluation(outputNum);
while (testIter.hasNext()) {
DataSet next = testIter.next();
next.shuffle();
eval.eval(next.getLabels(), model.output(next.getFeatures()));
}
System.out.println(eval.stats());
}
public void saveModel(String filePath) throws IOException {
File file = new File(filePath);
ModelSerializer.writeModel(model, file, true);
System.out.println("Model saved to " + filePath);
}
public void loadModel(String filePath) throws IOException {
File file = new File(filePath);
this.model = ModelSerializer.restoreMultiLayerNetwork(file);
System.out.println("Model loaded from " + filePath);
}
public int predictDigitFromImage(String imagePath) throws IOException {
BufferedImage img = ImageIO.read(new File(imagePath));
if (img.getHeight() != 28 || img.getWidth() != 28) {
throw new IllegalArgumentException("Image dimensions must be 28x28 pixels");
}
float[] pixels = new float[28 * 28];
for (int y = 0; y < 28; y++) {
for (int x = 0; x < 28; x++) {
int rgb = img.getRGB(x, y);
float gray = 0.299f * ((rgb >> 16) & 0xFF) + 0.587f * ((rgb >> 8) & 0xFF) + 0.114f * (rgb & 0xFF);
// Normalize grayscale value to be between 0 and 1 and invert colors
pixels[y * 28 + x] = 1 - (gray / 255f);
}
}
INDArray features = Nd4j.create(pixels, new int[]{1, 784});
INDArray output = model.output(features);
// Get the digit with the highest probability
int predictedDigit = Nd4j.argMax(output, 1).getInt(0);
//System out raw output values for all digits
System.out.println(output);
return predictedDigit;
}
public static void main(String[] args) throws IOException, InterruptedException {
int batchSize = 64;
int outputNum = 10; // Number of possible outcomes (0-9)
int numInput = 28 * 28; // Data input size
int numHidden = 500; // Number of hidden nodes
double learningRate = 0.001;
DigitRecognizer recognizer = new DigitRecognizer(outputNum, batchSize);
recognizer.setModel(recognizer.generateModel(numInput, numHidden, learningRate, outputNum));
recognizer.fetchData("src/main/resources/mnist_train.csv", "src/main/resources/mnist_test.csv");
recognizer.loadModel("src/main/resources/model2.zip");
recognizer.test();
//recognizer.saveModel("src/main/resources/model2.zip");
System.out.println("Predicted digit: " + recognizer.predictDigitFromImage("src/main/resources/2.png"));
}
}
Konsole:
Code:
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Loading Training data...
Loading Test data...
Data loaded successfully
Model loaded from src/main/resources/model2.zip
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0,9684
Precision: 0,9681
Recall: 0,9681
F1 Score: 0,9680
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
---------------------------------------------------
972 1 1 0 1 1 2 1 1 0 | 0 = 0
0 1123 3 0 0 1 2 2 3 1 | 1 = 1
13 5 987 8 0 1 4 6 7 1 | 2 = 2
0 0 2 976 0 15 0 5 6 6 | 3 = 3
4 3 1 1 945 0 4 2 4 18 | 4 = 4
3 0 0 6 0 866 2 1 1 13 | 5 = 5
11 2 0 0 6 26 910 0 3 0 | 6 = 6
0 4 12 2 1 0 0 1001 3 5 | 7 = 7
3 3 2 4 3 16 0 1 940 2 | 8 = 8
1 2 0 2 6 5 0 21 8 964 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
[[ 0.0803, 0.0614, 0.0966, 0.1142, 0.0882, 0.0981, 0.0905, 0.0893, 0.1646, 0.1168]]
Predicted digit: 8
Process finished with exit code 0