Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/cli/_banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo
"Commands:",
" • list-scenarios - See all available scenarios",
" • list-initializers - See all available initializers",
" • list-targets - See all available targets in the registry",
" • list-targets [opts] - See all available targets in the registry",
" • run <scenario> [opts] - Execute a security scenario",
" • scenario-history - View your session history",
" • print-scenario [N] - Display detailed results",
Expand Down
42 changes: 42 additions & 0 deletions pyrit/cli/_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,48 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
return result


def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]:
"""
Parse list-targets command arguments from a string (for shell mode).

Args:
args_string: Space-separated argument string (e.g., "--initializers target").

Returns:
Dictionary with parsed arguments:
- initializers: Optional[list[str | dict[str, Any]]]
- initialization_scripts: Optional[list[str]]

Raises:
ValueError: If parsing or validation fails.
"""
parts = args_string.split()

result: dict[str, Any] = {
"initializers": None,
"initialization_scripts": None,
}

i = 0
while i < len(parts):
if parts[i] == "--initializers":
result["initializers"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["initializers"].append(_parse_initializer_arg(parts[i]))
i += 1
elif parts[i] == "--initialization-scripts":
result["initialization_scripts"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["initialization_scripts"].append(parts[i])
i += 1
else:
raise ValueError(f"Unknown argument: {parts[i]}")

return result


# ---------------------------------------------------------------------------
# Shared argparse builder
# ---------------------------------------------------------------------------
Expand Down
32 changes: 15 additions & 17 deletions pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg
from pyrit.cli._cli_args import add_common_arguments as add_common_arguments
from pyrit.cli._cli_args import non_negative_int as non_negative_int
from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments
from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels
from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments
from pyrit.cli._cli_args import positive_int as positive_int
Expand Down Expand Up @@ -254,44 +255,41 @@ async def list_initializers_async(
async def list_targets_async(
*,
context: FrontendCore,
initializer_names: Optional[list[Any]] = None,
) -> list[str]:
"""
List available target names from the TargetRegistry.

Since targets are registered by initializers, this function requires initializers
to have been run first. If initializer_names are provided, they will be resolved
and run before querying the registry.
to have been run first. Configure initializers on the FrontendCore context
(via initializer_names or initialization_scripts) before calling this function.

Args:
context: PyRIT context with loaded registries.
initializer_names: Optional list of initializer entries to run before listing.

Returns:
Sorted list of registered target names.
"""
if not context._initialized:
await context.initialize_async()

# If initializer names are provided, run them to populate the target registry
if initializer_names or context._initializer_configs:
configs = context._initializer_configs
if configs:
initializer_instances = []
for config in configs:
# Run initializers and/or initialization scripts to populate the target registry
if context._initializer_configs or context._initialization_scripts:
initializer_instances = []
if context._initializer_configs:
for config in context._initializer_configs:
initializer_class = context.initializer_registry.get_class(config.name)
instance = initializer_class()
if config.args:
instance.set_params_from_args(args=config.args)
initializer_instances.append(instance)

await initialize_pyrit_async(
memory_db_type=context._database,
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances,
env_files=context._env_files,
silent=getattr(context, "_silent_reinit", False),
)
await initialize_pyrit_async(
memory_db_type=context._database,
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances or None,
env_files=context._env_files,
silent=getattr(context, "_silent_reinit", False),
)

target_registry = TargetRegistry.get_registry_singleton()
return target_registry.get_names()
Expand Down
20 changes: 11 additions & 9 deletions pyrit/cli/pyrit_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,20 @@ def main(args: Optional[list[str]] = None) -> int:
return asyncio.run(frontend_core.print_initializers_list_async(context=context))

if parsed_args.list_targets:
# Need initializers to populate target registry
context = frontend_core.FrontendCore(
config_file=parsed_args.config_file,
initializer_names=parsed_args.initializers,
log_level=parsed_args.log_level,
)
return asyncio.run(frontend_core.print_targets_list_async(context=context))
# Need initializers or initialization scripts to populate the target registry
initialization_scripts = None
if parsed_args.initialization_scripts:
try:
initialization_scripts = frontend_core.resolve_initialization_scripts(
script_paths=parsed_args.initialization_scripts
)
except FileNotFoundError as e:
print(f"Error: {e}")
return 1

if parsed_args.list_targets:
# Need initializers to populate target registry
context = frontend_core.FrontendCore(
config_file=parsed_args.config_file,
initialization_scripts=initialization_scripts,
initializer_names=parsed_args.initializers,
log_level=parsed_args.log_level,
)
Expand Down
104 changes: 91 additions & 13 deletions pyrit/cli/pyrit_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PyRITShell(cmd.Cmd):
Commands:
list-scenarios - List all available scenarios
list-initializers - List all available initializers
list-targets - List all available targets from the registry
list-targets [opts] - List all available targets from the registry
run <scenario> [opts] - Run a scenario with optional parameters
scenario-history - List all previous scenario runs
print-scenario [N] - Print detailed results for scenario run(s)
Expand Down Expand Up @@ -166,6 +166,44 @@ def _ensure_initialized(self) -> None:
self._init_complete.wait()
self._raise_init_error()

def _rebuild_context(
self,
*,
initializer_names: Optional[list[Any]] = None,
initialization_scripts: Optional[list[Path]] = None,
log_level: Optional[int] = None,
) -> frontend_core.FrontendCore:
"""
Create a per-command FrontendCore that inherits the shell's startup config.

Propagates config_file, database, and env_files from the shell's startup
kwargs, then overrides initializer_names, initialization_scripts, and
log_level for the current command. Shares registries with the shell
context to avoid redundant re-discovery.

Args:
initializer_names (Optional[list[Any]]): Per-command initializer overrides.
initialization_scripts (Optional[list[Path]]): Per-command script overrides.
log_level (Optional[int]): Per-command log level override.

Returns:
frontend_core.FrontendCore: A new context ready for use in a command.
"""
cmd_kwargs = dict(self._context_kwargs)
if initializer_names is not None:
cmd_kwargs["initializer_names"] = initializer_names
if initialization_scripts is not None:
cmd_kwargs["initialization_scripts"] = initialization_scripts
cmd_kwargs["log_level"] = log_level if log_level is not None else self.default_log_level

cmd_context = self._fc.FrontendCore(**cmd_kwargs)
cmd_context._scenario_registry = self.context._scenario_registry
cmd_context._initializer_registry = self.context._initializer_registry
cmd_context._initialized = True
cmd_context._silent_reinit = True

return cmd_context

def cmdloop(self, intro: Optional[str] = None) -> None:
"""Override cmdloop to play animated banner before starting the REPL."""
if intro is None:
Expand All @@ -189,6 +227,9 @@ def cmdloop(self, intro: Optional[str] = None) -> None:

def do_list_scenarios(self, arg: str) -> None:
"""List all available scenarios."""
if arg.strip():
print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}")
return
self._ensure_initialized()
try:
asyncio.run(self._fc.print_scenarios_list_async(context=self.context))
Expand All @@ -197,17 +238,53 @@ def do_list_scenarios(self, arg: str) -> None:

def do_list_initializers(self, arg: str) -> None:
"""List all available initializers."""
if arg.strip():
print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}")
return
self._ensure_initialized()
try:
asyncio.run(self._fc.print_initializers_list_async(context=self.context))
except Exception as e:
print(f"Error listing initializers: {e}")

def do_list_targets(self, arg: str) -> None:
"""List all available targets from the TargetRegistry."""
"""
List all available targets from the TargetRegistry.

Usage:
list-targets
list-targets --initializers <name> [<name> ...]
list-targets --initialization-scripts <path> [<path> ...]

Options:
--initializers <name> ... Built-in initializers to run first
--initialization-scripts <...> Custom Python scripts to run first

Examples:
list-targets --initializers target
list-targets --initializers target:tags=default,scorer
"""
self._ensure_initialized()
try:
asyncio.run(self._fc.print_targets_list_async(context=self.context))
list_targets_context = self.context
if arg.strip():
args = self._fc.parse_list_targets_arguments(args_string=arg)

resolved_scripts = None
if args["initialization_scripts"]:
resolved_scripts = self._fc.resolve_initialization_scripts(
script_paths=args["initialization_scripts"]
)
list_targets_context = self._rebuild_context(
initialization_scripts=resolved_scripts,
initializer_names=args["initializers"],
)

asyncio.run(self._fc.print_targets_list_async(context=list_targets_context))
except ValueError as e:
print(f"Error: {e}")
except FileNotFoundError as e:
print(f"Error: {e}")
except Exception as e:
print(f"Error listing targets: {e}")

Expand Down Expand Up @@ -292,16 +369,13 @@ def do_run(self, line: str) -> None:
print(f"Error: {e}")
return

# Create a context for this run with overrides
run_context = self._fc.FrontendCore(
initialization_scripts=resolved_scripts,
# Create a context for this run with per-command overrides,
# inheriting config_file, database, and env_files from startup.
run_context = self._rebuild_context(
initializer_names=args["initializers"],
log_level=args["log_level"] if args["log_level"] else self.default_log_level,
initialization_scripts=resolved_scripts,
log_level=args["log_level"],
)
# Use the existing registries (don't reinitialize)
run_context._scenario_registry = self.context._scenario_registry
run_context._initializer_registry = self.context._initializer_registry
run_context._initialized = True

try:
result = asyncio.run(
Expand Down Expand Up @@ -338,6 +412,9 @@ def do_scenario_history(self, arg: str) -> None:

Shows a numbered list of all scenario runs with the commands used.
"""
if arg.strip():
print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}")
return
if not self._scenario_history:
print("No scenario runs in history.")
return
Expand Down Expand Up @@ -467,8 +544,9 @@ def do_help(self, arg: str) -> None:
print(" pyrit_shell")
print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG")
else:
# Show help for specific command
super().do_help(arg)
# Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup
normalized_arg = arg.replace("-", "_")
super().do_help(normalized_arg)

def do_exit(self, arg: str) -> bool:
"""
Expand Down
68 changes: 68 additions & 0 deletions tests/unit/cli/test_frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,46 @@ def test_parse_run_arguments_missing_value(self):
frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency")


class TestParseListTargetsArguments:
"""Tests for parse_list_targets_arguments function."""

def test_parse_list_targets_arguments_empty(self):
"""Test parsing empty string returns defaults."""
result = frontend_core.parse_list_targets_arguments(args_string="")
assert result["initializers"] is None
assert result["initialization_scripts"] is None

def test_parse_list_targets_arguments_with_initializers(self):
"""Test parsing with initializers."""
result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2")
assert result["initializers"] == ["target", "init2"]

def test_parse_list_targets_arguments_with_initializer_params(self):
"""Test parsing initializers with key=value params."""
result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer")
assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}]

def test_parse_list_targets_arguments_with_initialization_scripts(self):
"""Test parsing with initialization-scripts."""
result = frontend_core.parse_list_targets_arguments(
args_string="--initialization-scripts script1.py script2.py"
)
assert result["initialization_scripts"] == ["script1.py", "script2.py"]

def test_parse_list_targets_arguments_with_both(self):
"""Test parsing with both initializers and scripts."""
result = frontend_core.parse_list_targets_arguments(
args_string="--initializers target --initialization-scripts script1.py"
)
assert result["initializers"] == ["target"]
assert result["initialization_scripts"] == ["script1.py"]

def test_parse_list_targets_arguments_unknown_arg_raises(self):
"""Test parsing with unknown argument raises ValueError."""
with pytest.raises(ValueError, match="Unknown argument"):
frontend_core.parse_list_targets_arguments(args_string="--unknown-flag")


@pytest.mark.asyncio
@pytest.mark.usefixtures("patch_central_database")
class TestRunScenarioAsync:
Expand Down Expand Up @@ -1141,3 +1181,31 @@ async def test_print_targets_list_empty(
captured = capsys.readouterr()
assert "No targets found" in captured.out
assert "--initializers target" in captured.out

@patch("pyrit.cli.frontend_core.TargetRegistry")
@patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock)
async def test_list_targets_with_initialization_scripts_calls_initialize(
self,
mock_init: AsyncMock,
mock_target_registry_class: MagicMock,
):
"""Test list_targets_async calls initialize_pyrit_async when only scripts are configured."""
mock_registry = MagicMock()
mock_registry.get_names.return_value = ["script_target"]
mock_target_registry_class.get_registry_singleton.return_value = mock_registry

context = frontend_core.FrontendCore()
context._scenario_registry = MagicMock()
context._initializer_registry = MagicMock()
context._initialized = True
context._initialization_scripts = ["/path/to/script.py"]
context._initializer_configs = None

result = await frontend_core.list_targets_async(context=context)

assert result == ["script_target"]
# Verify initialize_pyrit_async was called with the scripts
mock_init.assert_called_once()
call_kwargs = mock_init.call_args[1]
assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"]
assert call_kwargs["initializers"] is None
Loading
Loading