Startwerte bei univariater linearer Regression

CptK

CptK

Bekanntes Mitglied
Hallo,
ich habe eine Klasse für univariate lineare Regression:
UnivariateLinearRegression:
public class UnivariateLinearRegression implements LinearModel {
    
    private double w0, w1; // weights hw(x) = w1*x + w0

    public UnivariateLinearRegression() {
        w0 = Math.random();
        w1 = Math.random();
    }

    @Override
    public void train(double[][] data, int epochs, double alpha) {
        for(int i = 0; i < epochs; i++) {
            double errorw0 = sumErrorW0(data);
            double errorw1 = sumErrorW1(data);
            w0 = w0 + alpha * errorw0;
            w1 = w1 + alpha * errorw1;
        }
    }

    private double sumErrorW0(double[][] data) {
        double sum = 0;
        for (int i = 0; i < data.length; i++) {
            if (data[i].length != 2)
                throw new IllegalArgumentException("Each Data-Point must have two elements");
            sum += data[i][1] - (w1 * data[i][0] + w0);
        }
        return sum;
    }

    private double sumErrorW1(double[][] data) {
        double sum = 0;
        for (int i = 0; i < data.length; i++) {
            if (data[i].length != 2)
                throw new IllegalArgumentException("Each Data-Point must have two elements");
            sum += (data[i][1] - (w1 * data[i][0] + w0)) * data[i][0];
        }
        return sum;
    }
}
Nun habe ich das Problem, dass die Qualität des Ergebnisses sehr stark von den Startwerten abhängt. Momentan mache ich das ja mit Math.random() was (offensichtlich) nicht die beste lösung ist. Wie würde man das besser machen?

LG
 
mihe7

mihe7

Top Contributor
Nun habe ich das Problem, dass die Qualität des Ergebnisses sehr stark von den Startwerten abhängt.
Ich würde nicht (nur) nach Epochen gehen, sondern insbesondere das Delta des Fehlers im Blick behalten. Wenn keine nennenswerte Verbesserung mehr eintritt -> beenden. Evtl. ist auch das alpha zu groß, so dass das Optimum übersprungen wird. Außerdem würde ich sum noch durch data.length dividieren, sonst muss man das alpha an die Anzahl von Training-Samples anpassen.
 
CptK

CptK

Bekanntes Mitglied
Ich würde nicht (nur) nach Epochen gehen, sondern insbesondere das Delta des Fehlers im Blick behalten. Wenn keine nennenswerte Verbesserung mehr eintritt -> beenden. Evtl. ist auch das alpha zu groß, so dass das Optimum übersprungen wird. Außerdem würde ich sum noch durch data.length dividieren, sonst muss man das alpha an die Anzahl von Training-Samples anpassen.

Wäre das so korrekt:
Java:
public void train(double[][] data, int epochs, double alpha, double epsilon) {
        for(int i = 0; i < epochs; i++) {

            double errorw0 = 0;
            double errorw1 = 0;
            for (int j = 0; j < data.length; j++) {
                if (data[j].length != 2)
                    throw new IllegalArgumentException("Each Data-Point must have two elements");
                errorw0 += data[j][1] - (w1 * data[j][0] + w0);
                errorw1 += (data[j][1] - (w1 * data[j][0] + w0)) * data[j][0];
            }

            if (errorw0 < epsilon && errorw1 < epsilon)
                return;

            w0 = w0 + alpha * (errorw0 / data.length);
            w1 = w1 + alpha * (errorw1 / data.length);
        }
    }
Ich habe jetzt gelesen, dass die Konvergenz gegen das globale Minimum garantiert ist, sofern alpha klein genug ist. würde es also Sinn machen die epochs komplett rauszunehmen und das in dieser Art zu implementieren:
Java:
while(errorw0 >= epsilon || errorw1 >= epsilon) {
    //update w0 & w1
    ....
}
 
Zuletzt bearbeitet:
B

betatwo

Mitglied
Ehrlich gesagt verstehe ich nicht so wirklich, wie ich das verwenden soll.
Musst du das denn unbedingt per Hand machen??? Es macht mathematisch keinen Sinn, aber so funktioniert die Inter- und Exploration mit commons math:
Java:
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;

public class LinearRegression {
    public static double linearRegression(double... ds) {
        double[] xs = new double[ds.length];
        for (int i = 0; i < xs.length; i++) {
            xs[i] = i;
        }
        LinearInterpolator linearInterpolator = new LinearInterpolator();
        PolynomialSplineFunction splineFunction = linearInterpolator.interpolate(xs, ds);
        PolynomialFunction[] polynomials = splineFunction.getPolynomials();
        return polynomials[polynomials.length - 1].value(2);
    }

    public static void main(String[] args) {
        List<Double> ds = new ArrayList<>(List.of(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
        for (int i = 0; i < 5; i++) {
            double prediction = linearRegression(ds.stream().mapToDouble(Double::doubleValue).toArray());
            System.out.println(prediction);
            ds.add(prediction);
        }
    }
}

Das Ergebnis ist wenig spektakulär. 😊
 
Ähnliche Java Themen
  Titel Forum Antworten Datum
C linearer Suchalgorithmus Allgemeine Java-Themen 2

Ähnliche Java Themen


Oben