From b7d77757ef82620254aa0162bd2eae57f6230d67 Mon Sep 17 00:00:00 2001 From: frankstein Date: Sun, 1 Mar 2026 21:29:24 +0800 Subject: [PATCH 1/7] refactor(language-model): WIP hook-based replacement model refactor in TransformerLensLanguageModel --- src/lm_saes/backend/language_model.py | 270 +++++++++++++++++- .../circuit/utils/attribution_utils.py | 1 + src/lm_saes/clt.py | 8 + src/lm_saes/crosscoder.py | 8 + src/lm_saes/lorsa.py | 8 + src/lm_saes/molt.py | 8 + src/lm_saes/sae.py | 8 + src/lm_saes/utils/misc.py | 14 + 8 files changed, 322 insertions(+), 3 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index b400398b..8f49da4b 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -4,9 +4,11 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager +from functools import partial from itertools import accumulate from typing import Any, Callable, Literal, Optional, Union, cast +import einops import torch import torch.distributed as dist import torch.utils._pytree as pytree @@ -25,11 +27,11 @@ from lm_saes.backend.run_with_cache_until import run_with_cache_until from lm_saes.config import BaseModelConfig +from lm_saes.lorsa import LowRankSparseAttention +from lm_saes.sae import AbstractSparseAutoEncoder, SparseAutoEncoder from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type from lm_saes.utils.distributed import DimMap -from lm_saes.utils.misc import ( - pad_and_truncate_tokens, -) +from lm_saes.utils.misc import ensure_tokenized, pad_and_truncate_tokens from lm_saes.utils.timer import timer @@ -235,8 +237,89 @@ def pad_token_id(self) -> int | None: pass +def hook_in_fn_builder(replacement_module: AbstractSparseAutoEncoder) -> Callable: + if isinstance(replacement_module, SparseAutoEncoder): + + def hook_in_fn( + x: torch.Tensor, + hook: str, + replacement_module: AbstractSparseAutoEncoder, + cache_activations: dict[str, torch.Tensor], + ): + assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" + cache_activations[hook + ".x"] = x + cache_activations[hook + ".feature_acts.up"] = replacement_module.encode(x).to_sparse(-1).coalesce() + cache_activations[hook + ".feature_acts.down"] = ( + cache_activations[hook + ".feature_acts.up"].detach().coalesce() + ) + cache_activations[hook + ".feature_acts.down"].requires_grad_(True) + return x.detach() + + return hook_in_fn + + elif isinstance(replacement_module, LowRankSparseAttention): + + def hook_in_fn( + x: torch.Tensor, + hook: str, + replacement_module: LowRankSparseAttention, + cache_activations: dict[str, torch.Tensor], + ): + assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" + cache_activations[hook + ".x"] = x + encode_result = replacement_module.encode( + x, + return_hidden_pre=False, + return_attention_pattern=True, + return_attention_score=True, + ) + cache_activations[hook + ".attn_pattern"] = encode_result[1].detach() + cache_activations[hook + ".feature_acts.up"] = ( + encode_result[0].to_sparse(-1).coalesce() + ) # batch, seq_len, d_sae + cache_activations[hook + ".feature_acts.down"] = ( + cache_activations[hook + ".feature_acts.up"].detach().coalesce() + ) + cache_activations[hook + ".feature_acts.down"].requires_grad_(True) + return x.detach() + + return hook_in_fn + + else: + # TODO: handle other replacement modules such as CLT + raise ValueError(f"Unsupported replacement module type: {type(replacement_module)}") + + +def hook_out_fn_builder(replacement_module: AbstractSparseAutoEncoder) -> Callable: + def hook_out_fn( + x: torch.Tensor, + hook: str, + replacement_module: AbstractSparseAutoEncoder, + cache_activations: dict[str, torch.Tensor], + update_error_cache: bool = False, + ): + assert hook in replacement_module.cfg.hooks_out, "Hook point must be in hook points out" + hook_in = replacement_module.cfg.hooks_in[0] # TODO: handle multiple hook points in for CLT + reconstructed = replacement_module.decode(cache_activations[hook_in + ".feature_acts.down"].to_dense()) + cache_activations[hook + ".reconstructed"] = reconstructed + assert hook + ".error" in cache_activations or update_error_cache, ( + "There must be an error cache for the hook point" + ) + if update_error_cache: + error = x - reconstructed + cache_activations[hook + ".error"] = error + else: + error = cache_activations[hook + ".error"] + error = error.detach() + error.requires_grad_(True) + return reconstructed + error + + return hook_out_fn + + class TransformerLensLanguageModel(LanguageModel): def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = None): + self.cache_activations: dict[str, torch.Tensor] = {} self.cfg = cfg self.device_mesh = device_mesh if cfg.device == "cuda": @@ -286,6 +369,187 @@ def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = No else None ) + def initialize_replacement_model( + self, inputs: torch.Tensor | str, replacement_modules: list[AbstractSparseAutoEncoder] + ): + tokens = ensure_tokenized(inputs, self.tokenizer, device=self.device) + hook_in_fn_list: list[tuple[Union[str, Callable], Callable]] = [ + ( + hook_in, + partial( + hook_in_fn_builder(replacement_module), + replacement_module=replacement_module, + cache_activations=self.cache_activations, + ), + ) + for replacement_module in replacement_modules + for hook_in in replacement_module.cfg.hooks_in + ] + hook_out_fn_list: list[tuple[Union[str, Callable], Callable]] = [ + ( + hook_out, + partial( + hook_out_fn_builder(replacement_module), + replacement_module=replacement_module, + cache_activations=self.cache_activations, + update_error_cache=True, + ), + ) + for replacement_module in replacement_modules + for hook_out in replacement_module.cfg.hooks_out + ] + with self.hooks(fwd_hooks=hook_in_fn_list + hook_out_fn_list): + logits = self.forward(tokens) + return logits + + def get_cache_activations(self, suffix: str) -> list[torch.Tensor]: + return [v for k, v in self.cache_activations.items() if k.endswith(suffix)] + + def _clear_cache_activation_grads(self) -> None: + """Clear gradients for all cached activations in-place.""" + for activation in self.cache_activations.values(): + if activation.grad is not None: + activation.grad = None + + def _build_edge_row_block_from_cache( + self, + *, + cur_batch_size: int, + edge_matrix: torch.Tensor, + row_slice: slice, + token_embed: torch.Tensor, # batch, pos, d_model + feature_acts_down: list[torch.Tensor], + errors: list[torch.Tensor], + token_col_slice: slice, + feature_col_slice: slice, + error_col_slice: slice, + ) -> None: + token_grad = token_embed.grad + if token_grad is None: + token_block = torch.zeros(cur_batch_size, token_embed.shape[1], dtype=edge_matrix.dtype, device="cpu") + else: + token_block = ( + einops.reduce((token_embed * token_grad)[:cur_batch_size], "b pos d_model -> b pos", "sum") + .detach() + .cpu() + ) + edge_matrix[row_slice, token_col_slice] = token_block + + feature_blocks: list[torch.Tensor] = [] + # token + feature + error + for acts in feature_acts_down: + grad = acts.grad + if grad is None: + feature_blocks.append(torch.zeros(cur_batch_size, acts._nnz(), dtype=edge_matrix.dtype, device="cpu")) + continue + grad = grad.coalesce().to_dense() if grad.is_sparse else grad + + indices = acts.indices() + mask = indices[0] < cur_batch_size + idx = indices[:, mask] + products = acts.values()[mask] * grad[idx[0], idx[1], idx[2]] + feature_blocks.append(einops.rearrange(products, "(b k) -> b k", b=cur_batch_size).detach().cpu()) + + edge_matrix[row_slice, feature_col_slice] = torch.cat(feature_blocks, dim=1) + + error_blocks: list[torch.Tensor] = [] + for error in errors: + grad = error.grad + if grad is None: + error_blocks.append(torch.zeros(cur_batch_size, 1, dtype=edge_matrix.dtype, device="cpu")) + continue + contrib = einops.reduce((error * grad)[:cur_batch_size], "b ... -> b", "sum") + error_blocks.append(contrib.detach().cpu().unsqueeze(-1)) + + edge_matrix[row_slice, error_col_slice] = torch.cat(error_blocks, dim=1) + + def attribute( + self, + inputs: torch.Tensor | str, + replacement_modules: list[AbstractSparseAutoEncoder], + max_n_logits: int = 10, + desired_logit_prob: float = 0.95, + batch_size: int = 512, + max_feature_nodes: Optional[int] = None, + ): + assert self.model is not None, "model must be initialized" + tokens = ensure_tokenized(inputs, self.tokenizer, device=self.device) + n_tokens = tokens.shape[0] + fwd_hooks_in: list[tuple[Union[str, Callable], Callable]] = [ + ( + hook_in, + partial( + hook_in_fn_builder(replacement_module), + replacement_module=replacement_module, + cache_activations=self.cache_activations, + ), + ) + for replacement_module in replacement_modules + for hook_in in replacement_module.cfg.hooks_in + ] + fwd_hooks_out: list[tuple[Union[str, Callable], Callable]] = [ + ( + hook_out, + partial( + hook_out_fn_builder(replacement_module), + replacement_module=replacement_module, + cache_activations=self.cache_activations, + update_error_cache=True, + ), + ) + for replacement_module in replacement_modules + for hook_out in replacement_module.cfg.hooks_out + ] + + def token_fwd_hook_fn( + x: torch.Tensor, + hook: Any, + ): + self.cache_activations["hook_embed"] = x + self.cache_activations["hook_embed"].retain_grad() + return x + + token_fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [("hook_embed", token_fwd_hook_fn)] + with self.hooks(fwd_hooks=fwd_hooks_in + fwd_hooks_out + token_fwd_hooks): + batch_logits = self.forward(einops.repeat(tokens, "n -> b n", b=batch_size))[:, -1] # batch, d_vocab + + feature_acts_down = self.get_cache_activations(".feature_acts.down") + feature_acts_up = self.get_cache_activations(".feature_acts.up") + errors = self.get_cache_activations(".error") + + with torch.no_grad(): + probs = torch.softmax(batch_logits[0], dim=-1) + top_p, top_idx = torch.topk(probs, max_n_logits) + cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1 + top_p, top_idx = top_p[:cutoff], top_idx[:cutoff] + n_logits = len(top_idx) + num_active_features = sum(acts._nnz() for acts in feature_acts_up) + num_error_nodes = len(errors) + feature_nodes = min(max_feature_nodes or num_active_features, num_active_features) + edge_matrix = torch.zeros( + feature_nodes + n_logits, n_tokens + num_active_features + num_error_nodes + ) # [downstream_nodes, upstream_nodes], on cpu, [features+logits, tokens+features+errors] + + for i in range(0, n_logits, batch_size): + cur_batch_size = min(batch_size, n_logits - i) + batch_nodes = batch_logits[:, top_idx[i : i + cur_batch_size]] - torch.mean(batch_logits[0], dim=-1) + self._clear_cache_activation_grads() + batch_nodes.diagonal().sum().backward(retain_graph=True) + row_slice = slice(feature_nodes + i, feature_nodes + i + cur_batch_size) + self._build_edge_row_block_from_cache( + cur_batch_size=cur_batch_size, + edge_matrix=edge_matrix, + row_slice=row_slice, + token_embed=self.cache_activations["hook_embed"], + feature_acts_down=feature_acts_down, + errors=errors, + token_col_slice=slice(0, n_tokens), + feature_col_slice=slice(n_tokens, n_tokens + num_active_features), + error_col_slice=slice(n_tokens + num_active_features, n_tokens + num_active_features + num_error_nodes), + ) + + # TODO: trace from features + @property def eos_token_id(self) -> int | None: return self.tokenizer.eos_token_id diff --git a/src/lm_saes/circuit/utils/attribution_utils.py b/src/lm_saes/circuit/utils/attribution_utils.py index 549db0ca..957ab01e 100644 --- a/src/lm_saes/circuit/utils/attribution_utils.py +++ b/src/lm_saes/circuit/utils/attribution_utils.py @@ -183,6 +183,7 @@ def select_feature_activations( return torch.stack(activations) +# TODO: remove this function def ensure_tokenized(prompt: Union[str, torch.Tensor, List[int]], tokenizer) -> torch.Tensor: """Convert *prompt* → 1-D tensor of token ids (no batch dim).""" diff --git a/src/lm_saes/clt.py b/src/lm_saes/clt.py index a80138db..4bc7daa5 100644 --- a/src/lm_saes/clt.py +++ b/src/lm_saes/clt.py @@ -104,6 +104,14 @@ def associated_hook_points(self) -> list[str]: """All hook points used by the CLT.""" return self.hook_points_in + self.hook_points_out + @property + def hooks_in(self) -> list[str]: + return self.hook_points_in + + @property + def hooks_out(self) -> list[str]: + return self.hook_points_out + def model_post_init(self, __context): super().model_post_init(__context) assert len(self.hook_points_in) == len(self.hook_points_out), ( diff --git a/src/lm_saes/crosscoder.py b/src/lm_saes/crosscoder.py index e7179161..c370c2f0 100644 --- a/src/lm_saes/crosscoder.py +++ b/src/lm_saes/crosscoder.py @@ -61,6 +61,14 @@ class CrossCoderConfig(SparseDictionaryConfig): def associated_hook_points(self) -> list[str]: return self.hook_points + @property + def hooks_in(self) -> list[str]: + return self.hook_points + + @property + def hooks_out(self) -> list[str]: + return self.hook_points + @property def n_heads(self) -> int: return len(self.hook_points) diff --git a/src/lm_saes/lorsa.py b/src/lm_saes/lorsa.py index 4d52eabd..abc1b6e0 100644 --- a/src/lm_saes/lorsa.py +++ b/src/lm_saes/lorsa.py @@ -79,6 +79,14 @@ def associated_hook_points(self) -> list[str]: """All hook points used by Lorsa.""" return [self.hook_point_in, self.hook_point_out] + @property + def hooks_in(self) -> list[str]: + return [self.hook_point_in] + + @property + def hooks_out(self) -> list[str]: + return [self.hook_point_out] + def model_post_init(self, __context): super().model_post_init(__context) assert self.hook_point_in is not None and self.hook_point_out is not None, ( diff --git a/src/lm_saes/molt.py b/src/lm_saes/molt.py index b905cd72..b45c1adc 100644 --- a/src/lm_saes/molt.py +++ b/src/lm_saes/molt.py @@ -116,6 +116,14 @@ def num_rank_types(self) -> int: def associated_hook_points(self) -> list[str]: return [self.hook_point_in, self.hook_point_out] + @property + def hooks_in(self) -> list[str]: + return [self.hook_point_in] + + @property + def hooks_out(self) -> list[str]: + return [self.hook_point_out] + @register_sae_model("molt") class MixtureOfLinearTransform(SparseDictionary): diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index a7b03170..6d3f5cf0 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -32,6 +32,14 @@ class SAEConfig(SparseDictionaryConfig): def associated_hook_points(self) -> list[str]: return [self.hook_point_in, self.hook_point_out] + @property + def hooks_in(self) -> list[str]: + return [self.hook_point_in] + + @property + def hooks_out(self) -> list[str]: + return [self.hook_point_out] + @register_sae_model("sae") class SparseAutoEncoder(SparseDictionary): diff --git a/src/lm_saes/utils/misc.py b/src/lm_saes/utils/misc.py index 65f0467d..25e78b98 100644 --- a/src/lm_saes/utils/misc.py +++ b/src/lm_saes/utils/misc.py @@ -177,3 +177,17 @@ def get_slice_length(s: slice, length: int): start, stop, step = s.indices(length) length = (stop - start + step - 1) // step return length + + +def ensure_tokenized( + prompt: str | torch.Tensor | list[int], tokenizer, device: torch.device | str = "cpu" +) -> torch.Tensor: + """Convert *prompt* → 1-D tensor of token ids (no batch dim).""" + + if isinstance(prompt, str): + return tokenizer(prompt, return_tensors="pt").input_ids[0].to(device) + if isinstance(prompt, torch.Tensor): + return prompt.squeeze(0).to(device) if prompt.ndim == 2 else prompt.to(device) + if isinstance(prompt, list): + return torch.tensor(prompt, dtype=torch.long, device=device) + raise TypeError(f"Unsupported prompt type: {type(prompt)}") From ccfb2ee52de2b1baea2cefacb64f273162cd84fa Mon Sep 17 00:00:00 2001 From: frankstein Date: Mon, 2 Mar 2026 15:42:17 +0800 Subject: [PATCH 2/7] fix(language-model): fix bug in WIP hook-based replacement model --- src/lm_saes/backend/language_model.py | 100 ++++++++++++-------------- 1 file changed, 44 insertions(+), 56 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index 8f49da4b..a14e2055 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -249,9 +249,7 @@ def hook_in_fn( assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" cache_activations[hook + ".x"] = x cache_activations[hook + ".feature_acts.up"] = replacement_module.encode(x).to_sparse(-1).coalesce() - cache_activations[hook + ".feature_acts.down"] = ( - cache_activations[hook + ".feature_acts.up"].detach().coalesce() - ) + cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() cache_activations[hook + ".feature_acts.down"].requires_grad_(True) return x.detach() @@ -277,9 +275,7 @@ def hook_in_fn( cache_activations[hook + ".feature_acts.up"] = ( encode_result[0].to_sparse(-1).coalesce() ) # batch, seq_len, d_sae - cache_activations[hook + ".feature_acts.down"] = ( - cache_activations[hook + ".feature_acts.up"].detach().coalesce() - ) + cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() cache_activations[hook + ".feature_acts.down"].requires_grad_(True) return x.detach() @@ -411,57 +407,51 @@ def _clear_cache_activation_grads(self) -> None: if activation.grad is not None: activation.grad = None - def _build_edge_row_block_from_cache( + def _update_edge_matrix_from_cache( self, - *, - cur_batch_size: int, - edge_matrix: torch.Tensor, - row_slice: slice, + edge_rows: torch.Tensor, token_embed: torch.Tensor, # batch, pos, d_model feature_acts_down: list[torch.Tensor], errors: list[torch.Tensor], - token_col_slice: slice, - feature_col_slice: slice, - error_col_slice: slice, ) -> None: - token_grad = token_embed.grad - if token_grad is None: - token_block = torch.zeros(cur_batch_size, token_embed.shape[1], dtype=edge_matrix.dtype, device="cpu") - else: - token_block = ( - einops.reduce((token_embed * token_grad)[:cur_batch_size], "b pos d_model -> b pos", "sum") - .detach() - .cpu() + assert token_embed.grad is not None, "token_embed must have a gradient" + col_start = 0 + token_block = ( + einops.einsum( + token_embed[: edge_rows.shape[0]], + token_embed.grad[: edge_rows.shape[0]], + "b pos d_model, b pos d_model -> b pos", ) - edge_matrix[row_slice, token_col_slice] = token_block + .detach() + .cpu() + ) + edge_rows[:, col_start : col_start + token_embed.shape[1]] = token_block + col_start += token_embed.shape[1] - feature_blocks: list[torch.Tensor] = [] - # token + feature + error for acts in feature_acts_down: - grad = acts.grad - if grad is None: - feature_blocks.append(torch.zeros(cur_batch_size, acts._nnz(), dtype=edge_matrix.dtype, device="cpu")) - continue - grad = grad.coalesce().to_dense() if grad.is_sparse else grad - + assert acts.grad is not None, "feature_acts_down must have a gradient" + grad = acts.grad.coalesce().to_dense() if acts.grad.is_sparse else acts.grad indices = acts.indices() - mask = indices[0] < cur_batch_size + mask = indices[0] < edge_rows.shape[0] idx = indices[:, mask] - products = acts.values()[mask] * grad[idx[0], idx[1], idx[2]] - feature_blocks.append(einops.rearrange(products, "(b k) -> b k", b=cur_batch_size).detach().cpu()) + values = acts.values()[mask].reshape(edge_rows.shape[0], -1) + grad_values = grad[idx[0], idx[1], idx[2]].reshape(edge_rows.shape[0], -1) + products = einops.einsum(values, grad_values, "b k, b k -> b k").detach().cpu() + edge_rows[:, col_start : col_start + products.shape[1]] = products + col_start += products.shape[1] - edge_matrix[row_slice, feature_col_slice] = torch.cat(feature_blocks, dim=1) - - error_blocks: list[torch.Tensor] = [] for error in errors: - grad = error.grad - if grad is None: - error_blocks.append(torch.zeros(cur_batch_size, 1, dtype=edge_matrix.dtype, device="cpu")) - continue - contrib = einops.reduce((error * grad)[:cur_batch_size], "b ... -> b", "sum") - error_blocks.append(contrib.detach().cpu().unsqueeze(-1)) - - edge_matrix[row_slice, error_col_slice] = torch.cat(error_blocks, dim=1) + assert error.grad is not None, "error must have a gradient" + edge_rows[:, col_start : col_start + error.shape[1]] = ( + einops.einsum( + error[: edge_rows.shape[0]], + error.grad[: edge_rows.shape[0]], + "b pos ..., b pos ... -> b pos", + ) + .detach() + .cpu() + ) + col_start += error.shape[1] def attribute( self, @@ -523,8 +513,8 @@ def token_fwd_hook_fn( cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1 top_p, top_idx = top_p[:cutoff], top_idx[:cutoff] n_logits = len(top_idx) - num_active_features = sum(acts._nnz() for acts in feature_acts_up) - num_error_nodes = len(errors) + num_active_features = sum(acts._nnz() // batch_size for acts in feature_acts_up) + num_error_nodes = len(errors) * n_tokens feature_nodes = min(max_feature_nodes or num_active_features, num_active_features) edge_matrix = torch.zeros( feature_nodes + n_logits, n_tokens + num_active_features + num_error_nodes @@ -535,20 +525,18 @@ def token_fwd_hook_fn( batch_nodes = batch_logits[:, top_idx[i : i + cur_batch_size]] - torch.mean(batch_logits[0], dim=-1) self._clear_cache_activation_grads() batch_nodes.diagonal().sum().backward(retain_graph=True) - row_slice = slice(feature_nodes + i, feature_nodes + i + cur_batch_size) - self._build_edge_row_block_from_cache( - cur_batch_size=cur_batch_size, - edge_matrix=edge_matrix, - row_slice=row_slice, + self._update_edge_matrix_from_cache( + edge_rows=edge_matrix[feature_nodes + i : feature_nodes + i + cur_batch_size, :], token_embed=self.cache_activations["hook_embed"], feature_acts_down=feature_acts_down, errors=errors, - token_col_slice=slice(0, n_tokens), - feature_col_slice=slice(n_tokens, n_tokens + num_active_features), - error_col_slice=slice(n_tokens + num_active_features, n_tokens + num_active_features + num_error_nodes), ) - # TODO: trace from features + def normalize_edge_maxtrix(edge_matrix: torch.Tensor) -> torch.Tensor: + edge_matrix = edge_matrix.abs_() + return edge_matrix / edge_matrix.sum(dim=1, keepdim=True).clamp(min=1e-8) + + feature_row_to_col = torch.zeros(feature_nodes, dtype=torch.int64, device="cpu") @property def eos_token_id(self) -> int | None: From 6946ce509277960eff8894a61b59be9fe464b0aa Mon Sep 17 00:00:00 2001 From: frankstein Date: Mon, 2 Mar 2026 18:22:18 +0800 Subject: [PATCH 3/7] refactor(language-model): complete attribute method implementation (untested) - Remove stateful `cache_activations` from `TransformerLensLanguageModel`, passing it as an argument instead to make the class stateless. - Remove the unused `initialize_replacement_model` method. - Add `column_id_to_info` method to map column IDs to specific feature, error, and embedding information. - Implement the core logic of the `attribute` method, including forward pass, building normalized edge matrix via gradient backpropagation, and feature attribution computation (tests not yet written). --- src/lm_saes/backend/language_model.py | 153 +++++++++++++++++--------- 1 file changed, 100 insertions(+), 53 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index a14e2055..b766459d 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -315,7 +315,6 @@ def hook_out_fn( class TransformerLensLanguageModel(LanguageModel): def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = None): - self.cache_activations: dict[str, torch.Tensor] = {} self.cfg = cfg self.device_mesh = device_mesh if cfg.device == "cuda": @@ -365,45 +364,12 @@ def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = No else None ) - def initialize_replacement_model( - self, inputs: torch.Tensor | str, replacement_modules: list[AbstractSparseAutoEncoder] - ): - tokens = ensure_tokenized(inputs, self.tokenizer, device=self.device) - hook_in_fn_list: list[tuple[Union[str, Callable], Callable]] = [ - ( - hook_in, - partial( - hook_in_fn_builder(replacement_module), - replacement_module=replacement_module, - cache_activations=self.cache_activations, - ), - ) - for replacement_module in replacement_modules - for hook_in in replacement_module.cfg.hooks_in - ] - hook_out_fn_list: list[tuple[Union[str, Callable], Callable]] = [ - ( - hook_out, - partial( - hook_out_fn_builder(replacement_module), - replacement_module=replacement_module, - cache_activations=self.cache_activations, - update_error_cache=True, - ), - ) - for replacement_module in replacement_modules - for hook_out in replacement_module.cfg.hooks_out - ] - with self.hooks(fwd_hooks=hook_in_fn_list + hook_out_fn_list): - logits = self.forward(tokens) - return logits - - def get_cache_activations(self, suffix: str) -> list[torch.Tensor]: - return [v for k, v in self.cache_activations.items() if k.endswith(suffix)] + def _get_cache_activations(self, cache_activations: dict[str, torch.Tensor], suffix: str) -> list[torch.Tensor]: + return [v for k, v in cache_activations.items() if k.endswith(suffix)] - def _clear_cache_activation_grads(self) -> None: + def _clear_cache_activation_grads(self, cache_activations: dict[str, torch.Tensor]) -> None: """Clear gradients for all cached activations in-place.""" - for activation in self.cache_activations.values(): + for activation in cache_activations.values(): if activation.grad is not None: activation.grad = None @@ -453,6 +419,45 @@ def _update_edge_matrix_from_cache( ) col_start += error.shape[1] + assert col_start == edge_rows.shape[1], ( + f"col_start {col_start} must equal edge_rows.shape[1] {edge_rows.shape[1]}" + ) + + def column_id_to_info( + self, + cache_activations: dict[str, torch.Tensor], + ) -> list[tuple[int, int | None, str, int | None]]: + """Return ``(pos, layer, hook_name, dense_idx)`` for each active feature, error, and embedding.""" + result: list[tuple[int, int | None, str, int | None]] = [] + + if "hook_embed" in cache_activations: + embed = cache_activations["hook_embed"] + for p in range(embed.shape[1]): + result.append((p, None, "hook_embed", None)) + + for key, acts in cache_activations.items(): + if not key.endswith(".feature_acts.up"): + continue + hook_name = key.removesuffix(".feature_acts.up") + layer = int(hook_name.split(".")[1]) + idx = acts.indices() + if idx.shape[0] == 3: # (batch, pos, d_sae) + mask = idx[0] == 0 + pos, dense = idx[1, mask].tolist(), idx[2, mask].tolist() + else: # (pos, d_sae) + pos, dense = idx[0].tolist(), idx[1].tolist() + result += [(p, layer, hook_name, d) for p, d in zip(pos, dense)] + + for key, error in cache_activations.items(): + if not key.endswith(".error"): + continue + hook_name = key.removesuffix(".error") + layer = int(hook_name.split(".")[1]) + for p in range(error.shape[1]): + result.append((p, layer, hook_name, None)) + + return result + def attribute( self, inputs: torch.Tensor | str, @@ -463,6 +468,7 @@ def attribute( max_feature_nodes: Optional[int] = None, ): assert self.model is not None, "model must be initialized" + cache_activations: dict[str, torch.Tensor] = {} tokens = ensure_tokenized(inputs, self.tokenizer, device=self.device) n_tokens = tokens.shape[0] fwd_hooks_in: list[tuple[Union[str, Callable], Callable]] = [ @@ -471,7 +477,7 @@ def attribute( partial( hook_in_fn_builder(replacement_module), replacement_module=replacement_module, - cache_activations=self.cache_activations, + cache_activations=cache_activations, ), ) for replacement_module in replacement_modules @@ -483,7 +489,7 @@ def attribute( partial( hook_out_fn_builder(replacement_module), replacement_module=replacement_module, - cache_activations=self.cache_activations, + cache_activations=cache_activations, update_error_cache=True, ), ) @@ -494,18 +500,25 @@ def attribute( def token_fwd_hook_fn( x: torch.Tensor, hook: Any, + cache_activations: dict[str, torch.Tensor], ): - self.cache_activations["hook_embed"] = x - self.cache_activations["hook_embed"].retain_grad() + cache_activations["hook_embed"] = x + cache_activations["hook_embed"].retain_grad() return x - token_fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [("hook_embed", token_fwd_hook_fn)] + token_fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [ + ( + "hook_embed", + partial(token_fwd_hook_fn, cache_activations=cache_activations), + ) + ] with self.hooks(fwd_hooks=fwd_hooks_in + fwd_hooks_out + token_fwd_hooks): batch_logits = self.forward(einops.repeat(tokens, "n -> b n", b=batch_size))[:, -1] # batch, d_vocab - feature_acts_down = self.get_cache_activations(".feature_acts.down") - feature_acts_up = self.get_cache_activations(".feature_acts.up") - errors = self.get_cache_activations(".error") + feature_acts_down = self._get_cache_activations(cache_activations, ".feature_acts.down") + errors = self._get_cache_activations(cache_activations, ".error") + + column_id_to_info = self.column_id_to_info(cache_activations) with torch.no_grad(): probs = torch.softmax(batch_logits[0], dim=-1) @@ -513,8 +526,8 @@ def token_fwd_hook_fn( cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1 top_p, top_idx = top_p[:cutoff], top_idx[:cutoff] n_logits = len(top_idx) - num_active_features = sum(acts._nnz() // batch_size for acts in feature_acts_up) num_error_nodes = len(errors) * n_tokens + num_active_features = len(column_id_to_info) - n_tokens - num_error_nodes feature_nodes = min(max_feature_nodes or num_active_features, num_active_features) edge_matrix = torch.zeros( feature_nodes + n_logits, n_tokens + num_active_features + num_error_nodes @@ -523,20 +536,54 @@ def token_fwd_hook_fn( for i in range(0, n_logits, batch_size): cur_batch_size = min(batch_size, n_logits - i) batch_nodes = batch_logits[:, top_idx[i : i + cur_batch_size]] - torch.mean(batch_logits[0], dim=-1) - self._clear_cache_activation_grads() + self._clear_cache_activation_grads(cache_activations) batch_nodes.diagonal().sum().backward(retain_graph=True) self._update_edge_matrix_from_cache( edge_rows=edge_matrix[feature_nodes + i : feature_nodes + i + cur_batch_size, :], - token_embed=self.cache_activations["hook_embed"], + token_embed=cache_activations["hook_embed"], feature_acts_down=feature_acts_down, errors=errors, ) - def normalize_edge_maxtrix(edge_matrix: torch.Tensor) -> torch.Tensor: - edge_matrix = edge_matrix.abs_() - return edge_matrix / edge_matrix.sum(dim=1, keepdim=True).clamp(min=1e-8) + def get_normalize_edge_maxtrix(edge_matrix: torch.Tensor): + return torch.abs(edge_matrix) / torch.abs(edge_matrix).sum(dim=1, keepdim=True).clamp(min=1e-8) + + feature_row_to_col = torch.full((feature_nodes,), -1, dtype=torch.int64, device="cpu") + features_attributions = einops.einsum( + get_normalize_edge_maxtrix(edge_matrix[feature_nodes:, :]), + top_p, + "l t, l -> t", + ) + for i in range(0, feature_nodes, batch_size): + cur_batch = min(batch_size, feature_nodes - i) + already_assigned = feature_row_to_col[feature_row_to_col != -1] + features_attributions[already_assigned] = float("-inf") + _, batch_feature_ids = torch.topk(features_attributions, cur_batch) + feature_row_to_col[i : i + cur_batch] = batch_feature_ids + self._clear_cache_activation_grads(cache_activations) + bwd = None + for batch_idx, feat_id in enumerate(batch_feature_ids.tolist()): + info: Any = column_id_to_info[feat_id] + acts = cache_activations[info[2] + ".feature_acts.up"] + idx = acts.indices() + mask = (idx[0] == batch_idx) & (idx[1] == info[0]) & (idx[2] == info[3]) + bwd = acts.values()[mask][0] if bwd is None else bwd + acts.values()[mask][0] + assert bwd is not None, "bwd must not be None" + bwd.backward(retain_graph=True) + self._update_edge_matrix_from_cache( + edge_rows=edge_matrix[i : i + cur_batch, :], + token_embed=cache_activations["hook_embed"], + feature_acts_down=feature_acts_down, + errors=errors, + ) + features_attributions += einops.einsum( + get_normalize_edge_maxtrix(edge_matrix[feature_nodes:, batch_feature_ids]), + get_normalize_edge_maxtrix(edge_matrix[i : i + cur_batch, :]), + top_p, + "l b, b t, l -> t", + ) - feature_row_to_col = torch.zeros(feature_nodes, dtype=torch.int64, device="cpu") + return edge_matrix, cache_activations, column_id_to_info @property def eos_token_id(self) -> int | None: From 690bc1d86e847458692ecd0cfa2f2ffa2ba0b5c8 Mon Sep 17 00:00:00 2001 From: frankstein Date: Tue, 3 Mar 2026 16:06:58 +0800 Subject: [PATCH 4/7] refactor(language-model): introduce EdgeMatrix class for cleaner attribute implementation (untested) Add EdgeMatrix, a torch.Tensor wrapper subclass that carries cache_activations metadata through tensor operations. This simplifies the attribute method by encapsulating row/column info, activation lookups, and edge matrix construction into the class itself. --- src/lm_saes/backend/language_model.py | 432 ++++++++++++++++---------- 1 file changed, 275 insertions(+), 157 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index b766459d..951f7aef 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -31,7 +31,7 @@ from lm_saes.sae import AbstractSparseAutoEncoder, SparseAutoEncoder from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type from lm_saes.utils.distributed import DimMap -from lm_saes.utils.misc import ensure_tokenized, pad_and_truncate_tokens +from lm_saes.utils.misc import ensure_tokenized, item, pad_and_truncate_tokens from lm_saes.utils.timer import timer @@ -248,7 +248,7 @@ def hook_in_fn( ): assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" cache_activations[hook + ".x"] = x - cache_activations[hook + ".feature_acts.up"] = replacement_module.encode(x).to_sparse(-1).coalesce() + cache_activations[hook + ".feature_acts.up"] = replacement_module.encode(x) cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() cache_activations[hook + ".feature_acts.down"].requires_grad_(True) return x.detach() @@ -272,9 +272,7 @@ def hook_in_fn( return_attention_score=True, ) cache_activations[hook + ".attn_pattern"] = encode_result[1].detach() - cache_activations[hook + ".feature_acts.up"] = ( - encode_result[0].to_sparse(-1).coalesce() - ) # batch, seq_len, d_sae + cache_activations[hook + ".feature_acts.up"] = encode_result[0] # batch, seq_len, d_sae cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() cache_activations[hook + ".feature_acts.down"].requires_grad_(True) return x.detach() @@ -296,7 +294,7 @@ def hook_out_fn( ): assert hook in replacement_module.cfg.hooks_out, "Hook point must be in hook points out" hook_in = replacement_module.cfg.hooks_in[0] # TODO: handle multiple hook points in for CLT - reconstructed = replacement_module.decode(cache_activations[hook_in + ".feature_acts.down"].to_dense()) + reconstructed = replacement_module.decode(cache_activations[hook_in + ".feature_acts.down"]) cache_activations[hook + ".reconstructed"] = reconstructed assert hook + ".error" in cache_activations or update_error_cache, ( "There must be an error cache for the hook point" @@ -313,6 +311,229 @@ def hook_out_fn( return hook_out_fn +class EdgeMatrix(torch.Tensor): + """A Tensor wrapper that preserves ``cache_activations`` through operations. + + Args: + elem: The underlying tensor data. + cache_activations: Cached activation tensors keyed by hook point name. + """ + + matrix: torch.Tensor + cache_activations: dict[str, torch.Tensor] + n_tokens: int + n_logits: int + n_active_features: int + n_error: int + max_features: int + + @classmethod + def _wrap( + cls, + matrix: torch.Tensor, + cache_activations: dict[str, torch.Tensor], + n_tokens: int, + n_logits: int, + n_active_features: int, + n_error: int, + max_features: int | None = None, + ) -> "EdgeMatrix": + """Low-level constructor that assembles an EdgeMatrix from pre-computed parts.""" + obj = torch.Tensor._make_wrapper_subclass( + cls, + size=matrix.shape, + dtype=matrix.dtype, + layout=matrix.layout, + device=matrix.device, + requires_grad=matrix.requires_grad, + ) + obj.matrix = matrix + obj.cache_activations = cache_activations + obj.n_tokens = n_tokens + obj.n_logits = n_logits + obj.n_active_features = n_active_features + obj.n_error = n_error + obj.max_features = max_features or n_active_features + return obj + + @staticmethod + def __new__(cls, cache_activations: dict[str, torch.Tensor], max_features: int | None = None) -> "EdgeMatrix": + n_tokens = cache_activations["hook_embed"].shape[1] + n_logits = cache_activations["logits"].shape[-1] + n_active_features = int( + item( + cast( + torch.Tensor, + sum([v[0].gt(0).sum() for k, v in cache_activations.items() if k.endswith(".feature_acts.up")]), + ) + ) + ) + max_features = min(max_features or n_active_features, n_active_features) + n_error = n_tokens * len([v for k, v in cache_activations.items() if k.endswith(".error")]) + matrix = torch.zeros(max_features + n_logits, n_tokens + n_active_features + n_error) + return cls._wrap(matrix, cache_activations, n_tokens, n_logits, n_active_features, n_error, max_features) + + @classmethod + def __torch_dispatch__(cls, func: Any, types: list[type], args: Any = (), kwargs: Any = None) -> Any: + from torch.utils._pytree import tree_map + + if kwargs is None: + kwargs = {} + + source: EdgeMatrix | None = None + for a in tree_map(lambda x: x, args): + if isinstance(a, EdgeMatrix): + source = a + break + + def unwrap(x: Any) -> Any: + return x.matrix if isinstance(x, EdgeMatrix) else x + + out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + + def wrap(x: Any) -> Any: + if isinstance(x, torch.Tensor) and not isinstance(x, EdgeMatrix) and source is not None: + return cls._wrap( + x, + source.cache_activations, + source.n_tokens, + source.n_logits, + source.n_active_features, + source.n_error, + source.max_features, + ) + return x + + return tree_map(wrap, out) + + def __init__(self, cache_activations: dict[str, torch.Tensor], max_features: int | None = None): + self.row_info: list[tuple[int, str, int]] = self._build_row_info() + self.column_info: list[tuple[int, str, int | str]] = self._build_column_info() + self.row2column = [-1] * self.n_logits + + def _build_column_info(self) -> list[tuple[int, str, int | str]]: + """Build column metadata for the edge matrix. + + Each entry is ``(position, hook_name, feature_id)`` where *feature_id* + is ``str("embed")`` for embed and ``str("error")`` for error columns. + + Returns: + A list of ``(pos, hook_name, feature_id)`` tuples, one per column, + ordered as: embed tokens, active features, error terms. + """ + result: list[tuple[int, str, int | str]] = [] + if "hook_embed" in self.cache_activations: + n_pos = self.cache_activations["hook_embed"].shape[1] + result.extend([(p, "hook_embed", "embed") for p in range(n_pos)]) + for key in self.cache_activations: + if key.endswith(".feature_acts.down"): + active_mask = self.cache_activations[key][0] > 0 # [n_pos, n_features] + pos_ids, feat_ids = active_mask.nonzero(as_tuple=True) + result.extend( + [ + (p, key.removesuffix(".feature_acts.down"), i) + for p, i in zip(pos_ids.tolist(), feat_ids.tolist()) + ] + ) + for key in self.cache_activations: + if key.endswith(".error"): + n_pos = self.cache_activations[key].shape[1] + result.extend([(p, key, "error") for p in range(n_pos)]) + return result + + def _build_row_info(self) -> list[tuple[int, str, int]]: + return [(self.n_tokens, "logits", i) for i in range(self.n_logits)] + + def get_row_values(self, rows: int | list[int], feature_type: Literal["up", "down"] = "up") -> list[torch.Tensor]: + """Return activation values for the given row(s).""" + rows = [rows] if isinstance(rows, int) else rows + values: list[torch.Tensor] = [] + for idx, row_idx in enumerate(rows): + row_info = self.row_info[row_idx] + if row_info[1] == "logits": + values.append(self.cache_activations["logits"][idx, -1, row_info[2]]) + else: + values.append( + self.cache_activations[row_info[1] + f".feature_acts.{feature_type}"][idx, row_info[0], row_info[2]] + ) + return values + + def get_column_values( + self, columns: int | list[int], feature_type: Literal["up", "down"] = "down" + ) -> list[torch.Tensor]: + columns = [columns] if isinstance(columns, int) else columns + values = [] + for idx, column_idx in enumerate(columns): + column_info = self.column_info[column_idx] + if isinstance(column_info[2], str): + values.append(self.cache_activations[column_info[1]][idx, column_info[0]]) + else: + values.append( + self.cache_activations[column_info[1] + f".feature_acts.{feature_type}"][ + idx, column_info[0], column_info[2] + ] + ) + return values + + def update_row_info_from_columns(self, column_ids: int | list[int]) -> None: + column_ids = [column_ids] if isinstance(column_ids, int) else column_ids + self.row_info += [cast(tuple[int, str, int], self.column_info[column_id]) for column_id in column_ids] + self.row2column += column_ids + + +def _get_cache_activations(cache_activations: dict[str, torch.Tensor], suffix: str) -> list[torch.Tensor]: + return [v for k, v in cache_activations.items() if k.endswith(suffix)] + + +def _clear_cache_activation_grads(cache_activations: dict[str, torch.Tensor]) -> None: + """Clear gradients for all cached activations in-place.""" + for activation in cache_activations.values(): + if activation.grad is not None: + activation.grad = None + + +def update_edge_matrix_rows(edge_matrix: EdgeMatrix) -> None: + cur_batch_size = edge_matrix.shape[0] + column_start = 0 + assert edge_matrix.cache_activations["hook_embed"].grad is not None, "hook_embed must have a gradient" + edge_matrix[:, column_start : column_start + edge_matrix.n_tokens] = ( + einops.einsum( + edge_matrix.cache_activations["hook_embed"][:cur_batch_size], + edge_matrix.cache_activations["hook_embed"].grad[:cur_batch_size], + "b pos d_model, b pos d_model -> b pos", + ) + .detach() + .cpu() + ) + column_start += edge_matrix.n_tokens + + for feature_acts_down in _get_cache_activations(edge_matrix.cache_activations, ".feature_acts.down"): + assert feature_acts_down.grad is not None, "feature_acts_down must have a gradient" + attribution = feature_acts_down[:cur_batch_size] * feature_acts_down.grad[:cur_batch_size] # [b, pos, d_sae] + active_mask = feature_acts_down[0] > 0 # [pos, d_sae] + pos_ids, feat_ids = active_mask.nonzero(as_tuple=True) + n_active = pos_ids.shape[0] + edge_matrix[:, column_start : column_start + n_active] = ( + attribution[:, pos_ids, feat_ids].detach().cpu() # [b, n_active] + ) + column_start += n_active + + for error in _get_cache_activations(edge_matrix.cache_activations, ".error"): + assert error.grad is not None, "error must have a gradient" + edge_matrix[:, column_start : column_start + error.shape[1]] = ( + einops.einsum( + error[:cur_batch_size], + error.grad[:cur_batch_size], + "b pos ..., b pos ... -> b pos", + ) + .detach() + .cpu() + ) + column_start += error.shape[1] + + assert column_start == edge_matrix.shape[1], "column_start must be equal to the number of columns" + + class TransformerLensLanguageModel(LanguageModel): def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = None): self.cfg = cfg @@ -364,100 +585,6 @@ def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = No else None ) - def _get_cache_activations(self, cache_activations: dict[str, torch.Tensor], suffix: str) -> list[torch.Tensor]: - return [v for k, v in cache_activations.items() if k.endswith(suffix)] - - def _clear_cache_activation_grads(self, cache_activations: dict[str, torch.Tensor]) -> None: - """Clear gradients for all cached activations in-place.""" - for activation in cache_activations.values(): - if activation.grad is not None: - activation.grad = None - - def _update_edge_matrix_from_cache( - self, - edge_rows: torch.Tensor, - token_embed: torch.Tensor, # batch, pos, d_model - feature_acts_down: list[torch.Tensor], - errors: list[torch.Tensor], - ) -> None: - assert token_embed.grad is not None, "token_embed must have a gradient" - col_start = 0 - token_block = ( - einops.einsum( - token_embed[: edge_rows.shape[0]], - token_embed.grad[: edge_rows.shape[0]], - "b pos d_model, b pos d_model -> b pos", - ) - .detach() - .cpu() - ) - edge_rows[:, col_start : col_start + token_embed.shape[1]] = token_block - col_start += token_embed.shape[1] - - for acts in feature_acts_down: - assert acts.grad is not None, "feature_acts_down must have a gradient" - grad = acts.grad.coalesce().to_dense() if acts.grad.is_sparse else acts.grad - indices = acts.indices() - mask = indices[0] < edge_rows.shape[0] - idx = indices[:, mask] - values = acts.values()[mask].reshape(edge_rows.shape[0], -1) - grad_values = grad[idx[0], idx[1], idx[2]].reshape(edge_rows.shape[0], -1) - products = einops.einsum(values, grad_values, "b k, b k -> b k").detach().cpu() - edge_rows[:, col_start : col_start + products.shape[1]] = products - col_start += products.shape[1] - - for error in errors: - assert error.grad is not None, "error must have a gradient" - edge_rows[:, col_start : col_start + error.shape[1]] = ( - einops.einsum( - error[: edge_rows.shape[0]], - error.grad[: edge_rows.shape[0]], - "b pos ..., b pos ... -> b pos", - ) - .detach() - .cpu() - ) - col_start += error.shape[1] - - assert col_start == edge_rows.shape[1], ( - f"col_start {col_start} must equal edge_rows.shape[1] {edge_rows.shape[1]}" - ) - - def column_id_to_info( - self, - cache_activations: dict[str, torch.Tensor], - ) -> list[tuple[int, int | None, str, int | None]]: - """Return ``(pos, layer, hook_name, dense_idx)`` for each active feature, error, and embedding.""" - result: list[tuple[int, int | None, str, int | None]] = [] - - if "hook_embed" in cache_activations: - embed = cache_activations["hook_embed"] - for p in range(embed.shape[1]): - result.append((p, None, "hook_embed", None)) - - for key, acts in cache_activations.items(): - if not key.endswith(".feature_acts.up"): - continue - hook_name = key.removesuffix(".feature_acts.up") - layer = int(hook_name.split(".")[1]) - idx = acts.indices() - if idx.shape[0] == 3: # (batch, pos, d_sae) - mask = idx[0] == 0 - pos, dense = idx[1, mask].tolist(), idx[2, mask].tolist() - else: # (pos, d_sae) - pos, dense = idx[0].tolist(), idx[1].tolist() - result += [(p, layer, hook_name, d) for p, d in zip(pos, dense)] - - for key, error in cache_activations.items(): - if not key.endswith(".error"): - continue - hook_name = key.removesuffix(".error") - layer = int(hook_name.split(".")[1]) - for p in range(error.shape[1]): - result.append((p, layer, hook_name, None)) - - return result - def attribute( self, inputs: torch.Tensor | str, @@ -465,12 +592,11 @@ def attribute( max_n_logits: int = 10, desired_logit_prob: float = 0.95, batch_size: int = 512, - max_feature_nodes: Optional[int] = None, + max_features: int | None = None, ): assert self.model is not None, "model must be initialized" cache_activations: dict[str, torch.Tensor] = {} tokens = ensure_tokenized(inputs, self.tokenizer, device=self.device) - n_tokens = tokens.shape[0] fwd_hooks_in: list[tuple[Union[str, Callable], Callable]] = [ ( hook_in, @@ -513,77 +639,69 @@ def token_fwd_hook_fn( ) ] with self.hooks(fwd_hooks=fwd_hooks_in + fwd_hooks_out + token_fwd_hooks): - batch_logits = self.forward(einops.repeat(tokens, "n -> b n", b=batch_size))[:, -1] # batch, d_vocab - - feature_acts_down = self._get_cache_activations(cache_activations, ".feature_acts.down") - errors = self._get_cache_activations(cache_activations, ".error") - - column_id_to_info = self.column_id_to_info(cache_activations) + batch_logits = self.forward(einops.repeat(tokens, "n -> b n", b=batch_size)) with torch.no_grad(): - probs = torch.softmax(batch_logits[0], dim=-1) + probs = torch.softmax(batch_logits[0, -1], dim=-1) top_p, top_idx = torch.topk(probs, max_n_logits) cutoff = int(torch.searchsorted(torch.cumsum(top_p, 0), desired_logit_prob)) + 1 top_p, top_idx = top_p[:cutoff], top_idx[:cutoff] - n_logits = len(top_idx) - num_error_nodes = len(errors) * n_tokens - num_active_features = len(column_id_to_info) - n_tokens - num_error_nodes - feature_nodes = min(max_feature_nodes or num_active_features, num_active_features) - edge_matrix = torch.zeros( - feature_nodes + n_logits, n_tokens + num_active_features + num_error_nodes - ) # [downstream_nodes, upstream_nodes], on cpu, [features+logits, tokens+features+errors] - - for i in range(0, n_logits, batch_size): - cur_batch_size = min(batch_size, n_logits - i) - batch_nodes = batch_logits[:, top_idx[i : i + cur_batch_size]] - torch.mean(batch_logits[0], dim=-1) - self._clear_cache_activation_grads(cache_activations) - batch_nodes.diagonal().sum().backward(retain_graph=True) - self._update_edge_matrix_from_cache( - edge_rows=edge_matrix[feature_nodes + i : feature_nodes + i + cur_batch_size, :], - token_embed=cache_activations["hook_embed"], - feature_acts_down=feature_acts_down, - errors=errors, - ) + cache_activations["logits"] = batch_logits[:, :, top_idx] + + edge_matrix = EdgeMatrix( + cache_activations, max_features=max_features + ) # [downstream_nodes, upstream_nodes], on cpu, [logits+max_features, tokens+features+errors] - def get_normalize_edge_maxtrix(edge_matrix: torch.Tensor): - return torch.abs(edge_matrix) / torch.abs(edge_matrix).sum(dim=1, keepdim=True).clamp(min=1e-8) + for i in range(0, edge_matrix.n_logits, batch_size): + cur_batch_size = min(batch_size, edge_matrix.n_logits - i) + batch_nodes: torch.Tensor = torch.stack(edge_matrix.get_row_values(list(range(i, i + cur_batch_size)))) + batch_nodes -= torch.mean(batch_logits[:cur_batch_size, -1, :], dim=-1).squeeze() # [cur_batch_size, ] + _clear_cache_activation_grads(cache_activations) + batch_nodes.sum().backward(retain_graph=True) + update_edge_matrix_rows(cast(EdgeMatrix, edge_matrix[i : i + cur_batch_size])) + + def get_normalize_edge_matrix(matrix: torch.Tensor): + return torch.abs(matrix) / torch.abs(matrix).sum(dim=1, keepdim=True).clamp(min=1e-8) - feature_row_to_col = torch.full((feature_nodes,), -1, dtype=torch.int64, device="cpu") features_attributions = einops.einsum( - get_normalize_edge_maxtrix(edge_matrix[feature_nodes:, :]), + get_normalize_edge_matrix( + edge_matrix[ + : edge_matrix.n_logits, edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features + ] + ), top_p, - "l t, l -> t", + "logits features, logits -> features", ) - for i in range(0, feature_nodes, batch_size): - cur_batch = min(batch_size, feature_nodes - i) - already_assigned = feature_row_to_col[feature_row_to_col != -1] - features_attributions[already_assigned] = float("-inf") + visited_features: torch.Tensor | None = None + for i in range(0, edge_matrix.max_features, batch_size): + cur_batch = min(batch_size, edge_matrix.max_features - i) + if visited_features is not None: + features_attributions[visited_features] = float("-inf") _, batch_feature_ids = torch.topk(features_attributions, cur_batch) - feature_row_to_col[i : i + cur_batch] = batch_feature_ids - self._clear_cache_activation_grads(cache_activations) - bwd = None - for batch_idx, feat_id in enumerate(batch_feature_ids.tolist()): - info: Any = column_id_to_info[feat_id] - acts = cache_activations[info[2] + ".feature_acts.up"] - idx = acts.indices() - mask = (idx[0] == batch_idx) & (idx[1] == info[0]) & (idx[2] == info[3]) - bwd = acts.values()[mask][0] if bwd is None else bwd + acts.values()[mask][0] - assert bwd is not None, "bwd must not be None" - bwd.backward(retain_graph=True) - self._update_edge_matrix_from_cache( - edge_rows=edge_matrix[i : i + cur_batch, :], - token_embed=cache_activations["hook_embed"], - feature_acts_down=feature_acts_down, - errors=errors, + batch_column_ids = batch_feature_ids + edge_matrix.n_tokens + _clear_cache_activation_grads(cache_activations) + bwd = edge_matrix.get_column_values(batch_column_ids.tolist(), "up") + torch.stack(bwd).sum().backward(retain_graph=True) + update_edge_matrix_rows( + cast(EdgeMatrix, edge_matrix[i + edge_matrix.n_logits : i + cur_batch + edge_matrix.n_logits]) + ) + edge_matrix.update_row_info_from_columns(batch_column_ids.tolist()) + visited_features = ( + batch_feature_ids if visited_features is None else torch.cat([visited_features, batch_feature_ids]) ) features_attributions += einops.einsum( - get_normalize_edge_maxtrix(edge_matrix[feature_nodes:, batch_feature_ids]), - get_normalize_edge_maxtrix(edge_matrix[i : i + cur_batch, :]), + get_normalize_edge_matrix(edge_matrix[: edge_matrix.n_logits, batch_column_ids]), + get_normalize_edge_matrix( + edge_matrix[ + i + edge_matrix.n_logits : i + cur_batch + edge_matrix.n_logits, + edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features, + ] + ), top_p, - "l b, b t, l -> t", + "logits batch_features, batch_features features, logits -> features", ) - return edge_matrix, cache_activations, column_id_to_info + return edge_matrix @property def eos_token_id(self) -> int | None: From b6046c1a9ee4a5b9b359e07d1b7bac5f75e800b9 Mon Sep 17 00:00:00 2001 From: frankstein Date: Tue, 3 Mar 2026 16:20:50 +0800 Subject: [PATCH 5/7] refactor(sparse-dictionary): add hooks_in and hooks_out properties to SparseDictionaryConfig --- src/lm_saes/sparse_dictionary.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/lm_saes/sparse_dictionary.py b/src/lm_saes/sparse_dictionary.py index cd200add..2dd54628 100644 --- a/src/lm_saes/sparse_dictionary.py +++ b/src/lm_saes/sparse_dictionary.py @@ -188,6 +188,18 @@ def associated_hook_points(self) -> list[str]: """List of hook points used by the sparse dictionary, including all input and label hook points. This is used to retrieve useful data from the input activation source.""" raise NotImplementedError("Subclasses must implement this method") + @property + @abstractmethod + def hooks_in(self) -> list[str]: + """List of hook points used by the sparse dictionary, including all input hook points. This is used to retrieve useful data from the input activation source.""" + raise NotImplementedError("Subclasses must implement this method") + + @property + @abstractmethod + def hooks_out(self) -> list[str]: + """List of hook points used by the sparse dictionary, including all output hook points. This is used to retrieve useful data from the output activation source.""" + raise NotImplementedError("Subclasses must implement this method") + class SparseDictionary(HookedRootModule, ABC): """Abstract base class for all sparse dictionary models. From e6685883d84108210116ea18a9d22fa0da00f2b6 Mon Sep 17 00:00:00 2001 From: frankstein Date: Tue, 3 Mar 2026 16:21:05 +0800 Subject: [PATCH 6/7] refactor(language-model): update hook functions to use SparseDictionary instead of AbstractSparseAutoEncoder --- src/lm_saes/backend/language_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index 951f7aef..3c40f28b 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -28,7 +28,7 @@ from lm_saes.backend.run_with_cache_until import run_with_cache_until from lm_saes.config import BaseModelConfig from lm_saes.lorsa import LowRankSparseAttention -from lm_saes.sae import AbstractSparseAutoEncoder, SparseAutoEncoder +from lm_saes.sae import SparseAutoEncoder, SparseDictionary from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type from lm_saes.utils.distributed import DimMap from lm_saes.utils.misc import ensure_tokenized, item, pad_and_truncate_tokens @@ -237,13 +237,13 @@ def pad_token_id(self) -> int | None: pass -def hook_in_fn_builder(replacement_module: AbstractSparseAutoEncoder) -> Callable: +def hook_in_fn_builder(replacement_module: SparseDictionary) -> Callable: if isinstance(replacement_module, SparseAutoEncoder): def hook_in_fn( x: torch.Tensor, hook: str, - replacement_module: AbstractSparseAutoEncoder, + replacement_module: SparseDictionary, cache_activations: dict[str, torch.Tensor], ): assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" @@ -284,11 +284,11 @@ def hook_in_fn( raise ValueError(f"Unsupported replacement module type: {type(replacement_module)}") -def hook_out_fn_builder(replacement_module: AbstractSparseAutoEncoder) -> Callable: +def hook_out_fn_builder(replacement_module: SparseDictionary) -> Callable: def hook_out_fn( x: torch.Tensor, hook: str, - replacement_module: AbstractSparseAutoEncoder, + replacement_module: SparseDictionary, cache_activations: dict[str, torch.Tensor], update_error_cache: bool = False, ): @@ -588,7 +588,7 @@ def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = No def attribute( self, inputs: torch.Tensor | str, - replacement_modules: list[AbstractSparseAutoEncoder], + replacement_modules: list[SparseDictionary], max_n_logits: int = 10, desired_logit_prob: float = 0.95, batch_size: int = 512, From 26b6076797d74d4be08063b6588090c53e9024ea Mon Sep 17 00:00:00 2001 From: frankstein Date: Tue, 3 Mar 2026 19:32:25 +0800 Subject: [PATCH 7/7] refactor(language-model): enhance hook functions to utilize HookPoint and improve type hints --- src/lm_saes/backend/language_model.py | 166 ++++++++++++++++---------- 1 file changed, 106 insertions(+), 60 deletions(-) diff --git a/src/lm_saes/backend/language_model.py b/src/lm_saes/backend/language_model.py index 3c40f28b..75d671fa 100644 --- a/src/lm_saes/backend/language_model.py +++ b/src/lm_saes/backend/language_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import re @@ -6,7 +8,7 @@ from contextlib import contextmanager from functools import partial from itertools import accumulate -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast import einops import torch @@ -16,7 +18,9 @@ from torch.distributed import DeviceMesh from torch.distributed.tensor import DTensor from torch.distributed.tensor.experimental import local_map +from tqdm import tqdm from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookPoint from transformers import ( AutoModelForCausalLM, AutoProcessor, @@ -27,13 +31,15 @@ from lm_saes.backend.run_with_cache_until import run_with_cache_until from lm_saes.config import BaseModelConfig -from lm_saes.lorsa import LowRankSparseAttention -from lm_saes.sae import SparseAutoEncoder, SparseDictionary from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type from lm_saes.utils.distributed import DimMap from lm_saes.utils.misc import ensure_tokenized, item, pad_and_truncate_tokens from lm_saes.utils.timer import timer +if TYPE_CHECKING: + from lm_saes.lorsa import LowRankSparseAttention + from lm_saes.sae import SparseDictionary + def to_tokens(tokenizer, text, max_length, device="cpu", prepend_bos=True): tokens = tokenizer( @@ -238,19 +244,23 @@ def pad_token_id(self) -> int | None: def hook_in_fn_builder(replacement_module: SparseDictionary) -> Callable: + from lm_saes.lorsa import LowRankSparseAttention + from lm_saes.sae import SparseAutoEncoder + if isinstance(replacement_module, SparseAutoEncoder): def hook_in_fn( x: torch.Tensor, - hook: str, + hook: HookPoint, replacement_module: SparseDictionary, cache_activations: dict[str, torch.Tensor], ): - assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" - cache_activations[hook + ".x"] = x - cache_activations[hook + ".feature_acts.up"] = replacement_module.encode(x) - cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() - cache_activations[hook + ".feature_acts.down"].requires_grad_(True) + assert hook.name in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" + cache_activations[hook.name + ".x"] = x + cache_activations[hook.name + ".feature_acts.up"] = replacement_module.encode(x) + cache_activations[hook.name + ".feature_acts.down"] = ( + cache_activations[hook.name + ".feature_acts.up"].detach().requires_grad_(True) + ) return x.detach() return hook_in_fn @@ -259,22 +269,24 @@ def hook_in_fn( def hook_in_fn( x: torch.Tensor, - hook: str, + hook: HookPoint, replacement_module: LowRankSparseAttention, cache_activations: dict[str, torch.Tensor], ): - assert hook in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" - cache_activations[hook + ".x"] = x + assert hook.name in replacement_module.cfg.hooks_in, "Hook point must be in hook points in" + cache_activations[hook.name + ".x"] = x encode_result = replacement_module.encode( x, return_hidden_pre=False, return_attention_pattern=True, return_attention_score=True, ) - cache_activations[hook + ".attn_pattern"] = encode_result[1].detach() - cache_activations[hook + ".feature_acts.up"] = encode_result[0] # batch, seq_len, d_sae - cache_activations[hook + ".feature_acts.down"] = cache_activations[hook + ".feature_acts.up"].detach() - cache_activations[hook + ".feature_acts.down"].requires_grad_(True) + cache_activations[hook.name + ".attn_pattern"] = encode_result[1].detach() + cache_activations[hook.name + ".feature_acts.up"] = encode_result[0] # batch, seq_len, d_sae + cache_activations[hook.name + ".feature_acts.down"] = ( + cache_activations[hook.name + ".feature_acts.up"].detach().requires_grad_(True) + ) + return x.detach() return hook_in_fn @@ -287,25 +299,24 @@ def hook_in_fn( def hook_out_fn_builder(replacement_module: SparseDictionary) -> Callable: def hook_out_fn( x: torch.Tensor, - hook: str, + hook: HookPoint, replacement_module: SparseDictionary, cache_activations: dict[str, torch.Tensor], update_error_cache: bool = False, ): - assert hook in replacement_module.cfg.hooks_out, "Hook point must be in hook points out" + assert hook.name in replacement_module.cfg.hooks_out, "Hook point must be in hook points out" hook_in = replacement_module.cfg.hooks_in[0] # TODO: handle multiple hook points in for CLT reconstructed = replacement_module.decode(cache_activations[hook_in + ".feature_acts.down"]) - cache_activations[hook + ".reconstructed"] = reconstructed - assert hook + ".error" in cache_activations or update_error_cache, ( + cache_activations[hook.name + ".reconstructed"] = reconstructed + assert hook.name + ".error" in cache_activations or update_error_cache, ( "There must be an error cache for the hook point" ) if update_error_cache: - error = x - reconstructed - cache_activations[hook + ".error"] = error + error = (x - reconstructed).detach().requires_grad_(True) + cache_activations[hook.name + ".error"] = error else: - error = cache_activations[hook + ".error"] - error = error.detach() - error.requires_grad_(True) + error = cache_activations[hook.name + ".error"] + return reconstructed + error return hook_out_fn @@ -327,6 +338,23 @@ class EdgeMatrix(torch.Tensor): n_error: int max_features: int + _wrap_enabled: bool = True + + @classmethod + @contextmanager + def no_wrap(cls): + """Context manager that disables automatic wrapping in ``__torch_dispatch__``. + + Any torch operation on an ``EdgeMatrix`` inside this context will return + a plain ``torch.Tensor`` instead of being re-wrapped. + """ + prev = cls._wrap_enabled + cls._wrap_enabled = False + try: + yield + finally: + cls._wrap_enabled = prev + @classmethod def _wrap( cls, @@ -391,6 +419,9 @@ def unwrap(x: Any) -> Any: out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + if not cls._wrap_enabled: + return out + def wrap(x: Any) -> Any: if isinstance(x, torch.Tensor) and not isinstance(x, EdgeMatrix) and source is not None: return cls._wrap( @@ -509,25 +540,36 @@ def update_edge_matrix_rows(edge_matrix: EdgeMatrix) -> None: for feature_acts_down in _get_cache_activations(edge_matrix.cache_activations, ".feature_acts.down"): assert feature_acts_down.grad is not None, "feature_acts_down must have a gradient" - attribution = feature_acts_down[:cur_batch_size] * feature_acts_down.grad[:cur_batch_size] # [b, pos, d_sae] - active_mask = feature_acts_down[0] > 0 # [pos, d_sae] - pos_ids, feat_ids = active_mask.nonzero(as_tuple=True) - n_active = pos_ids.shape[0] - edge_matrix[:, column_start : column_start + n_active] = ( - attribution[:, pos_ids, feat_ids].detach().cpu() # [b, n_active] - ) + if feature_acts_down.grad is not None: + attribution = ( + feature_acts_down[:cur_batch_size] * feature_acts_down.grad[:cur_batch_size] + ) # [b, pos, d_sae] + active_mask = feature_acts_down[0] > 0 # [pos, d_sae] + pos_ids, feat_ids = active_mask.nonzero(as_tuple=True) + n_active = pos_ids.shape[0] + edge_matrix[:, column_start : column_start + n_active] = ( + attribution[:, pos_ids, feat_ids].detach().cpu() # [b, n_active] + ) + else: + n_active = item((feature_acts_down[0] > 0).sum()) + edge_matrix[:, column_start : column_start + n_active] = ( + torch.zeros_like(edge_matrix[:, column_start : column_start + n_active]).detach().cpu() + ) column_start += n_active for error in _get_cache_activations(edge_matrix.cache_activations, ".error"): - assert error.grad is not None, "error must have a gradient" edge_matrix[:, column_start : column_start + error.shape[1]] = ( - einops.einsum( - error[:cur_batch_size], - error.grad[:cur_batch_size], - "b pos ..., b pos ... -> b pos", + ( + einops.einsum( + error[:cur_batch_size], + error.grad[:cur_batch_size], + "b pos ..., b pos ... -> b pos", + ) + .detach() + .cpu() ) - .detach() - .cpu() + if error.grad is not None + else torch.zeros_like(edge_matrix[:, column_start : column_start + error.shape[1]]).detach().cpu() ) column_start += error.shape[1] @@ -663,17 +705,20 @@ def token_fwd_hook_fn( def get_normalize_edge_matrix(matrix: torch.Tensor): return torch.abs(matrix) / torch.abs(matrix).sum(dim=1, keepdim=True).clamp(min=1e-8) - features_attributions = einops.einsum( - get_normalize_edge_matrix( - edge_matrix[ - : edge_matrix.n_logits, edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features - ] - ), - top_p, - "logits features, logits -> features", - ) + with EdgeMatrix.no_wrap(): + features_attributions = einops.einsum( + get_normalize_edge_matrix( + edge_matrix[ + : edge_matrix.n_logits, + edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features, + ] + ).to(self.device), + top_p, + "logits features, logits -> features", + ) visited_features: torch.Tensor | None = None - for i in range(0, edge_matrix.max_features, batch_size): + + for i in tqdm(range(0, edge_matrix.max_features, batch_size)): cur_batch = min(batch_size, edge_matrix.max_features - i) if visited_features is not None: features_attributions[visited_features] = float("-inf") @@ -689,17 +734,18 @@ def get_normalize_edge_matrix(matrix: torch.Tensor): visited_features = ( batch_feature_ids if visited_features is None else torch.cat([visited_features, batch_feature_ids]) ) - features_attributions += einops.einsum( - get_normalize_edge_matrix(edge_matrix[: edge_matrix.n_logits, batch_column_ids]), - get_normalize_edge_matrix( - edge_matrix[ - i + edge_matrix.n_logits : i + cur_batch + edge_matrix.n_logits, - edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features, - ] - ), - top_p, - "logits batch_features, batch_features features, logits -> features", - ) + with EdgeMatrix.no_wrap(): + features_attributions += einops.einsum( + get_normalize_edge_matrix(edge_matrix[: edge_matrix.n_logits, batch_column_ids]).to(self.device), + get_normalize_edge_matrix( + edge_matrix[ + i + edge_matrix.n_logits : i + cur_batch + edge_matrix.n_logits, + edge_matrix.n_tokens : edge_matrix.n_tokens + edge_matrix.n_active_features, + ] + ).to(self.device), + top_p, + "logits batch_features, batch_features features, logits -> features", + ) return edge_matrix