-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsteer_stories.py
150 lines (99 loc) · 3.64 KB
/
steer_stories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
torch.set_grad_enabled(False);
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float
# device setup
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
from transformer_lens import HookedTransformer
from sae_lens import SAE
# Choose a layer you want to focus on
# For this tutorial, we're going to use layer ????
layer = 0
# get model
model = HookedTransformer.from_pretrained("tiny-stories-1L-21M", device = device)
# get the SAE for this layer
sae = SAE.load_from_pretrained("sae_tiny-stories-1L-21M_blocks.0.hook_mlp_out_16384", device = device)
# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)
sv_prompt = " Lily"
sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True)
sv_feature_acts = sae.encode(activationCache[hook_point])
print(torch.topk(sv_feature_acts, 3).indices.tolist())
# Generate
sv_prompt = " Lily"
sv_logits, activationCache = model.run_with_cache(sv_prompt, prepend_bos=True)
tokens = model.to_tokens(sv_prompt)
print(tokens)
# get the feature activations from our SAE
sv_feature_acts = sae.encode(activationCache[hook_point])
# get sae_out
sae_out = sae.decode(sv_feature_acts)
# print out the top activations, focus on the indices
print(torch.topk(sv_feature_acts, 3))
# get the neurons to use;
print(torch.topk(sv_feature_acts, 3).indices.tolist())
# choose the vector -- find this from the above section
#
steering_vector = sae.W_dec[10284]
example_prompt = "Once upon a time"
coeff = 1000
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)
# apply steering vector when the model generates
def steering_hook(resid_pre, hook):
if resid_pre.shape[1] == 1:
return
position = sae_out.shape[1]
if steering_on:
breakpoint()
# using our steering vector and applying the coefficient
resid_pre[:, :position - 1, :] += coeff * steering_vector
def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
if seed is not None:
torch.manual_seed(seed)
with model.hooks(fwd_hooks=fwd_hooks):
tokenized = model.to_tokens(prompt_batch)
result = model.generate(
stop_at_eos=False, # avoids a bug on MPS
input=tokenized,
max_new_tokens=50,
do_sample=True,
**kwargs)
return result
def run_generate(example_prompt):
model.reset_hooks()
editing_hooks = [(f"blocks.{layer}.hook_resid_post", steering_hook)]
res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)
# Print results, removing the ugly beginning of sequence token
res_str = model.to_string(res[:, 1:])
print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
steering_on = True
run_generate(example_prompt)
# evaluate features
import pandas as pd
# Let's start by getting the top 10 logits for each feature
projection_onto_unembed = sae.W_dec @ model.W_U
# get the top 10 logits.
vals, inds = torch.topk(projection_onto_unembed, 10, dim=1)
# get 10 random features
random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,))
# Show the top 10 logits promoted by those features
top_10_logits_df = pd.DataFrame(
[model.to_str_tokens(i) for i in inds[random_indices]],
index=random_indices.tolist(),
).T
top_10_logits_df
# [7195, 5910, 2041]
top_10_associated_words_logits_df = model.to_str_tokens(inds[5910])
# See the words associated with feature 7195 (Should be "Golden")