diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index bd5a3d40fe..a76d286b27 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo "Commands:", " • list-scenarios - See all available scenarios", " • list-initializers - See all available initializers", - " • list-targets - See all available targets in the registry", + " • list-targets [opts] - See all available targets in the registry", " • run [opts] - Execute a security scenario", " • scenario-history - View your session history", " • print-scenario [N] - Display detailed results", diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 1264956ccb..80e383dfe5 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -457,6 +457,48 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: return result +def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse list-targets command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "--initializers target"). + + Returns: + Dictionary with parsed arguments: + - initializers: Optional[list[str | dict[str, Any]]] + - initialization_scripts: Optional[list[str]] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + + result: dict[str, Any] = { + "initializers": None, + "initialization_scripts": None, + } + + i = 0 + while i < len(parts): + if parts[i] == "--initializers": + result["initializers"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initializers"].append(_parse_initializer_arg(parts[i])) + i += 1 + elif parts[i] == "--initialization-scripts": + result["initialization_scripts"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initialization_scripts"].append(parts[i]) + i += 1 + else: + raise ValueError(f"Unknown argument: {parts[i]}") + + return result + + # --------------------------------------------------------------------------- # Shared argparse builder # --------------------------------------------------------------------------- diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 3db5552011..bc9519052a 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -27,6 +27,7 @@ from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg from pyrit.cli._cli_args import add_common_arguments as add_common_arguments from pyrit.cli._cli_args import non_negative_int as non_negative_int +from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments from pyrit.cli._cli_args import positive_int as positive_int @@ -135,9 +136,6 @@ def __init__( ) from e raise - # Store the merged configuration - self._config = config - # Extract values from config for internal use # Use canonical mapping from configuration_loader self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] @@ -187,6 +185,63 @@ async def initialize_async(self) -> None: self._initialized = True + def with_overrides( + self, + *, + initializer_names: Optional[list[Any]] = None, + initialization_scripts: Optional[list[Path]] = None, + log_level: Optional[int] = None, + ) -> FrontendCore: + """ + Create a derived FrontendCore with per-command overrides. + + Copies inherited state (database, env_files, operator, operation, config) + from this instance and applies the given overrides. Shares registries + with the parent to avoid redundant re-discovery and skips re-reading + config files. + + Args: + initializer_names (Optional[list[Any]]): Per-command initializer overrides. + Each entry can be a string name or a dict with 'name' and optional 'args'. + None keeps the parent's value. + initialization_scripts (Optional[list[Path]]): Per-command script overrides. + None keeps the parent's value. + log_level (Optional[int]): Per-command log level override. + None keeps the parent's value. + + Returns: + FrontendCore: A new context ready for use, without re-reading config files. + """ + derived = object.__new__(FrontendCore) + + # Inherit from parent + derived._database = self._database + derived._env_files = self._env_files + derived._operator = self._operator + derived._operation = self._operation + + # Apply overrides or inherit + derived._log_level = log_level if log_level is not None else self._log_level + + if initializer_names is not None: + loader = ConfigurationLoader.from_dict({"initializers": initializer_names}) + derived._initializer_configs = loader._initializer_configs + else: + derived._initializer_configs = self._initializer_configs + + if initialization_scripts is not None: + derived._initialization_scripts = initialization_scripts + else: + derived._initialization_scripts = self._initialization_scripts + + # Share registries (singletons, no need to re-discover) + derived._scenario_registry = self._scenario_registry + derived._initializer_registry = self._initializer_registry + derived._initialized = True + derived._silent_reinit = True + + return derived + @property def scenario_registry(self) -> ScenarioRegistry: """ @@ -254,18 +309,16 @@ async def list_initializers_async( async def list_targets_async( *, context: FrontendCore, - initializer_names: Optional[list[Any]] = None, ) -> list[str]: """ List available target names from the TargetRegistry. Since targets are registered by initializers, this function requires initializers - to have been run first. If initializer_names are provided, they will be resolved - and run before querying the registry. + to have been run first. Configure initializers on the FrontendCore context + (via initializer_names or initialization_scripts) before calling this function. Args: context: PyRIT context with loaded registries. - initializer_names: Optional list of initializer entries to run before listing. Returns: Sorted list of registered target names. @@ -273,25 +326,24 @@ async def list_targets_async( if not context._initialized: await context.initialize_async() - # If initializer names are provided, run them to populate the target registry - if initializer_names or context._initializer_configs: - configs = context._initializer_configs - if configs: - initializer_instances = [] - for config in configs: + # Run initializers and/or initialization scripts to populate the target registry + if context._initializer_configs or context._initialization_scripts: + initializer_instances = [] + if context._initializer_configs: + for config in context._initializer_configs: initializer_class = context.initializer_registry.get_class(config.name) instance = initializer_class() if config.args: instance.set_params_from_args(args=config.args) initializer_instances.append(instance) - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances or None, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) target_registry = TargetRegistry.get_registry_singleton() return target_registry.get_names() diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index aefdfa5f22..9a8ca771c4 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -196,18 +196,20 @@ def main(args: Optional[list[str]] = None) -> int: return asyncio.run(frontend_core.print_initializers_list_async(context=context)) if parsed_args.list_targets: - # Need initializers to populate target registry - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_targets_list_async(context=context)) + # Need initializers or initialization scripts to populate the target registry + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 - if parsed_args.list_targets: - # Need initializers to populate target registry context = frontend_core.FrontendCore( config_file=parsed_args.config_file, + initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, log_level=parsed_args.log_level, ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index f19602bee0..ae38edcde8 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -33,7 +33,7 @@ class PyRITShell(cmd.Cmd): Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers - list-targets - List all available targets from the registry + list-targets [opts] - List all available targets from the registry run [opts] - Run a scenario with optional parameters scenario-history - List all previous scenario runs print-scenario [N] - Print detailed results for scenario run(s) @@ -189,6 +189,9 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" + if arg.strip(): + print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) @@ -197,6 +200,9 @@ def do_list_scenarios(self, arg: str) -> None: def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" + if arg.strip(): + print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_initializers_list_async(context=self.context)) @@ -204,10 +210,43 @@ def do_list_initializers(self, arg: str) -> None: print(f"Error listing initializers: {e}") def do_list_targets(self, arg: str) -> None: - """List all available targets from the TargetRegistry.""" + """ + List all available targets from the TargetRegistry. + + Usage: + list-targets + list-targets --initializers [ ...] + list-targets --initialization-scripts [ ...] + + Options: + --initializers ... Built-in initializers to run first + --initialization-scripts <...> Custom Python scripts to run first + + Examples: + list-targets --initializers target + list-targets --initializers target:tags=default,scorer + """ self._ensure_initialized() try: - asyncio.run(self._fc.print_targets_list_async(context=self.context)) + list_targets_context = self.context + if arg.strip(): + args = self._fc.parse_list_targets_arguments(args_string=arg) + + resolved_scripts = None + if args["initialization_scripts"]: + resolved_scripts = self._fc.resolve_initialization_scripts( + script_paths=args["initialization_scripts"] + ) + list_targets_context = self.context.with_overrides( + initialization_scripts=resolved_scripts, + initializer_names=args["initializers"], + ) + + asyncio.run(self._fc.print_targets_list_async(context=list_targets_context)) + except ValueError as e: + print(f"Error: {e}") + except FileNotFoundError as e: + print(f"Error: {e}") except Exception as e: print(f"Error listing targets: {e}") @@ -292,16 +331,13 @@ def do_run(self, line: str) -> None: print(f"Error: {e}") return - # Create a context for this run with overrides - run_context = self._fc.FrontendCore( - initialization_scripts=resolved_scripts, + # Create a context for this run with per-command overrides, + # inheriting config_file, database, and env_files from startup. + run_context = self.context.with_overrides( initializer_names=args["initializers"], - log_level=args["log_level"] if args["log_level"] else self.default_log_level, + initialization_scripts=resolved_scripts, + log_level=args["log_level"], ) - # Use the existing registries (don't reinitialize) - run_context._scenario_registry = self.context._scenario_registry - run_context._initializer_registry = self.context._initializer_registry - run_context._initialized = True try: result = asyncio.run( @@ -338,6 +374,9 @@ def do_scenario_history(self, arg: str) -> None: Shows a numbered list of all scenario runs with the commands used. """ + if arg.strip(): + print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}") + return if not self._scenario_history: print("No scenario runs in history.") return @@ -467,8 +506,9 @@ def do_help(self, arg: str) -> None: print(" pyrit_shell") print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") else: - # Show help for specific command - super().do_help(arg) + # Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup + normalized_arg = arg.replace("-", "_") + super().do_help(normalized_arg) def do_exit(self, arg: str) -> bool: """ diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 61b3c7bb50..422bc6a2d3 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -672,6 +672,46 @@ def test_parse_run_arguments_missing_value(self): frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") +class TestParseListTargetsArguments: + """Tests for parse_list_targets_arguments function.""" + + def test_parse_list_targets_arguments_empty(self): + """Test parsing empty string returns defaults.""" + result = frontend_core.parse_list_targets_arguments(args_string="") + assert result["initializers"] is None + assert result["initialization_scripts"] is None + + def test_parse_list_targets_arguments_with_initializers(self): + """Test parsing with initializers.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2") + assert result["initializers"] == ["target", "init2"] + + def test_parse_list_targets_arguments_with_initializer_params(self): + """Test parsing initializers with key=value params.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer") + assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}] + + def test_parse_list_targets_arguments_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initialization-scripts script1.py script2.py" + ) + assert result["initialization_scripts"] == ["script1.py", "script2.py"] + + def test_parse_list_targets_arguments_with_both(self): + """Test parsing with both initializers and scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initializers target --initialization-scripts script1.py" + ) + assert result["initializers"] == ["target"] + assert result["initialization_scripts"] == ["script1.py"] + + def test_parse_list_targets_arguments_unknown_arg_raises(self): + """Test parsing with unknown argument raises ValueError.""" + with pytest.raises(ValueError, match="Unknown argument"): + frontend_core.parse_list_targets_arguments(args_string="--unknown-flag") + + @pytest.mark.asyncio @pytest.mark.usefixtures("patch_central_database") class TestRunScenarioAsync: @@ -933,9 +973,125 @@ def test_parse_run_arguments_target_with_other_args(self): args_string="test_scenario --target my_target --initializers init1 --max-concurrency 5" ) - assert result["target"] == "my_target" - assert result["initializers"] == ["init1"] - assert result["max_concurrency"] == 5 + +class TestWithOverrides: + """Tests for FrontendCore.with_overrides method.""" + + def _make_initialized_parent(self) -> frontend_core.FrontendCore: + """Create a fully-initialized FrontendCore for testing with_overrides.""" + parent = frontend_core.FrontendCore( + database=frontend_core.IN_MEMORY, + initializer_names=["parent_init"], + log_level=logging.WARNING, + ) + parent._scenario_registry = MagicMock() + parent._initializer_registry = MagicMock() + parent._initialized = True + parent._silent_reinit = True + return parent + + def test_with_overrides_inherits_fields(self): + """Test that derived context inherits database, env_files, operator, operation.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._database == parent._database + assert derived._env_files == parent._env_files + assert derived._operator == parent._operator + assert derived._operation == parent._operation + + def test_with_overrides_shares_registries(self): + """Test that derived context shares scenario and initializer registries.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._scenario_registry is parent._scenario_registry + assert derived._initializer_registry is parent._initializer_registry + + def test_with_overrides_sets_initialized_and_silent(self): + """Test that derived context is marked initialized with silent reinit.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._initialized is True + assert derived._silent_reinit is True + + def test_with_overrides_none_keeps_parent_values(self): + """Test that passing None for all overrides keeps parent's values.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides( + initializer_names=None, + initialization_scripts=None, + log_level=None, + ) + + assert derived._initializer_configs == parent._initializer_configs + assert derived._initialization_scripts == parent._initialization_scripts + assert derived._log_level == parent._log_level + + def test_with_overrides_initializer_names(self): + """Test that initializer_names override normalizes to InitializerConfig objects.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=["target", "dataset"]) + + assert derived._initializer_configs is not None + names = [ic.name for ic in derived._initializer_configs] + assert names == ["target", "dataset"] + # Parent should still have original + assert [ic.name for ic in parent._initializer_configs] == ["parent_init"] + + def test_with_overrides_initializer_names_dict(self): + """Test initializer_names with dict entries (name + args).""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=[{"name": "target", "args": {"tags": "default"}}]) + + assert derived._initializer_configs is not None + assert len(derived._initializer_configs) == 1 + assert derived._initializer_configs[0].name == "target" + assert derived._initializer_configs[0].args == {"tags": "default"} + + def test_with_overrides_initialization_scripts(self): + """Test that initialization_scripts override replaces parent's scripts.""" + parent = self._make_initialized_parent() + new_scripts = [Path("/new/script.py")] + + derived = parent.with_overrides(initialization_scripts=new_scripts) + + assert derived._initialization_scripts == new_scripts + # Parent should be unchanged + assert parent._initialization_scripts != new_scripts + + def test_with_overrides_log_level(self): + """Test that log_level override replaces parent's log level.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(log_level=logging.DEBUG) + + assert derived._log_level == logging.DEBUG + assert parent._log_level == logging.WARNING + + def test_with_overrides_does_not_mutate_parent(self): + """Test that with_overrides does not modify the parent context.""" + parent = self._make_initialized_parent() + original_configs = parent._initializer_configs + original_log_level = parent._log_level + original_scripts = parent._initialization_scripts + + parent.with_overrides( + initializer_names=["new_init"], + initialization_scripts=[Path("/new.py")], + log_level=logging.DEBUG, + ) + + assert parent._initializer_configs is original_configs + assert parent._log_level == original_log_level + assert parent._initialization_scripts is original_scripts def test_parse_run_arguments_target_missing_value(self): """Test parsing --target without a value raises ValueError.""" @@ -1141,3 +1297,31 @@ async def test_print_targets_list_empty( captured = capsys.readouterr() assert "No targets found" in captured.out assert "--initializers target" in captured.out + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_list_targets_with_initialization_scripts_calls_initialize( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test list_targets_async calls initialize_pyrit_async when only scripts are configured.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = ["script_target"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + context._initialization_scripts = ["/path/to/script.py"] + context._initializer_configs = None + + result = await frontend_core.list_targets_async(context=context) + + assert result == ["script_target"] + # Verify initialize_pyrit_async was called with the scripts + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"] + assert call_kwargs["initializers"] is None diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 34a8b8ad52..a4c3620ca7 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -214,6 +214,55 @@ def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: Mag assert result == 1 + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_initializers( + self, + mock_frontend_core: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initializers passes initializers to FrontendCore.""" + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initializers", "target"]) + + assert result == 0 + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initialization-scripts passes scripts to FrontendCore.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initialization_scripts"] == [Path("/test/script.py")] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_list_targets_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with --list-targets and missing script file.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "missing.py"]) + + assert result == 1 + def test_main_no_scenario_specified(self, capsys): """Test main without scenario name.""" result = pyrit_scan.main([]) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 89a218644d..4f562f9917 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -151,6 +151,15 @@ def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, captured = capsys.readouterr() assert "Error listing scenarios" in captured.out + def test_do_list_scenarios_rejects_args(self, shell, capsys): + """Test do_list_scenarios rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_scenarios("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) def test_do_list_initializers(self, mock_print_initializers: AsyncMock, shell): """Test do_list_initializers command.""" @@ -171,6 +180,67 @@ def test_do_list_initializers_with_exception(self, mock_print_initializers: Asyn captured = capsys.readouterr() assert "Error listing initializers" in captured.out + def test_do_list_initializers_rejects_args(self, shell, capsys): + """Test do_list_initializers rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_initializers("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): + """Test do_list_targets with no arguments uses the default context.""" + s, ctx, _ = shell + + s.do_list_targets("") + + mock_print_targets.assert_called_once_with(context=ctx) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") + def test_do_list_targets_with_initializers( + self, + mock_parse: MagicMock, + mock_print_targets: AsyncMock, + shell, + ): + """Test do_list_targets with --initializers uses context.with_overrides.""" + s, ctx, _ = shell + mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} + mock_derived = MagicMock() + ctx.with_overrides = MagicMock(return_value=mock_derived) + + s.do_list_targets("--initializers target") + + mock_parse.assert_called_once_with(args_string="--initializers target") + ctx.with_overrides.assert_called_once_with( + initialization_scripts=None, + initializer_names=["target"], + ) + mock_print_targets.assert_called_once_with(context=mock_derived) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): + """Test do_list_targets handles exceptions.""" + s, ctx, _ = shell + mock_print_targets.side_effect = RuntimeError("Test error") + + s.do_list_targets("") + + captured = capsys.readouterr() + assert "Error listing targets" in captured.out + + def test_do_list_targets_parse_error(self, shell, capsys): + """Test do_list_targets shows error for invalid args.""" + s, ctx, _ = shell + + s.do_list_targets("--unknown-flag") + + captured = capsys.readouterr() + assert "Error" in captured.out + def test_do_run_empty_line(self, shell, capsys): """Test do_run with empty line.""" s, ctx, _ = shell @@ -380,6 +450,15 @@ def test_do_scenario_history_empty(self, shell, capsys): captured = capsys.readouterr() assert "No scenario runs in history" in captured.out + def test_do_scenario_history_rejects_args(self, shell, capsys): + """Test do_scenario_history rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_scenario_history("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + def test_do_scenario_history_with_runs(self, shell, capsys): """Test do_scenario_history with scenario runs.""" s, ctx, _ = shell @@ -502,6 +581,14 @@ def test_do_help_with_arg(self, shell): s.do_help("run") mock_parent_help.assert_called_with("run") + def test_do_help_with_hyphenated_arg(self, shell): + """Test do_help converts hyphens to underscores for command lookup.""" + s, ctx, _ = shell + + with patch("cmd.Cmd.do_help") as mock_parent_help: + s.do_help("list-targets") + mock_parent_help.assert_called_with("list_targets") + @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell):