Skip to content

Commit 670eaec

Browse files
Fixes 256 - ValueError mutable default (#263)
* Fixes 256 - ValueError mutable default Fixes issue #256. Error message: > ValueError: mutable default <class 'elk.training.train.Elicit'> for field run_template is not allowed: use default_factory * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * To lowercase --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 14669b1 commit 670eaec

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

elk/training/sweep.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from dataclasses import InitVar, dataclass, replace
1+
from dataclasses import InitVar, dataclass, field, replace
22

33
import numpy as np
44
import torch
55
from datasets import get_dataset_config_info
66
from transformers import AutoConfig
77

88
from ..evaluation import Eval
9-
from ..extraction import Extract
109
from ..files import memorably_named_dir, sweeps_dir
1110
from ..plotting.visualize import visualize_sweep
1211
from ..training.eigen_reporter import EigenFitterConfig
@@ -53,12 +52,7 @@ class Sweep:
5352
name: str | None = None
5453

5554
# A bit of a hack to add all the command line arguments from Elicit
56-
run_template: Elicit = Elicit(
57-
data=Extract(
58-
model="<placeholder>",
59-
datasets=("<placeholder>",),
60-
)
61-
)
55+
run_template: Elicit = field(default_factory=Elicit.default)
6256

6357
def __post_init__(self, add_pooled: bool):
6458
if not self.datasets:

elk/training/train.py

+10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from simple_parsing import subgroups
1212
from simple_parsing.helpers.serialization import save
1313

14+
from ..extraction import Extract
1415
from ..metrics import evaluate_preds, to_one_hot
1516
from ..run import Run
1617
from ..training.supervised import train_supervised
@@ -34,6 +35,15 @@ class Elicit(Run):
3435
cross-validation. Defaults to "single", which means to train a single classifier
3536
on the training data. "cv" means to use cross-validation."""
3637

38+
@staticmethod
39+
def default():
40+
return Elicit(
41+
data=Extract(
42+
model="<placeholder>",
43+
datasets=("<placeholder>",),
44+
)
45+
)
46+
3747
def create_models_dir(self, out_dir: Path):
3848
lr_dir = None
3949
lr_dir = out_dir / "lr_models"

0 commit comments

Comments
 (0)