Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes 256 - ValueError mutable default #263

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import InitVar, dataclass, replace
from dataclasses import InitVar, dataclass, field, replace

import numpy as np
import torch
from datasets import get_dataset_config_info
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenReporterConfig
Expand Down Expand Up @@ -52,12 +51,7 @@ class Sweep:
name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)
run_template: Elicit = field(default_factory=Elicit.default)

def __post_init__(self, add_pooled: bool):
if not self.datasets:
Expand Down
10 changes: 10 additions & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simple_parsing import subgroups
from simple_parsing.helpers.serialization import save

from ..extraction import Extract
from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..training.supervised import train_supervised
Expand All @@ -34,6 +35,15 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

@staticmethod
def default():
return Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)

def create_models_dir(self, out_dir: Path):
lr_dir = None
lr_dir = out_dir / "lr_models"
Expand Down