Skip to content

Commit 0e0cea7

Browse files
committed
Fix the experiments
1 parent 42d80d3 commit 0e0cea7

9 files changed

+243
-76
lines changed

experiments/custom_utilities_methods.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def bc_macro_min_tp_tn(
6464
alpha: float = 1,
6565
tolerance: float = 1e-6,
6666
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
67-
max_iter: int = 100,
67+
max_iters: int = 100,
6868
shuffle_order: bool = True,
6969
verbose: bool = False,
7070
return_meta: bool = False,
@@ -80,7 +80,7 @@ def bc_macro_min_tp_tn(
8080
skip_tn=True,
8181
tolerance=tolerance,
8282
init_y_pred=init_y_pred,
83-
max_iter=max_iter,
83+
max_iters=max_iters,
8484
shuffle_order=shuffle_order,
8585
verbose=verbose,
8686
return_meta=return_meta,
@@ -94,7 +94,7 @@ def bc_micro_f1(
9494
alpha: float = 1,
9595
tolerance: float = 1e-6,
9696
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
97-
max_iter: int = 100,
97+
max_iters: int = 100,
9898
shuffle_order: bool = True,
9999
verbose: bool = False,
100100
return_meta: bool = False,
@@ -109,7 +109,7 @@ def bc_macro_hmean(
109109
alpha: float = 1,
110110
tolerance: float = 1e-6,
111111
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
112-
max_iter: int = 100,
112+
max_iters: int = 100,
113113
shuffle_order: bool = True,
114114
verbose: bool = False,
115115
return_meta: bool = False,
@@ -125,7 +125,7 @@ def bc_macro_hmean(
125125
skip_tn=True,
126126
tolerance=tolerance,
127127
init_y_pred=init_y_pred,
128-
max_iter=max_iter,
128+
max_iters=max_iters,
129129
shuffle_order=shuffle_order,
130130
verbose=verbose,
131131
return_meta=return_meta,
@@ -139,7 +139,7 @@ def bc_macro_gmean(
139139
alpha: float = 1,
140140
tolerance: float = 1e-6,
141141
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
142-
max_iter: int = 100,
142+
max_iters: int = 100,
143143
shuffle_order: bool = True,
144144
verbose: bool = False,
145145
return_meta: bool = False,
@@ -155,7 +155,7 @@ def bc_macro_gmean(
155155
skip_tn=True,
156156
tolerance=tolerance,
157157
init_y_pred=init_y_pred,
158-
max_iter=max_iter,
158+
max_iters=max_iters,
159159
shuffle_order=shuffle_order,
160160
verbose=verbose,
161161
return_meta=return_meta,

experiments/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
xcolumns
22
click
3+
tqdm

experiments/run_iclr_2024_fw_experiment.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
from xcolumns.weighted_prediction import *
2424

2525

26-
# TODO: refactor this
27-
RECALCULATE_RESUTLS = False
28-
RECALCULATE_PREDICTION = False
29-
RETRAIN_MODEL = False
30-
31-
3226
def frank_wolfe_wrapper(
3327
Y_val,
3428
pred_val,
@@ -592,7 +586,25 @@ def predict_proba(self, X, top_k):
592586
@click.option("-s", "--seed", type=int, required=True)
593587
@click.option("-m", "--method", type=str, required=False, default=None)
594588
@click.option("-t", "--testsplit", type=float, required=False, default=0)
595-
def main(experiment, k, seed, method, testsplit):
589+
@click.option("-r", "--results_dir", type=str, required=False, default="results_fw/")
590+
@click.option(
591+
"--recalculate_predictions", is_flag=True, type=bool, required=False, default=False
592+
)
593+
@click.option(
594+
"--recalculate_results", is_flag=True, type=bool, required=False, default=False
595+
)
596+
@click.option("--retrain_model", is_flag=True, type=bool, required=False, default=False)
597+
def main(
598+
experiment,
599+
k,
600+
seed,
601+
method,
602+
testsplit,
603+
results_dir,
604+
recalculate_predictions,
605+
recalculate_results,
606+
retrain_model,
607+
):
596608
print(experiment)
597609

598610
if method is None:
@@ -894,7 +906,7 @@ def main(experiment, k, seed, method, testsplit):
894906
model = PytorchModel(model_path, model_seed, loss="asym")
895907

896908
if isinstance(model, ModelWrapper):
897-
if not os.path.exists(model_path) or RETRAIN_MODEL:
909+
if not os.path.exists(model_path) or retrain_model:
898910
with Timer():
899911
model.fit(X_train, Y_train, X_test, Y_test)
900912
# else:
@@ -907,7 +919,7 @@ def main(experiment, k, seed, method, testsplit):
907919
top_k = 100
908920
print("Predicting for validation set ...")
909921
val_pred_path = f"models_and_predictions/{experiment}_seed={model_seed}_split={1 - testsplit}_top_k={top_k}_pred_val.pkl"
910-
if not os.path.exists(val_pred_path) or RETRAIN_MODEL:
922+
if not os.path.exists(val_pred_path) or retrain_model:
911923
with Timer():
912924
pred_val = model.predict_proba(X_val, top_k=top_k)
913925
align_dim1(Y_train, pred_val)
@@ -920,7 +932,7 @@ def main(experiment, k, seed, method, testsplit):
920932

921933
print("Predicting for test set ...")
922934
test_pred_path = f"models_and_predictions/{experiment}_seed={model_seed}_split={1 - testsplit}_top_k={top_k}_pred_test.pkl"
923-
if not os.path.exists(test_pred_path) or RETRAIN_MODEL:
935+
if not os.path.exists(test_pred_path) or retrain_model:
924936
with Timer():
925937
pred_test = model.predict_proba(X_test, top_k=top_k)
926938
align_dim1(Y_train, pred_test)
@@ -941,7 +953,7 @@ def main(experiment, k, seed, method, testsplit):
941953
pred_test = sp.csr_matrix(pred_test)
942954

943955
print("Calculating metrics ...")
944-
output_path_prefix = f"results_fw/{experiment}/"
956+
output_path_prefix = f"{results_dir}/{experiment}/"
945957
os.makedirs(output_path_prefix, exist_ok=True)
946958
for method, func in methods.items():
947959
print(f"{experiment} - {method} @ {k}: ")
@@ -950,9 +962,9 @@ def main(experiment, k, seed, method, testsplit):
950962
results_path = f"{output_path}_results.json"
951963
pred_path = f"{output_path}_pred.pkl"
952964

953-
if not os.path.exists(results_path) or RECALCULATE_RESUTLS:
965+
if not os.path.exists(results_path) or recalculate_results:
954966
results = {}
955-
if not os.path.exists(pred_path) or RECALCULATE_PREDICTION:
967+
if not os.path.exists(pred_path) or recalculate_predictions:
956968
# results["test_log_loss"] = log_loss(Y_test, pred_test)
957969
# results["val_log_loss"] = log_loss(Y_val, pred_val)
958970

experiments/run_neurips_2023_bca_experiment.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"optimal-instance-prec": (predict_optimizing_instance_precision, {}),
2828
# "block-coord-instance-prec": (bc_instance_precision_at_k, {}), # This is the same as optimal-instance-prec but using block coordinate, for sanity-check purposes only
2929
"optimal-instance-ps-prec": (
30-
predict_optimizing_instance_propensity_weighted_precision,
30+
predict_optimizing_instance_propensity_scored_precision,
3131
{},
3232
),
3333
"power-law-with-beta=0.75": (
@@ -136,34 +136,34 @@
136136
# Greedy / 1 iter variants
137137
"greedy-macro-prec": (
138138
predict_optimizing_macro_precision_using_bc,
139-
{"init_y_pred": "greedy", "max_iter": 1},
139+
{"init_y_pred": "greedy", "max_iters": 1},
140140
),
141141
"greedy-macro-recall": (
142142
predict_optimizing_macro_precision_using_bc,
143-
{"init_y_pred": "greedy", "max_iter": 1},
143+
{"init_y_pred": "greedy", "max_iters": 1},
144144
),
145145
"greedy-macro-f1": (
146146
predict_optimizing_macro_f1_score_using_bc,
147-
{"init_y_pred": "greedy", "max_iter": 1},
147+
{"init_y_pred": "greedy", "max_iters": 1},
148148
),
149149
"greedy-cov": (
150150
predict_optimizing_coverage_using_bc,
151-
{"init_y_pred": "greedy", "max_iter": 1},
151+
{"init_y_pred": "greedy", "max_iters": 1},
152152
),
153153
#
154154
"block-coord-macro-prec-iter=1": (
155155
predict_optimizing_macro_precision_using_bc,
156-
{"max_iter": 1},
156+
{"max_iters": 1},
157157
),
158158
"block-coord-macro-recall-iter=1": (
159159
predict_optimizing_macro_precision_using_bc,
160-
{"max_iter": 1},
160+
{"max_iters": 1},
161161
),
162162
"block-coord-macro-f1-iter=1": (
163163
predict_optimizing_macro_f1_score_using_bc,
164-
{"max_iter": 1},
164+
{"max_iters": 1},
165165
),
166-
"block-coord-cov-iter=1": (predict_optimizing_coverage_using_bc, {"max_iter": 1}),
166+
"block-coord-cov-iter=1": (predict_optimizing_coverage_using_bc, {"max_iters": 1}),
167167
#
168168
# Similar results to the above
169169
# "greedy-start-block-coord-macro-prec": (predict_optimizing_macro_precision_using_bc, {"init_y_pred": "greedy"},),
@@ -227,7 +227,7 @@ def calculate_and_report_metrics(y_true, y_pred, k, metrics):
227227
@click.option("-m", "--method", type=str, required=False, default=None)
228228
@click.option("-p", "--probabilities_path", type=str, required=False, default=None)
229229
@click.option("-l", "--labels_path", type=str, required=False, default=None)
230-
@click.option("-r", "--results_dir", type=str, required=False, default="results/")
230+
@click.option("-r", "--results_dir", type=str, required=False, default="results_bc/")
231231
@click.option(
232232
"--recalculate_predictions", is_flag=True, type=bool, required=False, default=False
233233
)

experiments/utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515

1616
def call_function_with_supported_kwargs(func, *args, **kwargs):
17-
selected_kwargs = {
18-
k: v for k, v in kwargs.items() if k in func.__code__.co_varnames
19-
}
17+
if hasattr(func, "__signature__"):
18+
params = func.__signature__.parameters
19+
else:
20+
params = func.__code__.co_varnames
21+
selected_kwargs = {k: v for k, v in kwargs.items() if k in params}
2022
return func(*args, **selected_kwargs)
2123

2224

xcolumns/block_coordinate.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,12 @@ def predict_using_bc_with_0approx(
277277
binary_metric_func: Union[Callable, List[Callable]],
278278
k: int,
279279
metric_aggregation: str = "mean", # "mean" or "sum"
280-
maximize=True,
280+
maximize: bool = True,
281281
tolerance: float = 1e-6,
282282
init_y_pred: Union[str, Matrix] = "random", # "random", "top", "greedy", Matrix
283-
max_iter: int = 100,
283+
max_iters: int = 100,
284284
shuffle_order: bool = True,
285-
skip_tn=False,
285+
skip_tn: bool = False,
286286
return_meta: bool = False,
287287
seed: Optional[int] = None,
288288
verbose: bool = False,
@@ -313,7 +313,7 @@ def predict_using_bc_with_0approx(
313313
maximize: Whether to maximize the metric.
314314
tolerance: Defines the stopping condition, if the expected improvement of the metric is smaller than **tolerance** the algorithm stops.
315315
init_y_pred: The initial prediction matrix. It can be either "random", "top", "greedy" or a matrix of shape (n, m).
316-
max_iter: The maximum number of iterations.
316+
max_iters: The maximum number of iterations.
317317
shuffle_order: Whether to shuffle the order of instances in each iteration.
318318
skip_tn: Whether to skip the calculation of True Negatives in the confusion matrix, if the metric does not use the True Negatives, this can speed up the calculation, especially when using sparse matrices.
319319
return_meta: Whether to return the meta information.
@@ -345,7 +345,13 @@ def my_binary_f1_score_on_conf_matrix(tp, fp, fn, tn):
345345
"""
346346

347347
log_info(
348-
f"Starting optimization of ETU metric using block coordinate {'ascent' if maximize else 'descent'} algorithm ...",
348+
f"Starting optimization of ETU metric using block coordinate {'ascent (maximization)' if maximize else 'descent (minimization)'} algorithm ...",
349+
verbose,
350+
)
351+
if k > 0:
352+
log_info(f" Budget k: {k}", verbose)
353+
log_info(
354+
f" Tolerance (stopping condition): {tolerance}, max iterations: {max_iters}",
349355
verbose,
350356
)
351357

@@ -378,8 +384,8 @@ def my_binary_f1_score_on_conf_matrix(tp, fp, fn, tn):
378384
# Initialize the instance order and set seed for shuffling
379385
rng = np.random.default_rng(seed)
380386
order = np.arange(n)
381-
for j in range(1, max_iter + 1):
382-
log_info(f" Starting iteration {j}/{max_iter} ...", verbose)
387+
for j in range(1, max_iters + 1):
388+
log_info(f" Starting iteration {j}/{max_iters} ...", verbose)
383389

384390
if shuffle_order:
385391
rng.shuffle(order)
@@ -436,7 +442,7 @@ def my_binary_f1_score_on_conf_matrix(tp, fp, fn, tn):
436442
meta["utilities"].append(new_utility)
437443

438444
log_info(
439-
f" Iteration {j}/{max_iter} finished, expected metric value: {old_utility} -> {new_utility}",
445+
f" Iteration {j}/{max_iters} finished, expected metric value: {old_utility} -> {new_utility}",
440446
verbose,
441447
)
442448
if (
@@ -565,7 +571,7 @@ def predict_optimizing_coverage_using_bc(
565571
init_y_pred: Union[
566572
str, np.ndarray, csr_matrix
567573
] = "random", # "random", "topk", "random", or csr_matrix
568-
max_iter: int = 100,
574+
max_iters: int = 100,
569575
shuffle_order: bool = True,
570576
return_meta: bool = False,
571577
seed: Optional[int] = None,
@@ -614,8 +620,8 @@ def predict_optimizing_coverage_using_bc(
614620
# Initialize the instance order and set seed for shuffling
615621
rng = np.random.default_rng(seed)
616622
order = np.arange(n)
617-
for j in range(1, max_iter + 1):
618-
log_info(f" Starting iteration {j}/{max_iter} ...", verbose)
623+
for j in range(1, max_iters + 1):
624+
log_info(f" Starting iteration {j}/{max_iters} ...", verbose)
619625

620626
if shuffle_order:
621627
rng.shuffle(order)
@@ -642,7 +648,7 @@ def predict_optimizing_coverage_using_bc(
642648
meta["utilities"].append(new_cov)
643649

644650
log_info(
645-
f" Iteration {j}/{max_iter} finished, expected coverage: {old_cov} -> {new_cov}",
651+
f" Iteration {j}/{max_iters} finished, expected coverage: {old_cov} -> {new_cov}",
646652
verbose,
647653
)
648654
if new_cov <= old_cov + tolerance:
@@ -759,7 +765,7 @@ def predict_optimizing_instance_precision_using_bc(
759765
k: int,
760766
tolerance: float = 1e-6,
761767
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
762-
max_iter: int = 100,
768+
max_iters: int = 100,
763769
shuffle_order: bool = True,
764770
verbose: bool = False,
765771
return_meta: bool = False,
@@ -779,7 +785,7 @@ def instance_precision_with_specific_k(tp, fp, fn, tn):
779785
metric_aggregation="sum",
780786
tolerance=tolerance,
781787
init_y_pred=init_y_pred,
782-
max_iter=max_iter,
788+
max_iters=max_iters,
783789
shuffle_order=shuffle_order,
784790
verbose=verbose,
785791
return_meta=return_meta,
@@ -792,7 +798,7 @@ def predict_optimizing_mixed_instance_precision_and_macro_precision_using_bc(
792798
alpha: float = 1,
793799
tolerance: float = 1e-6,
794800
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
795-
max_iter: int = 100,
801+
max_iters: int = 100,
796802
shuffle_order: bool = True,
797803
verbose: bool = False,
798804
return_meta: bool = False,
@@ -817,7 +823,7 @@ def mixed_utility_fn(tp, fp, fn, tn):
817823
skip_tn=True,
818824
tolerance=tolerance,
819825
init_y_pred=init_y_pred,
820-
max_iter=max_iter,
826+
max_iters=max_iters,
821827
shuffle_order=shuffle_order,
822828
verbose=verbose,
823829
return_meta=return_meta,
@@ -830,7 +836,7 @@ def predict_optimizing_mixed_instance_precision_and_macro_recall_using_bc(
830836
alpha: float = 1,
831837
tolerance: float = 1e-6,
832838
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
833-
max_iter: int = 100,
839+
max_iters: int = 100,
834840
shuffle_order: bool = True,
835841
verbose: bool = False,
836842
return_meta: bool = False,
@@ -855,7 +861,7 @@ def mixed_utility_fn(tp, fp, fn, tn):
855861
skip_tn=True,
856862
tolerance=tolerance,
857863
init_y_pred=init_y_pred,
858-
max_iter=max_iter,
864+
max_iters=max_iters,
859865
shuffle_order=shuffle_order,
860866
verbose=verbose,
861867
return_meta=return_meta,
@@ -868,7 +874,7 @@ def predict_optimizing_mixed_instance_precision_and_macro_f1_score_using_bc(
868874
alpha: float = 1,
869875
tolerance: float = 1e-6,
870876
init_y_pred: Union[str, np.ndarray, csr_matrix] = "random",
871-
max_iter: int = 100,
877+
max_iters: int = 100,
872878
shuffle_order: bool = True,
873879
verbose: bool = False,
874880
return_meta: bool = False,
@@ -893,7 +899,7 @@ def mixed_utility_fn(tp, fp, fn, tn):
893899
skip_tn=True,
894900
tolerance=tolerance,
895901
init_y_pred=init_y_pred,
896-
max_iter=max_iter,
902+
max_iters=max_iters,
897903
shuffle_order=shuffle_order,
898904
verbose=verbose,
899905
return_meta=return_meta,

0 commit comments

Comments
 (0)