Skip to content

Commit c04a86d

Browse files
authored
Merge pull request #18 from EleutherAI/interventions
import classifier at module level
2 parents a9fb28f + 53bad05 commit c04a86d

File tree

8 files changed

+23
-18
lines changed

8 files changed

+23
-18
lines changed

elk_generalization/elk/ccs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
import torch
99
import torch.nn as nn
10-
from classifier import Classifier
1110
from concept_erasure import LeaceFitter
1211
from einops import repeat
1312
from torch import Tensor, optim
1413
from typing_extensions import override
1514

1615
from elk_generalization.elk.burns_norm import BurnsNorm
1716
from elk_generalization.elk.ccs_losses import LOSSES, parse_loss
17+
from elk_generalization.elk.classifier import Classifier
1818

1919

2020
@dataclass

elk_generalization/elk/crc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
22
import torch.nn.functional as F
3-
from classifier import Classifier
43
from concept_erasure import LeaceEraser
54
from torch import Tensor, nn, optim
65

6+
from elk_generalization.elk.classifier import Classifier
7+
78

89
class CrcReporter(Classifier):
910
def __init__(self, in_features: int, device: torch.device, dtype: torch.dtype):

elk_generalization/elk/lda.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from classifier import Classifier
32
from concept_erasure.shrinkage import optimal_linear_shrinkage
43
from sklearn.metrics import accuracy_score, roc_auc_score
54
from torch import Tensor, nn
65

6+
from elk_generalization.elk.classifier import Classifier
7+
78

89
class LdaReporter(Classifier):
910
def __init__(self, in_features: int, device: torch.device, dtype: torch.dtype):

elk_generalization/elk/lr_classifier.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from classifier import Classifier
32
from torch import Tensor
43
from torch.nn.functional import binary_cross_entropy_with_logits as bce_with_logits
54
from torch.nn.functional import cross_entropy
65

6+
from elk_generalization.elk.classifier import Classifier
7+
78

89
class LogisticRegression(Classifier):
910
"""Linear classifier trained with supervised learning."""

elk_generalization/elk/mean_diff.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
2-
from classifier import Classifier
32
from sklearn.metrics import accuracy_score, roc_auc_score
43
from torch import Tensor, nn
54

5+
from elk_generalization.elk.classifier import Classifier
6+
67

78
class MeanDiffReporter(Classifier):
89
def __init__(self, in_features: int, device: torch.device, dtype: torch.dtype):

elk_generalization/elk/run_transfers.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
get_ceiling_latent_knowledge = False
4747

4848
# code to modify models and datasets based on rank
49+
models = models[args.rank :: 8]
4950
print(ds_names, models)
5051

5152

@@ -61,14 +62,14 @@ def unpack_abbrev(ds_name, abbrev):
6162
exps = {k: ["B->B", "BE->B,BH"] for k in ["lr", "mean-diff", "lda"]}
6263
else:
6364
exps = {
64-
"lr": ["A->A,B,AH,BH", "B->B,A,BH", "B->BH", "AE->AE,AH,BE,BH"],
65-
"mean-diff": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
66-
"lda": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
67-
"lr-on-pair": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
65+
# "lr": ["A->A,B,AH,BH", "B->B,A,BH", "B->BH", "AE->AE,AH,BE,BH"],
66+
# "mean-diff": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
67+
# "lda": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
68+
# "lr-on-pair": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
6869
"mean-diff-on-pair": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH"],
69-
"ccs": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH", "all->all,BH"],
70-
"crc": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH", "all->all,BH"],
71-
"random": ["AE->AE,BH"],
70+
# "ccs": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH", "all->all,BH"],
71+
# "crc": ["A->A,B,AH,BH", "B->B,A", "AE->AE,AH,BH", "all->all,BH"],
72+
# "random": ["AE->AE,BH"],
7273
}
7374

7475
experiments_dir = "../../experiments"

elk_generalization/elk/transfer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from pathlib import Path
33

44
import torch
5-
from classifier import Classifier
65
from sklearn.metrics import accuracy_score, roc_auc_score
76
from tqdm import tqdm
87

98
from elk_generalization.elk.ccs import CcsConfig, CcsReporter
9+
from elk_generalization.elk.classifier import Classifier
1010
from elk_generalization.elk.crc import CrcReporter
1111
from elk_generalization.elk.lda import LdaReporter
1212
from elk_generalization.elk.lr_classifier import LogisticRegression

elk_generalization/results/figures.ipynb

+5-5
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@
10451045
"anomaly_ds_names = ds_names.copy()\n",
10461046
"anomaly_ds_names.remove(\"population\") # population only has false labels on H\n",
10471047
"\n",
1048-
"subtract_diag = False\n",
1048+
"subtract_diag = True\n",
10491049
"root = \"../../anomaly-results/\"\n",
10501050
"caption = \"Mechanistic anomaly detection AUROC. Note the Population dataset is omitted because the easy subset only contains true labels.\"\n",
10511051
"if subtract_diag:\n",
@@ -1189,15 +1189,15 @@
11891189
"model_sizes = {'Pythia': [0.41, 1, 1.4, 2.8, 6.9, 12],\n",
11901190
" 'Llama': [7],\n",
11911191
" 'Mistral': [7]}\n",
1192-
"pgr_values = {'Pythia': [0.48, 0.46, 0.56, 0.56, 0.55, 0.57],\n",
1193-
" 'Llama': [0.61],\n",
1194-
" 'Mistral': [0.61]}\n",
1192+
"pgr_values = {'Pythia': [0.47, 0.47, 0.55, 0.56, 0.56, 0.57],\n",
1193+
" 'Llama': [0.62],\n",
1194+
" 'Mistral': [0.63]}\n",
11951195
"\n",
11961196
"\n",
11971197
"plt.figure(figsize=(4.3, 3), dpi=150)\n",
11981198
"plt.plot(model_sizes['Pythia'], pgr_values['Pythia'], label=\"Pythia\", color=colors[0], marker=\"o\")\n",
11991199
"plt.plot(model_sizes['Llama'], pgr_values['Llama'], label=\"Llama\", color=colors[1], marker=\"*\", markersize=10)\n",
1200-
"plt.plot(model_sizes['Mistral'], pgr_values['Mistral'], label=\"Mistral\", color=colors[2], marker=\"x\", markersize=10)\n",
1200+
"plt.plot(model_sizes['Mistral'], pgr_values['Mistral'], label=\"Mistral\", color=colors[2], marker=\"x\", markersize=8)\n",
12011201
"plt.xlabel(\"model size\")\n",
12021202
"plt.ylabel(\"PGR\")\n",
12031203
"plt.legend()\n",

0 commit comments

Comments
 (0)