Skip to content

Commit

Permalink
Fixed bugs in accuracy test
Browse files Browse the repository at this point in the history
  • Loading branch information
AnotherSamWilson committed Jul 27, 2024
1 parent 2fde147 commit 60e498f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/test_imputed_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def make_dataset(seed):
random_state = np.random.RandomState(seed)
iris = pd.concat(load_iris(return_X_y=True, as_frame=True), axis=1)
iris["bi"] = random_state.binomial(
1, (iris["target"] == 0).map({True: 0.85, False: 0.15}), size=150
1, (iris["target"] == 0).map({True: 0.9, False: 0.10}), size=150
)
iris["bi"] = iris["bi"].astype("category")
iris["sp"] = iris["target"].map({0: "A", 1: "B", 2: "C"}).astype("category")
Expand All @@ -25,7 +25,7 @@ def make_dataset(seed):
axis=1,
inplace=True,
)
iris_amp = mf.utils.ampute_data(iris, perc=0.20)
iris_amp = mf.utils.ampute_data(iris, perc=0.20, random_state=random_state)

return iris, iris_amp

Expand Down Expand Up @@ -83,7 +83,7 @@ def get_categorical_performance(kernel: mf.ImputationKernel, variables, iris):
rocs[col] = roc_auc_score(orig, preds, multi_class="ovr", average="macro")
accs[col] = (imps == orig).mean()
rand_accs[col] = np.sum(
cand.value_counts(normalize=True) * imps.value_counts(normalize=True)
cand.value_counts(normalize=True) * orig.value_counts(normalize=True)
)
rocs = pd.Series(rocs)
accs = pd.Series(accs)
Expand All @@ -94,7 +94,8 @@ def get_categorical_performance(kernel: mf.ImputationKernel, variables, iris):
def test_defaults():

for i in range(10):
# i = 0
# i = 3
print(i)
iris, iris_amp = make_dataset(i)
kernel_1 = mf.ImputationKernel(
iris_amp,
Expand Down Expand Up @@ -166,7 +167,7 @@ def test_custom_params():
iterations=4,
verbose=False,
boosting="random_forest",
num_iterations=500,
num_iterations=200,
min_data_in_leaf=2,
)
kernel_1.complete_data(0, inplace=True)
Expand Down

0 comments on commit 60e498f

Please sign in to comment.