Skip to content

Commit

Permalink
remove 'transient' keyword for workspace to ensure proper initializat…
Browse files Browse the repository at this point in the history
…ion after deserialization
  • Loading branch information
haifengl committed Oct 29, 2017
1 parent 01d86b9 commit a3a4074
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions core/src/main/java/smile/regression/RLS.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,20 @@
*/
public class RLS implements OnlineRegression<double[]>, Serializable {
private static final long serialVersionUID = 1L;
private static final Logger logger = LoggerFactory.getLogger(RLS.class);


/**
* The dimensionality.
*/
private int p;
/**
* The coefficients with intercept.
*/
private double[] w;
/**
* The forgetting factor in (0, 1]. Values closer to 1 will have
* longer memory and values closer to 0 will be have shorter memory.
*/
private double lambda;
/**
* First initialized to the matrix (X<sup>T</sup>X)<sup>-1</sup>,
* it is updated with each new learning instance.
Expand All @@ -63,21 +71,12 @@ public class RLS implements OnlineRegression<double[]>, Serializable {
/**
* A single learning instance X, padded with 1 for intercept.
*/
private transient double[] x1;
/**
* The coefficients with intercept.
*/
private transient double[] w;
private double[] x1;
/**
* A temporary array used in computing V * X .
*/
private transient double[] Vx;
/**
* The forgetting factor in (0, 1]. Values closer to 1 will have
* longer memory and values closer to 0 will be have shorter memory.
*/
private double lambda;

private double[] Vx;

/**
* Trainer for linear regression by recursive least squares.
*/
Expand Down Expand Up @@ -135,7 +134,9 @@ public RLS(double[][] x, double[] y, double lambda) {
X.set(i, p, 1.0);
}

// weights and intercept
// Always use SVD instead of QR because it is more stable
// when the data is close to rank deficient, which is more
// likely in RLS as the initial data size may be small.
this.w = new double[p+1];
SVD svd = X.svd();
svd.solve(y, w);
Expand Down Expand Up @@ -230,7 +231,7 @@ public double getForgettingFactor() {
* @param lambda the forgetting factor
*/
public void setForgettingFactor(double lambda) {
if (lambda<=0 || lambda>1){
if (lambda <= 0 || lambda > 1){
throw new IllegalArgumentException("The forgetting factor is not between 0 (exclusive) and 1 (inclusive)");
}
this.lambda = lambda;
Expand Down

0 comments on commit a3a4074

Please sign in to comment.