Skip to content
471 changes: 467 additions & 4 deletions src/lm_saes/backend/language_model.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/lm_saes/circuit/utils/attribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def select_feature_activations(
return torch.stack(activations)


# TODO: remove this function
def ensure_tokenized(prompt: Union[str, torch.Tensor, List[int]], tokenizer) -> torch.Tensor:
"""Convert *prompt* → 1-D tensor of token ids (no batch dim)."""

Expand Down
8 changes: 8 additions & 0 deletions src/lm_saes/clt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def associated_hook_points(self) -> list[str]:
"""All hook points used by the CLT."""
return self.hook_points_in + self.hook_points_out

@property
def hooks_in(self) -> list[str]:
return self.hook_points_in

@property
def hooks_out(self) -> list[str]:
return self.hook_points_out

def model_post_init(self, __context):
super().model_post_init(__context)
assert len(self.hook_points_in) == len(self.hook_points_out), (
Expand Down
8 changes: 8 additions & 0 deletions src/lm_saes/crosscoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class CrossCoderConfig(SparseDictionaryConfig):
def associated_hook_points(self) -> list[str]:
return self.hook_points

@property
def hooks_in(self) -> list[str]:
return self.hook_points

@property
def hooks_out(self) -> list[str]:
return self.hook_points

@property
def n_heads(self) -> int:
return len(self.hook_points)
Expand Down
8 changes: 8 additions & 0 deletions src/lm_saes/lorsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def associated_hook_points(self) -> list[str]:
"""All hook points used by Lorsa."""
return [self.hook_point_in, self.hook_point_out]

@property
def hooks_in(self) -> list[str]:
return [self.hook_point_in]

@property
def hooks_out(self) -> list[str]:
return [self.hook_point_out]

def model_post_init(self, __context):
super().model_post_init(__context)
assert self.hook_point_in is not None and self.hook_point_out is not None, (
Expand Down
8 changes: 8 additions & 0 deletions src/lm_saes/molt.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ def num_rank_types(self) -> int:
def associated_hook_points(self) -> list[str]:
return [self.hook_point_in, self.hook_point_out]

@property
def hooks_in(self) -> list[str]:
return [self.hook_point_in]

@property
def hooks_out(self) -> list[str]:
return [self.hook_point_out]


@register_sae_model("molt")
class MixtureOfLinearTransform(SparseDictionary):
Expand Down
8 changes: 8 additions & 0 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class SAEConfig(SparseDictionaryConfig):
def associated_hook_points(self) -> list[str]:
return [self.hook_point_in, self.hook_point_out]

@property
def hooks_in(self) -> list[str]:
return [self.hook_point_in]

@property
def hooks_out(self) -> list[str]:
return [self.hook_point_out]


@register_sae_model("sae")
class SparseAutoEncoder(SparseDictionary):
Expand Down
12 changes: 12 additions & 0 deletions src/lm_saes/sparse_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def associated_hook_points(self) -> list[str]:
"""List of hook points used by the sparse dictionary, including all input and label hook points. This is used to retrieve useful data from the input activation source."""
raise NotImplementedError("Subclasses must implement this method")

@property
@abstractmethod
def hooks_in(self) -> list[str]:
"""List of hook points used by the sparse dictionary, including all input hook points. This is used to retrieve useful data from the input activation source."""
raise NotImplementedError("Subclasses must implement this method")

@property
@abstractmethod
def hooks_out(self) -> list[str]:
"""List of hook points used by the sparse dictionary, including all output hook points. This is used to retrieve useful data from the output activation source."""
raise NotImplementedError("Subclasses must implement this method")


class SparseDictionary(HookedRootModule, ABC):
"""Abstract base class for all sparse dictionary models.
Expand Down
14 changes: 14 additions & 0 deletions src/lm_saes/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,17 @@ def get_slice_length(s: slice, length: int):
start, stop, step = s.indices(length)
length = (stop - start + step - 1) // step
return length


def ensure_tokenized(
prompt: str | torch.Tensor | list[int], tokenizer, device: torch.device | str = "cpu"
) -> torch.Tensor:
"""Convert *prompt* → 1-D tensor of token ids (no batch dim)."""

if isinstance(prompt, str):
return tokenizer(prompt, return_tensors="pt").input_ids[0].to(device)
if isinstance(prompt, torch.Tensor):
return prompt.squeeze(0).to(device) if prompt.ndim == 2 else prompt.to(device)
if isinstance(prompt, list):
return torch.tensor(prompt, dtype=torch.long, device=device)
raise TypeError(f"Unsupported prompt type: {type(prompt)}")