Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor training #158

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
update docs
  • Loading branch information
jbloom-md committed May 22, 2024
commit 2282f5b8eff00fb84476b6907f8ea0fc2c7e8944
Empty file removed docs/about/reference.md
Empty file.
File renamed without changes.
File renamed without changes.
72 changes: 17 additions & 55 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,86 +38,48 @@ sparse_autoencoder = SparseAutoencoder.from_pretrained(
"gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre"
)
```
Currently, only `gpt2-small-res-jb` SAEs for the gpt2-small residual-stream are available via this method, but more SAEs will be added soon!

#### Loading SAEs, ActivationsStore, and Sparsity from Huggingface
You can see other importable SAEs in `sae_lens/pretrained_saes.yaml`.

For more advanced use-cases like fine-tuning a pre-trained SAE, [previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream).
(We'd accept a PR that converts this yaml to a nice table in the docs!)

#### Loading SAEs, ActivationsStore and Models from HuggingFace.

For more advanced use-cases like fine-tuning a pre-trained SAE, [previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs-Reformatted) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream).

```python
import torch

# picking up from the avoce chunk.
from sae_lens import LMSparseAutoencoderSessionloader
from huggingface_hub import hf_hub_download

layer = 8 # pick a layer you want.
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
model, _, activation_store = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg).load_sae_training_group_session(
path = path
)
sparse_autoencoder.eval()
```

You can also load the feature sparsity from huggingface.

```python
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_sparsity = torch.load(path, map_location=sparse_autoencoder.cfg.device)

```
### Background

We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab).



## Code Overview

The codebase contains 2 folders worth caring about:

- training: The main body of the code is here. Everything required for training SAEs.
- analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types.

Some other folders:

- tutorials: These aren't well maintained but I'll aim to clean them up soon.
- tests: When first developing the codebase, I was writing more tests. I have no idea whether they are currently working!


## Loading a Pretrained Language Model

Once your SAE is trained, the final SAE weights will be saved to wandb and are loadable via the session loader. The session loader will return:
- The model your SAE was trained on (presumably you're interested in studying this. It's always a HookedTransformer)
- Your SAE.
- An activations loader: from which you can get randomly sampled activations or batches of tokens from the dataset you used to train the SAE. (more on this in the tutorial)

```python
from sae_lens import LMSparseAutoencoderSessionloader

path ="path/to/sparse_autoencoder.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
path
)

```
## Tutorials

I wrote a tutorial to show users how to do some basic exploration of their SAE:

- `evaluating_your_sae.ipynb`: A quick/dirty notebook showing how to check L0 and Prediction loss with your SAE, as well as showing how to generate interactive dashboards using Callum's reporduction of [Anthropics interface](https://transformer-circuits.pub/2023/monosemantic-features#setup-interface).
- `logits_lens_with_features.ipynb`: A notebook showing how to reproduce the analysis from this [LessWrong post](https://www.lesswrong.com/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens).
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)


## Example Dashboard
## Example WandB Training Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.

![screenshot](dashboard_screenshot.png)




## Citations and References:

Research:
Expand Down
7 changes: 0 additions & 7 deletions docs/installation.md

This file was deleted.

75 changes: 4 additions & 71 deletions docs/training_saes.md
Original file line number Diff line number Diff line change
@@ -1,75 +1,8 @@
# Training Sparse Autoencoders

Sparse Autoencoders can be intimidating at first but it's fairly simple to train one once you know what each part of the config does. I've created a config class which you instantiate and pass to the runner which will complete your training run and log it's progress to wandb.
Methods development for training SAEs is rapidly evolving, so we're not attempting to maintain detailed up to date documentation.

Let's go through the major components of the config:
However, are attempting to maintain this [tutorial](tutorials/training_a_sparse_autoencoder.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb).

- Data: SAE's autoencode model activations. We need to specify the model, the part of the models activations we want to autoencode and the dataset the model is operating on when generating those activations. We now automatically detect if that dataset is tokenized and most huggingface datasets should be fine. One slightly annoying detail is that you need to know the dimensionality of those activations when contructing your SAE but you can get that in the transformerlens [docs](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html). Any language model in the table from those docs should work.
- SAE Parameters: Your expansion factor will determine the size of your SAE and the decoder bias initialization method should always be geometric_median or mean. Mean is faster but theoretically sub-optimal. I use another package to get the geometric median and it can be quite slow.
- Training Parameters: These are most critical. The right L1 coefficient (coefficient in the activation sparsity inducing term in the loss) changes with your learning rate but a good bet would be to use LR 4e-4 and L1 8e-5 for GPT2 small. These will vary for other models and playing around with them / short runs can be helpful. Training batch size of 4096 is standard and I'm not really sure whether there's benefit to playing with it. In theory a larger context size (one accurate to whatever the model was trained with) seems good but it's computationally cheaper to use 128. Learning rate warm up is important to avoid dead neurons.
- Activation Store Parameters: The activation store shuffles activations from forward passes over samples from your data. The larger it is, the better shuffling you'll get. In theory more shuffling is good. The total training tokens is a very important parameter. The more the better, but you'll often see good results having trained on a few hundred million tokens. Store batch batch size is a function of your gpu and how many forward passes of your model you want to do simultaneously when collecting activations.
- Dead Neurons / Sparsity Metrics: The config around resampling was more important when we were using resampling to avoid dead neurons (see Anthropic's post on this), but using ghost gradients, the resampling protcol is much simpler. I'd always set ghost grad to True and feature sampling method to None. The feature sampling window effects the dashboard statistics tracking feature occurence and the dead feature window tracks how many forward passes a neuron must not activate before we apply ghost grads to it.
- WANDB: Fairly straightfoward. Don't set log frequency too high or your dashboard will be slow!
- Device: I can run this code on my macbook with "mps" but mostly do runs with cuda.
- Dtype: Float16 maybe could work but I had some funky results and have left it at float32 for the time being.
- Checkpoints: I'd collected checkpoints on runs you care about but turn them off when tuning since it can be slow.


```python
import torch
import os
import sys

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_lens import LanguageModelSAERunnerConfig, language_model_sae_runner

