Skip to content

Commit db81cab

Browse files
author
Congrui Yi
authoredOct 3, 2016
fitting the intercept and adding screening rules
1 parent 124a3ca commit db81cab

File tree

1 file changed

+94
-51
lines changed

1 file changed

+94
-51
lines changed
 

‎src/raw.c

+94-51
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,32 @@
55
#include <R.h>
66
#include <R_ext/Applic.h>
77
double crossprod(double *X, double *y, int n, int j);
8-
double wcrossprod(double *X, double *y, double *w, int n, int j);
98
int checkConvergence(double *beta, double *beta_old, double eps, int l, int J);
10-
double S(double z, double l);
119
double MCP(double z, double l1, double l2, double gamma, double v);
1210
double SCAD(double z, double l1, double l2, double gamma, double v);
1311
double lasso(double z, double l1, double l2, double v);
1412
double gLoss(double *r, int n);
1513
double sqsum(double *X, int n, int j);
1614

1715
// Memory handling, output formatting (raw)
18-
SEXP cleanupR(double *r, double *a, double *v, int *e, SEXP beta0, SEXP beta, SEXP Dev, SEXP iter) {
19-
Free(r);
16+
SEXP cleanupR(double *a, double *r, double *v, double *z, int *e1, int *e2, SEXP beta0, SEXP beta, SEXP loss, SEXP iter) {
2017
Free(a);
18+
Free(r);
2119
Free(v);
22-
Free(e);
20+
Free(z);
21+
Free(e1);
22+
Free(e2);
2323
SEXP res;
2424
PROTECT(res = allocVector(VECSXP, 4));
2525
SET_VECTOR_ELT(res, 0, beta0);
2626
SET_VECTOR_ELT(res, 1, beta);
27-
SET_VECTOR_ELT(res, 2, Dev);
27+
SET_VECTOR_ELT(res, 2, loss);
2828
SET_VECTOR_ELT(res, 3, iter);
2929
UNPROTECT(5);
3030
return(res);
3131
}
3232

33-
// Coordinate descent for raw, unstandardized least squares
34-
// need to fit an intercept
33+
// Coordinate descent for gaussian models
3534
SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max_iter_, SEXP gamma_, SEXP multiplier, SEXP alpha_, SEXP dfmax_, SEXP user_) {
3635

3736
// Declarations
@@ -48,9 +47,9 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
4847
PROTECT(loss = allocVector(REALSXP, L));
4948
PROTECT(iter = allocVector(INTSXP, L));
5049
for (int i=0; i<L; i++) INTEGER(iter)[i] = 0;
51-
double *a = Calloc(p, double); // Beta from previous iteration
50+
double *a = Calloc(p, double); // Beta from previous iteration
5251
for (int j=0; j<p; j++) a[j]=0;
53-
double a0 = 0; // Beta0 from previous iteration, initially 0 from KKT since y is mean-centered
52+
double a0 = 0; // Beta0 from previous iteration, initially 0 from KKT since y is mean-centered
5453
double *X = REAL(X_);
5554
double *y = REAL(y_);
5655
const char *penalty = CHAR(STRING_ELT(penalty_, 0));
@@ -64,13 +63,15 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
6463
int user = INTEGER(user_)[0];
6564
double *r = Calloc(n, double);
6665
for (int i=0; i<n; i++) r[i] = y[i];
67-
double *z = Calloc(p, double);
68-
for (int j=0; j<p; j++) z[j] = crossprod(X, r, n, j)/n;
6966
double *v = Calloc(p, double);
7067
for (int j=0; j<p; j++) v[j] = sqsum(X, n, j)/n;
71-
int *e = Calloc(p, int);
72-
for (int j=0; j<p; j++) e[j] = 1;
73-
double l1, l2, u, mean_resid, shift;
68+
double *z = Calloc(p, double);
69+
for (int j=0; j<p; j++) z[j] = crossprod(X, r, n, j)/n; // initial a[j] = 0
70+
int *e1 = Calloc(p, int);
71+
for (int j=0; j<p; j++) e1[j] = 0;
72+
int *e2 = Calloc(p, int);
73+
for (int j=0; j<p; j++) e2[j] = 0;
74+
double cutoff, l1, l2, mean_resid, shift;
7475
int converged, lstart;
7576

7677
// If lam[0]=lam_max, skip lam[0] -- closed form sol'n available
@@ -82,7 +83,7 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
8283
}
8384

