5
5
#include <R.h>
6
6
#include <R_ext/Applic.h>
7
7
double crossprod (double * X , double * y , int n , int j );
8
- double wcrossprod (double * X , double * y , double * w , int n , int j );
9
8
int checkConvergence (double * beta , double * beta_old , double eps , int l , int J );
10
- double S (double z , double l );
11
9
double MCP (double z , double l1 , double l2 , double gamma , double v );
12
10
double SCAD (double z , double l1 , double l2 , double gamma , double v );
13
11
double lasso (double z , double l1 , double l2 , double v );
14
12
double gLoss (double * r , int n );
15
13
double sqsum (double * X , int n , int j );
16
14
17
15
// Memory handling, output formatting (raw)
18
- SEXP cleanupR (double * r , double * a , double * v , int * e , SEXP beta0 , SEXP beta , SEXP Dev , SEXP iter ) {
19
- Free (r );
16
+ SEXP cleanupR (double * a , double * r , double * v , double * z , int * e1 , int * e2 , SEXP beta0 , SEXP beta , SEXP loss , SEXP iter ) {
20
17
Free (a );
18
+ Free (r );
21
19
Free (v );
22
- Free (e );
20
+ Free (z );
21
+ Free (e1 );
22
+ Free (e2 );
23
23
SEXP res ;
24
24
PROTECT (res = allocVector (VECSXP , 4 ));
25
25
SET_VECTOR_ELT (res , 0 , beta0 );
26
26
SET_VECTOR_ELT (res , 1 , beta );
27
- SET_VECTOR_ELT (res , 2 , Dev );
27
+ SET_VECTOR_ELT (res , 2 , loss );
28
28
SET_VECTOR_ELT (res , 3 , iter );
29
29
UNPROTECT (5 );
30
30
return (res );
31
31
}
32
32
33
- // Coordinate descent for raw, unstandardized least squares
34
- // need to fit an intercept
33
+ // Coordinate descent for gaussian models
35
34
SEXP cdfit_raw (SEXP X_ , SEXP y_ , SEXP penalty_ , SEXP lambda , SEXP eps_ , SEXP max_iter_ , SEXP gamma_ , SEXP multiplier , SEXP alpha_ , SEXP dfmax_ , SEXP user_ ) {
36
35
37
36
// Declarations
@@ -48,9 +47,9 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
48
47
PROTECT (loss = allocVector (REALSXP , L ));
49
48
PROTECT (iter = allocVector (INTSXP , L ));
50
49
for (int i = 0 ; i < L ; i ++ ) INTEGER (iter )[i ] = 0 ;
51
- double * a = Calloc (p , double ); // Beta from previous iteration
50
+ double * a = Calloc (p , double ); // Beta from previous iteration
52
51
for (int j = 0 ; j < p ; j ++ ) a [j ]= 0 ;
53
- double a0 = 0 ; // Beta0 from previous iteration, initially 0 from KKT since y is mean-centered
52
+ double a0 = 0 ; // Beta0 from previous iteration, initially 0 from KKT since y is mean-centered
54
53
double * X = REAL (X_ );
55
54
double * y = REAL (y_ );
56
55
const char * penalty = CHAR (STRING_ELT (penalty_ , 0 ));
@@ -64,13 +63,15 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
64
63
int user = INTEGER (user_ )[0 ];
65
64
double * r = Calloc (n , double );
66
65
for (int i = 0 ; i < n ; i ++ ) r [i ] = y [i ];
67
- double * z = Calloc (p , double );
68
- for (int j = 0 ; j < p ; j ++ ) z [j ] = crossprod (X , r , n , j )/n ;
69
66
double * v = Calloc (p , double );
70
67
for (int j = 0 ; j < p ; j ++ ) v [j ] = sqsum (X , n , j )/n ;
71
- int * e = Calloc (p , int );
72
- for (int j = 0 ; j < p ; j ++ ) e [j ] = 1 ;
73
- double l1 , l2 , u , mean_resid , shift ;
68
+ double * z = Calloc (p , double );
69
+ for (int j = 0 ; j < p ; j ++ ) z [j ] = crossprod (X , r , n , j )/n ; // initial a[j] = 0
70
+ int * e1 = Calloc (p , int );
71
+ for (int j = 0 ; j < p ; j ++ ) e1 [j ] = 0 ;
72
+ int * e2 = Calloc (p , int );
73
+ for (int j = 0 ; j < p ; j ++ ) e2 [j ] = 0 ;
74
+ double cutoff , l1 , l2 , mean_resid , shift ;
74
75
int converged , lstart ;
75
76
76
77
// If lam[0]=lam_max, skip lam[0] -- closed form sol'n available
@@ -82,7 +83,7 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
82
83
}
83
84
84
85
// Path
85
- for (int l = lstart ; l < L ; l ++ ) {
86
+ for (int l = lstart ;l < L ;l ++ ) {
86
87
R_CheckUserInterrupt ();
87
88
if (l != 0 ) {
88
89
// Assign a0, a
@@ -92,66 +93,108 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
92
93
// Check dfmax
93
94
int nv = 0 ;
94
95
for (int j = 0 ; j < p ; j ++ ) {
95
- if (a [j ] != 0 ) nv ++ ;
96
+ if (a [j ] != 0 ) nv ++ ;
96
97
}
97
98
if (nv > dfmax ) {
98
99
for (int ll = l ; ll < L ; ll ++ ) INTEGER (iter )[ll ] = NA_INTEGER ;
99
- res = cleanupR (r , a , v , e , beta0 , beta , loss , iter );
100
+ res = cleanupR (a , r , v , z , e1 , e2 , beta0 , beta , loss , iter );
100
101
return (res );
101
102
}
103
+
104
+ // Determine eligible set
105
+ if (strcmp (penalty , "lasso" )== 0 ) cutoff = 2 * lam [l ] - lam [l - 1 ];
106
+ if (strcmp (penalty , "MCP" )== 0 ) cutoff = lam [l ] + gamma /(gamma - 1 )* (lam [l ] - lam [l - 1 ]);
107
+ if (strcmp (penalty , "SCAD" )== 0 ) cutoff = lam [l ] + gamma /(gamma - 2 )* (lam [l ] - lam [l - 1 ]);
108
+ for (int j = 0 ; j < p ; j ++ ) if (fabs (z [j ]) > (cutoff * alpha * m [j ])) e2 [j ] = 1 ;
109
+ } else {
110
+ // Determine eligible set
111
+ double lmax = 0 ;
112
+ for (int j = 0 ; j < p ; j ++ ) if (fabs (z [j ]) > lmax ) lmax = fabs (z [j ]);
113
+ if (strcmp (penalty , "lasso" )== 0 ) cutoff = 2 * lam [l ] - lmax ;
114
+ if (strcmp (penalty , "MCP" )== 0 ) cutoff = lam [l ] + gamma /(gamma - 1 )* (lam [l ] - lmax );
115
+ if (strcmp (penalty , "SCAD" )== 0 ) cutoff = lam [l ] + gamma /(gamma - 2 )* (lam [l ] - lmax );
116
+ for (int j = 0 ; j < p ; j ++ ) if (fabs (z [j ]) > (cutoff * alpha * m [j ])) e2 [j ] = 1 ;
102
117
}
103
118
104
119
while (INTEGER (iter )[l ] < max_iter ) {
105
120
while (INTEGER (iter )[l ] < max_iter ) {
106
- INTEGER (iter )[l ]++ ;
107
- // intercept
108
- mean_resid = 0.0 ;
109
- for (int i = 0 ; i < n ; i ++ ) mean_resid += r [i ];
110
- mean_resid /= n ;
111
- b0 [l ] = mean_resid + a0 ;
112
- for (int i = 0 ; i < n ; i ++ ) r [i ] -= mean_resid ;
121
+ while (INTEGER (iter )[l ] < max_iter ) {
122
+ // Solve over the active set
123
+ INTEGER (iter )[l ]++ ;
124
+ // intercept
125
+ mean_resid = 0.0 ;
126
+ for (int i = 0 ; i < n ; i ++ ) mean_resid += r [i ];
127
+ mean_resid /= n ;
128
+ b0 [l ] = mean_resid + a0 ;
129
+ for (int i = 0 ; i < n ; i ++ ) r [i ] -= mean_resid ;
113
130
114
- // covariates
131
+ for (int j = 0 ; j < p ; j ++ ) {
132
+ if (e1 [j ]) {
133
+ z [j ] = crossprod (X , r , n , j )/n + v [j ]* a [j ];
134
+
135
+ // Update beta_j
136
+ l1 = lam [l ] * m [j ] * alpha ;
137
+ l2 = lam [l ] * m [j ] * (1 - alpha );
138
+ if (strcmp (penalty ,"MCP" )== 0 ) b [l * p + j ] = MCP (z [j ], l1 , l2 , gamma , v [j ]);
139
+ if (strcmp (penalty ,"SCAD" )== 0 ) b [l * p + j ] = SCAD (z [j ], l1 , l2 , gamma , v [j ]);
140
+ if (strcmp (penalty ,"lasso" )== 0 ) b [l * p + j ] = lasso (z [j ], l1 , l2 , v [j ]);
141
+
142
+ // Update r
143
+ shift = b [l * p + j ] - a [j ];
144
+ if (shift != 0 ) for (int i = 0 ;i < n ;i ++ ) r [i ] -= shift * X [j * n + i ];
145
+ }
146
+ }
147
+
148
+ // Check for convergence
149
+ converged = checkConvergence (b , a , eps , l , p );
150
+ a0 = b0 [l ];
151
+ for (int j = 0 ; j < p ; j ++ ) a [j ] = b [l * p + j ];
152
+ if (converged ) break ;
153
+ }
154
+
155
+ // Scan for violations in strong set
156
+ int violations = 0 ;
115
157
for (int j = 0 ; j < p ; j ++ ) {
116
- if (e [j ]) {
117
- u = crossprod (X , r , n , j )/n + v [j ]* a [j ];
158
+ if (e1 [j ]== 0 && e2 [j ]== 1 ) {
118
159
119
- // Update b_j
160
+ z [j ] = crossprod (X , r , n , j )/n ; // a[j] = 0
161
+
162
+ // Update beta_j
120
163
l1 = lam [l ] * m [j ] * alpha ;
121
164
l2 = lam [l ] * m [j ] * (1 - alpha );
122
- if (strcmp (penalty ,"MCP" )== 0 ) b [l * p + j ] = MCP (u , l1 , l2 , gamma , v [j ]);
123
- if (strcmp (penalty ,"SCAD" )== 0 ) b [l * p + j ] = SCAD (u , l1 , l2 , gamma , v [j ]);
124
- if (strcmp (penalty ,"lasso" )== 0 ) b [l * p + j ] = lasso (u , l1 , l2 , v [j ]);
165
+ if (strcmp (penalty ,"MCP" )== 0 ) b [l * p + j ] = MCP (z [ j ] , l1 , l2 , gamma , v [j ]);
166
+ if (strcmp (penalty ,"SCAD" )== 0 ) b [l * p + j ] = SCAD (z [ j ] , l1 , l2 , gamma , v [j ]);
167
+ if (strcmp (penalty ,"lasso" )== 0 ) b [l * p + j ] = lasso (z [ j ] , l1 , l2 , v [j ]);
125
168
126
- // Update r
127
- shift = b [l * p + j ] - a [j ];
128
- if (shift != 0 ) for (int i = 0 ;i < n ;i ++ ) r [i ] -= shift * X [j * n + i ];
169
+ // If something enters the eligible set, update eligible set & residuals
170
+ if (b [l * p + j ] != 0 ) {
171
+ e1 [j ] = e2 [j ] = 1 ;
172
+ for (int i = 0 ; i < n ; i ++ ) r [i ] -= b [l * p + j ]* X [j * n + i ];
173
+ a [j ] = b [l * p + j ];
174
+ violations ++ ;
175
+ }
129
176
}
130
177
}
131
-
132
- // Check for convergence
133
- converged = checkConvergence (b , a , eps , l , p );
134
- a0 = b0 [l ];
135
- for (int j = 0 ; j < p ; j ++ ) a [j ] = b [l * p + j ];
136
- if (converged ) break ;
178
+ if (violations == 0 ) break ;
137
179
}
138
180
139
- // Scan for violations
181
+ // Scan for violations in rest
140
182
int violations = 0 ;
141
183
for (int j = 0 ; j < p ; j ++ ) {
142
- if (e [j ]== 0 ) {
143
- u = crossprod (X , r , n , j )/n + v [j ]* a [j ];
184
+ if (e2 [j ]== 0 ) {
185
+
186
+ z [j ] = crossprod (X , r , n , j )/n ; // a[j] = 0
144
187
145
- // Update b_j
188
+ // Update beta_j
146
189
l1 = lam [l ] * m [j ] * alpha ;
147
190
l2 = lam [l ] * m [j ] * (1 - alpha );
148
- if (strcmp (penalty ,"MCP" )== 0 ) b [l * p + j ] = MCP (u , l1 , l2 , gamma , v [j ]);
149
- if (strcmp (penalty ,"SCAD" )== 0 ) b [l * p + j ] = SCAD (u , l1 , l2 , gamma , v [j ]);
150
- if (strcmp (penalty ,"lasso" )== 0 ) b [l * p + j ] = lasso (u , l1 , l2 , v [j ]);
191
+ if (strcmp (penalty ,"MCP" )== 0 ) b [l * p + j ] = MCP (z [ j ] , l1 , l2 , gamma , v [j ]);
192
+ if (strcmp (penalty ,"SCAD" )== 0 ) b [l * p + j ] = SCAD (z [ j ] , l1 , l2 , gamma , v [j ]);
193
+ if (strcmp (penalty ,"lasso" )== 0 ) b [l * p + j ] = lasso (z [ j ] , l1 , l2 , v [j ]);
151
194
152
195
// If something enters the eligible set, update eligible set & residuals
153
196
if (b [l * p + j ] != 0 ) {
154
- e [j ] = 1 ;
197
+ e1 [ j ] = e2 [j ] = 1 ;
155
198
for (int i = 0 ; i < n ; i ++ ) r [i ] -= b [l * p + j ]* X [j * n + i ];
156
199
a [j ] = b [l * p + j ];
157
200
violations ++ ;
@@ -160,11 +203,11 @@ SEXP cdfit_raw(SEXP X_, SEXP y_, SEXP penalty_, SEXP lambda, SEXP eps_, SEXP max
160
203
}
161
204
162
205
if (violations == 0 ) {
163
- REAL (loss )[l ] = gLoss (r , n );
164
206
break ;
165
207
}
166
208
}
209
+ REAL (loss )[l ] = gLoss (r , n );
167
210
}
168
- res = cleanupR (r , a , v , e , beta0 , beta , loss , iter );
211
+ res = cleanupR (a , r , v , z , e1 , e2 , beta0 , beta , loss , iter );
169
212
return (res );
170
213
}
0 commit comments