Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyrit/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""Registry module for PyRIT class and instance registries."""

from pyrit.identifiers import Identifier, class_name_to_snake_case, snake_case_to_class_name
from pyrit.registry.base import RegistryProtocol
from pyrit.registry.class_registries import (
BaseClassRegistry,
Expand All @@ -28,15 +27,12 @@
"BaseClassRegistry",
"BaseInstanceRegistry",
"ClassEntry",
"class_name_to_snake_case",
"discover_in_directory",
"discover_in_package",
"discover_subclasses_in_loaded_modules",
"Identifier",
"InitializerMetadata",
"InitializerRegistry",
"RegistryProtocol",
"snake_case_to_class_name",
"ScenarioMetadata",
"ScenarioRegistry",
"ScorerRegistry",
Expand Down
31 changes: 31 additions & 0 deletions pyrit/registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,43 @@
and instance registries (which store T instances).
"""

from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable

from pyrit.identifiers.class_name_utils import class_name_to_snake_case

# Type variable for metadata (invariant for Protocol compatibility)
MetadataT = TypeVar("MetadataT")


@dataclass(frozen=True)
class ClassRegistryEntry:
"""
Minimal base for class-level registry metadata.

Provides the common fields every registry metadata type needs for display,
lookup, and filtering in class registries.

Attributes:
class_name (str): Python class name (e.g., "ContentHarmsScenario").
class_module (str): Full module path (e.g., "pyrit.scenario.scenarios.content_harms").
class_description (str): Human-readable description, typically from the class docstring.
"""

class_name: str
class_module: str
class_description: str = ""

@property
def snake_class_name(self) -> str:
"""
Snake_case version of class_name (e.g., "content_harms_scenario").

Used by CLI formatting and as registry display keys.
"""
return class_name_to_snake_case(self.class_name)


@runtime_checkable
class RegistryProtocol(Protocol[MetadataT]):
"""
Expand Down
31 changes: 0 additions & 31 deletions pyrit/registry/class_registries/base_class_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar

from pyrit.identifiers import Identifier
from pyrit.identifiers.class_name_utils import class_name_to_snake_case
from pyrit.registry.base import RegistryProtocol

Expand Down Expand Up @@ -183,36 +182,6 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT:
"""
pass

def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier:
"""
Build the common base metadata for a registered class.

This helper extracts fields common to all registries: name, class_name, class_description.
Subclasses can use this for building common fields if needed.

Args:
name: The registry name (snake_case identifier).
entry: The ClassEntry containing the registered class.

Returns:
An Identifier dataclass with common fields.
"""
registered_class = entry.registered_class

# Extract description from docstring, clean up whitespace
doc = registered_class.__doc__ or ""
if doc:
description = " ".join(doc.split())
else:
description = entry.description or "No description available"

return Identifier(
identifier_type="class",
class_name=registered_class.__name__,
class_module=registered_class.__module__,
class_description=description,
)

def get_class(self, name: str) -> Type[T]:
"""
Get a registered class by name.
Expand Down
11 changes: 7 additions & 4 deletions pyrit/registry/class_registries/initializer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional

from pyrit.identifiers import Identifier
from pyrit.registry.base import ClassRegistryEntry
from pyrit.registry.class_registries.base_class_registry import (
BaseClassRegistry,
ClassEntry,
Expand All @@ -34,15 +34,20 @@


@dataclass(frozen=True)
class InitializerMetadata(Identifier):
class InitializerMetadata(ClassRegistryEntry):
"""
Metadata describing a registered PyRITInitializer class.

Use get_class() to get the actual class.
"""

# Human-readable display name (e.g., "Objective Target Setup").
display_name: str = field(kw_only=True)

# Environment variables required by the initializer.
required_env_vars: tuple[str, ...] = field(kw_only=True)

# Execution order priority (lower = earlier).
execution_order: int = field(kw_only=True)


Expand Down Expand Up @@ -208,7 +213,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I
try:
instance = initializer_class()
return InitializerMetadata(
identifier_type="class",
class_name=initializer_class.__name__,
class_module=initializer_class.__module__,
class_description=instance.description,
Expand All @@ -219,7 +223,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I
except Exception as e:
logger.warning(f"Failed to get metadata for {name}: {e}")
return InitializerMetadata(
identifier_type="class",
class_name=initializer_class.__name__,
class_module=initializer_class.__module__,
class_description="Error loading initializer metadata",
Expand Down
16 changes: 12 additions & 4 deletions pyrit/registry/class_registries/scenario_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from pyrit.identifiers import Identifier
from pyrit.identifiers.class_name_utils import class_name_to_snake_case
from pyrit.registry.base import ClassRegistryEntry
from pyrit.registry.class_registries.base_class_registry import (
BaseClassRegistry,
ClassEntry,
Expand All @@ -33,17 +33,26 @@


@dataclass(frozen=True)
class ScenarioMetadata(Identifier):
class ScenarioMetadata(ClassRegistryEntry):
"""
Metadata describing a registered Scenario class.

