Skip to content

Commit 9837a0e

Browse files
authored
style: forgotten lint fix (#4688)
* lint * undo fake commit * undo fake commit * forgotten space
1 parent 80e832f commit 9837a0e

File tree

1 file changed

+42
-25
lines changed

1 file changed

+42
-25
lines changed

vowpalwabbit/core/src/reductions/active.cc

+42-25
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,37 @@ using namespace VW::reductions;
3232
namespace
3333
{
3434
float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
35-
{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
36-
const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
35+
{ // implementation follows
36+
// https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
37+
const float mellow_log_e_count_over_e_count =
38+
mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
3739
const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count);
3840
// loss should be in [0,1]
3941
avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f);
4042

41-
const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
42-
std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater.
43-
//std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
44-
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;
45-
46-
if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use.
47-
+ mellow_log_e_count_over_e_count) { return 1; }
48-
//old equation
49-
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
50-
// return mellow_log_e_count_over_e_count * rs * rs;
51-
const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff;
43+
const float sqrt_avg_loss_plus_sqrt_alt_loss =
44+
std::min(1.f, // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
45+
std::sqrt(avg_loss + alt_label_error_rate_diff)); // emperical variance deflater.
46+
// std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = "
47+
// << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
48+
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss
49+
//<< std::endl;
50+
51+
if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss // deflater in use.
52+
+ mellow_log_e_count_over_e_count)
53+
{
54+
return 1;
55+
}
56+
// old equation
57+
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss *
58+
// sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); return
59+
// mellow_log_e_count_over_e_count * rs * rs;
60+
const float sqrt_s = (sqrt_mellow_lecoec +
61+
std::sqrt(mellow_log_e_count_over_e_count +
62+
4 * alt_label_error_rate_diff * mellow_log_e_count_over_e_count)) /
63+
2 * alt_label_error_rate_diff;
5264
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
53-
return sqrt_s*sqrt_s;
65+
return sqrt_s * sqrt_s;
5466
}
5567

5668
float query_decision(const active& a, float updates_to_change_prediction, float example_count)
@@ -61,8 +73,10 @@ float query_decision(const active& a, float updates_to_change_prediction, float
6173
{
6274
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
6375
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / example_count);
64-
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
65-
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
76+
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not
77+
// following why we need it from the theory.
78+
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " <<
79+
// a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
6680
bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0);
6781
}
6882
// std::cout << "bias = " << bias << std::endl;
@@ -122,31 +136,32 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec)
122136

123137
template <bool is_learn>
124138
void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec)
125-
{
139+
{
126140
if (is_learn) { base.learn(ec); }
127141
else { base.predict(ec); }
128-
142+
129143
if (ec.l.simple.label == FLT_MAX)
130144
{
131-
if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?")
132-
{
145+
if (std::string(ec.tag.begin(), ec.tag.begin() + 6) == "query?")
146+
{
133147
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;
134148
// We want to understand the change in prediction if the label were to be
135149
// the opposite of what was predicted. 0 and 1 are used for the expected min
136150
// and max labels to be coming in from the active interactor.
137151
ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label;
138152
ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec);
139153
ec.l.simple.label = FLT_MAX;
140-
ec.pred.scalar = query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));
154+
ec.pred.scalar =
155+
query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));
141156
}
142157
}
143158
else
144-
{
159+
{
145160
// Update seen labels based on the current example's label.
146161
a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label);
147162
a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label);
148-
}
149-
}
163+
}
164+
}
150165

151166
void active_print_result(
152167
VW::io::writer* f, float res, float weight, const VW::v_array<char>& tag, VW::io::logger& logger)
@@ -232,7 +247,9 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
232247
option_group_definition new_options("[Reduction] Active Learning");
233248
new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning"))
234249
.add(make_option("simulation", simulation).help("Active learning simulation mode"))
235-
.add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes."))
250+
.add(make_option("direct", direct)
251+
.help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get "
252+
"query decision. Returned prediction is either -1 for no or the importance weight for yes."))
236253
.add(make_option("mellowness", active_c0)
237254
.keep()
238255
.default_value(1.f)

0 commit comments

Comments
 (0)