Skip to content

Commit

Permalink
BREAKING CHANGE: Quality of Life Refactor of SAE Lens adding SAE Anal…
Browse files Browse the repository at this point in the history
…ysis with HookedSAETransformer and some other breaking changes. (jbloomAus#162)

* move HookedSAETransformer from TL

* add tests

* move runners one level up

* fix docs name

* trainer clean up

* create training sae, not fully seperate yet

* remove accidentally commited notebook

* commit working code in the middle of refactor, more work to do

* don't use act layers plural

* make tutorial not use the activation store

* moved this file

* move import of toy model runner

* saes need to store at least enough information to run them

* further refactor and add tests

* finish act store device rebase

* fix config type not caught by test

* partial progress, not yet handling error term for hooked sae transformer

* bring tests in line with trainer doing more work

* revert some of the simplification to preserve various features, ghost grads, noising

* hooked sae transformer is working

* homogenize configs

* re-enable sae compilation

* remove old file that doesn't belong

* include normalize activations in base sae config

* make sure tutorial works

* don't forget to update pbar

* rename sparse autoencoder to sae for brevity

* move non-training specific modules out of training

* rename to remove _point

* first steps towards better docs

* final cleanup

* have ci use same test coverage total as make check-ci

* clean up docs a bit

---------

Co-authored-by: ckkissane <[email protected]>
  • Loading branch information
jbloomAus and ckkissane committed May 28, 2024
1 parent eb9489a commit e4eaccc
Show file tree
Hide file tree
Showing 88 changed files with 5,831 additions and 8,947 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
run: poetry run pyright
- name: Run Unit Tests
# Would use make, but want cov report in xml format
run: poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/unit --cov-report=xml
run: poetry run pytest -v --cov=sae_lens/ --cov-report=term-missing --cov-branch tests/unit --cov-report=xml
- name: Upload coverage reports to Codecov
uses: codecov/[email protected]
with:
Expand Down
3 changes: 3 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# API

::: sae_lens
5 changes: 4 additions & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ Contributions are welcome! To get setup for development, follow the instructions
Make sure you have [poetry](https://python-poetry.org/) installed, clone the repository, and install dependencies with:

```bash
poetry install
git clone https://github.com/jbloomAus/SAELens.git # we recommend you make a fork for submitting PR's and clone that!
poetry lock # can take a while.
poetry install
make check-ci # validate the install
```

## Testing, Linting, and Formatting
Expand Down
65 changes: 22 additions & 43 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,71 +25,50 @@ pip install sae-lens

### Loading Sparse Autoencoders from Huggingface


#### Loading officially supported SAEs

To load an officially supported sparse autoencoder, you can use `SparseAutoencoder.from_pretrained()` as below:
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the *original cfg dict* from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).

```python
from sae_lens import SparseAutoencoder
from sae_lens import SAE

layer = 8 # pick a layer you want.
sparse_autoencoder = SparseAutoencoder.from_pretrained(
"gpt2-small-res-jb", f"blocks.{layer}.hook_resid_pre"
sae, cfg_dict, sparsity = SAE.from_pretrained(
release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
device = device
)
```

You can see other importable SAEs in `sae_lens/pretrained_saes.yaml`.

(We'd accept a PR that converts this yaml to a nice table in the docs!)

#### Loading SAEs, ActivationsStore and Models from HuggingFace.

For more advanced use-cases like fine-tuning a pre-trained SAE, [previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs-Reformatted) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream).

```python

# picking up from the avoce chunk.
from sae_lens import LMSparseAutoencoderSessionloader

model, _, activation_store = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg).load_sae_training_group_session(
path = path
)
sparse_autoencoder.eval()
```

### Background
### Background and further Readings

We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab).

For recent progress in SAEs, we recommend the LessWrong forum's [Sparse Autoencoder tag](https://www.lesswrong.com/tag/sparse-autoencoders-saes)

## Tutorials

I wrote a tutorial to show users how to do some basic exploration of their SAE:

- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
- Loading and Analysing Pre-Trained Sparse Autoencoders [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- Understanding SAE Features with the Logit Lens [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- Training a Sparse Autoencoder [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)


## Example WandB Training Dashboard
## Example WandB Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.

![screenshot](dashboard_screenshot.png)

## Citations and References:

Research:
- [Towards Monosemanticy](https://transformer-circuits.pub/2023/monosemantic-features)
- [Sparse Autoencoders Find Highly Interpretable Features in Language Model](https://arxiv.org/abs/2309.08600)
## Citation



Reference Implementations:
- [Neel Nanda](https://github.com/neelnanda-io/1L-Sparse-Autoencoder)
- [AI-Safety-Foundation](https://github.com/ai-safety-foundation/sparse_autoencoder).
- [Arthur Conmy](https://github.com/ArthurConmy/sae).
- [Callum McDougall](https://github.com/callummcdougall/sae-exercises-mats/tree/main)
```
@misc{bloom2024saetrainingcodebase,
title = {SAELens Training
author = {Joseph Bloom, David Channin},
year = {2024},
howpublished = {\url{}},
}}
```
9 changes: 0 additions & 9 deletions docs/reference/language_models.md

This file was deleted.

7 changes: 0 additions & 7 deletions docs/reference/misc.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/reference/runners.md

This file was deleted.

6 changes: 0 additions & 6 deletions docs/reference/toy_models.md

This file was deleted.

7 changes: 5 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ test:
make acceptance-test

unit-test:
poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/unit
poetry run pytest -v --cov=sae_lens/ --cov-report=term-missing --cov-branch tests/unit

acceptance-test:
poetry run pytest -v --cov=sae_lens/training/ --cov-report=term-missing --cov-branch tests/acceptance
poetry run pytest -v --cov=sae_lens/ --cov-report=term-missing --cov-branch tests/acceptance

check-ci:
make check-format
make check-type
make unit-test

docstring-coverage:
poetry run docstr-coverage sae_lens --skip-file-doc

docs-serve:
poetry run mkdocs serve
14 changes: 5 additions & 9 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
site_name: SAELens Training
site_description: Docs for Sparse Autoencoder Training Library
site_name: SAE Lens
site_description: Docs for Sparse Autoencoder Training and Analysis Library
site_author: Joseph Bloom
repo_url: http://github.com/jbloomAus/mats_sae_training/
repo_name: jbloomAus/mats_sae_training
repo_url: http://github.com/jbloomAus/SAELens
repo_name: jbloomAus/SAELens
edit_uri: ""

theme:
Expand Down Expand Up @@ -45,15 +45,11 @@ extra_javascript:
nav:
- Home: index.md
- Roadmap: roadmap.md
- Installation: installation.md
- Training SAEs: training_saes.md
- Citation: citation.md
- Contributing: contributing.md
# - Analysis: usage/examples.md
- Reference:
- Language Models: reference/language_models.md
- Toy Models: reference/toy_models.md
- Misc: reference/misc.md
- API: api.md

plugins:
- search
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mamba-lens = "^0.0.4"
ansible-lint = { version = "^24.2.3", markers = "platform_system != 'Windows'" }
botocore = "^1.34.101"
boto3 = "^1.34.101"
docstr-coverage = "^2.3.2"

[tool.poetry.extras]
mamba = ["mamba-lens"]
Expand Down
25 changes: 15 additions & 10 deletions sae_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
__version__ = "2.1.3"

from .training.activations_store import ActivationsStore
from .training.cache_activations_runner import CacheActivationsRunner
from .training.config import (

from .analysis.hooked_sae_transformer import HookedSAETransformer
from .cache_activations_runner import CacheActivationsRunner
from .config import (
CacheActivationsRunnerConfig,
LanguageModelSAERunnerConfig,
PretokenizeRunnerConfig,
)
from .training.evals import run_evals
from .training.lm_runner import SAETrainingRunner
from .training.pretokenize_runner import pretokenize_runner
from .training.session_loader import LMSparseAutoencoderSessionloader
from .training.sparse_autoencoder import SparseAutoencoder
from .evals import run_evals
from .pretokenize_runner import pretokenize_runner
from .sae import SAE, SAEConfig
from .sae_training_runner import SAETrainingRunner
from .training.activations_store import ActivationsStore
from .training.training_sae import TrainingSAE, TrainingSAEConfig

__all__ = [
"SparseAutoencoder",
"SAE",
"SAEConfig",
"TrainingSAE",
"TrainingSAEConfig",
"HookedSAETransformer",
"ActivationsStore",
"LanguageModelSAERunnerConfig",
"SAETrainingRunner",
"CacheActivationsRunnerConfig",
"CacheActivationsRunner",
"PretokenizeRunnerConfig",
"pretokenize_runner",
"LMSparseAutoencoderSessionloader",
"run_evals",
]
Loading

0 comments on commit e4eaccc

Please sign in to comment.