8485
// Path
85-
for (int l=lstart; l<L; l++) {
86+
for (int l=lstart;l<L;l++) {
8687
R_CheckUserInterrupt();
8788
if (l != 0) {
8889
// Assign a0, a
@@ -92,66 +93,108 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
9293
// Check dfmax
9394
int nv = 0;
9495
for (int j=0; j<p; j++) {
95-
if (a[j] != 0) nv++;
96+
if (a[j] != 0) nv++;
9697
}
9798
if (nv > dfmax) {
9899
for (int ll=l; ll<L; ll++) INTEGER(iter)[ll] = NA_INTEGER;
99-
res = cleanupR(r, a, v, e, beta0, beta, loss, iter);
100+
res = cleanupR(a, r, v, z, e1, e2, beta0, beta, loss, iter);
100101
return(res);
101102
}
103+
104+
// Determine eligible set
105+
if (strcmp(penalty, "lasso")==0) cutoff = 2*lam[l] - lam[l-1];
106+
if (strcmp(penalty, "MCP")==0) cutoff = lam[l] + gamma/(gamma-1)*(lam[l] - lam[l-1]);
107+
if (strcmp(penalty, "SCAD")==0) cutoff = lam[l] + gamma/(gamma-2)*(lam[l] - lam[l-1]);
108+
for (int j=0; j<p; j++) if (fabs(z[j]) > (cutoff * alpha * m[j])) e2[j] = 1;
109+
} else {
110+
// Determine eligible set
111+
double lmax = 0;
112+
for (int j=0; j<p; j++) if (fabs(z[j]) > lmax) lmax = fabs(z[j]);
113+
if (strcmp(penalty, "lasso")==0) cutoff = 2*lam[l] - lmax;
114+
if (strcmp(penalty, "MCP")==0) cutoff = lam[l] + gamma/(gamma-1)*(lam[l] - lmax);
115+
if (strcmp(penalty, "SCAD")==0) cutoff = lam[l] + gamma/(gamma-2)*(lam[l] - lmax);
116+
for (int j=0; j<p; j++) if (fabs(z[j]) > (cutoff * alpha * m[j])) e2[j] = 1;
102117
}
103118

104119
while (INTEGER(iter)[l] < max_iter) {
105120
while (INTEGER(iter)[l] < max_iter) {
106-
INTEGER(iter)[l]++;
107-
// intercept
108-
mean_resid = 0.0;
109-
for (int i=0; i<n; i++) mean_resid += r[i];
110-
mean_resid /= n;
111-
b0[l] = mean_resid + a0;
112-
for (int i=0; i<n; i++) r[i] -= mean_resid;
121+
while (INTEGER(iter)[l] < max_iter) {
122+
// Solve over the active set
123+
INTEGER(iter)[l]++;
124+
// intercept
125+
mean_resid = 0.0;
126+
for (int i=0; i<n; i++) mean_resid += r[i];
127+
mean_resid /= n;
128+
b0[l] = mean_resid + a0;
129+
for (int i=0; i<n; i++) r[i] -= mean_resid;
113130

114-
// covariates
131+
for (int j=0; j<p; j++) {
132+
if (e1[j]) {
133+
z[j] = crossprod(X, r, n, j)/n + v[j]*a[j];
134+
135+
// Update beta_j
136+
l1 = lam[l] * m[j] * alpha;
137+
l2 = lam[l] * m[j] * (1-alpha);
138+
if (strcmp(penalty,"MCP")==0) b[l*p+j] = MCP(z[j], l1, l2, gamma, v[j]);
139+
if (strcmp(penalty,"SCAD")==0) b[l*p+j] = SCAD(z[j], l1, l2, gamma, v[j]);
140+
if (strcmp(penalty,"lasso")==0) b[l*p+j] = lasso(z[j], l1, l2, v[j]);
141+
142+
// Update r
143+
shift = b[l*p+j] - a[j];
144+
if (shift !=0) for (int i=0;i<n;i++) r[i] -= shift*X[j*n+i];
145+
}
146+
}
147+
148+
// Check for convergence
149+
converged = checkConvergence(b, a, eps, l, p);
150+
a0 = b0[l];
151+
for (int j=0; j<p; j++) a[j] = b[l*p+j];
152+
if (converged) break;
153+
}
154+
155+
// Scan for violations in strong set
156+
int violations = 0;
115157
for (int j=0; j<p; j++) {
116-
if (e[j]) {
117-
u = crossprod(X, r, n, j)/n + v[j]*a[j];
158+
if (e1[j]==0 && e2[j]==1) {
118159

119-
// Update b_j
160+
z[j] = crossprod(X, r, n, j)/n; // a[j] = 0
161+
162+
// Update beta_j
120163
l1 = lam[l] * m[j] * alpha;
121164
l2 = lam[l] * m[j] * (1-alpha);
122-
if (strcmp(penalty,"MCP")==0) b[l*p+j] = MCP(u, l1, l2, gamma, v[j]);
123-
if (strcmp(penalty,"SCAD")==0) b[l*p+j] = SCAD(u, l1, l2, gamma, v[j]);
124-
if (strcmp(penalty,"lasso")==0) b[l*p+j] = lasso(u, l1, l2, v[j]);
165+
if (strcmp(penalty,"MCP")==0) b[l*p+j] = MCP(z[j], l1, l2, gamma, v[j]);
166+
if (strcmp(penalty,"SCAD")==0) b[l*p+j] = SCAD(z[j], l1, l2, gamma, v[j]);
167+
if (strcmp(penalty,"lasso")==0) b[l*p+j] = lasso(z[j], l1, l2, v[j]);
125168

126-
// Update r
127-
shift = b[l*p+j] - a[j];
128-
if (shift !=0) for (int i=0;i<n;i++) r[i] -= shift*X[j*n+i];
169+
// If something enters the eligible set, update eligible set & residuals
170+
if (b[l*p+j] !=0) {
171+
e1[j] = e2[j] = 1;
172+
for (int i=0; i<n; i++) r[i] -= b[l*p+j]*X[j*n+i];
173+
a[j] = b[l*p+j];
174+
violations++;
175+
}
129176
}
130177
}
131-
132-
// Check for convergence
133-
converged = checkConvergence(b, a, eps, l, p);
134-
a0 = b0[l];
135-
for (int j=0; j<p; j++) a[j] = b[l*p+j];
136-
if (converged) break;
178+
if (violations==0) break;
137179
}
138180

139-
// Scan for violations
181+
// Scan for violations in rest
140182
int violations = 0;
141183
for (int j=0; j<p; j++) {
142-
if (e[j]==0) {
143-
u = crossprod(X, r, n, j)/n + v[j]*a[j];
184+
if (e2[j]==0) {
185+
186+
z[j] = crossprod(X, r, n, j)/n; // a[j] = 0
144187

145-
// Update b_j
188+
// Update beta_j
146189
l1 = lam[l] * m[j] * alpha;
147190
l2 = lam[l] * m[j] * (1-alpha);
148-
if (strcmp(penalty,"MCP")==0) b[l*p+j] = MCP(u, l1, l2, gamma, v[j]);
149-
if (strcmp(penalty,"SCAD")==0) b[l*p+j] = SCAD(u, l1, l2, gamma, v[j]);
150-
if (strcmp(penalty,"lasso")==0) b[l*p+j] = lasso(u, l1, l2, v[j]);
191+
if (strcmp(penalty,"MCP")==0) b[l*p+j] = MCP(z[j], l1, l2, gamma, v[j]);
192+
if (strcmp(penalty,"SCAD")==0) b[l*p+j] = SCAD(z[j], l1, l2, gamma, v[j]);
193+
if (strcmp(penalty,"lasso")==0) b[l*p+j] = lasso(z[j], l1, l2, v[j]);
151194

152195
// If something enters the eligible set, update eligible set & residuals
153196
if (b[l*p+j] !=0) {
154-
e[j] = 1;
197+
e1[j] = e2[j] = 1;
155198
for (int i=0; i<n; i++) r[i] -= b[l*p+j]*X[j*n+i];
156199
a[j] = b[l*p+j];
157200
violations++;
@@ -160,11 +203,11 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
160203
}
161204

162205
if (violations==0) {
163-
REAL(loss)[l] = gLoss(r, n);
164206
break;
165207
}
166208
}
209+
REAL(loss)[l] = gLoss(r, n);
167210
}
168-
res = cleanupR(r, a, v, e, beta0, beta, loss, iter);
211+
res = cleanupR(a, r, v, z, e1, e2, beta0, beta, loss, iter);
169212
return(res);
170213
}

0 commit comments

Comments
 (0)
Please sign in to comment.