@@ -32,25 +32,37 @@ using namespace VW::reductions;
32
32
namespace
33
33
{
34
34
float get_active_coin_bias (float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
35
- {// implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
36
- const float mellow_log_e_count_over_e_count = mellowness * (std::log (example_count + 1 .f ) + 0 .0001f ) / (example_count + 0 .0001f );
35
+ { // implementation follows
36
+ // https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
37
+ const float mellow_log_e_count_over_e_count =
38
+ mellowness * (std::log (example_count + 1 .f ) + 0 .0001f ) / (example_count + 0 .0001f );
37
39
const float sqrt_mellow_lecoec = std::sqrt (mellow_log_e_count_over_e_count);
38
40
// loss should be in [0,1]
39
41
avg_loss = VW::math::clamp (avg_loss, 0 .f , 1 .f );
40
42
41
- const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min (1 .f , // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
42
- std::sqrt (avg_loss + alt_label_error_rate_diff));// emperical variance deflater.
43
- // std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
44
- // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;
45
-
46
- if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss// deflater in use.
47
- + mellow_log_e_count_over_e_count) { return 1 ; }
48
- // old equation
49
- // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
50
- // return mellow_log_e_count_over_e_count * rs * rs;
51
- const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt (mellow_log_e_count_over_e_count+4 *alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2 *alt_label_error_rate_diff;
43
+ const float sqrt_avg_loss_plus_sqrt_alt_loss =
44
+ std::min (1 .f , // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
45
+ std::sqrt (avg_loss + alt_label_error_rate_diff)); // emperical variance deflater.
46
+ // std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = "
47
+ // << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
48
+ // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss
49
+ // << std::endl;
50
+
51
+ if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss // deflater in use.
52
+ + mellow_log_e_count_over_e_count)
53
+ {
54
+ return 1 ;
55
+ }
56
+ // old equation
57
+ // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss *
58
+ // sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); return
59
+ // mellow_log_e_count_over_e_count * rs * rs;
60
+ const float sqrt_s = (sqrt_mellow_lecoec +
61
+ std::sqrt (mellow_log_e_count_over_e_count +
62
+ 4 * alt_label_error_rate_diff * mellow_log_e_count_over_e_count)) /
63
+ 2 * alt_label_error_rate_diff;
52
64
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
53
- return sqrt_s* sqrt_s;
65
+ return sqrt_s * sqrt_s;
54
66
}
55
67
56
68
float query_decision (const active& a, float updates_to_change_prediction, float example_count)
@@ -61,8 +73,10 @@ float query_decision(const active& a, float updates_to_change_prediction, float
61
73
{
62
74
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
63
75
const float avg_loss = (static_cast <float >(a._shared_data ->sum_loss ) / example_count);
64
- // + std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
65
- // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
76
+ // + std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not
77
+ // following why we need it from the theory.
78
+ // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " <<
79
+ // a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
66
80
bias = get_active_coin_bias (example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0 );
67
81
}
68
82
// std::cout << "bias = " << bias << std::endl;
@@ -122,31 +136,32 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec)
122
136
123
137
template <bool is_learn>
124
138
void predict_or_learn_active_direct (active& a, learner& base, VW::example& ec)
125
- {
139
+ {
126
140
if (is_learn) { base.learn (ec); }
127
141
else { base.predict (ec); }
128
-
142
+
129
143
if (ec.l .simple .label == FLT_MAX)
130
144
{
131
- if (std::string (ec.tag .begin (), ec.tag .begin ()+ 6 ) == " query?" )
132
- {
145
+ if (std::string (ec.tag .begin (), ec.tag .begin () + 6 ) == " query?" )
146
+ {
133
147
const float threshold = (a._shared_data ->max_label + a._shared_data ->min_label ) * 0 .5f ;
134
148
// We want to understand the change in prediction if the label were to be
135
149
// the opposite of what was predicted. 0 and 1 are used for the expected min
136
150
// and max labels to be coming in from the active interactor.
137
151
ec.l .simple .label = (ec.pred .scalar >= threshold) ? a._min_seen_label : a._max_seen_label ;
138
152
ec.confidence = std::abs (ec.pred .scalar - threshold) / base.sensitivity (ec);
139
153
ec.l .simple .label = FLT_MAX;
140
- ec.pred .scalar = query_decision (a, ec.confidence , static_cast <float >(a._shared_data ->weighted_unlabeled_examples ));
154
+ ec.pred .scalar =
155
+ query_decision (a, ec.confidence , static_cast <float >(a._shared_data ->weighted_unlabeled_examples ));
141
156
}
142
157
}
143
158
else
144
- {
159
+ {
145
160
// Update seen labels based on the current example's label.
146
161
a._min_seen_label = std::min (ec.l .simple .label , a._min_seen_label );
147
162
a._max_seen_label = std::max (ec.l .simple .label , a._max_seen_label );
148
- }
149
- }
163
+ }
164
+ }
150
165
151
166
void active_print_result (
152
167
VW::io::writer* f, float res, float weight, const VW::v_array<char >& tag, VW::io::logger& logger)
@@ -232,7 +247,9 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
232
247
option_group_definition new_options (" [Reduction] Active Learning" );
233
248
new_options.add (make_option (" active" , active_option).keep ().necessary ().help (" Enable active learning" ))
234
249
.add (make_option (" simulation" , simulation).help (" Active learning simulation mode" ))
235
- .add (make_option (" direct" , direct).help (" Active learning via the tag and predictions interface. Tag should start with \" query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes." ))
250
+ .add (make_option (" direct" , direct)
251
+ .help (" Active learning via the tag and predictions interface. Tag should start with \" query?\" to get "
252
+ " query decision. Returned prediction is either -1 for no or the importance weight for yes." ))
236
253
.add (make_option (" mellowness" , active_c0)
237
254
.keep ()
238
255
.default_value (1 .f )
0 commit comments