diff --git a/pytorch_influence_functions/calc_influence_function.py b/pytorch_influence_functions/calc_influence_function.py index 8861cd4..5abad63 100644 --- a/pytorch_influence_functions/calc_influence_function.py +++ b/pytorch_influence_functions/calc_influence_function.py @@ -343,8 +343,8 @@ def calc_influence_single(model, train_loader, test_loader, test_id_num, gpu, influences.append(tmp_influence) display_progress("Calc. influence function: ", i, train_dataset_size) - harmful = np.argsort(influences) - helpful = harmful[::-1] + helpful = np.argsort(influences) + harmful = helpful[::-1] return influences, harmful.tolist(), helpful.tolist(), test_id_num