diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 5f2fe7536..790b9284d 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -3,7 +3,6 @@ """Registry module for PyRIT class and instance registries.""" -from pyrit.identifiers import Identifier, class_name_to_snake_case, snake_case_to_class_name from pyrit.registry.base import RegistryProtocol from pyrit.registry.class_registries import ( BaseClassRegistry, @@ -28,15 +27,12 @@ "BaseClassRegistry", "BaseInstanceRegistry", "ClassEntry", - "class_name_to_snake_case", "discover_in_directory", "discover_in_package", "discover_subclasses_in_loaded_modules", - "Identifier", "InitializerMetadata", "InitializerRegistry", "RegistryProtocol", - "snake_case_to_class_name", "ScenarioMetadata", "ScenarioRegistry", "ScorerRegistry", diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index 5f5e37400..ee9593d48 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -8,12 +8,43 @@ and instance registries (which store T instances). """ +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable +from pyrit.identifiers.class_name_utils import class_name_to_snake_case + # Type variable for metadata (invariant for Protocol compatibility) MetadataT = TypeVar("MetadataT") +@dataclass(frozen=True) +class ClassRegistryEntry: + """ + Minimal base for class-level registry metadata. + + Provides the common fields every registry metadata type needs for display, + lookup, and filtering in class registries. + + Attributes: + class_name (str): Python class name (e.g., "ContentHarmsScenario"). + class_module (str): Full module path (e.g., "pyrit.scenario.scenarios.content_harms"). + class_description (str): Human-readable description, typically from the class docstring. + """ + + class_name: str + class_module: str + class_description: str = "" + + @property + def snake_class_name(self) -> str: + """ + Snake_case version of class_name (e.g., "content_harms_scenario"). + + Used by CLI formatting and as registry display keys. + """ + return class_name_to_snake_case(self.class_name) + + @runtime_checkable class RegistryProtocol(Protocol[MetadataT]): """ diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index e7df37c78..f481ac31b 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -19,7 +19,6 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar -from pyrit.identifiers import Identifier from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.registry.base import RegistryProtocol @@ -183,36 +182,6 @@ def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: """ pass - def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> Identifier: - """ - Build the common base metadata for a registered class. - - This helper extracts fields common to all registries: name, class_name, class_description. - Subclasses can use this for building common fields if needed. - - Args: - name: The registry name (snake_case identifier). - entry: The ClassEntry containing the registered class. - - Returns: - An Identifier dataclass with common fields. - """ - registered_class = entry.registered_class - - # Extract description from docstring, clean up whitespace - doc = registered_class.__doc__ or "" - if doc: - description = " ".join(doc.split()) - else: - description = entry.description or "No description available" - - return Identifier( - identifier_type="class", - class_name=registered_class.__name__, - class_module=registered_class.__module__, - class_description=description, - ) - def get_class(self, name: str) -> Type[T]: """ Get a registered class by name. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 2058d7608..147726fca 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional -from pyrit.identifiers import Identifier +from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -34,15 +34,20 @@ @dataclass(frozen=True) -class InitializerMetadata(Identifier): +class InitializerMetadata(ClassRegistryEntry): """ Metadata describing a registered PyRITInitializer class. Use get_class() to get the actual class. """ + # Human-readable display name (e.g., "Objective Target Setup"). display_name: str = field(kw_only=True) + + # Environment variables required by the initializer. required_env_vars: tuple[str, ...] = field(kw_only=True) + + # Execution order priority (lower = earlier). execution_order: int = field(kw_only=True) @@ -208,7 +213,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I try: instance = initializer_class() return InitializerMetadata( - identifier_type="class", class_name=initializer_class.__name__, class_module=initializer_class.__module__, class_description=instance.description, @@ -219,7 +223,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> I except Exception as e: logger.warning(f"Failed to get metadata for {name}: {e}") return InitializerMetadata( - identifier_type="class", class_name=initializer_class.__name__, class_module=initializer_class.__module__, class_description="Error loading initializer metadata", diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 083f2c8d0..d2a5bbb8f 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -15,8 +15,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from pyrit.identifiers import Identifier from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, ClassEntry, @@ -33,17 +33,26 @@ @dataclass(frozen=True) -class ScenarioMetadata(Identifier): +class ScenarioMetadata(ClassRegistryEntry): """ Metadata describing a registered Scenario class. Use get_class() to get the actual class. """ + # The default strategy name (e.g., "single_turn") default_strategy: str = field(kw_only=True) + + # All available strategy names for this scenario. all_strategies: tuple[str, ...] = field(kw_only=True) + + # Aggregate strategies that combine multiple attack approaches. aggregate_strategies: tuple[str, ...] = field(kw_only=True) + + # Default dataset names used by this scenario. default_datasets: tuple[str, ...] = field(kw_only=True) + + # Maximum number of items per dataset. max_dataset_size: Optional[int] = field(kw_only=True) @@ -131,7 +140,7 @@ def discover_user_scenarios(self) -> None: from pyrit.scenario.core import Scenario try: - for module_name, scenario_class in discover_subclasses_in_loaded_modules( + for _, scenario_class in discover_subclasses_in_loaded_modules( base_class=Scenario # type: ignore[type-abstract] ): # Check if this is a user-defined class (not from pyrit.scenario.scenarios) @@ -170,7 +179,6 @@ def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioM max_dataset_size = dataset_config.max_dataset_size return ScenarioMetadata( - identifier_type="class", class_name=scenario_class.__name__, class_module=scenario_class.__module__, class_description=description, diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 5e7ce37e1..3bf264a50 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -327,7 +327,6 @@ async def test_print_scenarios_list_with_scenarios(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ ScenarioMetadata( - identifier_type="class", class_name="TestScenario", class_module="test.scenarios", class_description="Test description", @@ -369,7 +368,6 @@ async def test_print_initializers_list_with_initializers(self, capsys): mock_registry = MagicMock() mock_registry.list_metadata.return_value = [ InitializerMetadata( - identifier_type="class", class_name="TestInit", class_module="test.initializers", class_description="Test initializer", @@ -410,7 +408,6 @@ def test_format_scenario_metadata_basic(self, capsys): """Test format_scenario_metadata with basic metadata.""" scenario_metadata = ScenarioMetadata( - identifier_type="class", class_name="TestScenario", class_module="test.scenarios", class_description="", @@ -432,7 +429,6 @@ def test_format_scenario_metadata_with_description(self, capsys): """Test format_scenario_metadata with description.""" scenario_metadata = ScenarioMetadata( - identifier_type="class", class_name="TestScenario", class_module="test.scenarios", class_description="This is a test scenario", @@ -451,7 +447,6 @@ def test_format_scenario_metadata_with_description(self, capsys): def test_format_scenario_metadata_with_strategies(self, capsys): """Test format_scenario_metadata with strategies.""" scenario_metadata = ScenarioMetadata( - identifier_type="class", class_name="TestScenario", class_module="test.scenarios", class_description="", @@ -472,7 +467,6 @@ def test_format_scenario_metadata_with_strategies(self, capsys): def test_format_initializer_metadata_basic(self, capsys) -> None: """Test format_initializer_metadata with basic metadata.""" initializer_metadata = InitializerMetadata( - identifier_type="class", class_name="TestInit", class_module="test.initializers", class_description="", @@ -491,7 +485,6 @@ def test_format_initializer_metadata_basic(self, capsys) -> None: def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: """Test format_initializer_metadata with environment variables.""" initializer_metadata = InitializerMetadata( - identifier_type="class", class_name="TestInit", class_module="test.initializers", class_description="", @@ -509,7 +502,6 @@ def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: def test_format_initializer_metadata_with_description(self, capsys) -> None: """Test format_initializer_metadata with description.""" initializer_metadata = InitializerMetadata( - identifier_type="class", class_name="TestInit", class_module="test.initializers", class_description="Test description", diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index 3c8381dcd..f58ffdeb3 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -3,14 +3,11 @@ from dataclasses import dataclass, field -import pytest - -from pyrit.identifiers import Identifier -from pyrit.registry.base import _matches_filters +from pyrit.registry.base import ClassRegistryEntry, _matches_filters @dataclass(frozen=True) -class MetadataWithTags(Identifier): +class MetadataWithTags(ClassRegistryEntry): """Test metadata with a tags field for list filtering tests.""" tags: tuple[str, ...] = field(kw_only=True) @@ -21,8 +18,7 @@ class TestMatchesFilters: def test_matches_filters_exact_match_string(self): """Test that exact string matches work.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -32,8 +28,7 @@ def test_matches_filters_exact_match_string(self): def test_matches_filters_no_match_string(self): """Test that non-matching strings return False.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -43,8 +38,7 @@ def test_matches_filters_no_match_string(self): def test_matches_filters_multiple_filters_all_match(self): """Test that all filters must match.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -56,8 +50,7 @@ def test_matches_filters_multiple_filters_all_match(self): def test_matches_filters_multiple_filters_partial_match(self): """Test that partial matches return False when not all filters match.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -69,8 +62,7 @@ def test_matches_filters_multiple_filters_partial_match(self): def test_matches_filters_key_not_in_metadata(self): """Test that filtering on a non-existent key returns False.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -79,8 +71,7 @@ def test_matches_filters_key_not_in_metadata(self): def test_matches_filters_empty_filters(self): """Test that empty filters return True.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -90,7 +81,6 @@ def test_matches_filters_empty_filters(self): def test_matches_filters_list_value_contains_filter(self): """Test filtering when metadata value is a list and filter value is in the list.""" metadata = MetadataWithTags( - identifier_type="class", class_name="TestClass", class_module="test.module", class_description="A test item", @@ -102,7 +92,6 @@ def test_matches_filters_list_value_contains_filter(self): def test_matches_filters_list_value_not_contains_filter(self): """Test filtering when metadata value is a list and filter value is not in the list.""" metadata = MetadataWithTags( - identifier_type="class", class_name="TestClass", class_module="test.module", class_description="A test item", @@ -112,8 +101,7 @@ def test_matches_filters_list_value_not_contains_filter(self): def test_matches_filters_exclude_exact_match(self): """Test that exclude filters work for exact matches.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -124,7 +112,6 @@ def test_matches_filters_exclude_exact_match(self): def test_matches_filters_exclude_list_value(self): """Test exclude filters work for list values.""" metadata = MetadataWithTags( - identifier_type="class", class_name="TestClass", class_module="test.module", class_description="A test item", @@ -135,8 +122,7 @@ def test_matches_filters_exclude_list_value(self): def test_matches_filters_exclude_nonexistent_key(self): """Test that exclude filters for non-existent keys don't exclude the item.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -146,8 +132,7 @@ def test_matches_filters_exclude_nonexistent_key(self): def test_matches_filters_combined_include_and_exclude(self): """Test combined include and exclude filters.""" - metadata = Identifier( - identifier_type="class", + metadata = ClassRegistryEntry( class_name="TestClass", class_module="test.module", class_description="A test item", @@ -173,101 +158,3 @@ def test_matches_filters_combined_include_and_exclude(self): ) is False ) - - -class TestIdentifier: - """Tests for the Identifier dataclass and hash computation.""" - - def test_identifier_creation(self): - """Test creating an Identifier instance.""" - metadata = Identifier( - identifier_type="class", - class_name="TestScorer", - class_module="pyrit.test.scorer", - class_description="A test scorer for testing", - ) - assert metadata.identifier_type == "class" - assert metadata.class_name == "TestScorer" - assert metadata.class_module == "pyrit.test.scorer" - assert metadata.class_description == "A test scorer for testing" - # unique_name is auto-computed - assert metadata.unique_name is not None - assert "test_scorer" in metadata.unique_name - - def test_identifier_is_frozen(self): - """Test that Identifier is immutable.""" - metadata = Identifier( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="Description here", - ) - - with pytest.raises(AttributeError): - metadata.unique_name = "new_name" # type: ignore[misc] - - def test_identifier_hash_computed_at_creation(self): - """Test that hash is computed when the Identifier is created.""" - identifier = Identifier( - identifier_type="instance", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - ) - assert identifier.hash is not None - assert len(identifier.hash) == 64 # SHA256 hex length - - def test_identifier_hash_is_deterministic(self): - """Test that the same inputs produce the same hash.""" - identifier1 = Identifier( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - ) - identifier2 = Identifier( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - ) - assert identifier1.hash == identifier2.hash - - def test_identifier_hash_differs_for_different_inputs(self): - """Test that different inputs produce different hashes.""" - identifier1 = Identifier( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - ) - identifier2 = Identifier( - identifier_type="class", - class_name="DifferentClass", - class_module="test.module", - class_description="A test description", - ) - assert identifier1.hash != identifier2.hash - - def test_identifier_hash_is_immutable(self): - """Test that the hash cannot be modified.""" - identifier = Identifier( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - ) - with pytest.raises(AttributeError): - identifier.hash = "new_hash" # type: ignore[misc] - - def test_identifier_subclass_inherits_hash(self): - """Test that subclasses of Identifier also get a computed hash.""" - metadata = MetadataWithTags( - identifier_type="class", - class_name="TestClass", - class_module="test.module", - class_description="A test description", - tags=("tag1", "tag2"), - ) - assert metadata.hash is not None - assert len(metadata.hash) == 64