forked from jbloomAus/SAELens
-
Notifications
You must be signed in to change notification settings - Fork 0
/
toy_model_runner.py
64 lines (52 loc) · 1.74 KB
/
toy_model_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Any, cast
import einops
import torch
import wandb
from sae_lens.config import ToyModelSAERunnerConfig
from sae_lens.sae import SAE
from sae_lens.training.toy_models import ReluOutputModel as ToyModel
from sae_lens.training.toy_models import ToyConfig
from sae_lens.training.train_toy_sae import train_toy_sae
def toy_model_sae_runner(cfg: ToyModelSAERunnerConfig):
"""
A runner for training an SAE on a toy model.
"""
# Toy Model Config
toy_model_cfg = ToyConfig(
n_features=cfg.n_features,
n_hidden=cfg.n_hidden,
n_correlated_pairs=cfg.n_correlated_pairs,
n_anticorrelated_pairs=cfg.n_anticorrelated_pairs,
feature_probability=cfg.feature_probability,
)
# Initialize Toy Model
model = ToyModel(
cfg=toy_model_cfg,
device=torch.device(cfg.device),
)
# Train the Toy Model
model.optimize(steps=cfg.model_training_steps)
# Generate Training Data
batch = model.generate_batch(cfg.total_training_tokens)
hidden = einops.einsum(
batch,
model.W,
"batch_size features, hidden features -> batch_size hidden",
)
sae = SAE(
**cfg.get_base_sae_cfg_dict(),
) # config has the hyperparameters for the SAE
if cfg.log_to_wandb:
wandb.init(project=cfg.wandb_project, config=cast(Any, cfg))
sae = train_toy_sae(
sae,
activation_store=hidden.detach().squeeze(),
batch_size=cfg.train_batch_size,
feature_sampling_window=cfg.feature_sampling_window,
dead_feature_threshold=cfg.dead_feature_threshold,
use_wandb=cfg.log_to_wandb,
wandb_log_frequency=cfg.wandb_log_frequency,
)
if cfg.log_to_wandb:
wandb.finish()
return sae