Skip to content

Commit 80e832f

Browse files
beygelAlina Beygelzimerataymano
authored
feat: direct interface for active.cc and variable rename for understandability (#4671)
* updated active.cc * updates to active.cc * updates to tests * revert accidental help change in diff * removed diagnostic print statements --------- Co-authored-by: Alina Beygelzimer <[email protected]> Co-authored-by: Alexey Taymanov <[email protected]>
1 parent 2004e72 commit 80e832f

File tree

3 files changed

+78
-32
lines changed

3 files changed

+78
-32
lines changed

test/train-sets/ref/active-simulation.t24.stderr

+6-13
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,13 @@ Output pred = SCALAR
1111
average since example example current current current
1212
loss last counter weight label predict features
1313
1.000000 1.000000 1 1.0 -1.0000 0.0000 128
14-
0.791125 0.755288 2 6.8 -1.0000 -0.1309 44
15-
1.274829 1.444750 8 26.3 1.0000 -0.2020 34
16-
1.083985 0.895011 73 52.8 1.0000 0.0214 21
17-
0.887295 0.693362 130 106.3 -1.0000 -0.3071 146
18-
0.788245 0.690009 233 213.6 -1.0000 0.0421 47
19-
0.664628 0.541195 398 427.4 -1.0000 -0.1863 68
20-
0.634406 0.604328 835 856.9 -1.0000 -0.4327 40
2114

2215
finished run
2316
number of examples = 1000
24-
weighted example sum = 1014.004519
25-
weighted label sum = -68.618036
26-
average loss = 0.630964
27-
best constant = -0.067670
28-
best constant's loss = 0.995421
17+
weighted example sum = 1.000000
18+
weighted label sum = -1.000000
19+
average loss = 1.000000
20+
best constant = -1.000000
21+
best constant's loss = 0.000000
2922
total feature number = 78739
30-
total queries = 474
23+
total queries = 1

test/train-sets/ref/help.stdout

+6-2
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ Weight Options:
221221
[Reduction] Active Learning Options:
222222
--active Enable active learning (type: bool, keep, necessary)
223223
--simulation Active learning simulation mode (type: bool)
224-
--mellowness arg Active learning mellowness parameter c_0. Default 8 (type: float,
225-
default: 8, keep)
224+
--direct Active learning via the tag and predictions interface. Tag should
225+
start with "query?" to get query decision. Returned prediction
226+
is either -1 for no or the importance weight for yes. (type:
227+
bool)
228+
--mellowness arg Active learning mellowness parameter c_0. Default 1. (type: float,
229+
default: 1, keep)
226230
[Reduction] Active Learning with Cover Options:
227231
--active_cover Enable active learning with cover (type: bool, keep, necessary)
228232
--mellowness arg Active learning mellowness parameter c_0 (type: float, default:

vowpalwabbit/core/src/reductions/active.cc

+66-17
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,41 @@ using namespace VW::config;
3131
using namespace VW::reductions;
3232
namespace
3333
{
34-
float get_active_coin_bias(float k, float avg_loss, float g, float c0)
35-
{
36-
const float b = c0 * (std::log(k + 1.f) + 0.0001f) / (k + 0.0001f);
37-
const float sb = std::sqrt(b);
34+
float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
35+
{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
36+
const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
37+
const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count);
3838
// loss should be in [0,1]
3939
avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f);
4040

41-
const float sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g);
42-
if (g <= sb * sl + b) { return 1; }
43-
const float rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g);
44-
return b * rs * rs;
41+
const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
42+
std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater.
43+
//std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
44+
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;
45+
46+
if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use.
47+
+ mellow_log_e_count_over_e_count) { return 1; }
48+
//old equation
49+
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
50+
// return mellow_log_e_count_over_e_count * rs * rs;
51+
const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff;
52+
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
53+
return sqrt_s*sqrt_s;
4554
}
4655

47-
float query_decision(const active& a, float ec_revert_weight, float k)
56+
float query_decision(const active& a, float updates_to_change_prediction, float example_count)
4857
{
4958
float bias;
50-
if (k <= 1.f) { bias = 1.f; }
59+
if (example_count <= 1.f) { bias = 1.f; }
5160
else
5261
{
53-
const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
54-
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / k) +
55-
std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f));
56-
bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
62+
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
63+
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / example_count);
64+
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
65+
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
66+
bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0);
5767
}
58-
68+
// std::cout << "bias = " << bias << std::endl;
5969
return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f;
6070
}
6171

@@ -110,6 +120,34 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec)
110120
}
111121
}
112122

123+
template <bool is_learn>
124+
void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec)
125+
{
126+
if (is_learn) { base.learn(ec); }
127+
else { base.predict(ec); }
128+
129+
if (ec.l.simple.label == FLT_MAX)
130+
{
131+
if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?")
132+
{
133+
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;
134+
// We want to understand the change in prediction if the label were to be
135+
// the opposite of what was predicted. 0 and 1 are used for the expected min
136+
// and max labels to be coming in from the active interactor.
137+
ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label;
138+
ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec);
139+
ec.l.simple.label = FLT_MAX;
140+
ec.pred.scalar = query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));
141+
}
142+
}
143+
else
144+
{
145+
// Update seen labels based on the current example's label.
146+
a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label);
147+
a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label);
148+
}
149+
}
150+
113151
void active_print_result(
114152
VW::io::writer* f, float res, float weight, const VW::v_array<char>& tag, VW::io::logger& logger)
115153
{
@@ -189,14 +227,16 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
189227

190228
bool active_option = false;
191229
bool simulation = false;
230+
bool direct = false;
192231
float active_c0;
193232
option_group_definition new_options("[Reduction] Active Learning");
194233
new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning"))
195234
.add(make_option("simulation", simulation).help("Active learning simulation mode"))
235+
.add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes."))
196236
.add(make_option("mellowness", active_c0)
197237
.keep()
198-
.default_value(8.f)
199-
.help("Active learning mellowness parameter c_0. Default 8"));
238+
.default_value(1.f)
239+
.help("Active learning mellowness parameter c_0. Default 1."));
200240

201241
if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }
202242

@@ -223,6 +263,15 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::active_setup(VW::setup_bas
223263
print_update_func = VW::details::print_update_simple_label<active>;
224264
reduction_name.append("-simulation");
225265
}
266+
else if (direct)
267+
{
268+
learn_func = predict_or_learn_active_direct<true>;
269+
pred_func = predict_or_learn_active_direct<false>;
270+
update_stats_func = update_stats_active;
271+
output_example_prediction_func = VW::details::output_example_prediction_simple_label<active>;
272+
print_update_func = VW::details::print_update_simple_label<active>;
273+
learn_returns_prediction = base->learn_returns_prediction;
274+
}
226275
else
227276
{
228277
all.reduction_state.active = true;

0 commit comments

Comments
 (0)