@@ -31,31 +31,41 @@ using namespace VW::config;
31
31
using namespace VW ::reductions;
32
32
namespace
33
33
{
34
- float get_active_coin_bias (float k , float avg_loss, float g , float c0 )
35
- {
36
- const float b = c0 * (std::log (k + 1 .f ) + 0 .0001f ) / (k + 0 .0001f );
37
- const float sb = std::sqrt (b );
34
+ float get_active_coin_bias (float example_count , float avg_loss, float alt_label_error_rate_diff , float mellowness )
35
+ {// implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
36
+ const float mellow_log_e_count_over_e_count = mellowness * (std::log (example_count + 1 .f ) + 0 .0001f ) / (example_count + 0 .0001f );
37
+ const float sqrt_mellow_lecoec = std::sqrt (mellow_log_e_count_over_e_count );
38
38
// loss should be in [0,1]
39
39
avg_loss = VW::math::clamp (avg_loss, 0 .f , 1 .f );
40
40
41
- const float sl = std::sqrt (avg_loss) + std::sqrt (avg_loss + g);
42
- if (g <= sb * sl + b) { return 1 ; }
43
- const float rs = (sl + std::sqrt (sl * sl + 4 * g)) / (2 * g);
44
- return b * rs * rs;
41
+ const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min (1 .f , // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
42
+ std::sqrt (avg_loss + alt_label_error_rate_diff));// emperical variance deflater.
43
+ // std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
44
+ // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;
45
+
46
+ if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss// deflater in use.
47
+ + mellow_log_e_count_over_e_count) { return 1 ; }
48
+ // old equation
49
+ // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
50
+ // return mellow_log_e_count_over_e_count * rs * rs;
51
+ const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt (mellow_log_e_count_over_e_count+4 *alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2 *alt_label_error_rate_diff;
52
+ // std::cout << "sqrt_s = " << sqrt_s << std::endl;
53
+ return sqrt_s*sqrt_s;
45
54
}
46
55
47
- float query_decision (const active& a, float ec_revert_weight , float k )
56
+ float query_decision (const active& a, float updates_to_change_prediction , float example_count )
48
57
{
49
58
float bias;
50
- if (k <= 1 .f ) { bias = 1 .f ; }
59
+ if (example_count <= 1 .f ) { bias = 1 .f ; }
51
60
else
52
61
{
53
- const auto weighted_queries = static_cast <float >(a._shared_data ->weighted_labeled_examples );
54
- const float avg_loss = (static_cast <float >(a._shared_data ->sum_loss ) / k) +
55
- std::sqrt ((1 .f + 0 .5f * std::log (k)) / (weighted_queries + 0 .0001f ));
56
- bias = get_active_coin_bias (k, avg_loss, ec_revert_weight / k, a.active_c0 );
62
+ // const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
63
+ const float avg_loss = (static_cast <float >(a._shared_data ->sum_loss ) / example_count);
64
+ // + std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
65
+ // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
66
+ bias = get_active_coin_bias (example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0 );
57
67
}
58
-
68
+ // std::cout << "bias = " << bias << std::endl;
59
69
return (a._random_state ->get_and_update_random () < bias) ? 1 .f / bias : -1 .f ;
60
70
}
61
71
@@ -110,6 +120,34 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec)
110
120
}
111
121
}
112
122
123
+ template <bool is_learn>
124
+ void predict_or_learn_active_direct (active& a, learner& base, VW::example& ec)
125
+ {
126
+ if (is_learn) { base.learn (ec); }
127
+ else { base.predict (ec); }
128
+
129
+ if (ec.l .simple .label == FLT_MAX)
130
+ {
131
+ if (std::string (ec.tag .begin (), ec.tag .begin ()+6 ) == " query?" )
132
+ {
133
+ const float threshold = (a._shared_data ->max_label + a._shared_data ->min_label ) * 0 .5f ;
134
+ // We want to understand the change in prediction if the label were to be
135
+ // the opposite of what was predicted. 0 and 1 are used for the expected min
136
+ // and max labels to be coming in from the active interactor.
137
+ ec.l .simple .label = (ec.pred .scalar >= threshold) ? a._min_seen_label : a._max_seen_label ;
138
+ ec.confidence = std::abs (ec.pred .scalar - threshold) / base.sensitivity (ec);
139
+ ec.l .simple .label = FLT_MAX;
140
+ ec.pred .scalar = query_decision (a, ec.confidence , static_cast <float >(a._shared_data ->weighted_unlabeled_examples ));
141
+ }
142
+ }
143
+ else
144
+ {
145
+ // Update seen labels based on the current example's label.
146
+ a._min_seen_label = std::min (ec.l .simple .label , a._min_seen_label );
147
+ a._max_seen_label = std::max (ec.l .simple .label , a._max_seen_label );
148
+ }
149
+ }
150
+
113
151
void active_print_result (
114
152
VW::io::writer* f, float res, float weight, const VW::v_array<char >& tag, VW::io::logger& logger)
115
153
{
@@ -189,14 +227,16 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
189
227
190
228
bool active_option = false ;
191
229
bool simulation = false ;
230
+ bool direct = false ;
192
231
float active_c0;
193
232
option_group_definition new_options (" [Reduction] Active Learning" );
194
233
new_options.add (make_option (" active" , active_option).keep ().necessary ().help (" Enable active learning" ))
195
234
.add (make_option (" simulation" , simulation).help (" Active learning simulation mode" ))
235
+ .add (make_option (" direct" , direct).help (" Active learning via the tag and predictions interface. Tag should start with \" query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes." ))
196
236
.add (make_option (" mellowness" , active_c0)
197
237
.keep ()
198
- .default_value (8 .f )
199
- .help (" Active learning mellowness parameter c_0. Default 8 " ));
238
+ .default_value (1 .f )
239
+ .help (" Active learning mellowness parameter c_0. Default 1. " ));
200
240
201
241
if (!options.add_parse_and_check_necessary (new_options)) { return nullptr ; }
202
242
@@ -223,6 +263,15 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
223
263
print_update_func = VW::details::print_update_simple_label<active>;
224
264
reduction_name.append (" -simulation" );
225
265
}
266
+ else if (direct)
267
+ {
268
+ learn_func = predict_or_learn_active_direct<true >;
269
+ pred_func = predict_or_learn_active_direct<false >;
270
+ update_stats_func = update_stats_active;
271
+ output_example_prediction_func = VW::details::output_example_prediction_simple_label<active>;
272
+ print_update_func = VW::details::print_update_simple_label<active>;
273
+ learn_returns_prediction = base->learn_returns_prediction ;
274
+ }
226
275
else
227
276
{
228
277
all.reduction_state .active = true ;
0 commit comments