Use get_class() to get the actual class.
"""

# The default strategy name (e.g., "single_turn")
default_strategy: str = field(kw_only=True)

# All available strategy names for this scenario.
all_strategies: tuple[str, ...] = field(kw_only=True)

# Aggregate strategies that combine multiple attack approaches.
aggregate_strategies: tuple[str, ...] = field(kw_only=True)

# Default dataset names used by this scenario.
default_datasets: tuple[str, ...] = field(kw_only=True)

# Maximum number of items per dataset.
max_dataset_size: Optional[int] = field(kw_only=True)


Expand Down Expand Up @@ -131,7 +140,7 @@ def discover_user_scenarios(self) -> None:
from pyrit.scenario.core import Scenario

try:
for module_name, scenario_class in discover_subclasses_in_loaded_modules(
for _, scenario_class in discover_subclasses_in_loaded_modules(
base_class=Scenario # type: ignore[type-abstract]
):
# Check if this is a user-defined class (not from pyrit.scenario.scenarios)
Expand Down Expand Up @@ -170,7 +179,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM
max_dataset_size = dataset_config.max_dataset_size

return ScenarioMetadata(
identifier_type="class",
class_name=scenario_class.__name__,
class_module=scenario_class.__module__,
class_description=description,
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/cli/test_frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ async def test_print_scenarios_list_with_scenarios(self, capsys):
mock_registry = MagicMock()
mock_registry.list_metadata.return_value = [
ScenarioMetadata(
identifier_type="class",
class_name="TestScenario",
class_module="test.scenarios",
class_description="Test description",
Expand Down Expand Up @@ -369,7 +368,6 @@ async def test_print_initializers_list_with_initializers(self, capsys):
mock_registry = MagicMock()
mock_registry.list_metadata.return_value = [
InitializerMetadata(
identifier_type="class",
class_name="TestInit",
class_module="test.initializers",
class_description="Test initializer",
Expand Down Expand Up @@ -410,7 +408,6 @@ def test_format_scenario_metadata_basic(self, capsys):
"""Test format_scenario_metadata with basic metadata."""

scenario_metadata = ScenarioMetadata(
identifier_type="class",
class_name="TestScenario",
class_module="test.scenarios",
class_description="",
Expand All @@ -432,7 +429,6 @@ def test_format_scenario_metadata_with_description(self, capsys):
"""Test format_scenario_metadata with description."""

scenario_metadata = ScenarioMetadata(
identifier_type="class",
class_name="TestScenario",
class_module="test.scenarios",
class_description="This is a test scenario",
Expand All @@ -451,7 +447,6 @@ def test_format_scenario_metadata_with_description(self, capsys):
def test_format_scenario_metadata_with_strategies(self, capsys):
"""Test format_scenario_metadata with strategies."""
scenario_metadata = ScenarioMetadata(
identifier_type="class",
class_name="TestScenario",
class_module="test.scenarios",
class_description="",
Expand All @@ -472,7 +467,6 @@ def test_format_scenario_metadata_with_strategies(self, capsys):
def test_format_initializer_metadata_basic(self, capsys) -> None:
"""Test format_initializer_metadata with basic metadata."""
initializer_metadata = InitializerMetadata(
identifier_type="class",
class_name="TestInit",
class_module="test.initializers",
class_description="",
Expand All @@ -491,7 +485,6 @@ def test_format_initializer_metadata_basic(self, capsys) -> None:
def test_format_initializer_metadata_with_env_vars(self, capsys) -> None:
"""Test format_initializer_metadata with environment variables."""
initializer_metadata = InitializerMetadata(
identifier_type="class",
class_name="TestInit",
class_module="test.initializers",
class_description="",
Expand All @@ -509,7 +502,6 @@ def test_format_initializer_metadata_with_env_vars(self, capsys) -> None:
def test_format_initializer_metadata_with_description(self, capsys) -> None:
"""Test format_initializer_metadata with description."""
initializer_metadata = InitializerMetadata(
identifier_type="class",
class_name="TestInit",
class_module="test.initializers",
class_description="Test description",
Expand Down
Loading