# NOTE: Refer to training tutorials for updated parameter configurations.
# Tutorial notebook: https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb
# PRs to update docs welcome
cfg = LanguageModelSAERunnerConfig(
# Data Generating Function (Model + Training Distibuion)
model_name="tiny-stories-1L-21M", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
hook_point="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
hook_point_layer=0, # Only one layer in the model.
d_in=1024, # the width of the mlp output.
dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
is_dataset_tokenized=True,
# SAE Parameters
mse_loss_normalization=None, # We won't normalize the mse loss,
expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.
b_dec_init_method="geometric_median", # The geometric median can be used to initialize the decoder weights.
# Training Parameters
lr=0.0008, # lower the better, we'll go fairly high to speed up the tutorial.
lr_scheduler_name="constant", # constant learning rate with warmup. Could be better schedules out there.
lr_warm_up_steps=10000, # this can help avoid too many dead features initially.
l1_coefficient=0.001, # will control how sparse the feature activations are
lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)
train_batch_size=4096,
context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower.
# Activation Store Parameters
n_batches_in_buffer=64, # controls how many activations we store / shuffle.
training_tokens=1_000_000
* 50, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
store_batch_size=16,
# Resampling protocol
use_ghost_grads=False,
feature_sampling_window=1000, # this controls our reporting of feature sparsity stats
dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.
dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.
# WANDB
log_to_wandb=True, # always use wandb unless you are just testing code.
wandb_project="sae_lens_tutorial",
wandb_log_frequency=10,
# Misc
device=device,
seed=42,
n_checkpoints=0,
checkpoint_path="checkpoints",
dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

```
We encourage readers to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-1qosyh8g3-9bF3gamhLNJiqCL_QqLFrA) for support!
5 changes: 2 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ nav:
- Roadmap: roadmap.md
- Installation: installation.md
- Training SAEs: training_saes.md
- Citation: citation.md
- Contributing: contributing.md
# - Analysis: usage/examples.md
- Reference:
- Language Models: reference/language_models.md
- Toy Models: reference/toy_models.md
- Misc: reference/misc.md
- About:
- Citation: about/citation.md
- Contributing: about/contributing.md

plugins:
- search
Expand Down
6 changes: 3 additions & 3 deletions tutorials/training_a_sparse_autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"import os\n",
"\n",
"from sae_lens.training.config import LanguageModelSAERunnerConfig\n",
"from sae_lens.training.lm_runner import language_model_sae_runner\n",
"from sae_lens.training.lm_runner import SAETrainingRunner\n",
"\n",
"if torch.cuda.is_available():\n",
" device = \"cuda\"\n",
Expand Down Expand Up @@ -336,7 +336,7 @@
" scale_sparsity_penalty_by_decoder_norm=True,\n",
" decoder_heuristic_init=True,\n",
" init_encoder_as_decoder_transpose=True,\n",
" normalize_activations=False,\n",
" normalize_activations=True,\n",
" # Training Parameters\n",
" lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n",
" adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n",
Expand Down Expand Up @@ -372,7 +372,7 @@
")\n",
"\n",
"# look at the next cell to see some instruction for what to do while this is running.\n",
"sparse_autoencoder = language_model_sae_runner(cfg)"
"sparse_autoencoder = SAETrainingRunner(cfg).run()"
]
},
{
Expand Down
Loading