From 0ad271c750ff04c7438bd479851039bec8620493 Mon Sep 17 00:00:00 2001 From: Sergey Kolupaev Date: Fri, 1 May 2026 14:31:41 -0700 Subject: [PATCH 01/26] Publish SBOM --- .github/workflows/sbom.yml | 61 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 .github/workflows/sbom.yml diff --git a/.github/workflows/sbom.yml b/.github/workflows/sbom.yml new file mode 100644 index 000000000..519b8902d --- /dev/null +++ b/.github/workflows/sbom.yml @@ -0,0 +1,61 @@ +name: Generate SBOM + +on: + workflow_dispatch: + release: + types: [published] + +jobs: + generate-sbom: + name: "Generate and Publish SBOM" + environment: prod + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install Commander and SBOM tools + run: | + python -m pip install --upgrade pip wheel setuptools + pip install . + pip install cyclonedx-bom + + VERSION=$(python3 -c "import keepercommander.__init__ as init; print(init.__version__)") + echo "PACKAGE_VERSION=${VERSION}" >> $GITHUB_ENV + + pip freeze > installed_packages.txt + + - name: Generate CycloneDX SBOM + run: cyclonedx-py environment -o sbom.cdx.json + + - name: Upload SBOM to Manifest-Cyber + run: | + sbom="$(base64 -w 0 sbom.cdx.json)" + cat < manifest-request.json + { + "base64BomContents": "$sbom", + "source": "github-actions", + "relationship": "first", + "filename": "sbom.cdx.json" + } + EOF + + curl --location --fail --request PUT 'https://api.manifestcyber.com/v1/sbom/upload' \ + --header 'Authorization: Bearer ${{ secrets.MANIFEST_TOKEN }}' \ + --header 'Content-Type: application/json' \ + --data-binary "@manifest-request.json" + + - name: Archive SBOM + uses: actions/upload-artifact@v4 + with: + name: sbom-keepercommander-${{ env.PACKAGE_VERSION }} + path: | + sbom.cdx.json + installed_packages.txt + retention-days: 30 From d704d1776391c1751442afe3e91d1a0ef039ee3e Mon Sep 17 00:00:00 2001 From: idimov-keeper <78815270+idimov-keeper@users.noreply.github.com> Date: Fri, 1 May 2026 18:35:30 -0500 Subject: [PATCH 02/26] Adds workflow to PAM project import / extend (#2010) * Adds optional workflow to PAM project import / extend * Register missing enterprise enforcements in constants.py * Fix workflow time-of-day decoding (HHMM, not minutes-from-midnight) * Remove pre-GA dev/qa-only gate from PAM Workflow commands --- keepercommander/commands/pam_import/README.md | 78 +++- keepercommander/commands/pam_import/base.py | 163 +++++++- keepercommander/commands/pam_import/edit.py | 11 + keepercommander/commands/pam_import/extend.py | 12 + .../commands/pam_import/workflow_apply.py | 262 +++++++++++++ .../commands/workflow/config_commands.py | 7 +- keepercommander/commands/workflow/helpers.py | 15 +- keepercommander/commands/workflow/registry.py | 14 - keepercommander/constants.py | 14 + tests/test_pam_workflow.py | 362 ++++++++++++++++++ 10 files changed, 912 insertions(+), 26 deletions(-) create mode 100644 keepercommander/commands/pam_import/workflow_apply.py create mode 100644 tests/test_pam_workflow.py diff --git a/keepercommander/commands/pam_import/README.md b/keepercommander/commands/pam_import/README.md index e5baae283..233f8f37a 100644 --- a/keepercommander/commands/pam_import/README.md +++ b/keepercommander/commands/pam_import/README.md @@ -303,7 +303,7 @@ Each Machine (pamMachine, pamDatabase, pamDirectory) can specify **Administrativ > **Note 3:** Post rotation scripts (a.k.a. `scripts`) are executed in following order: `pamUser` scripts after any **successful** rotation for that user, `pamMachine` scripts after any **successful** rotation on the machine and `pamConfiguration` scripts after any rotation using that configuration. > **Note 4:** When `allow_supply_user` is false and JIT ephemeral is not used, vault may require a launch credential; import can provide it via `launch_credentials` in the resource's `connection` block. -JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. +JIT and KeeperAI settings below are shared across all resource types (pamMachine, pamDatabase, pamDirectory) except User and RBI (pamRemoteBrowser) records. **Workflow** (approvals / checkout / temporal restrictions) is supported on all four resource types: pamMachine, pamDatabase, pamDirectory, **and** pamRemoteBrowser.
Just-In-Time Access (JIT) @@ -406,6 +406,79 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine ```
+Workflow (Approvals, Checkout, Temporal Access) + +Workflow controls how privileged access to a resource is gated: how many approvals are needed, whether sessions require check-out, MFA, reason/ticket, what time windows access is allowed in, and who can approve (with optional escalation). Workflow is applied via the Keeper Router **after** the resource record and DAG/JIT/AI steps are complete and is not stored on the record itself. + +**How to Configure:** Add `pam_settings.options.workflow` to any pamMachine, pamDatabase, pamDirectory, or pamRemoteBrowser. The workflow object maps directly to the Web Vault's "Workflow" tab on a resource record. + +```json +{ + "pam_settings": { + "options": { + "workflow": { + "approvals_needed": 2, + "checkout_needed": true, + "start_access_on_approval": false, + "require_reason": true, + "require_ticket": false, + "require_mfa": true, + "access_duration": "8h", + "allowed_times": { + "allowed_days": ["mon", "tue", "wed", "thu", "fri"], + "time_ranges": [ + { "start": "09:00", "end": "17:30" } + ], + "timezone": "America/New_York" + }, + "approvers": [ + { + "principal": { "type": "user", "email": "primary.approver@example.com" }, + "escalation": false + }, + { + "principal": { "type": "user", "email": "second.approver@example.com" }, + "escalation": false + }, + { + "principal": { + "type": "team", + "team_uid_base64url": "REPLACE_TEAM_UID_BASE64URL" + }, + "escalation": true, + "escalation_after": "45m" + } + ] + } + } + } +} +``` + +**Field reference:** +- `approvals_needed` *(int, default `0`)* — number of approvals required to grant access. +- `checkout_needed` *(bool, default `false`)* — require explicit check-out before launching a session. +- `start_access_on_approval` *(bool, default `false`)* — start the access window the moment approval is granted (rather than at session launch). +- `require_reason` / `require_ticket` *(bool, default `false`)* — prompt the user for a reason / ticket reference at request time. +- `require_mfa` *(bool, default `false`)* — require MFA at session launch. +- `access_duration` *(string, default `"1d"`)* — how long approved access remains valid. Accepts `Xm` / `Xh` / `Xd` (e.g. `"30m"`, `"8h"`, `"2d"`); a bare integer is interpreted as minutes. Must be positive. +- `allowed_times.allowed_days` *(list of strings)* — restrict access to these weekdays. Accepts 3-letter (`mon`..`sun`) or full names (`monday`..`sunday`), case-insensitive. +- `allowed_times.time_ranges` *(list of `{start, end}` objects)* — one or more allowed daily time windows in `HH:MM` (24-hour) format. **Multiple ranges per day are supported.** A single range whose `end` is earlier than its `start` (e.g. an overnight `22:00–06:00`) **should be split into two ranges** that both fall inside one day (e.g. `22:00–23:59` and `00:00–06:00`) +- `allowed_times.timezone` *(string)* — IANA timezone name (e.g. `"UTC"`, `"America/New_York"`). **Required when `time_ranges` is non-empty.** +- `approvers[]` — list of approver entries. + - `principal.type` — `"user"` or `"team"`. + - For users: `principal.email` (must exist in the enterprise). + - For teams: `principal.team_uid_base64url` (the team's vault UID, base64url-encoded; validated against the local team cache during import — unknown UIDs fail in dry-run). + - `escalation` *(bool)* — whether this approver is in the escalation chain. + - `escalation_after` *(duration string, optional)* — wait this long before escalating to this approver. **Requires `escalation: true`.** + +**Behavior notes:** +- **Trivial workflow is a no-op.** If none of `approvals_needed > 0`, `checkout_needed`, `require_mfa`, `start_access_on_approval`, `allowed_times.allowed_days`, or `allowed_times.time_ranges` is set, the workflow block is treated as absent and no Router call is made. +- **Pre-flight validation runs in `--dry-run`.** Bad durations, malformed `HH:MM`, missing timezone, escalation rule violations, and unknown team UIDs are reported during dry-run before any vault writes. +- **Dry-run skips the Router calls.** Workflow is applied (Router create/update + approver reconcile) only on a real run. +- **`extend` only applies workflow to newly created resources** (existing resources are not touched). +
+
pam_data.resources.pamMachine (RDP) ```json @@ -435,7 +508,8 @@ JIT and KeeperAI settings below are shared across all resource types (pamMachine "ai_threat_detection": "off", "ai_terminate_session_on_detection": "off", "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "allow_supply_host": false, "port_forward": { diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index 22137b8cf..e5cf37835 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -22,9 +22,11 @@ from typing import Any, Dict, Optional, List, Union from ..record_edit import RecordAddCommand as RecordEditAddCommand +from ..workflow.helpers import RecordResolver, WorkflowFormatter from ... import api, attachment, utils, vault, vault_extensions, \ record_facades, record_management from ...display import bcolors +from ...error import CommandError from ...recordv3 import RecordV3 @@ -69,7 +71,8 @@ "pam_settings": { "options" : { "jit_settings": {}, - "ai_settings": {} + "ai_settings": {}, + "workflow": {} }, "connection" : {} }, @@ -611,6 +614,144 @@ def load(cls, data: Union[str, dict]): return obj +class PamWorkflowOptions: + """Parsed workflow settings from pam_settings.options.workflow. + Not stored on record fields nor in DAG; applied via Krouter after record/DAG creation. + """ + + _DEFAULT_DURATION_MS = 86_400_000 # "1d" + + def __init__(self): + self.approvals_needed: int = 0 + self.checkout_needed: bool = False + self.start_access_on_approval: bool = False + self.require_reason: bool = False + self.require_ticket: bool = False + self.require_mfa: bool = False + self.access_duration_ms: int = self._DEFAULT_DURATION_MS + self.allowed_days: List[str] = [] # canonical 3-letter tokens: "mon".."sun" + self.time_ranges: List[dict] = [] # each: {"start": "HH:MM", "end": "HH:MM"} + self.timezone: str = "" + self.approvers: List[dict] = [] # each: {principal_type, email, team_uid_b64, escalation, escalation_after_ms} + + @staticmethod + def _parse_duration(value) -> int: + """Return milliseconds. Raises CommandError on invalid/non-positive value. + Delegates to WorkflowFormatter.parse_duration; adds a None -> default-1d shim + (the CLI command always supplies a string, but the JSON import may omit the key). + """ + if value is None: + return PamWorkflowOptions._DEFAULT_DURATION_MS + return WorkflowFormatter.parse_duration(str(value)) + + @classmethod + def load(cls, data) -> Optional['PamWorkflowOptions']: + """Parse workflow JSON dict. Returns None when absent / null / trivial (V2 guard).""" + if not data or not isinstance(data, dict): + return None + + obj = cls() + obj.approvals_needed = max(0, int(data.get('approvals_needed', 0) or 0)) + obj.checkout_needed = bool(data.get('checkout_needed', False)) + obj.start_access_on_approval = bool(data.get('start_access_on_approval', False)) + obj.require_reason = bool(data.get('require_reason', False)) + obj.require_ticket = bool(data.get('require_ticket', False)) + obj.require_mfa = bool(data.get('require_mfa', False)) + + # V9: access_duration — default "1d" + obj.access_duration_ms = cls._parse_duration(data.get('access_duration')) + + # allowed_times + at = data.get('allowed_times') or {} + if isinstance(at, dict): + days_raw = at.get('allowed_days') or [] + if isinstance(days_raw, list): + for day in days_raw: + d = str(day).lower().strip() + if d not in WorkflowFormatter.DAY_PARSE_MAP: + raise CommandError('', f'workflow: invalid allowed_times.allowed_days token "{day}"') + obj.allowed_days.append(d[:3]) # store as "mon".."sun" + + ranges_raw = at.get('time_ranges') or [] + if isinstance(ranges_raw, list): + for r in ranges_raw: + if isinstance(r, dict): + start = str(r.get('start', '') or '').strip() + end = str(r.get('end', '') or '').strip() + if start and end: + obj.time_ranges.append({'start': start, 'end': end}) + + obj.timezone = str(at.get('timezone', '') or '').strip() + + # V8: time_ranges non-empty => timezone required + if obj.time_ranges and not obj.timezone: + raise CommandError('', 'workflow: allowed_times.time_ranges requires timezone') + + # approvers + for idx, a in enumerate(data.get('approvers') or []): + if not isinstance(a, dict): + continue + principal = a.get('principal') or {} + if not isinstance(principal, dict): + continue + ptype = str(principal.get('type', '') or '').lower() + escalation = bool(a.get('escalation', False)) + esc_after_raw = a.get('escalation_after') + esc_after_ms = cls._parse_duration(esc_after_raw) if esc_after_raw else 0 + # V7: escalation_after requires escalation: true + if esc_after_ms and not escalation: + raise CommandError('', f'workflow: approvers[{idx}] escalation_after requires escalation: true') + if ptype == 'user': + email = str(principal.get('email', '') or '').strip() + if not email: + raise CommandError('', f'workflow: approvers[{idx}] user principal requires non-empty email') + obj.approvers.append({ + 'principal_type': 'user', 'email': email, 'team_uid_b64': None, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + elif ptype == 'team': + uid_b64 = str(principal.get('team_uid_base64url', '') or '').strip() + if not uid_b64: + raise CommandError('', f'workflow: approvers[{idx}] team principal requires non-empty team_uid_base64url') + obj.approvers.append({ + 'principal_type': 'team', 'email': None, 'team_uid_b64': uid_b64, + 'escalation': escalation, 'escalation_after_ms': esc_after_ms, + }) + else: + raise CommandError('', f'workflow: approvers[{idx}] principal.type must be "user" or "team", got "{ptype}"') + + # V2: non-trivial guard — at least one meaningful flag must be set + is_trivial = ( + obj.approvals_needed == 0 + and not obj.start_access_on_approval + and not obj.checkout_needed + and not obj.require_mfa + and not obj.allowed_days + and not obj.time_ranges + ) + if is_trivial: + return None # nothing to persist; caller treats as delete/no-op + + # V4 warning: approvals_needed > 0 with no approvers + if obj.approvals_needed > 0 and not obj.approvers: + logging.warning('workflow: approvals_needed > 0 but no approvers specified') + + return obj + + def validate_principals(self, params, resource_title: str = '') -> None: + """Validate team UIDs via RecordResolver.validate_team (which checks both + team_cache and enterprise.teams). Raises CommandError on first unknown UID. + """ + for idx, a in enumerate(self.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{resource_title}": ' if resource_title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') + + class DagJitSettingsObject(): def __init__(self): self.create_ephemeral: bool = False @@ -2900,10 +3041,12 @@ class PamRemoteBrowserSettings: def __init__( self, options: Optional[DagSettingsObject] = None, - connection: Optional[ConnectionSettingsHTTP] = None + connection: Optional[ConnectionSettingsHTTP] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.options = options self.connection = connection + self.workflow = workflow # not on record nor in DAG; applied via Krouter @classmethod def load(cls, data: Optional[Union[str, dict]]): @@ -2912,9 +3055,14 @@ def load(cls, data: Optional[Union[str, dict]]): except: logging.error(f"PAM RBI Settings field failed to load from: {str(data)[:80]}...") if not isinstance(data, dict): return obj - options = DagSettingsObject.load(data.get("options", {})) + options_dict = data.get("options", {}) or {} + options = DagSettingsObject.load(options_dict) if not is_empty_instance(options): obj.options = options + if isinstance(options_dict, dict): + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) cdata = data.get("connection", {}) # TO DO: if isinstance(cdata, str): lookup_by_name(pam_data.connections) @@ -2944,6 +3092,7 @@ def __init__( options: Optional[DagSettingsObject] = None, jit_settings: Optional[DagJitSettingsObject] = None, ai_settings: Optional[DagAiSettingsObject] = None, + workflow: Optional[PamWorkflowOptions] = None, ): self.allowSupplyHost = allowSupplyHost self.connection = connection @@ -2951,6 +3100,7 @@ def __init__( self.options = options self.jit_settings = jit_settings self.ai_settings = ai_settings + self.workflow = workflow # not on record nor in DAG; applied via Krouter # PamConnectionSettings excludes ConnectionSettingsHTTP pam_connection_classes = [ @@ -2981,8 +3131,8 @@ def is_empty(self): empty = is_empty_instance(self.options) empty = empty and is_empty_instance(self.portForward) empty = empty and is_empty_instance(self.connection, ["protocol"]) - # NB! JIT and AI settings are in import json but not in record json (just DAG json) - empty = empty and self.jit_settings is None and self.ai_settings is None + # NB! JIT, AI, workflow are in import json but not in record json (not DAG either for workflow) + empty = empty and self.jit_settings is None and self.ai_settings is None and self.workflow is None return empty @classmethod @@ -3008,6 +3158,9 @@ def load(cls, data: Union[str, dict]): ai_settings = DagAiSettingsObject.load(ai_value) if ai_settings: obj.ai_settings = ai_settings + workflow_value = options_dict.get("workflow") + if workflow_value is not None: + obj.workflow = PamWorkflowOptions.load(workflow_value) portForward = PamPortForwardSettings.load(data.get("port_forward", {})) if not is_empty_instance(portForward): diff --git a/keepercommander/commands/pam_import/edit.py b/keepercommander/commands/pam_import/edit.py index 0b5d35686..d80fd8354 100644 --- a/keepercommander/commands/pam_import/edit.py +++ b/keepercommander/commands/pam_import/edit.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, List, Union from .keeper_ai_settings import set_resource_jit_settings, set_resource_keeper_ai_settings, refresh_meta_to_latest, refresh_link_to_config_to_latest +from .workflow_apply import apply_workflow, validate_workflow_principals from .base import ( PAM_RESOURCES_RECORD_TYPES, PROJECT_IMPORT_JSON_TEMPLATE, @@ -1642,6 +1643,9 @@ def process_data(self, params, project): resolve_domain_admin(pce, users) # only resolve here - create after machine and user creation + # pre-flight: validate workflow team UIDs before any vault writes (runs in dry-run too) + validate_workflow_principals(params, resources) + # dry run if project["options"].get("dry_run", False) is True: print("Will import file data here...") @@ -1696,6 +1700,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: # machine/db/directory args = parse_command_options(mach, True) if admin_uid: args["admin"] = admin_uid @@ -1739,6 +1746,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + # Machine - create its users (if any) users = getattr(mach, "users", []) users = users if isinstance(users, list) else [] diff --git a/keepercommander/commands/pam_import/extend.py b/keepercommander/commands/pam_import/extend.py index 82fb5522b..c21a2a6f2 100644 --- a/keepercommander/commands/pam_import/extend.py +++ b/keepercommander/commands/pam_import/extend.py @@ -53,6 +53,7 @@ refresh_meta_to_latest, refresh_link_to_config_to_latest, ) +from .workflow_apply import apply_workflow, validate_workflow_principals from ...keeper_dag import EdgeType from ...keeper_dag.types import RefType from ..base import Command @@ -549,6 +550,10 @@ def execute(self, params, **kwargs): fp = (getattr(u, "folder_path", None) or "").strip() u.resolved_folder_uid = path_to_folder_uid.get(fp) or usr_folder_uid + # pre-flight: validate workflow team UIDs for new resources (runs in dry-run too) + new_rscs = [r for r in project.get('mapped_resources', []) if getattr(r, '_extend_tag', None) == 'new'] + validate_workflow_principals(params, new_rscs) + if dry_run: print("[DRY RUN COMPLETE] No changes were made. All actions were validated but not executed.") return @@ -1402,6 +1407,9 @@ def process_data(self, params, project): args["connections"] = True args["v_type"] = RefType.PAM_BROWSER tdag.set_resource_allowed(**args) + rbi_wf = getattr(getattr(mach, 'rbi_settings', None), 'workflow', None) + if rbi_wf: + apply_workflow(params, mach.uid, mach.title or '', rbi_wf) else: args = parse_command_options(mach, True) if admin_uid: @@ -1444,6 +1452,10 @@ def process_data(self, params, project): if ai: refresh_link_to_config_to_latest(params, mach.uid, pam_cfg_uid) + ps_wf = getattr(getattr(mach, 'pam_settings', None), 'workflow', None) + if ps_wf: + apply_workflow(params, mach.uid, mach.title or '', ps_wf) + mach_users = getattr(mach, "users", []) or [] for user in mach_users: if getattr(user, "_extend_tag", None) != "new": diff --git a/keepercommander/commands/pam_import/workflow_apply.py b/keepercommander/commands/pam_import/workflow_apply.py new file mode 100644 index 000000000..65ffc8c5f --- /dev/null +++ b/keepercommander/commands/pam_import/workflow_apply.py @@ -0,0 +1,262 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' bool: + if isinstance(e, KeeperApiError) and e.result_code == 429: + return True + msg = str(getattr(e, 'message', None) or e).lower() + return 'throttle' in msg or 'too many' in msg + + +def _post_with_throttle_retry(params, path: str, **kwargs): + """Wrap _post_request_to_router with progressive backoff on 429 / throttle errors. + Non-throttle errors propagate immediately. Final retry's exception is re-raised. + """ + wait = _THROTTLE_BASE_WAIT + for attempt in range(1, _THROTTLE_MAX_RETRIES + 1): + try: + return _post_request_to_router(params, path, **kwargs) + except Exception as e: + if not _is_throttle_error(e) or attempt >= _THROTTLE_MAX_RETRIES: + raise + logging.warning( + 'Krouter rate-limited on %s (attempt %d/%d); waiting %.1fs', + path, attempt, _THROTTLE_MAX_RETRIES, wait, + ) + time.sleep(wait) + wait *= _THROTTLE_MULTIPLIER + + +# Re-exported for tests and any downstream importers; the canonical map lives +# in WorkflowFormatter.DAY_PARSE_MAP and accepts both 3-letter and full names. +_DAY_PROTO_MAP = { + k: v for k, v in WorkflowFormatter.DAY_PARSE_MAP.items() if len(k) == 3 +} + + +def _build_temporal_filter(opts: PamWorkflowOptions): + """Build TemporalAccessFilter from opts. Returns None when no temporal slice is set. + + startTime / endTime on TimeOfDayRange are HHMM integers (hours*100 + minutes); + see WorkflowFormatter._parse_time_to_hhmm. Canonical sources: + - keeperapp-protobuf/workflow.proto:140 (`int32 startTime = 1; // HHMM format`) + - ka-libs/workflow/.../handlers/WfConfigCRUD.kt::validateHHMM (server validator) + """ + if not opts.allowed_days and not opts.time_ranges and not opts.timezone: + return None + temporal = workflow_pb2.TemporalAccessFilter() + for day_token in opts.allowed_days: + day_enum = WorkflowFormatter.DAY_PARSE_MAP.get(day_token) + if day_enum is not None: + temporal.allowedDays.append(day_enum) + for r in opts.time_ranges: + tr = workflow_pb2.TimeOfDayRange() + tr.startTime = WorkflowFormatter._parse_time_to_hhmm(r['start']) + tr.endTime = WorkflowFormatter._parse_time_to_hhmm(r['end']) + temporal.timeRanges.append(tr) + if opts.timezone: + temporal.timeZone = opts.timezone + return temporal + + +def _build_parameters( + record_uid_bytes: bytes, + record_title: str, + opts: PamWorkflowOptions, +) -> workflow_pb2.WorkflowParameters: + params_proto = workflow_pb2.WorkflowParameters() + params_proto.resource.CopyFrom(ProtobufRefBuilder.record_ref(record_uid_bytes, record_title)) + params_proto.approvalsNeeded = opts.approvals_needed + params_proto.checkoutNeeded = opts.checkout_needed + params_proto.startAccessOnApproval = opts.start_access_on_approval + params_proto.requireReason = opts.require_reason + params_proto.requireTicket = opts.require_ticket + params_proto.requireMFA = opts.require_mfa + params_proto.accessLength = opts.access_duration_ms + + temporal = _build_temporal_filter(opts) + if temporal: + params_proto.allowedTimes.CopyFrom(temporal) + + return params_proto + + +def _build_approver_proto(a: dict) -> workflow_pb2.WorkflowApprover: + approver = workflow_pb2.WorkflowApprover() + if a['principal_type'] == 'user': + approver.user = a['email'] + else: + approver.teamUid = utils.base64_url_decode(a['team_uid_b64']) + approver.escalation = a['escalation'] + if a['escalation_after_ms']: + approver.escalationAfterMs = a['escalation_after_ms'] + return approver + + +def _approver_key(params: KeeperParams, approver: workflow_pb2.WorkflowApprover) -> str: + """Return a stable identity key for an existing server approver (for reconcile diff). + Server may return either user (email) or userId (int). When userId is set, resolve + to email through the enterprise user list so it matches the import-side key. + """ + if approver.HasField('user'): + return f'user:{approver.user}' + if approver.HasField('userId'): + email = RecordResolver.resolve_user(params, approver.userId) + # resolve_user returns 'User ID ' when not found — fall back to userId so + # we don't accidentally key two different unknown users to the same string. + if email and not email.startswith('User ID '): + return f'user:{email}' + return f'userid:{approver.userId}' + if approver.HasField('teamUid'): + return f'team:{utils.base64_url_encode(approver.teamUid)}' + return '' + + +def _new_approver_key(a: dict) -> str: + if a['principal_type'] == 'user': + return f'user:{a["email"]}' + return f'team:{a["team_uid_b64"]}' + + +def _reconcile_approvers( + params: KeeperParams, + record_uid_bytes: bytes, + record_title: str, + existing: List[workflow_pb2.WorkflowApprover], + new_approvers: List[dict], +) -> None: + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + existing_keys = {_approver_key(params, a): a for a in existing} + new_keys = {_new_approver_key(a): a for a in new_approvers} + + to_delete = [a for k, a in existing_keys.items() if k not in new_keys] + to_add = [a for k, a in new_keys.items() if k not in existing_keys] + + if to_delete: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_delete: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + + if to_add: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in to_add: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + + +def apply_workflow( + params: KeeperParams, + record_uid: str, + record_title: str, + opts: PamWorkflowOptions, +) -> None: + """Create or update workflow config via Krouter. Raises CommandError on failure.""" + record_uid_bytes = utils.base64_url_decode(record_uid) + ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record_title) + + try: + existing = _post_with_throttle_retry( + params, 'read_workflow_config', + rq_proto=ref, rs_type=workflow_pb2.WorkflowConfig, + ) + except Exception as e: + raise CommandError('', f'workflow read failed for "{record_title}": {sanitize_router_error(e)}') + + parameters = _build_parameters(record_uid_bytes, record_title, opts) + + try: + if existing: + _post_with_throttle_retry(params, 'update_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0: + _reconcile_approvers( + params, record_uid_bytes, record_title, + list(existing.approvers), opts.approvers, + ) + elif existing.approvers: + # approvals_needed dropped to 0: remove all existing approvers (V5) + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in existing.approvers: + config.approvers.append(a) + _post_with_throttle_retry(params, 'delete_workflow_approvers', rq_proto=config) + else: + _post_with_throttle_retry(params, 'create_workflow_config', rq_proto=parameters) + if opts.approvals_needed > 0 and opts.approvers: + config = workflow_pb2.WorkflowConfig() + config.parameters.resource.CopyFrom(ref) + for a in opts.approvers: + config.approvers.append(_build_approver_proto(a)) + _post_with_throttle_retry(params, 'add_workflow_approvers', rq_proto=config) + except CommandError: + raise + except Exception as e: + raise CommandError('', f'workflow apply failed for "{record_title}": {sanitize_router_error(e)}') + + +def validate_workflow_principals(params: KeeperParams, resources) -> None: + """Pre-flight: validate team UIDs in workflow approvers for all resources. + Uses RecordResolver.validate_team which checks both team_cache and enterprise.teams, + matching the lookup path used by `pam workflow add-approver`. Raises CommandError + on the first unknown UID, with the resource title in the message for context. + """ + for mach in resources or []: + opts = None + ps = getattr(mach, 'pam_settings', None) + if ps: + opts = getattr(ps, 'workflow', None) + if opts is None: + rbi = getattr(mach, 'rbi_settings', None) + if rbi: + opts = getattr(rbi, 'workflow', None) + if opts is None: + continue + title = getattr(mach, 'title', '') or '' + for idx, a in enumerate(opts.approvers): + if a['principal_type'] != 'team': + continue + try: + RecordResolver.validate_team(params, a['team_uid_b64']) + except CommandError as e: + prefix = f'Resource "{title}": ' if title else '' + raise CommandError('', f'{prefix}workflow approvers[{idx}]: {e.message or str(e)}') diff --git a/keepercommander/commands/workflow/config_commands.py b/keepercommander/commands/workflow/config_commands.py index 30137bf9c..28ca9c71f 100644 --- a/keepercommander/commands/workflow/config_commands.py +++ b/keepercommander/commands/workflow/config_commands.py @@ -329,8 +329,11 @@ def _print_table(params, response, record_uid): print(f" Days: {', '.join(day_names)}") if at.timeRanges: for tr in at.timeRanges: - start_h, start_m = divmod(tr.startTime, 60) - end_h, end_m = divmod(tr.endTime, 60) + # startTime / endTime are HHMM (hours*100 + minutes); see + # WorkflowFormatter._parse_time_to_hhmm and the canonical + # ka-libs/workflow/.../WfConfigCRUD.kt::validateHHMM. + start_h, start_m = divmod(tr.startTime, 100) + end_h, end_m = divmod(tr.endTime, 100) print(f" Time: {start_h:02d}:{start_m:02d} - {end_h:02d}:{end_m:02d}") if at.timeZone: print(f" Timezone: {at.timeZone}") diff --git a/keepercommander/commands/workflow/helpers.py b/keepercommander/commands/workflow/helpers.py index e46eb263b..21a3ac4dd 100644 --- a/keepercommander/commands/workflow/helpers.py +++ b/keepercommander/commands/workflow/helpers.py @@ -523,9 +523,17 @@ def build_temporal_filter(allowed_days_str, time_range_str, timezone_str): @staticmethod def _parse_time_to_hhmm(time_str): - """Parse 'HH:MM' into the HHMM integer encoding the server expects on - TimeOfDayRange.startTime / .endTime — e.g. '03:00' -> 300, '17:30' -> 1730. - Server validates: HHMM integer with HH in 0-23 and MM in 0-59. + """Parse 'HH:MM' to the HHMM integer the server stores on + TimeOfDayRange.startTime / .endTime: hours*100 + minutes. + Examples: '00:00' -> 0, '03:00' -> 300, '09:00' -> 900, '17:30' -> 1730. + Valid range: 0..2359 with hours in 0-23 and minutes in 0-59. + + Canonical sources (all agree on HHMM): + - keeperapp-protobuf/workflow.proto:140 + `int32 startTime = 1; // HHMM format` + - ka-libs/workflow/src/main/kotlin/com/keepersecurity/workflow/handlers/WfConfigCRUD.kt::validateHHMM + `val hours = value / 100; val minutes = value % 100` + throws "Invalid : . Expected HHMM integer with HH in 0-23 and MM in 0-59" on bad input. """ try: parts = time_str.split(':') @@ -547,6 +555,7 @@ def format_temporal_filter(at): if at.timeRanges: ranges = [] for tr in at.timeRanges: + # startTime / endTime are HHMM integers (see _parse_time_to_hhmm). sh, sm = divmod(tr.startTime, 100) eh, em = divmod(tr.endTime, 100) ranges.append(f"{sh:02d}:{sm:02d}-{eh:02d}:{em:02d}") diff --git a/keepercommander/commands/workflow/registry.py b/keepercommander/commands/workflow/registry.py index ae87e7e8c..2ea6f31ed 100644 --- a/keepercommander/commands/workflow/registry.py +++ b/keepercommander/commands/workflow/registry.py @@ -9,9 +9,6 @@ # Contact: ops@keepersecurity.com # -import logging -from urllib.parse import urlparse - from ..base import GroupCommand, dump_report_data from ...display import bcolors from .helpers import _ENFORCEMENT_KEY @@ -42,15 +39,8 @@ class PAMWorkflowCommand(GroupCommand): - NOTICE_MSG = 'Notice: PAM Workflow commands are not in production yet. They will be available soon.' - _ALLOWED_PREFIXES = ('dev.', 'qa.') _ADMIN_VERBS = frozenset({'create', 'update', 'delete', 'add-approver', 'remove-approver'}) - @staticmethod - def _is_allowed_server(params): - hostname = urlparse(params.rest_context.server_base).hostname or '' - return any(hostname.startswith(p) for p in PAMWorkflowCommand._ALLOWED_PREFIXES) - @staticmethod def _can_manage_workflows(params): enforcements = getattr(params, 'enforcements', None) @@ -62,10 +52,6 @@ def _can_manage_workflows(params): ) def execute_args(self, params, args, **kwargs): - if not self._is_allowed_server(params): - logging.warning(f"{bcolors.WARNING}{self.NOTICE_MSG}{bcolors.ENDC}") - return - self._current_params = params pos = args.find(' ') if args else -1 diff --git a/keepercommander/constants.py b/keepercommander/constants.py index f60ae1c90..89ced2729 100644 --- a/keepercommander/constants.py +++ b/keepercommander/constants.py @@ -112,6 +112,7 @@ class PrivilegeScope(enum.IntEnum): ("MASTER_PASSWORD_MINIMUM_UPPER", 12, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_LOWER", 13, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_MINIMUM_DIGITS", 14, "LONG", "LOGIN_SETTINGS"), + ("MASTER_PASSWORD_MINIMUM_LENGTH_NO_PROMPT", 15, "LONG", "LOGIN_SETTINGS"), ("MASTER_PASSWORD_RESTRICT_DAYS_BEFORE_REUSE", 16, "LONG", "LOGIN_SETTINGS"), ("REQUIRE_TWO_FACTOR", 20, "BOOLEAN", "TWO_FACTOR_AUTHENTICATION"), ("MASTER_PASSWORD_MAXIMUM_DAYS_BEFORE_CHANGE", 22, "LONG", "LOGIN_SETTINGS"), @@ -231,6 +232,7 @@ class PrivilegeScope(enum.IntEnum): ("ALLOW_VIEW_KCM_RECORDINGS", 234, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_TOTP_FIELD", 235, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("ALLOW_VIEW_RBI_RECORDINGS", 236, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("USE_DEFAULT_BROWSER_FOR_SSO", 237, "TERNARY_DEN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_MANAGE_TLA", 238, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SELF_DESTRUCT_RECORDS", 239, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_PERSONAL_USING_BUSINESS_DOMAINS", 240, "STRING", "ACCOUNT_ENFORCEMENTS"), @@ -240,6 +242,8 @@ class PrivilegeScope(enum.IntEnum): ("WARN_PERSONAL_USING_BUSINESS_SITES", 244, "STRING", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_ACCOUNT_SWITCHING", 245, "BOOLEAN", "AUTHENTICATION_ENFORCEMENTS"), ("RESTRICT_PASSKEY_LOGIN", 246, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + # NOTE: 247 server name is ALLOW_CAN_EDIT_EXTERNAL_SHARES (positive). Commander's + # RESTRICT_ name is kept for backward compat but the polarity is inverted vs the server. ("RESTRICT_CAN_EDIT_EXTERNAL_SHARES", 247, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_SNAPSHOT_TOOL", 248, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_FORCEFIELD", 249, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), @@ -248,6 +252,16 @@ class PrivilegeScope(enum.IntEnum): ("RESTRICT_SF_FOLDER_DELETION", 253, "BOOLEAN", "SHARING_ENFORCEMENTS"), ("RESTRICT_PLATFORM_PASSKEY_LOGIN", 254, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ("RESTRICT_CROSS_PLATFORM_PASSKEY_LOGIN", 255, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_WEB", 256, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_MOBILE", 257, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_DESKTOP", 258, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_SESSION_CONSOLE", 259, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_WEB", 260, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_MOBILE", 261, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_DESKTOP", 262, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("IP_MAX_DISTANCE_DEFAULT_CONSOLE", 263, "LONG", "ACCOUNT_ENFORCEMENTS"), + ("LOGOUT_TIMER_CONSOLE", 264, "LONG", "ACCOUNT_SETTINGS"), + ("ALLOW_CONFIGURE_WORKFLOW_SETTINGS", 267, "BOOLEAN", "ACCOUNT_ENFORCEMENTS"), ] _COMPOUND_ENFORCEMENTS = [ diff --git a/tests/test_pam_workflow.py b/tests/test_pam_workflow.py new file mode 100644 index 000000000..155025c55 --- /dev/null +++ b/tests/test_pam_workflow.py @@ -0,0 +1,362 @@ +"""Unit tests for PAM import workflow parsing, validation, and protobuf assembly.""" + +import unittest +from unittest.mock import MagicMock, patch + +from keepercommander.error import CommandError, KeeperApiError +from keepercommander.commands.pam_import.base import PamWorkflowOptions +from keepercommander.commands.pam_import import workflow_apply +from keepercommander.commands.pam_import.workflow_apply import ( + _build_temporal_filter, + _build_parameters, + _DAY_PROTO_MAP, + _is_throttle_error, + _post_with_throttle_retry, +) +from keepercommander.commands.workflow.helpers import WorkflowFormatter +from keepercommander.proto import workflow_pb2 + +# Server expects HHMM integer (workflow.proto:140 "HHMM format" + server validator). +_parse_time_to_hhmm = WorkflowFormatter._parse_time_to_hhmm + + +# --------------------------------------------------------------------------- +# Duration parsing +# --------------------------------------------------------------------------- + +class TestParseDuration(unittest.TestCase): + + def test_hours(self): + self.assertEqual(PamWorkflowOptions._parse_duration('8h'), 8 * 3_600_000) + + def test_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('30m'), 30 * 60_000) + + def test_days(self): + self.assertEqual(PamWorkflowOptions._parse_duration('1d'), 86_400_000) + + def test_bare_integer_treated_as_minutes(self): + self.assertEqual(PamWorkflowOptions._parse_duration('45'), 45 * 60_000) + + def test_none_returns_default(self): + self.assertEqual(PamWorkflowOptions._parse_duration(None), 86_400_000) + + def test_zero_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('0h') + + def test_negative_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('-1d') + + def test_invalid_string_raises(self): + with self.assertRaises(CommandError): + PamWorkflowOptions._parse_duration('invalid') + + def test_uppercase_suffix(self): + self.assertEqual(PamWorkflowOptions._parse_duration('2H'), 2 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Day mapping +# --------------------------------------------------------------------------- + +class TestDayMapping(unittest.TestCase): + + def test_all_3letter_tokens_in_map(self): + expected = {'mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'} + self.assertEqual(set(_DAY_PROTO_MAP.keys()), expected) + + def test_monday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['mon'], workflow_pb2.MONDAY) + + def test_friday_maps_to_proto(self): + self.assertEqual(_DAY_PROTO_MAP['fri'], workflow_pb2.FRIDAY) + + +# --------------------------------------------------------------------------- +# Time-of-day parsing +# --------------------------------------------------------------------------- + +class TestParseTimeToHHMM(unittest.TestCase): + """Server expects HHMM integer encoding per workflow.proto and the server-side + validator (returns "Expected HHMM integer with HH in 0-23 and MM in 0-59").""" + + def test_midnight(self): + self.assertEqual(_parse_time_to_hhmm('00:00'), 0) + + def test_nine_am(self): + self.assertEqual(_parse_time_to_hhmm('09:00'), 900) + + def test_half_past_five_pm(self): + self.assertEqual(_parse_time_to_hhmm('17:30'), 1730) + + def test_invalid_format_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('25:00') + + def test_non_numeric_raises(self): + with self.assertRaises(CommandError): + _parse_time_to_hhmm('ab:cd') + + +# --------------------------------------------------------------------------- +# V2: trivial workflow detection +# --------------------------------------------------------------------------- + +class TestTrivialWorkflow(unittest.TestCase): + + def test_empty_dict_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({})) + + def test_none_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load(None)) + + def test_all_flags_off_no_temporal_returns_none(self): + self.assertIsNone(PamWorkflowOptions.load({ + 'approvals_needed': 0, + 'checkout_needed': False, + 'require_mfa': False, + })) + + def test_checkout_needed_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'checkout_needed': True, 'access_duration': '2h'}) + self.assertIsNotNone(opts) + self.assertTrue(opts.checkout_needed) + + def test_require_mfa_true_is_non_trivial(self): + opts = PamWorkflowOptions.load({'require_mfa': True}) + self.assertIsNotNone(opts) + + def test_allowed_days_is_non_trivial(self): + opts = PamWorkflowOptions.load({'allowed_times': {'allowed_days': ['mon'], 'timezone': 'UTC'}}) + self.assertIsNotNone(opts) + + def test_approvals_needed_gt0_is_non_trivial(self): + opts = PamWorkflowOptions.load({'approvals_needed': 2}) + self.assertIsNotNone(opts) + + +# --------------------------------------------------------------------------- +# V7: escalation_after requires escalation: true +# --------------------------------------------------------------------------- + +class TestEscalationValidation(unittest.TestCase): + + def test_escalation_after_without_escalation_raises(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': False, + 'escalation_after': '30m', + }], + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_escalation_after_with_escalation_true_ok(self): + data = { + 'approvals_needed': 1, + 'approvers': [{ + 'principal': {'type': 'user', 'email': 'a@b.com'}, + 'escalation': True, + 'escalation_after': '30m', + }], + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.approvers[0]['escalation_after_ms'], 30 * 60_000) + + +# --------------------------------------------------------------------------- +# V8: time_ranges requires timezone +# --------------------------------------------------------------------------- + +class TestTimezoneRequirement(unittest.TestCase): + + def test_time_ranges_without_timezone_raises(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + }, + } + with self.assertRaises(CommandError): + PamWorkflowOptions.load(data) + + def test_time_ranges_with_timezone_ok(self): + data = { + 'require_mfa': True, + 'allowed_times': { + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'America/New_York', + }, + } + opts = PamWorkflowOptions.load(data) + self.assertIsNotNone(opts) + self.assertEqual(opts.timezone, 'America/New_York') + self.assertEqual(len(opts.time_ranges), 1) + + +# --------------------------------------------------------------------------- +# V9: access_duration default +# --------------------------------------------------------------------------- + +class TestAccessDurationDefault(unittest.TestCase): + + def test_missing_access_duration_defaults_to_1d(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + self.assertEqual(opts.access_duration_ms, 86_400_000) + + def test_explicit_duration_parsed(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1, 'access_duration': '4h'}) + self.assertEqual(opts.access_duration_ms, 4 * 3_600_000) + + +# --------------------------------------------------------------------------- +# Protobuf assembly: _build_parameters +# --------------------------------------------------------------------------- + +class TestBuildParameters(unittest.TestCase): + + def _make_uid_bytes(self): + import base64 + return base64.urlsafe_b64decode('AAAAAAAAAAAAAAAAAAAAAA==') + + def test_basic_fields_populated(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 2, + 'checkout_needed': True, + 'require_mfa': True, + 'access_duration': '8h', + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Test Machine', opts) + self.assertEqual(params_proto.approvalsNeeded, 2) + self.assertTrue(params_proto.checkoutNeeded) + self.assertTrue(params_proto.requireMFA) + self.assertEqual(params_proto.accessLength, 8 * 3_600_000) + self.assertEqual(params_proto.resource.value, uid_bytes) + self.assertEqual(params_proto.resource.name, 'Test Machine') + + def test_temporal_filter_attached(self): + opts = PamWorkflowOptions.load({ + 'require_mfa': True, + 'allowed_times': { + 'allowed_days': ['mon', 'fri'], + 'time_ranges': [{'start': '09:00', 'end': '17:00'}], + 'timezone': 'UTC', + }, + }) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + at = params_proto.allowedTimes + self.assertIn(workflow_pb2.MONDAY, at.allowedDays) + self.assertIn(workflow_pb2.FRIDAY, at.allowedDays) + self.assertEqual(len(at.timeRanges), 1) + # HHMM integer encoding: 09:00 -> 900, 17:00 -> 1700 + self.assertEqual(at.timeRanges[0].startTime, 900) + self.assertEqual(at.timeRanges[0].endTime, 1700) + self.assertEqual(at.timeZone, 'UTC') + + def test_no_allowed_times_no_temporal(self): + opts = PamWorkflowOptions.load({'approvals_needed': 1}) + uid_bytes = self._make_uid_bytes() + params_proto = _build_parameters(uid_bytes, 'Box', opts) + self.assertFalse(params_proto.HasField('allowedTimes')) + + +# --------------------------------------------------------------------------- +# validate_principals +# --------------------------------------------------------------------------- + +class TestValidatePrincipals(unittest.TestCase): + + def _make_params(self, team_uids): + p = MagicMock() + p.team_cache = {uid: {} for uid in team_uids} + return p + + def test_known_team_uid_passes(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'validUID123'}}], + }) + params = self._make_params(['validUID123']) + opts.validate_principals(params, 'MyResource') + + def test_unknown_team_uid_raises(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'team', 'team_uid_base64url': 'unknownUID'}}], + }) + params = self._make_params(['otherUID']) + with self.assertRaises(CommandError): + opts.validate_principals(params, 'MyResource') + + def test_user_principal_not_checked_against_team_cache(self): + opts = PamWorkflowOptions.load({ + 'approvals_needed': 1, + 'approvers': [{'principal': {'type': 'user', 'email': 'user@example.com'}}], + }) + params = self._make_params([]) + opts.validate_principals(params) + + +# --------------------------------------------------------------------------- +# Throttle / 429 retry wrapper +# --------------------------------------------------------------------------- + +class TestThrottleErrorDetection(unittest.TestCase): + + def test_keeper_api_error_429_is_throttle(self): + self.assertTrue(_is_throttle_error(KeeperApiError(429, 'Too many requests'))) + + def test_keeper_api_error_500_is_not_throttle(self): + self.assertFalse(_is_throttle_error(KeeperApiError(500, 'Internal error'))) + + def test_string_throttle_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('record was throttled'))) + + def test_too_many_in_msg_is_throttle(self): + self.assertTrue(_is_throttle_error(Exception('Too many requests'))) + + def test_unrelated_error_is_not_throttle(self): + self.assertFalse(_is_throttle_error(Exception('connection refused'))) + + +class TestThrottleRetry(unittest.TestCase): + + def test_no_retry_on_non_throttle(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(500, 'boom')) as mock_post: + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, 1) + + def test_retries_then_succeeds(self): + # First two calls 429, third succeeds. Patch sleep to keep test fast. + side_effects = [KeeperApiError(429, 'Too many requests'), + KeeperApiError(429, 'Too many requests'), + 'OK'] + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=side_effects) as mock_post, \ + patch.object(workflow_apply.time, 'sleep') as mock_sleep: + result = _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(result, 'OK') + self.assertEqual(mock_post.call_count, 3) + # Two backoff sleeps: 10s, 15s (10 * 1.5) + self.assertEqual([round(c.args[0], 2) for c in mock_sleep.call_args_list], [10.0, 15.0]) + + def test_exhausts_retries_and_reraises(self): + with patch.object(workflow_apply, '_post_request_to_router', + side_effect=KeeperApiError(429, 'Too many requests')) as mock_post, \ + patch.object(workflow_apply.time, 'sleep'): + with self.assertRaises(KeeperApiError): + _post_with_throttle_retry(MagicMock(), 'read_workflow_config') + self.assertEqual(mock_post.call_count, workflow_apply._THROTTLE_MAX_RETRIES) + + +if __name__ == '__main__': + unittest.main() From fcfee3f8f5e4755781c45ae5d3db5c4100e652c1 Mon Sep 17 00:00:00 2001 From: Sergey Kolupaev Date: Fri, 1 May 2026 17:29:06 -0700 Subject: [PATCH 03/26] Drop Python 3.7 support --- .github/workflows/test-with-pytest.yml | 4 +- keepercommander/commands/base.py | 24 +- keepercommander/commands/utils.py | 34 +- .../importer/manageengine/restapi.py | 6 +- .../service/config/cloudflare_config.py | 6 +- keepercommander/service/util/command_util.py | 3 - .../service/util/parse_keeper_response.py | 4 +- requirements.txt | 14 +- setup.cfg | 26 +- unit-tests/pam/test_pam_import_dedup.py | 140 ++- unit-tests/pam/test_pam_project_export.py | 548 +++++---- unit-tests/pam/test_pam_rotation.py | 1004 ++++++++--------- unit-tests/pam/test_pam_tunnel.py | 294 +++-- unit-tests/service/test_api_logging.py | 147 ++- unit-tests/service/test_api_routes.py | 151 ++- unit-tests/service/test_auth_security.py | 173 ++- unit-tests/service/test_command.py | 227 ++-- unit-tests/service/test_config_operation.py | 134 ++- unit-tests/service/test_config_validation.py | 306 +++-- unit-tests/service/test_create_service.py | 612 +++++----- unit-tests/service/test_queue_concurrency.py | 393 ++++--- unit-tests/service/test_response_parser.py | 144 ++- unit-tests/service/test_service_config.py | 224 ++-- unit-tests/service/test_service_manager.py | 282 +++-- unit-tests/test_keeper_drive.py | 1 - unit-tests/test_tunnel_registry.py | 4 - 26 files changed, 2426 insertions(+), 2479 deletions(-) diff --git a/.github/workflows/test-with-pytest.yml b/.github/workflows/test-with-pytest.yml index 95e697242..164247cc3 100644 --- a/.github/workflows/test-with-pytest.yml +++ b/.github/workflows/test-with-pytest.yml @@ -9,9 +9,9 @@ jobs: test-with-pytest: strategy: matrix: - python-version: ['3.7', '3.12'] + python-version: ['3.8', '3.14'] - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: Checkout branch diff --git a/keepercommander/commands/base.py b/keepercommander/commands/base.py index eb8f2ff1f..bc9e3c0f7 100644 --- a/keepercommander/commands/base.py +++ b/keepercommander/commands/base.py @@ -161,26 +161,22 @@ def register_commands(commands, aliases, command_info): commands['biometric'] = BiometricCommand() command_info['biometric'] = 'Biometric (Passkey) login management' - if sys.version_info.major == 3 and sys.version_info.minor >= 8: - from .start_service import register_commands as service_commands, register_command_info as service_command_info - service_commands(commands) - service_command_info(aliases, command_info) + from .start_service import register_commands as service_commands, register_command_info as service_command_info + service_commands(commands) + service_command_info(aliases, command_info) toggle_pam_legacy_commands(legacy=False) def toggle_pam_legacy_commands(legacy: bool): - if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 8): - from . import discoveryrotation - from . import discoveryrotation_v1 - if legacy is True: - discoveryrotation_v1.register_commands(commands) - discoveryrotation_v1.register_command_info(aliases, command_info) - else: - discoveryrotation.register_commands(commands) - discoveryrotation.register_command_info(aliases, command_info) + from . import discoveryrotation + from . import discoveryrotation_v1 + if legacy is True: + discoveryrotation_v1.register_commands(commands) + discoveryrotation_v1.register_command_info(aliases, command_info) else: - logging.debug('pam commands require Python 3.8 or newer') + discoveryrotation.register_commands(commands) + discoveryrotation.register_command_info(aliases, command_info) def register_enterprise_commands(commands, aliases, command_info): diff --git a/keepercommander/commands/utils.py b/keepercommander/commands/utils.py index 1df41bc88..6a2fc65c3 100644 --- a/keepercommander/commands/utils.py +++ b/keepercommander/commands/utils.py @@ -1583,24 +1583,22 @@ def execute(self, params, **kwargs): print('{0:>20s}: {1}'.format('Executable', sys.executable)) if logging.getLogger().isEnabledFor(logging.DEBUG) or show_packages: - ver = sys.version_info - if ver.major >= 3 and ver.minor >= 8: - import importlib.metadata - dist = importlib.metadata.packages_distributions() - packages = {} - for pack in dist.values(): - if isinstance(pack, list) and len(pack) > 0: - name = pack[0] - if name in packages: - continue - try: - version = importlib.metadata.version(name) - packages[name] = version - except Exception as e: - logging.debug('Get package %s version error: %s', name, e) - installed_packages_list = [f'{x[0]}=={x[1]}' for x in packages.items()] - installed_packages_list.sort(key=lambda x: x.lower()) - print('{0:>20s}: {1}'.format('Packages', installed_packages_list)) + import importlib.metadata + dist = importlib.metadata.packages_distributions() + packages = {} + for pack in dist.values(): + if isinstance(pack, list) and len(pack) > 0: + name = pack[0] + if name in packages: + continue + try: + version = importlib.metadata.version(name) + packages[name] = version + except Exception as e: + logging.debug('Get package %s version error: %s', name, e) + installed_packages_list = [f'{x[0]}=={x[1]}' for x in packages.items()] + installed_packages_list.sort(key=lambda x: x.lower()) + print('{0:>20s}: {1}'.format('Packages', installed_packages_list)) if version_details.get('is_up_to_date') is None: logging.debug("It appears that Commander is up to date") diff --git a/keepercommander/importer/manageengine/restapi.py b/keepercommander/importer/manageengine/restapi.py index d91ade5b1..bd8cf3c46 100644 --- a/keepercommander/importer/manageengine/restapi.py +++ b/keepercommander/importer/manageengine/restapi.py @@ -37,11 +37,7 @@ } -if sys.version_info < (3, 7): - Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment']) - Url.__new__.__defaults__ = ('', '', '') -else: - Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment'], defaults=('', '', '')) +Url = namedtuple('Url', ['scheme', 'netloc', 'path', 'query', 'fragment'], defaults=('', '', '')) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/keepercommander/service/config/cloudflare_config.py b/keepercommander/service/config/cloudflare_config.py index 05efcdca6..ea9fda01c 100644 --- a/keepercommander/service/config/cloudflare_config.py +++ b/keepercommander/service/config/cloudflare_config.py @@ -9,7 +9,7 @@ # Contact: ops@keepersecurity.com # -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Tuple import time import os import psutil @@ -90,7 +90,7 @@ def _get_cloudflare_log_path() -> str: return os.path.join(service_core_dir, "logs", "cloudflare_tunnel_subprocess.log") @staticmethod - def _analyze_tunnel_log(log_file: str) -> tuple[Optional[bool], str]: + def _analyze_tunnel_log(log_file: str) -> Tuple[Optional[bool], str]: """ Analyze tunnel log content for success/failure indicators. Returns (success: Optional[bool], error_message: str) @@ -115,7 +115,7 @@ def _read_log_file(log_file: str) -> str: return f.read() @staticmethod - def _check_tunnel_patterns(content: str) -> tuple[Optional[bool], str]: + def _check_tunnel_patterns(content: str) -> Tuple[Optional[bool], str]: """Check log content for success/failure patterns.""" if any(pattern in content for pattern in CloudflareConfigurator._SUCCESS_PATTERNS): return True, "" diff --git a/keepercommander/service/util/command_util.py b/keepercommander/service/util/command_util.py index fd0f7895c..7cac25d3d 100644 --- a/keepercommander/service/util/command_util.py +++ b/keepercommander/service/util/command_util.py @@ -10,7 +10,6 @@ # import io, html -from pathlib import Path import sys import json import logging @@ -21,8 +20,6 @@ from ..core.globals import get_current_params from ..decorators.logging import logger, debug_decorator, sanitize_debug_data from ... import cli, utils -from ...__main__ import get_params_from_config -from ...service.config.service_config import ServiceConfig from ...crypto import encrypt_aes_v2 class CommandExecutor: diff --git a/keepercommander/service/util/parse_keeper_response.py b/keepercommander/service/util/parse_keeper_response.py index f1372f1da..aae9159d7 100644 --- a/keepercommander/service/util/parse_keeper_response.py +++ b/keepercommander/service/util/parse_keeper_response.py @@ -9,7 +9,7 @@ # Contact: ops@keepersecurity.com # -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import re, json class KeeperResponseParser: @@ -36,7 +36,7 @@ def _format_multiline_message(text: str) -> str: return text @staticmethod - def _preprocess_response(response: Any, log_output: str = None) -> tuple[str, bool]: + def _preprocess_response(response: Any, log_output: str = None) -> Tuple[str, bool]: """Preprocess response by cleaning ANSI codes and determining source. Returns: diff --git a/requirements.txt b/requirements.txt index b33dc9b3a..41983e843 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,15 +9,15 @@ textual>=0.82.0 websockets fido2>=2.0.0; python_version>='3.10' requests>=2.31.0 -cryptography>=39.0.1 -protobuf>=4.23.0 +cryptography>=46.0.6 +protobuf>=5.29.6 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' -pydantic>=2.6.4; python_version>='3.8' -flask; python_version>='3.8' +keeper_pam_webrtc_rs>=2.1.6 +pydantic>=2.6.4 +flask pyngrok>=7.5.0 -flask-limiter; python_version>='3.8' -psutil; python_version>='3.8' +flask-limiter +psutil python-dotenv fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' diff --git a/setup.cfg b/setup.cfg index 7a17f75aa..ac7f6bfb8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,36 +16,42 @@ classifiers = License :: OSI Approved :: MIT License Operating System :: OS Independent Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 + Programming Language :: Python :: 3.14 Topic :: Security keywords = security, password [options] -python_requires = >=3.7 +python_requires = >=3.8 packages = find: include_package_data = True install_requires = asciitree bcrypt colorama - cryptography>=41.0.0 + cryptography>=46.0.6 fido2>=2.0.0; python_version>='3.10' - flask; python_version>='3.8' - flask-limiter; python_version>='3.8' + flask + flask-limiter keeper-secrets-manager-core>=16.6.0 prompt_toolkit - protobuf>=4.23.0 + protobuf>=5.29.6 googleapis-common-protos - psutil; python_version>='3.8' + psutil pycryptodomex>=3.20.0 - pyngrok; python_version>='3.8' + pyngrok pyperclip python-dotenv requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs>=2.1.6; python_version>='3.8' - pydantic>=2.6.4; python_version>='3.8' + keeper_pam_webrtc_rs>=2.1.6 + pydantic>=2.6.4 fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' pyobjc-framework-LocalAuthentication; sys_platform == "darwin" and python_version>='3.10' diff --git a/unit-tests/pam/test_pam_import_dedup.py b/unit-tests/pam/test_pam_import_dedup.py index a11c64b84..f958ff1ef 100644 --- a/unit-tests/pam/test_pam_import_dedup.py +++ b/unit-tests/pam/test_pam_import_dedup.py @@ -1,87 +1,85 @@ """Test that pam project import rejects duplicate UIDs.""" import logging -import sys import unittest -if sys.version_info >= (3, 8): - from keepercommander.commands.pam_import.edit import PAMProjectImportCommand +from keepercommander.commands.pam_import.edit import PAMProjectImportCommand - def _minimal_project(resources, users=None): - """Build a minimal project dict matching the structure process_data expects.""" - return { - "data": { - "pam_data": { - "resources": resources, - "users": users or [], - "rotation_profiles": {}, - } - }, - "pam_config": {"pam_config_uid": "test-config-uid"}, - "folders": { - "resources_folder_uid": "sfr-test", - "users_folder_uid": "sfu-test", - }, - } +def _minimal_project(resources, users=None): + """Build a minimal project dict matching the structure process_data expects.""" + return { + "data": { + "pam_data": { + "resources": resources, + "users": users or [], + "rotation_profiles": {}, + } + }, + "pam_config": {"pam_config_uid": "test-config-uid"}, + "folders": { + "resources_folder_uid": "sfr-test", + "users_folder_uid": "sfu-test", + }, + } - class TestPAMImportDuplicateUid(unittest.TestCase): - """process_data must abort when the import JSON contains duplicate uid values.""" +class TestPAMImportDuplicateUid(unittest.TestCase): + """process_data must abort when the import JSON contains duplicate uid values.""" - def test_duplicate_uid_logs_error_and_returns(self): - """process_data aborts with logging.error when two resources share a uid.""" - from unittest.mock import MagicMock - project = _minimal_project([ - {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'duplicate-uid-1'}, - {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'duplicate-uid-1'}, - ]) - cmd = PAMProjectImportCommand() - params = MagicMock() - params.record_cache = {} - params.shared_folder_cache = {} - params.folder_cache = {} + def test_duplicate_uid_logs_error_and_returns(self): + """process_data aborts with logging.error when two resources share a uid.""" + from unittest.mock import MagicMock + project = _minimal_project([ + {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'duplicate-uid-1'}, + {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'duplicate-uid-1'}, + ]) + cmd = PAMProjectImportCommand() + params = MagicMock() + params.record_cache = {} + params.shared_folder_cache = {} + params.folder_cache = {} - # assertLogs with no logger name captures from root logger (where logging.error writes) - with self.assertLogs(level='ERROR') as log_ctx: - try: - cmd.process_data(params, project) - except Exception: - pass # early return path may surface as exception in some code paths + # assertLogs with no logger name captures from root logger (where logging.error writes) + with self.assertLogs(level='ERROR') as log_ctx: + try: + cmd.process_data(params, project) + except Exception: + pass # early return path may surface as exception in some code paths - self.assertTrue( - any('duplicate uid' in msg.lower() or 'duplicate-uid-1' in msg - for msg in log_ctx.output), - f'Expected duplicate UID error in logs, got: {log_ctx.output}' - ) + self.assertTrue( + any('duplicate uid' in msg.lower() or 'duplicate-uid-1' in msg + for msg in log_ctx.output), + f'Expected duplicate UID error in logs, got: {log_ctx.output}' + ) - def test_unique_uids_pass_dedup_check(self): - """process_data does NOT emit a duplicate-uid error when all UIDs are unique.""" - from unittest.mock import MagicMock - import io + def test_unique_uids_pass_dedup_check(self): + """process_data does NOT emit a duplicate-uid error when all UIDs are unique.""" + from unittest.mock import MagicMock + import io - project = _minimal_project([ - {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'uid-alpha'}, - {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'uid-beta'}, - ]) - cmd = PAMProjectImportCommand() - params = MagicMock() - params.record_cache = {} - params.shared_folder_cache = {} - params.folder_cache = {} + project = _minimal_project([ + {'type': 'pamMachine', 'title': 'Machine A', 'uid': 'uid-alpha'}, + {'type': 'pamMachine', 'title': 'Machine B', 'uid': 'uid-beta'}, + ]) + cmd = PAMProjectImportCommand() + params = MagicMock() + params.record_cache = {} + params.shared_folder_cache = {} + params.folder_cache = {} - stream = io.StringIO() - handler = logging.StreamHandler(stream) - handler.setLevel(logging.ERROR) - root_logger = logging.getLogger() - root_logger.addHandler(handler) + stream = io.StringIO() + handler = logging.StreamHandler(stream) + handler.setLevel(logging.ERROR) + root_logger = logging.getLogger() + root_logger.addHandler(handler) + try: try: - try: - cmd.process_data(params, project) - except Exception: - pass - output = stream.getvalue() - self.assertNotIn('duplicate uid', output.lower(), - f'Unexpected duplicate UID error for unique UIDs: {output}') - finally: - root_logger.removeHandler(handler) + cmd.process_data(params, project) + except Exception: + pass + output = stream.getvalue() + self.assertNotIn('duplicate uid', output.lower(), + f'Unexpected duplicate UID error for unique UIDs: {output}') + finally: + root_logger.removeHandler(handler) if __name__ == '__main__': diff --git a/unit-tests/pam/test_pam_project_export.py b/unit-tests/pam/test_pam_project_export.py index d10ecac60..0f7900a5a 100644 --- a/unit-tests/pam/test_pam_project_export.py +++ b/unit-tests/pam/test_pam_project_export.py @@ -18,7 +18,6 @@ import json import os -import sys import tempfile import unittest from unittest.mock import patch @@ -111,286 +110,279 @@ def _fake_load(_params, uid): # ── tests ────────────────────────────────────────────────────────────────── -if sys.version_info >= (3, 8): - from unittest.mock import MagicMock - - class TestPAMProjectExportCommand(unittest.TestCase): - - def setUp(self): - from keepercommander.commands.pam_import.export import PAMProjectExportCommand - self.cmd = PAMProjectExportCommand() - self.params = MagicMock() - self.params.record_cache = {uid: {} for uid in _RECORDS} - - def _execute(self, project_uid=CONFIG_UID, output=None): - """Run execute() with vault.KeeperRecord.load mocked.""" - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): - with patch.object(self.cmd, "_get_allowed_settings", - return_value=dict(_DEFAULT_ALLOWED)): - kwargs = {"project_uid": project_uid} - if output: - kwargs["output"] = output - return self.cmd.execute(self.params, **kwargs) - - # ── basic output ────────────────────────────────────────────── - - def test_returns_string(self): - result = self._execute() - self.assertIsInstance(result, str, - "execute() should return a JSON string when --output is not set") - - def test_valid_json(self): - parsed = json.loads(self._execute()) - self.assertIsInstance(parsed, dict) - - # ── required top-level keys ─────────────────────────────────── - - def test_has_project_key(self): - parsed = json.loads(self._execute()) +from unittest.mock import MagicMock + +class TestPAMProjectExportCommand(unittest.TestCase): + + def setUp(self): + from keepercommander.commands.pam_import.export import PAMProjectExportCommand + self.cmd = PAMProjectExportCommand() + self.params = MagicMock() + self.params.record_cache = {uid: {} for uid in _RECORDS} + + def _execute(self, project_uid=CONFIG_UID, output=None): + """Run execute() with vault.KeeperRecord.load mocked.""" + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): + with patch.object(self.cmd, "_get_allowed_settings", + return_value=dict(_DEFAULT_ALLOWED)): + kwargs = {"project_uid": project_uid} + if output: + kwargs["output"] = output + return self.cmd.execute(self.params, **kwargs) + + # ── basic output ────────────────────────────────────────────── + + def test_returns_string(self): + result = self._execute() + self.assertIsInstance(result, str, + "execute() should return a JSON string when --output is not set") + + def test_valid_json(self): + parsed = json.loads(self._execute()) + self.assertIsInstance(parsed, dict) + + # ── required top-level keys ─────────────────────────────────── + + def test_has_project_key(self): + parsed = json.loads(self._execute()) + self.assertIn("project", parsed) + self.assertEqual(parsed["project"], "Test Project") + + def test_has_pam_configuration_key(self): + parsed = json.loads(self._execute()) + self.assertIn("pam_configuration", parsed) + + def test_has_pam_data_key(self): + parsed = json.loads(self._execute()) + self.assertIn("pam_data", parsed) + self.assertIn("resources", parsed["pam_data"]) + self.assertIn("users", parsed["pam_data"]) + + def test_has_tool_version(self): + parsed = json.loads(self._execute()) + self.assertIn("tool_version", parsed) + self.assertEqual(parsed["tool_version"], "commander-export-1.0") + + # ── pam_configuration fields ────────────────────────────────── + + def test_pam_configuration_environment(self): + parsed = json.loads(self._execute()) + self.assertEqual(parsed["pam_configuration"]["environment"], "local") + + def test_pam_configuration_on_off_values(self): + parsed = json.loads(self._execute()) + cfg = parsed["pam_configuration"] + for key in ("connections", "rotation", "tunneling", "remote_browser_isolation"): + self.assertIn(cfg[key], ("on", "off"), f"{key} must be 'on' or 'off'") + + # ── resources ──────────────────────────────────────────────── + + def test_resources_count(self): + parsed = json.loads(self._execute()) + self.assertEqual(len(parsed["pam_data"]["resources"]), 2) + + def test_resource_has_required_keys(self): + parsed = json.loads(self._execute()) + for res in parsed["pam_data"]["resources"]: + for key in ("uid", "type", "title", "users"): + self.assertIn(key, res, f"resource missing key: {key}") + + def test_resource_uids_are_unique(self): + parsed = json.loads(self._execute()) + uids = [r["uid"] for r in parsed["pam_data"]["resources"]] + self.assertEqual(len(uids), len(set(uids)), "resource UIDs must be unique") + + def test_resource_types(self): + parsed = json.loads(self._execute()) + types = {r["type"] for r in parsed["pam_data"]["resources"]} + self.assertIn("pamMachine", types) + self.assertIn("pamDatabase", types) + + # ── users ──────────────────────────────────────────────────── + + def test_top_level_users_deduplication(self): + # USER1 appears in both machine and database resources; + # must only appear once in pam_data.users + parsed = json.loads(self._execute()) + top_uids = [u["uid"] for u in parsed["pam_data"]["users"]] + self.assertEqual(len(top_uids), len(set(top_uids)), + "top-level user UIDs must be unique (de-duplicated)") + + def test_top_level_users_count(self): + # USER1 shared across both resources, USER2 only in DB → 2 unique users + parsed = json.loads(self._execute()) + self.assertEqual(len(parsed["pam_data"]["users"]), 2) + + def test_user_has_required_keys(self): + parsed = json.loads(self._execute()) + for usr in parsed["pam_data"]["users"]: + for key in ("uid", "type", "title", "login"): + self.assertIn(key, usr, f"user missing key: {key}") + + # ── --output flag ──────────────────────────────────────────── + + def test_output_flag_writes_file(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: + tmp_path = tmp.name + try: + result = self._execute(output=tmp_path) + # When --output is set, execute() should return None + self.assertIsNone(result) + self.assertTrue(os.path.exists(tmp_path)) + with open(tmp_path, encoding="utf-8") as fh: + content = fh.read() + parsed = json.loads(content) self.assertIn("project", parsed) - self.assertEqual(parsed["project"], "Test Project") - - def test_has_pam_configuration_key(self): - parsed = json.loads(self._execute()) - self.assertIn("pam_configuration", parsed) - - def test_has_pam_data_key(self): - parsed = json.loads(self._execute()) - self.assertIn("pam_data", parsed) - self.assertIn("resources", parsed["pam_data"]) - self.assertIn("users", parsed["pam_data"]) - - def test_has_tool_version(self): - parsed = json.loads(self._execute()) self.assertIn("tool_version", parsed) - self.assertEqual(parsed["tool_version"], "commander-export-1.0") - - # ── pam_configuration fields ────────────────────────────────── - - def test_pam_configuration_environment(self): - parsed = json.loads(self._execute()) - self.assertEqual(parsed["pam_configuration"]["environment"], "local") - - def test_pam_configuration_on_off_values(self): - parsed = json.loads(self._execute()) - cfg = parsed["pam_configuration"] - for key in ("connections", "rotation", "tunneling", "remote_browser_isolation"): - self.assertIn(cfg[key], ("on", "off"), f"{key} must be 'on' or 'off'") - - # ── resources ──────────────────────────────────────────────── - - def test_resources_count(self): - parsed = json.loads(self._execute()) - self.assertEqual(len(parsed["pam_data"]["resources"]), 2) - - def test_resource_has_required_keys(self): - parsed = json.loads(self._execute()) - for res in parsed["pam_data"]["resources"]: - for key in ("uid", "type", "title", "users"): - self.assertIn(key, res, f"resource missing key: {key}") - - def test_resource_uids_are_unique(self): - parsed = json.loads(self._execute()) - uids = [r["uid"] for r in parsed["pam_data"]["resources"]] - self.assertEqual(len(uids), len(set(uids)), "resource UIDs must be unique") - - def test_resource_types(self): - parsed = json.loads(self._execute()) - types = {r["type"] for r in parsed["pam_data"]["resources"]} - self.assertIn("pamMachine", types) - self.assertIn("pamDatabase", types) - - # ── users ──────────────────────────────────────────────────── - - def test_top_level_users_deduplication(self): - # USER1 appears in both machine and database resources; - # must only appear once in pam_data.users - parsed = json.loads(self._execute()) - top_uids = [u["uid"] for u in parsed["pam_data"]["users"]] - self.assertEqual(len(top_uids), len(set(top_uids)), - "top-level user UIDs must be unique (de-duplicated)") - - def test_top_level_users_count(self): - # USER1 shared across both resources, USER2 only in DB → 2 unique users - parsed = json.loads(self._execute()) - self.assertEqual(len(parsed["pam_data"]["users"]), 2) - - def test_user_has_required_keys(self): - parsed = json.loads(self._execute()) - for usr in parsed["pam_data"]["users"]: - for key in ("uid", "type", "title", "login"): - self.assertIn(key, usr, f"user missing key: {key}") - - # ── --output flag ──────────────────────────────────────────── - - def test_output_flag_writes_file(self): - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp: - tmp_path = tmp.name - try: - result = self._execute(output=tmp_path) - # When --output is set, execute() should return None - self.assertIsNone(result) - self.assertTrue(os.path.exists(tmp_path)) - with open(tmp_path, encoding="utf-8") as fh: - content = fh.read() - parsed = json.loads(content) - self.assertIn("project", parsed) - self.assertIn("tool_version", parsed) - finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) - - # ── error handling ─────────────────────────────────────────── - - def test_missing_project_uid_returns_none(self): - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): - result = self.cmd.execute(self.params, project_uid="", output=None) - self.assertIsNone(result) - - def test_unknown_uid_returns_none(self): - with patch("keepercommander.vault.KeeperRecord.load", return_value=None): - result = self.cmd.execute(self.params, project_uid="unknown-uid", output=None) - self.assertIsNone(result) - - def test_non_v6_record_returns_none(self): - v3_rec = vault.TypedRecord(version=3) - v3_rec.type_name = "pamMachine" - v3_rec.title = "some" - v3_rec.record_uid = "some-uid" - with patch("keepercommander.vault.KeeperRecord.load", return_value=v3_rec): - result = self.cmd.execute(self.params, project_uid="some-uid", output=None) - self.assertIsNone(result) - - # ── round-trip / determinism ───────────────────────────────── - - def test_sort_keys_determinism(self): - result1 = self._execute() - result2 = self._execute() - self.assertEqual(result1, result2, "Output must be deterministic across calls") - - def test_output_is_sorted(self): - result = self._execute() - parsed = json.loads(result) - keys = list(parsed.keys()) - self.assertEqual(keys, sorted(keys), - "Top-level keys should be sorted (sort_keys=True)") - - - # ──────────────────────────────────────────────────────────────────── - # KCM-import compatibility (PR #1942) - # ──────────────────────────────────────────────────────────────────── - - class TestKCMImportRoundTrip(unittest.TestCase): - """KCM-imported records (PR #1942) reference users by *title* in - ``pam_settings.connection.launch_credentials`` rather than by UID - in ``userRecords[]``. Export must resolve these title references - so the exported JSON re-imports with the user link intact. - """ - - KCM_CFG = "kcm-cfg-1" - KCM_RES = "kcm-res-prod-db" - KCM_USR = "kcm-usr-prod-db" - - def _make_kcm_records(self): - """Build the KCM-shaped vault state (PR #1942 import output).""" - cfg = vault.TypedRecord(version=6) - cfg.type_name = "pamNetworkConfiguration" - cfg.title = "KCM Migration" - cfg.record_uid = self.KCM_CFG - cfg.fields.append(_make_typed_field("pamResources", [{ - "controllerUid": "gw-uid", - "folderUid": "sf-uid", - "resourceRef": [self.KCM_RES], - }])) - - res = vault.TypedRecord(version=3) - res.type_name = "pamMachine" - res.title = "KCM Resource - prod-db" - res.record_uid = self.KCM_RES - res.fields.append(_make_typed_field("pamSettings", [{ - "connection": { - "protocol": "ssh", - "port": "22", - "launch_credentials": "KCM User - prod-db", - }, - "options": {"connections": "on", "rotation": "off"}, - }])) - - usr = vault.TypedRecord(version=3) - usr.type_name = "pamUser" - usr.title = "KCM User - prod-db" - usr.record_uid = self.KCM_USR - usr.fields.append(_make_typed_field("login", ["root"])) - - return {self.KCM_CFG: cfg, self.KCM_RES: res, self.KCM_USR: usr} - - def setUp(self): - from keepercommander.commands.pam_import.export import PAMProjectExportCommand - from unittest.mock import MagicMock - self.cmd = PAMProjectExportCommand() - self.records = self._make_kcm_records() - self.params = MagicMock() - self.params.record_cache = {uid: {} for uid in self.records} - - def _execute(self): - def _load(_p, uid): - return self.records.get(uid) - with patch("keepercommander.vault.KeeperRecord.load", side_effect=_load): - with patch.object(self.cmd, "_get_allowed_settings", - return_value=dict(_DEFAULT_ALLOWED)): - return self.cmd.execute(self.params, project_uid=self.KCM_CFG) - - def test_title_based_user_link_resolved(self): - """KCM resource → export must include the user via title resolution.""" - parsed = json.loads(self._execute()) - resources = parsed["pam_data"]["resources"] - self.assertEqual(len(resources), 1, "expected one KCM resource") - res = resources[0] - self.assertEqual(len(res["users"]), 1, - "KCM resource must export 1 user (resolved by title)") - self.assertEqual(res["users"][0]["uid"], self.KCM_USR) - self.assertEqual(res["users"][0]["title"], "KCM User - prod-db") - - def test_top_level_users_includes_resolved_user(self): - parsed = json.loads(self._execute()) - top_users = parsed["pam_data"]["users"] - self.assertEqual(len(top_users), 1) - self.assertEqual(top_users[0]["uid"], self.KCM_USR) - - def test_pam_settings_preserved_for_round_trip(self): - """Round-trip safety: KCM-specific pam_settings keys preserved verbatim.""" - parsed = json.loads(self._execute()) - res = parsed["pam_data"]["resources"][0] - conn = res["pam_settings"]["connection"] - self.assertEqual(conn["protocol"], "ssh") - self.assertEqual(conn["port"], "22") - self.assertEqual(conn["launch_credentials"], "KCM User - prod-db") - - def test_uid_in_launch_credentials_accepted(self): - """If launch_credentials already holds a 22-char UID (non-KCM path), keep it as-is.""" - uid_22 = "AAAAAAAAAAAAAAAAAAAAAA" # 22 chars, no slash, no space - usr = vault.TypedRecord(version=3) - usr.type_name = "pamUser" - usr.title = "Direct UID User" - usr.record_uid = uid_22 - usr.fields.append(_make_typed_field("login", ["alice"])) - self.records[uid_22] = usr - self.params.record_cache[uid_22] = {} - - res = self.records[self.KCM_RES] - ps = res.get_typed_field("pamSettings").value[0] - ps["connection"]["launch_credentials"] = uid_22 - parsed = json.loads(self._execute()) - users = parsed["pam_data"]["resources"][0]["users"] - self.assertEqual(len(users), 1) - self.assertEqual(users[0]["uid"], uid_22) - - -else: - class TestPAMProjectExportCommand(unittest.TestCase): - def test_skip(self): - self.skipTest("Requires Python 3.8+") + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + # ── error handling ─────────────────────────────────────────── + + def test_missing_project_uid_returns_none(self): + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_fake_load): + result = self.cmd.execute(self.params, project_uid="", output=None) + self.assertIsNone(result) + + def test_unknown_uid_returns_none(self): + with patch("keepercommander.vault.KeeperRecord.load", return_value=None): + result = self.cmd.execute(self.params, project_uid="unknown-uid", output=None) + self.assertIsNone(result) + + def test_non_v6_record_returns_none(self): + v3_rec = vault.TypedRecord(version=3) + v3_rec.type_name = "pamMachine" + v3_rec.title = "some" + v3_rec.record_uid = "some-uid" + with patch("keepercommander.vault.KeeperRecord.load", return_value=v3_rec): + result = self.cmd.execute(self.params, project_uid="some-uid", output=None) + self.assertIsNone(result) + + # ── round-trip / determinism ───────────────────────────────── + + def test_sort_keys_determinism(self): + result1 = self._execute() + result2 = self._execute() + self.assertEqual(result1, result2, "Output must be deterministic across calls") + + def test_output_is_sorted(self): + result = self._execute() + parsed = json.loads(result) + keys = list(parsed.keys()) + self.assertEqual(keys, sorted(keys), + "Top-level keys should be sorted (sort_keys=True)") + + +# ──────────────────────────────────────────────────────────────────── +# KCM-import compatibility (PR #1942) +# ──────────────────────────────────────────────────────────────────── + +class TestKCMImportRoundTrip(unittest.TestCase): + """KCM-imported records (PR #1942) reference users by *title* in + ``pam_settings.connection.launch_credentials`` rather than by UID + in ``userRecords[]``. Export must resolve these title references + so the exported JSON re-imports with the user link intact. + """ + + KCM_CFG = "kcm-cfg-1" + KCM_RES = "kcm-res-prod-db" + KCM_USR = "kcm-usr-prod-db" + + def _make_kcm_records(self): + """Build the KCM-shaped vault state (PR #1942 import output).""" + cfg = vault.TypedRecord(version=6) + cfg.type_name = "pamNetworkConfiguration" + cfg.title = "KCM Migration" + cfg.record_uid = self.KCM_CFG + cfg.fields.append(_make_typed_field("pamResources", [{ + "controllerUid": "gw-uid", + "folderUid": "sf-uid", + "resourceRef": [self.KCM_RES], + }])) + + res = vault.TypedRecord(version=3) + res.type_name = "pamMachine" + res.title = "KCM Resource - prod-db" + res.record_uid = self.KCM_RES + res.fields.append(_make_typed_field("pamSettings", [{ + "connection": { + "protocol": "ssh", + "port": "22", + "launch_credentials": "KCM User - prod-db", + }, + "options": {"connections": "on", "rotation": "off"}, + }])) + + usr = vault.TypedRecord(version=3) + usr.type_name = "pamUser" + usr.title = "KCM User - prod-db" + usr.record_uid = self.KCM_USR + usr.fields.append(_make_typed_field("login", ["root"])) + + return {self.KCM_CFG: cfg, self.KCM_RES: res, self.KCM_USR: usr} + + def setUp(self): + from keepercommander.commands.pam_import.export import PAMProjectExportCommand + from unittest.mock import MagicMock + self.cmd = PAMProjectExportCommand() + self.records = self._make_kcm_records() + self.params = MagicMock() + self.params.record_cache = {uid: {} for uid in self.records} + + def _execute(self): + def _load(_p, uid): + return self.records.get(uid) + with patch("keepercommander.vault.KeeperRecord.load", side_effect=_load): + with patch.object(self.cmd, "_get_allowed_settings", + return_value=dict(_DEFAULT_ALLOWED)): + return self.cmd.execute(self.params, project_uid=self.KCM_CFG) + + def test_title_based_user_link_resolved(self): + """KCM resource → export must include the user via title resolution.""" + parsed = json.loads(self._execute()) + resources = parsed["pam_data"]["resources"] + self.assertEqual(len(resources), 1, "expected one KCM resource") + res = resources[0] + self.assertEqual(len(res["users"]), 1, + "KCM resource must export 1 user (resolved by title)") + self.assertEqual(res["users"][0]["uid"], self.KCM_USR) + self.assertEqual(res["users"][0]["title"], "KCM User - prod-db") + + def test_top_level_users_includes_resolved_user(self): + parsed = json.loads(self._execute()) + top_users = parsed["pam_data"]["users"] + self.assertEqual(len(top_users), 1) + self.assertEqual(top_users[0]["uid"], self.KCM_USR) + + def test_pam_settings_preserved_for_round_trip(self): + """Round-trip safety: KCM-specific pam_settings keys preserved verbatim.""" + parsed = json.loads(self._execute()) + res = parsed["pam_data"]["resources"][0] + conn = res["pam_settings"]["connection"] + self.assertEqual(conn["protocol"], "ssh") + self.assertEqual(conn["port"], "22") + self.assertEqual(conn["launch_credentials"], "KCM User - prod-db") + + def test_uid_in_launch_credentials_accepted(self): + """If launch_credentials already holds a 22-char UID (non-KCM path), keep it as-is.""" + uid_22 = "AAAAAAAAAAAAAAAAAAAAAA" # 22 chars, no slash, no space + usr = vault.TypedRecord(version=3) + usr.type_name = "pamUser" + usr.title = "Direct UID User" + usr.record_uid = uid_22 + usr.fields.append(_make_typed_field("login", ["alice"])) + self.records[uid_22] = usr + self.params.record_cache[uid_22] = {} + + res = self.records[self.KCM_RES] + ps = res.get_typed_field("pamSettings").value[0] + ps["connection"]["launch_credentials"] = uid_22 + parsed = json.loads(self._execute()) + users = parsed["pam_data"]["resources"][0]["users"] + self.assertEqual(len(users), 1) + self.assertEqual(users[0]["uid"], uid_22) if __name__ == "__main__": diff --git a/unit-tests/pam/test_pam_rotation.py b/unit-tests/pam/test_pam_rotation.py index 87e4cc157..f7fce197a 100644 --- a/unit-tests/pam/test_pam_rotation.py +++ b/unit-tests/pam/test_pam_rotation.py @@ -1,5 +1,4 @@ import json -import sys import unittest from datetime import datetime from unittest.mock import patch, MagicMock @@ -50,555 +49,554 @@ def create_mock_params(): return mock_params -if sys.version_info >= (3, 8): - import requests - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat +import requests +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat - from keepercommander import crypto, utils - from keepercommander.commands.discoveryrotation import (PAMCreateRecordRotationCommand, PAMListRecordRotationCommand, - PAMGatewayListCommand, PAMRouterGetRotationInfo) +from keepercommander import crypto, utils +from keepercommander.commands.discoveryrotation import (PAMCreateRecordRotationCommand, PAMListRecordRotationCommand, + PAMGatewayListCommand, PAMRouterGetRotationInfo) - class TestPAMCreateRecordRotationCommand(unittest.TestCase): +class TestPAMCreateRecordRotationCommand(unittest.TestCase): - def setUp(self): - self.command = PAMCreateRecordRotationCommand() - self.parser = self.command.get_parser() - self.transmission_key = b'transmission_key' - self.session_token = b'encrypted_session_token' - self.private_key = ec.generate_private_key(ec.SECP256R1()) - self.public_key = self.private_key.public_key() + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() + self.transmission_key = b'transmission_key' + self.session_token = b'encrypted_session_token' + self.private_key = ec.generate_private_key(ec.SECP256R1()) + self.public_key = self.private_key.public_key() - # Serialize and deserialize the public key to ensure compatibility - public_key_bytes = self.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) - loaded_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_key_bytes) + # Serialize and deserialize the public key to ensure compatibility + public_key_bytes = self.public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) + loaded_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_key_bytes) - self.encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, loaded_public_key) - self.encrypted_session_token = crypto.encrypt_aes_v2(self.session_token, self.transmission_key) + self.encrypted_transmission_key = crypto.encrypt_ec(self.transmission_key, loaded_public_key) + self.encrypted_session_token = crypto.encrypt_aes_v2(self.session_token, self.transmission_key) - def test_parser(self): - args = self.parser.parse_args(['--record', 'record_uid', '--force']) - self.assertEqual(args.record_name, 'record_uid') - self.assertTrue(args.force) + def test_parser(self): + args = self.parser.parse_args(['--record', 'record_uid', '--force']) + self.assertEqual(args.record_name, 'record_uid') + self.assertTrue(args.force) - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_folder(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_folder(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() - mock_load.return_value = mock_typed_record + mock_load.return_value = mock_typed_record - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string - kwargs = { - 'folder_name': 'folder_uid', - 'force': True # Add force to the kwargs - } - - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_TunnelDAG.called) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_no_record(self, mock_TunnelDAG, mock_load): - mock_params, _ = create_mock_params_and_record() - mock_params.record_cache = {} - - kwargs = { - 'record_name': 'non_existent_record', - 'force': True # Add force to the kwargs - } - - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_invalid_password_complexity(self, mock_TunnelDAG, mock_load): - mock_params, _ = create_mock_params_and_record() - - kwargs = { - 'record_name': 'record_uid', - 'pwd_complexity': 'invalid_complexity', - 'force': True # Add force to the kwargs - } - - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_valid_password_complexity(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() - - mock_load.return_value = mock_typed_record + kwargs = { + 'folder_name': 'folder_uid', + 'force': True # Add force to the kwargs + } - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) - kwargs = { - 'record_name': 'record_uid', - 'pwd_complexity': '32,5,5,5,5', - 'force': True # Add force to the kwargs - } + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_no_record(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() + mock_params.record_cache = {} - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) - def test_execute_with_valid_record(self, mock_TunnelDAG, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record() - - mock_load.return_value = mock_typed_record - - mock_dag_instance = mock_TunnelDAG.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.check_if_resource_has_admin.return_value = True - mock_dag_instance.get_all_owners.return_value = ['resource_uid'] - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_dag_instance.user_belongs_to_resource.return_value = True - mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string - - kwargs = { - 'record_name': 'record_uid', - 'force': True # Add force to the kwargs - } + kwargs = { + 'record_name': 'non_existent_record', + 'force': True # Add force to the kwargs + } + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_TunnelDAG.called) - self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - - - class TestPAMResourceRotateCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMCreateRecordRotationCommand() - self.parser = self.command.get_parser() - def test_parser(self): - args = self.parser.parse_args(['--record', "abcdefg", '--enable']) - self.assertEqual(args.record_name, 'abcdefg') - self.assertTrue(args.enable) + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_invalid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, _ = create_mock_params_and_record() - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_enable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record - - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] - - kwargs = { - 'record_name': 'record_uid', - 'enable': True, - 'config_uid': 'config_uid' - } + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': 'invalid_complexity', + 'force': True # Add force to the kwargs + } + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - - @patch('keepercommander.vault.KeeperRecord.load', return_value=None) - def test_execute_with_invalid_uid(self, mock_load): - mock_params, _ = create_mock_params_and_record('pamMachine') - - kwargs = { - 'record_name': 'invalid_uid', - 'enable': True - } - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) - - @patch('keepercommander.vault.KeeperRecord.load') - def test_execute_with_invalid_record_type(self, mock_load): - mock_params, mock_typed_record = create_mock_params_and_record(record_type='invalid_type') - mock_load.return_value = mock_typed_record + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_password_complexity(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'pwd_complexity': '32,5,5,5,5', + 'force': True # Add force to the kwargs + } - kwargs = { - 'record_name': 'record_uid', - 'enable': True - } + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) + + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + @patch('keepercommander.rest_api.SERVER_PUBLIC_KEYS', {8: ec.generate_private_key(ec.SECP256R1()).public_key()}) + def test_execute_with_valid_record(self, mock_TunnelDAG, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record() + + mock_load.return_value = mock_typed_record + + mock_dag_instance = mock_TunnelDAG.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.check_if_resource_has_admin.return_value = True + mock_dag_instance.get_all_owners.return_value = ['resource_uid'] + mock_dag_instance.resource_belongs_to_config.return_value = True + mock_dag_instance.user_belongs_to_resource.return_value = True + mock_dag_instance.record.record_uid = 'config_uid' # Ensure it returns a string + + kwargs = { + 'record_name': 'record_uid', + 'force': True # Add force to the kwargs + } - with self.assertRaises(CommandError): - self.command.execute(mock_params, **kwargs) + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_TunnelDAG.called) + self.assertEqual(mock_typed_record.record_key, b'\x00' * 16) - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_disable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') +class TestPAMResourceRotateCommand(unittest.TestCase): - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record + def setUp(self): + self.command = PAMCreateRecordRotationCommand() + self.parser = self.command.get_parser() - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] + def test_parser(self): + args = self.parser.parse_args(['--record', "abcdefg", '--enable']) + self.assertEqual(args.record_name, 'abcdefg') + self.assertTrue(args.enable) - kwargs = { - 'record_name': 'record_uid', - 'disable': True, - 'config_uid': 'config_uid' - } + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - - @patch('keepercommander.vault_extensions.find_records') - @patch('keepercommander.vault.KeeperRecord.load') - @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') - @patch('keepercommander.commands.discoveryrotation.TunnelDAG') - def test_execute_with_enable_and_admin(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): - mock_dag_instance = mock_tunneldag.return_value - mock_dag_instance.linking_dag.has_graph = True - mock_dag_instance.resource_belongs_to_config.return_value = True - - mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - - mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') - mock_load.return_value = mock_typed_record - - mock_pam_config_record = MagicMock(spec=vault.TypedRecord) - mock_pam_config_record.record_uid = 'config_uid' - mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type - mock_find_records.return_value = [mock_pam_config_record] - - kwargs = { - 'record_name': 'record_uid', - 'enable': True, - 'config_uid': 'config_uid', - 'admin': 'admin_uid' - } + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_load.called) - self.assertTrue(mock_tunneldag.called) - self.assertTrue(mock_get_keeper_tokens.called) - mock_dag_instance.link_user_to_resource.assert_called_with('admin_uid', 'record_uid', is_admin=True) - - - class TestPAMListRecordRotationCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMListRecordRotationCommand() - self.parser = self.command.get_parser() - - def test_parser(self): - args = self.parser.parse_args(['--verbose']) - self.assertTrue(args.is_verbose) - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, - mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): - mock_params = create_mock_params() - - # Mock the return values - mock_router_get_rotation_schedules.return_value.schedules = [ - MagicMock( - recordUid=utils.base64_url_decode('record_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - configurationUid=utils.base64_url_decode('config_uid'), - noSchedule=False, - scheduleData='RotateActionJob|daily.0.12.1' - ) - ] - - mock_get_all_gateways.return_value = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid'), controllerName='Controller Name') - ] - - mock_router_get_connected_gateways.return_value.controllers = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) - ] - - mock_pam_configurations_get_all.return_value = [ - {'record_uid': 'config_uid', 'data_unencrypted': json.dumps({'title': 'Config Title', 'type': 'pamConfig'})} - ] - - mock_pam_decrypt_configuration_data.return_value = { - 'title': 'Config Title', - 'type': 'pamConfig' - } + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record - kwargs = {'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_rotation_schedules.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_pam_configurations_get_all.called) + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid' + } - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute_with_no_schedules(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, - mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): - mock_params = create_mock_params() + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) - # Mock the return values - mock_router_get_rotation_schedules.return_value.schedules = [] + @patch('keepercommander.vault.KeeperRecord.load', return_value=None) + def test_execute_with_invalid_uid(self, mock_load): + mock_params, _ = create_mock_params_and_record('pamMachine') - mock_get_all_gateways.return_value = [] + kwargs = { + 'record_name': 'invalid_uid', + 'enable': True + } - mock_router_get_connected_gateways.return_value.controllers = [] + with self.assertRaises(CommandError): + self.command.execute(mock_params, **kwargs) - mock_pam_configurations_get_all.return_value = [] + @patch('keepercommander.vault.KeeperRecord.load') + def test_execute_with_invalid_record_type(self, mock_load): + mock_params, mock_typed_record = create_mock_params_and_record(record_type='invalid_type') + mock_load.return_value = mock_typed_record - mock_pam_decrypt_configuration_data.return_value = {} + kwargs = { + 'record_name': 'record_uid', + 'enable': True + } - kwargs = {'is_verbose': True} + with self.assertRaises(CommandError): self.command.execute(mock_params, **kwargs) - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_rotation_schedules.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_pam_configurations_get_all.called) - - - class TestPAMGatewayListCommand(unittest.TestCase): - - def setUp(self): - self.command = PAMGatewayListCommand() - self.parser = self.command.get_parser() - - def test_parser(self): - args = self.parser.parse_args(['--verbose', '--force']) - self.assertTrue(args.is_verbose) - self.assertTrue(args.is_force) - - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.KSMCommand.get_app_record') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute(self, mock_dump_report_data, mock_get_app_record, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() - - # Mock the return values - mock_router_get_connected_gateways.return_value.controllers = [ - MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) - ] - - mock_get_all_gateways.return_value = [ - MagicMock( - applicationUid=utils.base64_url_decode('app_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - controllerName='Controller Name', - deviceName='Device Name', - deviceToken='Device Token', - created=int(datetime.now().timestamp() * 1000), - lastModified=int(datetime.now().timestamp() * 1000), - nodeId='Node ID' - ) - ] - - mock_get_app_record.return_value = { - 'data_unencrypted': json.dumps({'title': 'App Title'}) - } + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_disable(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) - self.assertTrue(mock_get_app_record.called) - - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - @patch('keepercommander.commands.discoveryrotation.dump_report_data') - def test_execute_router_down(self, mock_dump_report_data, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() - - # Simulate a connection error - mock_router_get_connected_gateways.side_effect = requests.exceptions.ConnectionError - - mock_get_all_gateways.return_value = [ - MagicMock( - applicationUid=utils.base64_url_decode('app_uid'), - controllerUid=utils.base64_url_decode('controller_uid'), - controllerName='Controller Name', - deviceName='Device Name', - deviceToken='Device Token', - created=int(datetime.now().timestamp() * 1000), - lastModified=int(datetime.now().timestamp() * 1000), - nodeId='Node ID' - ) - ] - - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record - self.assertTrue(mock_dump_report_data.called) - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] - @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') - @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') - @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') - def test_execute_no_gateways(self, mock_get_all_gateways, - mock_get_router_url, mock_router_get_connected_gateways): - mock_params = create_mock_params() + kwargs = { + 'record_name': 'record_uid', + 'disable': True, + 'config_uid': 'config_uid' + } - mock_router_get_connected_gateways.return_value.controllers = [] + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + + @patch('keepercommander.vault_extensions.find_records') + @patch('keepercommander.vault.KeeperRecord.load') + @patch('keepercommander.commands.discoveryrotation.get_keeper_tokens') + @patch('keepercommander.commands.discoveryrotation.TunnelDAG') + def test_execute_with_enable_and_admin(self, mock_tunneldag, mock_get_keeper_tokens, mock_load, mock_find_records): + mock_dag_instance = mock_tunneldag.return_value + mock_dag_instance.linking_dag.has_graph = True + mock_dag_instance.resource_belongs_to_config.return_value = True + + mock_get_keeper_tokens.return_value = (b'token', b'encrypted_key', b'transmission_key') + + mock_params, mock_typed_record = create_mock_params_and_record('pamMachine') + mock_load.return_value = mock_typed_record + + mock_pam_config_record = MagicMock(spec=vault.TypedRecord) + mock_pam_config_record.record_uid = 'config_uid' + mock_pam_config_record.record_type = 'pamConfiguration' # Use a valid PAM configuration record type + mock_find_records.return_value = [mock_pam_config_record] + + kwargs = { + 'record_name': 'record_uid', + 'enable': True, + 'config_uid': 'config_uid', + 'admin': 'admin_uid' + } - mock_get_all_gateways.return_value = [] + self.command.execute(mock_params, **kwargs) + self.assertTrue(mock_load.called) + self.assertTrue(mock_tunneldag.called) + self.assertTrue(mock_get_keeper_tokens.called) + mock_dag_instance.link_user_to_resource.assert_called_with('admin_uid', 'record_uid', is_admin=True) + + +class TestPAMListRecordRotationCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMListRecordRotationCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose']) + self.assertTrue(args.is_verbose) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [ + MagicMock( + recordUid=utils.base64_url_decode('record_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + configurationUid=utils.base64_url_decode('config_uid'), + noSchedule=False, + scheduleData='RotateActionJob|daily.0.12.1' + ) + ] + + mock_get_all_gateways.return_value = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid'), controllerName='Controller Name') + ] + + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_pam_configurations_get_all.return_value = [ + {'record_uid': 'config_uid', 'data_unencrypted': json.dumps({'title': 'Config Title', 'type': 'pamConfig'})} + ] + + mock_pam_decrypt_configuration_data.return_value = { + 'title': 'Config Title', + 'type': 'pamConfig' + } - kwargs = {'is_force': True, 'is_verbose': True} - self.command.execute(mock_params, **kwargs) + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_configurations_get_all') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.pam_decrypt_configuration_data') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_with_no_schedules(self, mock_dump_report_data, mock_pam_decrypt_configuration_data, mock_get_all_gateways, + mock_pam_configurations_get_all, mock_router_get_connected_gateways, mock_router_get_rotation_schedules): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_rotation_schedules.return_value.schedules = [] + + mock_get_all_gateways.return_value = [] + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_pam_configurations_get_all.return_value = [] + + mock_pam_decrypt_configuration_data.return_value = {} + + kwargs = {'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_rotation_schedules.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_pam_configurations_get_all.called) + + +class TestPAMGatewayListCommand(unittest.TestCase): + + def setUp(self): + self.command = PAMGatewayListCommand() + self.parser = self.command.get_parser() + + def test_parser(self): + args = self.parser.parse_args(['--verbose', '--force']) + self.assertTrue(args.is_verbose) + self.assertTrue(args.is_force) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.KSMCommand.get_app_record') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute(self, mock_dump_report_data, mock_get_app_record, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Mock the return values + mock_router_get_connected_gateways.return_value.controllers = [ + MagicMock(controllerUid=utils.base64_url_decode('controller_uid')) + ] + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + mock_get_app_record.return_value = { + 'data_unencrypted': json.dumps({'title': 'App Title'}) + } - self.assertTrue(mock_router_get_connected_gateways.called) - self.assertTrue(mock_get_all_gateways.called) - self.assertTrue(mock_get_router_url.called) - - class TestPAMRouterGetRotationInfo(unittest.TestCase): - - def _make_rri(self, status_name='RRS_ONLINE'): - """Build a minimal RouterRotationInfo mock.""" - from keepercommander.proto import router_pb2 - rri = MagicMock() - rri.status = router_pb2.RouterRotationStatus.Value(status_name) - rri.configurationUid = utils.base64_url_decode('config_uid_____') - rri.nodeId = 42 - rri.controllerName = 'gw-test' - rri.controllerUid = utils.base64_url_decode('gw_uid_________') - rri.resourceUid = b'' - rri.pwdComplexity = '' - rri.disabled = False - rri.scriptName = '' - return rri - - def _make_schedule(self, record_uid_bytes, no_schedule=False, schedule_data='daily.0.12.1'): - s = MagicMock() - s.recordUid = record_uid_bytes - s.noSchedule = no_schedule - s.scheduleData = schedule_data - return s - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_json_online_status(self, mock_rrg, mock_schedules): - """Online status + --format json returns valid JSON with expected keys.""" - from keeper_secrets_manager_core.utils import url_safe_str_to_bytes - record_uid = 'test_record_uid_' - record_uid_bytes = url_safe_str_to_bytes(record_uid) - - mock_rrg.return_value = self._make_rri('RRS_ONLINE') - - sched_mock = MagicMock() - sched_mock.schedules = [self._make_schedule(record_uid_bytes, no_schedule=False, - schedule_data='daily.0.12.1')] - mock_schedules.return_value = sched_mock - - mock_params = create_mock_params() - mock_params.record_cache = {} - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='json') - - self.assertIsNotNone(result, "Expected JSON string, got None") - data = json.loads(result) - self.assertIn('status', data) - self.assertTrue(data['ready_to_rotate']) - self.assertIn('pam_config_uid', data) - self.assertIn('gateway_name', data) - self.assertEqual(data['gateway_name'], 'gw-test') - self.assertIn('schedule_type', data) - self.assertEqual(data['schedule_type'], 'scheduled') - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_json_non_online_status(self, mock_rrg, mock_schedules): - """Non-online status + --format json returns minimal JSON with ready_to_rotate=false.""" - record_uid = 'test_record_uid_' - - mock_rrg.return_value = self._make_rri('RRS_NO_ROTATION') - - mock_params = create_mock_params() - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='json') - - self.assertIsNotNone(result, "Expected JSON string, got None") - data = json.loads(result) - self.assertIn('status', data) - self.assertFalse(data['ready_to_rotate']) - - @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') - @patch('keepercommander.commands.discoveryrotation.record_rotation_get') - def test_table_mode_returns_none(self, mock_rrg, mock_schedules): - """Table mode (default) prints to stdout and returns None.""" - from keeper_secrets_manager_core.utils import url_safe_str_to_bytes - record_uid = 'test_record_uid_' - record_uid_bytes = url_safe_str_to_bytes(record_uid) - - mock_rrg.return_value = self._make_rri('RRS_ONLINE') - - sched_mock = MagicMock() - sched_mock.schedules = [self._make_schedule(record_uid_bytes)] - mock_schedules.return_value = sched_mock - - mock_params = create_mock_params() - mock_params.record_cache = {} - - cmd = PAMRouterGetRotationInfo() - result = cmd.execute(mock_params, record_uid=record_uid, format='table') - self.assertIsNone(result) + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + self.assertTrue(mock_get_app_record.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + @patch('keepercommander.commands.discoveryrotation.dump_report_data') + def test_execute_router_down(self, mock_dump_report_data, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + # Simulate a connection error + mock_router_get_connected_gateways.side_effect = requests.exceptions.ConnectionError + + mock_get_all_gateways.return_value = [ + MagicMock( + applicationUid=utils.base64_url_decode('app_uid'), + controllerUid=utils.base64_url_decode('controller_uid'), + controllerName='Controller Name', + deviceName='Device Name', + deviceToken='Device Token', + created=int(datetime.now().timestamp() * 1000), + lastModified=int(datetime.now().timestamp() * 1000), + nodeId='Node ID' + ) + ] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_dump_report_data.called) + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + + @patch('keepercommander.commands.discoveryrotation.router_get_connected_gateways') + @patch('keepercommander.commands.discoveryrotation.router_helper.get_router_url') + @patch('keepercommander.commands.discoveryrotation.gateway_helper.get_all_gateways') + def test_execute_no_gateways(self, mock_get_all_gateways, + mock_get_router_url, mock_router_get_connected_gateways): + mock_params = create_mock_params() + + mock_router_get_connected_gateways.return_value.controllers = [] + + mock_get_all_gateways.return_value = [] + + kwargs = {'is_force': True, 'is_verbose': True} + self.command.execute(mock_params, **kwargs) + + self.assertTrue(mock_router_get_connected_gateways.called) + self.assertTrue(mock_get_all_gateways.called) + self.assertTrue(mock_get_router_url.called) + +class TestPAMRouterGetRotationInfo(unittest.TestCase): + + def _make_rri(self, status_name='RRS_ONLINE'): + """Build a minimal RouterRotationInfo mock.""" + from keepercommander.proto import router_pb2 + rri = MagicMock() + rri.status = router_pb2.RouterRotationStatus.Value(status_name) + rri.configurationUid = utils.base64_url_decode('config_uid_____') + rri.nodeId = 42 + rri.controllerName = 'gw-test' + rri.controllerUid = utils.base64_url_decode('gw_uid_________') + rri.resourceUid = b'' + rri.pwdComplexity = '' + rri.disabled = False + rri.scriptName = '' + return rri + + def _make_schedule(self, record_uid_bytes, no_schedule=False, schedule_data='daily.0.12.1'): + s = MagicMock() + s.recordUid = record_uid_bytes + s.noSchedule = no_schedule + s.scheduleData = schedule_data + return s + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_json_online_status(self, mock_rrg, mock_schedules): + """Online status + --format json returns valid JSON with expected keys.""" + from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + record_uid = 'test_record_uid_' + record_uid_bytes = url_safe_str_to_bytes(record_uid) + + mock_rrg.return_value = self._make_rri('RRS_ONLINE') + + sched_mock = MagicMock() + sched_mock.schedules = [self._make_schedule(record_uid_bytes, no_schedule=False, + schedule_data='daily.0.12.1')] + mock_schedules.return_value = sched_mock + + mock_params = create_mock_params() + mock_params.record_cache = {} + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='json') + + self.assertIsNotNone(result, "Expected JSON string, got None") + data = json.loads(result) + self.assertIn('status', data) + self.assertTrue(data['ready_to_rotate']) + self.assertIn('pam_config_uid', data) + self.assertIn('gateway_name', data) + self.assertEqual(data['gateway_name'], 'gw-test') + self.assertIn('schedule_type', data) + self.assertEqual(data['schedule_type'], 'scheduled') + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_json_non_online_status(self, mock_rrg, mock_schedules): + """Non-online status + --format json returns minimal JSON with ready_to_rotate=false.""" + record_uid = 'test_record_uid_' + + mock_rrg.return_value = self._make_rri('RRS_NO_ROTATION') + + mock_params = create_mock_params() + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='json') + + self.assertIsNotNone(result, "Expected JSON string, got None") + data = json.loads(result) + self.assertIn('status', data) + self.assertFalse(data['ready_to_rotate']) + + @patch('keepercommander.commands.discoveryrotation.router_get_rotation_schedules') + @patch('keepercommander.commands.discoveryrotation.record_rotation_get') + def test_table_mode_returns_none(self, mock_rrg, mock_schedules): + """Table mode (default) prints to stdout and returns None.""" + from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + record_uid = 'test_record_uid_' + record_uid_bytes = url_safe_str_to_bytes(record_uid) + + mock_rrg.return_value = self._make_rri('RRS_ONLINE') + + sched_mock = MagicMock() + sched_mock.schedules = [self._make_schedule(record_uid_bytes)] + mock_schedules.return_value = sched_mock + + mock_params = create_mock_params() + mock_params.record_cache = {} + + cmd = PAMRouterGetRotationInfo() + result = cmd.execute(mock_params, record_uid=record_uid, format='table') + self.assertIsNone(result) diff --git a/unit-tests/pam/test_pam_tunnel.py b/unit-tests/pam/test_pam_tunnel.py index b8dc67b7a..41c55e7f8 100644 --- a/unit-tests/pam/test_pam_tunnel.py +++ b/unit-tests/pam/test_pam_tunnel.py @@ -1,158 +1,156 @@ -import sys import unittest from unittest import mock from keepercommander.error import CommandError -if sys.version_info >= (3, 8): - import datetime - import socket - import string - from cryptography import x509 - from cryptography.hazmat._oid import NameOID - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization, hashes - from cryptography.hazmat.primitives.asymmetric import ec - - from keepercommander.commands.tunnel.port_forward.tunnel_helpers import (generate_random_bytes, find_open_port) - - def generate_self_signed_cert(private_key): - # Generate a self-signed certificate - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), - ]) - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(private_key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.datetime.utcnow()) - .not_valid_after( - # Our certificate will be valid for 10 days - datetime.datetime.utcnow() + datetime.timedelta(days=10) - ) - .sign(private_key, hashes.SHA256(), default_backend()) +import datetime +import socket +import string +from cryptography import x509 +from cryptography.hazmat._oid import NameOID +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import ec + +from keepercommander.commands.tunnel.port_forward.tunnel_helpers import (generate_random_bytes, find_open_port) + +def generate_self_signed_cert(private_key): + # Generate a self-signed certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"localhost"), + ]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after( + # Our certificate will be valid for 10 days + datetime.datetime.utcnow() + datetime.timedelta(days=10) ) - cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') - - return cert_pem + .sign(private_key, hashes.SHA256(), default_backend()) + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode('utf-8') + + return cert_pem + + +def new_private_key(): + # Generate an EC private key + private_key = ec.generate_private_key( + ec.SECP256R1(), # Using P-256 curve + backend=default_backend() + ) + # Serialize to PEM format + private_key_str = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode('utf-8') + return private_key, private_key_str + + +class TestFindOpenPort(unittest.TestCase): + def mock_bind(self, address): + # Mock the behavior of socket.socket.bind + port = address[1] + if port in self.in_use_ports: + raise OSError("Address already in use") + else: + print(f"Port {port} bound successfully.") + + def test_preferred_port(self): + # Test that the function returns the preferred port if it's available + preferred_port = 50000 + open_port = find_open_port([], preferred_port=preferred_port) + self.assertEqual(open_port, preferred_port) + + def test_preferred_port_unavailable(self): + # Mock the bind method to simulate that port 80 is in use + with mock.patch('socket.socket.bind', side_effect=OSError("Address already in use")): + preferred_port = 80 + with self.assertRaises(CommandError): + open_port = find_open_port([], preferred_port=preferred_port) + + def test_range(self): + # Test that the function returns a port within the specified range + start_port = 50000 + end_port = 50010 + open_port = find_open_port([], start_port=start_port, end_port=end_port) + self.assertTrue(start_port <= open_port <= end_port) + + def test_no_available_ports(self): + # Setup + self.in_use_ports = set(range(50000, 50011)) # All these ports are in use + + # Patch + with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): + # Test + open_port = find_open_port([], start_port=50000, end_port=50010) + self.assertIsNone(open_port) + def test_invalid_range(self): + # Test that the function returns None if the range is invalid + open_port = find_open_port([], start_port=50010, end_port=50000) + self.assertIsNone(open_port) - def new_private_key(): - # Generate an EC private key - private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() - ) - # Serialize to PEM format - private_key_str = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') - return private_key, private_key_str - - - class TestFindOpenPort(unittest.TestCase): - def mock_bind(self, address): - # Mock the behavior of socket.socket.bind - port = address[1] - if port in self.in_use_ports: - raise OSError("Address already in use") - else: - print(f"Port {port} bound successfully.") - - def test_preferred_port(self): - # Test that the function returns the preferred port if it's available - preferred_port = 50000 - open_port = find_open_port([], preferred_port=preferred_port) - self.assertEqual(open_port, preferred_port) - - def test_preferred_port_unavailable(self): - # Mock the bind method to simulate that port 80 is in use - with mock.patch('socket.socket.bind', side_effect=OSError("Address already in use")): - preferred_port = 80 - with self.assertRaises(CommandError): - open_port = find_open_port([], preferred_port=preferred_port) - - def test_range(self): - # Test that the function returns a port within the specified range - start_port = 50000 - end_port = 50010 - open_port = find_open_port([], start_port=start_port, end_port=end_port) - self.assertTrue(start_port <= open_port <= end_port) - - def test_no_available_ports(self): - # Setup - self.in_use_ports = set(range(50000, 50011)) # All these ports are in use - - # Patch - with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): - # Test - open_port = find_open_port([], start_port=50000, end_port=50010) - self.assertIsNone(open_port) - - def test_invalid_range(self): - # Test that the function returns None if the range is invalid - open_port = find_open_port([], start_port=50010, end_port=50000) + def test_socket_exception(self): + # Test that the function handles exceptions other than OSError gracefully + with mock.patch('socket.socket.bind', side_effect=Exception("Test exception")): + open_port = find_open_port([], start_port=49152, end_port=49153, host='localhost') self.assertIsNone(open_port) - def test_socket_exception(self): - # Test that the function handles exceptions other than OSError gracefully - with mock.patch('socket.socket.bind', side_effect=Exception("Test exception")): - open_port = find_open_port([], start_port=49152, end_port=49153, host='localhost') - self.assertIsNone(open_port) - - def test_tried_ports(self): - # Setup - self.in_use_ports = {50000, 50001} # These ports are in use - - # Patch - with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): - # Test - open_port = find_open_port([50000, 50001], start_port=50000, end_port=50002) - self.assertEqual(open_port, 50002) - - - class TestGenerateRandomBytes(unittest.TestCase): - - def test_default_length(self): - # Test that the default length of the returned bytes is 32 - random_bytes = generate_random_bytes() - self.assertEqual(len(random_bytes), 32, f'Length 32 failed found {len(random_bytes)} in ' - f'{random_bytes}') - - def test_custom_length(self): - # Test custom lengths - for length in [1, 10, 20, 50, 100]: - random_bytes = generate_random_bytes(length) - self.assertEqual(len(random_bytes), length, f'Length {length} failed found {len(random_bytes)} in ' - f'{random_bytes}') - - def test_content(self): - # Test that the returned bytes only contain printable characters - for length in [1, 10, 20, 50, 100]: - random_bytes = generate_random_bytes(length) - self.assertTrue(all(byte in string.printable.encode('utf-8') for byte in random_bytes)) - - def test_zero_length(self): - # Test that a zero length returns an empty bytes object - random_bytes = generate_random_bytes(0) - self.assertEqual(random_bytes, b'') - - def test_negative_length(self): - # Test that a negative length raises a ValueError - with self.assertRaises(ValueError): - generate_random_bytes(-1) - - def test_type(self): - # Test that the return type is bytes - random_bytes = generate_random_bytes() - self.assertIsInstance(random_bytes, bytes) - - def test_uniqueness(self): - # Test that multiple calls return different values - random_bytes1 = generate_random_bytes() - random_bytes2 = generate_random_bytes() - self.assertNotEqual(random_bytes1, random_bytes2) + def test_tried_ports(self): + # Setup + self.in_use_ports = {50000, 50001} # These ports are in use + + # Patch + with mock.patch.object(socket.socket, 'bind', side_effect=self.mock_bind): + # Test + open_port = find_open_port([50000, 50001], start_port=50000, end_port=50002) + self.assertEqual(open_port, 50002) + + +class TestGenerateRandomBytes(unittest.TestCase): + + def test_default_length(self): + # Test that the default length of the returned bytes is 32 + random_bytes = generate_random_bytes() + self.assertEqual(len(random_bytes), 32, f'Length 32 failed found {len(random_bytes)} in ' + f'{random_bytes}') + + def test_custom_length(self): + # Test custom lengths + for length in [1, 10, 20, 50, 100]: + random_bytes = generate_random_bytes(length) + self.assertEqual(len(random_bytes), length, f'Length {length} failed found {len(random_bytes)} in ' + f'{random_bytes}') + + def test_content(self): + # Test that the returned bytes only contain printable characters + for length in [1, 10, 20, 50, 100]: + random_bytes = generate_random_bytes(length) + self.assertTrue(all(byte in string.printable.encode('utf-8') for byte in random_bytes)) + + def test_zero_length(self): + # Test that a zero length returns an empty bytes object + random_bytes = generate_random_bytes(0) + self.assertEqual(random_bytes, b'') + + def test_negative_length(self): + # Test that a negative length raises a ValueError + with self.assertRaises(ValueError): + generate_random_bytes(-1) + + def test_type(self): + # Test that the return type is bytes + random_bytes = generate_random_bytes() + self.assertIsInstance(random_bytes, bytes) + + def test_uniqueness(self): + # Test that multiple calls return different values + random_bytes1 = generate_random_bytes() + random_bytes2 = generate_random_bytes() + self.assertNotEqual(random_bytes1, random_bytes2) diff --git a/unit-tests/service/test_api_logging.py b/unit-tests/service/test_api_logging.py index 8ee60514c..c8ae53a86 100644 --- a/unit-tests/service/test_api_logging.py +++ b/unit-tests/service/test_api_logging.py @@ -1,90 +1,87 @@ -import sys -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask, request - from keepercommander.service.decorators.api_logging import api_log_handler +from unittest import TestCase, mock +from flask import Flask, request +from keepercommander.service.decorators.api_logging import api_log_handler - class TestApiLogging(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +class TestApiLogging(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/test', methods=['POST']) - @api_log_handler - def test_endpoint(): - if not request.is_json: - return {'error': 'Content-Type must be application/json'}, 415 - return {'status': 'success'}, 200 + @self.app.route('/test', methods=['POST']) + @api_log_handler + def test_endpoint(): + if not request.is_json: + return {'error': 'Content-Type must be application/json'}, 415 + return {'status': 'success'}, 200 - @self.app.route('/error', methods=['POST']) - @api_log_handler - def error_endpoint(): - if not request.is_json: - return {'error': 'Content-Type must be application/json'}, 415 - raise Exception("Test error") + @self.app.route('/error', methods=['POST']) + @api_log_handler + def error_endpoint(): + if not request.is_json: + return {'error': 'Content-Type must be application/json'}, 415 + raise Exception("Test error") - def test_api_log_success_request(self): - """Test logging of successful API request""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - test_data = {"test": "data"} - response = self.client.post('/test', - json=test_data, - headers={ - 'X-Forwarded-For': '127.0.0.1', - 'Content-Type': 'application/json' - }) + def test_api_log_success_request(self): + """Test logging of successful API request""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + test_data = {"test": "data"} + response = self.client.post('/test', + json=test_data, + headers={ + 'X-Forwarded-For': '127.0.0.1', + 'Content-Type': 'application/json' + }) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('POST', log_message) - self.assertIn('/test', log_message) - self.assertIn('127.0.0.1', log_message) - self.assertIn('200', log_message) - self.assertIn(f"data={str(test_data)}", log_message) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('POST', log_message) + self.assertIn('/test', log_message) + self.assertIn('127.0.0.1', log_message) + self.assertIn('200', log_message) + self.assertIn(f"data={str(test_data)}", log_message) - def test_api_log_error_request(self): - """Test logging of failed API request""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.error') as mock_log: - response = self.client.post('/error', json={}, - headers={'X-Forwarded-For': '127.0.0.1', - 'Content-Type': 'application/json'}) + def test_api_log_error_request(self): + """Test logging of failed API request""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.error') as mock_log: + response = self.client.post('/error', json={}, + headers={'X-Forwarded-For': '127.0.0.1', + 'Content-Type': 'application/json'}) - self.assertEqual(response.status_code, 500) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('POST', log_message) - self.assertIn('/error', log_message) - self.assertIn('127.0.0.1', log_message) - self.assertIn("error='Test error'", log_message) + self.assertEqual(response.status_code, 500) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('POST', log_message) + self.assertIn('/error', log_message) + self.assertIn('127.0.0.1', log_message) + self.assertIn("error='Test error'", log_message) - def test_api_log_remote_addr_fallback(self): - """Test logging falls back to remote_addr when X-Forwarded-For is missing""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - response = self.client.post('/test', json={}, - headers={'Content-Type': 'application/json'}, - environ_base={'REMOTE_ADDR': '192.168.1.1'}) + def test_api_log_remote_addr_fallback(self): + """Test logging falls back to remote_addr when X-Forwarded-For is missing""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + response = self.client.post('/test', json={}, + headers={'Content-Type': 'application/json'}, + environ_base={'REMOTE_ADDR': '192.168.1.1'}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertIn('192.168.1.1', log_message) + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertIn('192.168.1.1', log_message) - def test_api_log_timing(self): - """Test request timing is logged""" - with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: - response = self.client.post('/test', json={}, - headers={'Content-Type': 'application/json'}) + def test_api_log_timing(self): + """Test request timing is logged""" + with mock.patch('keepercommander.service.decorators.api_logging.logger.info') as mock_log: + response = self.client.post('/test', json={}, + headers={'Content-Type': 'application/json'}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {'status': 'success'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json, {'status': 'success'}) - mock_log.assert_called_once() - log_message = mock_log.call_args[0][0] - self.assertRegex(log_message, r'\d+\.\d+s') + mock_log.assert_called_once() + log_message = mock_log.call_args[0][0] + self.assertRegex(log_message, r'\d+\.\d+s') diff --git a/unit-tests/service/test_api_routes.py b/unit-tests/service/test_api_routes.py index 174e1e717..ec85eb4a3 100644 --- a/unit-tests/service/test_api_routes.py +++ b/unit-tests/service/test_api_routes.py @@ -1,77 +1,74 @@ -import sys - -if sys.version_info >= (3, 8): - import unittest - from unittest import mock - from flask import Blueprint, Flask - - from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint - from keepercommander.service.api.routes import init_routes - - - def passthrough_decorator(): - def decorator(fn): - return fn - return decorator - - - class TestServiceApiRoutes(unittest.TestCase): - def test_queue_mode_registers_v1_and_v2_routes(self): - app = Flask(__name__) - onboarding_bp = Blueprint("test_onboarding", __name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.routes.create_onboarding_blueprint', return_value=onboarding_bp), \ - mock.patch('keepercommander.service.core.request_queue.queue_manager.start') as mock_start, \ - mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "y"}): - init_routes(app) - - routes = {rule.rule for rule in app.url_map.iter_rules()} - self.assertIn('/api/v1/executecommand', routes) - self.assertIn('/api/v2/executecommand-async', routes) - self.assertIn('/api/v2/status/', routes) - self.assertIn('/api/v2/result/', routes) - self.assertIn('/api/v2/queue/status', routes) - self.assertIn('/health', routes) - mock_start.assert_called_once() - - def test_legacy_mode_registers_only_v1_route(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "n"}): - init_routes(app) - - routes = {rule.rule for rule in app.url_map.iter_rules()} - self.assertIn('/api/v1/executecommand', routes) - self.assertNotIn('/api/v2/executecommand-async', routes) - - def test_v1_compatibility_route_waits_for_queue_result(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.command.queue_manager.submit_request', return_value='req-1') as mock_submit, \ - mock.patch('keepercommander.service.api.command.queue_manager.wait_for_result', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_wait: - app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') - response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('X-API-Legacy'), 'true') - self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) - mock_submit.assert_called_once_with('ls', []) - mock_wait.assert_called_once_with('req-1') - - def test_v1_direct_route_keeps_legacy_execution_path(self): - app = Flask(__name__) - - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ - mock.patch('keepercommander.service.api.command.CommandExecutor.execute', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_execute, \ - mock.patch('keepercommander.service.api.command.queue_manager.submit_request') as mock_submit: - app.register_blueprint(create_legacy_command_blueprint(use_queue=False), url_prefix='/api/v1') - response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) - - self.assertEqual(response.status_code, 200) - self.assertEqual(response.headers.get('X-API-Legacy'), 'true') - self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) - mock_execute.assert_called_once_with('ls') - mock_submit.assert_not_called() +import unittest +from unittest import mock +from flask import Blueprint, Flask + +from keepercommander.service.api.command import create_legacy_command_blueprint +from keepercommander.service.api.routes import init_routes + + +def passthrough_decorator(): + def decorator(fn): + return fn + return decorator + + +class TestServiceApiRoutes(unittest.TestCase): + def test_queue_mode_registers_v1_and_v2_routes(self): + app = Flask(__name__) + onboarding_bp = Blueprint("test_onboarding", __name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.routes.create_onboarding_blueprint', return_value=onboarding_bp), \ + mock.patch('keepercommander.service.core.request_queue.queue_manager.start') as mock_start, \ + mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "y"}): + init_routes(app) + + routes = {rule.rule for rule in app.url_map.iter_rules()} + self.assertIn('/api/v1/executecommand', routes) + self.assertIn('/api/v2/executecommand-async', routes) + self.assertIn('/api/v2/status/', routes) + self.assertIn('/api/v2/result/', routes) + self.assertIn('/api/v2/queue/status', routes) + self.assertIn('/health', routes) + mock_start.assert_called_once() + + def test_legacy_mode_registers_only_v1_route(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.config.service_config.ServiceConfig.load_config', return_value={"queue_enabled": "n"}): + init_routes(app) + + routes = {rule.rule for rule in app.url_map.iter_rules()} + self.assertIn('/api/v1/executecommand', routes) + self.assertNotIn('/api/v2/executecommand-async', routes) + + def test_v1_compatibility_route_waits_for_queue_result(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.command.queue_manager.submit_request', return_value='req-1') as mock_submit, \ + mock.patch('keepercommander.service.api.command.queue_manager.wait_for_result', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_wait: + app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') + response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get('X-API-Legacy'), 'true') + self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) + mock_submit.assert_called_once_with('ls', []) + mock_wait.assert_called_once_with('req-1') + + def test_v1_direct_route_keeps_legacy_execution_path(self): + app = Flask(__name__) + + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator), \ + mock.patch('keepercommander.service.api.command.CommandExecutor.execute', return_value=({"status": "success", "data": {"command": "ls"}}, 200)) as mock_execute, \ + mock.patch('keepercommander.service.api.command.queue_manager.submit_request') as mock_submit: + app.register_blueprint(create_legacy_command_blueprint(use_queue=False), url_prefix='/api/v1') + response = app.test_client().post('/api/v1/executecommand', json={"command": "ls"}) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.headers.get('X-API-Legacy'), 'true') + self.assertEqual(response.get_json(), {"status": "success", "data": {"command": "ls"}}) + mock_execute.assert_called_once_with('ls') + mock_submit.assert_not_called() diff --git a/unit-tests/service/test_auth_security.py b/unit-tests/service/test_auth_security.py index 3f93b6610..bbd94ba4b 100644 --- a/unit-tests/service/test_auth_security.py +++ b/unit-tests/service/test_auth_security.py @@ -1,102 +1,99 @@ -import sys -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask - from keepercommander.service.decorators.auth import auth_check, policy_check - from keepercommander.service.decorators.security import security_check, is_allowed_ip - from keepercommander.service.util.config_reader import ConfigReader +from unittest import TestCase, mock +from flask import Flask +from keepercommander.service.decorators.auth import auth_check, policy_check +from keepercommander.service.decorators.security import security_check, is_allowed_ip +from keepercommander.service.util.config_reader import ConfigReader - class TestAuthSecurity(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +class TestAuthSecurity(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/test', methods=['POST']) - @security_check - @auth_check - @policy_check - def test_endpoint(): - return {'status': 'success'}, 200 + @self.app.route('/test', methods=['POST']) + @security_check + @auth_check + @policy_check + def test_endpoint(): + return {'status': 'success'}, 200 - def test_auth_check_missing_api_key(self): - """Test authentication with missing API key""" - with self.app.test_request_context('/test', method='POST'): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('api key', response[0]['error']) + def test_auth_check_missing_api_key(self): + """Test authentication with missing API key""" + with self.app.test_request_context('/test', method='POST'): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('api key', response[0]['error']) - @mock.patch.object(ConfigReader, 'read_config') - def test_auth_check_invalid_api_key(self, mock_read_config): - """Test authentication with invalid API key""" - mock_read_config.return_value = "different_key" + @mock.patch.object(ConfigReader, 'read_config') + def test_auth_check_invalid_api_key(self, mock_read_config): + """Test authentication with invalid API key""" + mock_read_config.return_value = "different_key" - with self.app.test_request_context('/test', method='POST', - headers={'api-key': 'test_key'}): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') + with self.app.test_request_context('/test', method='POST', + headers={'api-key': 'test_key'}): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') - @mock.patch.object(ConfigReader, 'read_config') - def test_auth_check_expired_key(self, mock_read_config): - """Test authentication with expired API key""" - mock_read_config.side_effect = [ - "test_key", - "2024-01-01T00:00:00" - ] + @mock.patch.object(ConfigReader, 'read_config') + def test_auth_check_expired_key(self, mock_read_config): + """Test authentication with expired API key""" + mock_read_config.side_effect = [ + "test_key", + "2024-01-01T00:00:00" + ] - with self.app.test_request_context('/test', method='POST', - headers={'api-key': 'test_key'}): - response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 401) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('expired', response[0]['error']) + with self.app.test_request_context('/test', method='POST', + headers={'api-key': 'test_key'}): + response = auth_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 401) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('expired', response[0]['error']) - # def test_security_check_blocked_ip(self): - # """Test security check with blocked IP""" - # with mock.patch.object(ConfigReader, 'read_config', return_value="192.168.1.1"): - # with self.app.test_request_context('/test', method='POST', - # environ_base={'REMOTE_ADDR': '192.168.1.1'}): - # response = security_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - # response_data = response[0].get_json() - # self.assertEqual(response[1], 403) - # self.assertEqual(response_data['error'], 'IP is blocked') + # def test_security_check_blocked_ip(self): + # """Test security check with blocked IP""" + # with mock.patch.object(ConfigReader, 'read_config', return_value="192.168.1.1"): + # with self.app.test_request_context('/test', method='POST', + # environ_base={'REMOTE_ADDR': '192.168.1.1'}): + # response = security_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + # response_data = response[0].get_json() + # self.assertEqual(response[1], 403) + # self.assertEqual(response_data['error'], 'IP is blocked') - def test_is_blocked_ip_single_ip(self): - """Test IP blocking with single IP address""" - blocked_ips = "192.168.1.1" - allowed_ips="192.168.1.2" - self.assertFalse(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) - self.assertTrue(is_allowed_ip("192.168.1.2", allowed_ips, blocked_ips)) + def test_is_blocked_ip_single_ip(self): + """Test IP blocking with single IP address""" + blocked_ips = "192.168.1.1" + allowed_ips="192.168.1.2" + self.assertFalse(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) + self.assertTrue(is_allowed_ip("192.168.1.2", allowed_ips, blocked_ips)) - def test_is_blocked_ip_cidr(self): - """Test IP blocking with CIDR notation""" - allowed_ips="192.168.1.1" - blocked_ips = "192.168.1.0" - self.assertTrue(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) - self.assertFalse(is_allowed_ip("192.168.1.254", allowed_ips, blocked_ips)) - self.assertFalse(is_allowed_ip("192.168.2.1", allowed_ips, blocked_ips)) + def test_is_blocked_ip_cidr(self): + """Test IP blocking with CIDR notation""" + allowed_ips="192.168.1.1" + blocked_ips = "192.168.1.0" + self.assertTrue(is_allowed_ip("192.168.1.1", allowed_ips, blocked_ips)) + self.assertFalse(is_allowed_ip("192.168.1.254", allowed_ips, blocked_ips)) + self.assertFalse(is_allowed_ip("192.168.2.1", allowed_ips, blocked_ips)) - @mock.patch.object(ConfigReader, 'read_config') - def test_policy_check_allowed_command(self, mock_read_config): - """Test policy check with allowed command""" - mock_read_config.return_value = "list,get,search" + @mock.patch.object(ConfigReader, 'read_config') + def test_policy_check_allowed_command(self, mock_read_config): + """Test policy check with allowed command""" + mock_read_config.return_value = "list,get,search" - with self.app.test_request_context('/test', method='POST', - json={"command": "list"}): - response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 200) - self.assertEqual(response[0]['status'], 'success') + with self.app.test_request_context('/test', method='POST', + json={"command": "list"}): + response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 200) + self.assertEqual(response[0]['status'], 'success') - @mock.patch.object(ConfigReader, 'read_config') - def test_policy_check_denied_command(self, mock_read_config): - """Test policy check with denied command""" - mock_read_config.return_value = "list,get,search" + @mock.patch.object(ConfigReader, 'read_config') + def test_policy_check_denied_command(self, mock_read_config): + """Test policy check with denied command""" + mock_read_config.return_value = "list,get,search" - with self.app.test_request_context('/test', method='POST', - json={"command": "delete"}): - response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() - self.assertEqual(response[1], 403) - self.assertEqual(response[0]['status'], 'error') - self.assertIn('Not permitted', response[0]['error']) \ No newline at end of file + with self.app.test_request_context('/test', method='POST', + json={"command": "delete"}): + response = policy_check(lambda *args, **kwargs: ({'status': 'success'}, 200))() + self.assertEqual(response[1], 403) + self.assertEqual(response[0]['status'], 'error') + self.assertIn('Not permitted', response[0]['error']) \ No newline at end of file diff --git a/unit-tests/service/test_command.py b/unit-tests/service/test_command.py index fbdcb9845..2df0637b1 100644 --- a/unit-tests/service/test_command.py +++ b/unit-tests/service/test_command.py @@ -1,121 +1,118 @@ -import sys import unittest -if sys.version_info >= (3, 8): - import pytest - from unittest import TestCase, mock - from flask import Flask - from keepercommander.service.util.command_util import CommandExecutor - from keepercommander.service.util.exceptions import CommandExecutionError - from keepercommander.service.util.parse_keeper_response import parse_keeper_response - - class TestCommandAPI(TestCase): - def setUp(self): - self.app = Flask(__name__) - self.client = self.app.test_client() +from unittest import TestCase, mock +from flask import Flask +from keepercommander.service.util.command_util import CommandExecutor +from keepercommander.service.util.exceptions import CommandExecutionError +from keepercommander.service.util.parse_keeper_response import parse_keeper_response + +class TestCommandAPI(TestCase): + def setUp(self): + self.app = Flask(__name__) + self.client = self.app.test_client() - @self.app.route('/api/v1/executecommand', methods=['POST']) - def execute_command(): - command = "ls" - response, status_code = CommandExecutor.execute(command) - return {'response': response}, status_code - - def test_validate_command(self): - """Test command validation""" - result, status_code = CommandExecutor.validate_command("") - self.assertIsNotNone(result) - self.assertEqual(status_code, 400) - self.assertEqual(result["error"], "No command provided.") - - result = CommandExecutor.validate_command("ls") + @self.app.route('/api/v1/executecommand', methods=['POST']) + def execute_command(): + command = "ls" + response, status_code = CommandExecutor.execute(command) + return {'response': response}, status_code + + def test_validate_command(self): + """Test command validation""" + result, status_code = CommandExecutor.validate_command("") + self.assertIsNotNone(result) + self.assertEqual(status_code, 400) + self.assertEqual(result["error"], "No command provided.") + + result = CommandExecutor.validate_command("ls") + self.assertIsNone(result) + + def test_validate_session(self): + """Test session validation""" + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=None): + result, status_code = CommandExecutor.validate_session() + self.assertEqual(status_code, 401) + self.assertIn("No active session", result["error"]) + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value={"session": "active"}): + result = CommandExecutor.validate_session() self.assertIsNone(result) - def test_validate_session(self): - """Test session validation""" - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=None): - result, status_code = CommandExecutor.validate_session() - self.assertEqual(status_code, 401) - self.assertIn("No active session", result["error"]) - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value={"session": "active"}): - result = CommandExecutor.validate_session() - self.assertIsNone(result) - - @unittest.skip - def test_command_execution_success(self): - """Test successful command execution""" - mock_params = {"session": "active"} - test_command = "ls" - expected_output = "Folder1\nFolder2\n" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ - mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): - - response, status_code = CommandExecutor.execute(test_command) - self.assertEqual(status_code, 200) - self.assertIsNotNone(response) - - @unittest.skip - def test_command_execution_failure(self): - """Test command execution failure""" - mock_params = {"session": "active"} - test_command = "invalid_command" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', side_effect=Exception("Command failed")), \ - self.assertRaises(CommandExecutionError): + @unittest.skip + def test_command_execution_success(self): + """Test successful command execution""" + mock_params = {"session": "active"} + test_command = "ls" + expected_output = "Folder1\nFolder2\n" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ + mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): + + response, status_code = CommandExecutor.execute(test_command) + self.assertEqual(status_code, 200) + self.assertIsNotNone(response) + + @unittest.skip + def test_command_execution_failure(self): + """Test command execution failure""" + mock_params = {"session": "active"} + test_command = "invalid_command" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', side_effect=Exception("Command failed")), \ + self.assertRaises(CommandExecutionError): - CommandExecutor.execute(test_command) - - def test_response_encryption(self): - """Test response encryption when key is present""" - test_response = {"status": "success", "data": "test"} - - mock_key = "0" * 32 - - with mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=mock_key): - encrypted_response = CommandExecutor.encrypt_response(test_response) - self.assertIsInstance(encrypted_response, bytes) - self.assertGreater(len(encrypted_response), 0) - - def test_response_parsing(self): - """Test response parsing for different commands""" - - ls_response = "# Folder UID\n1 folder1_uid folder1 rw\n# Record UID\n1 record1_uid login record1" - parsed = parse_keeper_response("ls", ls_response) - self.assertEqual(parsed["status"], "success") - self.assertEqual(parsed["command"], "ls") - self.assertIn("folders", parsed["data"]) - self.assertIn("records", parsed["data"]) - - tree_response = "Root\n Folder1\n SubFolder1" - parsed = parse_keeper_response("tree", tree_response) - self.assertEqual(parsed["command"], "tree") - self.assertIsInstance(parsed["data"], dict) - self.assertIn("tree", parsed["data"]) - - def test_capture_output(self): - """Test command output capture""" - test_command = "ls" - expected_output = "test output" - mock_params = {"session": "active"} - - with mock.patch('keepercommander.cli.do_command', return_value=expected_output): - return_value, output, logs = CommandExecutor.capture_output_and_logs(mock_params, test_command) - self.assertEqual(return_value, expected_output) - - @unittest.skip - def test_integration_command_flow(self): - """Test the complete command execution flow""" - test_command = "ls" - mock_params = {"session": "active"} - expected_output = "# Folder UID\n1 folder1_uid folder1 rw" - - with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ - mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ - mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): - - response, status_code = CommandExecutor.execute(test_command) - self.assertEqual(status_code, 200) - self.assertIsNotNone(response) \ No newline at end of file + CommandExecutor.execute(test_command) + + def test_response_encryption(self): + """Test response encryption when key is present""" + test_response = {"status": "success", "data": "test"} + + mock_key = "0" * 32 + + with mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=mock_key): + encrypted_response = CommandExecutor.encrypt_response(test_response) + self.assertIsInstance(encrypted_response, bytes) + self.assertGreater(len(encrypted_response), 0) + + def test_response_parsing(self): + """Test response parsing for different commands""" + + ls_response = "# Folder UID\n1 folder1_uid folder1 rw\n# Record UID\n1 record1_uid login record1" + parsed = parse_keeper_response("ls", ls_response) + self.assertEqual(parsed["status"], "success") + self.assertEqual(parsed["command"], "ls") + self.assertIn("folders", parsed["data"]) + self.assertIn("records", parsed["data"]) + + tree_response = "Root\n Folder1\n SubFolder1" + parsed = parse_keeper_response("tree", tree_response) + self.assertEqual(parsed["command"], "tree") + self.assertIsInstance(parsed["data"], dict) + self.assertIn("tree", parsed["data"]) + + def test_capture_output(self): + """Test command output capture""" + test_command = "ls" + expected_output = "test output" + mock_params = {"session": "active"} + + with mock.patch('keepercommander.cli.do_command', return_value=expected_output): + return_value, output, logs = CommandExecutor.capture_output_and_logs(mock_params, test_command) + self.assertEqual(return_value, expected_output) + + @unittest.skip + def test_integration_command_flow(self): + """Test the complete command execution flow""" + test_command = "ls" + mock_params = {"session": "active"} + expected_output = "# Folder UID\n1 folder1_uid folder1 rw" + + with mock.patch('keepercommander.service.util.command_util.get_current_params', return_value=mock_params), \ + mock.patch('keepercommander.cli.do_command', return_value=expected_output), \ + mock.patch('keepercommander.service.util.command_util.ConfigReader.read_config', return_value=None): + + response, status_code = CommandExecutor.execute(test_command) + self.assertEqual(status_code, 200) + self.assertIsNotNone(response) \ No newline at end of file diff --git a/unit-tests/service/test_config_operation.py b/unit-tests/service/test_config_operation.py index 3db9c6be0..6c16ae796 100644 --- a/unit-tests/service/test_config_operation.py +++ b/unit-tests/service/test_config_operation.py @@ -1,83 +1,81 @@ -import sys -if sys.version_info >= (3, 8): - from unittest import TestCase, mock - from keepercommander.params import KeeperParams - from keepercommander.service.config.service_config import ServiceConfig - from keepercommander.service.commands.config_operation import AddConfigService +from unittest import TestCase, mock +from keepercommander.params import KeeperParams +from keepercommander.service.config.service_config import ServiceConfig +from keepercommander.service.commands.config_operation import AddConfigService - class TestConfigOperation(TestCase): - def setUp(self): - self.mock_params = mock.Mock(spec=KeeperParams) - self.cmd = AddConfigService() +class TestConfigOperation(TestCase): + def setUp(self): + self.mock_params = mock.Mock(spec=KeeperParams) + self.cmd = AddConfigService() - def test_execute_with_existing_config(self): - mock_config = { - "is_advanced_security_enabled": "y", - "records": [] - } - mock_record = { - "api-key": "test-api-key", - "command_list": "list", - "expiration_timestamp": "2024-12-31T23:59:59", - #"expiration_of_token": "" - } + def test_execute_with_existing_config(self): + mock_config = { + "is_advanced_security_enabled": "y", + "records": [] + } + mock_record = { + "api-key": "test-api-key", + "command_list": "list", + "expiration_timestamp": "2024-12-31T23:59:59", + #"expiration_of_token": "" + } - with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ - mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ - mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ - mock.patch('builtins.print'): + with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ + mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ + mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ + mock.patch('builtins.print'): - self.cmd.execute(self.mock_params) + self.cmd.execute(self.mock_params) - expected_config = { - "is_advanced_security_enabled": "y", - "records": [mock_record] - } - mock_save.assert_called_once_with(expected_config) + expected_config = { + "is_advanced_security_enabled": "y", + "records": [mock_record] + } + mock_save.assert_called_once_with(expected_config) - def test_execute_when_config_not_found(self): - with mock.patch.object(ServiceConfig, 'load_config', side_effect=FileNotFoundError), \ - mock.patch('builtins.print') as mock_print: + def test_execute_when_config_not_found(self): + with mock.patch.object(ServiceConfig, 'load_config', side_effect=FileNotFoundError), \ + mock.patch('builtins.print') as mock_print: - result = self.cmd.execute(self.mock_params) + result = self.cmd.execute(self.mock_params) - mock_print.assert_called_with( - "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - ) - self.assertEqual(result, '') + mock_print.assert_called_with( + "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + ) + self.assertEqual(result, '') - def test_execute_with_general_error(self): - with mock.patch.object(ServiceConfig, 'load_config', side_effect=Exception("Test error")), \ - mock.patch('builtins.print') as mock_print: + def test_execute_with_general_error(self): + with mock.patch.object(ServiceConfig, 'load_config', side_effect=Exception("Test error")), \ + mock.patch('builtins.print') as mock_print: - result = self.cmd.execute(self.mock_params) + result = self.cmd.execute(self.mock_params) - mock_print.assert_called_with( - "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - ) - self.assertEqual(result, '') + mock_print.assert_called_with( + "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + ) + self.assertEqual(result, '') - def test_create_and_add_record(self): - mock_config = { - "is_advanced_security_enabled": "n", - "records": [{"existing": "record"}] - } - mock_record = { - "api-key": "new-api-key", - "command_list": "list", - "expiration_timestamp": "2024-12-31T23:59:59", - #"expiration_of_token": "" - } + def test_create_and_add_record(self): + mock_config = { + "is_advanced_security_enabled": "n", + "records": [{"existing": "record"}] + } + mock_record = { + "api-key": "new-api-key", + "command_list": "list", + "expiration_timestamp": "2024-12-31T23:59:59", + #"expiration_of_token": "" + } - with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ - mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ - mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ - mock.patch('builtins.print'): + with mock.patch.object(ServiceConfig, 'load_config', return_value=mock_config), \ + mock.patch.object(ServiceConfig, 'create_record', return_value=mock_record), \ + mock.patch.object(ServiceConfig, 'save_config') as mock_save, \ + mock.patch('builtins.print'): - self.cmd.execute(self.mock_params) + self.cmd.execute(self.mock_params) - expected_config = { - "is_advanced_security_enabled": "n", - "records": [{"existing": "record"}, mock_record] - } - mock_save.assert_called_once_with(expected_config) \ No newline at end of file + expected_config = { + "is_advanced_security_enabled": "n", + "records": [{"existing": "record"}, mock_record] + } + mock_save.assert_called_once_with(expected_config) \ No newline at end of file diff --git a/unit-tests/service/test_config_validation.py b/unit-tests/service/test_config_validation.py index 77b28aa9e..5fb40ac53 100644 --- a/unit-tests/service/test_config_validation.py +++ b/unit-tests/service/test_config_validation.py @@ -1,168 +1,166 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import patch - import socket - from datetime import timedelta - from keepercommander.service.config.config_validation import ConfigValidator - from keepercommander.service.util.exceptions import ValidationError +import unittest +from unittest.mock import patch +import socket +from datetime import timedelta +from keepercommander.service.config.config_validation import ConfigValidator +from keepercommander.service.util.exceptions import ValidationError - class TestConfigValidator(unittest.TestCase): - def setUp(self): - self.validator = ConfigValidator() +class TestConfigValidator(unittest.TestCase): + def setUp(self): + self.validator = ConfigValidator() - def test_validate_port_valid(self): - """Test port validation with valid port numbers""" - test_ports = [1024, 8080, 8900, 9000, 65535] - for port in test_ports: - with self.subTest(port=port): - with patch('socket.socket') as mock_socket: - mock_socket.return_value.__enter__.return_value.bind.return_value = None - result = self.validator.validate_port(port) - self.assertEqual(result, port) + def test_validate_port_valid(self): + """Test port validation with valid port numbers""" + test_ports = [1024, 8080, 8900, 9000, 65535] + for port in test_ports: + with self.subTest(port=port): + with patch('socket.socket') as mock_socket: + mock_socket.return_value.__enter__.return_value.bind.return_value = None + result = self.validator.validate_port(port) + self.assertEqual(result, port) - def test_validate_port_invalid_number(self): - """Test port validation with invalid port numbers""" - invalid_ports = [-1, 0, 80, 443, 1023, 65536, 'abc', ''] - for port in invalid_ports: - with self.subTest(port=port): - with self.assertRaises(ValidationError): - self.validator.validate_port(port) + def test_validate_port_invalid_number(self): + """Test port validation with invalid port numbers""" + invalid_ports = [-1, 0, 80, 443, 1023, 65536, 'abc', ''] + for port in invalid_ports: + with self.subTest(port=port): + with self.assertRaises(ValidationError): + self.validator.validate_port(port) - def test_validate_port_in_use(self): - """Test port validation when port is already in use""" - with patch('socket.socket') as mock_socket: - mock_socket.return_value.__enter__.return_value.bind.side_effect = socket.error() - with self.assertRaises(ValidationError) as context: - self.validator.validate_port(8080) - self.assertIn("is already in use", str(context.exception)) + def test_validate_port_in_use(self): + """Test port validation when port is already in use""" + with patch('socket.socket') as mock_socket: + mock_socket.return_value.__enter__.return_value.bind.side_effect = socket.error() + with self.assertRaises(ValidationError) as context: + self.validator.validate_port(8080) + self.assertIn("is already in use", str(context.exception)) - def test_validate_ngrok_token_valid(self): - """Test ngrok token validation with valid tokens""" - valid_tokens = [ - '1234567890abcdef', - 'abcdef1234567890', - 'abc123_def456-789' - ] - for token in valid_tokens: - with self.subTest(token=token): - result = self.validator.validate_ngrok_token(token) - self.assertEqual(result, token) + def test_validate_ngrok_token_valid(self): + """Test ngrok token validation with valid tokens""" + valid_tokens = [ + '1234567890abcdef', + 'abcdef1234567890', + 'abc123_def456-789' + ] + for token in valid_tokens: + with self.subTest(token=token): + result = self.validator.validate_ngrok_token(token) + self.assertEqual(result, token) - def test_validate_ngrok_token_invalid(self): - """Test ngrok token validation with invalid tokens""" - invalid_tokens = [ - '', - '123', - 'abc@def', - None - ] - for token in invalid_tokens: - with self.subTest(token=token): - with self.assertRaises(ValidationError): - self.validator.validate_ngrok_token(token) + def test_validate_ngrok_token_invalid(self): + """Test ngrok token validation with invalid tokens""" + invalid_tokens = [ + '', + '123', + 'abc@def', + None + ] + for token in invalid_tokens: + with self.subTest(token=token): + with self.assertRaises(ValidationError): + self.validator.validate_ngrok_token(token) - def test_validate_rate_limit_valid(self): - """Test rate limit validation with valid formats""" - valid_limits = [ - '10/minute', - '100/hour', - '1000/day', - '50 per minute', - '200 per hour', - '5000 per day' - ] - for limit in valid_limits: - with self.subTest(limit=limit): - result = self.validator.validate_rate_limit(limit) - self.assertEqual(result, limit) + def test_validate_rate_limit_valid(self): + """Test rate limit validation with valid formats""" + valid_limits = [ + '10/minute', + '100/hour', + '1000/day', + '50 per minute', + '200 per hour', + '5000 per day' + ] + for limit in valid_limits: + with self.subTest(limit=limit): + result = self.validator.validate_rate_limit(limit) + self.assertEqual(result, limit) - def test_validate_rate_limit_invalid(self): - """Test rate limit validation with invalid formats""" - invalid_limits = [ - 'abc', - '10/second', - '100 by hour', - '0/minute', - '0/hour', - '0/day', - '0 per minute', - ] - for limit in invalid_limits: - with self.subTest(limit=limit): - with self.assertRaises(ValidationError): - self.validator.validate_rate_limit(limit) + def test_validate_rate_limit_invalid(self): + """Test rate limit validation with invalid formats""" + invalid_limits = [ + 'abc', + '10/second', + '100 by hour', + '0/minute', + '0/hour', + '0/day', + '0 per minute', + ] + for limit in invalid_limits: + with self.subTest(limit=limit): + with self.assertRaises(ValidationError): + self.validator.validate_rate_limit(limit) - def test_validate_ip_list_valid(self): - """Test IP list validation with valid IPs and CIDR blocks""" - valid_ips = [ - '192.168.1.1', - '10.0.0.0/24', - '192.168.1.1,10.0.0.0/24', - '2001:db8::1', - 'fe80::/10' - ] - for ip_list in valid_ips: - with self.subTest(ip_list=ip_list): - result = self.validator.validate_ip_list(ip_list) - self.assertEqual(result, ip_list) + def test_validate_ip_list_valid(self): + """Test IP list validation with valid IPs and CIDR blocks""" + valid_ips = [ + '192.168.1.1', + '10.0.0.0/24', + '192.168.1.1,10.0.0.0/24', + '2001:db8::1', + 'fe80::/10' + ] + for ip_list in valid_ips: + with self.subTest(ip_list=ip_list): + result = self.validator.validate_ip_list(ip_list) + self.assertEqual(result, ip_list) - def test_validate_ip_list_invalid(self): - """Test IP list validation with invalid IPs""" - invalid_ips = [ - '256.256.256.256', - '192.168.1', - '2001:xyz::1', - '192.168.1.1/33', - ] - for ip_list in invalid_ips: - with self.subTest(ip_list=ip_list): - with self.assertRaises(ValidationError): - self.validator.validate_ip_list(ip_list) + def test_validate_ip_list_invalid(self): + """Test IP list validation with invalid IPs""" + invalid_ips = [ + '256.256.256.256', + '192.168.1', + '2001:xyz::1', + '192.168.1.1/33', + ] + for ip_list in invalid_ips: + with self.subTest(ip_list=ip_list): + with self.assertRaises(ValidationError): + self.validator.validate_ip_list(ip_list) - def test_validate_encryption_key_valid(self): - """Test encryption key validation with valid keys""" - valid_key = 'abcdef1234567890ABCDEF1234567890' - result = self.validator.validate_encryption_key(valid_key) - self.assertEqual(result, valid_key) + def test_validate_encryption_key_valid(self): + """Test encryption key validation with valid keys""" + valid_key = 'abcdef1234567890ABCDEF1234567890' + result = self.validator.validate_encryption_key(valid_key) + self.assertEqual(result, valid_key) - def test_validate_encryption_key_invalid(self): - """Test encryption key validation with invalid keys""" - invalid_keys = [ - '', - '123456', - 'a' * 31, - 'a' * 33, - 'abc$%^&*()', - None - ] - for key in invalid_keys: - with self.subTest(key=key): - with self.assertRaises(ValidationError): - self.validator.validate_encryption_key(key) + def test_validate_encryption_key_invalid(self): + """Test encryption key validation with invalid keys""" + invalid_keys = [ + '', + '123456', + 'a' * 31, + 'a' * 33, + 'abc$%^&*()', + None + ] + for key in invalid_keys: + with self.subTest(key=key): + with self.assertRaises(ValidationError): + self.validator.validate_encryption_key(key) - def test_parse_expiration_time_valid(self): - """Test expiration time parsing with valid formats""" - test_cases = [ - ('30m', timedelta(minutes=30)), - ('24h', timedelta(hours=24)), - ('7d', timedelta(days=7)) - ] - for input_str, expected in test_cases: - with self.subTest(input_str=input_str): - result = self.validator.parse_expiration_time(input_str) - self.assertEqual(result, expected) + def test_parse_expiration_time_valid(self): + """Test expiration time parsing with valid formats""" + test_cases = [ + ('30m', timedelta(minutes=30)), + ('24h', timedelta(hours=24)), + ('7d', timedelta(days=7)) + ] + for input_str, expected in test_cases: + with self.subTest(input_str=input_str): + result = self.validator.parse_expiration_time(input_str) + self.assertEqual(result, expected) - def test_parse_expiration_time_invalid(self): - """Test expiration time parsing with invalid formats""" - invalid_times = [ - '', - '30x', - '-30m', - '0m', - 'abc', - ] - for time_str in invalid_times: - with self.subTest(time_str=time_str): - with self.assertRaises(ValidationError): - self.validator.parse_expiration_time(time_str) \ No newline at end of file + def test_parse_expiration_time_invalid(self): + """Test expiration time parsing with invalid formats""" + invalid_times = [ + '', + '30x', + '-30m', + '0m', + 'abc', + ] + for time_str in invalid_times: + with self.subTest(time_str=time_str): + with self.assertRaises(ValidationError): + self.validator.parse_expiration_time(time_str) \ No newline at end of file diff --git a/unit-tests/service/test_create_service.py b/unit-tests/service/test_create_service.py index a81017ea8..d761034a4 100644 --- a/unit-tests/service/test_create_service.py +++ b/unit-tests/service/test_create_service.py @@ -1,349 +1,347 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import Mock, patch - from keepercommander.params import KeeperParams - from keepercommander.service.commands.create_service import CreateService, StreamlineArgs +import unittest +from unittest.mock import Mock, patch +from keepercommander.params import KeeperParams +from keepercommander.service.commands.create_service import CreateService, StreamlineArgs - class TestCreateService(unittest.TestCase): - def setUp(self): - self.params = Mock(spec=KeeperParams) - self.command = CreateService() +class TestCreateService(unittest.TestCase): + def setUp(self): + self.params = Mock(spec=KeeperParams) + self.command = CreateService() - def test_get_parser(self): - """Test parser creation with correct arguments.""" - parser = self.command.get_parser() + def test_get_parser(self): + """Test parser creation with correct arguments.""" + parser = self.command.get_parser() - args = parser.parse_args(['--port', '8080']) - self.assertEqual(args.port, 8080) + args = parser.parse_args(['--port', '8080']) + self.assertEqual(args.port, 8080) - args = parser.parse_args(['--commands', 'record-list']) - self.assertEqual(args.commands, 'record-list') + args = parser.parse_args(['--commands', 'record-list']) + self.assertEqual(args.commands, 'record-list') - args = parser.parse_args(['--ngrok', 'token123']) - self.assertEqual(args.ngrok, 'token123') + args = parser.parse_args(['--ngrok', 'token123']) + self.assertEqual(args.ngrok, 'token123') - args = parser.parse_args(['--cloudflare', 'cf_token123']) - self.assertEqual(args.cloudflare, 'cf_token123') + args = parser.parse_args(['--cloudflare', 'cf_token123']) + self.assertEqual(args.cloudflare, 'cf_token123') - args = parser.parse_args(['--cloudflare_custom_domain', 'example.com']) - self.assertEqual(args.cloudflare_custom_domain, 'example.com') + args = parser.parse_args(['--cloudflare_custom_domain', 'example.com']) + self.assertEqual(args.cloudflare_custom_domain, 'example.com') - @patch('keepercommander.service.core.service_manager.ServiceManager') - def test_execute_service_already_running(self, mock_service_manager): - """Test execute when service is already running.""" - mock_service_manager.get_status.return_value = "Commander Service is Running on port 8080" + @patch('keepercommander.service.core.service_manager.ServiceManager') + def test_execute_service_already_running(self, mock_service_manager): + """Test execute when service is already running.""" + mock_service_manager.get_status.return_value = "Commander Service is Running on port 8080" - with patch('builtins.print') as mock_print: - self.command.execute(self.params) - mock_print.assert_called_with("Error: Commander Service is already running.") + with patch('builtins.print') as mock_print: + self.command.execute(self.params) + mock_print.assert_called_with("Error: Commander Service is already running.") - def test_handle_configuration_streamlined(self): - """Test streamlined configuration handling.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_handle_configuration_streamlined(self): + """Test streamlined configuration handling.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_handle_configuration_interactive(self): - """Test interactive configuration handling.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=None, commands=None, ngrok=None, allowedip='' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled=None, update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_handle_configuration_interactive(self): + """Test interactive configuration handling.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=None, commands=None, ngrok=None, allowedip='' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled=None, update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.config_handler, 'handle_interactive_config') as mock_interactive, \ - patch.object(self.command.security_handler, 'configure_security') as mock_security: - self.command._handle_configuration(config_data, self.params, args) - mock_interactive.assert_called_once_with(config_data, self.params) - mock_security.assert_called_once_with(config_data) + with patch.object(self.command.config_handler, 'handle_interactive_config') as mock_interactive, \ + patch.object(self.command.security_handler, 'configure_security') as mock_security: + self.command._handle_configuration(config_data, self.params, args) + mock_interactive.assert_called_once_with(config_data, self.params) + mock_security.assert_called_once_with(config_data) - def test_create_and_save_record(self): - """Test record creation and saving.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_create_and_save_record(self): + """Test record creation and saving.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs(port=8080, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch.object(self.command.service_config, 'create_record') as mock_create_record, \ - patch.object(self.command.service_config, 'save_config') as mock_save_config: + with patch.object(self.command.service_config, 'create_record') as mock_create_record, \ + patch.object(self.command.service_config, 'save_config') as mock_save_config: - mock_create_record.return_value = {'api-key': 'test-key'} - self.command._create_and_save_record(config_data, self.params, args) + mock_create_record.return_value = {'api-key': 'test-key'} + self.command._create_and_save_record(config_data, self.params, args) - mock_create_record.assert_called_once_with( - config_data["is_advanced_security_enabled"], - self.params, - args.commands, - args.token_expiration, - None # record_uid (update_vault_record is None) - ) - if(args.fileformat): - config_data["fileformat"]= args.fileformat - else: - mock_save_config.assert_called_once_with(config_data, 'create') + mock_create_record.assert_called_once_with( + config_data["is_advanced_security_enabled"], + self.params, + args.commands, + args.token_expiration, + None # record_uid (update_vault_record is None) + ) + if(args.fileformat): + config_data["fileformat"]= args.fileformat + else: + mock_save_config.assert_called_once_with(config_data, 'create') - def test_validation_error_handling(self): - """Test handling of validation errors during execution.""" - args = StreamlineArgs(port=-1, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) + def test_validation_error_handling(self): + """Test handling of validation errors during execution.""" + args = StreamlineArgs(port=-1, commands='record-list', ngrok=None, allowedip='0.0.0.0' ,deniedip='', ngrok_custom_domain=None, cloudflare=None, cloudflare_custom_domain=None, certfile='', certpassword='', fileformat='json', run_mode='foreground', queue_enabled='y', update_vault_record=None, ratelimit=None, encryption_key=None, token_expiration=None) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, port=-1) + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, port=-1) - mock_print.assert_called() + mock_print.assert_called() - def test_cloudflare_streamlined_configuration(self): - """Test streamlined configuration with Cloudflare tunnel.""" - config_data = self.command.service_config.create_default_config() - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_streamlined_configuration(self): + """Test streamlined configuration with Cloudflare tunnel.""" + config_data = self.command.service_config.create_default_config() + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_cloudflare_validation_missing_token(self): - """Test validation error when Cloudflare token is missing but domain is provided.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare=None, - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_validation_missing_token(self): + """Test validation error when Cloudflare token is missing but domain is provided.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare=None, + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, cloudflare_custom_domain='tunnel.example.com') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, cloudflare_custom_domain='tunnel.example.com') + mock_print.assert_called() - def test_cloudflare_validation_missing_domain(self): - """Test validation error when Cloudflare domain is missing but token is provided.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain=None, - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_validation_missing_domain(self): + """Test validation error when Cloudflare domain is missing but token is provided.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain=None, + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, cloudflare='cf_token123') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, cloudflare='cf_token123') + mock_print.assert_called() - def test_cloudflare_and_ngrok_mutual_exclusion(self): - """Test that Cloudflare and ngrok cannot be used together.""" - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok='ngrok_token123', - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain='ngrok.example.com', - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_and_ngrok_mutual_exclusion(self): + """Test that Cloudflare and ngrok cannot be used together.""" + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok='ngrok_token123', + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain='ngrok.example.com', + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - mock_create_config.return_value = {} - self.command.execute(self.params, ngrok='ngrok_token123', cloudflare='cf_token123') - mock_print.assert_called() + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + mock_create_config.return_value = {} + self.command.execute(self.params, ngrok='ngrok_token123', cloudflare='cf_token123') + mock_print.assert_called() - @patch('keepercommander.service.config.cloudflare_config.CloudflareConfigurator.configure_cloudflare') - def test_cloudflare_tunnel_startup_success(self, mock_cloudflare_configure): - """Test successful Cloudflare tunnel startup.""" - config_data = self.command.service_config.create_default_config() - config_data.update({ - 'cloudflare': 'y', - 'cloudflare_tunnel_token': 'cf_token123', - 'cloudflare_custom_domain': 'tunnel.example.com', - 'port': 8080 - }) + @patch('keepercommander.service.config.cloudflare_config.CloudflareConfigurator.configure_cloudflare') + def test_cloudflare_tunnel_startup_success(self, mock_cloudflare_configure): + """Test successful Cloudflare tunnel startup.""" + config_data = self.command.service_config.create_default_config() + config_data.update({ + 'cloudflare': 'y', + 'cloudflare_tunnel_token': 'cf_token123', + 'cloudflare_custom_domain': 'tunnel.example.com', + 'port': 8080 + }) - mock_cloudflare_configure.return_value = 12345 # Mock PID + mock_cloudflare_configure.return_value = 12345 # Mock PID - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - @patch('keepercommander.service.core.globals.init_globals') - @patch('keepercommander.service.core.service_manager.ServiceManager.start_service') - @patch('keepercommander.service.core.service_manager.ServiceManager.get_status') - def test_cloudflare_tunnel_startup_failure(self, mock_get_status, mock_start_service, mock_init_globals): - """Test Cloudflare tunnel startup failure due to firewall.""" - # Mock that service is not already running - mock_get_status.return_value = "Commander Service is not running" + @patch('keepercommander.service.core.globals.init_globals') + @patch('keepercommander.service.core.service_manager.ServiceManager.start_service') + @patch('keepercommander.service.core.service_manager.ServiceManager.get_status') + def test_cloudflare_tunnel_startup_failure(self, mock_get_status, mock_start_service, mock_init_globals): + """Test Cloudflare tunnel startup failure due to firewall.""" + # Mock that service is not already running + mock_get_status.return_value = "Commander Service is not running" - # Mock service startup failure due to Cloudflare tunnel issues - mock_start_service.side_effect = Exception("Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") + # Mock service startup failure due to Cloudflare tunnel issues + mock_start_service.side_effect = Exception("Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") - with patch('builtins.print') as mock_print: - with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: - with patch.object(self.command.service_config, 'create_record') as mock_create_record: - with patch.object(self.command.service_config, 'save_config') as mock_save_config: - with patch.object(self.command.service_config, 'update_or_add_record') as mock_update_record: - with patch.object(self.command.service_config.validator, 'validate_cloudflare_token') as mock_validate_token: - mock_create_config.return_value = { - 'is_advanced_security_enabled': 'n', - 'fileformat': 'json' - } - mock_create_record.return_value = {'api-key': 'test-key'} - mock_validate_token.return_value = 'cf_token123' # Mock valid token + with patch('builtins.print') as mock_print: + with patch.object(self.command.service_config, 'create_default_config') as mock_create_config: + with patch.object(self.command.service_config, 'create_record') as mock_create_record: + with patch.object(self.command.service_config, 'save_config') as mock_save_config: + with patch.object(self.command.service_config, 'update_or_add_record') as mock_update_record: + with patch.object(self.command.service_config.validator, 'validate_cloudflare_token') as mock_validate_token: + mock_create_config.return_value = { + 'is_advanced_security_enabled': 'n', + 'fileformat': 'json' + } + mock_create_record.return_value = {'api-key': 'test-key'} + mock_validate_token.return_value = 'cf_token123' # Mock valid token - # This should trigger the exception handling in execute() - self.command.execute( - self.params, - port=8080, - allowedip='0.0.0.0', - deniedip='', - commands='record-list', - ngrok=None, - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + # This should trigger the exception handling in execute() + self.command.execute( + self.params, + port=8080, + allowedip='0.0.0.0', + deniedip='', + commands='record-list', + ngrok=None, + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - # Verify that the error was printed - mock_print.assert_called_with("Unexpected error: Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") + # Verify that the error was printed + mock_print.assert_called_with("Unexpected error: Commander Service failed to start: Cloudflare tunnel failed to connect. This is likely due to firewall/proxy blocking the connection.") - def test_cloudflare_token_validation(self): - """Test Cloudflare token format validation.""" - # Test valid token format - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='eyJhIjoiYWJjZGVmZ2hpams', # Base64-like token - cloudflare_custom_domain='tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_token_validation(self): + """Test Cloudflare token format validation.""" + # Test valid token format + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='eyJhIjoiYWJjZGVmZ2hpams', # Base64-like token + cloudflare_custom_domain='tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - config_data = self.command.service_config.create_default_config() - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + config_data = self.command.service_config.create_default_config() + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - def test_cloudflare_domain_validation(self): - """Test Cloudflare custom domain validation.""" - # Test valid domain format - args = StreamlineArgs( - port=8080, - commands='record-list', - ngrok=None, - allowedip='0.0.0.0', - deniedip='', - ngrok_custom_domain=None, - cloudflare='cf_token123', - cloudflare_custom_domain='my-tunnel.example.com', - certfile='', - certpassword='', - fileformat='json', - run_mode='foreground', - queue_enabled='y', - update_vault_record=None, - ratelimit=None, - encryption_key=None, - token_expiration=None - ) + def test_cloudflare_domain_validation(self): + """Test Cloudflare custom domain validation.""" + # Test valid domain format + args = StreamlineArgs( + port=8080, + commands='record-list', + ngrok=None, + allowedip='0.0.0.0', + deniedip='', + ngrok_custom_domain=None, + cloudflare='cf_token123', + cloudflare_custom_domain='my-tunnel.example.com', + certfile='', + certpassword='', + fileformat='json', + run_mode='foreground', + queue_enabled='y', + update_vault_record=None, + ratelimit=None, + encryption_key=None, + token_expiration=None + ) - with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: - config_data = self.command.service_config.create_default_config() - self.command._handle_configuration(config_data, self.params, args) - mock_streamlined.assert_called_once_with(config_data, args, self.params) + with patch.object(self.command.config_handler, 'handle_streamlined_config') as mock_streamlined: + config_data = self.command.service_config.create_default_config() + self.command._handle_configuration(config_data, self.params, args) + mock_streamlined.assert_called_once_with(config_data, args, self.params) - if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/unit-tests/service/test_queue_concurrency.py b/unit-tests/service/test_queue_concurrency.py index 991ae4bc0..5a2800e30 100644 --- a/unit-tests/service/test_queue_concurrency.py +++ b/unit-tests/service/test_queue_concurrency.py @@ -1,209 +1,206 @@ -import sys - -if sys.version_info >= (3, 8): - import queue - import threading - import time - import unittest - from unittest import mock - from flask import Flask - - from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint - from keepercommander.service.core.request_queue import ( - DEFAULT_QUEUE_MAX_SIZE, - DEFAULT_REQUEST_TIMEOUT, - DEFAULT_RESULT_RETENTION, - RequestQueueManager, - ) - - - def passthrough_decorator(): - def decorator(fn): - return fn - return decorator - - - class TestQueueConcurrency(unittest.TestCase): - def setUp(self): - self.manager = RequestQueueManager() - self._reset_manager() - - def tearDown(self): - self._reset_manager() - - def _reset_manager(self): - self.manager.stop() - self.manager.request_queue = queue.Queue(maxsize=DEFAULT_QUEUE_MAX_SIZE) - self.manager.active_requests = {} - self.manager.completed_requests = {} - self.manager.worker_thread = None - self.manager.is_running = False - self.manager.current_request_id = None - self.manager.request_timeout = DEFAULT_REQUEST_TIMEOUT - self.manager.result_retention = DEFAULT_RESULT_RETENTION - - def _create_app(self, include_v2=False): - app = Flask(__name__) - with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator): - app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') - if include_v2: - app.register_blueprint(create_command_blueprint(), url_prefix='/api/v2') - return app - - def test_queue_manager_serializes_concurrent_submissions(self): - state_lock = threading.Lock() - inflight = {"count": 0, "max": 0} - results = {} - - def fake_execute(command): - with state_lock: - inflight["count"] += 1 - inflight["max"] = max(inflight["max"], inflight["count"]) - - time.sleep(0.05) - - with state_lock: - inflight["count"] -= 1 - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def submit_and_wait(index): - request_id = self.manager.submit_request(f"cmd-{index}") - results[index] = self.manager.wait_for_result(request_id, timeout=2) - - threads = [threading.Thread(target=submit_and_wait, args=(i,)) for i in range(5)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - self.assertEqual(inflight["max"], 1) - self.assertEqual(len(results), 5) - for index in range(5): - payload, status_code = results[index] - self.assertEqual(status_code, 200) - self.assertEqual(payload["data"]["command"], f"cmd-{index}") - - def test_v1_and_v2_share_single_queue_worker(self): - app = self._create_app(include_v2=True) - state_lock = threading.Lock() - inflight = {"count": 0, "max": 0} - outputs = {} - start_barrier = threading.Barrier(3) - - def fake_execute(command): - with state_lock: - inflight["count"] += 1 - inflight["max"] = max(inflight["max"], inflight["count"]) - - time.sleep(0.05) - - with state_lock: - inflight["count"] -= 1 - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def call_v1(): - with app.test_client() as client: - start_barrier.wait() - response = client.post('/api/v1/executecommand', json={"command": "legacy-cmd"}) - outputs["v1"] = (response.status_code, response.get_json(), response.headers.get('X-API-Legacy')) - - def call_v2(): - with app.test_client() as client: - start_barrier.wait() - response = client.post('/api/v2/executecommand-async', json={"command": "async-cmd"}) - response_data = response.get_json() - outputs["v2_submit"] = (response.status_code, response_data) - outputs["v2_result"] = self.manager.wait_for_result(response_data["request_id"], timeout=2) - - v1_thread = threading.Thread(target=call_v1) - v2_thread = threading.Thread(target=call_v2) - v1_thread.start() - v2_thread.start() - start_barrier.wait() - v1_thread.join() - v2_thread.join() - - self.assertEqual(inflight["max"], 1) - self.assertEqual(outputs["v1"][0], 200) - self.assertEqual(outputs["v1"][1]["data"]["command"], "legacy-cmd") - self.assertEqual(outputs["v1"][2], "true") - self.assertEqual(outputs["v2_submit"][0], 202) - self.assertEqual(outputs["v2_submit"][1]["status"], "queued") - self.assertEqual(outputs["v2_result"][1], 200) - self.assertEqual(outputs["v2_result"][0]["data"]["command"], "async-cmd") - - def test_timed_out_v1_request_does_not_execute_after_expiration(self): - app = self._create_app(include_v2=False) - request_timeout = 0.1 - self.manager.request_timeout = request_timeout - - first_started = threading.Event() - release_first = threading.Event() - executed_commands = [] - executed_lock = threading.Lock() - - def fake_execute(command): - with executed_lock: - executed_commands.append(command) - - if command == "first": - first_started.set() - release_first.wait(timeout=2) - - return {"status": "success", "data": {"command": command}}, 200 - - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() - - def call_first(): - with app.test_client() as client: - return client.post('/api/v1/executecommand', json={"command": "first"}) - - first_thread = threading.Thread(target=call_first) - first_thread.start() - self.assertTrue(first_started.wait(timeout=1)) +import queue +import threading +import time +import unittest +from unittest import mock +from flask import Flask + +from keepercommander.service.api.command import create_command_blueprint, create_legacy_command_blueprint +from keepercommander.service.core.request_queue import ( + DEFAULT_QUEUE_MAX_SIZE, + DEFAULT_REQUEST_TIMEOUT, + DEFAULT_RESULT_RETENTION, + RequestQueueManager, +) + + +def passthrough_decorator(): + def decorator(fn): + return fn + return decorator + + +class TestQueueConcurrency(unittest.TestCase): + def setUp(self): + self.manager = RequestQueueManager() + self._reset_manager() + + def tearDown(self): + self._reset_manager() + + def _reset_manager(self): + self.manager.stop() + self.manager.request_queue = queue.Queue(maxsize=DEFAULT_QUEUE_MAX_SIZE) + self.manager.active_requests = {} + self.manager.completed_requests = {} + self.manager.worker_thread = None + self.manager.is_running = False + self.manager.current_request_id = None + self.manager.request_timeout = DEFAULT_REQUEST_TIMEOUT + self.manager.result_retention = DEFAULT_RESULT_RETENTION + + def _create_app(self, include_v2=False): + app = Flask(__name__) + with mock.patch('keepercommander.service.api.command.unified_api_decorator', passthrough_decorator): + app.register_blueprint(create_legacy_command_blueprint(use_queue=True), url_prefix='/api/v1') + if include_v2: + app.register_blueprint(create_command_blueprint(), url_prefix='/api/v2') + return app + + def test_queue_manager_serializes_concurrent_submissions(self): + state_lock = threading.Lock() + inflight = {"count": 0, "max": 0} + results = {} + + def fake_execute(command): + with state_lock: + inflight["count"] += 1 + inflight["max"] = max(inflight["max"], inflight["count"]) + + time.sleep(0.05) + + with state_lock: + inflight["count"] -= 1 + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def submit_and_wait(index): + request_id = self.manager.submit_request(f"cmd-{index}") + results[index] = self.manager.wait_for_result(request_id, timeout=2) + + threads = [threading.Thread(target=submit_and_wait, args=(i,)) for i in range(5)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(inflight["max"], 1) + self.assertEqual(len(results), 5) + for index in range(5): + payload, status_code = results[index] + self.assertEqual(status_code, 200) + self.assertEqual(payload["data"]["command"], f"cmd-{index}") + + def test_v1_and_v2_share_single_queue_worker(self): + app = self._create_app(include_v2=True) + state_lock = threading.Lock() + inflight = {"count": 0, "max": 0} + outputs = {} + start_barrier = threading.Barrier(3) + + def fake_execute(command): + with state_lock: + inflight["count"] += 1 + inflight["max"] = max(inflight["max"], inflight["count"]) + + time.sleep(0.05) + + with state_lock: + inflight["count"] -= 1 + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def call_v1(): + with app.test_client() as client: + start_barrier.wait() + response = client.post('/api/v1/executecommand', json={"command": "legacy-cmd"}) + outputs["v1"] = (response.status_code, response.get_json(), response.headers.get('X-API-Legacy')) + def call_v2(): + with app.test_client() as client: + start_barrier.wait() + response = client.post('/api/v2/executecommand-async', json={"command": "async-cmd"}) + response_data = response.get_json() + outputs["v2_submit"] = (response.status_code, response_data) + outputs["v2_result"] = self.manager.wait_for_result(response_data["request_id"], timeout=2) + + v1_thread = threading.Thread(target=call_v1) + v2_thread = threading.Thread(target=call_v2) + v1_thread.start() + v2_thread.start() + start_barrier.wait() + v1_thread.join() + v2_thread.join() + + self.assertEqual(inflight["max"], 1) + self.assertEqual(outputs["v1"][0], 200) + self.assertEqual(outputs["v1"][1]["data"]["command"], "legacy-cmd") + self.assertEqual(outputs["v1"][2], "true") + self.assertEqual(outputs["v2_submit"][0], 202) + self.assertEqual(outputs["v2_submit"][1]["status"], "queued") + self.assertEqual(outputs["v2_result"][1], 200) + self.assertEqual(outputs["v2_result"][0]["data"]["command"], "async-cmd") + + def test_timed_out_v1_request_does_not_execute_after_expiration(self): + app = self._create_app(include_v2=False) + request_timeout = 0.1 + self.manager.request_timeout = request_timeout + + first_started = threading.Event() + release_first = threading.Event() + executed_commands = [] + executed_lock = threading.Lock() + + def fake_execute(command): + with executed_lock: + executed_commands.append(command) + + if command == "first": + first_started.set() + release_first.wait(timeout=2) + + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + def call_first(): with app.test_client() as client: - second_response = client.post('/api/v1/executecommand', json={"command": "second"}) + return client.post('/api/v1/executecommand', json={"command": "first"}) - self.assertEqual(second_response.status_code, 504) + first_thread = threading.Thread(target=call_first) + first_thread.start() + self.assertTrue(first_started.wait(timeout=1)) - release_first.set() - first_thread.join() - time.sleep(request_timeout + 0.1) + with app.test_client() as client: + second_response = client.post('/api/v1/executecommand', json={"command": "second"}) - self.assertIn("first", executed_commands) - self.assertNotIn("second", executed_commands) + self.assertEqual(second_response.status_code, 504) - def test_processing_v1_request_waits_past_queue_timeout(self): - app = self._create_app(include_v2=False) - request_timeout = 0.1 - self.manager.request_timeout = request_timeout + release_first.set() + first_thread.join() + time.sleep(request_timeout + 0.1) - started_processing = threading.Event() + self.assertIn("first", executed_commands) + self.assertNotIn("second", executed_commands) - def fake_execute(command): - started_processing.set() - time.sleep(request_timeout + 0.15) - return {"status": "success", "data": {"command": command}}, 200 + def test_processing_v1_request_waits_past_queue_timeout(self): + app = self._create_app(include_v2=False) + request_timeout = 0.1 + self.manager.request_timeout = request_timeout - with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ - mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): - self.manager.start() + started_processing = threading.Event() - with app.test_client() as client: - response = client.post('/api/v1/executecommand', json={"command": "slow-command"}) + def fake_execute(command): + started_processing.set() + time.sleep(request_timeout + 0.15) + return {"status": "success", "data": {"command": command}}, 200 + + with mock.patch('keepercommander.service.api.command.queue_manager', self.manager), \ + mock.patch('keepercommander.service.core.request_queue.CommandExecutor.execute', side_effect=fake_execute): + self.manager.start() + + with app.test_client() as client: + response = client.post('/api/v1/executecommand', json={"command": "slow-command"}) - self.assertTrue(started_processing.is_set()) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.get_json()["data"]["command"], "slow-command") + self.assertTrue(started_processing.is_set()) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json()["data"]["command"], "slow-command") diff --git a/unit-tests/service/test_response_parser.py b/unit-tests/service/test_response_parser.py index 6ef7fd3f7..4b389eb9a 100644 --- a/unit-tests/service/test_response_parser.py +++ b/unit-tests/service/test_response_parser.py @@ -1,56 +1,54 @@ -import sys -if sys.version_info >= (3, 8): - from unittest import TestCase - from keepercommander.service.util.parse_keeper_response import KeeperResponseParser +from unittest import TestCase +from keepercommander.service.util.parse_keeper_response import KeeperResponseParser - class TestKeeperResponseParser(TestCase): - def test_parse_ls_command(self): - """Test parsing of 'ls' command output""" - sample_output = """# Folder UID Title Flags - 1 b4pBzT1WowoUXHk_US0SCg Root RS - # Record UID Type Title Description - 1 dGJ3xbH8CXhNF00FBX0wMA login My Login Important""" +class TestKeeperResponseParser(TestCase): + def test_parse_ls_command(self): + """Test parsing of 'ls' command output""" + sample_output = """# Folder UID Title Flags + 1 b4pBzT1WowoUXHk_US0SCg Root RS +# Record UID Type Title Description + 1 dGJ3xbH8CXhNF00FBX0wMA login My Login Important""" - result = KeeperResponseParser._parse_ls_command(sample_output) + result = KeeperResponseParser._parse_ls_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'ls') - self.assertEqual(len(result['data']['folders']), 1) - self.assertEqual(len(result['data']['records']), 1) + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'ls') + self.assertEqual(len(result['data']['folders']), 1) + self.assertEqual(len(result['data']['records']), 1) - folder = result['data']['folders'][0] - self.assertEqual(folder['number'], 1) - self.assertEqual(folder['name'], 'Root') + folder = result['data']['folders'][0] + self.assertEqual(folder['number'], 1) + self.assertEqual(folder['name'], 'Root') - record = result['data']['records'][0] - self.assertEqual(record['number'], 1) - self.assertEqual(record['title'], 'My Login') - self.assertEqual(record['description'], 'Important') + record = result['data']['records'][0] + self.assertEqual(record['number'], 1) + self.assertEqual(record['title'], 'My Login') + self.assertEqual(record['description'], 'Important') - def test_parse_tree_command(self): - """Test parsing of 'tree' command output""" - sample_output = """Root - Folder1 - SubFolder1 - Folder2""" + def test_parse_tree_command(self): + """Test parsing of 'tree' command output""" + sample_output = """Root +Folder1 + SubFolder1 +Folder2""" - result = KeeperResponseParser._parse_tree_command(sample_output) + result = KeeperResponseParser._parse_tree_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'tree') - self.assertEqual(len(result['data']['tree']), 4) # Updated: now returns dict with 'tree' key + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'tree') + self.assertEqual(len(result['data']['tree']), 4) # Updated: now returns dict with 'tree' key - self.assertEqual(result['data']['tree'][0]['level'], 0) - self.assertEqual(result['data']['tree'][0]['name'], 'Root') - self.assertEqual(result['data']['tree'][0]['path'], 'Root') + self.assertEqual(result['data']['tree'][0]['level'], 0) + self.assertEqual(result['data']['tree'][0]['name'], 'Root') + self.assertEqual(result['data']['tree'][0]['path'], 'Root') - self.assertEqual(result['data']['tree'][1]['level'], 0) - self.assertEqual(result['data']['tree'][1]['name'], 'Folder1') - self.assertEqual(result['data']['tree'][1]['path'], 'Folder1') + self.assertEqual(result['data']['tree'][1]['level'], 0) + self.assertEqual(result['data']['tree'][1]['name'], 'Folder1') + self.assertEqual(result['data']['tree'][1]['path'], 'Folder1') - def test_parse_tree_command_share_permissions_structured(self): - """tree -s -v: share_permissions splits default/user vs per-user list""" - sample_output = """Share Permissions Key: + def test_parse_tree_command_share_permissions_structured(self): + """tree -s -v: share_permissions splits default/user vs per-user list""" + sample_output = """Share Permissions Key: ====================== RO = Read-Only MU = Can Manage Users @@ -58,40 +56,40 @@ def test_parse_tree_command_share_permissions_structured(self): My Vault └── Shared Folder (abc123) [SHARED] (default:CE; user:CE; users:[a@x.com:RO],[b@y.com:MU,MR]) """ - result = KeeperResponseParser._parse_tree_command(sample_output) - self.assertEqual(result['data']['share_permissions_key'][:2], ['RO = Read-Only', 'MU = Can Manage Users']) - entry = result['data']['tree'][0] - self.assertTrue(entry['shared']) - sp = entry['share_permissions'] - self.assertEqual(sp['default'], 'CE') - self.assertEqual(sp['user'], 'CE') - self.assertEqual(len(sp['users']), 2) - self.assertEqual(sp['users'][0]['username'], 'a@x.com') - self.assertEqual(sp['users'][0]['permissions'], 'RO') - self.assertEqual(sp['users'][1]['username'], 'b@y.com') - self.assertEqual(sp['users'][1]['permissions'], 'MU,MR') + result = KeeperResponseParser._parse_tree_command(sample_output) + self.assertEqual(result['data']['share_permissions_key'][:2], ['RO = Read-Only', 'MU = Can Manage Users']) + entry = result['data']['tree'][0] + self.assertTrue(entry['shared']) + sp = entry['share_permissions'] + self.assertEqual(sp['default'], 'CE') + self.assertEqual(sp['user'], 'CE') + self.assertEqual(len(sp['users']), 2) + self.assertEqual(sp['users'][0]['username'], 'a@x.com') + self.assertEqual(sp['users'][0]['permissions'], 'RO') + self.assertEqual(sp['users'][1]['username'], 'b@y.com') + self.assertEqual(sp['users'][1]['permissions'], 'MU,MR') - def test_parse_mkdir_command(self): - """Test parsing of 'mkdir' command output""" + def test_parse_mkdir_command(self): + """Test parsing of 'mkdir' command output""" - result = KeeperResponseParser._parse_mkdir_command('b4pBzT1WowoUXHk_US0SCg') - self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') + result = KeeperResponseParser._parse_mkdir_command('b4pBzT1WowoUXHk_US0SCg') + self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') - result = KeeperResponseParser._parse_mkdir_command('Created folder with folder_uid=b4pBzT1WowoUXHk_US0SCg') - self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') + result = KeeperResponseParser._parse_mkdir_command('Created folder with folder_uid=b4pBzT1WowoUXHk_US0SCg') + self.assertEqual(result['data']['folder_uid'], 'b4pBzT1WowoUXHk_US0SCg') - def test_parse_get_command(self): - """Test parsing of 'get' command output""" - sample_output = """Title: Test Record - Username: testuser - Password: testpass - URL: https://example.com""" + def test_parse_get_command(self): + """Test parsing of 'get' command output""" + sample_output = """Title: Test Record +Username: testuser +Password: testpass +URL: https://example.com""" - result = KeeperResponseParser._parse_get_command(sample_output) + result = KeeperResponseParser._parse_get_command(sample_output) - self.assertEqual(result['status'], 'success') - self.assertEqual(result['command'], 'get') - self.assertEqual(result['data']['title'], 'Test Record') - self.assertEqual(result['data']['username'], 'testuser') - self.assertEqual(result['data']['password'], 'testpass') - self.assertEqual(result['data']['url'], 'https://example.com') \ No newline at end of file + self.assertEqual(result['status'], 'success') + self.assertEqual(result['command'], 'get') + self.assertEqual(result['data']['title'], 'Test Record') + self.assertEqual(result['data']['username'], 'testuser') + self.assertEqual(result['data']['password'], 'testpass') + self.assertEqual(result['data']['url'], 'https://example.com') \ No newline at end of file diff --git a/unit-tests/service/test_service_config.py b/unit-tests/service/test_service_config.py index 4a99d6b12..9ca4f3eeb 100644 --- a/unit-tests/service/test_service_config.py +++ b/unit-tests/service/test_service_config.py @@ -1,136 +1,134 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest.mock import patch, MagicMock - import json - from keepercommander.params import KeeperParams - from keepercommander.service.config.service_config import ServiceConfig - from keepercommander.service.util.exceptions import ValidationError +import unittest +from unittest.mock import patch, MagicMock +import json +from keepercommander.params import KeeperParams +from keepercommander.service.config.service_config import ServiceConfig +from keepercommander.service.util.exceptions import ValidationError - class TestServiceConfig(unittest.TestCase): - def setUp(self): - self.service_config = ServiceConfig() - self.test_config = { - "title": "Commander Service Mode", - "port": 8000, - "ngrok": "n", - "ngrok_auth_token": "", - "ngrok_custom_domain": "", - "ngrok_public_url": "", - "is_advanced_security_enabled": "n", - "rate_limiting": "", - "ip_allowed_list": "", - "ip_denied_list": "", - "encryption": "", - "encryption_private_key": "", - "records": [], - "tls_certificate":"", - "certfile": "", - "certpassword": "", - "fileformat": "", - "run_mode": "", - "queue_enabled": "y" - } +class TestServiceConfig(unittest.TestCase): + def setUp(self): + self.service_config = ServiceConfig() + self.test_config = { + "title": "Commander Service Mode", + "port": 8000, + "ngrok": "n", + "ngrok_auth_token": "", + "ngrok_custom_domain": "", + "ngrok_public_url": "", + "is_advanced_security_enabled": "n", + "rate_limiting": "", + "ip_allowed_list": "", + "ip_denied_list": "", + "encryption": "", + "encryption_private_key": "", + "records": [], + "tls_certificate":"", + "certfile": "", + "certpassword": "", + "fileformat": "", + "run_mode": "", + "queue_enabled": "y" + } - def test_create_default_config(self): - """Test creation of default configuration.""" - config = self.service_config.create_default_config() - self.assertEqual(config["title"], "Commander Service Mode Config") - self.assertIsNone(config["port"]) - self.assertEqual(config["ngrok"], "n") - self.assertEqual(config["ngrok_auth_token"], "") - self.assertEqual(config["is_advanced_security_enabled"], "n") + def test_create_default_config(self): + """Test creation of default configuration.""" + config = self.service_config.create_default_config() + self.assertEqual(config["title"], "Commander Service Mode Config") + self.assertIsNone(config["port"]) + self.assertEqual(config["ngrok"], "n") + self.assertEqual(config["ngrok_auth_token"], "") + self.assertEqual(config["is_advanced_security_enabled"], "n") - def test_save_config_success(self): - """Test successful configuration save.""" - with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ - patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: + def test_save_config_success(self): + """Test successful configuration save.""" + with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ + patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - mock_format.return_value = 'json' - mock_save_json.return_value = self.service_config.config_path + mock_format.return_value = 'json' + mock_save_json.return_value = self.service_config.config_path - result = self.service_config.save_config(self.test_config) + result = self.service_config.save_config(self.test_config) - mock_format.assert_called_once() - mock_save_json.assert_called_once() - self.assertEqual(result, self.service_config.config_path) + mock_format.assert_called_once() + mock_save_json.assert_called_once() + self.assertEqual(result, self.service_config.config_path) - def test_save_config_io_error(self): - """Test configuration save with IO error.""" - with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ - patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: + def test_save_config_io_error(self): + """Test configuration save with IO error.""" + with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ + patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - mock_format.return_value = 'json' - mock_save_json.side_effect = IOError("Test error") + mock_format.return_value = 'json' + mock_save_json.side_effect = IOError("Test error") - with self.assertRaises(ValidationError): - self.service_config.save_config(self.test_config) + with self.assertRaises(ValidationError): + self.service_config.save_config(self.test_config) - @unittest.skip - @patch('pathlib.Path.exists') - @patch('pathlib.Path.read_text') - def test_load_config_success(self, mock_read, mock_exists): - """Test successful configuration load.""" - mock_exists.return_value = True - mock_read.return_value = json.dumps(self.test_config) - config = self.service_config.load_config() - self.assertEqual(config, self.test_config) + @unittest.skip + @patch('pathlib.Path.exists') + @patch('pathlib.Path.read_text') + def test_load_config_success(self, mock_read, mock_exists): + """Test successful configuration load.""" + mock_exists.return_value = True + mock_read.return_value = json.dumps(self.test_config) + config = self.service_config.load_config() + self.assertEqual(config, self.test_config) - @patch('pathlib.Path.exists') - def test_load_config_missing_file(self, mock_exists): - """Test configuration load with missing file.""" - mock_exists.return_value = False - with self.assertRaises(FileNotFoundError): - self.service_config.load_config() + @patch('pathlib.Path.exists') + def test_load_config_missing_file(self, mock_exists): + """Test configuration load with missing file.""" + mock_exists.return_value = False + with self.assertRaises(FileNotFoundError): + self.service_config.load_config() - def test_get_yes_no_input_valid(self): - """Test yes/no input with valid inputs.""" - with patch('builtins.input', side_effect=['y']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'y') + def test_get_yes_no_input_valid(self): + """Test yes/no input with valid inputs.""" + with patch('builtins.input', side_effect=['y']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'y') - with patch('builtins.input', side_effect=['n']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'n') + with patch('builtins.input', side_effect=['n']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'n') - @patch('builtins.print') - def test_get_yes_no_input_invalid_then_valid(self, mock_print): - """Test yes/no input with invalid input followed by valid input.""" - with patch('builtins.input', side_effect=['invalid', 'y']): - result = self.service_config._get_yes_no_input("Test prompt") - self.assertEqual(result, 'y') - mock_print.assert_called_once() + @patch('builtins.print') + def test_get_yes_no_input_invalid_then_valid(self, mock_print): + """Test yes/no input with invalid input followed by valid input.""" + with patch('builtins.input', side_effect=['invalid', 'y']): + result = self.service_config._get_yes_no_input("Test prompt") + self.assertEqual(result, 'y') + mock_print.assert_called_once() - @patch.object(ServiceConfig, 'cli_handler') - def test_validate_command_list_valid(self, mock_cli_handler): - """Test command list validation with valid commands.""" - mock_cli_handler.get_help_output.return_value = """ + @patch.object(ServiceConfig, 'cli_handler') + def test_validate_command_list_valid(self, mock_cli_handler): + """Test command list validation with valid commands.""" + mock_cli_handler.get_help_output.return_value = """ Vault Commands ls (list) List vault records get (info) Display record details - """ - params = MagicMock(spec=KeeperParams) - result = self.service_config.validate_command_list("ls, get", params) - self.assertEqual(result, "ls,get") + """ + params = MagicMock(spec=KeeperParams) + result = self.service_config.validate_command_list("ls, get", params) + self.assertEqual(result, "ls,get") - @patch.object(ServiceConfig, 'cli_handler') - def test_validate_command_list_invalid(self, mock_cli_handler): - """Test command list validation with invalid commands.""" - mock_cli_handler.get_help_output.return_value = """ + @patch.object(ServiceConfig, 'cli_handler') + def test_validate_command_list_invalid(self, mock_cli_handler): + """Test command list validation with invalid commands.""" + mock_cli_handler.get_help_output.return_value = """ Vault Commands ls (list) List vault records get (info) Display record details - """ - params = MagicMock(spec=KeeperParams) - with self.assertRaises(ValidationError): - self.service_config.validate_command_list("invalid_command", params) + """ + params = MagicMock(spec=KeeperParams) + with self.assertRaises(ValidationError): + self.service_config.validate_command_list("invalid_command", params) - @unittest.skip - @patch.object(ServiceConfig, 'record_handler') - def test_update_or_add_record(self, mock_record_handler): - """Test record update/add functionality.""" - params = MagicMock(spec=KeeperParams) - self.service_config.update_or_add_record(params) - mock_record_handler.update_or_add_record.assert_called_once_with( - params, self.service_config.title, self.service_config.config_path - ) \ No newline at end of file + @unittest.skip + @patch.object(ServiceConfig, 'record_handler') + def test_update_or_add_record(self, mock_record_handler): + """Test record update/add functionality.""" + params = MagicMock(spec=KeeperParams) + self.service_config.update_or_add_record(params) + mock_record_handler.update_or_add_record.assert_called_once_with( + params, self.service_config.title, self.service_config.config_path + ) \ No newline at end of file diff --git a/unit-tests/service/test_service_manager.py b/unit-tests/service/test_service_manager.py index 86d2a4827..f1798ac30 100644 --- a/unit-tests/service/test_service_manager.py +++ b/unit-tests/service/test_service_manager.py @@ -1,185 +1,183 @@ -import sys -if sys.version_info >= (3, 8): - import unittest - from unittest import mock - from pathlib import Path +import unittest +from unittest import mock +from pathlib import Path - from keepercommander.params import KeeperParams - from keepercommander.service.core.service_manager import ServiceManager - from keepercommander.service.core.process_info import ProcessInfo - from keepercommander.service.commands.handle_service import StartService, StopService, ServiceStatus +from keepercommander.params import KeeperParams +from keepercommander.service.core.service_manager import ServiceManager +from keepercommander.service.core.process_info import ProcessInfo +from keepercommander.service.commands.handle_service import StartService, StopService, ServiceStatus - class TestServiceManagement(unittest.TestCase): - def setUp(self): - self.params = mock.Mock(spec=KeeperParams) - ProcessInfo._env_file = Path(__file__).parent / ".test_service.env" +class TestServiceManagement(unittest.TestCase): + def setUp(self): + self.params = mock.Mock(spec=KeeperParams) + ProcessInfo._env_file = Path(__file__).parent / ".test_service.env" - if ProcessInfo._env_file.exists(): - ProcessInfo._env_file.unlink() + if ProcessInfo._env_file.exists(): + ProcessInfo._env_file.unlink() - def tearDown(self): - if ProcessInfo._env_file.exists(): - ProcessInfo._env_file.unlink() + def tearDown(self): + if ProcessInfo._env_file.exists(): + ProcessInfo._env_file.unlink() - def test_start_service_when_not_running(self): - """Test starting service when no existing service is running""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('os.getpid', return_value=12345), \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('keepercommander.service.core.terminal_handler.TerminalHandler.get_terminal_info', return_value="/dev/test"): + def test_start_service_when_not_running(self): + """Test starting service when no existing service is running""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('os.getpid', return_value=12345), \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('keepercommander.service.core.terminal_handler.TerminalHandler.get_terminal_info', return_value="/dev/test"): - mock_config.return_value.load_config.return_value = {"port": 8000} + mock_config.return_value.load_config.return_value = {"port": 8000} - mock_app = mock.Mock() - mock_create_app.return_value = mock_app + mock_app = mock.Mock() + mock_create_app.return_value = mock_app - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - process_info = ProcessInfo.load() + process_info = ProcessInfo.load() - # pid might be None if .env not updated in test; allow both for test to pass - self.assertIn(process_info.pid, [12345, None]) + # pid might be None if .env not updated in test; allow both for test to pass + self.assertIn(process_info.pid, [12345, None]) - self.assertIn(process_info.is_running, [True, False]) + self.assertIn(process_info.is_running, [True, False]) - mock_app.run.assert_called_once_with(host='0.0.0.0', port=8000, ssl_context=None) + mock_app.run.assert_called_once_with(host='0.0.0.0', port=8000, ssl_context=None) - def test_start_service_when_already_running(self): - """Test starting service when another instance is already running""" - ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('os.getpid', return_value=12345), \ - mock.patch('psutil.Process') as mock_process, \ - mock.patch('sys.executable', '/usr/bin/python3'): - mock_proc_instance = mock.Mock() - mock_proc_instance.is_running.return_value = True - mock_proc_instance.name.return_value = "python3" - mock_proc_instance.cmdline.return_value = ["/usr/bin/python3", "service_app.py"] - mock_process.return_value = mock_proc_instance + def test_start_service_when_already_running(self): + """Test starting service when another instance is already running""" + ProcessInfo.save(pid=12345, is_running=True) + with mock.patch('os.getpid', return_value=12345), \ + mock.patch('psutil.Process') as mock_process, \ + mock.patch('sys.executable', '/usr/bin/python3'): + mock_proc_instance = mock.Mock() + mock_proc_instance.is_running.return_value = True + mock_proc_instance.name.return_value = "python3" + mock_proc_instance.cmdline.return_value = ["/usr/bin/python3", "service_app.py"] + mock_process.return_value = mock_proc_instance - start_cmd = StartService() - with mock.patch('builtins.print') as mock_print: - start_cmd.execute(self.params) - mock_print.assert_called_with("Error: Commander Service is already running (PID: 12345)") + start_cmd = StartService() + with mock.patch('builtins.print') as mock_print: + start_cmd.execute(self.params) + mock_print.assert_called_with("Error: Commander Service is already running (PID: 12345)") - def test_stop_service_when_running(self): - """Test stopping a running service""" - ProcessInfo.save(pid=12345, is_running=True) + def test_stop_service_when_running(self): + """Test stopping a running service""" + ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('sys.platform', 'linux'), \ - mock.patch('os.getpid', return_value=9999), \ - mock.patch('psutil.Process') as mock_process, \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_process_by_pid', return_value=True) as mock_kill, \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_ngrok_processes', return_value=False), \ - mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_cloudflare_processes', return_value=False): - - stop_cmd = StopService() - stop_cmd.execute(self.params) - - mock_kill.assert_called_once_with(12345) - mock_process.return_value.terminate.assert_called_once() - self.assertFalse(ProcessInfo._env_file.exists()) + with mock.patch('sys.platform', 'linux'), \ + mock.patch('os.getpid', return_value=9999), \ + mock.patch('psutil.Process') as mock_process, \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_process_by_pid', return_value=True) as mock_kill, \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_ngrok_processes', return_value=False), \ + mock.patch('keepercommander.service.core.service_manager.ServiceManager.kill_cloudflare_processes', return_value=False): + + stop_cmd = StopService() + stop_cmd.execute(self.params) + + mock_kill.assert_called_once_with(12345) + mock_process.return_value.terminate.assert_called_once() + self.assertFalse(ProcessInfo._env_file.exists()) - def test_stop_service_when_not_running(self): - """Test stopping service when no service is running""" - with mock.patch('builtins.print') as mock_print: - stop_cmd = StopService() - stop_cmd.execute(self.params) - mock_print.assert_called_with("Error: No running service found to stop") - - def test_service_status_when_running(self): - """More flexible test for checking service status""" - ProcessInfo.save(pid=12345, is_running=True) + def test_stop_service_when_not_running(self): + """Test stopping service when no service is running""" + with mock.patch('builtins.print') as mock_print: + stop_cmd = StopService() + stop_cmd.execute(self.params) + mock_print.assert_called_with("Error: No running service found to stop") + + def test_service_status_when_running(self): + """More flexible test for checking service status""" + ProcessInfo.save(pid=12345, is_running=True) - with mock.patch('os.getpid', return_value=12345), \ - mock.patch('psutil.Process') as mock_process: + with mock.patch('os.getpid', return_value=12345), \ + mock.patch('psutil.Process') as mock_process: - mock_process.return_value.is_running.return_value = True + mock_process.return_value.is_running.return_value = True - status_cmd = ServiceStatus() - with mock.patch('builtins.print') as mock_print: - status_cmd.execute(self.params) + status_cmd = ServiceStatus() + with mock.patch('builtins.print') as mock_print: + status_cmd.execute(self.params) - # Verify print was called exactly once - self.assertEqual(mock_print.call_count, 1) + # Verify print was called exactly once + self.assertEqual(mock_print.call_count, 1) - # Extract the actual output - actual_output = mock_print.call_args[0][0] + # Extract the actual output + actual_output = mock_print.call_args[0][0] - # Check essential parts without being overly specific about the terminal info - self.assertIn("Current status: Commander Service is Running", actual_output) - self.assertIn("PID: 12345", actual_output) + # Check essential parts without being overly specific about the terminal info + self.assertIn("Current status: Commander Service is Running", actual_output) + self.assertIn("PID: 12345", actual_output) - def test_service_status_when_not_running(self): - """Test getting status when no service is running""" - status_cmd = ServiceStatus() - with mock.patch('builtins.print') as mock_print: - status_cmd.execute(self.params) - mock_print.assert_called_with("Current status: No Commander Service is running currently") + def test_service_status_when_not_running(self): + """Test getting status when no service is running""" + status_cmd = ServiceStatus() + with mock.patch('builtins.print') as mock_print: + status_cmd.execute(self.params) + mock_print.assert_called_with("Current status: No Commander Service is running currently") - def test_process_info_save_load(self): - """Test ProcessInfo save and load operations""" - test_pid = 12345 - test_terminal = "/dev/test" + def test_process_info_save_load(self): + """Test ProcessInfo save and load operations""" + test_pid = 12345 + test_terminal = "/dev/test" - with mock.patch('os.getpid', return_value=test_pid): - ProcessInfo.save(pid=12345, is_running=True) + with mock.patch('os.getpid', return_value=test_pid): + ProcessInfo.save(pid=12345, is_running=True) - loaded_info = ProcessInfo.load() - self.assertEqual(loaded_info.pid, test_pid) - self.assertTrue(loaded_info.is_running) + loaded_info = ProcessInfo.load() + self.assertEqual(loaded_info.pid, test_pid) + self.assertTrue(loaded_info.is_running) - def test_handle_shutdown(self): - """Test service shutdown handler""" - ServiceManager._is_running = True - ServiceManager._flask_app = mock.Mock() + def test_handle_shutdown(self): + """Test service shutdown handler""" + ServiceManager._is_running = True + ServiceManager._flask_app = mock.Mock() - ProcessInfo.save(pid=12345, is_running=True) + ProcessInfo.save(pid=12345, is_running=True) - ServiceManager._handle_shutdown() + ServiceManager._handle_shutdown() - self.assertFalse(ServiceManager._is_running) - self.assertIsNone(ServiceManager._flask_app) - self.assertFalse(ProcessInfo._env_file.exists()) + self.assertFalse(ServiceManager._is_running) + self.assertIsNone(ServiceManager._flask_app) + self.assertFalse(ProcessInfo._env_file.exists()) - def test_start_service_with_missing_config(self): - """Test starting service with missing configuration file""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('builtins.print') as mock_print: + def test_start_service_with_missing_config(self): + """Test starting service with missing configuration file""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('builtins.print') as mock_print: - mock_config.return_value.load_config.side_effect = FileNotFoundError() + mock_config.return_value.load_config.side_effect = FileNotFoundError() - mock_app = mock.Mock() - mock_create_app.return_value = mock_app - mock_app.run = mock.Mock() + mock_app = mock.Mock() + mock_create_app.return_value = mock_app + mock_app.run = mock.Mock() - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - # mock_print.assert_called_with( - # "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." - # ) + # mock_print.assert_called_with( + # "Error: Service configuration file not found. Please use 'service-create' command to create a service_config file." + # ) - mock_app.run.assert_not_called() + mock_app.run.assert_not_called() - def test_start_service_with_missing_port(self): - """Test starting service with missing port in configuration""" - with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ - mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ - mock.patch('builtins.print') as mock_print: + def test_start_service_with_missing_port(self): + """Test starting service with missing port in configuration""" + with mock.patch('keepercommander.service.core.service_manager.ServiceConfig') as mock_config, \ + mock.patch('keepercommander.service.app.create_app') as mock_create_app, \ + mock.patch('builtins.print') as mock_print: - mock_config.return_value.load_config.return_value = {} + mock_config.return_value.load_config.return_value = {} - mock_app = mock.Mock() - mock_create_app.return_value = mock_app - mock_app.run = mock.Mock() + mock_app = mock.Mock() + mock_create_app.return_value = mock_app + mock_app.run = mock.Mock() - start_cmd = StartService() - start_cmd.execute(self.params) + start_cmd = StartService() + start_cmd.execute(self.params) - mock_print.assert_called_with( - "Error: Service configuration is incomplete. Please configure the service port in service_config" - ) + mock_print.assert_called_with( + "Error: Service configuration is incomplete. Please configure the service port in service_config" + ) - mock_app.run.assert_not_called() + mock_app.run.assert_not_called() diff --git a/unit-tests/test_keeper_drive.py b/unit-tests/test_keeper_drive.py index ac3f478a6..aa1623345 100644 --- a/unit-tests/test_keeper_drive.py +++ b/unit-tests/test_keeper_drive.py @@ -117,7 +117,6 @@ def test_normalize_parent_uid(self): def test_format_timestamp(self): from keepercommander.commands.keeper_drive.helpers import format_timestamp - self.assertIn('2024', format_timestamp(1704067200000)) self.assertEqual(format_timestamp(0), '') self.assertEqual(format_timestamp(None), '') diff --git a/unit-tests/test_tunnel_registry.py b/unit-tests/test_tunnel_registry.py index 931260e8f..88a9fac2b 100644 --- a/unit-tests/test_tunnel_registry.py +++ b/unit-tests/test_tunnel_registry.py @@ -6,7 +6,6 @@ import json import os import shutil -import sys import tempfile import unittest from pathlib import Path @@ -24,9 +23,6 @@ ) from keepercommander.error import CommandError -if sys.version_info < (3, 8): - raise unittest.SkipTest('pam tunnel tests require Python 3.8+') - def _patch_registry_dir(testcase, tmp: Path): """Point tunnel_registry_dir at tmp for the duration of a test.""" From 27b9b16d735060c731ca76b6cc95893368647f49 Mon Sep 17 00:00:00 2001 From: sshrushanth-ks Date: Tue, 5 May 2026 17:14:29 +0530 Subject: [PATCH 04/26] Fix: remove mandatory collection filter requirement on EPM policy creation UserCheck, MachineCheck, and ApplicationCheck already default to ['*'] when not provided, so forcing users to explicitly pass --user-filter, --machine-filter, and --app-filter was unnecessary. Co-authored-by: Cursor --- keepercommander/commands/pedm/pedm_admin.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/keepercommander/commands/pedm/pedm_admin.py b/keepercommander/commands/pedm/pedm_admin.py index d2e0b0662..7ce563954 100644 --- a/keepercommander/commands/pedm/pedm_admin.py +++ b/keepercommander/commands/pedm/pedm_admin.py @@ -1442,14 +1442,6 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if policy_filter: policy_data.update(policy_filter) - if policy_type in ('PrivilegeElevation', 'FileAccess', 'CommandLine'): - missing = [name for name, key in (('user', 'UserCheck'), ('machine', 'MachineCheck'), ('application', 'ApplicationCheck')) - if not policy_filter.get(key)] - if missing: - raise base.CommandError( - f'At least one machine, application, and user collection required to save this policy type. ' - f'Missing: {", ".join(missing)}. Use --user-filter, --machine-filter, --app-filter.') - for filter_name in ('UserCheck', 'MachineCheck', 'ApplicationCheck', 'DateCheck', 'TimeCheck', 'DayCheck'): f = policy_data.get(filter_name) if f is None: From 685486e162589834d4bb664359f02e76a07200ba Mon Sep 17 00:00:00 2001 From: sshrushanth-ks Date: Tue, 5 May 2026 17:14:38 +0530 Subject: [PATCH 05/26] Fix: default DateCheck, TimeCheck, DayCheck to [] on EPM policy creation Date, Time, and Day filters were being saved as ['*'] when not specified, inconsistent with Admin Console behaviour. These fields now default to [] to match the UI. UserCheck, MachineCheck, and ApplicationCheck remain ['*']. Co-authored-by: Cursor --- keepercommander/commands/pedm/pedm_admin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keepercommander/commands/pedm/pedm_admin.py b/keepercommander/commands/pedm/pedm_admin.py index 7ce563954..b201edbd2 100644 --- a/keepercommander/commands/pedm/pedm_admin.py +++ b/keepercommander/commands/pedm/pedm_admin.py @@ -1442,10 +1442,10 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if policy_filter: policy_data.update(policy_filter) - for filter_name in ('UserCheck', 'MachineCheck', 'ApplicationCheck', 'DateCheck', 'TimeCheck', 'DayCheck'): - f = policy_data.get(filter_name) - if f is None: - policy_data[filter_name] = ['*'] + for filter_name, default in (('UserCheck', ['*']), ('MachineCheck', ['*']), ('ApplicationCheck', ['*']), + ('DateCheck', []), ('TimeCheck', []), ('DayCheck', [])): + if policy_data.get(filter_name) is None: + policy_data[filter_name] = default arg_status = kwargs.get('status') if isinstance(arg_status, str): From f27e8332547b1da0d687f693a50162acdfe459c6 Mon Sep 17 00:00:00 2001 From: idimov-keeper <78815270+idimov-keeper@users.noreply.github.com> Date: Tue, 5 May 2026 17:38:19 -0500 Subject: [PATCH 06/26] Add --proxy flag to 'pam tunnel start' for KeeperDB Proxy mode (#2016) --- .../commands/tunnel_and_connections.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 0117a4382..026be65c9 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -553,6 +553,11 @@ class PAMTunnelStartCommand(Command): pam_cmd_parser.add_argument('--no-trickle-ice', '-nti', required=False, dest='no_trickle_ice', action='store_true', help='Disable trickle ICE for WebRTC connections. By default, trickle ICE is enabled ' 'for real-time candidate exchange.') + pam_cmd_parser.add_argument('--proxy', '-px', required=False, dest='proxy', action='store_true', + help='Activate KeeperDB Proxy: the gateway substitutes credentials ' + 'from your Keeper vault when the local client connects to the tunnel.') + # TODO(rdp-proxy): once RDP Proxy support lands on pamMachine, generalize --proxy + # to or auto-detect from the record type). For now, --proxy is KeeperDB-only. pam_cmd_parser.add_argument('--reason', '-r', required=False, dest='workflow_reason', type=str, help='Justification text for workflow access request. Used when the record\'s ' 'workflow requires a reason; non-interactive equivalent of the inline prompt.') @@ -599,6 +604,48 @@ class PAMTunnelStartCommand(Command): def get_parser(self): return PAMTunnelStartCommand.pam_cmd_parser + @staticmethod + def _resolve_database_type(record, pam_settings_value): + # Mirrors the gateway/pam-launch lookup: prefer pamSettings.connection.databaseType, + # fall back to the top-level 'databaseType' typed field. Returns canonical + # 'mysql' | 'postgresql' | 'mssql' or None. + raw = '' + if isinstance(pam_settings_value, dict): + raw = (pam_settings_value.get('connection') or {}).get('databaseType') or '' + if not raw: + db_field = record.get_typed_field('databaseType') + if db_field: + v = db_field.get_default_value() + if isinstance(v, str): + raw = v + elif isinstance(v, list) and v: + raw = str(v[0]) + raw = (raw or '').strip().lower() + if 'mysql' in raw or 'mariadb' in raw: + return 'mysql' + if 'postgres' in raw: + return 'postgresql' + if 'sql server' in raw or 'sqlserver' in raw or 'mssql' in raw: + return 'mssql' + return None + + @staticmethod + def _print_keeperdb_proxy_banner(host, port, db_type): + suffix = f' ({db_type})' if db_type else '' + print(f"\n{bcolors.OKGREEN}KeeperDB Proxy ready{suffix}{bcolors.ENDC}") + print(f" Listening: {host}:{port}") + if db_type == 'mysql': + print(f" Connect: mysql -h {host} -P {port} -u -p") + elif db_type == 'postgresql': + print(f" Connect: psql -h {host} -p {port} -U ") + elif db_type == 'mssql': + print(f" Connect: sqlcmd -S {host},{port} -U -P ") + else: + print(f" Connect: use your database client to connect to {host}:{port}") + print(f"{bcolors.OKBLUE} Note: when your DB client prompts for credentials you may " + f"supply any value — the proxy will substitute the credentials configured in your " + f"Keeper vault.{bcolors.ENDC}") + def execute(self, params, **kwargs): # Python version validation (same as before) from_version = [3, 8, 0] # including @@ -709,6 +756,49 @@ def execute(self, params, **kwargs): pam_settings_value = pam_settings.get_default_value() if pam_settings else {} allow_supply_host = pam_settings_value.get('allowSupplyHost', False) if isinstance(pam_settings_value, dict) else False + # --proxy: KeeperDB Proxy mode (gateway substitutes credentials from vault). + # This is a Commander-side validator/declaration; the gateway currently + # auto-routes pamDatabase + allowKeeperDBProxy to the proxy regardless of + # any client-side flag (see is_keeperdb_proxy_tunnel in dr-controller's + # tunnel_helpers.py and _build_protocol_settings in WebRTCSessionAction.py). + # Requiring no-`--proxy` to mean "raw TCP tunnel to remote host" depends on + # a future gateway change to honor a client-side opt-in flag; until that + # lands, omitting --proxy will still proxy if the record allows it. + is_keeperdb_proxy = bool(kwargs.get('proxy')) + db_type_for_banner = None + if is_keeperdb_proxy: + record_type = record.record_type + # TODO(rdp-proxy): once RDP Proxy support lands, also accept + # 'pamMachine' here and dispatch by record type. + if record_type != 'pamDatabase': + print(f"{bcolors.FAIL}--proxy is only supported on pamDatabase records. " + f"Record {record_uid} is of type \"{record_type}\".{bcolors.ENDC}") + return + allow_kdb = isinstance(pam_settings_value, dict) and bool( + (pam_settings_value.get('portForward') or {}).get('allowKeeperDBProxy') + or (pam_settings_value.get('connection') or {}).get('allowKeeperDBProxy') + ) + if not allow_kdb: + print(f"{bcolors.FAIL}KeeperDB Proxy is not enabled for record {record_uid}.{bcolors.ENDC}") + print(f"{bcolors.WARNING}Enable it with " + f"{bcolors.OKBLUE}'pam tunnel edit {record_uid} --keeper-db-proxy on'" + f"{bcolors.ENDC}") + return + # Mirror the launch-credential pre-flight from PAMTunnelEditCommand + # (--keeper-db-proxy on path) so the failure message and timing are + # identical between edit and start. + _est, _ett, _tk = get_keeper_tokens(params) + _existing_cfg = get_config_uid(params, _est, _ett, record_uid) + _proxy_dag = TunnelDAG(params, _est, _ett, _existing_cfg, transmission_key=_tk) + if not _proxy_dag.check_if_resource_has_launch_credential(record_uid): + print(f"{bcolors.FAIL}No Launch Credentials assigned to record \"{record_uid}\". " + f"Please assign launch credentials before using --proxy.{bcolors.ENDC}") + print(f"{bcolors.WARNING}Use: " + f"{bcolors.OKBLUE}pam connection edit --launch-user (-lu) " + f"{bcolors.ENDC}") + return + db_type_for_banner = self._resolve_database_type(record, pam_settings_value) + # Get target host and port if allow_supply_host: # User must supply target host and port via command arguments or interactive prompt @@ -912,6 +1002,15 @@ def execute(self, params, **kwargs): result = start_rust_tunnel(params, record_uid, gateway_uid, host, port, seed, target_host, target_port, socks, trickle_ice, record.title, allow_supply_host=allow_supply_host, two_factor_value=two_factor_value) if result and result.get("success"): + # When --proxy was used, print the KeeperDB Proxy info banner once. + # Local listener is already bound (start_rust_tunnel returns the + # actual_local_listen_addr from Rust), so the connect string is + # valid even though the WebRTC handshake may still be in progress. + # Single call covers interactive, foreground, run, and background- + # child modes — the background parent returns earlier and never + # reaches this branch. + if is_keeperdb_proxy: + self._print_keeperdb_proxy_banner(host, port, db_type_for_banner) # Workflow lease expiry handling. # # Behavior note: at expiresOn we want to terminate the tunnel and From b8fccee4c843b1002df3312a6d799ddeab459cf6 Mon Sep 17 00:00:00 2001 From: idimov-keeper <78815270+idimov-keeper@users.noreply.github.com> Date: Wed, 6 May 2026 19:53:08 -0500 Subject: [PATCH 07/26] Workflow window session hard close (#2021) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Hard-close PAM session tubes on workflow lease expiry When the workflow lease expires for `pam tunnel start` or `pam launch`, soft-close the tube now and escalate to keeper_pam_webrtc_rs's new force_close_tube after 3s. Force-close drops the local TCP listener and severs in-flight forwarded TCP streams (SSH, MySQL, etc.), which the prior soft-close did not — an active session would linger past expiry until the user disconnected manually. Escalation is gated on both endpoints supporting force_close_tube: hasattr(tube_registry, "force_close_tube") on the local crate AND remote SDP-advertised version >= FORCE_CLOSE_MIN_VERSION ("2.1.18"). Older peers fall back to soft close only. - tunnel_helpers: add escalate_close() shared helper plus FORCE_CLOSE_MIN_VERSION / FORCE_CLOSE_DELAY_SECONDS constants; consolidate _version_at_least here as the single source of truth (terminal_connection.py re-exports for back-compat with launch.py) - tunnel_and_connections: replace the previously-commented-out soft close in `pam tunnel start`'s lease-expiry callback with escalate_close - pam_launch/launch: wire escalate_close into `pam launch`'s _on_lease_expired (was only setting flags before, no tube close) * Escalate-close when launch lease already expired at start * Suppress runaway TURN refresh-permission log leak; harden rust-log filter for concurrent sessions * Redraw keeper-shell prompt after async lease-expiry message --- keepercommander/commands/pam_launch/launch.py | 49 +++++-- .../commands/pam_launch/rust_log_filter.py | 138 +++++++++++++++++- .../pam_launch/terminal_connection.py | 32 +--- .../tunnel/port_forward/tunnel_helpers.py | 138 ++++++++++++++++++ .../commands/tunnel_and_connections.py | 77 ++++------ requirements.txt | 2 +- setup.cfg | 2 +- 7 files changed, 340 insertions(+), 98 deletions(-) diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py index 6d64cf993..a6114b7a7 100644 --- a/keepercommander/commands/pam_launch/launch.py +++ b/keepercommander/commands/pam_launch/launch.py @@ -62,6 +62,8 @@ unregister_tunnel_session, unregister_conversation_key, get_keeper_tokens, + escalate_close, + CloseConnectionReasons, ) from ..tunnel.port_forward.TunnelGraph import TunnelDAG from .rust_log_filter import ( @@ -1415,13 +1417,11 @@ def _start_cli_session( (banner printed there); do not create a second spinner or duplicate the launching line. preserve_crlf: When True (default), STDOUT keeps raw CRLF; False when ``pam launch -n`` / ``--normalize-crlf``. """ - import sys as _sys - # Non-interactive stdin guard: key-event mode requires a real TTY. # --stdin (pipe mode) is fine with redirected stdin, but key mode is not — # tty.setraw() will raise and character-at-a-time mapping makes no sense # for piped/scripted input. - if not use_stdin and not _sys.stdin.isatty(): + if not use_stdin and not sys.stdin.isatty(): if pre_connect_spinner is not None: pre_connect_spinner.stop() raise CommandError( @@ -1452,19 +1452,48 @@ def signal_handler_fn(signum, frame): # the web vault (immediate teardown, no grace period, no reconnect). # The "Access expired" line is printed AFTER terminal reset in finally # so the message survives reset_local_terminal_after_pam_session(). + # On expiry we soft-close the tube and escalate to force_close_tube + # after FORCE_CLOSE_DELAY_SECONDS so any in-flight forwarded streams + # (SSH bytes etc.) are severed instead of lingering until the user + # disconnects manually. Escalation is gated on local hasattr + + # remote SDP version (FORCE_CLOSE_MIN_VERSION). lease_timer = None + force_close_timer_holder = {} # mutable holder so cleanup can cancel if workflow_expires_on_ms and workflow_expires_on_ms > 0: - import time as _time - seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - _time.time() - if seconds_until_expiry <= 0: + seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - time.time() + _lease_tube_id = tunnel_result['tunnel'].get('tube_id') + _lease_tube_registry = tunnel_result['tunnel'].get('tube_registry') + + def _on_lease_expired(): + nonlocal shutdown_requested, lease_expired lease_expired = True shutdown_requested = True + if _lease_tube_id and _lease_tube_registry is not None: + # Fetch remote version lazily: the SDP answer arrives + # asynchronously; capturing eagerly at schedule time + # would race for short leases scheduled before SDP. + remote_ver = tunnel_result['tunnel'].get('remote_webrtc_version') + if not remote_ver: + sess = get_tunnel_session(_lease_tube_id) + remote_ver = ( + getattr(sess, 'remote_webrtc_version', None) + if sess else None + ) + force_close_timer_holder['t'] = escalate_close( + _lease_tube_registry, + _lease_tube_id, + remote_webrtc_version=remote_ver, + reason=CloseConnectionReasons.AdminClosed, + log_prefix=f"[lease-expiry launch tube={_lease_tube_id[:8]}] ", + ) + + if seconds_until_expiry <= 0: + # Already expired at session start: run the close-and-escalate + # path immediately so cleanup goes through the same flow as a + # mid-session expiry. + _on_lease_expired() else: import threading as _threading - def _on_lease_expired(): - nonlocal shutdown_requested, lease_expired - lease_expired = True - shutdown_requested = True lease_timer = _threading.Timer(seconds_until_expiry, _on_lease_expired) lease_timer.daemon = True lease_timer.start() diff --git a/keepercommander/commands/pam_launch/rust_log_filter.py b/keepercommander/commands/pam_launch/rust_log_filter.py index cded0e9bd..b1ee55e96 100644 --- a/keepercommander/commands/pam_launch/rust_log_filter.py +++ b/keepercommander/commands/pam_launch/rust_log_filter.py @@ -9,6 +9,79 @@ import threading +# Patterns for known leak messages from turn-0.11.0's relay-conn task. +# The webrtc-rs ICE agent does not synchronously cancel its TURN refresh task +# on PeerConnection.close(); the task survives indefinitely and re-fires every +# few minutes (TURN permission lifetime ~5 min, refresh at ~3/4 of that). Each +# iteration logs: +# "fail to refresh permissions: CreatePermission error response (error 400: Bad Request)" +# "refresh permissions failed" +# from turn-0.11.0/src/client/relay_conn.rs:528 / :618. +# +# Until the upstream leak is fixed, suppress these messages permanently — they +# are post-close stragglers from a deallocated TURN allocation and have no +# diagnostic value to the user. +_TURN_REFRESH_LEAK_PATTERNS = ( + 'fail to refresh permissions', + 'refresh permissions failed', +) + + +class _PermanentTurnLeakFilter(logging.Filter): + """Always drop the known turn-rs refresh-permission leak messages. + + Installed once at module import time on the root logger, never removed. + Independent of the session-scoped _RustWebrtcToDebugFilter — that one + flips with --debug; this one is an upstream-bug workaround that should + fire regardless of debug state. + """ + + def filter(self, record: logging.LogRecord) -> bool: + try: + msg = record.getMessage() + except Exception: + return True + for needle in _TURN_REFRESH_LEAK_PATTERNS: + if needle in msg: + return False + return True + + +# Loggers known to emit the leak. Both dot- and colon-separated names cover +# the Rust→Python bridge formats. ``turn`` and ``turn.client`` cover any +# parent that records may originate from depending on rust-log target style. +_TURN_LEAK_LOGGER_NAMES = ( + 'turn', + 'turn.client', + 'turn.client.relay_conn', + 'turn::client::relay_conn', +) + +_PERMANENT_TURN_FILTER = _PermanentTurnLeakFilter() + + +def _install_permanent_turn_filter(): + """Attach the content filter to the known leaky loggers AND root. + + Python's logger filters fire only at the originating logger (filters do + NOT re-check during propagation up the hierarchy via callHandlers), so we + must attach to the actual emitting logger names rather than relying on + root.addFilter alone. Idempotent — safe to call again. + """ + for name in _TURN_LEAK_LOGGER_NAMES: + log = logging.getLogger(name) + if _PERMANENT_TURN_FILTER not in log.filters: + log.addFilter(_PERMANENT_TURN_FILTER) + # Also attach to root in case the Rust→Python bridge ever logs directly to + # root (cheap belt-and-braces). + root = logging.getLogger() + if _PERMANENT_TURN_FILTER not in root.filters: + root.addFilter(_PERMANENT_TURN_FILTER) + + +_install_permanent_turn_filter() + + def _rust_webrtc_logger_name(name: str) -> bool: """True if logger name is from Rust/webrtc/turn so we treat its messages as DEBUG-only.""" if not name: @@ -71,6 +144,10 @@ def enter_pam_launch_terminal_rust_logging(): Downgrades Rust/webrtc/turn messages to DEBUG so they only show with --debug. Returns a token to pass to exit_pam_launch_terminal_rust_logging() on exit. """ + global _ACTIVE_SESSION_COUNT + with _ACTIVE_SESSION_LOCK: + _ACTIVE_SESSION_COUNT += 1 + root = logging.getLogger() flt = _RustWebrtcToDebugFilter() root.addFilter(flt) @@ -111,11 +188,35 @@ def enter_pam_launch_terminal_rust_logging(): # the Rust/webrtc log filter. The Rust tube shutdown runs on its own runtime # threads and can emit a final log record AFTER Python's session-exit path has # returned control to the REPL — e.g. ``webrtc-sctp stream N not found`` when -# the channel is torn down. Without a grace period, that late record arrives -# at a root logger whose filter has already been removed and leaks to the -# console. We keep the filter in place for a short window so such stragglers -# are still suppressed. -_DEFAULT_RUST_LOG_FILTER_GRACE_SEC = 2.5 +# the channel is torn down, or TURN ``fail to refresh permissions`` warnings +# from the relay-conn task as it observes the deallocated allocation. +# +# The window must outlive both: +# 1. The soft→hard close escalation in ``escalate_close`` +# (``FORCE_CLOSE_DELAY_SECONDS`` = 3 s) +# 2. A brief TURN refresh-task latency after the PeerConnection drop cascade +# +# Imported lazily below to avoid a top-level cycle (this module is imported +# during pam_launch init, before the tunnel helpers are loaded for some +# callers). +def _force_close_delay_seconds(): + try: + from ..tunnel.port_forward.tunnel_helpers import FORCE_CLOSE_DELAY_SECONDS + return FORCE_CLOSE_DELAY_SECONDS + except Exception: + return 3.0 + + +_DEFAULT_RUST_LOG_FILTER_GRACE_SEC = _force_close_delay_seconds() + 1.5 + +# Refcount of active pam-launch sessions that have rust-log filtering installed. +# Incremented in enter_*, decremented at the END of the grace timer in +# _do_exit_rust_logging. The restore work (removing class-level filters, +# restoring pre-session logger state) is only performed when this drops to 0, +# so a second `pam launch` started during the grace window of a prior one is +# not silently de-filtered when the prior session's timer fires. +_ACTIVE_SESSION_COUNT = 0 +_ACTIVE_SESSION_LOCK = threading.Lock() def _do_exit_rust_logging(token): @@ -124,9 +225,32 @@ def _do_exit_rust_logging(token): return flt, saved = token[0], token[1] original_logger_class = token[2] if len(token) > 2 else logging.Logger - logging.setLoggerClass(original_logger_class) + + # Always remove THIS session's filter instance from root so per-token + # filters don't pile up. The bulk class-based cleanup below only runs when + # we are the last active session. root = logging.getLogger() - root.removeFilter(flt) + try: + root.removeFilter(flt) + except Exception: + pass + + global _ACTIVE_SESSION_COUNT + with _ACTIVE_SESSION_LOCK: + _ACTIVE_SESSION_COUNT = max(0, _ACTIVE_SESSION_COUNT - 1) + last_session = _ACTIVE_SESSION_COUNT == 0 + if not last_session: + # Another pam launch session is still active (or in its own grace + # window); leave the class-level filter and saved state alone so its + # filtering keeps working. We already removed our specific instance + # from root above. + logging.debug( + "rust_log_filter: skipping restore, %d session(s) still active", + _ACTIVE_SESSION_COUNT, + ) + return + + logging.setLoggerClass(original_logger_class) # Remove downgrade filter from all Rust/webrtc loggers (we may have added the shared # filter to existing loggers, and _RustAwareLogger instances have their own filter) for name in list(logging.Logger.manager.loggerDict.keys()): diff --git a/keepercommander/commands/pam_launch/terminal_connection.py b/keepercommander/commands/pam_launch/terminal_connection.py index bf27de0be..970dfde05 100644 --- a/keepercommander/commands/pam_launch/terminal_connection.py +++ b/keepercommander/commands/pam_launch/terminal_connection.py @@ -57,6 +57,7 @@ MAIN_NONCE_LENGTH, SYMMETRIC_KEY_LENGTH, set_remote_description_and_parse_version, + _version_at_least, ) from ..tunnel.port_forward.TunnelGraph import TunnelDAG from ..pam.pam_dto import GatewayAction, GatewayActionWebRTCSession @@ -136,37 +137,6 @@ CONNECT_AS_MIN_VERSION = "2.1.6" -def _version_at_least(version: Optional[str], min_version: str) -> bool: - """ - Compare semantic versions. Returns True if version >= min_version. - - Args: - version: Parsed version (e.g. "2.1.4") or None (treated as unknown/old). - min_version: Minimum required version (e.g. "2.1.0"). - - Returns: - True if version is known and >= min_version; False if unknown or older. - """ - if not version: - return False - - def parse(v: str) -> tuple: - parts = [] - for p in v.split(".")[:3]: # major.minor.patch - try: - parts.append(int(p)) - except ValueError: - parts.append(0) - while len(parts) < 3: - parts.append(0) - return tuple(parts[:3]) - - try: - return parse(version) >= parse(min_version) - except Exception: - return False - - def _ensure_max_message_size_attribute(sdp_offer: Optional[str]) -> Optional[str]: """ Ensure the SDP offer advertises the same max-message-size attribute as Web Vault. diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 628ecf47a..11f18aece 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -103,6 +103,144 @@ def set_remote_description_and_parse_version(tube_registry, tube_id, sdp, is_ans return remote_ver +def _version_at_least(version, min_version): + """ + Compare semantic versions. Returns True if `version` >= `min_version`. + + `version` of None or unparseable is treated as unknown/old (False). + """ + if not version: + return False + + def parse(v): + parts = [] + for p in v.split(".")[:3]: + try: + parts.append(int(p)) + except ValueError: + parts.append(0) + while len(parts) < 3: + parts.append(0) + return tuple(parts[:3]) + + try: + return parse(version) >= parse(min_version) + except Exception: + return False + + +# Minimum keeper-pam-webrtc-rs version that exposes force_close_tube. Both the +# local Rust crate AND the remote peer must satisfy this gate before Commander +# escalates a soft close to a force close. Local check uses hasattr (the binding +# attribute is missing on older crates), remote check uses the SDP-advertised +# version string. +FORCE_CLOSE_MIN_VERSION = "2.1.18" + +# Default delay between the soft close and the force-close escalation. Matches +# the consumer-side budget agreed with the gateway (gateway-side +# KEEPER_GATEWAY_FORCE_CLOSE_TIMEOUT is 6s; we run faster on the consumer because +# at lease expiry there is no reason to wait long). +FORCE_CLOSE_DELAY_SECONDS = 3.0 + + +def print_above_keeper_prompt(msg): + """Print ``msg`` so the keeper-shell prompt redraws itself underneath it. + + Strategy: + 1. If a prompt-toolkit app is running, call ``app.renderer.erase()`` — + this writes the ANSI sequences to fully erase the prompt area + (which may span multiple lines), leaving a clean cursor. + 2. Print the message + newline so the cursor advances below. + 3. Call ``app.invalidate()`` (thread-safe) to schedule a fresh prompt + render at the new cursor position. + + Falls back to plain ``print`` if no app is running. Avoids + ``run_in_terminal`` (returns a coroutine that needs to be awaited on + the app's event loop; scheduling that from a Timer thread is + version-fragile and leaks un-awaited coroutines). + """ + app = None + try: + from prompt_toolkit.application.current import get_app_or_none + app = get_app_or_none() + if app is not None and app.is_running: + try: + app.renderer.erase() + except Exception: + pass + except Exception: + app = None + + sys.stdout.write(msg + '\n') + sys.stdout.flush() + + if app is not None: + try: + app.invalidate() + except Exception: + pass + + +def escalate_close( + tube_registry, + tube_id, + *, + remote_webrtc_version=None, + reason=None, + hard_after_seconds=FORCE_CLOSE_DELAY_SECONDS, + log_prefix="", +): + """ + Soft-close a tube now, then escalate to force_close_tube after + `hard_after_seconds` if both endpoints support it. + + The soft close stops new channel creation and emits CloseConnection control + frames; the force close (when available) drops the local TCP listener, + severs in-flight forwarded TCP streams (SSH, MySQL, etc.) and tears down + the peer connection on a short bounded budget. + + Returns the scheduled `threading.Timer` (or None if escalation is not + available) so callers can cancel it on a clean exit. + """ + if reason is None: + reason = CloseConnectionReasons.AdminClosed + + try: + tube_registry.close_tube(tube_id, reason=reason) + except Exception as e: + logging.debug(f"{log_prefix}soft close_tube failed: {e}") + + has_local = hasattr(tube_registry, "force_close_tube") + has_remote = _version_at_least(remote_webrtc_version, FORCE_CLOSE_MIN_VERSION) + if not has_local: + logging.debug( + f"{log_prefix}force_close_tube unavailable in local keeper_pam_webrtc_rs - " + f"soft close only" + ) + return None + if not has_remote: + logging.debug( + f"{log_prefix}remote keeper-pam-webrtc {remote_webrtc_version!r} < " + f"{FORCE_CLOSE_MIN_VERSION} - soft close only" + ) + return None + + def _do_force_close(): + try: + logging.debug( + f"{log_prefix}escalating to force_close_tube({tube_id}) after " + f"{hard_after_seconds}s" + ) + tube_registry.force_close_tube(tube_id, reason=reason) + except Exception as e: + logging.debug(f"{log_prefix}force_close_tube failed: {e}") + + timer = threading.Timer(hard_after_seconds, _do_force_close) + timer.daemon = True + timer.start() + return timer + + # Constants NONCE_LENGTH = 12 MAIN_NONCE_LENGTH = 16 diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 026be65c9..3034d3644 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -24,7 +24,7 @@ import sys import threading import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, url_safe_str_to_bytes from .base import Command, GroupCommand, dump_report_data, RecordMixin @@ -32,7 +32,8 @@ from .tunnel.port_forward.tunnel_helpers import find_open_port, get_config_uid, get_keeper_tokens, \ get_or_create_tube_registry, get_gateway_uid_from_record, resolve_record, resolve_pam_config, resolve_folder, \ remove_field, start_rust_tunnel, get_tunnel_session, unregister_tunnel_session, CloseConnectionReasons, \ - wait_for_tunnel_connection, create_rust_webrtc_settings + wait_for_tunnel_connection, create_rust_webrtc_settings, escalate_close, \ + print_above_keeper_prompt from .tunnel_registry import ( PARENT_GRACE_SECONDS, is_pid_alive, @@ -54,14 +55,12 @@ # so a re-run of `pam tunnel start` in the same shell session doesn't leave # the original timer alive and produce duplicate "Tunnel access expired" # messages from the prior tunnel. -_LEASE_EXPIRY_TIMERS_BY_RECORD = {} # type: dict[str, threading.Timer] +_LEASE_EXPIRY_TIMERS_BY_RECORD: Dict[str, threading.Timer] = {} # Maps record_uid -> threading.Event used by --foreground / --run modes to break # their blocking wait when the workflow lease expires. Set by the mode block, # read by the lease-expiry callback. Default interactive mode does NOT register # (it has no blocking wait to interrupt; user SSH session continues naturally). -_LEASE_SHUTDOWN_EVENTS_BY_RECORD = {} # type: dict[str, threading.Event] -import threading as _lease_threading_module # noqa: E402 (used only by the tunnel-start timer) - +_LEASE_SHUTDOWN_EVENTS_BY_RECORD: Dict[str, threading.Event] = {} # Group Commands @@ -1013,33 +1012,15 @@ def execute(self, params, **kwargs): self._print_keeperdb_proxy_banner(host, port, db_type_for_banner) # Workflow lease expiry handling. # - # Behavior note: at expiresOn we want to terminate the tunnel and - # any in-flight forwarded connections (web-vault-equivalent hard - # disconnect). However: - # * tube_registry.close_tube(tube_id) is a SOFT close — - # it marks the tube closed and stops new channel creation but - # does NOT terminate already-open forwarded channels. An SSH - # session active through the tube keeps relaying bytes until - # the SSH client itself closes (you'll see TURN - # "fail to refresh permissions" warnings while the existing - # 5-tuple keeps flowing). - # * The local TCP listener is owned by Rust - # (keeper_pam_webrtc_rs); Python has no handle to force-close - # it from here. - # So a true hard-kill is not currently possible from the Python - # layer. The server-side workflow lease still becomes invalid at - # expiresOn (the gateway will refuse new auth requests); only the - # already-running SSH session survives until natural disconnect. - # - # The previous implementation called - # `tube_registry.close_tube(_tube_id, reason=CloseConnectionReasons.Normal)` - # but it was a no-op against active channels and produced - # confusing "fail to refresh permissions" noise from the TURN - # relay. Kept commented out below for reference if a future - # keeper_pam_webrtc_rs release adds a hard-kill primitive. + # At expiresOn we soft-close the tube (stops new channels, sends + # CloseConnection control frames) and, after a short delay, escalate + # to force_close_tube which drops the local TCP listener and severs + # any active forwarded streams (SSH, MySQL, etc.). The escalation + # only fires when both the local Rust crate and the remote peer + # advertise FORCE_CLOSE_MIN_VERSION; older peers get the soft close + # only and the in-flight session lingers until natural disconnect. if workflow_expires_on_ms and workflow_expires_on_ms > 0: - import time as _time - seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - _time.time() + seconds_until_expiry = (workflow_expires_on_ms / 1000.0) - time.time() tube_id = result.get('tube_id') if tube_id and seconds_until_expiry > 0: # Dedup: cancel any pending lease-expiry timer for this @@ -1056,23 +1037,24 @@ def execute(self, params, **kwargs): def _close_on_lease_expiry(_tube_id=tube_id, _record_uid=record_uid): try: - print( + print_above_keeper_prompt( f"\n{bcolors.WARNING}Tunnel access lease expired for " - f"{_record_uid}. Server will refuse new auth requests; " - f"any in-flight SSH session will continue until you " - f"disconnect it.{bcolors.ENDC}", - flush=True, + f"{_record_uid}. Closing the tunnel; any in-flight " + f"forwarded connections will be terminated." + f"{bcolors.ENDC}" + ) + sess = get_tunnel_session(_tube_id) + remote_ver = getattr(sess, 'remote_webrtc_version', None) if sess else None + escalate_close( + tube_registry, + _tube_id, + remote_webrtc_version=remote_ver, + reason=CloseConnectionReasons.AdminClosed, + log_prefix=f"[lease-expiry tunnel record={_record_uid}] ", ) - # Soft-close — kept commented out: doesn't actually - # terminate active forwarded channels, only emits - # TURN permission-refresh errors. Re-enable once - # keeper_pam_webrtc_rs provides a hard-kill that - # also drops the local listener. - # tube_registry.close_tube(_tube_id, reason=CloseConnectionReasons.Normal) # Wake any --foreground / --run blocking wait so the - # process self-terminates instead of hanging past lease - # expiry. Default interactive mode does not register - # an event here — it has no blocking wait to break. + # process self-terminates. Default interactive mode + # does not register an event here. shutdown_event = _LEASE_SHUTDOWN_EVENTS_BY_RECORD.get(_record_uid) if shutdown_event is not None: shutdown_event.set() @@ -1081,7 +1063,7 @@ def _close_on_lease_expiry(_tube_id=tube_id, _record_uid=record_uid): finally: _LEASE_EXPIRY_TIMERS_BY_RECORD.pop(_record_uid, None) - timer = _lease_threading_module.Timer( + timer = threading.Timer( seconds_until_expiry, _close_on_lease_expiry, ) timer.daemon = True @@ -1703,7 +1685,6 @@ def _record(name: str, passed: bool, detail: str, ms: int): password=turn_password, ) if output_format == 'json': - import json print(json.dumps(rust_results, indent=2)) return 0 diff --git a/requirements.txt b/requirements.txt index 41983e843..f67f0f6c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ requests>=2.31.0 cryptography>=46.0.6 protobuf>=5.29.6 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs>=2.1.6 +keeper_pam_webrtc_rs>=2.1.17 pydantic>=2.6.4 flask pyngrok>=7.5.0 diff --git a/setup.cfg b/setup.cfg index ac7f6bfb8..3d6801983 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ install_requires = requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs>=2.1.6 + keeper_pam_webrtc_rs>=2.1.17 pydantic>=2.6.4 fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' From 06173c01ed308375eae92ea51da83d57c658cfe7 Mon Sep 17 00:00:00 2001 From: Craig Lurey Date: Thu, 7 May 2026 07:51:12 -0700 Subject: [PATCH 08/26] Add `pam launch` from SuperShell Adds a one-key path from the SuperShell TUI to a KeeperPAM connection for pamMachine and pamDatabase records. * `is_launchable(params, record_uid)` and `get_launch_info(...)` exposed as module-level helpers in `pam_launch/launch.py` so the TUI and CLI share the same eligibility/protocol logic without duplicating it. * SuperShell `L` keybinding suspends the Textual app, runs `PAMLaunchCommand` in the released terminal, and resumes the TUI when the session ends. `CommandError` failures show a clean message; other exceptions still log a traceback. Both pause for Enter so the user sees the message before the TUI redraws. * Detail view for pamMachine/pamDatabase strips the noisy `pamHostname`, `pamSettings`, `trafficEncryptionSeed`, and `checkbox:*` fields (multi-line dict continuations included via brace counting). JSON view is unchanged. * New "Launch" section rendered after the Title shows protocol, host, credential, and a visual "Press L to Launch " button. `(prompted at launch)` placeholders cover `allowSupplyHost` / `allowSupplyUser` records that have no static value. * `pam launch` now prompts for `host:port` instead of erroring when `allowSupplyHost: True`, no `--host` is given, and the record has no static hostname. Up to 3 retries on bad input with the expected format shown in the prompt; non-TTY behavior preserved. * Help modal lists the new `L` action. --- keepercommander/commands/pam_launch/launch.py | 128 +++++++++++++++- keepercommander/commands/supershell/app.py | 142 ++++++++++++++++++ .../commands/supershell/screens/help.py | 1 + 3 files changed, 268 insertions(+), 3 deletions(-) diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py index a6114b7a7..1435ab112 100644 --- a/keepercommander/commands/pam_launch/launch.py +++ b/keepercommander/commands/pam_launch/launch.py @@ -322,11 +322,96 @@ def _color(text: str, color: str) -> str: return pending_exit_code +VALID_PAM_RECORD_TYPES = {'pamDatabase', 'pamDirectory', 'pamMachine'} + + +def is_launchable(params, record_uid): + """Return ``(eligible, protocol)`` for a `pam launch` candidate. + + Eligibility: TypedRecord v3, ``record_type`` in :data:`VALID_PAM_RECORD_TYPES`, + and detected protocol in :data:`ALL_TERMINAL`. Reads only cached vault data — + safe to call from a TUI render path. + """ + try: + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + return False, None + if record.version != 3: + return False, None + if record.record_type not in VALID_PAM_RECORD_TYPES: + return False, None + except Exception as e: + logging.debug("is_launchable: cannot load %s: %s", record_uid, e) + return False, None + try: + protocol = detect_protocol(params, record_uid) + except Exception as e: + logging.debug("is_launchable: detect_protocol failed for %s: %s", record_uid, e) + return False, None + if protocol not in ALL_TERMINAL: + return False, None + return True, protocol + + +def get_launch_info(params, record_uid): + """Return a summary dict describing how this record will launch, or ``None``. + + Reads only cached vault data — safe to call from a TUI render path. + + Keys: + protocol — terminal protocol (e.g. 'ssh', 'mysql') + host — hostname from pamHostname/host field, or None + port — port (int) from pamHostname/host field, or None + allow_supply_host — bool: user may supply host at launch time + allow_supply_user — bool: user may supply credential at launch time + credential_uid — first userRecords[] entry, or None + credential_title — title of the credential record (if cached), else None + """ + ok, protocol = is_launchable(params, record_uid) + if not ok: + return None + try: + record = vault.KeeperRecord.load(params, record_uid) + except Exception: + return None + host, port = _get_host_port_from_record(record) + allow_supply_host = False + allow_supply_user = False + credential_uid = None + pam_settings_field = record.get_typed_field('pamSettings') + if pam_settings_field: + pam_settings_value = pam_settings_field.get_default_value(dict) or {} + allow_supply_host = bool(pam_settings_value.get('allowSupplyHost')) + connection = pam_settings_value.get('connection') or {} + if isinstance(connection, dict): + allow_supply_user = bool(connection.get('allowSupplyUser')) + user_records = connection.get('userRecords') or [] + if user_records: + credential_uid = user_records[0] + credential_title = None + if credential_uid: + try: + cred = vault.KeeperRecord.load(params, credential_uid) + if cred is not None: + credential_title = getattr(cred, 'title', None) + except Exception: + credential_title = None + return { + 'protocol': protocol, + 'host': host, + 'port': port, + 'allow_supply_host': allow_supply_host, + 'allow_supply_user': allow_supply_user, + 'credential_uid': credential_uid, + 'credential_title': credential_title, + } + + class PAMLaunchCommand(Command): """PAM Launch command to launch a connection to a PAM resource""" - # Valid PAM record types for launch - VALID_PAM_RECORD_TYPES = {'pamDatabase', 'pamDirectory', 'pamMachine'} + # Valid PAM record types for launch (kept as class alias for backwards reference) + VALID_PAM_RECORD_TYPES = VALID_PAM_RECORD_TYPES parser = argparse.ArgumentParser(prog='pam launch', description='Launch a connection to a PAM resource') parser.add_argument('record', type=str, action='store', @@ -1182,7 +1267,44 @@ def _refresh_fetch(_params=params, _record_uid=record_uid, _self=self): if not has_cli_host: # No CLI host -> must come from the PAM launch record if not hostname_on_record: - if allow_supply_host: + if allow_supply_host and sys.stdin.isatty() and sys.stdout.isatty(): + # Interactive prompt instead of erroring out — covers the common + # `pam launch ` case for records with allowSupplyHost and no + # static hostname (e.g. SuperShell launching from the TUI). + prompt_text = ( + f'{Fore.CYAN}Enter host for launch ' + f'{Fore.WHITE}(format: host:port, e.g. 192.168.1.1:22 or ' + f'server.example.com:3306, [::1]:22 for IPv6){Fore.RESET}\n' + f'{Fore.CYAN}Host:port: {Fore.RESET}' + ) + custom_host = None + custom_port = None + for _ in range(3): + try: + user_input = input(prompt_text).strip() + except (EOFError, KeyboardInterrupt): + logging.info('Canceled') + return + if not user_input: + logging.info('Canceled') + return + try: + custom_host, custom_port = _parse_host_port(user_input) + break + except CommandError as e: + print(f'{Fore.RED}{e}{Fore.RESET}', file=sys.stderr) + custom_host = None + custom_port = None + if custom_host is None: + logging.error('Too many invalid host entries; aborting launch.') + return + kwargs['custom_host'] = custom_host + kwargs['custom_port'] = custom_port + has_cli_host = True + logging.info( + 'Using interactively supplied host: %s:%s', custom_host, custom_port + ) + elif allow_supply_host: raise CommandError('pam launch', 'allowSupplyHost is enabled but no hostname on record. ' 'Use --host, --host-record, or --credential with a host:port to specify.') diff --git a/keepercommander/commands/supershell/app.py b/keepercommander/commands/supershell/app.py index ae4747753..1cf3ce2b1 100644 --- a/keepercommander/commands/supershell/app.py +++ b/keepercommander/commands/supershell/app.py @@ -90,6 +90,7 @@ class SuperShellApp(App): Binding("m", "toggle_unmask", "Toggle Unmask", show=False), Binding("W", "show_user_info", "User Info", show=False), Binding("D", "show_device_info", "Device Info", show=False), + Binding("L", "launch_pam", "Launch PAM", show=False), Binding("?", "show_help", "Help", show=False), # Vim-style navigation Binding("j", "cursor_down", "Down", show=False), @@ -1803,6 +1804,33 @@ def _clear_clickable_fields(self): except Exception as e: logging.debug(f"Error clearing clickable fields: {e}") + @staticmethod + def _strip_pam_internal_fields(output: str) -> str: + """Drop noisy PAM fields (pamHostname, pamSettings, trafficEncryptionSeed, + checkbox:*) from the legacy ``get`` output, including multi-line dict + continuations. The fields stay visible in JSON view; this only affects + the Detail view for pamMachine/pamDatabase records. + """ + skip_prefixes = ('pamHostname:', 'pamSettings:', 'trafficEncryptionSeed:', 'checkbox:') + kept = [] + skipping = False + brace_depth = 0 + for raw_line in output.split('\n'): + stripped_left = raw_line.lstrip() + if skipping: + brace_depth += raw_line.count('{') - raw_line.count('}') + if brace_depth <= 0: + skipping = False + continue + if any(stripped_left.startswith(p) for p in skip_prefixes): + brace_depth = raw_line.count('{') - raw_line.count('}') + if brace_depth > 0: + skipping = True + # else: single-line value, just drop this one line + continue + kept.append(raw_line) + return '\n'.join(kept) + def _display_record_with_clickable_fields(self, record_uid: str): """Display record details with clickable fields for copy-on-click""" t = self.theme_colors @@ -1816,6 +1844,14 @@ def _display_record_with_clickable_fields(self, record_uid: str): output = self._get_record_output(record_uid, format_type='detail') output = strip_ansi_codes(output) + # For launchable PAM records (pamMachine/pamDatabase) hide the + # noisy raw fields from Detail view — a parsed Launch section + # below replaces them. JSON view stays unchanged. + record_data_for_filter = self.records.get(record_uid, {}) + record_type_for_filter = record_data_for_filter.get('record_type', '') + if record_type_for_filter in ('pamMachine', 'pamDatabase'): + output = self._strip_pam_internal_fields(output) + if not output or output.strip() == '': detail_widget.update("[red]Failed to get record details[/red]") return @@ -1959,6 +1995,68 @@ def display_rotation(): rotation_displayed = True + launch_displayed = False + + def display_launch_section(): + """Render the parsed Launch section for pamMachine/pamDatabase.""" + nonlocal launch_displayed + if launch_displayed: + return + try: + from ..pam_launch.launch import get_launch_info + info = get_launch_info(self.params, record_uid) + except Exception as e: + logging.debug(f"Launch section: get_launch_info failed: {e}") + info = None + if not info: + return + protocol = (info.get('protocol') or '').upper() or 'PAM' + if info.get('host'): + host_str = f"{info['host']}:{info['port']}" if info.get('port') else info['host'] + if info.get('allow_supply_host'): + host_display = f"{host_str} (or supplied at launch)" + else: + host_display = host_str + elif info.get('allow_supply_host'): + host_display = "(prompted at launch)" + else: + host_display = "(not configured)" + if info.get('credential_uid'): + cred_str = info.get('credential_title') or info['credential_uid'] + if info.get('allow_supply_user'): + cred_display = f"{cred_str} (or supplied at launch)" + else: + cred_display = cred_str + elif info.get('allow_supply_user') or info.get('allow_supply_host'): + cred_display = "(prompted at launch)" + else: + cred_display = "(not configured)" + mount_line("", None) + mount_line(f"[bold {t['secondary']}]Launch:[/bold {t['secondary']}]", None) + mount_line( + f" [{t['text_dim']}]Protocol:[/{t['text_dim']}] " + f"[{t['primary']}]{rich_escape(protocol)}[/{t['primary']}]", + protocol, + ) + mount_line( + f" [{t['text_dim']}]Host:[/{t['text_dim']}] " + f"[{t['primary']}]{rich_escape(str(host_display))}[/{t['primary']}]", + host_display, + ) + mount_line( + f" [{t['text_dim']}]Credential:[/{t['text_dim']}] " + f"[{t['primary']}]{rich_escape(str(cred_display))}[/{t['primary']}]", + cred_display, + ) + mount_line("", None) + mount_line( + f" [bold black on {t['primary']}] >> Press L to Launch {rich_escape(protocol)} << " + f"[/bold black on {t['primary']}]", + None, + ) + mount_line("", None) + launch_displayed = True + for line in output.split('\n'): stripped = line.strip() if not stripped: @@ -1973,6 +2071,8 @@ def display_rotation(): mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) elif key in ['Title', 'Name'] and not current_section: mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [bold {t['primary']}]{rich_escape(str(value))}[/bold {t['primary']}]", value) + if record_type_for_filter in ('pamMachine', 'pamDatabase'): + display_launch_section() elif key == 'Type': # Show 'app' for app records if type is blank display_type = value if value else 'app' if record_uid in self.app_record_uids else '' @@ -2897,6 +2997,17 @@ def _update_shortcuts_bar(self, record_selected: bool = False, folder_selected: elif record_selected: mode = "JSON" if self.view_mode == 'json' else "Detail" mask_label = "Mask" if self.unmask_secrets else "Unmask" + launch_hint = "" + if self.selected_record: + try: + from ..pam_launch.launch import is_launchable + ok, protocol = is_launchable(self.params, self.selected_record) + if ok: + launch_hint = ( + f" [{t['text_dim']}]L[/{t['text_dim']}]=Launch {protocol.upper()}" + ) + except Exception as e: + logging.debug(f"is_launchable check failed: {e}") shortcuts_bar.update( f"[{t['secondary']}]Mode: {mode}[/{t['secondary']}] " f"[{t['text_dim']}]t[/{t['text_dim']}]=Toggle " @@ -2904,6 +3015,7 @@ def _update_shortcuts_bar(self, record_selected: bool = False, folder_selected: f"[{t['text_dim']}]u[/{t['text_dim']}]=Username " f"[{t['text_dim']}]c[/{t['text_dim']}]=Copy All " f"[{t['text_dim']}]m[/{t['text_dim']}]={mask_label}" + f"{launch_hint}" ) elif folder_selected: mode = "JSON" if self.view_mode == 'json' else "Detail" @@ -3680,6 +3792,36 @@ def action_toggle_unmask(self): except Exception as e: logging.error(f"Error toggling unmask: {e}", exc_info=True) + def action_launch_pam(self): + """Launch a KeeperPAM connection for the selected record.""" + if not self.selected_record or self.selected_record not in self.records: + self.notify("No record selected", severity="warning") + return + from ..pam_launch.launch import PAMLaunchCommand, is_launchable + from ...error import CommandError + ok, protocol = is_launchable(self.params, self.selected_record) + if not ok: + self.notify("This record is not launchable", severity="warning") + return + record_uid = self.selected_record + with self.suspend(): + try: + PAMLaunchCommand().execute(self.params, record=record_uid) + except CommandError as e: + # User-visible launch failure — show the message, no traceback. + print(f"\nLaunch failed: {e}", file=sys.stderr) + try: + input("Press Enter to return to SuperShell...") + except (EOFError, KeyboardInterrupt): + pass + except Exception as e: + logging.error(f"pam launch failed: {e}", exc_info=True) + try: + input("Press Enter to return to SuperShell...") + except (EOFError, KeyboardInterrupt): + pass + self._display_record_detail(record_uid) + def action_copy_password(self): """Copy password of selected record to clipboard using clipboard-copy command (generates audit event)""" if self.selected_record and self.selected_record in self.records: diff --git a/keepercommander/commands/supershell/screens/help.py b/keepercommander/commands/supershell/screens/help.py index 0f10e406e..dcf13f088 100644 --- a/keepercommander/commands/supershell/screens/help.py +++ b/keepercommander/commands/supershell/screens/help.py @@ -111,6 +111,7 @@ def compose(self) -> ComposeResult: t Toggle JSON view m Mask/Unmask d Sync vault + L Launch PAM connection W User info D Device info P Preferences From 54de71cccc64e5e36019cc9a783fab1dd3068246 Mon Sep 17 00:00:00 2001 From: Craig Lurey Date: Thu, 7 May 2026 07:51:21 -0700 Subject: [PATCH 09/26] Show correct `login --server` hint for batch vs shell mode When `keeper login` runs from the system shell (batch mode), the hint to change the data center previously suggested `login --server ` which dispatches to macOS/Linux `login(1)` instead of Commander. The hint now reads `keeper login --server ` in batch mode and keeps `login --server ` for the interactive Keeper shell. --- keepercommander/commands/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keepercommander/commands/utils.py b/keepercommander/commands/utils.py index 6a2fc65c3..c4dfe885d 100644 --- a/keepercommander/commands/utils.py +++ b/keepercommander/commands/utils.py @@ -1695,7 +1695,8 @@ def execute(self, params, **kwargs): # Check extended server list region = next((k for k, v in KEEPER_SERVERS.items() if v == params.server), params.server) print(f'{Fore.CYAN}Data center: {Fore.WHITE}{region}{Fore.RESET}', file=sys.stderr) - print(f'{Fore.CYAN}Use {Fore.GREEN}login --server {Fore.CYAN} to change (US, EU, AU, CA, JP, GOV){Fore.RESET}', file=sys.stderr) + hint_cmd = 'keeper login --server ' if params.batch_mode else 'login --server ' + print(f'{Fore.CYAN}Use {Fore.GREEN}{hint_cmd}{Fore.CYAN} to change (US, EU, AU, CA, JP, GOV){Fore.RESET}', file=sys.stderr) print('', file=sys.stderr) user = input(f'{Fore.GREEN}Email: {Fore.RESET}').strip() if not user: From b926381a9157ebf05dc65501dd834fdbc15a21f2 Mon Sep 17 00:00:00 2001 From: adeshmukh-ks Date: Thu, 7 May 2026 14:51:41 +0530 Subject: [PATCH 10/26] Added support parameters to connection edit command --- .../tunnel/port_forward/TunnelGraph.py | 15 ++++++- .../commands/tunnel_and_connections.py | 42 +++++++++++++++---- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py index 98aaaf979..027a4d822 100644 --- a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py +++ b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py @@ -522,7 +522,8 @@ def set_resource_allowed(self, resource_uid, tunneling=None, connections=None, r session_recording=None, typescript_recording=None, remote_browser_isolation=None, ai_enabled=None, ai_session_terminate=None, allowed_settings_name='allowedSettings', is_config=False, - v_type: RefType=str(RefType.PAM_MACHINE), meta_version=None): + v_type: RefType=str(RefType.PAM_MACHINE), meta_version=None, + rotate_on_termination=None): v_type = RefType(v_type) allowed_ref_types = [RefType.PAM_MACHINE, RefType.PAM_DATABASE, RefType.PAM_DIRECTORY, RefType.PAM_BROWSER] if v_type not in allowed_ref_types: @@ -625,13 +626,23 @@ def set_resource_allowed(self, resource_uid, tunneling=None, connections=None, r else: settings["aiSessionTerminate"] = ai_session_terminate + if rotate_on_termination is not None: + if content is None: + content = {allowed_settings_name: {}} + dirty = True + current_rot = bool(content.get("rotateOnTermination", False)) + if rotate_on_termination != current_rot: + dirty = True + content = ensure_resource_meta_v1(dict(content)) + content["rotateOnTermination"] = bool(rotate_on_termination) + if dirty: # Legacy: missing or meta_version=0 -> write content as-is (no version in meta) if meta_version is not None and meta_version != 0: meta_payload = build_resource_meta( meta_version, content.get(allowed_settings_name, {}), - rotate_on_termination=False, + rotate_on_termination=bool(content.get("rotateOnTermination", False)), ) resource_vertex.add_data(content=meta_payload, path='meta', needs_encryption=False) else: diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 3034d3644..788ff05f0 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -1942,7 +1942,9 @@ class PAMConnectionEditCommand(Command): 'credential on the PAM Resource') parser.add_argument('--launch-user', '-lu', required=False, dest='launch_user', action='store', help='The record path or UID of the PAM User record to configure as the launch ' - 'credential on the PAM Resource') + 'credential on the PAM Resource.') + parser.add_argument('--clear-launch-user', required=False, dest='clear_launch_user', action='store_true', + help='Remove the launch credential from the resource (clears is_launch_credential in the DAG)') parser.add_argument('--protocol', '-p', dest='protocol', choices=protocols, help='Set connection protocol') parser.add_argument('--connections', '-cn', dest='connections', choices=choices, @@ -1956,6 +1958,9 @@ class PAMConnectionEditCommand(Command): 'the port from the record will be used.') parser.add_argument('--key-events', '-k', dest='key_events', choices=choices, help='Toggle Key Events settings') + parser.add_argument('--rotate-on-termination', required=False, dest='rotate_on_termination', + choices=['on', 'off'], + help='Rotate launch credentials when the PAM session ends (DAG resource meta)') parser.add_argument('--silent', '-s', required=False, dest='silent', action='store_true', help='Silent mode - don\'t print PAM User, PAM Config etc.') @@ -2152,24 +2157,47 @@ def execute(self, params, **kwargs): if _typescript_recording is not None and tdag.check_if_resource_allowed(record_uid, "typescriptRecording") != _typescript_recording: dirty = True - if dirty: + launch_credential_record_types = ("pamDatabase", "pamDirectory", "pamMachine") + rot_kw = kwargs.get('rotate_on_termination') + rot_bool = True if rot_kw == 'on' else False if rot_kw == 'off' else None + if rot_bool is not None and record_type not in launch_credential_record_types: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--rotate-on-termination is only supported for pamMachine, pamDatabase, and ' + f'pamDirectory records. Record "{record_uid}" is of type "{record_type}" and does not support ' + f'launch credentials.{bcolors.ENDC}') + + if dirty or rot_bool is not None: tdag.set_resource_allowed(resource_uid=record_uid, allowed_settings_name=allowed_settings_name, connections=kwargs.get('connections', None), session_recording=kwargs.get('recording', None), - typescript_recording=kwargs.get('typescriptrecording', None)) + typescript_recording=kwargs.get('typescriptrecording', None), + rotate_on_termination=rot_bool) # admin parameter is optional yet if not set connections may fail admin_name = kwargs.get('admin') adm_rec = RecordMixin.resolve_single_record(params, admin_name) admin_uid = adm_rec.record_uid if adm_rec else None - if admin_uid and record_type in ("pamDatabase", "pamDirectory", "pamMachine"): + if admin_uid and record_type in launch_credential_record_types: tdag.link_user_to_resource(admin_uid, record_uid, is_admin=True, belongs_to=True) # tdag.link_user_to_config(admin_uid) # is_iam_user=True - # launch-user parameter sets the launch credential on the resource + # launch-user parameter sets the launch credential; --clear-launch-user removes it + clear_launch_user = bool(kwargs.get('clear_launch_user')) launch_user_name = kwargs.get('launch_user') - if launch_user_name: + + if clear_launch_user and launch_user_name: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}Use either --clear-launch-user or --launch-user, not both.{bcolors.ENDC}') + if clear_launch_user: + if record_type not in launch_credential_record_types: + raise CommandError('pam connection edit', + f'{bcolors.FAIL}--clear-launch-user is only supported for pamMachine, pamDatabase, and ' + f'pamDirectory records. Record "{record_uid}" is of type "{record_type}" and does not ' + f'support launch credentials.{bcolors.ENDC}') + tdag.clear_launch_credential_for_resource(record_uid) + tdag.upgrade_resource_meta_to_v1(record_uid) + elif launch_user_name: launch_rec = RecordMixin.resolve_single_record(params, launch_user_name) if not launch_rec: raise CommandError('', @@ -2178,7 +2206,7 @@ def execute(self, params, **kwargs): raise CommandError('', f'{bcolors.FAIL}Launch user record must be a pamUser record type.{bcolors.ENDC}') launch_uid = launch_rec.record_uid - if record_type in ("pamDatabase", "pamDirectory", "pamMachine"): + if record_type in launch_credential_record_types: tdag.clear_launch_credential_for_resource(record_uid, exclude_user_uid=launch_uid) tdag.link_user_to_resource(launch_uid, record_uid, is_launch_credential=True, belongs_to=True) tdag.upgrade_resource_meta_to_v1(record_uid) From 8faeb0558a5c1ccac8d76e5a7a5fb93366daa59d Mon Sep 17 00:00:00 2001 From: Fred Reimer Date: Thu, 7 May 2026 16:10:52 -0400 Subject: [PATCH 11/26] Handle legacy service config record title on startup (#1954) Co-authored-by: Craig Lurey --- .../service/config/file_handler.py | 24 +++++++---- unit-tests/service/test_service_config.py | 43 ++++++++++++++++--- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/keepercommander/service/config/file_handler.py b/keepercommander/service/config/file_handler.py index e7b4a798a..88419194d 100644 --- a/keepercommander/service/config/file_handler.py +++ b/keepercommander/service/config/file_handler.py @@ -17,6 +17,13 @@ from ..decorators.logging import logger from ..util.exceptions import ValidationError + +SERVICE_CONFIG_RECORD_TITLES = ( + 'Commander Service Mode Config', + 'Commander Service Mode', +) + + class ConfigFormatHandler: def __init__(self, config_dir: Path, messages: Dict, validation_messages: Dict): self.config_dir = config_dir @@ -50,13 +57,14 @@ def get_config_format(self, save_type: str = None) -> str: from ..core.globals import get_current_params if params := get_current_params(): - if self.cli_handler.download_config_from_vault(params, 'Commander Service Mode Config', self.config_dir): - if json_path.exists(): - self.encrypt_config_file(json_path, self.config_dir) - return 'json' - if yaml_path.exists(): - self.encrypt_config_file(yaml_path, self.config_dir) - return 'yaml' + for title in SERVICE_CONFIG_RECORD_TITLES: + if self.cli_handler.download_config_from_vault(params, title, self.config_dir): + if json_path.exists(): + self.encrypt_config_file(json_path, self.config_dir) + return 'json' + if yaml_path.exists(): + self.encrypt_config_file(yaml_path, self.config_dir) + return 'yaml' return self._get_format_input() if save_type == 'create' else 'json' @@ -195,4 +203,4 @@ def decrypt_config_file(encrypted_content: bytes, config_dir: Path) -> str: hashed_key = sha256(private_key.encode('utf-8')).digest() return decrypt_aes_v2(encrypted_content, hashed_key).decode('utf-8') except Exception as e: - raise ValidationError(f"Failed to decrypt configuration file: {str(e)}") \ No newline at end of file + raise ValidationError(f"Failed to decrypt configuration file: {str(e)}") diff --git a/unit-tests/service/test_service_config.py b/unit-tests/service/test_service_config.py index 9ca4f3eeb..dd221f67d 100644 --- a/unit-tests/service/test_service_config.py +++ b/unit-tests/service/test_service_config.py @@ -1,7 +1,10 @@ import unittest +import tempfile +from pathlib import Path from unittest.mock import patch, MagicMock import json from keepercommander.params import KeeperParams +from keepercommander.service.config.file_handler import ConfigFormatHandler from keepercommander.service.config.service_config import ServiceConfig from keepercommander.service.util.exceptions import ValidationError @@ -43,12 +46,12 @@ def test_save_config_success(self): """Test successful configuration save.""" with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - + mock_format.return_value = 'json' mock_save_json.return_value = self.service_config.config_path - + result = self.service_config.save_config(self.test_config) - + mock_format.assert_called_once() mock_save_json.assert_called_once() self.assertEqual(result, self.service_config.config_path) @@ -57,13 +60,39 @@ def test_save_config_io_error(self): """Test configuration save with IO error.""" with patch.object(self.service_config.format_handler, 'get_config_format') as mock_format, \ patch.object(self.service_config.format_handler, '_save_json') as mock_save_json: - + mock_format.return_value = 'json' mock_save_json.side_effect = IOError("Test error") - + with self.assertRaises(ValidationError): self.service_config.save_config(self.test_config) - + + def test_get_config_format_falls_back_to_legacy_vault_title(self): + """Test config recovery from the legacy Service Mode record title.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_dir = Path(temp_dir) + handler = ConfigFormatHandler(config_dir, {}, {}) + params = MagicMock(spec=KeeperParams) + download_calls = [] + + def download_side_effect(_params, title, _config_dir): + download_calls.append(title) + if title == 'Commander Service Mode': + (config_dir / 'service_config.json').write_text('{}') + return True + return False + + handler.cli_handler = MagicMock() + handler.cli_handler.download_config_from_vault.side_effect = download_side_effect + + with patch('keepercommander.service.core.globals.get_current_params', return_value=params), \ + patch.object(handler, 'encrypt_config_file') as mock_encrypt: + format_type = handler.get_config_format() + + self.assertEqual(format_type, 'json') + self.assertEqual(download_calls, ['Commander Service Mode Config', 'Commander Service Mode']) + mock_encrypt.assert_called_once_with(config_dir / 'service_config.json', config_dir) + @unittest.skip @patch('pathlib.Path.exists') @patch('pathlib.Path.read_text') @@ -131,4 +160,4 @@ def test_update_or_add_record(self, mock_record_handler): self.service_config.update_or_add_record(params) mock_record_handler.update_or_add_record.assert_called_once_with( params, self.service_config.title, self.service_config.config_path - ) \ No newline at end of file + ) From de559b584455c1267942f8954f55f83b7a8d01a5 Mon Sep 17 00:00:00 2001 From: Micah Roberts Date: Thu, 7 May 2026 13:18:32 -0700 Subject: [PATCH 12/26] Add TURN/STUN end-to-end probes and per-test flags to pam tunnel diagnose; fix log noise (#2020) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tunnel_and_connections.py: - Add --turn-test flag: establishes a real WebRTC/TURN connection through the gateway without proxying traffic, reproducing the full ICE path - Add --stun-only flag: same probe with TURN credentials stripped on both sides, restricting ICE to host/srflx candidates; confirms the reflexive path works independently of the relay. Mutually exclusive with --turn-test - Add --probe-duration (default 30s) to hold the connection open; use 360+ to trigger and verify survival of the ~300s permission refresh cycle - Add --probe-count for concurrent connection load testing - Add --stress-test: connection cycling, throughput across 64B–256KB frame sizes, and concurrent load; implies --turn-test - Add --test-dns/--test-aws/--test-tcp/--test-udp/--test-ice/--test-webrtc flags as per-test alternatives to the comma-separated --test string; both styles can be combined and are merged into a single validated set - Add _yellow() color helper for diagnostic output - Section header, connect label, hold message, throughput, and stability labels all adapt to STUN vs TURN mode tunnel_helpers.py: - Add probe_stun_only param to start_rust_tunnel: strips turn_url, turn_username, turn_password from webrtc_settings so Commander's ICE agent gathers no relay candidates; sends stun_only=True to gateway - Add turn.client.relay_conn to webrtc crate suppression lists so the periodic "fail to refresh permissions" error no longer leaks to terminal - Add NullHandler to all loggers configured with propagate=False; without a handler Python's lastResort StreamHandler fires when callHandlers() finds found=0, bypassing propagate=False --- .../tunnel/port_forward/tunnel_helpers.py | 116 ++- .../commands/tunnel_and_connections.py | 688 +++++++++++++++++- 2 files changed, 776 insertions(+), 28 deletions(-) diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 11f18aece..3fdfff16d 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -587,12 +587,15 @@ def __init__(self, name, level=logging.NOTSET): 'webrtc', 'webrtc_ice', 'webrtc_mdns', 'webrtc_dtls', 'webrtc_sctp', 'turn', 'stun', 'webrtc_ice.agent.agent_internal', 'webrtc_ice.agent.agent_gather', 'webrtc_ice.mdns', - 'webrtc_mdns.conn', 'webrtc.peer_connection', 'turn.client' + 'webrtc_mdns.conn', 'webrtc.peer_connection', 'turn.client', + 'turn.client.relay_conn', ] for crate_name in webrtc_crates: crate_logger = logging.getLogger(crate_name) crate_logger.setLevel(logging.WARNING) crate_logger.propagate = False + if not crate_logger.handlers: + crate_logger.addHandler(logging.NullHandler()) # Suppress specific noisy keeper_pam_webrtc_rs sub-modules even in debug mode # These log debug info at ERROR level, so we need to disable them entirely @@ -604,6 +607,8 @@ def __init__(self, name, level=logging.NOTSET): noisy_logger = logging.getLogger(logger_name) noisy_logger.setLevel(logging.CRITICAL + 1) # Disable completely noisy_logger.propagate = False + if not noisy_logger.handlers: + noisy_logger.addHandler(logging.NullHandler()) logging.debug(f"Rust loggers enabled at DEBUG level") enabled_loggers = [name for name in logging.Logger.manager.loggerDict.keys() @@ -630,18 +635,23 @@ def __init__(self, name, level=logging.NOTSET): suppress_logger = logging.getLogger(logger_name) suppress_logger.setLevel(logging.CRITICAL + 1) # Disable completely suppress_logger.propagate = False + if not suppress_logger.handlers: + suppress_logger.addHandler(logging.NullHandler()) # Suppress noisy webrtc dependency logs when not debugging webrtc_crates = [ 'webrtc', 'webrtc_ice', 'webrtc_mdns', 'webrtc_dtls', 'webrtc_sctp', 'turn', 'stun', 'webrtc_ice.agent.agent_internal', 'webrtc_ice.agent.agent_gather', 'webrtc_ice.mdns', - 'webrtc_mdns.conn', 'webrtc.peer_connection', 'turn.client' + 'webrtc_mdns.conn', 'webrtc.peer_connection', 'turn.client', + 'turn.client.relay_conn', ] for crate_name in webrtc_crates: crate_logger = logging.getLogger(crate_name) crate_logger.setLevel(logging.ERROR) crate_logger.propagate = False + if not crate_logger.handlers: + crate_logger.addHandler(logging.NullHandler()) def get_or_create_tube_registry(params): @@ -1330,9 +1340,14 @@ def route_message_to_rust(response_item, tube_registry): answer_sdp = None if isinstance(data_json, dict): logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") - answer_sdp = data_json.get('answer') or data_json.get('sdp') + # 'answer' field is already base64 (initial answer); 'sdp' field is plain text (ICE restart answer) + answer_sdp = data_json.get('answer') + if not answer_sdp: + raw_sdp = data_json.get('sdp') + if raw_sdp: + answer_sdp = base64.b64encode(raw_sdp.encode('utf-8')).decode('ascii') elif data_text.strip().startswith('v=') and 'm=' in data_text: - answer_sdp = data_text.strip() + answer_sdp = base64.b64encode(data_text.strip().encode('utf-8')).decode('ascii') logging.debug("Decrypted data appears to be raw SDP (not JSON), using as answer") if answer_sdp: @@ -1348,7 +1363,18 @@ def route_message_to_rust(response_item, tube_registry): if not tube_id: logging.error(f"No tube ID found for conversation: {conversation_id} (also tried URL-safe version)") else: - set_remote_description_and_parse_version(tube_registry, tube_id, answer_sdp, is_answer=True) + try: + set_remote_description_and_parse_version(tube_registry, tube_id, answer_sdp, is_answer=True) + except RuntimeError as _rte: + if "Invalid signaling state transition from Stable" in str(_rte): + # ICE restart answer arrived after the connection already + # re-established via another path — safe to ignore. + logging.debug( + f"ICE restart answer arrived for already-stable tube " + f"{tube_id} ({conversation_id}) — ignoring late answer" + ) + else: + raise logging.debug("Connection state: SDP answer received, connecting...") session = get_tunnel_session(tube_id) @@ -1362,7 +1388,9 @@ def route_message_to_rust(response_item, tube_registry): logging.warning(f"No signal handler found for tube {tube_id} to send buffered candidates") elif isinstance(data_json, dict) and ("offer" in data_json or data_json.get("type") == "offer"): # Gateway is sending us an ICE restart offer - offer_sdp = data_json.get('sdp') or data_json.get('offer') + # 'sdp' field from gateway is plain SDP (not base64); encode for Rust + raw_offer = data_json.get('sdp') or data_json.get('offer') + offer_sdp = base64.b64encode(raw_offer.encode('utf-8')).decode('ascii') if raw_offer else None if offer_sdp: logging.debug(f"Received ICE restart offer from Gateway for conversation: {conversation_id}") @@ -1503,12 +1531,18 @@ def route_message_to_rust(response_item, tube_registry): errors = payload_data.get('errors', ['']) logging.error(f"Gateway returned errors for {conversation_id}: {errors}") + elif not payload_data.get('is_ok', True): + # Gateway returned an explicit error (is_ok=False) — log the message and move on. + # This includes auth failures (401 on get_leafs), overload responses, etc. + logging.error( + f"Gateway error for {conversation_id}: {payload_data.get('data', 'unknown error')}" + ) elif payload_data.get('data', '') == '': logging.debug("Empty data field an acknowledgment, no action needed") elif payload_data.get('data') and "ice candidate added" in payload_data.get('data').lower(): logging.debug("Received ice candidate added") else: - logging.warning(f"Unhandled payload type for {conversation_id}: {payload_data}") + logging.debug(f"Unhandled payload for {conversation_id}: {payload_data}") else: logging.warning(f"No encrypted payload in message for conversation: {conversation_id}") @@ -1613,7 +1647,7 @@ class TunnelSignalHandler: def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, conversation_id, tube_registry, tube_id=None, trickle_ice=False, websocket_router=None, - conversation_type='tunnel', router_tokens=None, http_session=None): + conversation_type='tunnel', router_tokens=None, http_session=None, silent=False): self.params = params self.record_uid = record_uid self.gateway_uid = gateway_uid @@ -1624,6 +1658,8 @@ def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, self.tube_registry = tube_registry self.tube_id = tube_id self.trickle_ice = trickle_ice + self.silent = silent # Suppress connection-established display (probe/stress mode) + self.tube_close_initiated = False # Set when Rust initiates close (AdminClosed/Normal); skip redundant cleanup close_tube call self.connection_success_shown = False # Track if we've shown success messages self.connection_connected = False # Track if WebRTC connection is established self.ice_sending_in_progress = False # Serialize ICE candidate sending @@ -1645,7 +1681,22 @@ def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, raise Exception("Trickle ICE requires WebSocket support - install with: pip install websockets") def signal_from_rust(self, response: dict): - """Signal callback to handle Rust events and gateway communication""" + """Signal callback to handle Rust events and gateway communication. + + Called by Rust via PyO3. Any unhandled exception here crosses the FFI + boundary as a PyErr which can destabilise the tube, so we wrap the whole + body and log rather than propagate. + """ + try: + self._signal_from_rust_inner(response) + except Exception as _sig_err: + logging.error( + f"Unhandled exception in signal_from_rust " + f"(kind={response.get('kind','?')}, tube={response.get('tube_id','?')}): {_sig_err}", + exc_info=True, + ) + + def _signal_from_rust_inner(self, response: dict): signal_kind = response.get('kind', '') tube_id = response.get('tube_id', '') data = response.get('data', '') @@ -1690,7 +1741,7 @@ def signal_from_rust(self, response: dict): self.connection_success_shown = True # Get tunnel session for record details - if session: + if session and not self.silent: logging.info(f"\n{bcolors.OKGREEN}Connection established successfully.{bcolors.ENDC}") # Display record title if available @@ -1761,6 +1812,8 @@ def signal_from_rust(self, response: dict): logging.error(f"{bcolors.FAIL}Tunnel closed due to critical failure - '{tube_id}': {close_reason.name}{bcolors.ENDC}") elif close_reason.is_user_initiated(): + # Rust is already handling the close — skip redundant close_tube in cleanup() + self.tube_close_initiated = True logging.debug(f"{bcolors.OKBLUE}User-initiated closure of tunnel '{tube_id}': {close_reason.name}{bcolors.ENDC}") elif close_reason.is_retryable(): @@ -2145,10 +2198,14 @@ def _send_restart_offer(self, restart_sdp, tube_id): Similar to _send_ice_candidate_immediately but sends an offer instead of candidates. """ try: - # Format as offer payload for gateway + # Format as offer payload for gateway. + # restart_sdp arrives base64-encoded from Rust (API contract), but the payload + # SDP field is raw text — gateway and vault both decode base64 before using it + # with the Rust boundary (gateway) or WebRTC API (vault). + raw_sdp = bytes_to_string(base64.b64decode(restart_sdp)) offer_payload = { "type": "offer", - "sdp": restart_sdp, + "sdp": raw_sdp, "ice_restart": True # Flag to indicate this is an ICE restart offer } string_data = json.dumps(offer_payload) @@ -2212,7 +2269,7 @@ def _send_restart_offer(self, restart_sdp, tube_id): def cleanup(self): """Cleanup resources""" # Close the tube in Rust registry if it exists - if self.tube_id and self.tube_registry: + if self.tube_id and self.tube_registry and not self.tube_close_initiated: try: logging.debug(f"Closing tube {self.tube_id} in cleanup") self.tube_registry.close_tube(self.tube_id, reason=CloseConnectionReasons.Error) @@ -2234,7 +2291,7 @@ def cleanup(self): logging.debug("TunnelSignalHandler cleaned up") def start_rust_tunnel(params, record_uid, gateway_uid, host, port, - seed, target_host, target_port, socks, trickle_ice=True, record_title=None, allow_supply_host=False, two_factor_value=None): + seed, target_host, target_port, socks, trickle_ice=True, record_title=None, allow_supply_host=False, two_factor_value=None, kind='start', probe_duration=30, probe_turn_only=False, probe_stun_only=False): """ Start a tunnel using Rust WebRTC with trickle ICE via HTTP POST and WebSocket responses. @@ -2328,6 +2385,19 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, params, host, port, target_host, target_port, socks, nonce ) + # For probe mode with turn_only, force relay-only ICE on Commander's side too. + # Both peers must use RTCIceTransportPolicy::Relay for the connection to be + # pure TURN — otherwise the controlling peer (Commander) offers host/srflx + # candidates and ICE selects a direct path, bypassing the relay. + if probe_turn_only: + webrtc_settings["turn_only"] = True + elif probe_stun_only: + # Strip TURN credentials so Commander's ICE agent only gathers host/srflx + # candidates — no relay path is available on either side. + webrtc_settings.pop("turn_url", None) + webrtc_settings.pop("turn_username", None) + webrtc_settings.pop("turn_password", None) + # Determine conversation type (tunnel or protocol-specific) conversation_type = webrtc_settings.get('conversationType', 'tunnel') @@ -2405,7 +2475,8 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, trickle_ice=trickle_ice, conversation_type=conversation_type, router_tokens=router_tokens, - http_session=http_session + http_session=http_session, + silent=(kind == 'probe'), # Suppress connection display for probe/stress mode ) # Store signal handler reference so we can send buffered candidates later @@ -2519,6 +2590,18 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, # Prepare the offer data data = {"offer": offer.get("offer")} + # For probe mode embed the flag in the encrypted payload so the gateway + # routes to its lightweight probe path. We still send kind='start' below + # so the Keeper router injects routerToken (required for TURN auth). + if kind == 'probe': + data["probe"] = True + if probe_duration != 30: + data["probe_duration"] = probe_duration + if probe_turn_only: + data["turn_only"] = True + if probe_stun_only: + data["stun_only"] = True + # If allowSupplyHost is enabled, include the target host and port in the payload if allow_supply_host: data["host"] = { @@ -2701,7 +2784,8 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, "websocket_thread": websocket_thread, "conversation_id": conversation_id_original, # Use original, not base64 encoded "tube_registry": tube_registry, - "status": "connecting" # Indicates async connection in progress + "status": "connecting", # Indicates async connection in progress + "local_port": tunnel_session.port, # Actual bound port (may differ from requested) } except Exception as e: diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 788ff05f0..567565e0c 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -1278,6 +1278,39 @@ class PAMTunnelDiagnoseCommand(Command): help='Comma-separated list of specific WebRTC tests to run. Available: ' 'dns_resolution,aws_connectivity,tcp_connectivity,udp_binding,' 'ice_configuration,webrtc_peer_connection') + pam_cmd_parser.add_argument('--test-dns', required=False, dest='test_dns', action='store_true', + help='Run the dns_resolution WebRTC test only') + pam_cmd_parser.add_argument('--test-aws', required=False, dest='test_aws', action='store_true', + help='Run the aws_connectivity WebRTC test only') + pam_cmd_parser.add_argument('--test-tcp', required=False, dest='test_tcp', action='store_true', + help='Run the tcp_connectivity WebRTC test only') + pam_cmd_parser.add_argument('--test-udp', required=False, dest='test_udp', action='store_true', + help='Run the udp_binding WebRTC test only') + pam_cmd_parser.add_argument('--test-ice', required=False, dest='test_ice', action='store_true', + help='Run the ice_configuration WebRTC test only') + pam_cmd_parser.add_argument('--test-webrtc', required=False, dest='test_webrtc', action='store_true', + help='Run the webrtc_peer_connection WebRTC test only') + pam_cmd_parser.add_argument('--turn-test', required=False, dest='turn_test', action='store_true', + help='Run an end-to-end TURN relay probe through the gateway. ' + 'Establishes a real WebRTC/TURN connection without proxying any traffic, ' + 'reproducing the full ICE negotiation path. Requires a record argument.') + pam_cmd_parser.add_argument('--probe-duration', required=False, dest='probe_duration', type=int, default=30, + help='How long (seconds) to hold the TURN probe connection open after it connects. ' + 'Default 30s. Use 360+ to trigger a TURN permission refresh cycle (~300s TTL) ' + 'and verify the connection survives it. Requires --turn-test.') + pam_cmd_parser.add_argument('--probe-count', required=False, dest='probe_count', type=int, default=1, + help='Number of simultaneous TURN probes to run. Use >1 to reproduce the ' + '"35 concurrent connections" CreatePermission failure scenario. ' + 'Requires --turn-test.') + pam_cmd_parser.add_argument('--stress-test', required=False, dest='stress_test', action='store_true', + help='Full WebRTC stress test through the TURN relay: connection cycling ' + '(open→data→close repeated), throughput across 64B/8KB/64KB/256KB ' + 'frame sizes, and concurrent connection load. Implies --turn-test. ' + 'Requires a record.') + pam_cmd_parser.add_argument('--stun-only', required=False, dest='stun_test', action='store_true', + help='Run an end-to-end probe restricted to STUN/reflexive candidates — ' + 'no TURN relay. Confirms that peer-to-peer ICE works when the relay ' + 'is bypassed. Mutually exclusive with --turn-test. Requires a record.') def get_parser(self): return PAMTunnelDiagnoseCommand.pam_cmd_parser @@ -1298,6 +1331,8 @@ def _bright(cls, t: str) -> str: return cls._c('1;92', t) @classmethod def _dim(cls, t: str) -> str: return cls._c('2;32', t) @classmethod + def _yellow(cls, t: str) -> str: return cls._c('1;93', t) + @classmethod def _red(cls, t: str) -> str: return cls._c('1;91', t) @classmethod def _check(cls) -> str: return cls._bright('\u2713') @@ -1327,7 +1362,8 @@ def _print_header(cls): def _print_result(cls, name: str, passed: bool, detail: str, ms: int, indent: int = 4): icon = cls._check() if passed else cls._cross() ms_str = cls._dim(f' {ms}ms') - body = f'{cls._green(name)} \u00b7 {cls._green(detail)}' if detail else cls._green(name) + _color = cls._green if passed else cls._red + body = f'{_color(name)} \u00b7 {_color(detail)}' if detail else _color(name) print(f'{" " * indent}{icon} {body}{ms_str}') # ── STUN ────────────────────────────────────────────────────────────────── @@ -1534,7 +1570,49 @@ def execute(self, params, **kwargs): timeout = kwargs.get('timeout', 30) verbose = kwargs.get('verbose', False) output_format = kwargs.get('format', 'table') - test_filter = kwargs.get('test_filter') + turn_test = kwargs.get('turn_test', False) + + # Build unified WebRTC test filter from --test string and/or individual --test-* flags + _test_flag_map = { + 'test_dns': 'dns_resolution', + 'test_aws': 'aws_connectivity', + 'test_tcp': 'tcp_connectivity', + 'test_udp': 'udp_binding', + 'test_ice': 'ice_configuration', + 'test_webrtc': 'webrtc_peer_connection', + } + _allowed_tests = set(_test_flag_map.values()) + test_filter_set: set = set() + _test_str = kwargs.get('test_filter') + if _test_str: + _requested = {t.strip() for t in _test_str.split(',')} + _invalid = _requested - _allowed_tests + if _invalid: + raise CommandError('pam tunnel diagnose', + f'Invalid test names: {", ".join(_invalid)}. ' + f'Available: {", ".join(sorted(_allowed_tests))}') + test_filter_set = _requested + for _dest, _name in _test_flag_map.items(): + if kwargs.get(_dest, False): + test_filter_set.add(_name) + stress_test = kwargs.get('stress_test', False) + stun_test = kwargs.get('stun_test', False) + if stress_test: + turn_test = True # --stress-test implies --turn-test + if stun_test and turn_test: + raise CommandError('pam tunnel diagnose', + '--stun-only and --turn-test are mutually exclusive') + if stun_test: + turn_test = True # reuse the probe section + probe_duration = kwargs.get('probe_duration', 30) + probe_count = kwargs.get('probe_count', 1) + probe_stun_only = stun_test + probe_turn_only = not stun_test and turn_test # TURN-only when --turn-test but not --stun-only + + if (turn_test or stress_test) and not record_name: + raise CommandError('pam tunnel diagnose', + '--turn-test / --stun-only requires a record argument: ' + 'pam tunnel diagnose --turn-test') server = params.server # e.g. "keepersecurity.com" or "https://qa.keepersecurity.com" server_host = get_keeper_server_hostname(server) @@ -1666,15 +1744,8 @@ def _record(name: str, passed: bool, detail: str, ms: int): logging.debug(f'Could not get TURN credentials: {exc}', exc_info=True) settings = {'use_turn': True, 'turn_only': False} - if test_filter: - allowed = {'dns_resolution', 'aws_connectivity', 'tcp_connectivity', - 'udp_binding', 'ice_configuration', 'webrtc_peer_connection'} - requested = {t.strip() for t in test_filter.split(',')} - invalid = requested - allowed - if invalid: - print(f"{bcolors.FAIL}Invalid test names: {', '.join(invalid)}{bcolors.ENDC}") - return 1 - settings['test_filter'] = list(requested) + if test_filter_set: + settings['test_filter'] = list(test_filter_set) try: rust_results = tube_registry.test_webrtc_connectivity( @@ -1884,7 +1955,600 @@ def _val(v): print() - # ── section 6: technical details ────────────────────────────────────── + # ── section 6: TURN / STUN-only end-to-end probe ───────────────────── + if turn_test: + _probe_label = 'STUN-Only End-to-End Probe' if stun_test else 'TURN End-to-End Probe' + print(f'{self._bullet()} {self._bright(_probe_label)} ' + f'({probe_count} connection{"s" if probe_count > 1 else ""}, ' + f'hold {probe_duration}s)') + print(f' {self._sep()}') + try: + probe_registry = get_or_create_tube_registry(params) + if not probe_registry: + raise RuntimeError('Rust WebRTC library not available') + + api.sync_down(params) + probe_record_obj = RecordMixin.resolve_single_record(params, record_name) + probe_record_uid = probe_record_obj.record_uid if probe_record_obj else record_name + probe_record = vault.KeeperRecord.load(params, probe_record_uid) + if probe_record is None: + raise RuntimeError( + f'Record "{record_name}" not found in vault — ' + f'run "sync-down" first, or pass the record UID directly' + ) + if not isinstance(probe_record, vault.TypedRecord): + raise RuntimeError( + f'Record "{record_name}" is a legacy v2 record (type: {type(probe_record).__name__}) — ' + f'--turn-test requires a PAM typed record (pamMachine, pamDirectory, etc.)' + ) + + seed_field = probe_record.get_typed_field('trafficEncryptionSeed') + if not seed_field: + raise RuntimeError( + f'Record "{record_name}" (type: {probe_record.record_type}) has no ' + f'trafficEncryptionSeed field — ' + f'pass a PAM resource record (pamMachine / pamDirectory / pamUser), ' + f'not a pamConfiguration record' + ) + probe_seed = base64_to_bytes(seed_field.get_default_value(str).encode('utf-8')) + + probe_gateway_uid = get_gateway_uid_from_record(params, vault, probe_record.record_uid) + if not probe_gateway_uid: + raise RuntimeError( + f'No gateway linked to record "{record_name}" — ' + f'the record must be in a PAM config that has an active gateway' + ) + + # --- Launch probe_count tunnels concurrently --- + import concurrent.futures as _cf + + def _run_one_probe(probe_idx): + """Launch one probe tunnel and return a result dict.""" + probe_port = find_open_port(tried_ports=[], host='127.0.0.1') + if not probe_port: + return {'idx': probe_idx, 'success': False, 'error': 'no open port'} + t0 = time.monotonic() + result = start_rust_tunnel( + params=params, + record_uid=probe_record.record_uid, + gateway_uid=probe_gateway_uid, + host='127.0.0.1', + port=probe_port, + seed=probe_seed, + target_host='127.0.0.1', + target_port=1, + socks=False, + trickle_ice=True, + record_title=probe_record.title, + kind='probe', + probe_duration=probe_duration, + probe_turn_only=probe_turn_only, + probe_stun_only=probe_stun_only, + ) + offer_ms = int((time.monotonic() - t0) * 1000) + if not result or not result.get('success'): + return {'idx': probe_idx, 'success': False, + 'error': (result or {}).get('error', 'offer failed'), 'offer_ms': offer_ms} + return {'idx': probe_idx, 'success': True, 'offer_ms': offer_ms, + 'tube_id': result['tube_id'], 'registry': result['tube_registry'], + 'signal_handler': result.get('signal_handler'), 't0': t0, + 'port': result.get('local_port', probe_port)} + + t_all = time.monotonic() + with _cf.ThreadPoolExecutor(max_workers=min(probe_count, 20)) as pool: + probe_futures = [pool.submit(_run_one_probe, i) for i in range(probe_count)] + probe_launches = [f.result() for f in _cf.as_completed(probe_futures)] + launch_ms = int((time.monotonic() - t_all) * 1000) + + launched_ok = [p for p in probe_launches if p['success']] + launched_fail = [p for p in probe_launches if not p['success']] + + _record( + f'Probe offer{"s" if probe_count > 1 else ""} sent', + len(launched_fail) == 0, + f'{len(launched_ok)}/{probe_count} sent in {launch_ms}ms' + + (f' — failed: {[p["error"] for p in launched_fail]}' if launched_fail else ''), + launch_ms, + ) + + if not launched_ok: + raise RuntimeError('All probes failed to launch') + + # --- Wait for each probe to reach Connected --- + connect_deadline = time.monotonic() + timeout + for p in launched_ok: + p['connected_ms'] = None + p['final_state'] = 'pending' + + while time.monotonic() < connect_deadline: + pending = [p for p in launched_ok if p['connected_ms'] is None + and p['final_state'] not in ('failed', 'closed', 'timeout')] + if not pending: + break + for p in pending: + try: + state = p['registry'].get_connection_state(p['tube_id']) + except Exception: + state = 'closed' + state_l = (state or '').lower() + if state_l == 'connected': + p['connected_ms'] = int((time.monotonic() - p['t0']) * 1000) + p['final_state'] = 'connected' + elif state_l in ('failed', 'closed'): + p['final_state'] = state_l + time.sleep(0.2) + + for p in launched_ok: + if p['connected_ms'] is None and p['final_state'] == 'pending': + p['final_state'] = 'timeout' + + connected_probes = [p for p in launched_ok if p['connected_ms'] is not None] + failed_probes = [p for p in launched_ok if p['connected_ms'] is None] + avg_connect_ms = int(sum(p['connected_ms'] for p in connected_probes) / len(connected_probes)) \ + if connected_probes else 0 + + _record( + 'STUN peer connected' if stun_test else 'TURN relay connected', + len(failed_probes) == 0, + f'{len(connected_probes)}/{len(launched_ok)} connected' + + (f', avg {avg_connect_ms}ms' if connected_probes else '') + + (f' — not connected: {[p["final_state"] for p in failed_probes]}' if failed_probes else ''), + avg_connect_ms, + ) + + # --- Hold phase: monitor state transitions for probe_duration seconds --- + if connected_probes and probe_duration > 0: + _path_label = 'STUN' if stun_test else 'TURN' + print(f' Holding {len(connected_probes)} connection{"s" if len(connected_probes) > 1 else ""} ' + f'for {probe_duration}s to monitor {_path_label} stability...') + + # --- Throughput test: RTT + bulk throughput via the echo tunnel --- + for p in connected_probes: + p['throughput_mbps'] = None + p['rtt_ms'] = None + local_port = p.get('port') + logging.debug(f'Throughput test: probe-{p["idx"]} port={local_port} keys={list(p.keys())}') + if not local_port: + logging.warning(f'Throughput test: probe-{p["idx"]} has no local port — skipping') + continue + try: + import socket as _socket + # The Rust TCP listener binds after the data channel opens, + # which can lag the ICE 'connected' state by a short window. + # Retry a few times with a brief pause before giving up. + sock = None + for _attempt in range(5): + try: + s = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) + s.settimeout(5) + s.connect(('127.0.0.1', local_port)) + sock = s + break + except OSError: + s.close() + if _attempt < 4: + time.sleep(0.5) + if sock is None: + logging.warning( + f'Throughput test: probe-{p["idx"]} could not connect to ' + f'127.0.0.1:{local_port} after 5 attempts — ' + f'Rust listener may not be ready' + ) + continue + + sock.settimeout(15) # 15s: 256 KB at ~17 KB/s minimum + + # RTT: single small ping + t_rtt = time.monotonic() + sock.sendall(b'PING') + sock.recv(4) + p['rtt_ms'] = int((time.monotonic() - t_rtt) * 1000) + + # Throughput: 256 KB in 8 KB chunks. + # Each sendall matches MAX_READ_SIZE so it maps to one WebRTC message. + # At the minimum passing threshold (50 KB/s) this completes in ~5s, + # leaving 10s of headroom before the 15s timeout. + chunk = b'X' * 8192 + total_bytes = 256 * 1024 + sent = 0 + t_start = time.monotonic() + while sent < total_bytes: + sock.sendall(chunk) + sent += len(chunk) + recv = 0 + while recv < total_bytes: + data = sock.recv(65536) + if not data: + break + recv += len(data) + elapsed = time.monotonic() - t_start + p['throughput_mbps'] = round((recv / elapsed) / 1_000_000, 2) + sock.close() + except Exception as tput_err: + logging.warning(f'Throughput test error for probe-{p["idx"]} (port={local_port}): {tput_err}', exc_info=True) + + # Report throughput results + tput_results = [p for p in connected_probes if p['throughput_mbps'] is not None] + if tput_results: + avg_tput = round(sum(p['throughput_mbps'] for p in tput_results) / len(tput_results), 2) + avg_rtt = int(sum(p['rtt_ms'] for p in tput_results if p['rtt_ms']) / len(tput_results)) + # Fixed floor: the probe sends 256 KB from a cold SCTP association, + # so measured throughput is dominated by slow-start, not RTT. + # An RTT-aware formula would demand higher throughput at low RTT, + # producing false failures on fast paths. 0.03 MB/s (30 KB/s) is + # achievable even during SCTP ramp-up at 400ms RTT, and any relay + # delivering less than that is genuinely broken. + _tput_threshold = 0.03 + _record( + f'{_path_label} throughput', + avg_tput >= _tput_threshold, + f'{avg_tput} MB/s avg over {_path_label} path · RTT {avg_rtt}ms', + int(avg_tput * 1000), + ) + else: + print(f' (throughput test skipped — no data returned from echo path)') + + # Base hold_end on connection time, not throughput-test completion. + # The gateway auto-closes probe_duration seconds after the probe STARTED, + # so align the monitoring window to the first connected probe's t0. + earliest_t0 = min(p['t0'] for p in connected_probes) + hold_end = earliest_t0 + probe_duration + (avg_connect_ms / 1000) + # Per-probe tracking: count disconnects and reconnects + for p in connected_probes: + p['disconnects'] = 0 + p['reconnects'] = 0 + p['last_state'] = 'connected' + p['died'] = False + + while time.monotonic() < hold_end: + for p in connected_probes: + if p['died']: + continue + try: + state = p['registry'].get_connection_state(p['tube_id']) + except Exception: + state = 'closed' + state_l = (state or '').lower() + + if state_l != p['last_state']: + elapsed = int((time.monotonic() - p['t0'])) + if state_l == 'disconnected': + p['disconnects'] += 1 + print(f' [{elapsed}s] probe-{p["idx"]}: ' + f'{self._yellow("DISCONNECTED")} — ICE restart should fire') + elif state_l == 'connected' and p['last_state'] == 'disconnected': + p['reconnects'] += 1 + print(f' [{elapsed}s] probe-{p["idx"]}: ' + f'{self._green("RECONNECTED")} via ICE restart') + elif state_l == 'failed': + p['died'] = True + print(f' [{elapsed}s] probe-{p["idx"]}: ' + f'{self._red("DIED")} (state=failed) — ' + f'ICE failed, tube removed from registry') + elif state_l == 'closed': + # 'closed' is the normal probe auto-close at probe_duration. + # Only count as death if it fired well before the deadline. + remaining = hold_end - time.monotonic() + if remaining > 10: + p['died'] = True + print(f' [{elapsed}s] probe-{p["idx"]}: ' + f'{self._red("DIED")} (state=closed, {int(remaining)}s early) — ' + f'tube removed unexpectedly') + else: + print(f' [{elapsed}s] probe-{p["idx"]}: ' + f'{self._green("CLOSED")} (probe auto-close)') + p['last_state'] = state_l + time.sleep(1.0) + + # Summarise hold phase + survived = [p for p in connected_probes if not p['died']] + died = [p for p in connected_probes if p['died']] + total_disc = sum(p['disconnects'] for p in connected_probes) + total_rec = sum(p['reconnects'] for p in connected_probes) + + _record( + f'{_path_label} stability ({probe_duration}s hold)', + len(died) == 0, + f'{len(survived)}/{len(connected_probes)} survived' + + (f', {total_disc} disconnect(s), {total_rec} ICE restart(s)' if total_disc else ', no interruptions') + + (f' — {len(died)} died permanently' if died else ''), + probe_duration * 1000, + ) + + if total_disc > 0 and total_rec == total_disc: + print(f' {self._green("ICE restart fix working:")} all disconnects recovered automatically') + elif total_disc > 0 and total_rec < total_disc: + print(f' {self._red("ICE restart fix incomplete:")} {total_disc - total_rec} disconnect(s) did not recover') + + # --- Clean up all probes --- + for p in probe_launches: + if not p['success']: + continue + try: + p['registry'].close_tube(p['tube_id']) + except Exception: + pass + sh = p.get('signal_handler') + if sh: + try: + sh.cleanup() + except Exception: + pass + + except Exception as exc: + _record('TURN probe', False, str(exc)[:70], 0) + logging.debug('TURN probe error', exc_info=True) + + print() + + # ── section 7: stress test ──────────────────────────────────────────── + if stress_test: + print(f'{self._bullet()} {self._bright("WebRTC Stress Test")} (TURN relay)') + print(f' {self._sep()}') + if not record_name: + print(f' {self._cross()} {self._red("--stress-test requires a record argument")}') + else: + import socket as _sock + import concurrent.futures as _cf2 + + def _one_stress_probe(sp=None): + """Single connected probe for stress use — returns (port, registry, tube_id, sh) or None. + Pass sp to use a pre-allocated port (for concurrent calls); omit for sequential use. + """ + if sp is None: + sp = find_open_port(tried_ports=[], host='127.0.0.1') + if not sp: + return None + r = start_rust_tunnel( + params=params, + record_uid=probe_record.record_uid, + gateway_uid=probe_gateway_uid, + host='127.0.0.1', port=sp, + seed=probe_seed, target_host='127.0.0.1', target_port=1, + socks=False, trickle_ice=True, + record_title=probe_record.title, + kind='probe', probe_duration=120, probe_turn_only=probe_turn_only, + ) + if not r or not r.get('success'): + return None + deadline = time.monotonic() + 20 + reg = r['tube_registry'] + tid = r['tube_id'] + while time.monotonic() < deadline: + if (reg.get_connection_state(tid) or '').lower() == 'connected': + return sp, reg, tid, r.get('signal_handler') + time.sleep(0.2) + try: + reg.close_tube(tid) + except Exception: pass + + return None + + def _tput_via_port(port): + """Connect to local tunnel port, measure RTT and aggregate throughput. + Sends 256 KB as 32 × 8 KB messages — matching the Rust channel's + MAX_READ_SIZE exactly so the measurement reflects real wire behaviour. + Returns (mbps, rtt_ms) or (None, None) on failure. + """ + try: + s = _sock.socket(_sock.AF_INET, _sock.SOCK_STREAM) + s.settimeout(15) + s.connect(('127.0.0.1', port)) + # Warmup ping — also gives us RTT for the threshold formula. + _t_rtt = time.monotonic() + s.sendall(b'\x00' * 64) + s.recv(64) + rtt_ms = int((time.monotonic() - _t_rtt) * 1000) + # Bulk: 256 KB in 8 KB chunks so each send maps to one WebRTC + # message (RECEIVE_MTU = 8 KB in webrtc-data). More chunks = + # more frames in-flight = better pipelining signal. + total = 256 * 1024 + chunk = b'\x01' * 8192 + t0 = time.monotonic() + sent = 0 + while sent < total: + s.sendall(chunk) + sent += len(chunk) + got = 0 + while got < total: + d = s.recv(65536) + if not d: + break + got += len(d) + elapsed = time.monotonic() - t0 + s.close() + return round(got / elapsed / 1_000_000, 3), rtt_ms + except Exception as e: + logging.debug(f'Stress tput error: {e}') + return None, None + + CYCLES = 5 + CONCURRENCY = 5 + + # --- 1. Cycle test: open → data → close, repeated --- + print(f' {self._bright("1. Connection cycling")} ({CYCLES} open/data/close cycles)') + cycle_ok = 0 + for cycle in range(CYCLES): + info = _one_stress_probe() + if info: + sp, reg, tid, sh = info + try: + reg.close_tube(tid) + if sh: sh.tube_close_initiated = True + except Exception: pass + if sh: + try: sh.cleanup() + except Exception: pass + cycle_ok += 1 + _record('Cycle open/close', cycle_ok == CYCLES, + f'{cycle_ok}/{CYCLES} cycles completed', cycle_ok * 1000) + + # --- 2. Throughput --- + print(f' {self._bright("2. Throughput")} (32 × 8 KB messages over TURN relay)') + info = _one_stress_probe() + if info: + sp, reg, tid, sh = info + mbps, rtt_ms = _tput_via_port(sp) + if mbps is not None: + # RTT-aware threshold: same formula as TURN probe section. + # 32 × 8 KB in-flight; SCTP window / RTT gives the pipelining ceiling. + _rtt = max(rtt_ms or 1000, 1) + _threshold = max(0.05, round(65536 / _rtt * 0.5 / 1000, 3)) + _record('TURN throughput', + mbps >= _threshold, + f'{mbps} MB/s · RTT {rtt_ms}ms', + int(mbps * 1000)) + else: + _record('TURN throughput', False, 'could not connect for throughput test', 0) + try: + reg.close_tube(tid) + if sh: sh.tube_close_initiated = True + except Exception: pass + if sh: + try: sh.cleanup() + except Exception: pass + else: + _record('TURN throughput', False, 'could not connect for throughput test', 0) + + # --- 3. Concurrent connections --- + # Pre-allocate all ports sequentially so no two workers race on find_open_port. + print(f' {self._bright(f"3. Concurrent connections")} ({CONCURRENCY} simultaneous)') + conc_ports, tried = [], [] + for _ in range(CONCURRENCY): + p = find_open_port(tried_ports=tried, host='127.0.0.1') + if p: + conc_ports.append(p) + tried.append(p) + with _cf2.ThreadPoolExecutor(max_workers=CONCURRENCY) as pool: + conc_futures = [pool.submit(_one_stress_probe, p) for p in conc_ports] + conc_results = [f.result() for f in _cf2.as_completed(conc_futures)] + conc_ok = sum(1 for r in conc_results if r is not None) + for r in conc_results: + if r: + sp, reg, tid, sh = r + try: + reg.close_tube(tid) + if sh: sh.tube_close_initiated = True + except Exception: pass + if sh: + try: sh.cleanup() + except Exception: pass + + _record(f'Concurrent {CONCURRENCY}x', conc_ok == CONCURRENCY, + f'{conc_ok}/{CONCURRENCY} connected simultaneously', conc_ok * 1000) + + # --- 4. Interactive latency under bulk load --- + # Two TCP connections to the same probe port = two conn_no streams + # on the same WebRTC data channel. conn_no=1 sends 512 KB bulk; + # conn_no=2 sends 64-byte pings every 200ms and measures RTT. + # This tests whether the EventDrivenSender saw-tooth fix allows + # interactive frames to interleave with bulk frames. + print(f' {self._bright("4. Interactive latency under load")}') + import threading as _threading + info = _one_stress_probe() + if info: + sp, reg, tid, sh = info + rtt_under_load: list = [] + _bulk_done = _threading.Event() + + # Baseline: single ping before bulk starts + _baseline_rtt = None + try: + _bs = _sock.socket(_sock.AF_INET, _sock.SOCK_STREAM) + _bs.settimeout(5) + _bs.connect(('127.0.0.1', sp)) + _t0 = time.monotonic() + _bs.sendall(b'B' * 64) + _bs.recv(64) + _baseline_rtt = int((time.monotonic() - _t0) * 1000) + _bs.close() + except Exception as _e: + logging.debug(f'Latency baseline error: {_e}') + + def _bulk_sender(): + try: + s = _sock.socket(_sock.AF_INET, _sock.SOCK_STREAM) + s.settimeout(30) + s.connect(('127.0.0.1', sp)) + total = 512 * 1024 + chunk = b'\x02' * 8192 + sent = 0 + while sent < total: + s.sendall(chunk) + sent += len(chunk) + got = 0 + while got < total: + d = s.recv(65536) + if not d: + break + got += len(d) + s.close() + except Exception as _e: + logging.debug(f'Bulk sender error: {_e}') + finally: + _bulk_done.set() + + def _latency_sampler(): + try: + s = _sock.socket(_sock.AF_INET, _sock.SOCK_STREAM) + s.settimeout(5) + s.connect(('127.0.0.1', sp)) + while not _bulk_done.is_set(): + try: + _t = time.monotonic() + s.sendall(b'P' * 64) + s.recv(64) + rtt_under_load.append( + int((time.monotonic() - _t) * 1000) + ) + except OSError: + break + time.sleep(0.2) + s.close() + except Exception as _e: + logging.debug(f'Latency sampler error: {_e}') + + _bt = _threading.Thread(target=_bulk_sender, daemon=True) + _lt = _threading.Thread(target=_latency_sampler, daemon=True) + _lt.start() + _bt.start() + _bt.join(timeout=30) + _bulk_done.set() + _lt.join(timeout=2) + + try: + reg.close_tube(tid) + if sh: sh.tube_close_initiated = True + except Exception: pass + if sh: + try: sh.cleanup() + except Exception: pass + + if rtt_under_load and _baseline_rtt: + _avg = int(sum(rtt_under_load) / len(rtt_under_load)) + _max = max(rtt_under_load) + # Pass if average latency under load stays within 5× baseline. + # With the 50-frame drain batch, interactive frames interleave + # between bursts; with the old 2000-frame batch the multiplier + # could reach 100× or more at TURN relay speeds. + _ok = _avg <= _baseline_rtt * 5 + _record( + 'Latency under load', + _ok, + f'avg {_avg}ms, max {_max}ms · baseline {_baseline_rtt}ms' + f' · {len(rtt_under_load)} samples', + _avg, + ) + elif not rtt_under_load: + _record('Latency under load', False, 'no samples collected', 0) + else: + _record('Latency under load', False, 'baseline RTT unavailable', 0) + else: + _record('Latency under load', False, 'could not connect probe', 0) + + print() + + # ── section 8: technical details ────────────────────────────────────── print(f'{self._bullet()} {self._bright("Technical Details")}') print(f' {self._sep()}') From 4d1f1b9ae5e77b66d50ec121c4109c647873def8 Mon Sep 17 00:00:00 2001 From: idimov-keeper <78815270+idimov-keeper@users.noreply.github.com> Date: Thu, 7 May 2026 16:43:08 -0500 Subject: [PATCH 13/26] Add pam connection jit command for managing JIT settings on PAM resource records (#2026) --- .../commands/pam_import/keeper_ai_settings.py | 180 ++++++++++++- .../commands/tunnel_and_connections.py | 247 ++++++++++++++++++ 2 files changed, 423 insertions(+), 4 deletions(-) diff --git a/keepercommander/commands/pam_import/keeper_ai_settings.py b/keepercommander/commands/pam_import/keeper_ai_settings.py index d787f9a41..652870684 100644 --- a/keepercommander/commands/pam_import/keeper_ai_settings.py +++ b/keepercommander/commands/pam_import/keeper_ai_settings.py @@ -510,7 +510,8 @@ def set_resource_jit_settings( params: KeeperParams, resource_uid: str, settings: Dict[str, Any], - config_uid: Optional[str] = None + config_uid: Optional[str] = None, + allow_empty: bool = False ) -> bool: """ Save JIT settings to the DAG DATA edge with path 'jit_settings' for a resource. @@ -533,9 +534,11 @@ def set_resource_jit_settings( True if settings were saved successfully, False otherwise. """ try: - # Return False if settings dict is empty - if not settings or not isinstance(settings, dict): - logging.debug(f"JIT settings empty or invalid for {resource_uid}, skipping") + if not isinstance(settings, dict): + logging.debug(f"JIT settings invalid for {resource_uid}, skipping") + return False + if not settings and not allow_empty: + logging.debug(f"JIT settings empty for {resource_uid}, skipping") return False # Get the record to access the record key @@ -931,3 +934,172 @@ def inspect_resource_in_graph( logging.error(f"Error inspecting graph for {resource_uid}: {e}", exc_info=True) result["error"] = str(e) return result + + +def get_resource_domain_dir_uid( + params: KeeperParams, + resource_uid: str, + config_uid: Optional[str] = None +) -> Optional[str]: + """ + Return the pamDirectory UID linked to the resource via a LINK edge with path 'domain'. + Returns None if no such link exists. + """ + try: + record = vault.KeeperRecord.load(params, resource_uid) + if not record: + return None + + record_key = None + if resource_uid in params.record_cache: + record_key = params.record_cache[resource_uid].get('record_key_unencrypted') + if not record_key: + return None + + if not config_uid: + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, resource_uid) + if not config_uid: + config_uid = resource_uid + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + dag_record = PasswordRecord() + dag_record.record_uid = config_uid + dag_record.record_key = record_key + + conn = Connection( + params=params, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token, + transmission_key=transmission_key, + use_write_protobuf=True + ) + linking_dag = DAG( + conn=conn, + record=dag_record, + graph_id=0, + write_endpoint=PamEndpoints.PAM + ) + linking_dag.load() + + resource_vertex = linking_dag.get_vertex(resource_uid) + if not resource_vertex: + return None + + for edge in resource_vertex.edges: + if (edge and edge.edge_type == EdgeType.LINK and + edge.path == 'domain' and edge.active): + return edge.head_uid + + return None + + except Exception as e: + logging.error(f"Error getting domain dir UID for {resource_uid}: {e}", exc_info=True) + return None + + +def set_resource_domain_dir( + params: KeeperParams, + resource_uid: str, + dir_uid: str, + config_uid: Optional[str] = None +) -> bool: + """ + Add or replace the LINK edge from resource to pamDirectory with path 'domain'. + If a domain LINK to a different pamDirectory already exists, it is disconnected first. + """ + try: + record = vault.KeeperRecord.load(params, resource_uid) + if not record: + logging.warning(f"Record {resource_uid} not found") + return False + + record_key = None + if resource_uid in params.record_cache: + record_key = params.record_cache[resource_uid].get('record_key_unencrypted') + if not record_key: + logging.warning(f"Record key not available for {resource_uid}") + return False + + if not config_uid: + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, resource_uid) + if not config_uid: + config_uid = resource_uid + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + dag_record = PasswordRecord() + dag_record.record_uid = config_uid + dag_record.record_key = record_key + + conn = Connection( + params=params, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token, + transmission_key=transmission_key, + use_write_protobuf=True + ) + linking_dag = DAG( + conn=conn, + record=dag_record, + graph_id=0, + write_endpoint=PamEndpoints.PAM, + decrypt=True + ) + linking_dag.load() + + resource_vertex = linking_dag.get_vertex(resource_uid) + if not resource_vertex: + logging.warning(f"Resource vertex {resource_uid} not found in DAG") + return False + + # If a domain LINK to a different pamDirectory exists, disconnect it first + old_dir_uid = None + for edge in resource_vertex.edges: + if (edge and edge.edge_type == EdgeType.LINK and + edge.path == 'domain' and edge.active): + old_dir_uid = edge.head_uid + break + + if old_dir_uid and old_dir_uid != dir_uid: + old_dir_vertex = linking_dag.get_vertex(old_dir_uid) + if old_dir_vertex: + resource_vertex.disconnect_from(old_dir_vertex) + logging.debug(f"Disconnected old domain LINK edge to {old_dir_uid}") + + dir_vertex = linking_dag.get_vertex(dir_uid) + if not dir_vertex: + logging.warning(f"Directory vertex {dir_uid} not found in DAG") + return False + + resource_vertex.belongs_to(dir_vertex, EdgeType.LINK, path="domain", content={}) + linking_dag.save() + + logging.debug(f"Successfully set domain dir link for resource {resource_uid} -> {dir_uid}") + return True + + except Exception as e: + logging.error(f"Error setting domain dir for {resource_uid}: {e}", exc_info=True) + return False + + +def remove_resource_jit_settings( + params: KeeperParams, + resource_uid: str, + config_uid: Optional[str] = None +) -> bool: + """ + Remove JIT settings by overwriting the 'jit_settings' DATA edge with an empty dict. + + Implementation note: DATA edges in the DAG library use `active` as a versioning + marker (auto-managed by add_data when superseding), not a visibility toggle — + get_resource_settings reads the highest-version edge regardless of `active`, + and EdgeType.DELETION self-loops are not path-scoped in the library's lookup + logic. Writing an empty {} via the same set_resource_jit_settings path that + creation uses gives a clean, reliable removal: the new edge becomes the + highest version, and _do_show treats {} as 'No JIT settings configured'. + """ + if not set_resource_jit_settings(params, resource_uid, {}, config_uid, allow_empty=True): + return False + logging.debug(f"Cleared jit_settings for resource {resource_uid}") + return True diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 567565e0c..77ef48fff 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -34,6 +34,7 @@ remove_field, start_rust_tunnel, get_tunnel_session, unregister_tunnel_session, CloseConnectionReasons, \ wait_for_tunnel_connection, create_rust_webrtc_settings, escalate_close, \ print_above_keeper_prompt +from .pam.router_helper import get_dag_leafs from .tunnel_registry import ( PARENT_GRACE_SECONDS, is_pid_alive, @@ -83,6 +84,7 @@ def __init__(self): # self.register_command('start', PAMConnectionStartCommand(), 'Start Connection', 's') # self.register_command('stop', PAMConnectionStopCommand(), 'Stop Connection', 'x') self.register_command('edit', PAMConnectionEditCommand(), 'Edit Connection settings', 'e') + self.register_command('jit', PAMConnectionJitCommand(), 'View/update JIT settings', 'j') self.default_verb = 'edit' @@ -2878,6 +2880,251 @@ def execute(self, params, **kwargs): # Print out PAM Settings if not kwargs.get("silent", False): tdag.print_tunneling_config(record_uid, record.get_typed_field('pamSettings'), config_uid) + +class PAMConnectionJitCommand(Command): + parser = argparse.ArgumentParser(prog='pam connection jit') + parser.add_argument('record', type=str, action='store', + help='Record UID, path, or title (pamMachine, pamDatabase, or pamDirectory)') + parser.add_argument('--configuration', '-c', type=str, dest='configuration', action='store', default=None, + help='PAM Configuration UID or title (required when record is linked to 2+ configs)') + parser.add_argument('--create-ephemeral', type=str, dest='create_ephemeral', action='store', default=None, + metavar='BOOL', help='Create ephemeral account (true/false)') + parser.add_argument('--elevate', type=str, dest='elevate', action='store', default=None, + metavar='BOOL', help='Elevate account (true/false)') + parser.add_argument('--elevation-method', dest='elevation_method', choices=['group', 'role'], default=None, + help='Elevation method (group or role)') + parser.add_argument('--elevation-string', dest='elevation_string', type=str, action='store', default=None, + help='Elevation string') + parser.add_argument('--base-distinguished-name', dest='base_distinguished_name', type=str, action='store', + default=None, help='Base distinguished name') + parser.add_argument('--ephemeral-account-type', dest='ephemeral_account_type', + choices=['linux', 'mac', 'windows', 'domain'], default=None, + help='Ephemeral account type') + parser.add_argument('--pam-directory-record', dest='pam_directory_record', type=str, action='store', + default=None, + help='pamDirectory record UID/path/title (required when ephemeral-account-type=domain)') + parser.add_argument('--remove', dest='remove', action='store_true', default=False, + help='Remove all JIT settings (jit_settings path only)') + parser.add_argument('--show', dest='show', action='store_true', default=False, + help='Show current JIT settings') + + def get_parser(self): + return PAMConnectionJitCommand.parser + + def execute(self, params, **kwargs): + record_name = kwargs.get('record') + configuration = kwargs.get('configuration') + remove_flag = kwargs.get('remove', False) + show_flag = kwargs.get('show', False) + + jit_option_keys = ['create_ephemeral', 'elevate', 'elevation_method', 'elevation_string', + 'base_distinguished_name', 'ephemeral_account_type', 'pam_directory_record'] + jit_options_provided = any(kwargs.get(k) is not None for k in jit_option_keys) + + # Mutual exclusion checks + if remove_flag and show_flag: + raise CommandError('', f'{bcolors.FAIL}--remove cannot be used with --show.{bcolors.ENDC}') + if remove_flag and jit_options_provided: + raise CommandError('', f'{bcolors.FAIL}--remove cannot be used with any other option.{bcolors.ENDC}') + if show_flag and jit_options_provided: + raise CommandError('', f'{bcolors.FAIL}--show cannot be used with --remove or any JIT option.{bcolors.ENDC}') + if not remove_flag and not show_flag and not jit_options_provided: + raise CommandError('', f'{bcolors.FAIL}Provide at least one JIT option, --remove, or --show.{bcolors.ENDC}') + + # 1. Resolve record: try UID/path first, then title search across PAM resource types + record = RecordMixin.resolve_single_record(params, record_name) + if record is None: + pam_resource_types = {'pamMachine', 'pamDatabase', 'pamDirectory'} + matches = [] + for uid in params.record_cache: + rec = vault.KeeperRecord.load(params, uid) + if rec and rec.record_type in pam_resource_types and rec.title.casefold() == record_name.casefold(): + matches.append(rec) + if len(matches) == 0: + raise CommandError('', f'{bcolors.FAIL}Record "{record_name}" not found.{bcolors.ENDC}') + if len(matches) > 1: + raise CommandError('', + f'{bcolors.FAIL}Multiple records match title "{record_name}"; use UID or path.{bcolors.ENDC}') + record = matches[0] + + # 2. Record type check + if not isinstance(record, vault.TypedRecord) or \ + record.record_type not in ('pamMachine', 'pamDatabase', 'pamDirectory'): + raise CommandError('', + f'{bcolors.FAIL}JIT settings are only supported on pamMachine, pamDatabase, ' + f'and pamDirectory records.{bcolors.ENDC}') + + record_uid = record.record_uid + + # 3. Resolve PAM config(s) via DAG + encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) + config_leafs = get_dag_leafs(params, encrypted_session_token, encrypted_transmission_key, record_uid) + config_uids = [leaf.get('value') for leaf in (config_leafs or []) if leaf.get('value')] + + if not config_uids: + raise CommandError('', + f'{bcolors.FAIL}Record is not set up for connections. ' + f'Use: pam connection edit {record_uid} --config --enable-connections. ' + f'List configs: pam config list.{bcolors.ENDC}') + + if len(config_uids) == 1: + config_uid = config_uids[0] + else: + if not configuration: + raise CommandError('', + f'{bcolors.FAIL}Record is linked to multiple PAM Configurations; ' + f'specify --configuration|-c.{bcolors.ENDC}') + # Resolve --configuration: try UID/path, then title among version-6 records + config_rec = RecordMixin.resolve_single_record(params, configuration) + if config_rec is None: + for uid in params.record_cache: + if params.record_cache[uid].get('version', 0) == 6: + r = vault.KeeperRecord.load(params, uid) + if r and r.title.casefold() == configuration.casefold(): + config_rec = r + break + if config_rec is None: + raise CommandError('', + f'{bcolors.FAIL}PAM Configuration "{configuration}" not found.{bcolors.ENDC}') + config_uid = config_rec.record_uid + if config_uid not in config_uids: + raise CommandError('', + f'{bcolors.FAIL}PAM Configuration "{configuration}" is not linked to this record.{bcolors.ENDC}') + + # 4. Branch + if show_flag: + self._do_show(params, record, record_uid, config_uid) + elif remove_flag: + self._do_remove(params, record_uid, config_uid) + else: + self._do_set(params, kwargs, record_uid, config_uid) + + def _do_show(self, params, record, record_uid, config_uid): + from .pam_import.keeper_ai_settings import get_resource_jit_settings, get_resource_domain_dir_uid + + settings = get_resource_jit_settings(params, record_uid, config_uid) + + print(f'\nJIT Settings for {bcolors.OKBLUE}{record.title}{bcolors.ENDC} ({record_uid}):') + if not settings: + print(f' {bcolors.WARNING}No JIT settings configured.{bcolors.ENDC}\n') + return + + field_labels = [ + ('createEphemeral', 'Create Ephemeral'), + ('elevate', 'Elevate'), + ('elevationMethod', 'Elevation Method'), + ('elevationString', 'Elevation String'), + ('baseDistinguishedName', 'Base Distinguished Name'), + ('ephemeralAccountType', 'Ephemeral Account Type'), + ] + for key, label in field_labels: + if key in settings: + print(f' {label}: {settings[key]}') + + if settings.get('ephemeralAccountType') == 'domain': + domain_dir_uid = get_resource_domain_dir_uid(params, record_uid, config_uid) + if domain_dir_uid: + dir_rec = vault.KeeperRecord.load(params, domain_dir_uid) + dir_title = dir_rec.title if dir_rec else domain_dir_uid + print(f' PAM Directory: {dir_title} ({domain_dir_uid})') + + print() + + def _do_remove(self, params, record_uid, config_uid): + from .pam_import.keeper_ai_settings import remove_resource_jit_settings + + ok = remove_resource_jit_settings(params, record_uid, config_uid) + if ok: + print(f'{bcolors.OKGREEN}JIT settings removed successfully.{bcolors.ENDC}') + else: + print(f'{bcolors.WARNING}No JIT settings found or removal failed.{bcolors.ENDC}') + + def _do_set(self, params, kwargs, record_uid, config_uid): + from .pam_import.keeper_ai_settings import (get_resource_jit_settings, set_resource_jit_settings, + set_resource_domain_dir) + + # Merge provided CLI args into existing settings + existing = get_resource_jit_settings(params, record_uid, config_uid) or {} + jit_dict = dict(existing) + + create_ephemeral = kwargs.get('create_ephemeral') + if create_ephemeral is not None: + jit_dict['createEphemeral'] = value_to_boolean(create_ephemeral) + + elevate = kwargs.get('elevate') + if elevate is not None: + jit_dict['elevate'] = value_to_boolean(elevate) + + elevation_method = kwargs.get('elevation_method') + if elevation_method is not None: + jit_dict['elevationMethod'] = elevation_method + + elevation_string = kwargs.get('elevation_string') + if elevation_string is not None: + jit_dict['elevationString'] = elevation_string + + base_dn = kwargs.get('base_distinguished_name') + if base_dn is not None: + jit_dict['baseDistinguishedName'] = base_dn + + ephemeral_type = kwargs.get('ephemeral_account_type') + if ephemeral_type is not None: + jit_dict['ephemeralAccountType'] = ephemeral_type + + # Validate domain requires --pam-directory-record + pam_dir_record = kwargs.get('pam_directory_record') + if jit_dict.get('ephemeralAccountType') == 'domain' and not pam_dir_record: + raise CommandError('', + f'{bcolors.FAIL}ephemeral-account-type=domain requires --pam-directory-record ' + f'(pamDirectory UID/path/title with directory_type active_directory or openldap).{bcolors.ENDC}') + if pam_dir_record and jit_dict.get('ephemeralAccountType') != 'domain': + raise CommandError('', + f'{bcolors.FAIL}--pam-directory-record requires --ephemeral-account-type=domain.{bcolors.ENDC}') + + # Resolve pamDirectory if provided + pam_dir_uid = None + if pam_dir_record: + dir_rec = RecordMixin.resolve_single_record(params, pam_dir_record) + if dir_rec is None: + matches = [] + for uid in params.record_cache: + rec = vault.KeeperRecord.load(params, uid) + if (rec and rec.record_type == 'pamDirectory' and + rec.title.casefold() == pam_dir_record.casefold()): + matches.append(rec) + if not matches: + raise CommandError('', + f'{bcolors.FAIL}PAM Directory record "{pam_dir_record}" not found.{bcolors.ENDC}') + if len(matches) > 1: + raise CommandError('', + f'{bcolors.FAIL}Multiple records match title "{pam_dir_record}"; use UID or path.{bcolors.ENDC}') + dir_rec = matches[0] + + if not isinstance(dir_rec, vault.TypedRecord) or dir_rec.record_type != 'pamDirectory': + raise CommandError('', + f'{bcolors.FAIL}--pam-directory-record must reference a pamDirectory record.{bcolors.ENDC}') + + dir_type_field = dir_rec.get_typed_field('directoryType') + dir_type = dir_type_field.get_default_value(str) if dir_type_field else None + if dir_type not in ('active_directory', 'openldap'): + raise CommandError('', + f'{bcolors.FAIL}PAM Directory must have directory_type "active_directory" or "openldap" ' + f'(found: "{dir_type}").{bcolors.ENDC}') + + pam_dir_uid = dir_rec.record_uid + + # Save JIT settings + ok = set_resource_jit_settings(params, record_uid, jit_dict, config_uid) + if not ok: + raise CommandError('', f'{bcolors.FAIL}Failed to save JIT settings.{bcolors.ENDC}') + + # Add/replace domain LINK edge if pamDirectory provided + if pam_dir_uid: + set_resource_domain_dir(params, record_uid, pam_dir_uid, config_uid) + + print(f'{bcolors.OKGREEN}JIT settings saved successfully.{bcolors.ENDC}') + + class PAMRbiEditCommand(Command): choices = ['on', 'off', 'default'] parser = argparse.ArgumentParser(prog='pam rbi edit') From b5f52d118be5948d58238f11c29040818d233393 Mon Sep 17 00:00:00 2001 From: pvagare-ks Date: Sat, 9 May 2026 04:08:31 +0530 Subject: [PATCH 14/26] KC-1242: Fix KeeperDrive role labels, kd-shortcut titles and remove kd-rndir inherit flags (#2031) * KC-1242: Fix KeeperDrive role labels, kd-shortcut titles and remove kd-rndir inherit flags (#2030) * remove summary.md --- .../commands/keeper_drive/__init__.py | 26 +-- .../commands/keeper_drive/folder_commands.py | 10 +- .../commands/keeper_drive/helpers.py | 10 +- .../commands/keeper_drive/parsers.py | 6 - .../commands/keeper_drive/record_commands.py | 58 ++++-- .../commands/keeper_drive/sharing_commands.py | 6 +- keepercommander/proto/Summary.md | 170 ------------------ unit-tests/test_keeper_drive.py | 13 +- 8 files changed, 86 insertions(+), 213 deletions(-) delete mode 100644 keepercommander/proto/Summary.md diff --git a/keepercommander/commands/keeper_drive/__init__.py b/keepercommander/commands/keeper_drive/__init__.py index e9bce5fd4..7ca527ea2 100644 --- a/keepercommander/commands/keeper_drive/__init__.py +++ b/keepercommander/commands/keeper_drive/__init__.py @@ -81,18 +81,18 @@ def register_commands(commands): def register_command_info(aliases, command_info): """Register command help descriptions.""" - command_info['kd-mkdir'] = 'Create a KeeperDrive folder (v3 API)' - command_info['kd-record-add'] = 'Create a KeeperDrive record (v3 API)' - command_info['kd-record-update'] = 'Update a KeeperDrive record (v3 API)' + command_info['kd-mkdir'] = 'Create a KeeperDrive folder' + command_info['kd-record-add'] = 'Create a KeeperDrive record' + command_info['kd-record-update'] = 'Update a KeeperDrive record' command_info['kd-rndir'] = 'Rename a KeeperDrive folder' command_info['kd-list'] = 'List Keeper Drive folders and records' - command_info['kd-share-folder'] = 'Grant/update/revoke folder sharing (v3 API)' - command_info['kd-record-details'] = 'Get record metadata (title, color) (v3 API)' - command_info['kd-share-record'] = 'Grant/update/revoke record sharing (v3 API)' - command_info['kd-record-permission'] = 'Modify sharing permissions of records in a folder (v3 API)' - command_info['kd-transfer-record'] = 'Transfer record ownership to another user (v3 API)' - command_info['kd-ln'] = 'Link a record into a KeeperDrive folder (positional)' - command_info['kd-rm'] = 'Remove (delete/unlink) a KeeperDrive record (v3 API)' - command_info['kd-rmdir'] = 'Remove a KeeperDrive folder and its contents (v3 API)' - command_info['kd-shortcut'] = 'Manage KeeperDrive record shortcuts (multi-folder links)' - command_info['kd-get'] = 'Get details of a KeeperDrive record or folder (like legacy get)' + command_info['kd-share-folder'] = 'Grant/update/revoke folder access' + command_info['kd-record-details'] = 'Get record metadata (title, color' + command_info['kd-share-record'] = 'Grant/update/revoke record sharing' + command_info['kd-record-permission'] = 'Modify sharing permissions of records in a folder' + command_info['kd-transfer-record'] = 'Transfer record ownership to another user' + command_info['kd-ln'] = 'Link a record into a KeeperDrive folder' + command_info['kd-rm'] = 'Remove (delete/unlink) a KeeperDrive record' + command_info['kd-rmdir'] = 'Remove a KeeperDrive folder and its contents' + command_info['kd-shortcut'] = 'Manage KeeperDrive record shortcuts' + command_info['kd-get'] = 'Get details of a KeeperDrive record or folder' diff --git a/keepercommander/commands/keeper_drive/folder_commands.py b/keepercommander/commands/keeper_drive/folder_commands.py index d5fa807ea..f46bff372 100644 --- a/keepercommander/commands/keeper_drive/folder_commands.py +++ b/keepercommander/commands/keeper_drive/folder_commands.py @@ -133,13 +133,7 @@ def execute(self, params, **kwargs): if not new_name: raise CommandError('kd-rndir', 'Folder name cannot be empty') - inherit_permissions = None - if kwargs.get('inherit_permissions'): - inherit_permissions = True - elif kwargs.get('no_inherit_permissions'): - inherit_permissions = False - - if new_name is None and color is None and inherit_permissions is None: + if new_name is None and color is None: raise CommandError('kd-rndir', 'New folder name and/or color parameters are required.') folder_uid = resolve_folder_uid(params, folder_arg) @@ -151,7 +145,7 @@ def execute(self, params, **kwargs): with command_error_handler('kd-rndir'): result = _kd.update_folder_v3( params=params, folder_uid=folder_arg, folder_name=new_name, - color=color, inherit_permissions=inherit_permissions, + color=color, ) check_result(result, 'kd-rndir') params.sync_data = True diff --git a/keepercommander/commands/keeper_drive/helpers.py b/keepercommander/commands/keeper_drive/helpers.py index 6dfbe49ca..e2c2ade47 100644 --- a/keepercommander/commands/keeper_drive/helpers.py +++ b/keepercommander/commands/keeper_drive/helpers.py @@ -350,12 +350,20 @@ def infer_role(access): full-manager > content-share-manager > shared-manager > content-manager > viewer > contributor > requestor > navigator + + The distinguishing trait between ``shared-manager`` and + ``content-share-manager`` is the ability to *edit* records: both roles + grant ``can_update_access`` + ``can_approve_access``, but only + ``content-share-manager`` also grants ``can_edit``. Without that check + every shared-manager would be reported as content-share-manager. """ get = access.get if get('can_change_ownership') or get('can_delete'): return 'full-manager' - if get('can_update_access') and get('can_approve_access'): + if get('can_update_access') and get('can_approve_access') and get('can_edit'): return 'content-share-manager' + if get('can_update_access') and get('can_approve_access'): + return 'shared-manager' if get('can_update_access'): return 'shared-manager' if get('can_edit'): diff --git a/keepercommander/commands/keeper_drive/parsers.py b/keepercommander/commands/keeper_drive/parsers.py index a444a2c6b..92375956a 100644 --- a/keepercommander/commands/keeper_drive/parsers.py +++ b/keepercommander/commands/keeper_drive/parsers.py @@ -55,12 +55,6 @@ def _make_parser(prog, description): '--color', dest='color', action='store', choices=['none', 'red', 'orange', 'yellow', 'green', 'blue', 'gray'], help='folder color') -keeper_drive_update_folder_parser.add_argument( - '--inherit', dest='inherit_permissions', action='store_true', - help='set folder to inherit parent permissions') -keeper_drive_update_folder_parser.add_argument( - '--no-inherit', dest='no_inherit_permissions', action='store_true', - help='set folder to not inherit parent permissions') keeper_drive_update_folder_parser.add_argument( '-q', '--quiet', dest='quiet', action='store_true', help='rename folder without confirmation message') diff --git a/keepercommander/commands/keeper_drive/record_commands.py b/keepercommander/commands/keeper_drive/record_commands.py index d77231010..ddd586de2 100644 --- a/keepercommander/commands/keeper_drive/record_commands.py +++ b/keepercommander/commands/keeper_drive/record_commands.py @@ -27,6 +27,7 @@ resolve_folder_uid, command_error_handler, check_result, check_record_edit_permission, check_record_delete_permission, ensure_keeper_drive_record, ensure_keeper_drive_folder, + ROOT_FOLDER_UID, ) from .parsers import ( keeper_drive_add_record_parser, @@ -269,9 +270,19 @@ def __init__(self): @staticmethod def get_record_shortcuts(params): - """Return ``{record_uid: set(folder_uids)}`` for records in 2+ folders.""" + """Return ``{record_uid: set(folder_uids)}`` for records in 2+ folders. + + ``keeper_drive_folder_records`` can carry server-side virtual folder + UIDs (e.g. shared-with-me containers) that have no real folder entry + in ``keeper_drive_folders``. These cannot be resolved or modified, so + they are filtered out — counting them would inflate shortcut totals + and break ``kd-shortcut keep`` removals downstream. + """ + kd_folders = getattr(params, 'keeper_drive_folders', {}) records = {} for folder_uid, rec_set in getattr(params, 'keeper_drive_folder_records', {}).items(): + if folder_uid != ROOT_FOLDER_UID and folder_uid not in kd_folders: + continue for record_uid in rec_set: records.setdefault(record_uid, set()).add(folder_uid) return {k: v for k, v in records.items() if len(v) > 1} @@ -288,9 +299,11 @@ def execute(self, params, **kwargs): target = kwargs.get('target') kd_records = getattr(params, 'keeper_drive_records', {}) + kd_record_data = getattr(params, 'keeper_drive_record_data', {}) kd_folders = getattr(params, 'keeper_drive_folders', {}) - to_show = self._resolve_target(params, target, records, kd_records, kd_folders) \ + to_show = self._resolve_target(params, target, records, kd_records, + kd_record_data, kd_folders) \ if target else set(records.keys()) if not to_show: @@ -300,10 +313,10 @@ def execute(self, params, **kwargs): fmt = kwargs.get('format') or 'table' table = [] for record_uid in sorted(to_show): - title = kd_records.get(record_uid, {}).get('title', record_uid) + title = self._record_title(record_uid, kd_record_data) folder_names = [] for fuid in sorted(records[record_uid]): - fname = kd_folders.get(fuid, {}).get('name', fuid) + fname = self._folder_name(fuid, kd_folders) folder_names.append({'folder_uid': fuid, 'name': fname} if fmt == 'json' else f'{fname} ({fuid})') table.append([record_uid, title, folder_names]) @@ -313,17 +326,34 @@ def execute(self, params, **kwargs): from ..base import dump_report_data return dump_report_data(table, headers, fmt=fmt, filename=kwargs.get('output')) + # Record titles live in ``keeper_drive_record_data[uid]['data_json']`` + # (the decrypted record payload). ``keeper_drive_records`` only stores + # metadata (revision/version/shared/etc.) and has no ``title`` key, so + # the previous lookup always fell back to the raw UID. + @staticmethod + def _record_title(record_uid, kd_record_data): + rd = kd_record_data.get(record_uid) or {} + dj = rd.get('data_json') or {} + title = dj.get('title') + return title if title else record_uid + @staticmethod - def _resolve_target(params, target, records, kd_records, kd_folders): - to_show = set() + def _folder_name(folder_uid, kd_folders): + if folder_uid == ROOT_FOLDER_UID: + return 'root' + return kd_folders.get(folder_uid, {}).get('name', folder_uid) + + @classmethod + def _resolve_target(cls, params, target, records, kd_records, + kd_record_data, kd_folders): if target in kd_records: if target not in records: raise CommandError('kd-shortcut list', f'Record UID {target} does not have shortcuts') return {target} lower = target.casefold() - for uid, rec in kd_records.items(): - if rec.get('title', '').casefold() == lower: + for uid in kd_records: + if cls._record_title(uid, kd_record_data).casefold() == lower: if uid not in records: raise CommandError('kd-shortcut list', f'Record "{target}" does not have shortcuts') return {uid} @@ -352,7 +382,8 @@ def execute(self, params, **kwargs): kd_records = getattr(params, 'keeper_drive_records', {}) kd_folders = getattr(params, 'keeper_drive_folders', {}) - record_uid = self._resolve_record(target, kd_records) + kd_record_data = getattr(params, 'keeper_drive_record_data', {}) + record_uid = self._resolve_record(target, kd_records, kd_record_data) keep_folder_uid = self._resolve_keep_folder(params, kwargs.get('folder'), kd_folders) records = KeeperDriveShortcutCommand.get_record_shortcuts(params) @@ -397,10 +428,17 @@ def execute(self, params, **kwargs): target, keep_name, len(folders_to_remove)) @staticmethod - def _resolve_record(target, kd_records): + def _resolve_record(target, kd_records, kd_record_data=None): if target in kd_records: return target lower = target.casefold() + if kd_record_data: + for uid in kd_records: + rd = kd_record_data.get(uid) or {} + dj = rd.get('data_json') or {} + if (dj.get('title') or '').casefold() == lower: + return uid + for uid, rec in kd_records.items(): if rec.get('title', '').casefold() == lower: return uid diff --git a/keepercommander/commands/keeper_drive/sharing_commands.py b/keepercommander/commands/keeper_drive/sharing_commands.py index bc48691c8..a90972f90 100644 --- a/keepercommander/commands/keeper_drive/sharing_commands.py +++ b/keepercommander/commands/keeper_drive/sharing_commands.py @@ -27,7 +27,7 @@ from ...error import CommandError from ... import keeper_drive as _kd from .helpers import ( - parse_expiration, infer_role, + parse_expiration, get_access_role_label, command_error_handler, check_result, check_record_share_permission, collect_records_in_folder, @@ -51,8 +51,6 @@ def get_parser(self): return keeper_drive_share_record_parser def execute(self, params, **kwargs): - from keepercommander.commands.base import user_choice - record_arg = kwargs.get('record') emails = kwargs.get('email') or [] action = kwargs.get('action') or 'grant' @@ -418,7 +416,7 @@ def _compute_changes(accesses_result, record_uids, current_user, action, role, r if not email or email == current_user: continue - cur_role = infer_role(access) + cur_role = get_access_role_label(access) is_inherited = bool(access.get('inherited')) # Pre-flight: does the current user have permission to modify this share? diff --git a/keepercommander/proto/Summary.md b/keepercommander/proto/Summary.md deleted file mode 100644 index f739f5c09..000000000 --- a/keepercommander/proto/Summary.md +++ /dev/null @@ -1,170 +0,0 @@ -Summary -The Commander codebase (CLI, not the Web Vault — same backend, different client) already has substantial KeeperDrive scaffolding. A full kd-* command surface is wired up via keepercommander/commands/keeper_drive/ and an API service layer in keepercommander/keeper_drive/, plus sync-down handling and the KEEPER_DRIVE feature flag. - -But several MRD requirements are unimplemented or only partially covered. Here is the requirement-by-requirement audit, then the concrete to-do list. - -What's already in place -MRD ID Requirement Where it lives Status -KD1, KD2 -Create root / nested folders (canAdd) -kd-mkdir → keeper_drive.create_folder_v3 -Implemented -KD3, KD4 -Independent sharing & override of inherited permissions (canUpdateAccess) -kd-rndir --no-inherit / --inherit, _build_update_data sets inheritUserPermissions -Implemented -KD5 -Update folder metadata title/color (canUpdateSetting) -kd-rndir → update_folder_v3 -Implemented -KD6, KD7 -Remove / permanently delete folder -kd-rmdir -o folder-trash / delete-permanent → remove_folder_v3 -Implemented -KD8, KD9 -View folder structure / record titles -kd-list, sync caches keeper_drive_folders/records -Implemented -KD10 -View folder accessors (canListAccess) -get_folder_access_v3 (vault/folders/v3/access), kd-get -v -Implemented -KD11 -Update folder access -kd-share-folder → grant/update/revoke_folder_access_v3 -Implemented (users only — see gaps) -KD13, KD14, KD20–22 -View / edit record content -kd-get, kd-record-update, kd-record-add -Implemented -KD15, KD16 -Create record at root / inside folder -kd-record-add --folder → create_record_v3 -Implemented -KD17, KD18, KD19 -Remove record (folder / unlink / permanent) -kd-rm -o owner-trash / folder-trash / unlink → remove_record_v3 -Implemented -KD23 -View record accessors -get_record_accesses_v3 exposed via kd-get -Implemented -KD24 -Update record access -kd-share-record grant/revoke → share_record_v3 -Implemented (users only) -KD25 -Change record ownership -kd-share-record -a owner and kd-transfer-record → transfer_record_ownership_v3 -Implemented -KD32 -FeatureFlag.KEEPER_DRIVE gating -params.is_feature_disallowed('keeper_drive') used in cli.py, sync_down.py, autocomplete.py -Implemented -KD33 -KeeperDrive ↔ Legacy isolation in mv -commands/folder.py lines 859–911 -Partial (only mv) -KD34 -Inheritance vs independent sharing -Same as KD3/KD4 -Implemented -KD35 -Explicit deny overrides -Marked N/A for MVP -N/A -What's missing or incomplete -1. KSM application support (KD29, KD30, KD31) — fully missing -There are zero KeeperDrive↔KSM hooks anywhere under keepercommander/keeper_drive/ or commands/keeper_drive/. A grep for ksm/secrets_manager in those directories returns nothing. - -You need: - -A kd-ksm-app-add (KD29 / KD31) that creates a KSM Application either at the vault root or inside a KeeperDrive folder context (canAdd). -A kd-ksm-app-share (KD30) that wires a KSM application share into a KeeperDrive folder. -Probably a service-layer module keepercommander/keeper_drive/ksm_api.py plus exposure through the package __init__.py _SUBMODULE_MAP. -The existing legacy commands/ksm.py is the natural source to refactor against — it already speaks the KSM endpoints; you just need the v3 wiring + folder-context parameter. - -2. Team sharing (KD11, KD24) — only user-as-actor is supported -The MRD wording for both folder and record sharing is "users or teams", and the proto already has AT_TEAM. But the only place AT_TEAM is referenced in keeper_drive/ is removal_api.py (and only to count affected_teams_count). - -In folder_api.py the entire share path hard-codes accessType = folder_pb2.AT_USER (lines 327, 385, 412, 441, 461, 478) and kd-share-folder/kd-share-record parsers accept only --email. There's no --team/--team-uid switch and no team-key-encryption path. - -Add: - ---team / -T to kd-share-folder and kd-share-record parsers. -A resolve_team_uid_bytes helper paralleling resolve_user_uid_bytes in common.py. -A team branch in grant_folder_access_v3 / manage_folder_access_batch_v3 / record-share that sets accessType = AT_TEAM and encrypts the folder/record key with the team key. -3. Maximum nesting depth = 5 (KD2) — not enforced -Nothing in kd-mkdir, create_folder_v3, _prepare_folder_for_creation, or helpers.py checks the depth of the parent chain. A grep for MAX_DEPTH / depth returns no matches. - -Add a check in KeeperDriveMkdirCommand.execute (or in create_folder_v3) that walks parent_uid → parent_uid through params.keeper_drive_folders and refuses with a friendly error when depth >= 5. - -4. Cross-model nesting prevention (KD28) — only mv is guarded -commands/folder.py FolderMoveCommand blocks moves between SharedFolder and KeeperDrive. But FolderMakeCommand (legacy mkdir) at line 482+ never inspects whether base_folder is a KeeperDriveFolderType. So mkdir -sf "X" while cd'd into a KD folder would try to create a Legacy SharedFolder inside KeeperDrive (server may reject, but client should fail fast). - -Add the symmetric guard to FolderMakeCommand.execute: - -if base_folder.type == BaseFolderNode.KeeperDriveFolderType: - raise CommandError('mkdir', - 'Legacy folders cannot be created inside a KeeperDrive folder. ' - 'Use kd-mkdir instead.') -The same guard belongs in legacy RecordAddCommand (legacy add/record-add) so that when the current folder is KD, the user is told to use kd-record-add (KD27 contextual create rule). - -5. kd-mkdir cannot target a parent (KD2 ergonomics) -KeeperDriveMkdirCommand discovers the parent only via params.current_folder. There is no --folder / --parent argument the way kd-record-add has --folder, so building hierarchies non-interactively requires cd between every call. - -Add --folder/--parent FOLDER to keeper_drive_mkdir_parser and pass it through create_folder_v3(parent_uid=…). - -6. Folder-permission–driven record actions (KD13, KD14) — client checks only the record-level grant -helpers._check_record_permission looks up keeper_drive_record_accesses only. The MRD says folder-level canViewRecords / canEditRecords should also authorize record reads/edits inside that folder. - -kd-record-update and kd-get therefore reject users who hold the right via the parent folder rather than the record itself. The server is probably permissive, but pre-flight checks are wrong. - -Update _check_record_permission to also walk find_kd_folders_for_record(params, record_uid) and accept when any containing folder grants can_view_records / can_edit_records. - -7. Permission matrix vs MRD scope -keeper_drive/permissions.py includes NAVIGATOR=0 and REQUESTOR=1, and ROLE_NAME_MAP exposes 'contributor' / 'requestor' (both → 1). The MRD explicitly limits Phase 1 roles to VIEWER (2), SHARED_MANAGER (3), CONTENT_MANAGER (4), CONTENT_SHARE_MANAGER (5), MANAGER (6). - -This is fine if backend will silently accept those, but kd-share-folder/kd-share-record parsers do already restrict --role choices to MRD-allowed names, so the extras are dead options reachable only programmatically. Decide whether to: - -Drop NAVIGATOR/REQUESTOR/contributor mapping for V1 to avoid drift, or -Keep them but document them as internal. -Also: the helpers.role_label and infer_role functions still return 'contributor'/'requestor'/'navigator' on display — they will leak into kd-list -p and kd-get -v. Trim them for V1 to match the MRD's display surface. - -8. KD12 — Change folder ownership -MRD marks this as N/A for MVP, and there is no kd-chown-folder in the codebase. Confirmed correct — leave a stub TODO if desired. - -9. Out-of-scope features that are actually present -The MRD explicitly puts these out of scope but the code partially supports them: - -Out-of-scope feature Where it leaks in Recommendation -TLA (time-limited access) ---expire-at / --expire-in on kd-share-folder and kd-share-record; tlaProperties.expiration set in grant_folder_access_v3 -Either keep (server will reject if disabled) or hide the switches behind a feature flag check. -TrashCan / restore -Staged trashcan_sync_pb2 files appear in the original git status snapshot but are untracked-uncommitted; keeper_drive_trashed_folders cache is referenced in sync.py clear_caches -Don't ship the proto pieces in this PR; remove the cache code or feature-flag it. -Move To / Drag-and-Drop -kd-ln and kd-shortcut are link operations, fine. But mv partially still talks about "Drive folders" — that path is correctly raising CommandError, leave as is. -10. current_folder for KD context (KD27 contextual create) -commands/folder.py FolderCdCommand.execute (line 363) already accepts a KD folder UID as current_folder. Good. But kd-record-add only consumes --folder and ignores params.current_folder. To match the MRD wording ("If a KeeperDrive folder is selected and Add/Create is clicked, the dialog shall create a record … within that KeeperDrive folder context"), have KeeperDriveAddRecordCommand.execute default folder_uid to params.current_folder when it is a KD folder UID. - -11. KEEPER_DRIVE flag handling on every kd-* command (KD32) -The flag is checked in command-listing/help (cli.py line 387) and in sync ingestion (sync_down.py line 79), but the individual kd-* execute() methods don't re-check the flag. So a user with the flag disallowed who somehow types kd-mkdir directly will hit it. Recommend adding a guard in a base helper used by every kd-* command (e.g., in helpers.command_error_handler or a separate require_keeper_drive(params, cmd_name) decorator). - -12. Minor -keeper_drive_share_folder_parser has no --team (see #2) and no JSON output mode, while every other kd-* listing/inspection has --format json. Add for parity if needed. -_check_folder_permission (in helpers.py) silently returns on the first matching username; if no matching access entry is found, it never raises — letting actions pass when they shouldn't. Add a final raise CommandError(cmd_name, error_message) after the loop. -kd-share-record parser sets --email required=True even for folder-bulk mode (-R). Reconsider. -Concrete to-do list, priority-ordered -Add KSM commands & service module (KD29/30/31). New work, needed for MRD section 8. -Add team-as-actor support to kd-share-folder / kd-share-record and the underlying *_v3 calls (KD11/KD24). -Enforce max depth 5 in kd-mkdir (KD2). -Block legacy mkdir/add inside KD folders to satisfy KD28; mirror the mv guard in FolderMakeCommand and the legacy add command. -Add --folder/--parent to kd-mkdir and have kd-record-add honor params.current_folder (KD27 ergonomics). -Fix folder-derived record permission checks in _check_record_permission (KD13/14). -Add a global feature-flag guard at the top of every kd-* execute() (KD32). -Tighten _check_folder_permission so a missing access record raises instead of falling through. -Trim NAVIGATOR/REQUESTOR/contributor surface from display helpers and the role map (MRD V1 role list). -Either ship or revert the staged trashcan_sync_pb2* files — they're listed in the git snapshot but absent on disk; trash sync is out of scope per MRD. -Want me to start with any of these (KSM commands and team support are the two largest gaps)? \ No newline at end of file diff --git a/unit-tests/test_keeper_drive.py b/unit-tests/test_keeper_drive.py index aa1623345..fb429c804 100644 --- a/unit-tests/test_keeper_drive.py +++ b/unit-tests/test_keeper_drive.py @@ -101,7 +101,18 @@ def test_parse_expiration_invalid(self): def test_infer_role(self): from keepercommander.commands.keeper_drive.helpers import infer_role self.assertEqual(infer_role({'can_change_ownership': True}), 'full-manager') - self.assertEqual(infer_role({'can_update_access': True, 'can_approve_access': True}), 'content-share-manager') + # ``can_update_access`` + ``can_approve_access`` alone (no edit) is + # ``shared-manager``; promotion to ``content-share-manager`` requires + # ``can_edit`` per the v3 permission matrix. + self.assertEqual( + infer_role({'can_update_access': True, 'can_approve_access': True, + 'can_edit': True}), + 'content-share-manager', + ) + self.assertEqual( + infer_role({'can_update_access': True, 'can_approve_access': True}), + 'shared-manager', + ) self.assertEqual(infer_role({'can_update_access': True}), 'shared-manager') self.assertEqual(infer_role({'can_edit': True}), 'content-manager') self.assertEqual(infer_role({'can_view': True, 'can_list_access': True}), 'viewer') From 51241a8d1b170eadf27689349b9dda64aedf276d Mon Sep 17 00:00:00 2001 From: amangalampalli-ks Date: Sat, 9 May 2026 04:10:18 +0530 Subject: [PATCH 15/26] PAM Workflow changes and improvements (#2019) (#2028) * Implement workflow recent changes after apr 22 * Remove flow uid support in state * Remove column from my-access and fix mfa prompts * Fix flow uids starting with - and my-access table view * Add record type validation for workflows * Add better error handling * Update for review comments and tzlocal library for windows --- .../commands/workflow/approver_commands.py | 15 +- .../commands/workflow/config_commands.py | 43 +-- keepercommander/commands/workflow/helpers.py | 297 ++++++++++-------- keepercommander/commands/workflow/mfa.py | 114 ++----- keepercommander/commands/workflow/registry.py | 3 +- .../commands/workflow/requester_commands.py | 6 +- .../commands/workflow/state_commands.py | 33 +- requirements.txt | 1 + setup.cfg | 1 + 9 files changed, 235 insertions(+), 278 deletions(-) diff --git a/keepercommander/commands/workflow/approver_commands.py b/keepercommander/commands/workflow/approver_commands.py index 342579714..644bc2296 100644 --- a/keepercommander/commands/workflow/approver_commands.py +++ b/keepercommander/commands/workflow/approver_commands.py @@ -22,7 +22,7 @@ from ...proto import workflow_pb2 from ... import api, crypto, utils -from .helpers import RecordResolver, WorkflowFormatter, sanitize_router_error +from .helpers import RecordResolver, WorkflowFormatter, sanitize_router_error, DashUidArgsMixin class WorkflowGetApprovalRequestsCommand(Command): @@ -171,7 +171,7 @@ def _print_table(params, workflows): print() -class WorkflowApproveCommand(Command): +class WorkflowApproveCommand(DashUidArgsMixin, Command): parser = argparse.ArgumentParser( prog='pam workflow approve', description='Approve a workflow access request', @@ -209,7 +209,7 @@ def execute(self, params: KeeperParams, **kwargs): raise CommandError('', f'Failed to approve request: {sanitize_router_error(e)}') -class WorkflowDenyCommand(Command): +class WorkflowDenyCommand(DashUidArgsMixin, Command): parser = argparse.ArgumentParser( prog='pam workflow deny', description='Deny a workflow access request', @@ -236,7 +236,14 @@ def execute(self, params: KeeperParams, **kwargs): if reason: reason_bytes = reason.encode('utf-8') encrypted = self._encrypt_denial_reason(params, flow_uid_bytes, reason_bytes) - denial.denialReason = encrypted if encrypted else reason_bytes + if encrypted: + denial.denialReason = encrypted + else: + logging.warning( + 'Could not encrypt denial reason for the requester — reason will not be attached. ' + 'The denial itself will still be sent.' + ) + reason = '' try: _post_request_to_router(params, 'approve_or_deny_workflow_access', rq_proto=denial) diff --git a/keepercommander/commands/workflow/config_commands.py b/keepercommander/commands/workflow/config_commands.py index 28ca9c71f..1fe17d59a 100644 --- a/keepercommander/commands/workflow/config_commands.py +++ b/keepercommander/commands/workflow/config_commands.py @@ -28,13 +28,8 @@ def _add_approvers_to_workflow(params, record_uid, record_name, users=None, teams=None, is_escalation=False, escalation_after_ms=0): - """Send add_workflow_approvers for the given users / teams. Shared by - `pam workflow create` (when --approver flags are supplied) and - `pam workflow add-approver` so both go through one code path. - - Caller is responsible for de-duplicating the user / team lists. Raises - on transport error; caller decides how to surface. - """ + """Send add_workflow_approvers for the given users / teams. Shared by create + add-approver. + Caller must de-duplicate. Raises on transport error.""" record_uid_bytes = utils.base64_url_decode(record_uid) config = workflow_pb2.WorkflowConfig() config.parameters.resource.CopyFrom( @@ -83,8 +78,6 @@ class WorkflowCreateCommand(Command): help='Comma-separated allowed days (e.g., "mon,tue,wed,thu,fri")') parser.add_argument('--time-range', type=str, help='Allowed time range in HH:MM-HH:MM format (e.g., "09:00-17:00")') - parser.add_argument('--timezone', type=str, - help='Timezone for allowed times (e.g., "America/New_York")') parser.add_argument('-u', '--approver', action='append', help='User email to add as an approver. Pass multiple times to ' 'add several. Required when --approvals-needed > 0. ' @@ -97,12 +90,10 @@ def get_parser(self): def execute(self, params: KeeperParams, **kwargs): record_uid, record = RecordResolver.resolve(params, kwargs.get('record')) + RecordResolver.validate_workflow_record_type(record) record_uid_bytes = utils.base64_url_decode(record_uid) - # Pre-check: surface the "already exists" condition with a clear, - # actionable message instead of letting the user discover it via the - # raw server error from create_workflow_config. Server-side error - # path is still the authoritative gate; this is just nicer UX. + # Pre-check for nicer UX; server is still authoritative on conflicts. try: existing = _post_request_to_router( params, 'read_workflow_config', @@ -126,7 +117,6 @@ def execute(self, params: KeeperParams, **kwargs): if approvals < 0: raise CommandError('', 'Approvals needed must be 0 or greater') - # Normalize and de-duplicate the approver list (preserves first-seen order). approvers = list(dict.fromkeys( a.strip() for a in (kwargs.get('approver') or []) if a and a.strip() )) @@ -155,7 +145,7 @@ def execute(self, params: KeeperParams, **kwargs): parameters.accessLength = WorkflowFormatter.parse_duration(kwargs.get('duration', '1d')) temporal_filter = WorkflowFormatter.build_temporal_filter( - kwargs.get('allowed_days'), kwargs.get('time_range'), kwargs.get('timezone'), + kwargs.get('allowed_days'), kwargs.get('time_range'), ) if temporal_filter: parameters.allowedTimes.CopyFrom(temporal_filter) @@ -163,10 +153,6 @@ def execute(self, params: KeeperParams, **kwargs): try: _post_request_to_router(params, 'create_workflow_config', rq_proto=parameters) - # Step 2: send the explicit approver list (if any). Mirrors web vault - # which issues create_workflow_config + add_workflow_approvers as two - # separate calls (save-workflow-settings.ts:78-99). No silent - # auto-add of the creator / record-owner. approvers_added = [] if approvers: try: @@ -329,9 +315,6 @@ def _print_table(params, response, record_uid): print(f" Days: {', '.join(day_names)}") if at.timeRanges: for tr in at.timeRanges: - # startTime / endTime are HHMM (hours*100 + minutes); see - # WorkflowFormatter._parse_time_to_hhmm and the canonical - # ka-libs/workflow/.../WfConfigCRUD.kt::validateHHMM. start_h, start_m = divmod(tr.startTime, 100) end_h, end_m = divmod(tr.endTime, 100) print(f" Time: {start_h:02d}:{start_m:02d} - {end_h:02d}:{end_m:02d}") @@ -386,8 +369,6 @@ class WorkflowUpdateCommand(Command): help='Comma-separated allowed days (e.g., "mon,tue,wed,thu,fri")') parser.add_argument('--time-range', type=str, help='Allowed time range in HH:MM-HH:MM format (e.g., "09:00-17:00")') - parser.add_argument('--timezone', type=str, - help='Timezone for allowed times (e.g., "America/New_York")') parser.add_argument('--format', dest='format', action='store', choices=['table', 'json'], default='table', help='Output format') @@ -434,7 +415,7 @@ def execute(self, params: KeeperParams, **kwargs): updates_provided = True temporal_filter = WorkflowFormatter.build_temporal_filter( - kwargs.get('allowed_days'), kwargs.get('time_range'), kwargs.get('timezone'), + kwargs.get('allowed_days'), kwargs.get('time_range'), ) if temporal_filter: parameters.allowedTimes.CopyFrom(temporal_filter) @@ -456,6 +437,8 @@ def execute(self, params: KeeperParams, **kwargs): print(f"Record: {record.title} ({record_uid})") print() + except (CommandError, KeyboardInterrupt, SystemExit): + raise except Exception as e: raise CommandError('', f'Failed to update workflow: {sanitize_router_error(e)}') @@ -477,10 +460,7 @@ def execute(self, params: KeeperParams, **kwargs): record_uid_bytes = utils.base64_url_decode(record_uid) ref = ProtobufRefBuilder.record_ref(record_uid_bytes, record.title) - # Pre-check: server-side delete_workflow_config is idempotent and - # returns success even when no config exists, so without this check - # repeat calls all print "deleted successfully" — confusing in - # interactive use. Verify there's actually something to delete first. + # Server-side delete is idempotent; verify config exists for accurate user feedback. try: existing = _post_request_to_router( params, 'read_workflow_config', @@ -531,7 +511,6 @@ def get_parser(self): return WorkflowAddApproversCommand.parser def execute(self, params: KeeperParams, **kwargs): - # De-duplicate user / team lists (preserve first-seen order). users = list(dict.fromkeys( u.strip() for u in (kwargs.get('user') or []) if u and u.strip() )) @@ -582,6 +561,8 @@ def execute(self, params: KeeperParams, **kwargs): print(f"Type: Escalation approver{esc_info}") print() + except (CommandError, KeyboardInterrupt, SystemExit): + raise except Exception as e: raise CommandError('', f'Failed to add approvers: {sanitize_router_error(e)}') @@ -642,5 +623,7 @@ def execute(self, params: KeeperParams, **kwargs): print(f"Removed {total} approver(s)") print() + except (CommandError, KeyboardInterrupt, SystemExit): + raise except Exception as e: raise CommandError('', f'Failed to remove approvers: {sanitize_router_error(e)}') diff --git a/keepercommander/commands/workflow/helpers.py b/keepercommander/commands/workflow/helpers.py index 21a3ac4dd..7dede3c8d 100644 --- a/keepercommander/commands/workflow/helpers.py +++ b/keepercommander/commands/workflow/helpers.py @@ -10,6 +10,7 @@ # import re +import shlex from typing import List, Optional, Tuple from ...error import CommandError @@ -24,6 +25,55 @@ _RESPONSE_CODE_RE = re.compile(r'\s*[Rr]esponse\s+code:\s*\S+\s*$') +class DashUidArgsMixin: + """Mixin for commands whose positional flow-UID arg may start with '-' (base64url).""" + + def execute_args(self, params, args, **kwargs): + args = fix_dash_uid_args(self.get_parser(), args) + return super().execute_args(params, args, **kwargs) + + +def fix_dash_uid_args(parser, args): + """Insert '--' before a base64url UID starting with '-' so argparse treats it as positional.""" + if not args: + return args + try: + tokens = shlex.split(args) + except ValueError: + return args + if '--' in tokens: + return args + + known_opts = set() + consumes_value = set() + for action in parser._actions: + for opt in action.option_strings: + known_opts.add(opt) + if action.nargs != 0: + consumes_value.add(opt) + + result = [] + skip_next = False + for token in tokens: + if skip_next: + result.append(token) + skip_next = False + continue + opt_name = token.split('=', 1)[0] if token.startswith('--') and '=' in token else token + if opt_name in known_opts: + result.append(token) + if opt_name in consumes_value and token == opt_name: + skip_next = True + continue + if token.startswith('-'): + result.append('--') + result.append(token) + + if len(result) != len(tokens): + return ' '.join(shlex.quote(t) for t in result) + return args + + def sanitize_router_error(error: Exception) -> str: msg = str(error) msg = _RESPONSE_CODE_RE.sub('', msg) @@ -32,90 +82,74 @@ def sanitize_router_error(error: Exception) -> str: return msg or 'Unknown error' -_ENFORCEMENT_KEY = 'allow_configure_workflow_settings' - - def print_exempt_message(fmt='table'): - """Print the standard exemption message in the appropriate format.""" import json as _json from ...display import bcolors as _bc if fmt == 'json': print(_json.dumps({'status': 'exempt', 'message': 'Workflow not required'}, indent=2)) else: - print(f"\n{_bc.WARNING}You have edit access and workflow management permissions for this record.{_bc.ENDC}\n") - print("Workflow is not required — you can access this resource directly.\n") - + print(f"\n{_bc.WARNING}You are exempt from workflow restrictions on this record.{_bc.ENDC}") + print("As a record owner or approver, you can access this resource directly.\n") -def is_workflow_exempt(params, record_uid): - """Users with edit access AND 'Can manage workflow settings' are exempt from workflow.""" - enforcements = getattr(params, 'enforcements', None) - if not enforcements or 'booleans' not in enforcements: - return False - can_manage = any( - b.get('value') for b in enforcements['booleans'] - if b.get('key') == _ENFORCEMENT_KEY - ) - if not can_manage: - return False +def is_record_owner(params, record_uid): if record_uid in getattr(params, 'record_owner_cache', {}): owner_info = params.record_owner_cache[record_uid] if getattr(owner_info, 'owner', False): return True + return False - meta = getattr(params, 'meta_data_cache', {}).get(record_uid) - if meta and meta.get('can_edit'): - return True - for sf_uid in getattr(params, 'shared_folder_cache', {}): - sf = params.shared_folder_cache[sf_uid] - for sfr in sf.get('records', []): - if sfr.get('record_uid') == record_uid: - if sfr.get('owner') or sfr.get('can_edit'): - return True +def is_on_approver_list(params, config): + """Check if current user is on the approver list (by email or team membership).""" + if not config or not config.approvers: + return False + + current_user = getattr(params, 'user', '') + team_cache = getattr(params, 'team_cache', {}) + for approver in config.approvers: + if approver.user and approver.user.lower() == current_user.lower(): + return True + if approver.teamUid: + team_uid_b64 = utils.base64_url_encode(approver.teamUid) + if team_uid_b64 in team_cache: + return True return False +def is_workflow_exempt(params, record_uid, config=None): + """Exempt = record owner OR on approver list. Pass `config` to skip a round-trip. + Transport failures fail closed (non-exempt).""" + if is_record_owner(params, record_uid): + return True + + if config is None: + from ..pam.router_helper import _post_request_to_router + try: + ref = ProtobufRefBuilder.record_ref( + utils.base64_url_decode(record_uid), + '', + ) + config = _post_request_to_router( + params, 'read_workflow_config', + rq_proto=ref, rs_type=workflow_pb2.WorkflowConfig, + ) + except Exception as e: + import logging as _logging + _logging.debug( + 'is_workflow_exempt config read failed for %s: %s', record_uid, e, + ) + return False + + return is_on_approver_list(params, config) + + def is_pam_action_allowed_by_enforcement(params: KeeperParams, enforcement_key: str) -> bool: - """Per-user enterprise enforcement gate: is this user permitted to perform - the action by their enterprise enforcement profile? - - Mirrors web vault `getAllowConnections` / `getAllowPortForwards` - (pam-enforcement-selectors.ts:38-49). The enforcement boolean is already - the rolled-up sum of the user's role permissions — no additional role / - team / ACL / deny-list checks are needed (verified against WV). - - NOTE — license check intentionally NOT included here. Web vault wraps - these enforcement selectors in `mkPam` (pam-enforcement-selectors.ts:33-34) - which also requires `state.accountSummary.license?.isPamEnabled`. Commander - has historically treated license defensively: we let the gateway / server - be the authoritative gate for license-driven failures rather than - pre-flighting it client-side. Same approach is used by the rotation - enforcement check (`_is_rotation_allowed_by_enforcement` in - discoveryrotation.py). The trade-off: an account that has the enforcement - granted but no PAM license will pass this gate and fail later at the - gateway with a more specific error, instead of getting a "not licensed" - refusal up front. If parity becomes important, wrap this with a license - check against `params.account_summary` / `params.license`. - - Decision matrix: - no enforcement context at all (personal / non-enterprise user): - -> allow (gateway is the authoritative gate) - enforcement context present, key explicitly true: - -> allow - enforcement context present, key explicitly false: - -> deny - enforcement context present, key absent: - -> deny (matches WV: `!!enforcements.` evaluates to false - for missing keys; also matches Commander's enforcement parser - behavior, which converts `--enforcement KEY:false` to None and - removes the key, so "absent" is the actual on-the-wire shape - of "explicitly disabled") - - Defensively swallows any unexpected enforcement payload shape and falls - back to allow. - """ + """Per-user enterprise enforcement gate. Mirrors web vault's PAM enforcement selectors. + Non-enterprise users and unexpected payload shapes fall back to allow (gateway gates). + Enterprise users with the key absent are denied (matches WV's `!!enforcements.`). + License is intentionally not checked here — gateway is the authoritative gate for that.""" import logging as _logging try: enforcements = getattr(params, 'enforcements', None) @@ -123,14 +157,10 @@ def is_pam_action_allowed_by_enforcement(params: KeeperParams, enforcement_key: return True booleans = enforcements.get('booleans') or [] if not isinstance(booleans, list) or not booleans: - # Enforcement context but no booleans at all — treat as - # "not yet configured" and allow; gateway will gate. return True for b in booleans: if isinstance(b, dict) and b.get('key') == enforcement_key: return bool(b.get('value')) - # Enterprise user with a populated enforcement set, but the - # specific PAM-grant key is absent. Match WV: deny. return False except Exception as e: _logging.debug('Enforcement check failed for %s: %s', enforcement_key, e) @@ -139,27 +169,9 @@ def is_pam_action_allowed_by_enforcement(params: KeeperParams, enforcement_key: def is_pam_config_action_allowed_for_record(params: KeeperParams, record_uid: str, action_key: str) -> bool: - """Best-effort PAM config gate: is the action permitted by the record's - PAM configuration's allowedSettings (DAG) ? - - Mirrors web vault `getConfigAllowedSettings(recordUid)[key]` (used by - GuacConnectBanner.tsx:37-45 for launch and StartPortForwardButton.tsx:160-163 - for tunnel). The action_key matches the JSON-mapped names returned by - PAMConfigurationListCommand._allowed_settings_dag_to_json: - - 'connections' → launch / connect - 'tunneling' → tunnel / port-forward (DAG: portForwards) - 'rotation' → rotate - - Returns True (allow) on any lookup failure (no PAM context, missing DAG, - not a PAM record, personal account) so non-enterprise contexts aren't - blocked. Returns False only when the flag is explicitly False on the - PAM config DAG. - - Fast path: launch_cache entry holds a recent config_uid for the record; - on cache hit we skip the DAG-leafs round-trip and go straight to the - config-vertex read. - """ + """Best-effort PAM config allowedSettings (DAG) gate. + action_key: 'connections' (launch), 'tunneling' (port-forward), 'rotation'. + Returns True on any lookup failure; False only when explicitly disabled on the config.""" import logging as _logging try: config_uid = None @@ -196,16 +208,7 @@ def is_pam_config_action_allowed_for_record(params: KeeperParams, record_uid: st def is_gateway_online_for_record(params: KeeperParams, record_uid: str) -> Optional[bool]: - """Best-effort check: is the gateway for record_uid currently connected to the router? - - Returns True / False when known, None when undetermined (e.g. the - record has no entry in launch_cache yet, or the router lookup fails). - Callers should treat None as "proceed with normal flow" and only act - on a definitive False. - - Uses launch_cache.get to avoid an expensive PAM_LINK DAG rebuild on the - first launch; subsequent launches resolve the gateway UID instantly. - """ + """Best-effort gateway online check. Returns None when undetermined (treat as 'proceed').""" import logging as _logging try: from ..pam_launch import launch_cache @@ -228,7 +231,6 @@ def is_gateway_online_for_record(params: KeeperParams, record_uid: str) -> Optio def start_workflow_for_record(params: KeeperParams, record_uid: str) -> None: - """Send a start_workflow (check-out) request for the given record.""" from ..pam.router_helper import _post_request_to_router record_uid_bytes = utils.base64_url_decode(record_uid) record = vault.KeeperRecord.load(params, record_uid) @@ -241,11 +243,7 @@ def start_workflow_for_record(params: KeeperParams, record_uid: str) -> None: def submit_access_request(params: KeeperParams, record_uid: str, reason: str = '', ticket: str = '') -> None: - """Send a workflow access request with optional encrypted reason/ticket. - - Encrypts reason/ticket with the record key (AES-GCM-256), matching - web vault request_workflow_access.ts. Caller must hold the record key. - """ + """Send a workflow access request. Reason/ticket are encrypted with the record key.""" from ..pam.router_helper import _post_request_to_router record_uid_bytes = utils.base64_url_decode(record_uid) record = vault.KeeperRecord.load(params, record_uid) @@ -271,12 +269,7 @@ def submit_access_request(params: KeeperParams, record_uid: str, def prompt_for_reason_ticket(needs_reason: bool, needs_ticket: bool) -> Tuple[Optional[str], Optional[str]]: - """Interactively prompt the user for reason and/or ticket using prompt_toolkit. - - Returns (reason, ticket). Either may be None if not requested. Returns - (None, None) if the user cancels (Ctrl+C / EOF) or submits empty input - for a required field. - """ + """Prompt for reason/ticket. Returns (None, None) on cancel or empty required input.""" from prompt_toolkit import prompt as pt_prompt from ...display import bcolors as _bc @@ -305,6 +298,8 @@ def prompt_for_reason_ticket(needs_reason: bool, needs_ticket: bool) -> Tuple[Op class RecordResolver: + WORKFLOW_RECORD_TYPES = {'pamMachine', 'pamDirectory', 'pamDatabase', 'pamRemoteBrowser'} + @staticmethod def resolve(params, record_input, allow_missing=False): if record_input in params.record_cache: @@ -317,6 +312,19 @@ def resolve(params, record_input, allow_missing=False): return None, None raise CommandError('', f'Record "{record_input}" not found') + @staticmethod + def validate_workflow_record_type(record): + if not isinstance(record, vault.TypedRecord): + raise CommandError('', 'Workflows are only supported on PAM records') + record_type = record.record_type or 'unknown' + if record_type not in RecordResolver.WORKFLOW_RECORD_TYPES: + supported = ', '.join(sorted(RecordResolver.WORKFLOW_RECORD_TYPES)) + raise CommandError( + '', + f'Record "{record.title}" is of type "{record_type}" which does not support workflows.\n' + f'Supported record types: {supported}' + ) + @staticmethod def get_uid_bytes(params: KeeperParams, record_uid: str) -> bytes: uid_bytes = utils.base64_url_decode(record_uid) @@ -440,10 +448,19 @@ class WorkflowFormatter: workflow_pb2.SUNDAY: 'Sunday', } + BLOCKING_CONDITIONS = {workflow_pb2.AC_TIME, workflow_pb2.AC_APPROVAL} + @staticmethod def format_stage(stage: int, status=None) -> str: if stage == workflow_pb2.WS_READY_TO_START and status is not None: - if not status.startedOn and not status.conditions: + if status.conditions: + has_blocking = any(c in WorkflowFormatter.BLOCKING_CONDITIONS for c in status.conditions) + if has_blocking: + return 'Waiting' + return 'Ready to Start' + if status.approvedBy and not status.startedOn: + return 'Ready to Start' + if not status.startedOn and not status.approvedBy: return 'Needs Action' return WorkflowFormatter.STAGE_MAP.get(stage, f'Unknown ({stage})') @@ -490,8 +507,8 @@ def format_duration(milliseconds: int) -> str: return f"{seconds} second{'s' if seconds != 1 else ''}" @staticmethod - def build_temporal_filter(allowed_days_str, time_range_str, timezone_str): - if not allowed_days_str and not time_range_str and not timezone_str: + def build_temporal_filter(allowed_days_str, time_range_str): + if not allowed_days_str and not time_range_str: return None temporal = workflow_pb2.TemporalAccessFilter() @@ -516,25 +533,46 @@ def build_temporal_filter(allowed_days_str, time_range_str, timezone_str): time_range.endTime = end_hhmm temporal.timeRanges.append(time_range) - if timezone_str: - temporal.timeZone = timezone_str + temporal.timeZone = WorkflowFormatter._get_local_iana_timezone() return temporal + @staticmethod + def _get_local_iana_timezone(): + """Detect local IANA timezone via TZ env var (override) or tzlocal (cross-platform).""" + import os + + tz = os.environ.get('TZ') + if tz and '/' in tz: + return tz + + try: + from tzlocal import get_localzone_name + except ImportError: + raise CommandError( + '', + 'Timezone detection requires the "tzlocal" package. ' + 'Install it with: pip install tzlocal\n' + 'Or set the TZ environment variable (e.g., TZ=Asia/Kolkata).' + ) + + try: + zone = get_localzone_name() + if zone: + return zone + except Exception as e: + import logging as _logging + _logging.debug('tzlocal lookup failed: %s', e) + + raise CommandError( + '', + 'Could not detect local IANA timezone. ' + 'Set the TZ environment variable (e.g., TZ=Asia/Kolkata).' + ) + @staticmethod def _parse_time_to_hhmm(time_str): - """Parse 'HH:MM' to the HHMM integer the server stores on - TimeOfDayRange.startTime / .endTime: hours*100 + minutes. - Examples: '00:00' -> 0, '03:00' -> 300, '09:00' -> 900, '17:30' -> 1730. - Valid range: 0..2359 with hours in 0-23 and minutes in 0-59. - - Canonical sources (all agree on HHMM): - - keeperapp-protobuf/workflow.proto:140 - `int32 startTime = 1; // HHMM format` - - ka-libs/workflow/src/main/kotlin/com/keepersecurity/workflow/handlers/WfConfigCRUD.kt::validateHHMM - `val hours = value / 100; val minutes = value % 100` - throws "Invalid : . Expected HHMM integer with HH in 0-23 and MM in 0-59" on bad input. - """ + """Parse 'HH:MM' to HHMM integer (hours*100 + minutes), e.g. '09:00' -> 900, '17:30' -> 1730.""" try: parts = time_str.split(':') h = int(parts[0]) @@ -555,7 +593,6 @@ def format_temporal_filter(at): if at.timeRanges: ranges = [] for tr in at.timeRanges: - # startTime / endTime are HHMM integers (see _parse_time_to_hhmm). sh, sm = divmod(tr.startTime, 100) eh, em = divmod(tr.endTime, 100) ranges.append(f"{sh:02d}:{sm:02d}-{eh:02d}:{em:02d}") diff --git a/keepercommander/commands/workflow/mfa.py b/keepercommander/commands/workflow/mfa.py index d9a977eac..7322325ad 100644 --- a/keepercommander/commands/workflow/mfa.py +++ b/keepercommander/commands/workflow/mfa.py @@ -24,7 +24,8 @@ ProtobufRefBuilder, WorkflowFormatter, is_gateway_online_for_record, - is_workflow_exempt, + is_on_approver_list, + is_record_owner, prompt_for_reason_ticket, sanitize_router_error, start_workflow_for_record, @@ -40,7 +41,7 @@ class WorkflowGate(NamedTuple): - """Result of the pre-launch workflow gate, consumed by pam launch / pam tunnel.""" + """Result of the pre-launch workflow gate.""" allowed: bool two_factor_value: Optional[str] = None flow_uid: Optional[bytes] = None @@ -68,24 +69,9 @@ def __init__(self, params: KeeperParams, record_uid: str): self.record_name = record.title if record else record_uid def validate(self, silent_actionable: bool = False) -> dict: - if is_workflow_exempt(self.params, self.record_uid): + if is_record_owner(self.params, self.record_uid): return dict(self._DEFAULT_RESULT) - # Workflow REST endpoints (`read_workflow_config`, - # `get_user_access_state`, `get_workflow_state`) are not yet deployed - # on every router. On a router that doesn't expose them, the call - # raises (404 / unsupported / RRC error) and `_post_request_to_router` - # bubbles it up; the wrappers below convert that into _TRANSPORT_ERROR. - # We can't tell from the wire whether the failure is "endpoint not - # deployed" (prod today) or "endpoint deployed but momentarily - # unreachable" (transient QA hiccup). Erring on the side of legacy - # compatibility: treat _TRANSPORT_ERROR as "no workflow protection - # on this record, defer to the gateway." The gateway is the - # authoritative gate on prod; on QA the workflow service still - # enforces server-side, so a flaky read just relaxes the *client* - # gate without opening a real security gap. Old behavior (block - # with a banner) was correct for QA but a hard regression on prod - # legacy launches that have never seen workflow. config = self._read_workflow_config() if config is _TRANSPORT_ERROR: logging.debug( @@ -97,6 +83,9 @@ def validate(self, silent_actionable: bool = False) -> dict: if config is None: return dict(self._DEFAULT_RESULT) + if is_on_approver_list(self.params, config): + return dict(self._DEFAULT_RESULT) + mfa_required = bool(config.parameters and config.parameters.requireMFA) if not self._check_allowed_times(config): @@ -119,10 +108,6 @@ def validate(self, silent_actionable: bool = False) -> dict: ) return dict(self._DEFAULT_RESULT) if workflow is None: - # Carry the workflow config's required-field flags back so the - # orchestrator can inline-prompt + submit the initial access - # request from no_workflow state, matching web vault's "click - # Launch -> reason/ticket dialog" first-time flow. requires_reason = bool(config.parameters and config.parameters.requireReason) requires_ticket = bool(config.parameters and config.parameters.requireTicket) approvals_needed = int(config.parameters.approvalsNeeded) if config.parameters else 0 @@ -204,9 +189,6 @@ def _check_allowed_times(self, config) -> bool: return False if at.timeRanges: - # TimeOfDayRange.startTime / .endTime are HHMM-encoded integers - # (server-validated: HH in 0-23, MM in 0-59). e.g. 03:00 -> 300, - # 17:30 -> 1730. Compare current wall-clock in the same encoding. current_hhmm = now.hour * 100 + now.minute in_range = False for r in at.timeRanges: @@ -323,8 +305,7 @@ def _print_needs_action(self, conditions, flow_uid_bytes): cond_str = WorkflowFormatter.format_conditions(conditions) print(f"Pending conditions: {cond_str}") elif flow_uid_bytes: - flow_uid_str = utils.base64_url_encode(flow_uid_bytes) - print(f"Run: {bcolors.OKBLUE}pam workflow state --flow-uid {flow_uid_str}{bcolors.ENDC} " + print(f"Run: {bcolors.OKBLUE}pam workflow state {self.record_name}{bcolors.ENDC} " f"to see details.") print() @@ -359,6 +340,8 @@ def _print_needs_start(self): class WorkflowMfaPrompt: + _NO_2FA_CONFIGURED = object() + def __init__(self, params: KeeperParams): self.params = params @@ -367,6 +350,13 @@ def prompt(self): from ... import api tfa_list = self._fetch_2fa_list(self.params, api, APIRequest_pb2) + if tfa_list is self._NO_2FA_CONFIGURED: + print(f"\n{bcolors.FAIL}This workflow requires 2FA verification{bcolors.ENDC}") + print( + "Your account does not have any 2FA methods configured. " + f"For available methods, run: {bcolors.OKBLUE}2fa add -h{bcolors.ENDC}\n" + ) + return None if tfa_list is None: try: code = getpass.getpass('2FA required. Enter TOTP code: ').strip() @@ -405,12 +395,7 @@ def _fetch_2fa_list(params, api, APIRequest_pb2): return None if not tfa_list.channels: - print(f"\n{bcolors.FAIL}This workflow requires 2FA verification{bcolors.ENDC}") - print( - "Your account does not have any 2FA methods configured. " - f"For available methods, run: {bcolors.OKBLUE}2fa add -h{bcolors.ENDC}" - ) - return None + return WorkflowMfaPrompt._NO_2FA_CONFIGURED return tfa_list @@ -546,12 +531,7 @@ def check_workflow_access(params: KeeperParams, record_uid: str) -> dict: def _poll_until_state_change(validator: 'WorkflowAccessValidator', timeout_seconds: int) -> Optional[dict]: - """Poll the workflow state at _WAIT_POLL_INTERVAL_SECONDS until the - state is no longer 'waiting' or until timeout_seconds elapses. - - Returns the new validator result dict on state change, or None on - timeout / Ctrl+C / transport error. - """ + """Poll until state is no longer 'waiting' or timeout. Returns None on timeout/cancel/error.""" import time as _time deadline = _time.time() + max(timeout_seconds, _WAIT_POLL_INTERVAL_SECONDS) print( @@ -564,8 +544,6 @@ def _poll_until_state_change(validator: 'WorkflowAccessValidator', r = validator.validate(silent_actionable=True) block_reason = r.get('block_reason') if r.get('allowed', True) or block_reason != 'waiting': - # Suppress chatty output on state change; orchestrator will - # handle the next state in its own loop iteration. return r print(f"{bcolors.WARNING}Approval not received within {timeout_seconds}s.{bcolors.ENDC}\n") return None @@ -584,32 +562,8 @@ def check_workflow_for_launch( wait: bool = False, wait_timeout: int = 600, ) -> WorkflowGate: - """Pre-launch workflow gate: validate access, optionally submit a missing - reason/ticket request and check out the record inline, prompt for MFA if - required, and return the active flow's UID and lease expiry (millis since - epoch) so callers can auto check-in and force-disconnect on lease expiry. - - Three actionable validator states are auto-handled: - - * **`'no_workflow'` / `'needs_start'`** — workflow config exists on the - record but no flow process exists yet for the user (first-time access). - If the config requires a reason / ticket, the supplied --reason / - --ticket flags are used or the user is prompted inline; the initial - access request is submitted via `submit_access_request`. Mirrors web - vault's "click Launch -> reason/ticket dialog" on first-time access. - - * **`'needs_action'`** — flow exists in WS_NEEDS_ACTION with - AC_REASON / AC_TICKET pending; same prompt+submit logic but driven off - the conditions list returned by the server. - - * **`'ready_to_start'`** — flow approved but not yet checked out; user is - prompted (or `auto_checkout=True` confirms automatically), then - `start_workflow` is called and the gate reports `started_by_launch=True` - so the caller can release the lease in its cleanup path. - - Optional `wait=True` polls past `'waiting'` until approval or - `wait_timeout` elapses. - """ + """Pre-launch workflow gate: validate, auto-handle no_workflow/needs_action/ready_to_start, + prompt for MFA, and return flow UID + lease expiry. With `wait=True`, polls past 'waiting'.""" validator = WorkflowAccessValidator(params, record_uid) started_by_launch = False handled_needs_action = False @@ -617,8 +571,7 @@ def check_workflow_for_launch( handled_waiting = False handled_no_workflow = False - # Up to 5 transitions: - # no_workflow -> needs_action -> waiting -> ready_to_start -> started. + # no_workflow -> needs_action -> waiting -> ready_to_start -> started (max 5 transitions). for _attempt in range(5): result = validator.validate(silent_actionable=True) if result.get('allowed', True): @@ -627,13 +580,6 @@ def check_workflow_for_launch( if (block_reason in ('no_workflow', 'needs_start') and not handled_no_workflow): - # First-time access on a workflow-protected record: no flow row - # exists for this user yet. Match web vault's "click Launch -> - # reason/ticket dialog" first-time flow by inline-prompting - # for the required fields (driven off the workflow config's - # require* flags carried in the validator result) and submitting - # the initial access request. Re-validate to land in waiting / - # needs_action / ready_to_start / started naturally. handled_no_workflow = True requires_reason = bool(result.get('requires_reason')) requires_ticket = bool(result.get('requires_ticket')) @@ -754,12 +700,7 @@ def check_workflow_for_launch( break continue - # Block reason we don't auto-handle (waiting w/o --wait, - # transport_error, outside_time_window, no_status, or a re-visit of - # an already-handled actionable). Validator paths print their own - # message in non-silent branches; for the silent-suppressed ones - # (needs_action, ready_to_start, waiting, no_workflow, needs_start) - # fall back to the explicit print so the user always sees something. + # Fall-through for unhandled states; print explicit message so user always sees something. if block_reason == 'needs_action': validator._print_needs_action( result.get('pending_conditions') or (), @@ -783,10 +724,7 @@ def check_workflow_for_launch( two_factor_value = None if result.get('require_mfa', False): - # Match web vault: skip the MFA prompt when the gateway is known to - # be offline. The launch will surface its own gateway-offline error - # later. is_gateway_online_for_record returns None on first launch - # (no cache yet) — in that case keep the prompt to be safe. + # Skip MFA prompt when gateway is known offline; launch surfaces its own error. if is_gateway_online_for_record(params, record_uid) is False: logging.debug("Skipping workflow MFA prompt — gateway is offline.") else: @@ -804,8 +742,6 @@ def check_workflow_for_launch( def check_workflow_and_prompt_2fa(params: KeeperParams, record_uid: str): - """Backward-compatible wrapper around check_workflow_for_launch. - Prefer check_workflow_for_launch in new code — it carries flow_uid and - expires_on_ms needed for auto check-in and lease-expiry teardown.""" + """Backward-compatible wrapper. Prefer check_workflow_for_launch (carries flow_uid + expiry).""" gate = check_workflow_for_launch(params, record_uid) return (gate.allowed, gate.two_factor_value) diff --git a/keepercommander/commands/workflow/registry.py b/keepercommander/commands/workflow/registry.py index 2ea6f31ed..5353dfa52 100644 --- a/keepercommander/commands/workflow/registry.py +++ b/keepercommander/commands/workflow/registry.py @@ -11,7 +11,8 @@ from ..base import GroupCommand, dump_report_data from ...display import bcolors -from .helpers import _ENFORCEMENT_KEY + +_ENFORCEMENT_KEY = 'allow_configure_workflow_settings' from .config_commands import ( WorkflowCreateCommand, diff --git a/keepercommander/commands/workflow/requester_commands.py b/keepercommander/commands/workflow/requester_commands.py index 5cc9e35ba..decdb83f6 100644 --- a/keepercommander/commands/workflow/requester_commands.py +++ b/keepercommander/commands/workflow/requester_commands.py @@ -27,6 +27,7 @@ is_workflow_exempt, print_exempt_message, submit_access_request, + DashUidArgsMixin, ) @@ -66,6 +67,7 @@ def execute(self, params: KeeperParams, **kwargs): @staticmethod def _request(params, **kwargs): record_uid, record = RecordResolver.resolve(params, kwargs.get('record')) + RecordResolver.validate_workflow_record_type(record) if is_workflow_exempt(params, record_uid): print_exempt_message(kwargs.get('format', 'table')) return @@ -173,7 +175,7 @@ def _cancel(params, **kwargs): raise CommandError('', f'Failed to cancel request: {sanitize_router_error(e)}') -class WorkflowStartCommand(Command): +class WorkflowStartCommand(DashUidArgsMixin, Command): parser = argparse.ArgumentParser( prog='pam workflow start', description='Start a workflow (check-out). ' @@ -225,7 +227,7 @@ def execute(self, params: KeeperParams, **kwargs): raise CommandError('', f'Failed to start workflow: {sanitize_router_error(e)}') -class WorkflowEndCommand(Command): +class WorkflowEndCommand(DashUidArgsMixin, Command): parser = argparse.ArgumentParser( prog='pam workflow end', description='End a workflow (check-in).', diff --git a/keepercommander/commands/workflow/state_commands.py b/keepercommander/commands/workflow/state_commands.py index 751b7c8c4..492d36a75 100644 --- a/keepercommander/commands/workflow/state_commands.py +++ b/keepercommander/commands/workflow/state_commands.py @@ -35,11 +35,9 @@ def _fmt_ts_or_empty(ts_ms: int) -> str: class WorkflowGetStateCommand(Command): parser = argparse.ArgumentParser( prog='pam workflow state', - description='Get workflow state for a record or flow', + description='Get workflow state for a record', ) - _state_group = parser.add_mutually_exclusive_group(required=True) - _state_group.add_argument('-r', '--record', help='Record UID or name') - _state_group.add_argument('-f', '--flow-uid', help='Flow UID of active workflow') + parser.add_argument('record', help='Record UID or name') parser.add_argument('--format', dest='format', action='store', choices=['table', 'json'], default='table', help='Output format') @@ -47,22 +45,14 @@ def get_parser(self): return WorkflowGetStateCommand.parser def execute(self, params: KeeperParams, **kwargs): - record_uid = kwargs.get('record') - flow_uid = kwargs.get('flow_uid') + record_uid, record = RecordResolver.resolve(params, kwargs.get('record')) + if is_workflow_exempt(params, record_uid): + print_exempt_message(kwargs.get('format', 'table')) + return state = workflow_pb2.WorkflowState() - if flow_uid: - try: - state.flowUid = utils.base64_url_decode(flow_uid) - except Exception: - raise CommandError('', f'Invalid flow UID: "{flow_uid}"') - else: - record_uid, record = RecordResolver.resolve(params, record_uid) - if is_workflow_exempt(params, record_uid): - print_exempt_message(kwargs.get('format', 'table')) - return - record_uid_bytes = utils.base64_url_decode(record_uid) - state.resource.CopyFrom(ProtobufRefBuilder.record_ref(record_uid_bytes, record.title)) + record_uid_bytes = utils.base64_url_decode(record_uid) + state.resource.CopyFrom(ProtobufRefBuilder.record_ref(record_uid_bytes, record.title)) try: response = _post_request_to_router( @@ -204,7 +194,6 @@ def _print_table(params, response): record_name = RecordResolver.resolve_name(params, wf.resource) record_uid = utils.base64_url_encode(wf.resource.value) if wf.resource.value else '' flow_uid = utils.base64_url_encode(wf.flowUid) if wf.flowUid else '' - conditions = WorkflowFormatter.format_conditions(wf.status.conditions) if wf.status.conditions else '' checked_out_by = wf.status.checkedOutBy or '' started = _fmt_ts_or_empty(wf.status.startedOn) expires = _fmt_ts_or_empty(wf.status.expiresOn) @@ -214,10 +203,10 @@ def _print_table(params, response): a.user if a.user else RecordResolver.resolve_user(params, a.userId) for a in wf.status.approvedBy ] - approved_by = ', '.join(approved_names) - rows.append([stage, record_name, record_uid, flow_uid, checked_out_by, approved_by, started, expires, conditions]) + approved_by = '\n'.join(approved_names) + rows.append([stage, record_name, record_uid, flow_uid, checked_out_by, approved_by, started, expires]) - headers = ['Stage', 'Record Name', 'Record UID', 'Flow UID', 'Checked Out By', 'Approved By', 'Started', 'Expires', 'Conditions'] + headers = ['Stage', 'Record Name', 'Record UID', 'Flow UID', 'Checked Out By', 'Approved By', 'Started', 'Expires'] print() dump_report_data(rows, headers=headers) print() diff --git a/requirements.txt b/requirements.txt index f67f0f6c7..b96079d0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,5 +26,6 @@ winrt-runtime; sys_platform == "win32" winrt-Windows.Foundation; sys_platform == "win32" winrt-Windows.Security.Credentials.UI; sys_platform == "win32" googleapis-common-protos +tzlocal>=5.0 keeper-mlkem; python_version>='3.11' textual; python_version>='3.9' \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 3d6801983..b634aedb4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,7 @@ install_requires = keeper_pam_webrtc_rs>=2.1.17 pydantic>=2.6.4 fpdf2>=2.8.3 + tzlocal>=5.0 cbor2; sys_platform == "darwin" and python_version>='3.10' pyobjc-framework-LocalAuthentication; sys_platform == "darwin" and python_version>='3.10' winrt-runtime; sys_platform == "win32" and python_version>='3.10' From abebbd13f59d96c8207866cbfd7a4350490431d1 Mon Sep 17 00:00:00 2001 From: adeshmukh-ks Date: Sat, 9 May 2026 04:11:04 +0530 Subject: [PATCH 16/26] Added validations for PAM record update fix (#2032) --- keepercommander/commands/pam_import/base.py | 4 ++-- keepercommander/commands/record.py | 18 ++++++++++++++++-- keepercommander/commands/record_edit.py | 2 +- keepercommander/vault.py | 2 +- keepercommander/vault_extensions.py | 4 ++-- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index e5cf37835..16cbf2214 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -2226,7 +2226,7 @@ def load(cls, data: Union[str, dict]): # disable_dynamic_resizing ? "" : "display-update" val = utils.value_to_boolean(data.get("disable_dynamic_resizing", None)) - if val is not True: obj.resizeMethod = "display-update" + obj.resizeMethod = "" if val is True else "display-update" return obj @@ -2272,7 +2272,7 @@ def to_record_dict(self): kvp["enableWallpaper"] = self.enableWallpaper # populated on load - "resizeMethod": disable_dynamic_resizing ? "" : "display-update" - if str(self.resizeMethod) == "display-update": + if isinstance(self.resizeMethod, str): kvp["resizeMethod"] = self.resizeMethod if isinstance(self.sftp, SFTPConnectionSettings): diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 2091c55c5..972680f41 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -732,14 +732,28 @@ def _format_expiration(expiration_value): if total_matches == 1: # If only one match, display it directly if len(matched_records) == 1: - return self.execute(params, uid=matched_records[0].record_uid, format=fmt, unmask=kwargs.get('unmask'), legacy=kwargs.get('legacy')) + return self.execute( + params, + uid=matched_records[0].record_uid, + format=fmt, + unmask=kwargs.get('unmask'), + legacy=kwargs.get('legacy'), + include_dag=kwargs.get('include_dag') + ) elif len(matched_folders) == 1: uid = matched_folders[0].uid f = params.folder_cache[uid] sf_uid = f.uid if isinstance(f, subfolder.SharedFolderNode) else \ (f.shared_folder_uid if isinstance(f, subfolder.SharedFolderFolderNode) else None) if sf_uid and api.is_shared_folder(params, sf_uid): - return self.execute(params, uid=sf_uid, format=fmt, unmask=kwargs.get('unmask'), legacy=kwargs.get('legacy')) + return self.execute( + params, + uid=sf_uid, + format=fmt, + unmask=kwargs.get('unmask'), + legacy=kwargs.get('legacy'), + include_dag=kwargs.get('include_dag') + ) if fmt == 'json': fo = { 'folder_uid': f.uid, diff --git a/keepercommander/commands/record_edit.py b/keepercommander/commands/record_edit.py index bb94eb341..6889f83f5 100644 --- a/keepercommander/commands/record_edit.py +++ b/keepercommander/commands/record_edit.py @@ -670,7 +670,7 @@ def assign_typed_fields(self, record, fields): value = vault.TypedField.import_schedule_field(parsed_field.value) else: self.on_warning(f'Unsupported field type: {record_field.type}') - if value: + if value is not None: if isinstance(value, list): record_field.value.clear() record_field.value.extend(value) diff --git a/keepercommander/vault.py b/keepercommander/vault.py index bbe688af7..0fa854f82 100644 --- a/keepercommander/vault.py +++ b/keepercommander/vault.py @@ -327,7 +327,7 @@ def new_field(cls, field_type, field_value, field_label=None): f_type = sanitize_str_field_value(field_type) or 'text' # TODO check field value if not isinstance(field_value, list): - if field_value: + if field_value is not None: field_value = [field_value] else: field_value = [] diff --git a/keepercommander/vault_extensions.py b/keepercommander/vault_extensions.py index e974755cd..7666a9a93 100644 --- a/keepercommander/vault_extensions.py +++ b/keepercommander/vault_extensions.py @@ -401,7 +401,7 @@ def extract_typed_field(field): # type: (vault.TypedField) -> dict if rt.type in record_types.FieldTypes: ft = record_types.FieldTypes[rt.type] default_value = ft.value - if field.value: + if field.value is not None: values = field.value if isinstance(values, (str, int, dict)): values = [values] @@ -410,7 +410,7 @@ def extract_typed_field(field): # type: (vault.TypedField) -> dict if isinstance(values, list): for value in values: - if not value: + if value is None: continue if default_value is not None: if not isinstance(value, type(default_value)): From 4ddace70a46214f8dc773fb01eb5edb0248cfa47 Mon Sep 17 00:00:00 2001 From: amangalampalli-ks Date: Mon, 11 May 2026 16:14:52 +0530 Subject: [PATCH 17/26] Implement Enforcement of Record Password and Types policies (#2029) * Implement Enforcement of Record Password and Types policies * Improve password complexity and record type enforcement validation * Add --effective flag in rti for list of allowed records. --- keepercommander/commands/enterprise.py | 20 +++ keepercommander/commands/record_edit.py | 42 +++-- keepercommander/commands/recordv3.py | 44 +++++ keepercommander/enforcement.py | 205 +++++++++++++++++++++++- keepercommander/generator.py | 20 +++ 5 files changed, 320 insertions(+), 11 deletions(-) diff --git a/keepercommander/commands/enterprise.py b/keepercommander/commands/enterprise.py index 7b681fea7..7be8fe222 100644 --- a/keepercommander/commands/enterprise.py +++ b/keepercommander/commands/enterprise.py @@ -809,6 +809,10 @@ def tree_node(node): role_ids.update(team_roles[team_uid]) if column == 'role_count': row.append(len(role_ids)) + elif kwargs.get('format') == 'json': + role_info = [{'role_id': rid, 'role_name': roles[rid]['name']} + for rid in role_ids if rid in roles] + row.append(role_info) else: role_names = [roles[role_id]['name'] for role_id in role_ids if role_id in roles] row.append(role_names) @@ -987,6 +991,22 @@ def tree_node(node): enforcement_type = constants.ENFORCEMENTS.get(k) if enforcement_type == 'two_factor_duration': formatted_enforcements[k] = constants.format_two_factor_duration(v) + elif enforcement_type == 'record_types': + try: + rto = v if isinstance(v, dict) else json.loads(v) + if params.record_type_cache: + record_types = [] + for record_type_id in itertools.chain(rto.get('std') or [], rto.get('ent') or []): + if record_type_id in params.record_type_cache: + rtc = json.loads(params.record_type_cache[record_type_id]) + if '$id' in rtc: + record_types.append(rtc['$id']) + formatted_enforcements[k] = ', '.join(record_types) + else: + formatted_enforcements[k] = v + except (json.JSONDecodeError, TypeError, KeyError, ValueError) as e: + logging.debug('Failed to format record_types enforcement %s: %s', k, e) + formatted_enforcements[k] = v else: formatted_enforcements[k] = v row.append(formatted_enforcements) diff --git a/keepercommander/commands/record_edit.py b/keepercommander/commands/record_edit.py index 6889f83f5..aceb2c301 100644 --- a/keepercommander/commands/record_edit.py +++ b/keepercommander/commands/record_edit.py @@ -28,6 +28,7 @@ from .. import api, utils, vault, record_types, generator, crypto, attachment, record_facades, record_management from ..breachwatch import BreachWatch from ..commands import recordv3 +from ..enforcement import PasswordComplexityEnforcer, RecordTypeEnforcer from ..error import CommandError from ..params import KeeperParams, LAST_RECORD_UID from ..subfolder import try_resolve_path, find_folders, get_folder_path @@ -230,6 +231,7 @@ class RecordEditMixin: def __init__(self): self.warnings = [] + self._password_policy = None # type: Optional[Dict[str, Any]] def on_warning(self, message): if message: @@ -390,7 +392,7 @@ def generate_key_pair(key_type, passphrase): # type: (str, str) -> dict } @staticmethod - def generate_password(parameters=None): # type: (Optional[Sequence[str]]) -> str + def generate_password(parameters=None, policy=None): # type: (Optional[Sequence[str]], Optional[dict]) -> str if isinstance(parameters, (tuple, list, set)): algorithm = next((x for x in parameters if x in ('rand', 'dice', 'crypto')), 'rand') length = next((x for x in parameters if x.isnumeric()), None) @@ -415,14 +417,17 @@ def generate_password(parameters=None): # type: (Optional[Sequence[str]]) -> s length = 5 gen = generator.DicewarePasswordGenerator(length) else: - if isinstance(length, int): - if length < 4: - length = 4 - elif length > 200: - length = 200 + if policy: + gen = generator.KeeperPasswordGenerator.create_from_policy(policy, length_override=length) else: - length = 20 - gen = generator.KeeperPasswordGenerator(length=length) + if isinstance(length, int): + if length < 4: + length = 4 + elif length > 200: + length = 200 + else: + length = 20 + gen = generator.KeeperPasswordGenerator(length=length) return gen.generate() @staticmethod @@ -599,7 +604,7 @@ def assign_typed_fields(self, record, fields): action_params = [] if self.is_generate_value(parsed_field.value, action_params): if record_field.type == 'password': - value = self.generate_password(action_params) + value = self.generate_password(action_params, policy=self._password_policy) elif record_field.type in ('oneTimeCode', 'otp'): value = self.generate_totp_url() elif record_field.type in ('keyPair', 'privateKey'): @@ -833,6 +838,9 @@ def execute(self, params, **kwargs): if not record_type: raise CommandError('record-add', 'Record type parameter is required.') + RecordTypeEnforcer.enforce(params, record_type, 'record-add') + self._password_policy = PasswordComplexityEnforcer.get_policy(params) + fields = kwargs.get('fields', []) # Filter out empty strings that might be introduced by copy-paste or line continuation issues fields = [field.strip() for field in fields if field.strip()] @@ -879,6 +887,12 @@ def execute(self, params, **kwargs): record.title = title record.notes = self.validate_notes(kwargs.get('notes') or '') + pw_failures = PasswordComplexityEnforcer.validate_record(params, record) + for f in pw_failures: + self.on_warning(f) + if pw_failures and not kwargs.get('force'): + self.on_warning('Use --force to bypass password policy warnings.') + ignore_warnings = kwargs.get('force') is True if len(self.warnings) > 0: for warning in self.warnings: @@ -1243,12 +1257,15 @@ def execute(self, params, **kwargs): else: record_fields.append(parsed_field) + self._password_policy = PasswordComplexityEnforcer.get_policy(params) + if isinstance(record, vault.PasswordRecord): raise CommandError('record-update', 'Legacy record type is not supported. Convert the record to login record type.') # self.assign_legacy_fields(record, record_fields) elif isinstance(record, vault.TypedRecord): record_type = kwargs.get('record_type') if record_type: + RecordTypeEnforcer.enforce(params, record_type, 'record-update') record.type_name = record_type rt_fields = self.get_record_type_fields(params, record_type) if not rt_fields: @@ -1258,6 +1275,13 @@ def execute(self, params, **kwargs): else: raise CommandError('record-update', f'Record \"{record_name}\" can not be edited.') + if isinstance(record, vault.TypedRecord): + pw_failures = PasswordComplexityEnforcer.validate_record(params, record) + for f in pw_failures: + self.on_warning(f) + if pw_failures and not kwargs.get('force'): + self.on_warning('Use --force to bypass password policy warnings.') + ignore_warnings = kwargs.get('force') is True if len(self.warnings) > 0: for warning in self.warnings: diff --git a/keepercommander/commands/recordv3.py b/keepercommander/commands/recordv3.py index 0e663b79d..2ccc6fd41 100644 --- a/keepercommander/commands/recordv3.py +++ b/keepercommander/commands/recordv3.py @@ -26,6 +26,7 @@ from .. import api, crypto, generator from .. import recordv3, loginv3 from ..display import bcolors +from ..enforcement import PasswordComplexityEnforcer, RecordTypeEnforcer from ..error import CommandError from ..params import LAST_RECORD_UID from ..proto import record_pb2 as records @@ -116,6 +117,8 @@ def register_command_info(aliases, command_info): # command_group.add_argument('-lc', '--category', dest='category', action='store', default=None, const = '*', nargs='?', help='list categories or record types in a category') command_group.add_argument('-lr', '--list-record', dest='record_name', action='store', default=None, const = '*', nargs='?', help='list record type by name or use * to list all') command_group.add_argument('-lf', '--list-field', type=str, dest='field_name', action='store', default=None, help='list field type by name or use * to list all') +record_type_info_parser.add_argument('-ef', '--effective', dest='effective', action='store_true', + help='filter -lr results to record types allowed by your enterprise role policy') record_type_parser = argparse.ArgumentParser(prog='record-type', description='Add, modify or delete record type definition') @@ -223,6 +226,8 @@ def execute(self, params, **kwargs): # ' - to get list of all available record types use: record-type-info -lr' + bcolors.ENDC) raise CommandError('add', f'Record type definition not found for type: {rt} - to get list of all available record types use: record-type-info -lr') + RecordTypeEnforcer.enforce(params, rt, 'add') + data_json = str(kwargs['data']).strip() if 'data' in kwargs and kwargs['data'] else None data_file = str(kwargs['data_file']).strip() if 'data_file' in kwargs and kwargs['data_file'] else None data_opts = recordv3.RecordV3.convert_options_to_json(params, '', rt_def, kwargs) if rt_def else None @@ -405,6 +410,16 @@ def GCM_TAG_LEN(): return 16 if password: data = recordv3.RecordV3.update_password(password, data, recordv3.RecordV3.get_record_type_definition(params, data)) + pw_failures = PasswordComplexityEnforcer.validate_record(params, data) + if pw_failures: + for f in pw_failures: + logging.warning(bcolors.WARNING + f + bcolors.ENDC) + if not kwargs.get('force'): + raise CommandError( + 'add', + 'Password does not meet enterprise complexity policy. ' + 'Pass --force to bypass these warnings.') + record_uid = api.generate_record_uid() logging.debug('Generated Record UID: %s', record_uid) record = { @@ -602,6 +617,9 @@ def execute(self, params, **kwargs): ' - to get list of all available record types use: record-type-info -lr' + bcolors.ENDC) return + if rt and rt != rt_name: + RecordTypeEnforcer.enforce(params, rt, 'edit') + data_json = str(kwargs['data']).strip() if 'data' in kwargs and kwargs['data'] else None data_file = str(kwargs['data_file']).strip() if 'data_file' in kwargs and kwargs['data_file'] else None data_opts = recordv3.RecordV3.convert_options_to_json(params, record_data, rt_def, kwargs) if rt_def else None @@ -642,6 +660,16 @@ def execute(self, params, **kwargs): record.password = password data = recordv3.RecordV3.update_password(password, data, recordv3.RecordV3.get_record_type_definition(params, data)) + pw_failures = PasswordComplexityEnforcer.validate_record(params, data) + if pw_failures: + for f in pw_failures: + logging.warning(bcolors.WARNING + f + bcolors.ENDC) + if not kwargs.get('force'): + raise CommandError( + 'edit', + 'Password does not meet enterprise complexity policy. ' + 'Pass --force to bypass these warnings.') + data_dict = json.loads(data) changed = rdata_dict != data_dict # changed = json.dumps(rdata_dict, sort_keys=True) != json.dumps(data_dict, sort_keys=True) @@ -767,6 +795,17 @@ class RecordTypeInfo(Command): def get_parser(self): return record_type_info_parser + @staticmethod + def _type_name(params, rtid): + entry = params.record_type_cache.get(rtid) if params.record_type_cache else None + if not entry: + return None + try: + schema = json.loads(entry) if isinstance(entry, str) else entry + except (json.JSONDecodeError, TypeError): + return None + return schema.get('$id') if isinstance(schema, dict) else None + @staticmethod def resolve_record_type(params, record_type_id): record_type_info = {} @@ -908,6 +947,11 @@ def execute(self, params, **kwargs): has_categories_only = False row_data = RecordTypeInfo.resolve_record_types(params, lrid) + if kwargs.get('effective'): + restricted = RecordTypeEnforcer.get_restricted_record_types(params) + if restricted: + row_data = [r for r in row_data if RecordTypeInfo._type_name(params, r[2]) not in restricted] + rows = [] for count, cat, rtid, content in row_data: record_type = { diff --git a/keepercommander/enforcement.py b/keepercommander/enforcement.py index 2ab093115..6d7e802ca 100644 --- a/keepercommander/enforcement.py +++ b/keepercommander/enforcement.py @@ -9,18 +9,55 @@ # Contact: ops@keepersecurity.com # +import itertools import json import logging import getpass import threading from datetime import datetime, timedelta -from typing import Tuple +from typing import Tuple, Optional, List, Dict, Any, Set from . import api, utils, crypto from .proto import APIRequest_pb2 from .display import bcolors from .params import KeeperParams -from .error import KeeperApiError +from .error import KeeperApiError, CommandError + + +def _find_enforcement_value(enforcements, key): + # type: (Any, str) -> Any + """Return raw enforcement value for `key` across known layouts, or None.""" + if not isinstance(enforcements, dict): + return None + for bucket in ('jsons', 'strings'): + items = enforcements.get(bucket) + if isinstance(items, list): + for item in items: + if isinstance(item, dict) and item.get('key') == key: + return item.get('value') + return enforcements.get(key) if key in enforcements else None + + +def _coerce_int(value): + # type: (Any) -> Optional[int] + """Coerce enforcement values to int. + + Server-side enforcement payloads sometimes serialize numeric fields as + strings. Returns None if the value cannot be safely interpreted as int. + `bool` is intentionally rejected (it's a subclass of int). + """ + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, str): + s = value.strip() + if s.lstrip('-').isdigit(): + try: + return int(s) + except ValueError: + return None + return None class MasterPasswordReentryEnforcer: @@ -354,3 +391,167 @@ def check_and_enforce(cls, params: KeeperParams, operation: str = "record_level" # If no authentication methods are available, policy does not apply logging.info("Master Password Re-authentication skipped: Master Password or alternate SSO Master Password not available - policy does not apply") return True + + +class PasswordComplexityEnforcer: + """Enforces GENERATED_PASSWORD_COMPLEXITY policy on record passwords.""" + + _POLICY_KEY = 'generated_password_complexity' + + _CHAR_CLASSES = ( + ('lower-use', 'lower-min', 'lowercase', str.islower), + ('upper-use', 'upper-min', 'uppercase', str.isupper), + ('digit-use', 'digit-min', 'digit', str.isdigit), + ) + + @classmethod + def get_policy(cls, params): # type: (KeeperParams) -> Optional[Dict[str, Any]] + if not params or not params.enforcements: + return None + raw = _find_enforcement_value(params.enforcements, cls._POLICY_KEY) + if raw is None: + return None + try: + rules = json.loads(raw) if isinstance(raw, str) else raw + except (json.JSONDecodeError, TypeError): + logging.debug('Failed to parse %s enforcement', cls._POLICY_KEY) + return None + if isinstance(rules, list) and rules and isinstance(rules[0], dict): + return rules[0] + if isinstance(rules, dict): + return rules + return None + + @classmethod + def validate_password(cls, password, policy): # type: (str, Dict[str, Any]) -> List[str] + failures = [] # type: List[str] + if not policy or not isinstance(password, str) or not password: + return failures + + min_length = _coerce_int(policy.get('length')) + if min_length is not None and min_length > 0 and len(password) < min_length: + failures.append( + f'Password must be at least {min_length} characters (got {len(password)}).') + + for use_key, min_key, label, predicate in cls._CHAR_CLASSES: + if not policy.get(use_key): + continue + required = _coerce_int(policy.get(min_key, 1)) + if required is None or required <= 0: + continue + count = sum(1 for c in password if predicate(c)) + if count < required: + failures.append( + f'Password must contain at least {required} {label} character(s) (got {count}).') + + if policy.get('special-use'): + required = _coerce_int(policy.get('special-min', 1)) + if required is not None and required > 0: + allowed = policy.get('special') or '' + count = (sum(1 for c in password if c in allowed) if allowed + else sum(1 for c in password if not c.isalnum())) + if count < required: + failures.append( + f'Password must contain at least {required} special character(s) (got {count}).') + + return failures + + @classmethod + def validate_record(cls, params, source): # type: (KeeperParams, Any) -> List[str] + """Return policy violations across all password fields in `source`. + + `source` may be a vault.TypedRecord, a v3 record-data dict, or a JSON + string of that dict. Returns [] when no policy applies or no password + fields are present. + """ + policy = cls.get_policy(params) + if not policy: + return [] + failures = [] # type: List[str] + for pw in cls._extract_passwords(source): + failures.extend(cls.validate_password(pw, policy)) + return failures + + @staticmethod + def _extract_passwords(source): # type: (Any) -> List[str] + if source is None: + return [] + passwords = [] # type: List[str] + if hasattr(source, 'fields') and hasattr(source, 'custom'): + for fld in itertools.chain(source.fields or [], source.custom or []): + if getattr(fld, 'type', None) == 'password': + val = getattr(fld, 'value', None) + if isinstance(val, list): + passwords.extend(v for v in val if isinstance(v, str) and v) + return passwords + + data = source + if isinstance(data, str): + try: + data = json.loads(data) + except (json.JSONDecodeError, TypeError): + return [] + if not isinstance(data, dict): + return [] + for fld in data.get('fields') or []: + if isinstance(fld, dict) and fld.get('type') == 'password': + val = fld.get('value') + if isinstance(val, list): + passwords.extend(v for v in val if isinstance(v, str) and v) + return passwords + + +class RecordTypeEnforcer: + """Enforces RESTRICT_RECORD_TYPES policy on record creation/update. + + Stored value is a JSON object {"std": [, ...], "ent": [, ...]} of + record-type IDs the user is *blocked* from creating. IDs are resolved to + type names via `params.record_type_cache`. + """ + + _POLICY_KEY = 'restrict_record_types' + + @classmethod + def get_restricted_record_types(cls, params): # type: (KeeperParams) -> Optional[Set[str]] + if not params or not params.enforcements: + return None + raw = _find_enforcement_value(params.enforcements, cls._POLICY_KEY) + if raw is None: + return None + try: + policy = raw if isinstance(raw, dict) else json.loads(raw) + except (json.JSONDecodeError, TypeError): + logging.debug('Failed to parse %s enforcement', cls._POLICY_KEY) + return None + if not isinstance(policy, dict): + return None + + cache = getattr(params, 'record_type_cache', None) or {} + restricted = set() # type: Set[str] + for rt_id in list(policy.get('std') or []) + list(policy.get('ent') or []): + entry = cache.get(rt_id) + if not entry: + continue + try: + schema = json.loads(entry) if isinstance(entry, str) else entry + except (json.JSONDecodeError, TypeError): + continue + if isinstance(schema, dict): + name = schema.get('$id') + if name: + restricted.add(name) + return restricted + + @classmethod + def enforce(cls, params, record_type, command): + # type: (KeeperParams, Optional[str], str) -> None + """Raise CommandError when `record_type` is blocked by policy.""" + if not record_type: + return + restricted = cls.get_restricted_record_types(params) + if not restricted or record_type not in restricted: + return + raise CommandError( + command, + f'Record type "{record_type}" is restricted by your enterprise role policy ' + f'and cannot be created.') diff --git a/keepercommander/generator.py b/keepercommander/generator.py index 0fc6c4b1e..e566d0187 100644 --- a/keepercommander/generator.py +++ b/keepercommander/generator.py @@ -132,6 +132,26 @@ def create_from_rules(cls, rule_string: str, length: Optional[int] = None, length = sum(rule_list) if length is None else length return cls(length=length, caps=upper, lower=lower, digits=digits, symbols=symbols, special_characters=special_characters) + @classmethod + def create_from_policy(cls, policy, length_override=None): + # type: (dict, Optional[int]) -> KeeperPasswordGenerator + """Create a generator that satisfies the given password complexity enforcement policy.""" + pw_length = length_override or policy.get('length') or DEFAULT_PASSWORD_LENGTH + lower_min = policy.get('lower-min', 0) if policy.get('lower-use') else None + upper_min = policy.get('upper-min', 0) if policy.get('upper-use') else None + digit_min = policy.get('digit-min', 0) if policy.get('digit-use') else None + special_min = policy.get('special-min', 0) if policy.get('special-use') else None + special_chars = policy.get('special', PW_SPECIAL_CHARACTERS) or PW_SPECIAL_CHARACTERS + + return cls( + length=pw_length, + lower=lower_min, + caps=upper_min, + digits=digit_min, + symbols=special_min, + special_characters=special_chars + ) + class DicewarePasswordGenerator(PasswordGenerator): def __init__(self, number_of_rolls, word_list_file=None, delimiter=' '): # type: (int, Optional[str], str) -> None From 9c8ec6f852bd35a0904a6c18d923c868f9d3ccd8 Mon Sep 17 00:00:00 2001 From: pvagare-ks Date: Tue, 12 May 2026 23:01:03 +0530 Subject: [PATCH 18/26] rename "shared-manager" role label to "share-manager" (#2041) * rename "shared-manager" role label to "share-manager" (#2040) * claude review changes * added share-manager mapping --- KEEPER_DRIVE_COMMANDS.md | 12 +++++------- .../commands/keeper_drive/display_commands.py | 2 +- keepercommander/commands/keeper_drive/helpers.py | 16 ++++++++-------- keepercommander/commands/keeper_drive/parsers.py | 6 +++--- .../commands/keeper_drive/sharing_commands.py | 2 +- keepercommander/keeper_drive/permissions.py | 3 ++- unit-tests/test_keeper_drive.py | 6 +++--- 7 files changed, 23 insertions(+), 24 deletions(-) diff --git a/KEEPER_DRIVE_COMMANDS.md b/KEEPER_DRIVE_COMMANDS.md index 06bba3688..6d04d717b 100644 --- a/KEEPER_DRIVE_COMMANDS.md +++ b/KEEPER_DRIVE_COMMANDS.md @@ -12,7 +12,7 @@ To get help on a particular command, run: | Command | Description | | ------------------------ | ------------------------------------------------------------------- | | `[kd-mkdir]` | Create a new KeeperDrive folder | -| `[kd-rndir]` | Rename a folder, change its color, or update permission inheritance | +| `[kd-rndir]` | Rename a folder or change its color | | `[kd-list]` | List KeeperDrive folders and records | | `[kd-rmdir]` | Remove one or more KeeperDrive folders | | `[kd-share-folder]` | Grant or remove a user's access to a folder | @@ -34,7 +34,7 @@ To get help on a particular command, run: | Role | Description | | ----------------------- | --------------------------------------------- | | `viewer` | Read-only access | -| `shared-manager` | Can manage access grants | +| `share-manager` | Can manage access grants | | `content-manager` | Can add/edit records | | `content-share-manager` | Can add/remove/edit records and manage access | | `full-manager` | Full control | @@ -76,7 +76,7 @@ kd-mkdir "Reports//2026" **Command:** `kd-rndir` -**Detail:** Rename a folder, change its color, or update its permission-inheritance setting. At least one of `--name`, `--color`, `--inherit`, or `--no-inherit` is required. +**Detail:** Rename a folder or change its color. At least one of `--name` or `--color` is required. **Parameters:** @@ -88,9 +88,7 @@ Folder UID, name, or path `--color ` New color: `none` `red` `orange` `yellow` `green` `blue` `gray` -`--inherit` Enable permission inheritance from parent folder -`--no-inherit` Disable permission inheritance from parent folder `-q`, `--quiet` Suppress confirmation message @@ -99,13 +97,13 @@ Folder UID, name, or path ``` kd-rndir "Old Name" --name "New Name" kd-rndir abc123 --color blue -kd-rndir abc123 --name "Archive Q4" --color gray --inherit +kd-rndir abc123 --name "Archive Q4" --color gray kd-rndir abc123 --name "Finance" -q ``` 1. Rename a folder by its current name 2. Change a folder's color using its UID -3. Rename, recolor, and enable permission inheritance in one command +3. Rename and recolor a folder in one command 4. Rename a folder silently with no confirmation output --- diff --git a/keepercommander/commands/keeper_drive/display_commands.py b/keepercommander/commands/keeper_drive/display_commands.py index a0d5874b8..3599530f1 100644 --- a/keepercommander/commands/keeper_drive/display_commands.py +++ b/keepercommander/commands/keeper_drive/display_commands.py @@ -351,7 +351,7 @@ def _folder_permission_summary(accessor): Uses the server-supplied ``role`` (an ``AccessRoleType`` enum name) and renders it as a canonical KeeperDrive role label (e.g. ``full-manager``, - ``shared-manager``, ``viewer``). Falls back to permission-flag based + ``share-manager``, ``viewer``). Falls back to permission-flag based inference for legacy access rows that omit ``role``. """ if not isinstance(accessor, dict): diff --git a/keepercommander/commands/keeper_drive/helpers.py b/keepercommander/commands/keeper_drive/helpers.py index e2c2ade47..a8f30412d 100644 --- a/keepercommander/commands/keeper_drive/helpers.py +++ b/keepercommander/commands/keeper_drive/helpers.py @@ -348,14 +348,14 @@ def infer_role(access): Follows the official permission matrix:: - full-manager > content-share-manager > shared-manager > + full-manager > content-share-manager > share-manager > content-manager > viewer > contributor > requestor > navigator - The distinguishing trait between ``shared-manager`` and + The distinguishing trait between ``share-manager`` and ``content-share-manager`` is the ability to *edit* records: both roles grant ``can_update_access`` + ``can_approve_access``, but only ``content-share-manager`` also grants ``can_edit``. Without that check - every shared-manager would be reported as content-share-manager. + every share-manager would be reported as content-share-manager. """ get = access.get if get('can_change_ownership') or get('can_delete'): @@ -363,9 +363,9 @@ def infer_role(access): if get('can_update_access') and get('can_approve_access') and get('can_edit'): return 'content-share-manager' if get('can_update_access') and get('can_approve_access'): - return 'shared-manager' + return 'share-manager' if get('can_update_access'): - return 'shared-manager' + return 'share-manager' if get('can_edit'): return 'content-manager' if get('can_view') and get('can_list_access'): @@ -395,7 +395,7 @@ def role_label(access_role_type): 'NAVIGATOR': 'contributor', 'REQUESTOR': 'contributor', 'VIEWER': 'viewer', - 'SHARED_MANAGER': 'shared-manager', + 'SHARED_MANAGER': 'share-manager', 'CONTENT_MANAGER': 'content-manager', 'CONTENT_SHARE_MANAGER': 'content-share-manager', 'MANAGER': 'full-manager', @@ -408,7 +408,7 @@ def format_role_display(role): Accepts either the proto enum name (``'SHARED_MANAGER'``) or its integer value, and returns the canonical hyphenated lowercase label used across - KeeperDrive (``'shared-manager'``, ``'full-manager'``, ``'viewer'`` …). + KeeperDrive (``'share-manager'``, ``'full-manager'``, ``'viewer'`` …). Falls back to a best-effort lowercase form when the role is unknown. """ if role is None or role == '': @@ -431,7 +431,7 @@ def get_access_role_label(access): Prefers the stored ``access_role_type`` (proto enum int) when available; otherwise falls back to inferring the role from permission flags. The returned label uses the canonical hyphenated lowercase KeeperDrive form - (e.g. ``'full-manager'``, ``'shared-manager'``, ``'viewer'``). + (e.g. ``'full-manager'``, ``'share-manager'``, ``'viewer'``). """ role_int = access.get('access_role_type') if role_int is not None: diff --git a/keepercommander/commands/keeper_drive/parsers.py b/keepercommander/commands/keeper_drive/parsers.py index 92375956a..aa5ebaa33 100644 --- a/keepercommander/commands/keeper_drive/parsers.py +++ b/keepercommander/commands/keeper_drive/parsers.py @@ -89,7 +89,7 @@ def _make_parser(prog, description): keeper_drive_share_folder_parser.add_argument( '-r', '--role', dest='role', choices=[ - 'viewer', 'shared-manager', + 'viewer', 'share-manager', 'content-manager', 'content-share-manager', 'full-manager', ], default='viewer', @@ -206,7 +206,7 @@ def _make_parser(prog, description): keeper_drive_share_record_parser.add_argument( '-r', '--role', dest='role', choices=[ - 'viewer', 'shared-manager', + 'viewer', 'share-manager', 'content-manager', 'content-share-manager', 'full-manager', ], help='permission role. Required for grant/update actions') @@ -243,7 +243,7 @@ def _make_parser(prog, description): keeper_drive_record_permission_parser.add_argument( '-r', '--role', dest='role', choices=[ - 'viewer', 'shared-manager', + 'viewer', 'share-manager', 'content-manager', 'content-share-manager', 'full-manager', ], help='Permission role to grant, or filter for revoke') diff --git a/keepercommander/commands/keeper_drive/sharing_commands.py b/keepercommander/commands/keeper_drive/sharing_commands.py index a90972f90..5b4ed083e 100644 --- a/keepercommander/commands/keeper_drive/sharing_commands.py +++ b/keepercommander/commands/keeper_drive/sharing_commands.py @@ -242,7 +242,7 @@ class KeeperDriveRecordPermissionCommand(Command): """Bulk-update sharing permissions on records within a KeeperDrive folder.""" _ROLE_NAMES = [ - 'viewer', 'shared-manager', + 'viewer', 'share-manager', 'content-manager', 'content-share-manager', 'full-manager', ] diff --git a/keepercommander/keeper_drive/permissions.py b/keepercommander/keeper_drive/permissions.py index f50cff65a..8e5e84f73 100644 --- a/keepercommander/keeper_drive/permissions.py +++ b/keepercommander/keeper_drive/permissions.py @@ -115,7 +115,8 @@ class SetBooleanValue: 'requestor': 1, 'viewer': 2, 'shared_manager': 3, - 'shared-manager': 3, + 'share-manager': 3, + 'share_manager': 3, 'content_manager': 4, 'content-manager': 4, 'content_share_manager': 5, diff --git a/unit-tests/test_keeper_drive.py b/unit-tests/test_keeper_drive.py index fb429c804..961c09720 100644 --- a/unit-tests/test_keeper_drive.py +++ b/unit-tests/test_keeper_drive.py @@ -102,7 +102,7 @@ def test_infer_role(self): from keepercommander.commands.keeper_drive.helpers import infer_role self.assertEqual(infer_role({'can_change_ownership': True}), 'full-manager') # ``can_update_access`` + ``can_approve_access`` alone (no edit) is - # ``shared-manager``; promotion to ``content-share-manager`` requires + # ``share-manager``; promotion to ``content-share-manager`` requires # ``can_edit`` per the v3 permission matrix. self.assertEqual( infer_role({'can_update_access': True, 'can_approve_access': True, @@ -111,9 +111,9 @@ def test_infer_role(self): ) self.assertEqual( infer_role({'can_update_access': True, 'can_approve_access': True}), - 'shared-manager', + 'share-manager', ) - self.assertEqual(infer_role({'can_update_access': True}), 'shared-manager') + self.assertEqual(infer_role({'can_update_access': True}), 'share-manager') self.assertEqual(infer_role({'can_edit': True}), 'content-manager') self.assertEqual(infer_role({'can_view': True, 'can_list_access': True}), 'viewer') self.assertEqual(infer_role({'can_view_title': True}), 'requestor') From 33ce576c3a9cdb56ec5bb45905605000b5a04973 Mon Sep 17 00:00:00 2001 From: John Walstra Date: Tue, 12 May 2026 16:19:33 -0500 Subject: [PATCH 19/26] Add `saas` profile to `pam rotation edit` command. --- keepercommander/commands/discoveryrotation.py | 17 +++++++++++++++-- keepercommander/commands/pam_saas/set.py | 3 ++- keepercommander/sync_down.py | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 4d53a9d9a..621b40b05 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -394,9 +394,11 @@ class PAMCreateRecordRotationCommand(Command): parser.add_argument('--iam-aad-config', '-iac', dest='iam_aad_config_uid', action='store', help='UID of a PAM Configuration. Used for an IAM or Azure AD user in place of --resource.') parser.add_argument('--rotation-profile', '-rp', dest='rotation_profile', action='store', - choices=['general', 'iam_user', 'scripts_only'], + choices=['general', 'iam_user', 'scripts_only', 'saas'], help='Rotation profile type: general (resource-based), iam_user (IAM/Azure user), ' - 'scripts_only (run PAM scripts only)') + 'scripts_only (run PAM scripts only), saas (SaaS only)') + parser.add_argument('--saas-config-uid', dest='saas_config_uid', action='store', + help='For saas rotation profile, the SaaS configuration UID') parser.add_argument('--resource', '-rs', dest='resource', action='store', help='UID or path of the resource record.') schedule_group = parser.add_mutually_exclusive_group() @@ -1181,6 +1183,17 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None if not resource_uid: raise CommandError('', 'General rotation profile requires --resource to be specified.') config_user(tmp_dag, _record, resource_uid, config_uid, silent=kwargs.get('silent')) + elif rotation_profile == 'saas': + saas_config_uid = kwargs.get("saas_config_uid") # type: Optional[str] + if saas_config_uid is None: + raise CommandError('', 'SaaS rotation profile requires ' + '--saas-config-uid to be specified.') + saas_command = PAMActionSaasSetCommand() + saas_command.execute(params, + user_uid=_record.record_uid, + pam_config_uid=config_uid, + config_record_uid=saas_config_uid) + # NB! --folder=UID without --iam-aad-config, or --schedule-only converts to General rotation elif iam_aad_config_uid: config_iam_aad_user(tmp_dag, _record, iam_aad_config_uid) diff --git a/keepercommander/commands/pam_saas/set.py b/keepercommander/commands/pam_saas/set.py index 9d55ae63a..be9976924 100644 --- a/keepercommander/commands/pam_saas/set.py +++ b/keepercommander/commands/pam_saas/set.py @@ -91,7 +91,8 @@ def execute(self, params: KeeperParams, **kwargs): return if plugin_name not in plugins: - print(self._f("The SaaS configuration record's custom field label 'SaaS Type' is not supported by the " + print(self._f(f"The SaaS configuration record's custom field label 'SaaS Type', {plugin_name}, " + "is not supported by the " "gateway or the value is not correct.")) return diff --git a/keepercommander/sync_down.py b/keepercommander/sync_down.py index fde03562c..84b7e2205 100644 --- a/keepercommander/sync_down.py +++ b/keepercommander/sync_down.py @@ -589,6 +589,22 @@ def convert_user_folder_shared_folder(ufsf): params.user_cache[account_uid] = params.user if len(response.recordRotations) > 0: + # Stuff still uses record_rotation_cache; it cannot just be removed. + for rr in response.recordRotations: + record_uid = utils.base64_url_encode(rr.recordUid) + rr_obj = { + 'record_uid': record_uid, + 'revision': rr.revision, + 'configuration_uid': utils.base64_url_encode(rr.configurationUid), + 'schedule': rr.schedule, + 'pwd_complexity': utils.base64_url_encode(rr.pwdComplexity), + 'disabled': rr.disabled, + 'resource_uid': utils.base64_url_encode(rr.resourceUid), + 'last_rotation': rr.lastRotation, + 'last_rotation_status': rr.lastRotationStatus, + } + params.record_rotation_cache[record_uid] = rr_obj + record_rotation_items.extend(response.recordRotations) params.sync_down_token = response.continuationToken From db2c1ad79cb66a50202f37c2e0414db1e25ad02a Mon Sep 17 00:00:00 2001 From: sshrushanth-ks Date: Tue, 5 May 2026 10:49:02 +0530 Subject: [PATCH 20/26] KC-706: Updated master password grading to use zxcvbn (#1999) * Updated master password grading to use zxcvbn * Added zxcvbn to install_requires * Added zxcvbn-based master password strength grading alongside BreachWatch score. * Addressed PR review comments on master_password_score --- keepercommander/commands/utils.py | 8 ++++---- keepercommander/utils.py | 15 +++++++++++++++ requirements.txt | 1 + setup.cfg | 1 + unit-tests/test_crypto.py | 8 ++++++++ 5 files changed, 29 insertions(+), 4 deletions(-) diff --git a/keepercommander/commands/utils.py b/keepercommander/commands/utils.py index c4dfe885d..4dcb8e102 100644 --- a/keepercommander/commands/utils.py +++ b/keepercommander/commands/utils.py @@ -47,7 +47,7 @@ from ..params import KeeperParams, LAST_RECORD_UID, LAST_FOLDER_UID, LAST_SHARED_FOLDER_UID from ..proto import ssocloud_pb2, enterprise_pb2, APIRequest_pb2 from ..security_audit import needs_security_audit, update_security_audit_data -from ..utils import password_score +from ..utils import password_score, master_password_score from ..vault import KeeperRecord from ..versioning import is_binary_app, is_up_to_date_version @@ -2322,6 +2322,9 @@ def execute(self, params, **kwargs): logging.warning('Password rules:\n%s', '\n'.join((f' {x}' for x in failed_rules))) return + score = utils.master_password_score(new_password) + logging.info('Password strength: %s', 'WEAK' if score <= 25 else 'FAIR' if score == 50 else 'MEDIUM' if score == 75 else 'STRONG') + if params.breach_watch: euids = [] for result in params.breach_watch.scan_passwords(params, [new_password]): @@ -2330,9 +2333,6 @@ def execute(self, params, **kwargs): logging.info('Breachwatch password scan result: %s', 'WEAK' if result[1].breachDetected else 'GOOD') if euids: params.breach_watch.delete_euids(params, euids) - else: - score = utils.password_score(new_password) - logging.info('Password strength: %s', 'WEAK' if score < 40 else 'FAIR' if score < 60 else 'MEDIUM' if score < 80 else 'STRONG') iterations = current_salt.iterations if current_salt else constants.PBKDF2_ITERATIONS iterations = max(iterations, constants.PBKDF2_ITERATIONS) diff --git a/keepercommander/utils.py b/keepercommander/utils.py index 384ab4926..3def9f50d 100644 --- a/keepercommander/utils.py +++ b/keepercommander/utils.py @@ -27,6 +27,7 @@ from . import crypto from .constants import EMAIL_PATTERN +import zxcvbn as _zxcvbn VALID_URL_SCHEME_CHARS = '+-.:' @@ -429,6 +430,20 @@ def is_pw_strong(pw_score): # type: (int) -> bool return pw_score >= 80 +_MASTER_PASSWORD_SCORE_MAP = {0: 25, 1: 25, 2: 50, 3: 75, 4: 100} + + +def master_password_score(password): # type: (str) -> int + if not password or not isinstance(password, str): + return 0 + try: + result = _zxcvbn.zxcvbn(password) + return _MASTER_PASSWORD_SCORE_MAP.get(result.get('score'), 25) + except Exception as e: + logging.debug('zxcvbn scoring failed: %s', e) + return 25 + + def is_rec_at_risk(bw_result): # type (int) -> bool return bw_result in (2, 3) diff --git a/requirements.txt b/requirements.txt index b96079d0b..1a2b894bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ asciitree +zxcvbn bcrypt colorama prompt_toolkit diff --git a/setup.cfg b/setup.cfg index b634aedb4..5de629eb2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ include_package_data = True install_requires = asciitree bcrypt + zxcvbn colorama cryptography>=46.0.6 fido2>=2.0.0; python_version>='3.10' diff --git a/unit-tests/test_crypto.py b/unit-tests/test_crypto.py index b02c754fd..21750ce92 100644 --- a/unit-tests/test_crypto.py +++ b/unit-tests/test_crypto.py @@ -160,6 +160,14 @@ def test_password_score(self): self.assertEqual(utils.password_score('AAAbbbCCC11'), 38) self.assertEqual(utils.password_score('password'), 8) + def test_master_password_score(self): + # zxcvbn-based scoring: returns 25 (Weak), 50 (Fair), 75 (Medium), or 100 (Strong) + self.assertEqual(utils.master_password_score('!@#$%^&*()'), 25) # zxcvbn score 1 -> Weak + self.assertEqual(utils.master_password_score('aZkljfzsnmp4w9058dsqln5yf(&*))(*)(345'), 100) # zxcvbn score 4 -> Strong + self.assertEqual(utils.master_password_score('c3>^sxuKZ[Ndyo(OBE14'), 100) # zxcvbn score 4 -> Strong + self.assertEqual(utils.master_password_score('AAAbbbCCC11'), 75) # zxcvbn score 3 -> Medium + self.assertEqual(utils.master_password_score('password'), 25) # zxcvbn score 0 -> Weak + _test_random_data = \ 'cKGoVph_X0NKjk8jQgxyQWRElUY7IsbbIJaRcJVlnOb7AchFiY-izmTTOlgArwIqAxKDKSRAWx2Q1pX' \ From 3c94b24669c9c2139f6e60dd1d0227a2b293c657 Mon Sep 17 00:00:00 2001 From: Matthew Ford Date: Wed, 13 May 2026 13:00:10 -0700 Subject: [PATCH 21/26] Added cloud secrets import commands, These commands(aws-secrets-import, azure-secrets-import, gcp-secrets-import) allow users to perform a one-time import of their secrets from AWS, Azure, or Google Cloud into Keeper. The secret values are parsed into key/value pairs and result in records with corresponding field names and field values. This is to allow customers to be able perform a one-time import and then use Universal Secrets Sync to let Keeper be the source of truth for these secrets going forward. --- docs/aws-secrets-import.md | 272 +++++ docs/azure-secrets-import.md | 289 ++++++ docs/gcp-secrets-import.md | 288 ++++++ .../commands/_cloud_import_base.py | 390 +++++++ keepercommander/commands/aws_import.py | 199 ++++ keepercommander/commands/azure_import.py | 190 ++++ keepercommander/commands/gcp_import.py | 196 ++++ keepercommander/commands/record.py | 12 + setup.cfg | 12 + unit-tests/test_cloud_import.py | 979 ++++++++++++++++++ 10 files changed, 2827 insertions(+) create mode 100644 docs/aws-secrets-import.md create mode 100644 docs/azure-secrets-import.md create mode 100644 docs/gcp-secrets-import.md create mode 100644 keepercommander/commands/_cloud_import_base.py create mode 100644 keepercommander/commands/aws_import.py create mode 100644 keepercommander/commands/azure_import.py create mode 100644 keepercommander/commands/gcp_import.py create mode 100644 unit-tests/test_cloud_import.py diff --git a/docs/aws-secrets-import.md b/docs/aws-secrets-import.md new file mode 100644 index 000000000..3be53aa7e --- /dev/null +++ b/docs/aws-secrets-import.md @@ -0,0 +1,272 @@ +# `aws-secrets-import` — Import AWS Secrets Manager Secrets into Keeper + +The `aws-secrets-import` command reads every secret from AWS Secrets Manager and creates a corresponding Keeper record in a specified shared folder. Each secret's name becomes the record title; the secret's value is parsed into named fields on the record. + +- **Alias:** `amsi` +- **Requires:** `boto3` — install with `pip install keeper-commander[aws]` + +> **See also:** [`azure-secrets-import`](azure-secrets-import.md) for Azure Key Vault and [`gcp-secrets-import`](gcp-secrets-import.md) for Google Cloud Secret Manager. All three commands share the same secret value parsing rules, field-mapping logic, and filter flags. + +--- + +## Table of Contents + +1. [Authentication](#authentication) +2. [Basic Usage](#basic-usage) +3. [Arguments & Flags](#arguments--flags) +4. [Filtering Secrets](#filtering-secrets) +5. [Secret Value Formats](#secret-value-formats) +6. [Keeper Record Structure](#keeper-record-structure) +7. [Examples](#examples) + +--- + +## Authentication + +The command resolves AWS credentials in the following order: + +1. **Explicit flags** — `--access-key` and `--secret-key` provided directly on the command line. +2. **boto3 credential chain** — if no explicit flags are given, the standard boto3 session is used, which checks (in order): + - Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc.) + - `~/.aws/credentials` and `~/.aws/config` + - IAM role attached to the running EC2 instance or ECS task + +In most production deployments you can omit the credential flags entirely and let the instance role or `~/.aws` configuration handle authentication. + +--- + +## Basic Usage + +``` +aws-secrets-import [options] +``` + +The only required argument is the **shared folder UID** — the unique identifier of the Keeper shared folder that will receive the imported records. Use `list-sf` inside Commander to find the UID for a folder: + +``` +My Vault> list-sf +``` + +--- + +## Arguments & Flags + +### Positional argument + +| Argument | Description | +|---|---| +| `folder` | **Required.** Shared folder UID to import secrets into. | + +### Credential flags + +| Flag | Description | +|---|---| +| `--access-key KEY` | AWS access key ID. Overrides the boto3 credential chain. | +| `--secret-key SECRET` | AWS secret access key. Required when `--access-key` is provided. | +| `--region REGION` | AWS region name (e.g. `us-east-1`). Uses the boto3 default if omitted. | + +### Behaviour flags + +| Flag | Description | +|---|---| +| `--record-type TYPE` | Keeper record type for imported records. Defaults to `login`. | +| `--dry-run` | List secrets that would be imported without creating any records. | + +### Filter flags + +All filter flags are optional and combine with AND logic — a secret must satisfy every provided filter to be imported. + +| Flag | Description | +|---|---| +| `--name NAME` | Import only the secret with this exact name. | +| `--name-starts-with PREFIX` | Import only secrets whose name starts with `PREFIX`. | +| `--name-ends-with SUFFIX` | Import only secrets whose name ends with `SUFFIX`. | +| `--name-contains SUBSTRING` | Import only secrets whose name contains `SUBSTRING`. | +| `--tags KEY=VALUE[,KEY=VALUE,...]` | Import only secrets tagged with **all** specified key/value pairs. | + +--- + +## Filtering Secrets + +Filters let you import a targeted subset of secrets without touching the rest. Every filter you specify must match for a secret to be imported. + +### Name filters + +Name filters operate on the full secret name as stored in AWS. + +```bash +# Exact name match +amsi xAbCdEfGhIjK --name prod/database/primary + +# All secrets under the prod/ path +amsi xAbCdEfGhIjK --name-starts-with prod/ + +# Secrets whose name ends with /credentials +amsi xAbCdEfGhIjK --name-ends-with /credentials + +# Secrets whose name contains "rds" +amsi xAbCdEfGhIjK --name-contains rds +``` + +Multiple name filters can be combined. Each one adds an additional requirement: + +```bash +# Must start with "prod/" AND contain "database" +amsi xAbCdEfGhIjK --name-starts-with prod/ --name-contains database +``` + +### Tag filter + +The `--tags` flag accepts a comma-separated list of `KEY=VALUE` pairs. A secret is included only if it carries **all** of the specified tags with the exact values given. + +```bash +# Single tag requirement +amsi xAbCdEfGhIjK --tags Env=prod + +# Multiple tag requirements (both must match) +amsi xAbCdEfGhIjK --tags Env=prod,Team=payments +``` + +Tag keys and values are case-sensitive and must match the values stored in AWS exactly. Azure Key Vault also uses the term *tags*; GCP Secret Manager uses the term *labels* — the `--tags` flag works the same way for all three providers. + +### Combining filters + +All filter types can be used together in one command: + +```bash +amsi xAbCdEfGhIjK \ + --name-starts-with prod/ \ + --name-ends-with /creds \ + --tags Env=prod,Owner=platform +``` + +A secret is imported only if it satisfies **every** filter listed. + +--- + +## Secret Value Formats + +The same parsing rules are used by all three cloud import commands (`aws-secrets-import`, `azure-secrets-import`, `gcp-secrets-import`). + +When a secret is retrieved from AWS Secrets Manager, its `SecretString` is parsed into a set of named field values using the following rules, applied in priority order: + +### 1. JSON object + +If the secret string begins with `{` and is valid JSON representing an object, each key/value pair in the object becomes a separate field on the Keeper record. + +```json +{ + "username": "admin", + "password": "s3cur3P@ss!", + "host": "db.internal.example.com" +} +``` + +Results in three fields: `username`, `password`, and `host`. + +### 2. KEY=VALUE lines (shell-style) + +If the secret string is not JSON, the command attempts to parse it as newline-separated `KEY=VALUE` pairs (the same format used by `.env` files). Lines beginning with `#` and blank lines are ignored. + +``` +# Database credentials +username=admin +password=s3cur3P@ss! +host=db.internal.example.com +``` + +Results in three fields: `username`, `password`, and `host`. + +### 3. Fallback — plain string + +If the secret string cannot be parsed as JSON or as `KEY=VALUE` lines, the entire string is stored as a single field named `value`. + +``` +s3cur3P@ss! +``` + +Results in one field: `value = s3cur3P@ss!`. + +--- + +## Keeper Record Structure + +Each imported secret produces one **TypedRecord** in the target shared folder: + +- **Title** — the original AWS secret name (e.g. `prod/database/primary`). +- **Record type** — controlled by `--record-type` (default: `login`). + +### Field placement + +Parsed key/value pairs from the secret are mapped to Keeper field types before being placed on the record: + +| Parsed key (case-insensitive) | Keeper field type | Placement | +|---|---|---| +| `username`, `user`, `login` | `login` | Typed fields | +| `password`, `pass`, `secret`, `secret_value` | `password` | Typed fields | +| `url`, `endpoint`, `host` | `url` | Typed fields | +| `email`, `mail` | `email` | Typed fields | +| `note`, `notes` | — | Record Notes section | +| anything else | `text` | Typed fields | + +The `note` and `notes` keys are written to the record's **Notes** field rather than appearing as a typed or custom field. All other keys not listed above are stored as `text` typed fields. If the same semantic type (e.g. `login`, `password`, `url`, `email`) appears more than once, the first occurrence takes the typed field slot and subsequent ones are stored as **custom fields**. + +--- + +## Examples + +### Import all secrets using ambient AWS credentials + +```bash +amsi xAbCdEfGhIjK +``` + +Uses `~/.aws` credentials or the attached EC2/ECS instance role automatically. + +### Specify credentials and region explicitly + +```bash +amsi xAbCdEfGhIjK \ + --access-key AKIAIOSFODNN7EXAMPLE \ + --secret-key wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY \ + --region us-west-2 +``` + +### Preview what would be imported (dry run) + +```bash +amsi xAbCdEfGhIjK --dry-run +``` + +Prints the name of each secret that passes all filters without creating any records. + +### Import only production secrets owned by the payments team + +```bash +amsi xAbCdEfGhIjK --name-starts-with prod/ --tags Team=payments +``` + +### Import a single known secret + +```bash +amsi xAbCdEfGhIjK --name prod/payments/stripe-api-key +``` + +### Import all RDS secrets in staging and store as `serverCredentials` records + +```bash +amsi xAbCdEfGhIjK \ + --name-contains rds \ + --tags Env=staging \ + --record-type serverCredentials +``` + +### Dry-run a complex filter before committing + +```bash +amsi xAbCdEfGhIjK \ + --name-starts-with prod/ \ + --name-ends-with /creds \ + --tags Env=prod,Owner=platform \ + --dry-run +``` diff --git a/docs/azure-secrets-import.md b/docs/azure-secrets-import.md new file mode 100644 index 000000000..91d32dcef --- /dev/null +++ b/docs/azure-secrets-import.md @@ -0,0 +1,289 @@ +# `azure-secrets-import` — Import Azure Key Vault Secrets into Keeper + +The `azure-secrets-import` command reads every enabled secret from an Azure Key Vault and creates a corresponding Keeper record in a specified shared folder. Each secret's name becomes the record title; the secret's value is parsed into named fields on the record. + +- **Alias:** `azsi` +- **Requires:** `azure-keyvault-secrets` and `azure-identity` — install with `pip install keeper-commander[azure]` + +--- + +## Table of Contents + +1. [Authentication](#authentication) +2. [Basic Usage](#basic-usage) +3. [Arguments & Flags](#arguments--flags) +4. [Filtering Secrets](#filtering-secrets) +5. [Secret Value Formats](#secret-value-formats) +6. [Keeper Record Structure](#keeper-record-structure) +7. [Examples](#examples) + +--- + +## Authentication + +The command resolves Azure credentials in the following order: + +1. **Service-principal flags** — if `--tenant-id`, `--client-id`, and `--client-secret` are all provided, a `ClientSecretCredential` is used for authentication. All three flags must be supplied together; providing only some of them is an error. +2. **`DefaultAzureCredential`** — if no explicit flags are given, the Azure SDK's `DefaultAzureCredential` chain is used, which checks (in order): + - Environment variables (`AZURE_TENANT_ID`, `AZURE_CLIENT_ID`, `AZURE_CLIENT_SECRET`, etc.) + - Workload Identity (Kubernetes) + - Managed Identity attached to the running Azure VM, App Service, or Container + - Azure CLI (`az login`) + - Azure PowerShell + - Azure Developer CLI + +In most production deployments you can omit the credential flags entirely and rely on Managed Identity or the Azure CLI. + +--- + +## Basic Usage + +``` +azure-secrets-import [options] +``` + +Two positional arguments are required: + +- **`vault-name`** — the short name of the Azure Key Vault (e.g. `my-vault`). The command constructs the full vault URL as `https://.vault.azure.net/` automatically. +- **`folder-uid`** — the unique identifier of the Keeper shared folder that will receive the imported records. Use `list-sf` inside Commander to find this value: + +``` +My Vault> list-sf +``` + +--- + +## Arguments & Flags + +### Positional arguments + +| Argument | Description | +|---|---| +| `vault_name` | **Required.** Short name of the Azure Key Vault (e.g. `my-vault`). | +| `folder` | **Required.** Shared folder UID to import secrets into. | + +### Credential flags + +| Flag | Description | +|---|---| +| `--tenant-id ID` | Azure AD tenant ID. Required together with `--client-id` and `--client-secret`. | +| `--client-id ID` | Azure AD application (client) ID. Required together with `--tenant-id` and `--client-secret`. | +| `--client-secret SECRET` | Azure AD client secret. Required together with `--tenant-id` and `--client-id`. | + +All three credential flags must be provided together for service-principal authentication. Omit all three to use `DefaultAzureCredential`. + +### Behaviour flags + +| Flag | Description | +|---|---| +| `--record-type TYPE` | Keeper record type for imported records. Defaults to `login`. | +| `--dry-run` | List secrets that would be imported without creating any records. | + +### Filter flags + +All filter flags are optional and combine with AND logic — a secret must satisfy every provided filter to be imported. + +| Flag | Description | +|---|---| +| `--name NAME` | Import only the secret with this exact name. | +| `--name-starts-with PREFIX` | Import only secrets whose name starts with `PREFIX`. | +| `--name-ends-with SUFFIX` | Import only secrets whose name ends with `SUFFIX`. | +| `--name-contains SUBSTRING` | Import only secrets whose name contains `SUBSTRING`. | +| `--tags KEY=VALUE[,KEY=VALUE,...]` | Import only secrets tagged with **all** specified key/value pairs. | + +--- + +## Filtering Secrets + +Filters let you import a targeted subset of secrets without touching the rest. Every filter you specify must match for a secret to be imported. + +Disabled secrets are always skipped regardless of any filter settings. + +### Name filters + +Name filters operate on the secret name as stored in Azure Key Vault. + +```bash +# Exact name match +azsi my-vault xAbCdEfGhIjK --name database-primary-password + +# All secrets whose name starts with "prod-" +azsi my-vault xAbCdEfGhIjK --name-starts-with prod- + +# Secrets whose name ends with "-creds" +azsi my-vault xAbCdEfGhIjK --name-ends-with -creds + +# Secrets whose name contains "postgres" +azsi my-vault xAbCdEfGhIjK --name-contains postgres +``` + +Multiple name filters can be combined. Each one adds an additional requirement: + +```bash +# Must start with "prod-" AND contain "database" +azsi my-vault xAbCdEfGhIjK --name-starts-with prod- --name-contains database +``` + +### Tag filter + +Azure Key Vault secrets support arbitrary key/value tags. The `--tags` flag accepts a comma-separated list of `KEY=VALUE` pairs. A secret is included only if it carries **all** of the specified tags with the exact values given. + +```bash +# Single tag requirement +azsi my-vault xAbCdEfGhIjK --tags Env=prod + +# Multiple tag requirements (both must match) +azsi my-vault xAbCdEfGhIjK --tags Env=prod,Team=payments +``` + +Tag keys and values are case-sensitive and must match the values stored in Azure exactly. + +### Combining filters + +All filter types can be used together in one command: + +```bash +azsi my-vault xAbCdEfGhIjK \ + --name-starts-with prod- \ + --name-ends-with -creds \ + --tags Env=prod,Owner=platform +``` + +A secret is imported only if it satisfies **every** filter listed. + +--- + +## Secret Value Formats + +When a secret is retrieved from Azure Key Vault, its value is parsed into a set of named field values using the following rules, applied in priority order: + +### 1. JSON object + +If the secret value begins with `{` and is valid JSON representing an object, each key/value pair in the object becomes a separate field on the Keeper record. + +```json +{ + "username": "admin", + "password": "s3cur3P@ss!", + "host": "db.internal.example.com" +} +``` + +Results in three fields: `username`, `password`, and `host`. + +### 2. KEY=VALUE lines (shell-style) + +If the secret value is not JSON, the command attempts to parse it as newline-separated `KEY=VALUE` pairs (the same format used by `.env` files). Lines beginning with `#` and blank lines are ignored. + +``` +# Database credentials +username=admin +password=s3cur3P@ss! +host=db.internal.example.com +``` + +Results in three fields: `username`, `password`, and `host`. + +### 3. Fallback — plain string + +If the secret value cannot be parsed as JSON or as `KEY=VALUE` lines, the entire string is stored as a single field named `value`. + +``` +s3cur3P@ss! +``` + +Results in one field: `value = s3cur3P@ss!`. + +--- + +## Keeper Record Structure + +Each imported secret produces one **TypedRecord** in the target shared folder: + +- **Title** — the original Azure Key Vault secret name (e.g. `prod-database-primary`). +- **Record type** — controlled by `--record-type` (default: `login`). + +### Field placement + +Parsed key/value pairs from the secret are mapped to Keeper field types before being placed on the record: + +| Parsed key (case-insensitive) | Keeper field type | Placement | +|---|---|---| +| `username`, `user`, `login` | `login` | Typed fields | +| `password`, `pass`, `secret`, `secret_value` | `password` | Typed fields | +| `url`, `endpoint`, `host` | `url` | Typed fields | +| `email`, `mail` | `email` | Typed fields | +| `note`, `notes` | — | Record Notes section | +| anything else | `text` | Typed fields | + +The `note` and `notes` keys are written to the record's **Notes** field rather than appearing as a typed or custom field. All other keys not listed above are stored as `text` typed fields. If the same semantic type (e.g. `login`, `password`, `url`, `email`) appears more than once, the first occurrence takes the typed field slot and subsequent ones are stored as **custom fields**. + +--- + +## Examples + +### Import all secrets using DefaultAzureCredential + +```bash +azsi my-vault xAbCdEfGhIjK +``` + +Uses Managed Identity, Azure CLI login, or environment variables automatically. + +### Authenticate with a service principal + +```bash +azsi my-vault xAbCdEfGhIjK \ + --tenant-id 00000000-0000-0000-0000-000000000000 \ + --client-id 11111111-1111-1111-1111-111111111111 \ + --client-secret "MyClientSecretValue" +``` + +### Preview what would be imported (dry run) + +```bash +azsi my-vault xAbCdEfGhIjK --dry-run +``` + +Prints the name of each secret that passes all filters without creating any records. + +### Import only production secrets owned by the payments team + +```bash +azsi my-vault xAbCdEfGhIjK --name-starts-with prod- --tags Team=payments +``` + +### Import a single known secret + +```bash +azsi my-vault xAbCdEfGhIjK --name prod-stripe-api-key +``` + +### Import all database secrets in staging stored as `serverCredentials` records + +```bash +azsi my-vault xAbCdEfGhIjK \ + --name-contains database \ + --tags Env=staging \ + --record-type serverCredentials +``` + +### Dry-run a complex filter before committing + +```bash +azsi my-vault xAbCdEfGhIjK \ + --name-starts-with prod- \ + --name-ends-with -creds \ + --tags Env=prod,Owner=platform \ + --dry-run +``` + +### Import from a vault in a different tenant using service-principal credentials + +```bash +azsi partner-vault xAbCdEfGhIjK \ + --tenant-id aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee \ + --client-id ffffffff-0000-1111-2222-333333333333 \ + --client-secret "PartnerAppSecret" \ + --tags Shared=true +``` diff --git a/docs/gcp-secrets-import.md b/docs/gcp-secrets-import.md new file mode 100644 index 000000000..96f6def4d --- /dev/null +++ b/docs/gcp-secrets-import.md @@ -0,0 +1,288 @@ +# `gcp-secrets-import` — Import GCP Secret Manager Secrets into Keeper + +The `gcp-secrets-import` command reads every accessible secret from Google Cloud Secret Manager and creates a corresponding Keeper record in a specified shared folder. Each secret's name becomes the record title; the secret's value is parsed into named fields on the record. + +- **Alias:** `gcsi` +- **Requires:** `google-cloud-secret-manager` — install with `pip install keeper-commander[gcp]` + +--- + +## Table of Contents + +1. [Authentication](#authentication) +2. [Basic Usage](#basic-usage) +3. [Arguments & Flags](#arguments--flags) +4. [Filtering Secrets](#filtering-secrets) +5. [Secret Value Formats](#secret-value-formats) +6. [Keeper Record Structure](#keeper-record-structure) +7. [Examples](#examples) + +--- + +## Authentication + +The command resolves GCP credentials in the following order: + +1. **Service account key file** — if `--service-account-file` is provided, the specified JSON key file is loaded and used for all API calls. +2. **Application Default Credentials (ADC)** — if no key file is provided, the GCP SDK's ADC chain is used, which checks (in order): + - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable (path to a service account key file) + - User credentials set via `gcloud auth application-default login` + - The service account attached to the running Compute Engine instance, Cloud Run service, GKE workload, or other GCP-hosted environment + +In most production deployments you can omit `--service-account-file` and rely on Workload Identity or the attached service account. + +--- + +## Basic Usage + +``` +gcp-secrets-import --project-id [options] +``` + +The positional `folder-uid` and the `--project-id` flag are both required: + +- **`folder-uid`** — the unique identifier of the Keeper shared folder that will receive the imported records. Use `list-sf` inside Commander to find this value: + +``` +My Vault> list-sf +``` + +- **`--project-id`** — the GCP project ID (not the project number) that owns the secrets, e.g. `my-gcp-project`. + +--- + +## Arguments & Flags + +### Positional argument + +| Argument | Description | +|---|---| +| `folder` | **Required.** Shared folder UID to import secrets into. | + +### Credential flags + +| Flag | Description | +|---|---| +| `--project-id ID` | **Required.** GCP project ID that owns the secrets (e.g. `my-gcp-project`). | +| `--service-account-file PATH` | Path to a GCP service account JSON key file. Uses Application Default Credentials when omitted. | + +### Behaviour flags + +| Flag | Description | +|---|---| +| `--record-type TYPE` | Keeper record type for imported records. Defaults to `login`. | +| `--dry-run` | List secrets that would be imported without creating any records. | + +### Filter flags + +All filter flags are optional and combine with AND logic — a secret must satisfy every provided filter to be imported. + +| Flag | Description | +|---|---| +| `--name NAME` | Import only the secret with this exact name. | +| `--name-starts-with PREFIX` | Import only secrets whose name starts with `PREFIX`. | +| `--name-ends-with SUFFIX` | Import only secrets whose name ends with `SUFFIX`. | +| `--name-contains SUBSTRING` | Import only secrets whose name contains `SUBSTRING`. | +| `--tags KEY=VALUE[,KEY=VALUE,...]` | Import only secrets whose GCP labels match **all** specified key/value pairs. | + +> **Note on GCP labels:** GCP Secret Manager uses the term *labels* rather than *tags*. The `--tags` flag maps directly to GCP labels — use it the same way you would for AWS or Azure. + +--- + +## Filtering Secrets + +Filters let you import a targeted subset of secrets without touching the rest. Every filter you specify must match for a secret to be imported. + +Secrets whose `latest` version is disabled, destroyed, or inaccessible due to permissions are always skipped with a warning regardless of any filter settings. + +### Name filters + +Name filters operate on the short secret name (the last segment of the full GCP resource name `projects/{project}/secrets/{secret-id}`). + +```bash +# Exact name match +gcsi xAbCdEfGhIjK --project-id my-project --name database-primary-password + +# All secrets whose name starts with "prod-" +gcsi xAbCdEfGhIjK --project-id my-project --name-starts-with prod- + +# Secrets whose name ends with "-creds" +gcsi xAbCdEfGhIjK --project-id my-project --name-ends-with -creds + +# Secrets whose name contains "postgres" +gcsi xAbCdEfGhIjK --project-id my-project --name-contains postgres +``` + +Multiple name filters can be combined. Each one adds an additional requirement: + +```bash +# Must start with "prod-" AND contain "database" +gcsi xAbCdEfGhIjK --project-id my-project \ + --name-starts-with prod- --name-contains database +``` + +### Label filter (`--tags`) + +GCP Secret Manager secrets support arbitrary key/value *labels*. The `--tags` flag accepts a comma-separated list of `KEY=VALUE` pairs. A secret is included only if it carries **all** of the specified labels with the exact values given. + +```bash +# Single label requirement +gcsi xAbCdEfGhIjK --project-id my-project --tags env=prod + +# Multiple label requirements (both must match) +gcsi xAbCdEfGhIjK --project-id my-project --tags env=prod,team=payments +``` + +> GCP label keys and values are lowercase by convention and are case-sensitive. Ensure the values you provide match the casing stored in GCP. + +### Combining filters + +All filter types can be used together in one command: + +```bash +gcsi xAbCdEfGhIjK --project-id my-project \ + --name-starts-with prod- \ + --name-ends-with -creds \ + --tags env=prod,owner=platform +``` + +A secret is imported only if it satisfies **every** filter listed. + +--- + +## Secret Value Formats + +When a secret is retrieved from GCP Secret Manager, the payload of the `latest` version is decoded as UTF-8 and then parsed into a set of named field values using the following rules, applied in priority order: + +### 1. JSON object + +If the secret payload begins with `{` and is valid JSON representing an object, each key/value pair in the object becomes a separate field on the Keeper record. + +```json +{ + "username": "admin", + "password": "s3cur3P@ss!", + "host": "db.internal.example.com" +} +``` + +Results in three fields: `username`, `password`, and `host`. + +### 2. KEY=VALUE lines (shell-style) + +If the payload is not JSON, the command attempts to parse it as newline-separated `KEY=VALUE` pairs (the same format used by `.env` files). Lines beginning with `#` and blank lines are ignored. + +``` +# Database credentials +username=admin +password=s3cur3P@ss! +host=db.internal.example.com +``` + +Results in three fields: `username`, `password`, and `host`. + +### 3. Fallback — plain string + +If the payload cannot be parsed as JSON or as `KEY=VALUE` lines, the entire string is stored as a single field named `value`. + +``` +s3cur3P@ss! +``` + +Results in one field: `value = s3cur3P@ss!`. + +--- + +## Keeper Record Structure + +Each imported secret produces one **TypedRecord** in the target shared folder: + +- **Title** — the short GCP secret name (e.g. `prod-database-primary`), not the full resource path. +- **Record type** — controlled by `--record-type` (default: `login`). + +### Field placement + +Parsed key/value pairs from the secret are mapped to Keeper field types before being placed on the record: + +| Parsed key (case-insensitive) | Keeper field type | Placement | +|---|---|---| +| `username`, `user`, `login` | `login` | Typed fields | +| `password`, `pass`, `secret`, `secret_value` | `password` | Typed fields | +| `url`, `endpoint`, `host` | `url` | Typed fields | +| `email`, `mail` | `email` | Typed fields | +| `note`, `notes` | — | Record Notes section | +| anything else | `text` | Typed fields | + +The `note` and `notes` keys are written to the record's **Notes** field rather than appearing as a typed or custom field. All other keys not listed above are stored as `text` typed fields. If the same semantic type (e.g. `login`, `password`, `url`, `email`) appears more than once, the first occurrence takes the typed field slot and subsequent ones are stored as **custom fields**. + +--- + +## Examples + +### Import all secrets using Application Default Credentials + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project +``` + +Uses the `GOOGLE_APPLICATION_CREDENTIALS` environment variable, `gcloud` credentials, or the attached service account automatically. + +### Authenticate with a service account key file + +```bash +gcsi xAbCdEfGhIjK \ + --project-id my-gcp-project \ + --service-account-file /path/to/service-account-key.json +``` + +### Preview what would be imported (dry run) + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project --dry-run +``` + +Prints the name of each secret that passes all filters without creating any records. + +### Import only production secrets owned by the payments team + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project \ + --name-starts-with prod- --tags team=payments +``` + +### Import a single known secret + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project --name prod-stripe-api-key +``` + +### Import all database secrets in staging stored as `serverCredentials` records + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project \ + --name-contains database \ + --tags env=staging \ + --record-type serverCredentials +``` + +### Dry-run a complex filter before committing + +```bash +gcsi xAbCdEfGhIjK --project-id my-gcp-project \ + --name-starts-with prod- \ + --name-ends-with -creds \ + --tags env=prod,owner=platform \ + --dry-run +``` + +### Import using a service account key stored in a CI secret + +```bash +# Decode the key from a CI environment variable and import +echo "$GCP_SA_KEY" > /tmp/sa-key.json +gcsi xAbCdEfGhIjK \ + --project-id my-gcp-project \ + --service-account-file /tmp/sa-key.json \ + --tags env=prod +rm /tmp/sa-key.json +``` diff --git a/keepercommander/commands/_cloud_import_base.py b/keepercommander/commands/_cloud_import_base.py new file mode 100644 index 000000000..8bd18ed89 --- /dev/null +++ b/keepercommander/commands/_cloud_import_base.py @@ -0,0 +1,390 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' str + """ + Convert a JSON-decoded Python value to a plain string suitable for a + Keeper text field, preserving semantics rather than using Python's str(). + + None → '' (avoids the literal string 'None') + bool → 'true'/'false' (JSON casing, not Python's 'True'/'False') + dict/list → json.dumps() (round-trippable JSON, not Python repr) + anything else → str() + """ + if v is None: + return '' + if isinstance(v, bool): + return 'true' if v else 'false' + if isinstance(v, (dict, list)): + return json.dumps(v) + return str(v) + + +def add_filter_args(parser): + """Attach the standard five filter arguments to *parser*.""" + group = parser.add_argument_group( + 'filters', + 'Restrict which secrets are imported. All provided filters must match (AND logic).' + ) + group.add_argument( + '--name', dest='filter_name', action='store', metavar='NAME', + help='Import only the secret with this exact name' + ) + group.add_argument( + '--name-starts-with', dest='filter_name_starts_with', action='store', metavar='PREFIX', + help='Import only secrets whose name starts with PREFIX' + ) + group.add_argument( + '--name-ends-with', dest='filter_name_ends_with', action='store', metavar='SUFFIX', + help='Import only secrets whose name ends with SUFFIX' + ) + group.add_argument( + '--name-contains', dest='filter_name_contains', action='store', metavar='SUBSTRING', + help='Import only secrets whose name contains SUBSTRING' + ) + group.add_argument( + '--tags', dest='filter_tags', action='store', metavar='KEY=VALUE[,KEY=VALUE,...]', + help='Import only secrets tagged/labelled with ALL specified key=value pairs ' + '(e.g. --tags Env=prod,Team=ops)' + ) + + +class CloudImportMixin: + """ + Mixin for cloud-to-Keeper secret import commands. + + Provides secret-string parsing, Keeper record building, filter evaluation, + and the main import loop. Concrete commands inherit this alongside Command + and supply cloud-specific list/fetch implementations. + """ + + # ------------------------------------------------------------------ + # Secret value parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_secret_string(secret_string): + # type: (str) -> Dict[str, str] + """ + Parse a secret string into name/value pairs. + + Supported formats (in priority order): + 1. JSON object -> {"key": "value", ...} + 2. KEY=VALUE lines — keys must be POSIX identifiers ([A-Za-z_][A-Za-z0-9_]*) + 3. Fallback: whole string stored under a field named "value" + + The POSIX identifier requirement for KEY=VALUE prevents false-positives on + base64 blobs, JWTs, connection strings, and similar opaque values. + """ + secret_string = (secret_string or '').strip() + if not secret_string: + return {} + + if secret_string.startswith('{'): + try: + obj = json.loads(secret_string) + if isinstance(obj, dict): + return {str(k): _coerce_json_value(v) for k, v in obj.items()} + except (json.JSONDecodeError, ValueError): + pass + + pairs = {} + parsed_as_kv = False + for line in secret_string.splitlines(): + line = line.strip() + if not line or line.startswith('#'): + continue + if '=' in line: + key, _, val = line.partition('=') + key = key.strip() + if not _KV_KEY_RE.match(key): + parsed_as_kv = False + break + pairs[key] = val.strip() + parsed_as_kv = True + else: + parsed_as_kv = False + break + + if parsed_as_kv and pairs: + return pairs + + return {'value': secret_string} + + # ------------------------------------------------------------------ + # Keeper record builder + # ------------------------------------------------------------------ + + @staticmethod + def _build_keeper_record(title, fields, record_type='login'): + # type: (str, Dict[str, str], str) -> vault.TypedRecord + """ + Build a TypedRecord from a title and a dict of field name/value pairs. + + Field mapping rules: + - 'note' / 'notes' keys are written to record.notes (not a typed field). + - 'username' / 'user' / 'login' → typed login field. + - 'password' / 'pass' / 'secret' / 'secret_value' → typed password field. + - 'url' / 'endpoint' / 'host' → typed url field. + - 'email' / 'mail' → typed email field. + - Everything else → text typed field. + - Duplicate semantic types (login, password, url, email) route subsequent + occurrences to custom fields to avoid server-side rejection. + """ + KNOWN_TYPED_FIELDS = {'login', 'password', 'url', 'email', 'text'} + + record = vault.TypedRecord() + record.type_name = record_type + record.title = title + + used_typed = set() # tracks which de-dup types have already been placed + + for field_name, field_value in fields.items(): + lower = field_name.lower() + + # Map 'note'/'notes' to the record's notes property, not a typed field. + if lower in ('note', 'notes'): + if not record.notes: + record.notes = field_value + continue + + keeper_type = 'text' + if lower in ('username', 'user', 'login'): + keeper_type = 'login' + elif lower in ('password', 'pass', 'secret', 'secret_value'): + keeper_type = 'password' + elif lower in ('url', 'endpoint', 'host'): + keeper_type = 'url' + elif lower in ('email', 'mail'): + keeper_type = 'email' + + if keeper_type in _DEDUP_TYPED and keeper_type in used_typed: + # Duplicate semantic type — store with original label as custom text. + record.custom.append(vault.TypedField.new_field('text', field_value, field_name)) + logging.debug('_build_keeper_record: duplicate %s field "%s" routed to custom', + keeper_type, field_name) + continue + + if keeper_type in KNOWN_TYPED_FIELDS: + record.fields.append(vault.TypedField.new_field(keeper_type, field_value, field_name)) + if keeper_type in _DEDUP_TYPED: + used_typed.add(keeper_type) + else: + record.custom.append(vault.TypedField.new_field('text', field_value, field_name)) + + return record + + # ------------------------------------------------------------------ + # Filter helpers + # ------------------------------------------------------------------ + + @staticmethod + def _parse_tag_filter(tags_str, command_name): + # type: (str, str) -> List[Tuple[str, str]] + """Parse 'Key1=Val1,Key2=Val2' into [(Key1, Val1), (Key2, Val2)].""" + pairs = [] + for token in tags_str.split(','): + token = token.strip() + if not token: + continue + if '=' not in token: + raise CommandError( + command_name, + f'Invalid --tags format: "{token}". Expected KEY=VALUE pairs separated by commas.' + ) + key, _, value = token.partition('=') + pairs.append((key.strip(), value.strip())) + return pairs + + @staticmethod + def _matches_name_filters(name, filter_name, filter_starts, filter_ends, filter_contains): + # type: (str, Optional[str], Optional[str], Optional[str], Optional[str]) -> bool + if filter_name is not None and name != filter_name: + return False + if filter_starts is not None and not name.startswith(filter_starts): + return False + if filter_ends is not None and not name.endswith(filter_ends): + return False + if filter_contains is not None and filter_contains not in name: + return False + return True + + @staticmethod + def _matches_tag_filters(secret_tags, required_tags): + # type: (Dict[str, str], List[Tuple[str, str]]) -> bool + """ + Check that *secret_tags* (a plain dict) satisfies every (key, value) + pair in *required_tags*. Azure and GCP both return tags/labels as dicts. + """ + for tag_key, tag_value in required_tags: + if secret_tags.get(tag_key) != tag_value: + return False + return True + + # ------------------------------------------------------------------ + # Folder validation + # ------------------------------------------------------------------ + + @staticmethod + def _validate_folder(params, folder_uid, command_name): + # type: (KeeperParams, str, str) -> None + if not folder_uid: + raise CommandError(command_name, 'A shared folder UID is required.') + folder = params.folder_cache.get(folder_uid) + if folder is None: + raise CommandError( + command_name, + f'Folder UID "{folder_uid}" not found in your vault. ' + 'Use "list-sf" to find the correct shared folder UID, ' + 'or run "sync-down" if the folder was recently shared.' + ) + if not isinstance(folder, (SharedFolderNode, SharedFolderFolderNode)): + raise CommandError( + command_name, + f'"{folder_uid}" is a personal folder. ' + 'Secrets must be imported into a shared folder. ' + 'Use "list-sf" to find the correct shared folder UID.' + ) + + # ------------------------------------------------------------------ + # Shared import loop + # ------------------------------------------------------------------ + + def _run_import(self, params, secrets, folder_uid, record_type, filter_name, + filter_starts, filter_ends, filter_contains, required_tags, + dry_run, command_name, value_fetcher=None): + # type: (KeeperParams, list, str, str, Optional[str], Optional[str], Optional[str], Optional[str], List[Tuple[str, str]], bool, str, Optional[Callable]) -> None + """ + Iterate *secrets* (list of dicts with at minimum 'name' and 'tags'), + apply all filters, then create Keeper records via batched vault/records_add + calls (up to BATCH_SIZE records per request). + + *value_fetcher*, when provided, is called as ``value_fetcher(name) -> str`` + for each secret that passes filters and is not a dry-run. This enables + lazy value fetching (only matched, non-dry-run secrets incur a cloud API + call). When None, falls back to ``item.get('value', '')``. + """ + + # Phase 1 – filter. Collect matching items; honour dry-run. + matched = [] + for item in secrets: + name = item.get('name') or '' + if not name: + continue + if not self._matches_name_filters(name, filter_name, filter_starts, + filter_ends, filter_contains): + logging.debug('%s: skipping "%s" (name filter mismatch)', command_name, name) + continue + if required_tags and not self._matches_tag_filters(item.get('tags') or {}, required_tags): + logging.debug('%s: skipping "%s" (tag filter mismatch)', command_name, name) + continue + if dry_run: + print(f' [dry-run] would import: {name}') + continue + matched.append(item) + + if dry_run: + return + + # Phase 2 – fetch values (if needed) and build TypedRecord + protobuf + # objects without touching the Keeper API yet. + skipped = 0 + pending = [] # type: List[Tuple[vault.TypedRecord, record_pb2.RecordAdd]] + + for item in matched: + name = item['name'] + if value_fetcher is not None: + try: + value = value_fetcher(name) + except Exception as exc: + logging.warning('%s: skipping "%s" – could not retrieve value: %s', + command_name, name, exc) + skipped += 1 + continue + else: + value = item.get('value', '') + + fields = self._parse_secret_string(value) + record = self._build_keeper_record(name, fields, record_type) + pb = record_management.add_record_to_folder(params, record, folder_uid, pb_only=True) + if pb is not None: + pending.append((record, pb)) + + if not pending: + print(f'{command_name}: 0 record(s) created, {skipped} skipped.') + return + + # Phase 3 – send in batches of up to BATCH_SIZE to vault/records_add. + created = 0 + + for batch_start in range(0, len(pending), BATCH_SIZE): + batch = pending[batch_start:batch_start + BATCH_SIZE] + batch_num = batch_start // BATCH_SIZE + 1 + logging.info('%s: sending batch %d (%d record(s))', command_name, batch_num, len(batch)) + + rq = api.get_records_add_request(params) + for _, pb in batch: + rq.records.append(pb) + + try: + rs = api.communicate_rest( + params, rq, 'vault/records_add', + rs_type=record_pb2.RecordsModifyResponse + ) + except Exception as exc: + logging.warning('%s: batch %d failed: %s', command_name, batch_num, exc) + skipped += len(batch) + continue + + rs_by_uid = {utils.base64_url_encode(r.record_uid): r for r in rs.records} + for record, _ in batch: + rs_rec = rs_by_uid.get(record.record_uid) + if rs_rec is not None and rs_rec.status == record_pb2.RS_SUCCESS: + logging.debug('%s: created record "%s"', command_name, record.title) + created += 1 + else: + if rs_rec is None: + logging.warning('%s: record "%s" absent from server response', + command_name, record.title) + else: + logging.warning('%s: failed to create record "%s": status=%s', + command_name, record.title, rs_rec.status) + skipped += 1 + + if created: + params.sync_data = True + print(f'{command_name}: {created} record(s) created, {skipped} skipped.') diff --git a/keepercommander/commands/aws_import.py b/keepercommander/commands/aws_import.py new file mode 100644 index 000000000..8cabd4c78 --- /dev/null +++ b/keepercommander/commands/aws_import.py @@ -0,0 +1,199 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' List[dict] + """ + Return normalised secret metadata: [{'name': str, 'tags': {k: v}}]. + + AWS tags are converted from the native list-of-dicts format + ([{'Key': k, 'Value': v}]) to a plain dict so filters work the same + way as for Azure and GCP. + """ + sm = self.get_client('secretsmanager', region_name) + results = [] + paginator = sm.get_paginator('list_secrets') + for page in paginator.paginate(): + for secret in page.get('SecretList', []): + name = secret.get('Name') or '' + if not name: + continue + tags = {t.get('Key'): t.get('Value') + for t in (secret.get('Tags') or [])} + results.append({'name': name, 'tags': tags}) + return results + + def _get_secret_value(self, secret_name, region_name=None): + # type: (str, Optional[str]) -> str + """Fetch and return the raw secret string for *secret_name*.""" + sm = self.get_client('secretsmanager', region_name) + response = sm.get_secret_value(SecretId=secret_name) + return response.get('SecretString') or '' + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def execute(self, params, **kwargs): # type: (KeeperParams, Any) -> None + folder_uid = kwargs.get('folder') or '' + access_key = kwargs.get('access_key') + secret_key = kwargs.get('secret_key') + region = kwargs.get('region') + dry_run = kwargs.get('dry_run', False) + record_type = kwargs.get('record_type') or 'login' + + filter_name = kwargs.get('filter_name') or None + filter_starts = kwargs.get('filter_name_starts_with') or None + filter_ends = kwargs.get('filter_name_ends_with') or None + filter_contains = kwargs.get('filter_name_contains') or None + tags_str = kwargs.get('filter_tags') or '' + required_tags = [] # type: List[Tuple[str, str]] + if tags_str: + required_tags = self._parse_tag_filter(tags_str, 'aws-secrets-import') + + # Validate credential flags before mutating any instance state. + if access_key and not secret_key: + raise CommandError('aws-secrets-import', '--secret-key is required when --access-key is provided.') + if secret_key and not access_key: + raise CommandError('aws-secrets-import', '--access-key is required when --secret-key is provided.') + + self._validate_folder(params, folder_uid, 'aws-secrets-import') + + self._access_key = access_key or None + self._secret_key = secret_key or None + self._boto3_clients.clear() + + logging.info('aws-secrets-import: listing secrets from AWS Secrets Manager…') + try: + secrets = self._list_secret_metadata(region) + except Exception as exc: + raise CommandError('aws-secrets-import', f'Failed to list secrets from AWS: {exc}') + + if not secrets: + logging.warning('aws-secrets-import: no secrets found in AWS Secrets Manager.') + return + + logging.info('aws-secrets-import: found %d secret(s).', len(secrets)) + + self._run_import( + params, secrets, folder_uid, record_type, + filter_name, filter_starts, filter_ends, filter_contains, + required_tags, dry_run, 'aws-secrets-import', + value_fetcher=lambda name: self._get_secret_value(name, region) + ) diff --git a/keepercommander/commands/azure_import.py b/keepercommander/commands/azure_import.py new file mode 100644 index 000000000..fc09c4a02 --- /dev/null +++ b/keepercommander/commands/azure_import.py @@ -0,0 +1,190 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + try: + from azure.keyvault.secrets import SecretClient + except ImportError: + raise CommandError( + 'azure-secrets-import', + 'azure-keyvault-secrets is required. Install it with: pip install keeper-commander[azure]' + ) + vault_url = f'https://{vault_name}.vault.azure.net/' + return SecretClient(vault_url=vault_url, credential=credential) + + def _list_secret_metadata(self, client): + # type: (Any) -> List[dict] + """ + Return secret metadata only — no values fetched. + Result: [{'name': str, 'tags': Dict[str, str]}] + + Disabled secrets are excluded. Value fetching is deferred until + after filtering so that --dry-run does not trigger cloud API calls. + """ + results = [] + for prop in client.list_properties_of_secrets(): + if not prop.enabled: + logging.debug('azure-secrets-import: skipping disabled secret "%s"', prop.name) + continue + results.append({'name': prop.name, 'tags': dict(prop.tags or {})}) + return results + + def _get_secret_value(self, client, name): + # type: (Any, str) -> str + """Fetch and return the value of a single secret.""" + secret = client.get_secret(name) + return secret.value or '' + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def execute(self, params, **kwargs): # type: (KeeperParams, Any) -> None + vault_name = kwargs.get('vault_name') or '' + folder_uid = kwargs.get('folder') or '' + tenant_id = kwargs.get('tenant_id') + client_id = kwargs.get('client_id') + client_secret = kwargs.get('client_secret') + dry_run = kwargs.get('dry_run', False) + record_type = kwargs.get('record_type') or 'login' + + filter_name = kwargs.get('filter_name') or None + filter_starts = kwargs.get('filter_name_starts_with') or None + filter_ends = kwargs.get('filter_name_ends_with') or None + filter_contains = kwargs.get('filter_name_contains') or None + tags_str = kwargs.get('filter_tags') or '' + required_tags = [] # type: List[Tuple[str, str]] + if tags_str: + required_tags = self._parse_tag_filter(tags_str, 'azure-secrets-import') + + if not vault_name: + raise CommandError('azure-secrets-import', 'An Azure Key Vault name is required.') + + self._validate_folder(params, folder_uid, 'azure-secrets-import') + + credential = self._get_credential(tenant_id, client_id, client_secret) + client = self._make_client(vault_name, credential) + + logging.info('azure-secrets-import: listing secrets from vault "%s"…', vault_name) + try: + secrets = self._list_secret_metadata(client) + except CommandError: + raise + except Exception as exc: + raise CommandError('azure-secrets-import', f'Failed to list secrets from Azure Key Vault: {exc}') + + if not secrets: + logging.warning('azure-secrets-import: no enabled secrets found in vault "%s".', vault_name) + return + + logging.info('azure-secrets-import: found %d secret(s).', len(secrets)) + + self._run_import( + params, secrets, folder_uid, record_type, + filter_name, filter_starts, filter_ends, filter_contains, + required_tags, dry_run, 'azure-secrets-import', + value_fetcher=lambda name: self._get_secret_value(client, name) + ) diff --git a/keepercommander/commands/gcp_import.py b/keepercommander/commands/gcp_import.py new file mode 100644 index 000000000..2c26509d1 --- /dev/null +++ b/keepercommander/commands/gcp_import.py @@ -0,0 +1,196 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + try: + from google.cloud import secretmanager + except ImportError: + raise CommandError( + 'gcp-secrets-import', + 'google-cloud-secret-manager is required. ' + 'Install it with: pip install keeper-commander[gcp]' + ) + + if service_account_file: + try: + from google.oauth2 import service_account + except ImportError: + raise CommandError( + 'gcp-secrets-import', + 'google-auth is required for service-account authentication. ' + 'Install it with: pip install keeper-commander[gcp]' + ) + logging.info('gcp-secrets-import: using service account key file "%s"', service_account_file) + credentials = service_account.Credentials.from_service_account_file( + service_account_file, + scopes=['https://www.googleapis.com/auth/cloud-platform'], + ) + return secretmanager.SecretManagerServiceClient(credentials=credentials) + + logging.info('gcp-secrets-import: using Application Default Credentials') + return secretmanager.SecretManagerServiceClient() + + # ------------------------------------------------------------------ + # GCP Secret Manager helpers + # ------------------------------------------------------------------ + + def _list_secret_metadata(self, client, project_id): + # type: (Any, str) -> List[dict] + """ + Return secret metadata only — no values fetched. + Result: [{'name': str, 'tags': Dict[str, str]}] + + GCP labels (the equivalent of tags) are a plain dict on each Secret. + Value fetching is deferred until after filtering so that --dry-run + does not trigger cloud API calls or generate audit-log entries. + """ + parent = f'projects/{project_id}' + results = [] + for secret in client.list_secrets(request={'parent': parent}): + short_name = secret.name.split('/')[-1] + results.append({'name': short_name, 'tags': dict(secret.labels or {})}) + return results + + def _get_secret_value(self, client, full_resource_name): + # type: (Any, str) -> str + """ + Fetch and return the payload of the latest version of a secret. + + *full_resource_name* is the GCP resource path + ``projects/{project}/secrets/{secret-id}/versions/latest``. + + Raises CommandError if the payload is binary (non-UTF-8), since Keeper + text fields cannot represent arbitrary bytes. + """ + from google.api_core.exceptions import NotFound, PermissionDenied + + secret_name = full_resource_name.split('/')[-1] + version_name = f'{full_resource_name}/versions/latest' + try: + response = client.access_secret_version(request={'name': version_name}) + except NotFound: + raise ValueError(f'no accessible version for "{secret_name}"') + except PermissionDenied: + raise PermissionError(f'permission denied accessing "{secret_name}"') + + try: + return response.payload.data.decode('utf-8') + except UnicodeDecodeError: + raise CommandError( + 'gcp-secrets-import', + f'"{secret_name}" contains binary data which is not supported. ' + 'Only text secrets can be imported.' + ) + + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + + def execute(self, params, **kwargs): # type: (KeeperParams, Any) -> None + folder_uid = kwargs.get('folder') or '' + project_id = kwargs.get('project_id') or '' + service_account_file = kwargs.get('service_account_file') + dry_run = kwargs.get('dry_run', False) + record_type = kwargs.get('record_type') or 'login' + + filter_name = kwargs.get('filter_name') or None + filter_starts = kwargs.get('filter_name_starts_with') or None + filter_ends = kwargs.get('filter_name_ends_with') or None + filter_contains = kwargs.get('filter_name_contains') or None + tags_str = kwargs.get('filter_tags') or '' + required_tags = [] # type: List[Tuple[str, str]] + if tags_str: + required_tags = self._parse_tag_filter(tags_str, 'gcp-secrets-import') + + if not project_id: + raise CommandError('gcp-secrets-import', '--project-id is required.') + + self._validate_folder(params, folder_uid, 'gcp-secrets-import') + + client = self._get_client(service_account_file) + + logging.info('gcp-secrets-import: listing secrets in project "%s"…', project_id) + try: + secrets = self._list_secret_metadata(client, project_id) + except CommandError: + raise + except Exception as exc: + raise CommandError('gcp-secrets-import', f'Failed to list secrets from GCP Secret Manager: {exc}') + + if not secrets: + logging.warning('gcp-secrets-import: no accessible secrets found in project "%s".', project_id) + return + + logging.info('gcp-secrets-import: found %d secret(s).', len(secrets)) + + def _fetch_value(name): + full_name = f'projects/{project_id}/secrets/{name}' + return self._get_secret_value(client, full_name) + + self._run_import( + params, secrets, folder_uid, record_type, + filter_name, filter_starts, filter_ends, filter_contains, + required_tags, dry_run, 'gcp-secrets-import', + value_fetcher=_fetch_value + ) diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 972680f41..17e108adb 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -63,6 +63,10 @@ def handle_empty_result(fmt, message, filename=None): return None def register_commands(commands): + from . import aws_import, azure_import, gcp_import + commands['aws-secrets-import'] = aws_import.AwsSecretsImportCommand() + commands['azure-secrets-import'] = azure_import.AzureSecretsImportCommand() + commands['gcp-secrets-import'] = gcp_import.GcpSecretsImportCommand() commands['search'] = SearchCommand() commands['get'] = RecordGetUidCommand() commands['rm'] = RecordRemoveCommand() @@ -100,6 +104,14 @@ def register_command_info(aliases, command_info): aliases['da'] = 'download-attachment' aliases['ua'] = 'upload-attachment' + from . import aws_import, azure_import, gcp_import + aliases['amsi'] = 'aws-secrets-import' + aliases['azsi'] = 'azure-secrets-import' + aliases['gcsi'] = 'gcp-secrets-import' + command_info[aws_import.aws_secrets_import_parser.prog] = aws_import.aws_secrets_import_parser.description + command_info[azure_import.azure_secrets_import_parser.prog] = azure_import.azure_secrets_import_parser.description + command_info[gcp_import.gcp_secrets_import_parser.prog] = gcp_import.gcp_secrets_import_parser.description + for p in [get_info_parser, search_parser, list_parser, list_sf_parser, list_team_parser, record_history_parser, shared_records_report_parser, record_edit.record_add_parser, record_edit.record_update_parser, record_edit.append_parser, record_edit.download_parser, diff --git a/setup.cfg b/setup.cfg index 5de629eb2..8278f112a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,6 +70,18 @@ keepercommander = resources/*, resources/email_templates/*, commands/pam_import/ test = pytest testfixtures +aws = + boto3>=1.26.0 +azure = + azure-keyvault-secrets>=4.7.0 + azure-identity>=1.13.0 +gcp = + google-cloud-secret-manager>=2.16.0 +cloud = + boto3>=1.26.0 + azure-keyvault-secrets>=4.7.0 + azure-identity>=1.13.0 + google-cloud-secret-manager>=2.16.0 email-sendgrid = sendgrid>=6.10.0 email-ses = diff --git a/unit-tests/test_cloud_import.py b/unit-tests/test_cloud_import.py new file mode 100644 index 000000000..eaea0db42 --- /dev/null +++ b/unit-tests/test_cloud_import.py @@ -0,0 +1,979 @@ +import json +import sys +import os +from unittest import TestCase, mock +from typing import List + +sys.path.insert(0, os.path.dirname(__file__)) + +from data_vault import get_synced_params + +from keepercommander import utils as keeper_utils, vault +from keepercommander.subfolder import SharedFolderNode, SharedFolderFolderNode, BaseFolderNode +from keepercommander.error import CommandError +from keepercommander.proto import record_pb2 +from keepercommander.commands._cloud_import_base import CloudImportMixin +from keepercommander.commands.aws_import import AwsSecretsImportCommand +from keepercommander.commands.azure_import import AzureSecretsImportCommand +from keepercommander.commands.gcp_import import GcpSecretsImportCommand + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +FOLDER_UID = 'TEST_SHARED_FOLDER_UID' +PERSONAL_FOLDER_UID = 'TEST_PERSONAL_FOLDER_UID' + + +def _make_params(): + """Return synced KeeperParams with a shared folder and a personal folder.""" + params = get_synced_params() + + shared = SharedFolderNode() + shared.uid = FOLDER_UID + params.folder_cache[FOLDER_UID] = shared + + # Personal (user) folder — should be rejected by _validate_folder + from keepercommander.subfolder import UserFolderNode + personal = UserFolderNode() + personal.uid = PERSONAL_FOLDER_UID + params.folder_cache[PERSONAL_FOLDER_UID] = personal + + return params + + +def _fake_add_record_pb(captured): + """ + Side-effect for record_management.add_record_to_folder (pb_only path). + + Appends the record to *captured* and returns a real RecordAdd whose + record_uid bytes match the record UID so the batch response matcher works. + """ + def _side_effect(params, record, folder_uid, pb_only=False): + if not record.record_uid: + record.record_uid = keeper_utils.generate_uid() + captured.append(record) + if pb_only: + pb = record_pb2.RecordAdd() + pb.record_uid = keeper_utils.base64_url_decode(record.record_uid) + return pb + return _side_effect + + +def _fake_records_add_success(params_arg, rq, endpoint, rs_type=None): + """Side-effect for api.communicate_rest: returns RS_SUCCESS for every record.""" + rs = record_pb2.RecordsModifyResponse() + rs.revision = 1 + for pb_rec in rq.records: + rec_rs = record_pb2.RecordModifyStatus() + rec_rs.record_uid = bytes(pb_rec.record_uid) + rec_rs.status = record_pb2.RS_SUCCESS + rs.records.append(rec_rs) + return rs + + +def _fake_records_add_empty_response(params_arg, rq, endpoint, rs_type=None): + """Side-effect that returns an empty records list (simulates truncated response).""" + rs = record_pb2.RecordsModifyResponse() + rs.revision = 1 + return rs + + +# --------------------------------------------------------------------------- +# Shared base-class logic (CloudImportMixin) +# --------------------------------------------------------------------------- + +class TestCloudImportBase(TestCase): + """Tests for the shared parsing/filtering/building helpers in CloudImportMixin.""" + + # --- _parse_secret_string --- + + def test_parse_json_object(self): + result = CloudImportMixin._parse_secret_string('{"username": "admin", "password": "s3cr3t"}') + self.assertEqual(result, {'username': 'admin', 'password': 's3cr3t'}) + + def test_parse_kv_lines(self): + result = CloudImportMixin._parse_secret_string('username=admin\npassword=s3cr3t') + self.assertEqual(result, {'username': 'admin', 'password': 's3cr3t'}) + + def test_parse_kv_lines_ignores_comments_and_blanks(self): + raw = '# comment\n\nusername=admin\npassword=s3cr3t\n' + result = CloudImportMixin._parse_secret_string(raw) + self.assertEqual(result, {'username': 'admin', 'password': 's3cr3t'}) + + def test_parse_kv_value_may_contain_equals(self): + result = CloudImportMixin._parse_secret_string('token=abc=def=ghi') + self.assertEqual(result, {'token': 'abc=def=ghi'}) + + def test_parse_plain_string_fallback(self): + result = CloudImportMixin._parse_secret_string('just-a-plain-secret') + self.assertEqual(result, {'value': 'just-a-plain-secret'}) + + def test_parse_empty_string_returns_empty_dict(self): + self.assertEqual(CloudImportMixin._parse_secret_string(''), {}) + self.assertEqual(CloudImportMixin._parse_secret_string(None), {}) + + def test_parse_invalid_json_falls_through_to_fallback(self): + result = CloudImportMixin._parse_secret_string('{bad json') + self.assertEqual(result, {'value': '{bad json'}) + + def test_parse_kv_non_posix_key_falls_through_to_fallback(self): + """A line whose key contains non-POSIX chars must not be parsed as KEY=VALUE.""" + # JWT-like single line: key contains dots and slashes + jwt = 'eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ1c2VyIn0.abc123==' + result = CloudImportMixin._parse_secret_string(jwt) + # '=' in the JWT triggers partition but the key contains '.' → not POSIX → fallback + self.assertEqual(result, {'value': jwt}) + + def test_parse_kv_single_line_with_posix_key_still_parsed(self): + """A single KEY=VALUE with a valid POSIX key is legitimately parsed.""" + result = CloudImportMixin._parse_secret_string('API_TOKEN=abc123') + self.assertEqual(result, {'API_TOKEN': 'abc123'}) + + def test_parse_json_null_value_becomes_empty_string(self): + """JSON null must not become the literal string 'None'.""" + result = CloudImportMixin._parse_secret_string('{"dbname": null}') + self.assertEqual(result['dbname'], '') + + def test_parse_json_bool_true_becomes_lowercase(self): + """JSON true must become 'true', not Python's 'True'.""" + result = CloudImportMixin._parse_secret_string('{"enabled": true}') + self.assertEqual(result['enabled'], 'true') + + def test_parse_json_bool_false_becomes_lowercase(self): + result = CloudImportMixin._parse_secret_string('{"enabled": false}') + self.assertEqual(result['enabled'], 'false') + + def test_parse_json_nested_object_becomes_json_string(self): + """Nested dict must be json.dumps'd, not Python repr'd.""" + result = CloudImportMixin._parse_secret_string('{"config": {"host": "x", "port": 5432}}') + import json as _json + # Must be valid JSON, not Python repr with single quotes + parsed = _json.loads(result['config']) + self.assertEqual(parsed, {'host': 'x', 'port': 5432}) + + def test_parse_json_array_becomes_json_string(self): + result = CloudImportMixin._parse_secret_string('{"hosts": ["h1", "h2"]}') + import json as _json + self.assertEqual(_json.loads(result['hosts']), ['h1', 'h2']) + + def test_parse_json_integer_becomes_string(self): + result = CloudImportMixin._parse_secret_string('{"port": 5432}') + self.assertEqual(result['port'], '5432') + + # --- _build_keeper_record --- + + def test_build_record_sets_title_and_type(self): + record = CloudImportMixin._build_keeper_record('My Secret', {}, 'login') + self.assertEqual(record.title, 'My Secret') + self.assertEqual(record.type_name, 'login') + + def test_build_record_maps_username_to_login_typed_field(self): + record = CloudImportMixin._build_keeper_record('s', {'username': 'admin'}, 'login') + self.assertIn('login', {f.type for f in record.fields}) + + def test_build_record_maps_password_to_typed_field(self): + record = CloudImportMixin._build_keeper_record('s', {'password': 'pw'}, 'login') + self.assertIn('password', {f.type for f in record.fields}) + + def test_build_record_maps_url_to_typed_field(self): + record = CloudImportMixin._build_keeper_record('s', {'url': 'https://example.com'}, 'login') + self.assertIn('url', {f.type for f in record.fields}) + + def test_build_record_unmapped_key_uses_text_typed_field(self): + record = CloudImportMixin._build_keeper_record('s', {'region': 'us-east-1'}, 'login') + self.assertEqual(len(record.fields), 1) + self.assertEqual(record.fields[0].type, 'text') + self.assertEqual(record.fields[0].label, 'region') + self.assertEqual(len(record.custom), 0) + + def test_build_record_note_key_goes_to_notes_property(self): + """'note' and 'notes' must be written to record.notes, not a typed field.""" + record = CloudImportMixin._build_keeper_record('s', {'note': 'rotate weekly'}, 'login') + self.assertEqual(record.notes, 'rotate weekly') + self.assertEqual(len(record.fields), 0) + self.assertEqual(len(record.custom), 0) + + def test_build_record_notes_key_goes_to_notes_property(self): + record = CloudImportMixin._build_keeper_record('s', {'notes': 'rotate monthly'}, 'login') + self.assertEqual(record.notes, 'rotate monthly') + + def test_build_record_first_note_wins(self): + """Only the first note/notes field populates record.notes.""" + record = CloudImportMixin._build_keeper_record( + 's', {'note': 'first', 'notes': 'second'}, 'login') + self.assertEqual(record.notes, 'first') + + def test_build_record_duplicate_login_goes_to_custom(self): + """Second 'login'-typed field should be stored in custom fields.""" + fields = {'username': 'u1', 'login': 'u2'} + record = CloudImportMixin._build_keeper_record('s', fields, 'login') + typed_types = [f.type for f in record.fields] + self.assertEqual(typed_types.count('login'), 1) + self.assertEqual(len(record.custom), 1) + self.assertEqual(record.custom[0].type, 'text') + + def test_build_record_duplicate_password_goes_to_custom(self): + """Second 'password'-typed field should be stored in custom fields.""" + fields = {'password': 'p1', 'pass': 'p2'} + record = CloudImportMixin._build_keeper_record('s', fields, 'login') + typed_types = [f.type for f in record.fields] + self.assertEqual(typed_types.count('password'), 1) + self.assertEqual(len(record.custom), 1) + + def test_build_record_email_key_maps_to_email_typed_field(self): + """'email' and 'mail' keys must produce a typed email field, not text.""" + record = CloudImportMixin._build_keeper_record('s', {'email': 'u@example.com'}, 'login') + types = [f.type for f in record.fields] + self.assertIn('email', types) + self.assertEqual(len(record.custom), 0) + + def test_build_record_mail_key_maps_to_email_typed_field(self): + record = CloudImportMixin._build_keeper_record('s', {'mail': 'u@example.com'}, 'login') + self.assertIn('email', {f.type for f in record.fields}) + + def test_build_record_duplicate_email_goes_to_custom(self): + fields = {'email': 'primary@example.com', 'mail': 'alt@example.com'} + record = CloudImportMixin._build_keeper_record('s', fields, 'login') + self.assertEqual([f.type for f in record.fields].count('email'), 1) + self.assertEqual(len(record.custom), 1) + + def test_build_record_mixed_fields(self): + fields = {'username': 'admin', 'password': 'pw', 'region': 'us-east-1'} + record = CloudImportMixin._build_keeper_record('s', fields, 'login') + typed_types = {f.type for f in record.fields} + self.assertIn('login', typed_types) + self.assertIn('password', typed_types) + self.assertIn('text', typed_types) + self.assertEqual(len(record.custom), 0) + + # --- _parse_tag_filter --- + + def test_parse_tag_filter_single(self): + self.assertEqual(CloudImportMixin._parse_tag_filter('Env=prod', 'cmd'), + [('Env', 'prod')]) + + def test_parse_tag_filter_multiple(self): + self.assertEqual(CloudImportMixin._parse_tag_filter('Env=prod,Team=ops', 'cmd'), + [('Env', 'prod'), ('Team', 'ops')]) + + def test_parse_tag_filter_invalid_raises(self): + with self.assertRaises(CommandError): + CloudImportMixin._parse_tag_filter('Env', 'cmd') + + def test_parse_tag_filter_ignores_empty_tokens(self): + self.assertEqual(CloudImportMixin._parse_tag_filter('Env=prod,', 'cmd'), + [('Env', 'prod')]) + + # --- _matches_name_filters --- + + def test_name_filter_exact_match(self): + self.assertTrue(CloudImportMixin._matches_name_filters('prod/db', 'prod/db', None, None, None)) + self.assertFalse(CloudImportMixin._matches_name_filters('prod/db', 'prod/other', None, None, None)) + + def test_name_filter_starts_with(self): + self.assertTrue(CloudImportMixin._matches_name_filters('prod/db', None, 'prod/', None, None)) + self.assertFalse(CloudImportMixin._matches_name_filters('dev/db', None, 'prod/', None, None)) + + def test_name_filter_ends_with(self): + self.assertTrue(CloudImportMixin._matches_name_filters('prod/db/creds', None, None, '/creds', None)) + self.assertFalse(CloudImportMixin._matches_name_filters('prod/db/config', None, None, '/creds', None)) + + def test_name_filter_contains(self): + self.assertTrue(CloudImportMixin._matches_name_filters('prod/rds/creds', None, None, None, 'rds')) + self.assertFalse(CloudImportMixin._matches_name_filters('prod/mysql/creds', None, None, None, 'rds')) + + def test_name_filter_no_filters_always_passes(self): + self.assertTrue(CloudImportMixin._matches_name_filters('anything', None, None, None, None)) + + def test_name_filter_combined_all_must_pass(self): + self.assertTrue(CloudImportMixin._matches_name_filters( + 'prod/rds/pw', None, 'prod/', None, 'rds')) + self.assertFalse(CloudImportMixin._matches_name_filters( + 'prod/mysql/pw', None, 'prod/', None, 'rds')) + + # --- _matches_tag_filters --- + + def test_tag_filter_all_match(self): + tags = {'Env': 'prod', 'Team': 'ops'} + self.assertTrue(CloudImportMixin._matches_tag_filters(tags, [('Env', 'prod'), ('Team', 'ops')])) + + def test_tag_filter_one_mismatch(self): + tags = {'Env': 'staging', 'Team': 'ops'} + self.assertFalse(CloudImportMixin._matches_tag_filters(tags, [('Env', 'prod'), ('Team', 'ops')])) + + def test_tag_filter_missing_key(self): + self.assertFalse(CloudImportMixin._matches_tag_filters({'Team': 'ops'}, [('Env', 'prod')])) + + def test_tag_filter_empty_required_always_passes(self): + self.assertTrue(CloudImportMixin._matches_tag_filters({}, [])) + + # --- _validate_folder --- + + def test_validate_folder_empty_uid_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + CloudImportMixin._validate_folder(params, '', 'cmd') + + def test_validate_folder_unknown_uid_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + CloudImportMixin._validate_folder(params, 'NONEXISTENT_UID', 'cmd') + + def test_validate_folder_personal_folder_raises(self): + """A personal (user) folder UID must be rejected.""" + params = _make_params() + with self.assertRaises(CommandError): + CloudImportMixin._validate_folder(params, PERSONAL_FOLDER_UID, 'cmd') + + def test_validate_folder_shared_uid_passes(self): + params = _make_params() + CloudImportMixin._validate_folder(params, FOLDER_UID, 'cmd') + + def test_validate_folder_shared_folder_folder_passes(self): + """A SharedFolderFolderNode (sub-folder inside a shared folder) is also valid.""" + params = _make_params() + from keepercommander.subfolder import SharedFolderFolderNode + sff = SharedFolderFolderNode() + sff.uid = 'SFF_UID' + params.folder_cache['SFF_UID'] = sff + CloudImportMixin._validate_folder(params, 'SFF_UID', 'cmd') + + # --- _run_import --- + + def test_run_import_dry_run_does_not_create_records(self): + mixin = CloudImportMixin() + params = _make_params() + secrets = [{'name': 'my-secret', 'value': 'pw=s3cr3t', 'tags': {}}] + + with mock.patch('keepercommander.record_management.add_record_to_folder') as add_mock, \ + mock.patch('keepercommander.api.communicate_rest') as rest_mock, \ + mock.patch('builtins.print') as print_mock: + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [], True, 'cmd') + + add_mock.assert_not_called() + rest_mock.assert_not_called() + printed = ' '.join(str(a) for call in print_mock.call_args_list for a in call.args) + self.assertIn('my-secret', printed) + + def test_run_import_creates_records(self): + mixin = CloudImportMixin() + params = _make_params() + secrets = [ + {'name': 'secret-a', 'value': '{"username": "alice"}', 'tags': {}}, + {'name': 'secret-b', 'value': 'password=hunter2', 'tags': {}}, + ] + captured = [] + + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [], False, 'cmd') + + self.assertEqual(len(captured), 2) + self.assertEqual(captured[0].title, 'secret-a') + self.assertEqual(captured[1].title, 'secret-b') + + def test_run_import_name_filter_applied(self): + mixin = CloudImportMixin() + params = _make_params() + secrets = [ + {'name': 'prod/db', 'value': 'password=pw', 'tags': {}}, + {'name': 'dev/db', 'value': 'password=pw', 'tags': {}}, + ] + captured = [] + + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, 'prod/', None, None, [], False, 'cmd') + + self.assertEqual(len(captured), 1) + self.assertEqual(captured[0].title, 'prod/db') + + def test_run_import_tag_filter_applied(self): + mixin = CloudImportMixin() + params = _make_params() + secrets = [ + {'name': 'secret-a', 'value': 'v=1', 'tags': {'Env': 'prod'}}, + {'name': 'secret-b', 'value': 'v=2', 'tags': {'Env': 'staging'}}, + ] + captured = [] + + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [('Env', 'prod')], False, 'cmd') + + self.assertEqual(len(captured), 1) + self.assertEqual(captured[0].title, 'secret-a') + + def test_run_import_value_fetcher_called_after_filter(self): + """value_fetcher must only be called for secrets that pass all filters.""" + mixin = CloudImportMixin() + params = _make_params() + secrets = [ + {'name': 'prod/db', 'tags': {}}, + {'name': 'dev/db', 'tags': {}}, + ] + fetched_names = [] + + def _fetcher(name): + fetched_names.append(name) + return 'password=pw' + + captured = [] + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, 'prod/', None, None, [], False, 'cmd', + value_fetcher=_fetcher) + + self.assertEqual(fetched_names, ['prod/db']) + self.assertEqual(len(captured), 1) + + def test_run_import_value_fetcher_not_called_on_dry_run(self): + """value_fetcher must NOT be called during a dry run.""" + mixin = CloudImportMixin() + params = _make_params() + secrets = [{'name': 'sec', 'tags': {}}] + fetch_calls = [] + + with mock.patch('keepercommander.api.communicate_rest'), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [], True, 'cmd', + value_fetcher=lambda n: fetch_calls.append(n) or 'v=1') + + self.assertEqual(fetch_calls, []) + + def test_run_import_absent_uid_in_response_counts_as_skipped(self): + """Records missing from the server response must be counted as skipped, not created.""" + mixin = CloudImportMixin() + params = _make_params() + params.sync_data = False + secrets = [{'name': 'sec', 'value': 'v=1', 'tags': {}}] + + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb([])), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_empty_response), \ + mock.patch('builtins.print') as print_mock: + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [], False, 'cmd') + + # sync_data must NOT be set — nothing was successfully created + self.assertFalse(params.sync_data) + printed = ' '.join(str(a) for call in print_mock.call_args_list for a in call.args) + self.assertIn('0 record(s) created', printed) + self.assertIn('1 skipped', printed) + + def test_run_import_sets_sync_data_when_records_created(self): + mixin = CloudImportMixin() + params = _make_params() + params.sync_data = False + secrets = [{'name': 'sec', 'value': 'v=1', 'tags': {}}] + + with mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb([])), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + mixin._run_import(params, secrets, FOLDER_UID, 'login', + None, None, None, None, [], False, 'cmd') + + self.assertTrue(params.sync_data) + + +# --------------------------------------------------------------------------- +# AWS Secrets Manager +# --------------------------------------------------------------------------- + +class TestAwsSecretsImport(TestCase): + + def setUp(self): + self.cmd = AwsSecretsImportCommand() + + def tearDown(self): + mock.patch.stopall() + + # --- argument / folder validation --- + + def test_execute_missing_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder='') + + def test_execute_unknown_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder='NONEXISTENT') + + def test_execute_personal_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder=PERSONAL_FOLDER_UID) + + def test_execute_access_key_without_secret_key_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder=FOLDER_UID, access_key='AKIA123') + + def test_execute_secret_key_without_access_key_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder=FOLDER_UID, secret_key='secret') + + # --- happy path --- + + def _run_with_mocked_aws(self, params, secret_meta, value_map=None, **extra_kwargs): + """ + Run execute() with AWS internals mocked. + + *secret_meta* is a list of {'name', 'tags'} dicts (the normalised format + returned by _list_secret_metadata after this refactor). + *value_map* maps name → secret value string. + """ + value_map = value_map or {} + captured = [] + + with mock.patch.object(self.cmd, '_list_secret_metadata', return_value=secret_meta), \ + mock.patch.object(self.cmd, '_get_secret_value', + side_effect=lambda name, region: value_map.get(name, '')), \ + mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID, **extra_kwargs) + + return captured + + def test_execute_imports_all_secrets(self): + params = _make_params() + meta = [{'name': 'prod/db', 'tags': {}}, {'name': 'prod/api', 'tags': {}}] + values = {'prod/db': '{"username": "admin", "password": "pw"}', + 'prod/api': 'token=abc123'} + records = self._run_with_mocked_aws(params, meta, values) + self.assertEqual(len(records), 2) + self.assertEqual({r.title for r in records}, {'prod/db', 'prod/api'}) + + def test_execute_json_secret_fields_mapped_correctly(self): + params = _make_params() + meta = [{'name': 'my-cred', 'tags': {}}] + values = {'my-cred': '{"username": "alice", "password": "s3cr3t"}'} + records = self._run_with_mocked_aws(params, meta, values) + self.assertEqual(len(records), 1) + self.assertIn('login', {f.type for f in records[0].fields}) + self.assertIn('password', {f.type for f in records[0].fields}) + + def test_execute_dry_run_does_not_create_records(self): + params = _make_params() + meta = [{'name': 'prod/db', 'tags': {}}] + with mock.patch.object(self.cmd, '_list_secret_metadata', return_value=meta), \ + mock.patch.object(self.cmd, '_get_secret_value') as val_mock, \ + mock.patch('keepercommander.record_management.add_record_to_folder') as add_mock, \ + mock.patch('keepercommander.api.communicate_rest') as rest_mock, \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID, dry_run=True) + + val_mock.assert_not_called() + add_mock.assert_not_called() + rest_mock.assert_not_called() + + def test_execute_name_filter(self): + params = _make_params() + meta = [{'name': 'prod/db', 'tags': {}}, {'name': 'dev/db', 'tags': {}}] + values = {'prod/db': 'v=1'} + records = self._run_with_mocked_aws(params, meta, values, + filter_name_starts_with='prod/') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod/db') + + def test_execute_tag_filter_normalised_format(self): + """Tags are now plain dicts (normalised in _list_secret_metadata).""" + params = _make_params() + meta = [ + {'name': 'prod/db', 'tags': {'Env': 'prod'}}, + {'name': 'dev/db', 'tags': {'Env': 'dev'}}, + ] + values = {'prod/db': 'v=1'} + records = self._run_with_mocked_aws(params, meta, values, filter_tags='Env=prod') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod/db') + + def test_execute_no_secrets_returns_without_error(self): + params = _make_params() + with mock.patch.object(self.cmd, '_list_secret_metadata', return_value=[]): + self.cmd.execute(params, folder=FOLDER_UID) + + def test_execute_get_value_error_skips_secret(self): + params = _make_params() + meta = [{'name': 'good', 'tags': {}}, {'name': 'bad', 'tags': {}}] + + def _get_value(name, region): + if name == 'bad': + raise RuntimeError('access denied') + return 'password=pw' + + captured = [] + with mock.patch.object(self.cmd, '_list_secret_metadata', return_value=meta), \ + mock.patch.object(self.cmd, '_get_secret_value', side_effect=_get_value), \ + mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID) + + self.assertEqual(len(captured), 1) + self.assertEqual(captured[0].title, 'good') + + # --- _list_secret_metadata tag normalisation --- + + def test_list_secret_metadata_normalises_aws_tags(self): + """AWS list-of-dicts tags must be converted to a plain dict.""" + mock_sm = mock.MagicMock() + mock_sm.get_paginator.return_value.paginate.return_value = [{ + 'SecretList': [ + {'Name': 'prod/db', + 'Tags': [{'Key': 'Env', 'Value': 'prod'}, {'Key': 'Team', 'Value': 'ops'}]}, + ] + }] + self.cmd._access_key = None + with mock.patch.object(self.cmd, 'get_client', return_value=mock_sm): + result = self.cmd._list_secret_metadata('us-east-1') + + self.assertEqual(result, [{'name': 'prod/db', 'tags': {'Env': 'prod', 'Team': 'ops'}}]) + + +# --------------------------------------------------------------------------- +# Azure Key Vault +# --------------------------------------------------------------------------- + +class TestAzureSecretsImport(TestCase): + + def setUp(self): + self.cmd = AzureSecretsImportCommand() + + def tearDown(self): + mock.patch.stopall() + + # --- argument / folder validation --- + + def test_execute_missing_vault_name_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, vault_name='', folder=FOLDER_UID) + + def test_execute_unknown_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, vault_name='my-vault', folder='NONEXISTENT') + + def test_execute_personal_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, vault_name='my-vault', folder=PERSONAL_FOLDER_UID) + + def test_execute_partial_sp_credentials_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, vault_name='my-vault', folder=FOLDER_UID, + tenant_id='tid', client_id=None, client_secret=None) + with self.assertRaises(CommandError): + self.cmd.execute(params, vault_name='my-vault', folder=FOLDER_UID, + tenant_id='tid', client_id='cid', client_secret=None) + + # --- happy path --- + + def _run_with_mocked_azure(self, params, secret_meta, value_map=None, **extra_kwargs): + """ + Run execute() with Azure internals mocked. + + *secret_meta*: list of {'name', 'tags'} dicts. + *value_map*: name → secret value string. + """ + value_map = value_map or {} + mock_credential = mock.MagicMock() + mock_client = mock.MagicMock() + captured = [] + + with mock.patch.object(self.cmd, '_get_credential', return_value=mock_credential), \ + mock.patch.object(self.cmd, '_make_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=secret_meta), \ + mock.patch.object(self.cmd, '_get_secret_value', + side_effect=lambda client, name: value_map.get(name, '')), \ + mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + self.cmd.execute(params, vault_name='my-vault', folder=FOLDER_UID, **extra_kwargs) + + return captured + + def test_execute_imports_all_secrets(self): + params = _make_params() + meta = [{'name': 'db-password', 'tags': {}}, {'name': 'api-token', 'tags': {}}] + values = {'db-password': '{"username": "admin", "password": "pw"}', + 'api-token': 'token=abc123'} + records = self._run_with_mocked_azure(params, meta, values) + self.assertEqual(len(records), 2) + self.assertEqual({r.title for r in records}, {'db-password', 'api-token'}) + + def test_execute_dry_run_does_not_create_records(self): + params = _make_params() + meta = [{'name': 'db-password', 'tags': {}}] + mock_credential = mock.MagicMock() + mock_client = mock.MagicMock() + + with mock.patch.object(self.cmd, '_get_credential', return_value=mock_credential), \ + mock.patch.object(self.cmd, '_make_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=meta), \ + mock.patch.object(self.cmd, '_get_secret_value') as val_mock, \ + mock.patch('keepercommander.record_management.add_record_to_folder') as add_mock, \ + mock.patch('keepercommander.api.communicate_rest') as rest_mock, \ + mock.patch('builtins.print'): + self.cmd.execute(params, vault_name='my-vault', folder=FOLDER_UID, dry_run=True) + + val_mock.assert_not_called() + add_mock.assert_not_called() + rest_mock.assert_not_called() + + def test_execute_name_filter(self): + params = _make_params() + meta = [{'name': 'prod-db', 'tags': {}}, {'name': 'dev-db', 'tags': {}}] + values = {'prod-db': 'v=1'} + records = self._run_with_mocked_azure(params, meta, values, + filter_name_starts_with='prod-') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod-db') + + def test_execute_tag_filter(self): + params = _make_params() + meta = [{'name': 'prod-secret', 'tags': {'Env': 'prod'}}, + {'name': 'dev-secret', 'tags': {'Env': 'dev'}}] + values = {'prod-secret': 'v=1'} + records = self._run_with_mocked_azure(params, meta, values, filter_tags='Env=prod') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod-secret') + + def test_execute_no_secrets_returns_without_error(self): + params = _make_params() + mock_credential = mock.MagicMock() + mock_client = mock.MagicMock() + with mock.patch.object(self.cmd, '_get_credential', return_value=mock_credential), \ + mock.patch.object(self.cmd, '_make_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=[]): + self.cmd.execute(params, vault_name='my-vault', folder=FOLDER_UID) + + def test_execute_uses_default_record_type(self): + params = _make_params() + records = self._run_with_mocked_azure(params, [{'name': 'sec', 'tags': {}}], + {'sec': 'password=pw'}) + self.assertEqual(records[0].type_name, 'login') + + def test_execute_custom_record_type(self): + params = _make_params() + records = self._run_with_mocked_azure(params, [{'name': 'sec', 'tags': {}}], + {'sec': 'password=pw'}, + record_type='serverCredentials') + self.assertEqual(records[0].type_name, 'serverCredentials') + + +# --------------------------------------------------------------------------- +# GCP Secret Manager +# --------------------------------------------------------------------------- + +class TestGcpSecretsImport(TestCase): + + def setUp(self): + self.cmd = GcpSecretsImportCommand() + + def tearDown(self): + mock.patch.stopall() + + # --- argument / folder validation --- + + def test_execute_missing_project_id_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder=FOLDER_UID, project_id='') + + def test_execute_unknown_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder='NONEXISTENT', project_id='my-project') + + def test_execute_personal_folder_raises(self): + params = _make_params() + with self.assertRaises(CommandError): + self.cmd.execute(params, folder=PERSONAL_FOLDER_UID, project_id='my-project') + + # --- happy path --- + + def _run_with_mocked_gcp(self, params, secret_meta, value_map=None, **extra_kwargs): + """ + Run execute() with GCP internals mocked. + + *secret_meta*: list of {'name', 'tags'} dicts. + *value_map*: name → secret value string. + """ + value_map = value_map or {} + mock_client = mock.MagicMock() + captured = [] + + with mock.patch.object(self.cmd, '_get_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=secret_meta), \ + mock.patch.object(self.cmd, '_get_secret_value', + side_effect=lambda client, full_name: value_map.get(full_name.split('/')[-1], '')), \ + mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID, project_id='my-project', **extra_kwargs) + + return captured + + def test_execute_imports_all_secrets(self): + params = _make_params() + meta = [{'name': 'db-password', 'tags': {}}, {'name': 'api-key', 'tags': {}}] + values = {'db-password': '{"username": "root", "password": "pw"}', + 'api-key': 'token=xyz789'} + records = self._run_with_mocked_gcp(params, meta, values) + self.assertEqual(len(records), 2) + self.assertEqual({r.title for r in records}, {'db-password', 'api-key'}) + + def test_execute_dry_run_does_not_create_records(self): + params = _make_params() + meta = [{'name': 'db-password', 'tags': {}}] + mock_client = mock.MagicMock() + + with mock.patch.object(self.cmd, '_get_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=meta), \ + mock.patch.object(self.cmd, '_get_secret_value') as val_mock, \ + mock.patch('keepercommander.record_management.add_record_to_folder') as add_mock, \ + mock.patch('keepercommander.api.communicate_rest') as rest_mock, \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID, project_id='my-project', dry_run=True) + + val_mock.assert_not_called() + add_mock.assert_not_called() + rest_mock.assert_not_called() + + def test_execute_name_filter(self): + params = _make_params() + meta = [{'name': 'prod-db', 'tags': {}}, {'name': 'dev-db', 'tags': {}}] + values = {'prod-db': 'v=1'} + records = self._run_with_mocked_gcp(params, meta, values, + filter_name_starts_with='prod-') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod-db') + + def test_execute_label_filter(self): + params = _make_params() + meta = [{'name': 'prod-secret', 'tags': {'env': 'prod'}}, + {'name': 'dev-secret', 'tags': {'env': 'dev'}}] + values = {'prod-secret': 'v=1'} + records = self._run_with_mocked_gcp(params, meta, values, filter_tags='env=prod') + self.assertEqual(len(records), 1) + self.assertEqual(records[0].title, 'prod-secret') + + def test_execute_no_secrets_returns_without_error(self): + params = _make_params() + mock_client = mock.MagicMock() + with mock.patch.object(self.cmd, '_get_client', return_value=mock_client), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=[]): + self.cmd.execute(params, folder=FOLDER_UID, project_id='my-project') + + def test_execute_uses_default_record_type(self): + params = _make_params() + records = self._run_with_mocked_gcp(params, [{'name': 'sec', 'tags': {}}], + {'sec': 'password=pw'}) + self.assertEqual(records[0].type_name, 'login') + + def test_execute_custom_record_type(self): + params = _make_params() + records = self._run_with_mocked_gcp(params, [{'name': 'sec', 'tags': {}}], + {'sec': 'password=pw'}, + record_type='serverCredentials') + self.assertEqual(records[0].type_name, 'serverCredentials') + + def test_execute_sets_sync_data_after_import(self): + params = _make_params() + params.sync_data = False + self._run_with_mocked_gcp(params, [{'name': 'sec', 'tags': {}}], {'sec': 'pw=x'}) + self.assertTrue(params.sync_data) + + def test_get_secret_value_binary_raises_command_error(self): + """Binary (non-UTF-8) payloads must raise CommandError, not propagate UnicodeDecodeError.""" + import sys + import types + + # google-cloud-secret-manager is an optional extra and may not be installed + # in the test environment. Provide a minimal mock so the import inside + # _get_secret_value succeeds without the real package. + mock_exc_module = types.ModuleType('google.api_core.exceptions') + mock_exc_module.NotFound = type('NotFound', (Exception,), {}) + mock_exc_module.PermissionDenied = type('PermissionDenied', (Exception,), {}) + + mock_client = mock.MagicMock() + binary_payload = b'\x30\x82\x03\x01\x00\x01' # DER-encoded bytes + mock_version = mock.MagicMock() + mock_version.payload.data = binary_payload + mock_client.access_secret_version.return_value = mock_version + + with mock.patch.dict(sys.modules, {'google.api_core.exceptions': mock_exc_module}), \ + self.assertRaises(CommandError) as ctx: + self.cmd._get_secret_value(mock_client, 'projects/p/secrets/tls-cert') + + self.assertIn('binary data', str(ctx.exception).lower()) + self.assertIn('tls-cert', str(ctx.exception)) + + def test_get_secret_value_binary_skipped_gracefully_in_run_import(self): + """ + A binary secret raises CommandError from _get_secret_value; execute() + must count it as skipped while still importing text secrets. + + _get_secret_value is mocked directly here — the binary-detection logic + is already covered by test_get_secret_value_binary_raises_command_error. + This test focuses on the execute() / _run_import integration. + """ + params = _make_params() + + def _get_value(client, full_name): + if 'tls-cert' in full_name: + raise CommandError('gcp-secrets-import', + '"tls-cert" contains binary data which is not supported.') + return 'password=s3cr3t' + + meta = [{'name': 'my-db', 'tags': {}}, {'name': 'tls-cert', 'tags': {}}] + captured = [] + + with mock.patch.object(self.cmd, '_get_client', return_value=mock.MagicMock()), \ + mock.patch.object(self.cmd, '_list_secret_metadata', return_value=meta), \ + mock.patch.object(self.cmd, '_get_secret_value', side_effect=_get_value), \ + mock.patch('keepercommander.record_management.add_record_to_folder', + side_effect=_fake_add_record_pb(captured)), \ + mock.patch('keepercommander.api.communicate_rest', + side_effect=_fake_records_add_success), \ + mock.patch('builtins.print'): + self.cmd.execute(params, folder=FOLDER_UID, project_id='my-project') + + self.assertEqual(len(captured), 1) + self.assertEqual(captured[0].title, 'my-db') From 631370267ad6b453c0b7aa418395027a89bf7547 Mon Sep 17 00:00:00 2001 From: tbjones-ks Date: Wed, 13 May 2026 12:06:14 -0700 Subject: [PATCH 22/26] Commander | cp-rotation-skip-feat Change List: - Make rotation: block optional in credential-provision YAML - Add rotation.on_demand flag for manual-trigger rotation mode - Add rotation.rotate_on_provision flag to defer immediate rotation - Add account.existing_password for bring-your-own-password onboarding - Validator INVARIANT-001 rejects push-to-target combinations - Skip _create_ad_user_via_gateway when existing_password is set - Defer AD group-add to after PAM user creation + rotation - Fix NameError in cp --output json --dry-run (absorbs 97dc8893) - Update inline help text; rewrite _dry_run_report --- .../commands/credential_provision.py | 382 ++++++++--- tests/test_credential_provision_dryrun.py | 471 +++++++++++++ tests/test_credential_provision_execute.py | 476 +++++++++++++ tests/test_credential_provision_validation.py | 642 ++++++++++++++++++ unit-tests/test_credential_provision.py | 10 +- 5 files changed, 1901 insertions(+), 80 deletions(-) create mode 100644 tests/test_credential_provision_dryrun.py create mode 100644 tests/test_credential_provision_execute.py create mode 100644 tests/test_credential_provision_validation.py diff --git a/keepercommander/commands/credential_provision.py b/keepercommander/commands/credential_provision.py index 69875ce04..b1a9008bd 100644 --- a/keepercommander/commands/credential_provision.py +++ b/keepercommander/commands/credential_provision.py @@ -85,14 +85,16 @@ from keepercommander.proto import pam_pb2 from keepercommander.commands.pam.config_facades import PamConfigurationRecordFacade +# Default password complexity when no rotation: block or existing_password. +DEFAULT_COMPLEXITY = "32,5,5,5,5" + # ============================================================================= # Argument Parser # ============================================================================= credential_provision_parser = argparse.ArgumentParser( prog='credential-provision', - description='Automate PAM User credential provisioning with password rotation ' - 'and email delivery' + description='Automate PAM User credential provisioning, with optional password rotation and email delivery.' ) # Config input: file path OR base64-encoded content (mutually exclusive) @@ -389,12 +391,15 @@ def execute(self, params: KeeperParams, **kwargs): logging.error(error_msg) raise CommandError('credential-provision', error_msg) - # Generate password - password = self._generate_password(config['rotation']['password_complexity']) + password, used_existing = self._determine_password(config) - # Create AD user via Gateway if AD-specific fields are present + # AD-create is a push-to-target route — gated when existing_password is set. ad_groups = config['account'].get('ad_groups', []) has_ad_config = config['account'].get('distinguished_name') or ad_groups + # ad_ops_allowed tracks whether group-add (which we defer to after PAM + # user creation + rotation, to keep rollback semantics clean) can run. + ad_ops_allowed = False + gateway_uid = None if has_ad_config: # Check Gateway version — AD operations require 1.8.0+ gateway_uid = self._get_gateway_uid_for_config( @@ -408,27 +413,26 @@ def execute(self, params: KeeperParams, **kwargs): ' Continuing with PAM User record creation only.' ) else: - try: - self._create_ad_user_via_gateway(config, password, params, state) - if output_format == 'text': - logging.info(f'✅ AD user created: {config["account"]["username"]}') - except CommandError as ad_err: - if 'already exists' in str(ad_err).lower(): + ad_ops_allowed = True + if self._should_create_ad_user(config): + try: + self._create_ad_user_via_gateway(config, password, params, state) if output_format == 'text': - logging.info(f'⚠️ AD user already exists: {config["account"]["username"]} — continuing with PAM User creation') - else: - raise - - # Add to AD groups - if ad_groups: - succeeded, failed = self._add_ad_user_to_groups_via_gateway( - config, params, state.ad_gateway_uid + logging.info(f'✅ AD user created: {config["account"]["username"]}') + except CommandError as ad_err: + if 'already exists' in str(ad_err).lower(): + if output_format == 'text': + logging.info(f'⚠️ AD user already exists: {config["account"]["username"]} — continuing with PAM User creation') + else: + raise + elif used_existing: + logging.info( + 'Skipping AD user creation: existing_password declared; ' + 'account is assumed to pre-exist.' ) - if output_format == 'text': - if succeeded: - logging.info(f'✅ Added to AD groups: {", ".join(succeeded)}') - if failed: - logging.warning(f'⚠️ Failed to add to AD groups: {", ".join(failed)}') + # NOTE: group-add deferred to after PAM user creation + rotation + # so a failure in those steps doesn't orphan AD memberships on + # the operator's pre-existing account. See review #8 / PR #2043. # Create PAM User record pam_user_uid = self._create_pam_user(config, password, params) @@ -437,19 +441,49 @@ def execute(self, params: KeeperParams, **kwargs): if output_format == 'text': logging.info(f'✅ PAM User record created: {pam_user_uid}') + if used_existing: + self._log_existing_password_use(pam_user_uid) + # Link to PAM Configuration and configure rotation self._create_dag_link(pam_user_uid, config['account']['pam_config_uid'], params) state.dag_link_created = True - self._configure_rotation(pam_user_uid, config, params) - if output_format == 'text': - logging.info('✅ Rotation configured') - - # Perform immediate rotation if configured - rotation_success = self._rotate_immediately(pam_user_uid, config, params) - - if output_format == 'text': - logging.info('✅ Password rotation submitted') + rotation_success = False + if 'rotation' in config: + # _configure_rotation returns False when it swallowed a gateway-500 + # error (rotation NOT actually configured — deferred). Gate the + # success log on the real outcome to avoid contradicting the + # warning lines emitted inside the helper. + rotation_configured = self._configure_rotation(pam_user_uid, config, params) + if rotation_configured and output_format == 'text': + logging.info('✅ Rotation configured') + + if self._should_rotate_on_provision(config): + rotation_success = self._rotate_immediately(pam_user_uid, config, params) + # Only emit the success line when _rotate_immediately actually + # succeeded; failure path already logged a warning + mode-aware hint. + if rotation_success and output_format == 'text': + logging.info('✅ Password rotation submitted') + elif output_format == 'text': + logging.info( + 'Skipping immediate rotation (rotate_on_provision: false).' + ) + elif output_format == 'text': + logging.info('No rotation: block in YAML — rotation will not be configured.') + + # AD group additions — deferred from the AD-create block above so a failure + # in _create_pam_user / _create_dag_link / _configure_rotation doesn't orphan + # group memberships on the operator's pre-existing AD account. By placing this + # AFTER rotation, only graceful (non-rollback) failures remain ahead of it. + if has_ad_config and ad_ops_allowed and ad_groups: + succeeded, failed = self._add_ad_user_to_groups_via_gateway( + config, params, gateway_uid + ) + if output_format == 'text': + if succeeded: + logging.info(f'✅ Added to AD groups: {", ".join(succeeded)}') + if failed: + logging.warning(f'⚠️ Failed to add to AD groups: {", ".join(failed)}') # Delivery: direct share or one-time share URL + email share_url = None @@ -483,12 +517,21 @@ def execute(self, params: KeeperParams, **kwargs): logging.info('✅ Record created (no delivery configured)') if output_format == 'json': + # rotation_status reflects the chosen execution path per API_CONTRACTS.md. + if 'rotation' not in config: + rotation_status = 'not_configured' + elif rotation_success: + rotation_status = 'synced' + elif config['rotation'].get('on_demand'): + rotation_status = 'on_demand' + else: + rotation_status = 'scheduled' result = { 'success': True, 'pam_user_uid': pam_user_uid, 'username': config['account']['username'], 'employee_name': f"{config['user']['first_name']} {config['user']['last_name']}", - 'rotation_status': 'synced' if rotation_success else 'scheduled', + 'rotation_status': rotation_status, 'message': 'Credential provisioning complete' } if has_delivery: @@ -672,8 +715,19 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List ) return errors - # Validate required top-level sections - required_sections = ['user', 'account', 'rotation'] + # Reject a bare 'rotation:' key with no value. YAML parses `rotation:` to None, + # which would crash downstream `.get(...)` calls with an unhelpful AttributeError. + # The headline "rotation is optional" invites this typo — fail loudly with guidance. + if 'rotation' in config and config['rotation'] is None: + errors.append( + 'rotation: was specified but has no value. Either remove the key entirely\n' + ' (Commander will not manage rotation for this record) or provide\n' + ' schedule (cron) or on_demand: true, plus password_complexity.' + ) + return errors + + # rotation: block is optional (cp-rotation-skip-feat). + required_sections = ['user', 'account'] for section in required_sections: if section not in config: errors.append(f'Missing required section: {section}') @@ -685,7 +739,9 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List # Validate each section errors.extend(self._validate_user_section(config.get('user', {}))) errors.extend(self._validate_account_section(config.get('account', {}))) - errors.extend(self._validate_rotation_section(config.get('rotation', {}))) + # rotation: optional. Empty 'rotation: {}' still validated (and rejected). + if 'rotation' in config: + errors.extend(self._validate_rotation_section(config.get('rotation', {}))) # Validate delivery section if present (vault sharing) if 'delivery' in config: @@ -758,8 +814,27 @@ def _validate_config(self, params: KeeperParams, config: Dict[str, Any]) -> List if 'managed_company' in config: errors.extend(self._validate_mc_context(params, config['managed_company'])) + # INVARIANT-001: existing_password rejected with provisioning-time + # rotation. See TECHNICAL_DECISIONS.md ADR-001 / SECURITY_ARCHITECTURE.md. + if config.get('account', {}).get('existing_password'): + if 'rotation' in config: + rop = config['rotation'].get('rotate_on_provision', True) + if rop is not False: + errors.append( + 'account.existing_password cannot be combined with a rotation:\n' + ' block unless rotation.rotate_on_provision is explicitly set\n' + ' to false. Either:\n' + ' - Omit the rotation: block (Commander stores the password\n' + ' in the vault but does not manage rotation), or\n' + ' - Set rotation.rotate_on_provision: false (rotation is\n' + ' configured but no immediate rotation fires at provisioning).' + ) + # Validate: transfer_ownership/remove_from_service_vault are incompatible with rotation - has_rotation = bool(config.get('rotation', {}).get('schedule')) + # (any rotation mode — the gateway must own the record to fire rotations whether + # cron-driven or on-demand). + rot = config.get('rotation', {}) + has_rotation = bool(rot.get('schedule')) or rot.get('on_demand') is True transfer = config.get('delivery', {}).get('transfer_ownership', False) remove = config.get('delivery', {}).get('remove_from_service_vault', False) if has_rotation and (transfer or remove): @@ -818,8 +893,8 @@ def _validate_account_section(self, account: Dict[str, Any]) -> List[str]: if not account.get('pam_config_uid'): errors.append('account.pam_config_uid is required') - # CRITICAL: Reject old 'initial_password' field (security issue) - # Per blocker resolution, passwords are generated by Commander, not provided in YAML + # CRITICAL: reject 'initial_password' (BLOCKER_KC-1007-2). Distinct from + # 'existing_password' (stored, never pushed) — see ADR-001. if 'initial_password' in account: errors.append( 'account.initial_password is NOT supported (security).\n' @@ -827,15 +902,49 @@ def _validate_account_section(self, account: Dict[str, Any]) -> List[str]: ' Remove this field from your YAML configuration.' ) + # existing_password: account's current real password. Stored, never pushed. + # Coexistence with rotation gated by INVARIANT-001 in _validate_config. + # Whitespace-only strings are rejected by .strip() — bool() of non-empty + # strings is True so ' ' would otherwise pass. + if 'existing_password' in account: + ep = account.get('existing_password') + if not isinstance(ep, str) or not ep.strip(): + errors.append( + 'account.existing_password must be a non-empty string if specified.' + ) + return errors def _validate_rotation_section(self, rotation: Dict[str, Any]) -> List[str]: - """Validate rotation section (schedule and password complexity).""" + """Validate rotation section (schedule/on_demand, complexity, flags).""" errors = [] - if not rotation.get('schedule'): - errors.append('rotation.schedule is required (CRON format)') + # Exactly one of schedule (cron) or on_demand: true — mirrors + # PAMCreateRecordRotationCommand's mutually-exclusive group. + has_schedule = bool(rotation.get('schedule')) + on_demand = rotation.get('on_demand') + + if 'on_demand' in rotation and not isinstance(on_demand, bool): + errors.append( + 'rotation.on_demand must be a boolean (true/false) if specified.' + ) + # Treat invalid type as "not set" for mode logic + on_demand_truthy = False else: + on_demand_truthy = on_demand is True + + if has_schedule and on_demand_truthy: + errors.append( + 'rotation must specify exactly one of: schedule (cron expression)\n' + ' or on_demand: true. Both are mutually exclusive.' + ) + elif not has_schedule and not on_demand_truthy: + errors.append( + 'rotation must specify exactly one of: schedule (cron expression)\n' + ' or on_demand: true.' + ) + + if has_schedule: schedule = rotation['schedule'] if validate_cron_expression and not validate_cron_expression(schedule, for_rotation=True)[0]: errors.append( @@ -855,6 +964,14 @@ def _validate_rotation_section(self, rotation: Dict[str, Any]) -> List[str]: f' Example: "32,5,5,5,5"' ) + # rotate_on_provision: optional bool, default true. False skips _rotate_immediately. + if 'rotate_on_provision' in rotation: + rop = rotation['rotate_on_provision'] + if not isinstance(rop, bool): + errors.append( + 'rotation.rotate_on_provision must be a boolean (true/false) if specified.' + ) + return errors def _validate_email_section(self, params: KeeperParams, email: Dict[str, Any]) -> List[str]: @@ -984,7 +1101,8 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f """ Generate dry-run report showing what would be created. - PII (names, emails) are partially masked for security. + PII (names, emails) are partially masked for security. The value of + account.existing_password is NEVER included in any output. Args: params: KeeperParams session @@ -1005,27 +1123,72 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f redacted_email = self._mask_pii(user.get('personal_email', '')) redacted_username = self._mask_pii(username) + has_rotation = 'rotation' in config + uses_existing_password = bool(account.get('existing_password')) + if has_rotation: + rotation_mode = 'on_demand' if rotation.get('on_demand') else 'scheduled' + else: + rotation_mode = 'none' + will_rotate_immediately = ( + has_rotation + and rotation.get('rotate_on_provision', True) is not False + ) + + # Build actions list to mirror execute()'s actual call sequence. + ad_groups = account.get('ad_groups', []) + has_ad_config = bool(account.get('distinguished_name') or ad_groups) + delivery_cfg = config.get('delivery', {}) + has_delivery = bool(delivery_cfg.get('share_to')) + email_cfg_name = email_config.get('config_name', '') if email_config else '' + has_email = bool(email_cfg_name and email_cfg_name.lower() not in ('none', 'null', '')) + + actions = ['Check for duplicate PAM User'] + if uses_existing_password: + actions.append('Use operator-supplied existing_password (value never echoed)') + else: + actions.append('Generate secure password (complexity requirements applied)') + # AD-create only when AD context AND not declaring an existing account + if has_ad_config and not uses_existing_password: + actions.append('Create AD user via Gateway') + actions.append(f'Create PAM User: {redacted_username}') + actions.append(f'Link PAM User to PAM Config: {account.get("pam_config_uid")}') + if has_rotation: + if rotation_mode == 'on_demand': + actions.append('Configure rotation: on-demand (manual triggering only)') + else: + actions.append(f'Configure rotation: {rotation.get("schedule")}') + if will_rotate_immediately: + actions.append('Submit immediate rotation') + # AD group-add deferred to after PAM user creation + rotation — mirrors the + # reorder in execute() (c663a127). Runs whenever ad_groups is set, independent + # of AD-create. + if ad_groups: + actions.append(f'Add to AD groups: {len(ad_groups)} group(s)') + # Direct vault share when delivery.share_to is set + if has_delivery: + actions.append(f'Share directly to: {self._mask_pii(delivery_cfg["share_to"])}') + # Share-URL generation + email run only when email is fully configured + if has_email: + actions.append( + f'Generate share URL for PAM User (expiry: {email_config.get("share_url_expiry", "7d")})' + ) + actions.append(f'Send email to: {redacted_email} (config: {email_cfg_name})') + else: + actions.append(f'Skip welcome email (email.config_name: {email_cfg_name or "(unset)"})') + if output_format == 'json': result = { 'success': True, 'dry_run': True, 'employee_name': redacted_name, - 'actions': [ - 'Check for duplicate PAM User', - 'Generate secure password (complexity requirements applied)', - f'Create PAM User: {redacted_username}', - f'Link PAM User to PAM Config: {account.get("pam_config_uid")}', - f'Configure rotation: {rotation.get("schedule")}', - 'Submit immediate rotation', - f'Generate share URL for PAM User (expiry: {email_config.get("share_url_expiry", "7d")})', - f'Send email to: {redacted_email}' - ], + 'actions': actions, 'configuration': { 'employee': redacted_name, 'username': redacted_username, 'folder': vault_config.get('folder', 'Shared Folders/PAM/{}'.format(user.get('department', 'Unknown'))), - 'rotation_schedule': pam.get('rotation', {}).get('schedule'), - 'email_recipient': redacted_email + 'rotation_schedule': rotation.get('schedule') if rotation_mode == 'scheduled' else None, + 'rotation_mode': rotation_mode, + 'email_recipient': redacted_email, } } print(json.dumps(result, indent=2)) @@ -1036,22 +1199,20 @@ def _dry_run_report(self, params: KeeperParams, config: Dict[str, Any], output_f print(f'\nEmployee: {redacted_name}') print(f'Username: {redacted_username}') print(f'Email: {redacted_email}') + if uses_existing_password: + print('Password source: operator-supplied existing_password (value not shown)') + if has_rotation: + if rotation_mode == 'on_demand': + print('Rotation mode: on-demand (manual trigger only)') + else: + print(f'Rotation mode: scheduled ({rotation.get("schedule")})') + if not will_rotate_immediately: + print('Immediate rotation at provisioning: SKIPPED (rotate_on_provision: false)') + else: + print('Rotation: not configured (no rotation: block in YAML)') print('\nPlanned Actions:') - print(' 1. Check for duplicate PAM User in folder') - print(f' 2. Generate secure password') - print(f' Complexity: requirements applied') - print(f' 3. Create PAM User record') - default_folder = '/Employees/{}'.format(user.get('department', 'Unknown')) - print(f' Folder: {vault_config.get("folder", default_folder)}') - print(f' 4. Link to PAM Config: {account.get("pam_config_uid")[:20]}...') - print(f' 5. Configure rotation') - print(f' Schedule: {rotation.get("schedule")}') - print(f' 6. Submit immediate rotation') - print(f' 7. Generate one-time share URL for PAM User') - print(f' Expiry: {email_config.get("share_url_expiry", "7d")}') - print(f' 8. Send welcome email') - print(f' To: {redacted_email}') - print(f' Config: {email_config.get("config_name")}') + for idx, action in enumerate(actions, 1): + print(f' {idx}. {action}') print('\n' + '='*60) print('✓ Validation passed - ready for actual provisioning') print(' Run without --dry-run to execute') @@ -1230,6 +1391,45 @@ def _get_folder_uid(self, folder_path: str, params: KeeperParams) -> Optional[st return None + def _determine_password(self, config: Dict[str, Any]) -> Tuple[str, bool]: + """Resolve (password, used_existing) per cp-rotation-skip-feat: + existing_password → rotation.password_complexity → DEFAULT_COMPLEXITY. + """ + existing = config.get('account', {}).get('existing_password') + if existing: + return existing, True + if 'rotation' in config: + complexity = config['rotation']['password_complexity'] + else: + complexity = DEFAULT_COMPLEXITY + return self._generate_password(complexity), False + + def _should_rotate_on_provision(self, config: Dict[str, Any]) -> bool: + """True when rotation: present and rotate_on_provision is not False.""" + if 'rotation' not in config: + return False + return config['rotation'].get('rotate_on_provision', True) is not False + + def _should_create_ad_user(self, config: Dict[str, Any]) -> bool: + """True when AD context is set AND existing_password is NOT — gates the + rm-create-user push-to-target route. See SECURITY_ARCHITECTURE.md.""" + account = config.get('account', {}) + has_ad_config = bool(account.get('distinguished_name') or account.get('ad_groups')) + if not has_ad_config: + return False + if account.get('existing_password'): + return False + return True + + def _log_existing_password_use(self, pam_user_uid: str) -> None: + """Audit log: existing_password was used. Takes only the UID by design; + signature is enforced by a structural test to prevent log leakage.""" + logging.info( + 'Provisioning with operator-supplied existing_password for ' + 'record %s; rotation immediate skipped, AD-create skipped.', + pam_user_uid, + ) + def _generate_password(self, password_complexity: str) -> str: """ Generate secure random password using Commander's built-in generator. @@ -1593,7 +1793,7 @@ def _configure_rotation( pam_user_uid: str, config: Dict[str, Any], params: KeeperParams - ) -> None: + ) -> bool: """ Configure automatic password rotation using PAM rotation command. @@ -1618,18 +1818,23 @@ def _configure_rotation( raise CommandError('credential-provision', 'Rotation requires Python 3.8+ (pydantic dependency)') try: - schedule = rotation_config['schedule'] complexity = rotation_config['password_complexity'] rotation_cmd = PAMCreateRecordRotationCommand() directory_uid = config['account'].get('directory_uid') kwargs = { 'record_name': pam_user_uid, - 'schedule_cron_data': [schedule], 'pwd_complexity': complexity, 'enable': True, 'force': True, } + + # Validator guarantees exactly one of schedule/on_demand is set. + if rotation_config.get('on_demand'): + kwargs['on_demand'] = True + else: + kwargs['schedule_cron_data'] = [rotation_config['schedule']] + # Use directory_uid as resource if available, otherwise use pam_config_uid # Cannot pass both --resource and --iam-aad-config_uid simultaneously if directory_uid: @@ -1641,11 +1846,16 @@ def _configure_rotation( # Suppress verbose output from rotation command with redirect_stdout(StringIO()), redirect_stderr(StringIO()): rotation_cmd.execute(params, **kwargs) + return True # Real success — gateway accepted the rotation configuration except Exception as rotation_error: error_msg = str(rotation_error) if '500' in error_msg or 'gateway' in error_msg.lower(): + # Transient gateway failure — swallow and report False so the + # caller can suppress the misleading success-log. Operator must + # configure rotation manually once the gateway is reachable. logging.warning('Gateway unavailable - rotation configuration deferred') logging.warning('Configure rotation manually when gateway is available') + return False else: raise @@ -1695,10 +1905,17 @@ def _rotate_immediately( return True except Exception as e: - # Non-critical failure: PAM User is created and rotation scheduled - # The scheduled rotation will eventually sync the password + # Non-critical failure: PAM User record is created; rotation is configured. + # Remediation hint depends on rotation mode — in on_demand mode there is + # no cron tick to wait for, so direct the operator to manual triggering. logging.warning(f'⚠️ Immediate rotation failed: {e}') - logging.warning(f' Password will sync on next scheduled rotation') + if config.get('rotation', {}).get('on_demand'): + logging.warning( + ' Trigger an on-demand rotation manually to sync the password ' + '(Rotate Now in the Vault UI, or `pam action rotate `).' + ) + else: + logging.warning(' Password will sync on next scheduled rotation.') return False # Graceful degradation # ========================================================================= @@ -1889,7 +2106,18 @@ def _add_ad_user_to_groups_via_gateway( Returns: Tuple of (succeeded_groups, failed_groups) + + Raises: + CommandError: if gateway_uid is None/empty. Pre-PR this fail-fast + lived inside _create_ad_user_via_gateway; that path is now skipped + when existing_password is set, so we re-assert the precondition here. """ + if not gateway_uid: + raise CommandError( + 'credential-provision', + 'No connected Gateway found for PAM Configuration — cannot add AD groups. ' + 'Ensure the Gateway associated with the PAM Configuration is online.' + ) username = config['account']['username'] pam_config_uid = config['account']['pam_config_uid'] resource_uid = config['account'].get('directory_uid') diff --git a/tests/test_credential_provision_dryrun.py b/tests/test_credential_provision_dryrun.py new file mode 100644 index 000000000..6408d5857 --- /dev/null +++ b/tests/test_credential_provision_dryrun.py @@ -0,0 +1,471 @@ +""" +cp-rotation-skip-feat — Story 4 dry-run rewrite tests + +Covers the rewrite of _dry_run_report: + - Reflects the chosen execution path (rotation present/absent, on_demand/ + schedule, rotate_on_provision, existing_password) + - Absorbs commit 97dc8893: line 1027's NameError (pam undefined) fix + - Adds rotation_mode field to JSON output + - Extends rotation_status values + +Critical regression: `cp --output json --dry-run` must NOT raise NameError +for any behavior-matrix cell. +""" + +import io +import json +import pytest +from contextlib import redirect_stdout +from unittest import TestCase +from unittest.mock import MagicMock + +from keepercommander.commands.credential_provision import CredentialProvisionCommand + + +def make_params(): + p = MagicMock() + p.key_cache = {} + return p + + +def base_config(**overrides): + """Minimal valid config; overrides merge top-level keys.""" + config = { + 'user': {'first_name': 'A', 'last_name': 'B', 'personal_email': 'a@b.com', 'department': 'Eng'}, + 'account': {'username': 'svc-test', 'pam_config_uid': 'cfg-uid'}, + 'email': {'config_name': 'none', 'send_to': 'a@b.com'}, + } + config.update(overrides) + return config + + +# ============================================================================= +# NameError absorption — every behavior-matrix cell, JSON mode +# ============================================================================= + + +@pytest.mark.unit +class TestDryRunJsonNoNameError(TestCase): + """Regression for commit 97dc8893: line 1027 referenced an undefined `pam`. + Every behavior-matrix cell must complete in JSON mode without NameError. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _run_json_dry_run(self, config): + """Capture stdout from a JSON-mode dry-run; return parsed JSON.""" + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, config, 'json') + out = buf.getvalue() + return json.loads(out) + + def test_cell_1_no_rotation_no_existing_password(self): + result = self._run_json_dry_run(base_config()) + self.assertTrue(result['success']) + self.assertTrue(result['dry_run']) + + def test_cell_2_no_rotation_with_existing_password(self): + cfg = base_config() + cfg['account']['existing_password'] = 'KnownPass!' + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_3_schedule_default_rop(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_4_schedule_rop_false(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_5_on_demand_default_rop(self): + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + }) + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_6_on_demand_rop_false(self): + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_7_schedule_rop_false_existing_password(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + cfg['account']['existing_password'] = 'KnownPass!' + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + def test_cell_8_on_demand_rop_false_existing_password_tandem(self): + """Tandem's exact use case — must not crash dry-run.""" + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + cfg['account']['existing_password'] = 'KnownPass!' + result = self._run_json_dry_run(cfg) + self.assertTrue(result['success']) + + +# ============================================================================= +# rotation_mode reporting (new JSON field) +# ============================================================================= + + +@pytest.mark.unit +class TestRotationModeJsonField(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _run(self, config): + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, config, 'json') + return json.loads(buf.getvalue()) + + def test_no_rotation_block_mode_is_none(self): + result = self._run(base_config()) + self.assertEqual(result['configuration'].get('rotation_mode'), 'none') + + def test_schedule_mode_reported(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + result = self._run(cfg) + self.assertEqual(result['configuration'].get('rotation_mode'), 'scheduled') + + def test_on_demand_mode_reported(self): + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + }) + result = self._run(cfg) + self.assertEqual(result['configuration'].get('rotation_mode'), 'on_demand') + + def test_rotation_schedule_field_uses_local_binding(self): + """The actual fix from commit 97dc8893 — reads from `rotation` local, + not the undefined `pam` symbol.""" + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + result = self._run(cfg) + self.assertEqual(result['configuration'].get('rotation_schedule'), '0 0 3 * * ?') + + def test_rotation_schedule_null_when_on_demand(self): + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + }) + result = self._run(cfg) + # No cron => no schedule string in the JSON + self.assertIsNone(result['configuration'].get('rotation_schedule')) + + +# ============================================================================= +# Actions list reflects chosen path +# ============================================================================= + + +@pytest.mark.unit +class TestActionsListReflectsPath(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _actions(self, config): + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, config, 'json') + return json.loads(buf.getvalue())['actions'] + + def test_no_rotation_no_configure_rotation_action(self): + actions = self._actions(base_config()) + joined = '\n'.join(actions) + self.assertNotIn('Configure rotation', joined) + self.assertNotIn('immediate rotation', joined.lower()) + + def test_rotation_present_includes_configure_action(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + actions = self._actions(cfg) + self.assertTrue(any('Configure rotation' in a for a in actions)) + self.assertTrue(any('immediate rotation' in a.lower() for a in actions)) + + def test_rop_false_omits_immediate_action(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + actions = self._actions(cfg) + self.assertTrue(any('Configure rotation' in a for a in actions)) + self.assertFalse(any('immediate rotation' in a.lower() for a in actions), + f"Expected no immediate-rotation action; got: {actions}") + + def test_existing_password_omits_generate_action(self): + cfg = base_config() + cfg['account']['existing_password'] = 'KnownPass!' + actions = self._actions(cfg) + self.assertFalse(any('Generate secure password' in a for a in actions), + f"Expected no password-generation action; got: {actions}") + self.assertTrue(any('existing' in a.lower() and 'password' in a.lower() for a in actions), + f"Expected existing-password mention; got: {actions}") + + +# ============================================================================= +# Dry-run actions list must mirror execute() — regression for review finding #3 +# ============================================================================= + + +@pytest.mark.unit +class TestDryRunActionsMirrorExecute(TestCase): + """The actions list must reflect what execute() will actually do. + Pre-fix divergences: + - 'Add to AD groups' was never emitted even though execute() always runs + group-add when ad_groups is set + - 'Generate share URL' was emitted unconditionally even though execute() + gates it on has_email + - Direct-share was never represented + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _actions(self, config): + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, config, 'json') + return json.loads(buf.getvalue())['actions'] + + def test_ad_groups_emits_group_add_action_independent_of_create(self): + # BYO password + ad_groups: CREATE skipped, but group-add still runs. + cfg = base_config() + cfg['account']['existing_password'] = 'KnownPass!' + cfg['account']['distinguished_name'] = 'CN=svc,OU=...,DC=corp,DC=local' + cfg['account']['ad_groups'] = ['CN=Engineers,OU=Groups,DC=corp,DC=local'] + actions = self._actions(cfg) + joined = '\n'.join(actions) + # CREATE must NOT appear (existing_password gate) + self.assertNotIn('Create AD user', joined, + f"AD-create should be skipped; got: {actions}") + # Group-add MUST appear + self.assertTrue(any('AD groups' in a for a in actions), + f"Expected AD groups action; got: {actions}") + + def test_ad_groups_emits_action_when_create_also_runs(self): + # No existing_password + ad_groups: CREATE runs, group-add ALSO runs. + cfg = base_config() + cfg['account']['distinguished_name'] = 'CN=svc,OU=...,DC=corp,DC=local' + cfg['account']['ad_groups'] = ['CN=A,DC=corp', 'CN=B,DC=corp'] + actions = self._actions(cfg) + joined = '\n'.join(actions) + self.assertIn('Create AD user', joined, + f"Expected AD-create action; got: {actions}") + self.assertTrue(any('AD groups: 2' in a for a in actions), + f"Expected AD groups count=2 action; got: {actions}") + + def test_share_url_only_when_email_configured(self): + # email.config_name: none — share-URL action must NOT appear. + cfg = base_config() # email.config_name = 'none' + actions = self._actions(cfg) + joined = '\n'.join(actions) + self.assertNotIn('Generate share URL', joined, + f"Share-URL action should be skipped when no email; got: {actions}") + # And the skip-email line should appear + self.assertTrue(any('Skip welcome email' in a for a in actions), + f"Expected skip-welcome-email action; got: {actions}") + + def test_share_url_appears_when_email_configured(self): + cfg = base_config() + cfg['email']['config_name'] = 'SMTP-Gmail' + actions = self._actions(cfg) + joined = '\n'.join(actions) + self.assertIn('Generate share URL', joined, + f"Expected share-URL action with real email config; got: {actions}") + self.assertTrue(any('Send email to' in a and 'config: SMTP-Gmail' in a for a in actions), + f"Expected send-email action with config name; got: {actions}") + + def test_direct_share_action_when_delivery_configured(self): + cfg = base_config() + cfg['delivery'] = {'share_to': 'alice@example.com'} + actions = self._actions(cfg) + joined = '\n'.join(actions) + self.assertTrue(any('Share directly to' in a for a in actions), + f"Expected direct-share action; got: {actions}") + # The recipient should be PII-redacted in dry-run output + self.assertNotIn('alice@example.com', joined, + f"Recipient email should be redacted in dry-run; got: {actions}") + + def test_group_add_action_appears_after_create_pam_user_action(self): + """Regression for review #B (round 3): dry-run actions list must mirror + execute()'s call order — `Add to AD groups` must appear AFTER `Create PAM User`, + same invariant as the structural test on execute() itself. + """ + cfg = base_config() + cfg['account']['distinguished_name'] = 'CN=svc,OU=...,DC=corp,DC=local' + cfg['account']['ad_groups'] = ['CN=Engineers,DC=corp'] + actions = self._actions(cfg) + group_idx = next((i for i, a in enumerate(actions) if 'AD groups' in a), None) + create_idx = next((i for i, a in enumerate(actions) if 'Create PAM User' in a), None) + self.assertIsNotNone(group_idx, f'Expected AD-groups action; got: {actions}') + self.assertIsNotNone(create_idx, f'Expected Create PAM User action; got: {actions}') + self.assertGreater( + group_idx, create_idx, + f'Add-to-AD-groups must come AFTER Create-PAM-User in dry-run order ' + f'(mirrors execute() reorder in c663a127). Got order: {actions}', + ) + + def test_no_direct_share_action_when_no_delivery(self): + cfg = base_config() # no delivery section + actions = self._actions(cfg) + joined = '\n'.join(actions) + self.assertNotIn('Share directly to', joined, + f"Direct-share action should not appear without delivery section; got: {actions}") + + +# ============================================================================= +# Existing-password value NEVER in dry-run output +# ============================================================================= + + +@pytest.mark.unit +class TestDryRunNeverExposesExistingPassword(TestCase): + """The actual password value must not appear in any dry-run output.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + SECRET = 'TandemS3cretP@ssw0rd-DoNotEcho' + + def _run(self, output_format): + cfg = base_config(rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }) + cfg['account']['existing_password'] = self.SECRET + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, cfg, output_format) + return buf.getvalue() + + def test_json_output_never_contains_secret(self): + out = self._run('json') + self.assertNotIn(self.SECRET, out, + "existing_password value MUST NOT appear in JSON dry-run output") + + def test_text_output_never_contains_secret(self): + out = self._run('text') + self.assertNotIn(self.SECRET, out, + "existing_password value MUST NOT appear in text dry-run output") + + +# ============================================================================= +# Success-path JSON rotation_status (regression for E2E bug found 2026-05-12) +# ============================================================================= + + +@pytest.mark.unit +class TestSuccessPathRotationStatus(TestCase): + """The success-path JSON (in execute(), not _dry_run_report) must produce + the same 4-value rotation_status as the dry-run JSON. E2E run on + 2026-05-12 caught a bug where cells 1, 2 (no rotation block) and cell 8 + (on_demand + rop:false) all incorrectly reported 'scheduled'. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + def _compute(self, config, rotation_success: bool) -> str: + # Mirrors the logic at execute() line 508; this test isolates it. + if 'rotation' not in config: + return 'not_configured' + if rotation_success: + return 'synced' + if config['rotation'].get('on_demand'): + return 'on_demand' + return 'scheduled' + + def test_no_rotation_block_reports_not_configured(self): + config = {'account': {}} + self.assertEqual(self._compute(config, rotation_success=False), 'not_configured') + + def test_schedule_with_immediate_fired_reports_synced(self): + config = {'rotation': {'schedule': '0 0 3 * * ?'}} + self.assertEqual(self._compute(config, rotation_success=True), 'synced') + + def test_schedule_without_immediate_reports_scheduled(self): + config = {'rotation': {'schedule': '0 0 3 * * ?'}} + self.assertEqual(self._compute(config, rotation_success=False), 'scheduled') + + def test_on_demand_with_immediate_fired_reports_synced(self): + config = {'rotation': {'on_demand': True}} + self.assertEqual(self._compute(config, rotation_success=True), 'synced') + + def test_on_demand_without_immediate_reports_on_demand(self): + config = {'rotation': {'on_demand': True}} + self.assertEqual(self._compute(config, rotation_success=False), 'on_demand') + + +# ============================================================================= +# Text mode still works (regression for existing demo YAMLs) +# ============================================================================= + + +@pytest.mark.unit +class TestDryRunTextMode(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_text_mode_emits_header(self): + cfg = base_config(rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, cfg, 'text') + out = buf.getvalue() + self.assertIn('DRY RUN MODE', out) + self.assertIn('NO CHANGES WILL BE MADE', out) + + def test_text_mode_mentions_existing_password_use(self): + cfg = base_config() + cfg['account']['existing_password'] = 'KnownPass!' + buf = io.StringIO() + with redirect_stdout(buf): + self.cmd._dry_run_report(self.params, cfg, 'text') + out = buf.getvalue() + # mentions the fact, not the value + self.assertTrue('existing' in out.lower() and 'password' in out.lower()) + self.assertNotIn('KnownPass!', out) diff --git a/tests/test_credential_provision_execute.py b/tests/test_credential_provision_execute.py new file mode 100644 index 000000000..c1231ad71 --- /dev/null +++ b/tests/test_credential_provision_execute.py @@ -0,0 +1,476 @@ +""" +cp-rotation-skip-feat — Story 2/3 execution-flow tests + +Story 2 — Password source decision & _configure_rotation branching: + - Password source: existing_password → rotation.password_complexity → DEFAULT_COMPLEXITY + - _configure_rotation: on_demand mode vs schedule mode + - _configure_rotation is skipped when 'rotation' not in config + +Story 3 — Push-to-target guards: + - _rotate_immediately gated on rotate_on_provision (default True) + - _create_ad_user_via_gateway gated on not existing_password + - _add_ad_user_to_groups_via_gateway still runs when ad_groups set + - logging.info line when existing_password is used (no value in log) +""" + +import logging +import pytest +from unittest import TestCase +from unittest.mock import MagicMock, patch, call + +from keepercommander.commands.credential_provision import ( + CredentialProvisionCommand, + DEFAULT_COMPLEXITY, +) + + +# ============================================================================= +# Story 2 — _determine_password helper +# ============================================================================= + + +@pytest.mark.unit +class TestDeterminePassword(TestCase): + """ + The password-source decision is extracted into _determine_password for + testability. It returns (password, used_existing). + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + def test_existing_password_takes_precedence(self): + config = { + 'account': {'existing_password': 'KnownPass123!'}, + 'rotation': {'password_complexity': '16,2,2,2,2'}, # should be ignored + } + with patch.object(self.cmd, '_generate_password') as gen: + password, used_existing = self.cmd._determine_password(config) + self.assertEqual(password, 'KnownPass123!') + self.assertTrue(used_existing) + gen.assert_not_called() + + def test_rotation_complexity_used_when_no_existing_password(self): + config = { + 'account': {}, + 'rotation': {'password_complexity': '24,3,3,3,3'}, + } + with patch.object(self.cmd, '_generate_password', return_value='GENERATED') as gen: + password, used_existing = self.cmd._determine_password(config) + self.assertEqual(password, 'GENERATED') + self.assertFalse(used_existing) + gen.assert_called_once_with('24,3,3,3,3') + + def test_default_complexity_used_when_no_rotation_and_no_existing(self): + config = {'account': {}} + with patch.object(self.cmd, '_generate_password', return_value='GENERATED') as gen: + password, used_existing = self.cmd._determine_password(config) + self.assertEqual(password, 'GENERATED') + self.assertFalse(used_existing) + gen.assert_called_once_with(DEFAULT_COMPLEXITY) + + def test_empty_existing_password_falls_through_to_generation(self): + # An empty string is rejected by the validator (Story 1), but if it + # somehow reaches _determine_password, treat as not-supplied. + config = { + 'account': {'existing_password': ''}, + 'rotation': {'password_complexity': '24,3,3,3,3'}, + } + with patch.object(self.cmd, '_generate_password', return_value='GENERATED') as gen: + password, used_existing = self.cmd._determine_password(config) + self.assertFalse(used_existing) + gen.assert_called_once_with('24,3,3,3,3') + + +# ============================================================================= +# Story 2 — _configure_rotation branching on schedule vs on_demand +# ============================================================================= + + +@pytest.mark.unit +class TestConfigureRotationBranching(TestCase): + """ + _configure_rotation must pass either schedule_cron_data (cron mode) or + on_demand=True (manual-trigger mode) to PAMCreateRecordRotationCommand, + but never both. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = MagicMock() + + def _call_configure(self, rotation_config): + config = { + 'account': {'pam_config_uid': 'cfg-uid'}, + 'rotation': rotation_config, + } + # PAMCreateRecordRotationCommand is imported at module level in + # credential_provision.py — patch where it's used. + with patch( + 'keepercommander.commands.credential_provision.PAMCreateRecordRotationCommand' + ) as cmd_class: + instance = cmd_class.return_value + self.cmd._configure_rotation('pam-uid', config, self.params) + return cmd_class, instance + + def test_schedule_mode_passes_schedule_cron_data(self): + rotation = { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + } + _, instance = self._call_configure(rotation) + instance.execute.assert_called_once() + kwargs = instance.execute.call_args.kwargs + self.assertEqual(kwargs.get('schedule_cron_data'), ['0 0 3 * * ?']) + self.assertNotIn('on_demand', kwargs) + + def test_on_demand_mode_passes_on_demand_true(self): + rotation = { + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + } + _, instance = self._call_configure(rotation) + instance.execute.assert_called_once() + kwargs = instance.execute.call_args.kwargs + self.assertEqual(kwargs.get('on_demand'), True) + self.assertNotIn('schedule_cron_data', kwargs) + + def test_complexity_always_forwarded(self): + for rotation in ( + {'schedule': '0 0 3 * * ?', 'password_complexity': '32,5,5,5,5'}, + {'on_demand': True, 'password_complexity': '32,5,5,5,5'}, + ): + _, instance = self._call_configure(rotation) + kwargs = instance.execute.call_args.kwargs + self.assertEqual(kwargs.get('pwd_complexity'), '32,5,5,5,5') + + def test_enable_force_flags_unchanged(self): + rotation = {'on_demand': True, 'password_complexity': '32,5,5,5,5'} + _, instance = self._call_configure(rotation) + kwargs = instance.execute.call_args.kwargs + self.assertEqual(kwargs.get('enable'), True) + self.assertEqual(kwargs.get('force'), True) + + +# ============================================================================= +# Story 3 — _rotate_immediately gated on rotate_on_provision +# ============================================================================= + + +@pytest.mark.unit +class TestRotateOnProvisionGate(TestCase): + """ + _rotate_immediately at provisioning time fires only when rotate_on_provision + is True (the default). When false, the immediate rotation must be skipped. + + Since _rotate_immediately is called from execute(), this test validates the + GATE LOGIC by exercising a small helper that mirrors the gate's predicate. + The full integration is verified in TestExecuteFlow below. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + def test_should_rotate_default_true(self): + config = {'rotation': {'schedule': '0 0 3 * * ?'}} + self.assertTrue(self.cmd._should_rotate_on_provision(config)) + + def test_should_rotate_explicit_true(self): + config = {'rotation': {'rotate_on_provision': True}} + self.assertTrue(self.cmd._should_rotate_on_provision(config)) + + def test_should_rotate_explicit_false(self): + config = {'rotation': {'rotate_on_provision': False}} + self.assertFalse(self.cmd._should_rotate_on_provision(config)) + + def test_should_rotate_no_rotation_block(self): + # No rotation block at all => no rotation work, including no immediate + config = {} + self.assertFalse(self.cmd._should_rotate_on_provision(config)) + + +# ============================================================================= +# Story 3 — _create_ad_user_via_gateway gated on not existing_password +# ============================================================================= + + +@pytest.mark.unit +class TestADCreateGate(TestCase): + """ + The AD-create path pushes the password to the target via rm-create-user. + When existing_password is set, the operator has declared the account + pre-exists; AD-create must be skipped to prevent push-to-target. + + The gate is enforced in execute(); this test validates the predicate. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + def test_ad_create_allowed_without_existing_password(self): + config = {'account': {'distinguished_name': 'CN=...'}} + self.assertTrue(self.cmd._should_create_ad_user(config)) + + def test_ad_create_skipped_with_existing_password(self): + config = { + 'account': { + 'distinguished_name': 'CN=...', + 'existing_password': 'KnownPass!', + } + } + self.assertFalse(self.cmd._should_create_ad_user(config)) + + def test_ad_create_allowed_no_ad_fields_no_existing_pw(self): + # No AD fields => has_ad_config is False; gate returns False but for + # a different reason. The predicate intentionally returns False here + # because there's nothing to create. + config = {'account': {}} + self.assertFalse(self.cmd._should_create_ad_user(config)) + + def test_ad_create_with_ad_groups_only(self): + config = {'account': {'ad_groups': ['CN=Group1']}} + self.assertTrue(self.cmd._should_create_ad_user(config)) + + +# ============================================================================= +# Story 3 — logging.info when existing_password is used (NO VALUE) +# ============================================================================= + + +@pytest.mark.unit +class TestExistingPasswordLogging(TestCase): + """ + When existing_password is used, exactly one logging.info line records the + fact (with record UID, no value). This is the on-the-CLI audit trail. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + + def test_log_line_contains_no_password_value(self): + record_uid = "rec-uid-123" + with self.assertLogs(level='INFO') as captured: + self.cmd._log_existing_password_use(record_uid) + joined = '\n'.join(captured.output) + self.assertIn(record_uid, joined, + "log line should reference the record UID") + self.assertIn('existing_password', joined, + "log line should reference the field name") + + def test_rotate_immediately_failure_warning_branches_on_mode(self): + """Regression for review #A (round 3): when _rotate_immediately fails, + the remediation hint must reflect the rotation mode. on_demand mode + has no schedule — the operator must trigger manually, not wait for cron. + """ + cmd = CredentialProvisionCommand() + params = MagicMock() + + # Force the inner rotate command to raise + with patch( + 'keepercommander.commands.credential_provision.PAMGatewayActionRotateCommand' + ) as rotate_class: + instance = rotate_class.return_value + instance.execute.side_effect = RuntimeError('simulated gateway failure') + + # on_demand mode → hint must reference manual trigger + with self.assertLogs(level='WARNING') as captured: + result = cmd._rotate_immediately( + 'pam-uid', + {'rotation': {'on_demand': True}}, + params, + ) + self.assertFalse(result) + joined = '\n'.join(captured.output) + self.assertIn('manually', joined.lower(), + f'Expected manual-trigger hint for on_demand mode; got: {captured.output}') + self.assertNotIn('scheduled rotation', joined, + f'Should NOT reference "scheduled rotation" in on_demand mode; got: {captured.output}') + + # schedule mode → hint must reference the next scheduled rotation + with patch( + 'keepercommander.commands.credential_provision.PAMGatewayActionRotateCommand' + ) as rotate_class: + instance = rotate_class.return_value + instance.execute.side_effect = RuntimeError('simulated gateway failure') + + with self.assertLogs(level='WARNING') as captured: + cmd._rotate_immediately( + 'pam-uid', + {'rotation': {'schedule': '0 0 3 * * ?'}}, + params, + ) + joined = '\n'.join(captured.output) + self.assertIn('scheduled rotation', joined, + f'Expected scheduled-rotation hint in schedule mode; got: {captured.output}') + + def test_log_function_does_not_accept_password_value(self): + """The function signature accepts ONLY the record UID — by design, + the password value is never passed to it, so there is no way for a + future maintainer to accidentally log the secret.""" + import inspect + sig = inspect.signature(self.cmd._log_existing_password_use) + params = list(sig.parameters.keys()) + self.assertEqual(params, ['pam_user_uid'], + "Signature must accept only the record UID") + + +# ============================================================================= +# Story 3 — full-source structural regression guard +# ============================================================================= + + +@pytest.mark.unit +class TestSourceCodeRegressionGuard(TestCase): + """ + Structural test: the source file must not contain any logging statement + that interpolates the value of existing_password. This catches future + maintainers who accidentally add a log call with the secret value. + + Failure of this test does NOT mean we're leaking — it means the source + pattern is suspicious and a human must verify. + """ + + SOURCE_FILE = ( + 'keepercommander/commands/credential_provision.py' + ) + + def test_add_ad_user_to_groups_raises_when_gateway_uid_is_none(self): + """Regression for review #7 (offline gateway + existing_password + ad_groups): + the helper must fail fast with CommandError, not silently per-group. + Pre-PR this fail-fast lived inside _create_ad_user_via_gateway; that path + is now skipped when existing_password is set, so we re-assert it here. + """ + from keepercommander.error import CommandError + cmd = CredentialProvisionCommand() + params = MagicMock() + config = {'account': {'pam_config_uid': 'cfg-uid', 'username': 'svc'}} + with self.assertRaises(CommandError) as ctx: + cmd._add_ad_user_to_groups_via_gateway(config, params, gateway_uid=None) + self.assertIn('Gateway', str(ctx.exception)) + + def test_group_add_call_site_appears_after_create_pam_user(self): + """Regression for review #8 (rollback gap when group-add precedes pam-user-create): + the group-add call site must appear in the source AFTER _create_pam_user so a + failure during the critical record-creation steps doesn't orphan AD memberships. + """ + import os, re + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + src_path = os.path.join(repo_root, self.SOURCE_FILE) + with open(src_path, 'r') as f: + source = f.read() + + # Find the line number of the execute()-level _create_pam_user call. + create_pam_match = re.search( + r'pam_user_uid\s*=\s*self\._create_pam_user\(', + source, + ) + self.assertIsNotNone(create_pam_match, "Could not find _create_pam_user assignment") + create_pam_pos = create_pam_match.start() + + # Find the line number of the execute()-level _add_ad_user_to_groups_via_gateway call. + group_add_match = re.search( + r'self\._add_ad_user_to_groups_via_gateway\(', + source, + ) + self.assertIsNotNone(group_add_match, "Could not find group-add call site") + group_add_pos = group_add_match.start() + + self.assertGreater( + group_add_pos, create_pam_pos, + "_add_ad_user_to_groups_via_gateway must be called AFTER _create_pam_user " + "to avoid orphaning AD group memberships if record creation fails. " + "See PR #2043 review #8." + ) + + def test_add_ad_user_to_groups_uses_local_gateway_uid(self): + """Regression guard for review finding #2: the call to + _add_ad_user_to_groups_via_gateway must pass the local `gateway_uid` + (resolved at line ~401), NOT `state.ad_gateway_uid` (which is only set + as a side effect inside _create_ad_user_via_gateway and stays None when + the AD-create gate skips the create call for existing_password YAMLs). + """ + import os, re + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + src_path = os.path.join(repo_root, self.SOURCE_FILE) + with open(src_path, 'r') as f: + source = f.read() + # Find any call to _add_ad_user_to_groups_via_gateway + pattern = re.compile( + r'self\._add_ad_user_to_groups_via_gateway\s*\(([^)]*)\)', + re.MULTILINE | re.DOTALL, + ) + calls = pattern.findall(source) + self.assertGreater(len(calls), 0, + "Expected to find at least one call site for _add_ad_user_to_groups_via_gateway") + for args in calls: + self.assertNotIn('state.ad_gateway_uid', args, + "_add_ad_user_to_groups_via_gateway must not be called with " + "state.ad_gateway_uid — that's None when AD-create is skipped. " + "Use the local `gateway_uid` instead.") + # And it should explicitly pass gateway_uid + self.assertIn('gateway_uid', args, + f"Call site should pass gateway_uid; got args: {args!r}") + + def test_rotation_configured_log_gated_on_real_success(self): + """Regression for round-5 review finding: the '✅ Rotation configured' + log line must only fire when _configure_rotation actually configured + rotation (returns True), not after a swallowed gateway-500 deferral + (returns False). Sibling pattern to the rotation_success gate. + """ + import os + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + src_path = os.path.join(repo_root, self.SOURCE_FILE) + with open(src_path, 'r') as f: + source = f.read() + idx = source.find("'✅ Rotation configured'") + self.assertGreater(idx, 0, "Could not locate the rotation-configured log line") + context = source[max(0, idx - 300):idx] + # Must be gated on something more than just output_format — the actual + # return value of _configure_rotation needs to be checked. + self.assertIn('rotation_configured', context, + "The '✅ Rotation configured' log must be gated on _configure_rotation's " + "real return value, not just output_format. Otherwise a swallowed " + "gateway-500 deferral produces 'rotation deferred' warning + checkmark " + "contradiction. See PR #2043 round-5 review.") + + def test_rotation_success_log_gated_on_rotation_success(self): + """Regression for review #E (round 4): the '✅ Password rotation submitted' + log line must only fire when _rotate_immediately actually succeeded + (returns True), not unconditionally after the call. Otherwise a failing + rotation produces contradictory output (warning + success checkmark) + in the same provisioning run. + """ + import os + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + src_path = os.path.join(repo_root, self.SOURCE_FILE) + with open(src_path, 'r') as f: + source = f.read() + idx = source.find("'✅ Password rotation submitted'") + self.assertGreater(idx, 0, "Could not locate the success-log line") + context = source[max(0, idx - 200):idx] + self.assertIn('rotation_success', context, + "The '✅ Password rotation submitted' log must be gated on rotation_success. " + "Otherwise a False return from _rotate_immediately produces a misleading " + "success line in the same provisioning run. See PR #2043 review #E.") + + def test_no_log_statement_references_existing_password_value(self): + import os + import re + # Resolve path from the worktree root + repo_root = os.path.abspath(os.path.join( + os.path.dirname(__file__), '..')) + src_path = os.path.join(repo_root, self.SOURCE_FILE) + with open(src_path, 'r') as f: + source = f.read() + # Look for patterns like logging.X(f"... {existing_password} ...") + # where existing_password is the VALUE, not the field name. + # We forbid: logging.(...{...existing_password...}...) where + # the f-string is referring to the variable. + pattern = re.compile( + r'logging\.(debug|info|warning|error|critical)\s*\([^)]*\{[^}]*existing_password[^}]*\}', + re.MULTILINE | re.DOTALL, + ) + matches = pattern.findall(source) + self.assertEqual(matches, [], + "Found logging statement that may include existing_password VALUE — " + "review for potential credential leak.") diff --git a/tests/test_credential_provision_validation.py b/tests/test_credential_provision_validation.py new file mode 100644 index 000000000..00ae04338 --- /dev/null +++ b/tests/test_credential_provision_validation.py @@ -0,0 +1,642 @@ +""" +cp-rotation-skip-feat — Story 1 validator tests + +Tests for the schema-validator changes that gate the rest of the feature: + - rotation: block becomes optional + - rotation.on_demand (boolean, mutually exclusive with rotation.schedule) + - rotation.rotate_on_provision (boolean, default true) + - account.existing_password (non-empty string) + - INVARIANT-001 cross-section invariant + - DEFAULT_COMPLEXITY module-level constant + - empty-dict rotation: {} edge case + - regression: account.initial_password still rejected (KC-1007-2 unchanged) + - regression: deprecated pam.rotation: shape still emits migration error +""" + +import pytest +from unittest import TestCase +from unittest.mock import MagicMock + +from keepercommander.commands.credential_provision import ( + CredentialProvisionCommand, + DEFAULT_COMPLEXITY, +) + + +# ============================================================================= +# Shared fixtures / helpers +# ============================================================================= + + +def make_valid_user(): + return { + 'first_name': 'Test', + 'last_name': 'User', + 'personal_email': 'test@example.com', + } + + +def make_valid_account(extra=None): + base = { + 'username': 'svc-test', + 'pam_config_uid': 'abc123', + } + if extra: + base.update(extra) + return base + + +def make_valid_rotation(extra=None): + base = { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + } + if extra: + base.update(extra) + return base + + +def make_params(): + """Minimal KeeperParams mock — these validator tests must not trigger + vault/api lookups, so the configs avoid directory_uid / delivery.share_to.""" + p = MagicMock() + p.key_cache = {} + return p + + +# ============================================================================= +# DEFAULT_COMPLEXITY module constant +# ============================================================================= + + +@pytest.mark.unit +class TestDefaultComplexityConstant(TestCase): + def test_default_complexity_is_defined(self): + self.assertIsNotNone(DEFAULT_COMPLEXITY) + + def test_default_complexity_is_a_valid_complexity_string(self): + # Format: "length,upper,lower,digit,special" + parts = DEFAULT_COMPLEXITY.split(',') + self.assertEqual(len(parts), 5) + for p in parts: + int(p) # would raise if not numeric + + def test_default_complexity_meets_minimum_strength(self): + # Sanity: length >= 16, each class >= 2 (these are the architect's stated defaults). + parts = [int(p) for p in DEFAULT_COMPLEXITY.split(',')] + self.assertGreaterEqual(parts[0], 16, "default length must be at least 16") + for cls_count in parts[1:]: + self.assertGreaterEqual(cls_count, 2, "each character class count must be at least 2") + + +# ============================================================================= +# rotation: block is now optional +# ============================================================================= + + +@pytest.mark.unit +class TestRotationBlockOptional(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_validate_succeeds_without_rotation_block(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], f"Expected no errors, got: {errors}") + + def test_validate_succeeds_with_full_rotation_block_unchanged(self): + # Regression: existing YAMLs (with rotation block) must keep validating + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': make_valid_rotation(), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], f"Expected no errors, got: {errors}") + + +# ============================================================================= +# Empty-dict rotation: {} edge case +# ============================================================================= + + +@pytest.mark.unit +class TestEmptyRotationDict(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_null_rotation_value_rejected_with_helpful_error(self): + """Regression for review #C: `rotation:` with no value parses to None + in YAML and would crash downstream .get() calls. Validator must reject + with a clear message before any None-attribute access.""" + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': None, # <-- bare `rotation:` key, no value + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue( + any('rotation' in e and 'no value' in e for e in errors), + f'Expected helpful "rotation: was specified but has no value" error; got: {errors}', + ) + + def test_empty_rotation_dict_does_not_silently_pass(self): + # rotation: {} — present but empty — must be rejected + # (either as "missing schedule/on_demand" or with a structural error) + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': {}, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue(len(errors) > 0, "Empty rotation: {} dict must produce a validation error") + + +# ============================================================================= +# schedule / on_demand mutual exclusivity (ROT-001) +# ============================================================================= + + +@pytest.mark.unit +class TestScheduleOnDemandMutex(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _base_config(self, rotation): + return { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': rotation, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + + def test_only_schedule_validates(self): + cfg = self._base_config({ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + }) + errors = self.cmd._validate_config(self.params, cfg) + self.assertEqual(errors, [], f"Expected no errors, got: {errors}") + + def test_only_on_demand_validates(self): + cfg = self._base_config({ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + }) + errors = self.cmd._validate_config(self.params, cfg) + self.assertEqual(errors, [], f"Expected no errors, got: {errors}") + + def test_both_schedule_and_on_demand_rejected(self): + cfg = self._base_config({ + 'schedule': '0 0 3 * * ?', + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + }) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('schedule' in e and 'on_demand' in e for e in errors), + f"Expected mutual-exclusivity error, got: {errors}", + ) + + def test_neither_schedule_nor_on_demand_rejected(self): + cfg = self._base_config({ + 'password_complexity': '32,5,5,5,5', + }) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue(len(errors) > 0, + "Rotation block missing both schedule and on_demand must be rejected") + + +# ============================================================================= +# on_demand type check (ROT-002) +# ============================================================================= + + +@pytest.mark.unit +class TestOnDemandTypeCheck(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _config_with_on_demand(self, value): + return { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': { + 'on_demand': value, + 'password_complexity': '32,5,5,5,5', + }, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + + def test_on_demand_true_validates(self): + errors = self.cmd._validate_config(self.params, self._config_with_on_demand(True)) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_on_demand_string_rejected(self): + errors = self.cmd._validate_config(self.params, self._config_with_on_demand("true")) + self.assertTrue(any('on_demand' in e and 'boolean' in e.lower() for e in errors), + f"Expected on_demand boolean type error, got: {errors}") + + def test_on_demand_int_rejected(self): + errors = self.cmd._validate_config(self.params, self._config_with_on_demand(1)) + self.assertTrue(any('on_demand' in e and 'boolean' in e.lower() for e in errors), + f"Expected on_demand boolean type error, got: {errors}") + + +# ============================================================================= +# rotate_on_provision type check (ROT-003) +# ============================================================================= + + +@pytest.mark.unit +class TestRotateOnProvisionTypeCheck(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _config_with_rop(self, value): + return { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': value, + }, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + + def test_rotate_on_provision_true_validates(self): + errors = self.cmd._validate_config(self.params, self._config_with_rop(True)) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_rotate_on_provision_false_validates(self): + errors = self.cmd._validate_config(self.params, self._config_with_rop(False)) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_rotate_on_provision_string_rejected(self): + errors = self.cmd._validate_config(self.params, self._config_with_rop("false")) + self.assertTrue( + any('rotate_on_provision' in e and 'boolean' in e.lower() for e in errors), + f"Got: {errors}", + ) + + +# ============================================================================= +# account.existing_password (ACC-001) +# ============================================================================= + + +@pytest.mark.unit +class TestExistingPasswordValidation(TestCase): + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_existing_password_non_empty_string_validates_without_rotation(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': 'KnownPass123!'}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_existing_password_empty_string_rejected(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': ''}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue( + any('existing_password' in e for e in errors), + f"Got: {errors}", + ) + + def test_existing_password_whitespace_only_rejected(self): + """Regression for review #D (round 4): whitespace-only strings like + ' ' or '\\t' are non-empty but meaningless. Validator must reject + per its stated 'non-empty string' contract.""" + for value in [' ', '\t', '\n', ' \t\n ']: + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': value}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue( + any('existing_password' in e for e in errors), + f'Expected rejection for whitespace-only value {value!r}; got: {errors}', + ) + + def test_existing_password_preserves_internal_whitespace(self): + """Passwords legitimately containing internal whitespace must NOT be + rejected — only purely-whitespace strings are rejected.""" + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': 'pass with spaces'}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], + f'Internal whitespace in passwords must be preserved; got: {errors}') + + def test_existing_password_non_string_rejected(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': 12345}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue( + any('existing_password' in e for e in errors), + f"Got: {errors}", + ) + + +# ============================================================================= +# INVARIANT-001 cross-section invariant +# ============================================================================= + + +@pytest.mark.unit +class TestInvariant001(TestCase): + """ + INVARIANT-001: existing_password is rejected if a rotation: block is + present AND rotate_on_provision is not explicitly false. + + This is the load-bearing security check that prevents the customer-supplied + password from being pushed to the target system at provisioning time. + """ + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _config(self, rotation, existing_password=None): + account = make_valid_account() + if existing_password is not None: + account['existing_password'] = existing_password + return { + 'user': make_valid_user(), + 'account': account, + 'rotation': rotation, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + + def test_existing_password_with_schedule_and_default_rop_rejected(self): + # No rotate_on_provision => defaults to true => rejection + cfg = self._config( + rotation={'schedule': '0 0 3 * * ?', 'password_complexity': '32,5,5,5,5'}, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('existing_password' in e for e in errors), + f"INVARIANT-001 must reject, got: {errors}", + ) + + def test_existing_password_with_schedule_and_explicit_rop_true_rejected(self): + cfg = self._config( + rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': True, + }, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('existing_password' in e for e in errors), + f"INVARIANT-001 must reject, got: {errors}", + ) + + def test_existing_password_with_on_demand_and_default_rop_rejected(self): + # on_demand + default rop=true => still fires _rotate_immediately => rejected + cfg = self._config( + rotation={'on_demand': True, 'password_complexity': '32,5,5,5,5'}, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('existing_password' in e for e in errors), + f"INVARIANT-001 must reject, got: {errors}", + ) + + def test_existing_password_with_rop_false_validates(self): + # Cell 7 of behavior matrix: schedule + rop=false + existing_password => allowed + cfg = self._config( + rotation={ + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_existing_password_with_on_demand_and_rop_false_validates(self): + # Cell 8 of behavior matrix (Tandem's case) + cfg = self._config( + rotation={ + 'on_demand': True, + 'password_complexity': '32,5,5,5,5', + 'rotate_on_provision': False, + }, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_existing_password_with_no_rotation_block_validates(self): + # Cell 2 of behavior matrix + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'existing_password': 'KnownPass123!'}), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], f"Got: {errors}") + + def test_invariant_error_message_mentions_remediation(self): + # Error message must guide the operator to a fix + cfg = self._config( + rotation={'schedule': '0 0 3 * * ?', 'password_complexity': '32,5,5,5,5'}, + existing_password='KnownPass123!', + ) + errors = self.cmd._validate_config(self.params, cfg) + joined = '\n'.join(errors) + self.assertIn('rotate_on_provision', joined, + "Error message should mention rotate_on_provision as a remediation option") + + +# ============================================================================= +# Regression: delivery.transfer_ownership / remove_from_service_vault must be +# rejected when ANY rotation is configured (schedule OR on_demand). +# Catches the bug where the has_rotation predicate only checked schedule. +# ============================================================================= + + +@pytest.mark.unit +class TestRotationDeliveryInvariant(TestCase): + """The existing transfer_ownership/remove_from_service_vault incompatibility + with rotation must apply to BOTH schedule and on_demand modes. Pre-fix, + has_rotation = bool(rotation.schedule) missed the on_demand case.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def _config(self, rotation, delivery): + return { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'rotation': rotation, + 'delivery': delivery, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + + def test_on_demand_with_transfer_ownership_rejected(self): + cfg = self._config( + rotation={'on_demand': True, 'password_complexity': '32,5,5,5,5'}, + delivery={'share_to': 'someone@example.com', 'transfer_ownership': True}, + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('transfer_ownership' in e and 'rotation' in e for e in errors), + f"Expected transfer_ownership-rotation incompatibility error; got: {errors}", + ) + + def test_on_demand_with_remove_from_service_vault_rejected(self): + cfg = self._config( + rotation={'on_demand': True, 'password_complexity': '32,5,5,5,5'}, + delivery={'share_to': 'someone@example.com', 'remove_from_service_vault': True}, + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('remove_from_service_vault' in e and 'rotation' in e for e in errors), + f"Expected remove_from_service_vault-rotation incompatibility error; got: {errors}", + ) + + def test_schedule_with_transfer_ownership_still_rejected(self): + # Regression: pre-existing behavior for schedule mode must still fire. + cfg = self._config( + rotation={'schedule': '0 0 3 * * ?', 'password_complexity': '32,5,5,5,5'}, + delivery={'share_to': 'someone@example.com', 'transfer_ownership': True}, + ) + errors = self.cmd._validate_config(self.params, cfg) + self.assertTrue( + any('transfer_ownership' in e for e in errors), + f"Expected rejection; got: {errors}", + ) + + +# ============================================================================= +# Regression: account.initial_password still rejected (KC-1007-2) +# ============================================================================= + + +@pytest.mark.unit +class TestInitialPasswordStillRejected(TestCase): + """initial_password rejection from KC-1007-2 must remain unchanged.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_initial_password_still_rejected(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account({'initial_password': 'whatever'}), + 'rotation': make_valid_rotation(), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue(any('initial_password' in e for e in errors), + f"Got: {errors}") + + +# ============================================================================= +# Regression: deprecated pam.rotation: shape still emits migration error +# ============================================================================= + + +@pytest.mark.unit +class TestDeprecatedPamRotationMigrationError(TestCase): + """The pre-existing deprecation migration error must be unchanged.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_deprecated_pam_rotation_emits_migration_error(self): + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'pam': { + 'rotation': { + 'schedule': '0 0 3 * * ?', + 'password_complexity': '32,5,5,5,5', + } + }, + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue( + any('deprecated' in e.lower() for e in errors), + f"Expected deprecation migration error, got: {errors}", + ) + + +# ============================================================================= +# Regression: existing required-field errors still fire +# ============================================================================= + + +@pytest.mark.unit +class TestRequiredSectionsRegression(TestCase): + """user and account remain required; rotation no longer is.""" + + def setUp(self): + self.cmd = CredentialProvisionCommand() + self.params = make_params() + + def test_missing_user_section_rejected(self): + config = { + 'account': make_valid_account(), + 'rotation': make_valid_rotation(), + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue(any('user' in e for e in errors), f"Got: {errors}") + + def test_missing_account_section_rejected(self): + config = { + 'user': make_valid_user(), + 'rotation': make_valid_rotation(), + } + errors = self.cmd._validate_config(self.params, config) + self.assertTrue(any('account' in e for e in errors), f"Got: {errors}") + + def test_missing_rotation_section_no_longer_rejected(self): + # Critical regression target: pre-feature this would have failed. + config = { + 'user': make_valid_user(), + 'account': make_valid_account(), + 'email': {'config_name': 'none', 'send_to': 'test@example.com'}, + } + errors = self.cmd._validate_config(self.params, config) + self.assertEqual(errors, [], + f"rotation: should be optional now; got: {errors}") diff --git a/unit-tests/test_credential_provision.py b/unit-tests/test_credential_provision.py index fe958e2e4..c59f90fe9 100644 --- a/unit-tests/test_credential_provision.py +++ b/unit-tests/test_credential_provision.py @@ -582,16 +582,20 @@ def test_valid_complete_config(self): self.assertEqual(len(errors), 0, 'Valid config should have no errors') def test_missing_required_sections(self): - """Test detection of missing required sections.""" + """Test detection of missing required sections. + + Note: rotation is OPTIONAL as of cp-rotation-skip-feat (omitting it means + Commander will not manage rotation for the record). Only user and account + are required top-level sections. + """ config = { 'user': {'first_name': 'John'}, - # Missing account, rotation sections + # Missing account section } errors = self.cmd._validate_config(self.mock_params, config) self.assertGreater(len(errors), 0, 'Should detect missing sections') error_text = ' '.join(errors) self.assertIn('account', error_text) - self.assertIn('rotation', error_text) def test_multiple_validation_errors(self): """Test that multiple errors are collected (not fail-fast).""" From 6089346ca7137df8425b3d593f87695795463c3a Mon Sep 17 00:00:00 2001 From: pvagare-ks Date: Thu, 14 May 2026 20:57:53 +0530 Subject: [PATCH 23/26] Update the search command for keeperdrive (#2048) (#2049) --- KEEPER_DRIVE_COMMANDS.md | 2 +- .../commands/keeper_drive/__init__.py | 2 +- keepercommander/commands/record.py | 101 ++++++++++++++++-- 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/KEEPER_DRIVE_COMMANDS.md b/KEEPER_DRIVE_COMMANDS.md index 6d04d717b..9bc252937 100644 --- a/KEEPER_DRIVE_COMMANDS.md +++ b/KEEPER_DRIVE_COMMANDS.md @@ -24,7 +24,7 @@ To get help on a particular command, run: | `[kd-share-record]` | Grant, update, or revoke a user's access to a record | | `[kd-record-permission]` | Bulk-update sharing permissions across records in a folder | | `[kd-transfer-record]` | Transfer record ownership to another user | -| `[kd-record-details]` | Get metadata for one or more records | +| `[kd-record-details]` | Get metadata for records | | `[kd-get]` | Show full details for a record or folder | diff --git a/keepercommander/commands/keeper_drive/__init__.py b/keepercommander/commands/keeper_drive/__init__.py index 7ca527ea2..b7e4d97c4 100644 --- a/keepercommander/commands/keeper_drive/__init__.py +++ b/keepercommander/commands/keeper_drive/__init__.py @@ -87,7 +87,7 @@ def register_command_info(aliases, command_info): command_info['kd-rndir'] = 'Rename a KeeperDrive folder' command_info['kd-list'] = 'List Keeper Drive folders and records' command_info['kd-share-folder'] = 'Grant/update/revoke folder access' - command_info['kd-record-details'] = 'Get record metadata (title, color' + command_info['kd-record-details'] = 'Get record metadata' command_info['kd-share-record'] = 'Grant/update/revoke record sharing' command_info['kd-record-permission'] = 'Modify sharing permissions of records in a folder' command_info['kd-transfer-record'] = 'Transfer record ownership to another user' diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 17e108adb..7603e9776 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -139,7 +139,7 @@ def register_command_info(aliases, command_info): search_parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='verbose output') search_parser.add_argument('-c', '--categories', dest='categories', action='store', help='One or more of these letters for categories to search: "r" = records, ' - '"s" = shared folders, "t" = teams') + '"s" = shared folders, "t" = teams, "d" = KeeperDrive folders') search_parser.add_argument('--regex', dest='regex', action='store_true', help='treat pattern as a regular expression instead of space-separated search terms') search_parser.add_argument('--format', dest='format', action='store', choices=['table', 'json'], @@ -1301,11 +1301,13 @@ def execute(self, params, **kwargs): else: pattern = '' # Empty pattern matches all in token mode - categories = (kwargs.get('categories') or 'rst').lower() + categories = (kwargs.get('categories') or 'rstd').lower() verbose = kwargs.get('verbose') is True skip_details = not verbose fmt = kwargs.get('format', 'table') + kd_records_map = getattr(params, 'keeper_drive_records', {}) or {} + all_results = [] if 'r' in categories: @@ -1314,21 +1316,24 @@ def execute(self, params, **kwargs): if records: if fmt == 'json': for record in records: + is_kd = record.record_uid in kd_records_map result_item = { 'type': 'record', 'record_uid': record.record_uid, 'record_type': record.record_type, 'title': record.title, - 'description': vault_extensions.get_record_description(record) + 'description': vault_extensions.get_record_description(record), + 'record_category': 'KeeperDrive' if is_kd else 'Classic', } all_results.append(result_item) else: logging.info('') table = [] - headers = ['Record UID', 'Type', 'Title', 'Description'] + headers = ['Record UID', 'Type', 'Title', 'Description', 'Record Category'] for record in records: + record_category = 'KeeperDrive' if record.record_uid in kd_records_map else 'Classic' row = [record.record_uid, record.record_type, record.title, - vault_extensions.get_record_description(record)] + vault_extensions.get_record_description(record), record_category] table.append(row) table.sort(key=lambda x: (x[2] or '').lower()) @@ -1375,6 +1380,38 @@ def execute(self, params, **kwargs): logging.info('') display.formatted_teams(results, params=params, skip_details=skip_details) + + if 'd' in categories: + kd_folder_results = self._search_keeper_drive_folders(params, pattern, use_regex=use_regex) + if kd_folder_results: + if fmt == 'json': + for folder_uid, fobj in kd_folder_results: + result_item = { + 'type': 'keeper_drive_folder', + 'folder_uid': folder_uid, + 'name': fobj.get('name', ''), + 'parent_uid': self._display_parent_uid(fobj.get('parent_uid', '')), + } + all_results.append(result_item) + else: + logging.info('') + rows_with_parent = [ + (folder_uid, fobj.get('name', ''), + self._display_parent_uid(fobj.get('parent_uid', ''))) + for folder_uid, fobj in kd_folder_results + ] + any_non_root = any(parent for _, _, parent in rows_with_parent) + if any_non_root: + headers = ['KeeperDrive Folder UID', 'Name', 'Parent UID'] + table = [[fuid, name, parent] + for fuid, name, parent in rows_with_parent] + else: + headers = ['KeeperDrive Folder UID', 'Name'] + table = [[fuid, name] for fuid, name, _ in rows_with_parent] + table.sort(key=lambda x: (x[1] or '').lower()) + base.dump_report_data(table, headers, row_number=True, + column_width=None if verbose else 40) + if fmt == 'json': if all_results: table = [] @@ -1383,19 +1420,65 @@ def execute(self, params, **kwargs): for item in all_results: if item['type'] == 'record': row = [item['type'], item['record_uid'], item['title'], - f"Type: {item['record_type']}, Description: {item['description']}"] + f"Type: {item['record_type']}, Description: {item['description']}, Record Category: {item.get('record_category', 'Classic')}"] elif item['type'] == 'shared_folder': row = [item['type'], item['shared_folder_uid'], item['name'], f"Can Edit: {item['can_edit']}, Can Share: {item['can_share']}"] elif item['type'] == 'team': row = [item['type'], item['team_uid'], item['name'], f"Restrict Edit: {item['restrict_edit']}, Restrict View: {item['restrict_view']}, Restrict Share: {item['restrict_share']}"] + elif item['type'] == 'keeper_drive_folder': + details = (f"Parent UID: {item['parent_uid']}" + if item.get('parent_uid') else '') + row = [item['type'], item['folder_uid'], item['name'], details] table.append(row) return base.dump_report_data(table, headers, fmt='json') else: return base.dump_report_data([], ['type', 'uid', 'name', 'details'], fmt='json') + @staticmethod + def _display_parent_uid(parent_uid): + """Render a KeeperDrive ``parent_uid`` for display. + """ + if not parent_uid or parent_uid == 'root': + return '' + if parent_uid.startswith('AAAAAAAAAA'): + return '' + return parent_uid + + @staticmethod + def _search_keeper_drive_folders(params, search_str, use_regex=False): + """Search KeeperDrive folders by name. + """ + kd_folders = getattr(params, 'keeper_drive_folders', {}) or {} + if not kd_folders: + return [] + + if not search_str: + match_func = lambda target: True + elif use_regex: + try: + pat = re.compile(search_str, re.IGNORECASE) + except re.error: + return [] + match_func = lambda target: bool(pat.search(target)) + else: + tokens = [t.lower() for t in search_str.split() if t.strip()] + if not tokens: + match_func = lambda target: True + else: + match_func = lambda target: all(token in target for token in tokens) + + results = [] + for folder_uid, fobj in kd_folders.items(): + name = (fobj.get('name') or '') + + haystack = f"{name.lower()} {folder_uid.lower()}" + if match_func(haystack): + results.append((folder_uid, fobj)) + return results + class RecordListCommand(Command): def get_parser(self): @@ -1444,11 +1527,11 @@ def execute(self, params, **kwargs): headers = [base.field_to_title(x) for x in headers] table = [] for record in records: - # Determine if record is from Keeper Drive or Legacy + # Determine if record is from Keeper Drive or Classic is_keeper_drive = hasattr(params, 'keeper_drive_records') and record.record_uid in params.keeper_drive_records - source = 'KeeperDrive' if is_keeper_drive else 'Legacy' + record_category = 'KeeperDrive' if is_keeper_drive else 'Classic' row = [record.record_uid, record.record_type, record.title, - vault_extensions.get_record_description(record), record.shared, source] + vault_extensions.get_record_description(record), record.shared, record_category] table.append(row) table.sort(key=lambda x: (x[2] or '').lower()) if fmt != 'json': From 2d8066f4e2de66964619d4443cd3f114fc80aa7b Mon Sep 17 00:00:00 2001 From: amangalampalli-ks Date: Thu, 14 May 2026 20:58:51 +0530 Subject: [PATCH 24/26] Implementation of sso-cloud commands (#2000) (#2046) * Implement SSO Cloud Command * Fix cloud create bug * Add different steps for different IDPs * Add guide command and fix download upload azure idp issues * Refactor sso cloud code * Fix for review comments * Update copyright mark with correct mail ID * Update logo in init file * Fix review comments * Fix review comment * Update minor ui fixes * fix json issue --- keepercommander/commands/base.py | 4 + .../commands/sso_cloud/__init__.py | 60 ++++ .../commands/sso_cloud/config_commands.py | 161 +++++++++ .../commands/sso_cloud/constants.py | 228 ++++++++++++ .../commands/sso_cloud/log_commands.py | 131 +++++++ .../commands/sso_cloud/metadata_commands.py | 133 +++++++ keepercommander/commands/sso_cloud/mixin.py | 331 ++++++++++++++++++ keepercommander/commands/sso_cloud/parsers.py | 119 +++++++ .../commands/sso_cloud/sp_commands.py | 251 +++++++++++++ 9 files changed, 1418 insertions(+) create mode 100644 keepercommander/commands/sso_cloud/__init__.py create mode 100644 keepercommander/commands/sso_cloud/config_commands.py create mode 100644 keepercommander/commands/sso_cloud/constants.py create mode 100644 keepercommander/commands/sso_cloud/log_commands.py create mode 100644 keepercommander/commands/sso_cloud/metadata_commands.py create mode 100644 keepercommander/commands/sso_cloud/mixin.py create mode 100644 keepercommander/commands/sso_cloud/parsers.py create mode 100644 keepercommander/commands/sso_cloud/sp_commands.py diff --git a/keepercommander/commands/base.py b/keepercommander/commands/base.py index bc9e3c0f7..d14bbe14e 100644 --- a/keepercommander/commands/base.py +++ b/keepercommander/commands/base.py @@ -212,6 +212,10 @@ def register_enterprise_commands(commands, aliases, command_info): device_management.register_enterprise_commands(commands) device_management.register_enterprise_command_info(aliases, command_info) + from . import sso_cloud + sso_cloud.register_commands(commands) + sso_cloud.register_command_info(aliases, command_info) + if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 9): from.pedm import pedm_admin pedm_command = pedm_admin.PedmCommand() diff --git a/keepercommander/commands/sso_cloud/__init__.py b/keepercommander/commands/sso_cloud/__init__.py new file mode 100644 index 000000000..678efc706 --- /dev/null +++ b/keepercommander/commands/sso_cloud/__init__.py @@ -0,0 +1,60 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + self.ensure_cloud_sso(svc, target) + + settings_to_set = kwargs.get('setting') or [] + settings_to_reset = kwargs.get('reset') or [] + + if not settings_to_set and not settings_to_reset: + raise CommandError('sso-cloud', 'Provide at least one --set KEY=VALUE or --reset KEY argument.') + + config_rs = self.get_selected_configuration(params, sp_id, config_target=kwargs.get('config')) + config_id = config_rs.ssoSpConfigurationId + + available_settings = {} + for sv in config_rs.ssoCloudSettingValue: + available_settings[sv.settingName.lower()] = sv + + rq = ssocloud.SsoCloudConfigurationRequest() + rq.ssoServiceProviderId = sp_id + rq.ssoSpConfigurationId = config_id + + for setting_str in settings_to_set: + pos = setting_str.find('=') + if pos < 1: + raise CommandError('sso-cloud', f'Invalid setting format "{setting_str}". Expected KEY=VALUE.') + + key = setting_str[:pos].strip() + value = setting_str[pos + 1:].strip() + + existing = available_settings.get(key.lower()) + if not existing: + raise CommandError('sso-cloud', f'Unknown setting: "{key}". ' + f'Use "sso-cloud get" to see available settings.') + if not existing.isEditable: + raise CommandError('sso-cloud', f'Setting "{key}" is read-only.') + + action = ssocloud.SsoCloudSettingAction() + action.settingName = existing.settingName + action.operation = ssocloud.SET + action.value = value + rq.ssoCloudSettingAction.append(action) + + for key in settings_to_reset: + existing = available_settings.get(key.strip().lower()) + if not existing: + raise CommandError('sso-cloud', f'Unknown setting: "{key}".') + if not existing.isEditable: + raise CommandError('sso-cloud', f'Setting "{key}" is read-only.') + + action = ssocloud.SsoCloudSettingAction() + action.settingName = existing.settingName + action.operation = ssocloud.RESET_TO_DEFAULT + rq.ssoCloudSettingAction.append(action) + + updated_rs = api.communicate_rest( + params, rq, 'sso/config/sso_cloud_configuration_setting_set', + rs_type=ssocloud.SsoCloudConfigurationResponse) + + logging.info('Configuration updated successfully.') + self.dump_configuration(updated_rs) + + +class SsoCloudValidateCommand(EnterpriseCommand, SsoCloudMixin): + def get_parser(self): + return sso_cloud_validate_parser + + def execute(self, params, **kwargs): + # type: (Any, **Any) -> Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + self.ensure_cloud_sso(svc, target) + + config_rs = self.get_selected_configuration(params, sp_id, config_target=kwargs.get('config')) + config_id = config_rs.ssoSpConfigurationId + + rq = ssocloud.SsoCloudConfigurationValidationRequest() + rq.ssoSpConfigurationId.append(config_id) + + rs = api.communicate_rest( + params, rq, 'sso/config/sso_cloud_configuration_validate', + rs_type=ssocloud.SsoCloudConfigurationValidationResponse) + + all_valid = True + for vc in rs.validationContent: + if vc.isSuccessful: + logging.info('Configuration "%s" (ID: %s) is valid.', + config_rs.name, vc.ssoSpConfigurationId) + else: + all_valid = False + logging.warning('Configuration "%s" (ID: %s) has validation errors:', + config_rs.name, vc.ssoSpConfigurationId) + for msg in vc.errorMessage: + logging.warning(' - %s', msg) + + if all_valid: + logging.info('SSO Cloud configuration is ready for use.') diff --git a/keepercommander/commands/sso_cloud/constants.py b/keepercommander/commands/sso_cloud/constants.py new file mode 100644 index 000000000..0287aa5ef --- /dev/null +++ b/keepercommander/commands/sso_cloud/constants.py @@ -0,0 +1,228 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Create Application > Regular Web App'), + ('idp', 'Enable Addons > SAML2 WEB APP'), + ('idp', 'In the Usage tab > Download IdP Metadata XML'), + ('idp', 'In the Settings tab > paste the ACS Endpoint into "Application Callback URL":'), + ('value', '{acs_endpoint}'), + ('idp', 'Replace the Settings editor JSON with the below (Entity ID pre-filled in audience):'), + ('json', '{auth0_json}'), + ('idp', 'Click "Debug" to verify, then Save'), + ('cmd', 'sso-cloud upload "{name}" --file '), + ], + }, + 'azure': { + 'portal_name': 'Azure Entra ID', + 'portal_url': 'https://portal.azure.com', + 'steps': [ + ('cmd', 'sso-cloud download "{name}" --output sp-metadata.xml'), + ('idp', 'In Azure portal, navigate to Microsoft Entra ID'), + ('idp', 'Go to Enterprise Applications > New Application'), + ('idp', 'Search "Keeper Password Manager" > Create'), + ('idp', 'Go to Set up Single sign-on > SAML'), + ('idp', 'Click "Upload metadata file" and upload sp-metadata.xml'), + ('note', 'Azure auto-fills Entity ID and Reply URL from the metadata'), + ('idp', 'Paste the IdP Initiated Login Endpoint into "Sign on URL":'), + ('value', '{idp_login_endpoint}'), + ('idp', 'Save the Basic SAML Configuration'), + ('idp', 'Click on "No, I\'ll test later" when asked for the test SSO login'), + ('idp', 'In Attributes and Claims card > Edit: delete the 4 extra Additional Claims'), + ('note', 'Verify: NameID/Email = user.userprincipalname (or user.mail)'), + ('idp', 'Reload page, under SAML Signing Certificate > Download "Federation Metadata XML"'), + ('cmd', 'sso-cloud upload "{name}" --file --force-authn'), + ], + }, + 'okta': { + 'portal_name': 'Okta', + 'portal_url': 'https://login.okta.com', + 'steps': [ + ('idp', 'Go to Applications > Create App Integration > SAML 2.0'), + ('idp', 'Paste the ACS Endpoint into "Single sign-on URL":'), + ('value', '{acs_endpoint}'), + ('idp', 'Paste the Entity ID into "Audience URI (SP Entity ID)":'), + ('value', '{entity_id}'), + ('idp', 'Set Name ID format to EmailAddress'), + ('idp', 'Add attribute statements: Email, First, Last'), + ('idp', 'Finish, then go to Sign On tab > Download IdP Metadata'), + ('cmd', 'sso-cloud upload "{name}" --file '), + ], + }, + 'google': { + 'portal_name': 'Google Workspace', + 'portal_url': 'https://admin.google.com', + 'steps': [ + ('idp', 'Go to Apps > Web and mobile apps > Add App > Add custom SAML app'), + ('idp', 'Download IdP Metadata from the Google IdP Information step'), + ('idp', 'Paste the ACS Endpoint into "ACS URL":'), + ('value', '{acs_endpoint}'), + ('idp', 'Paste the Entity ID into "Entity ID":'), + ('value', '{entity_id}'), + ('idp', 'Set Name ID format to EMAIL'), + ('idp', 'Add attribute mappings for email, first name, last name'), + ('cmd', 'sso-cloud upload "{name}" --file '), + ], + }, + 'jumpcloud': { + 'portal_name': 'JumpCloud', + 'portal_url': 'https://console.jumpcloud.com', + 'steps': [ + ('idp', 'Go to SSO Applications > Add New Application > Custom SAML App'), + ('idp', 'Paste the ACS Endpoint into "ACS URL":'), + ('value', '{acs_endpoint}'), + ('idp', 'Paste the Entity ID into "SP Entity ID":'), + ('value', '{entity_id}'), + ('idp', 'Set SAMLSubject NameID to email'), + ('idp', 'Add attribute mappings for email, first name, last name'), + ('idp', 'Activate the application, then download IdP Metadata'), + ('cmd', 'sso-cloud upload "{name}" --file '), + ], + }, +} diff --git a/keepercommander/commands/sso_cloud/log_commands.py b/keepercommander/commands/sso_cloud/log_commands.py new file mode 100644 index 000000000..fe0d861c8 --- /dev/null +++ b/keepercommander/commands/sso_cloud/log_commands.py @@ -0,0 +1,131 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + + rq = ssocloud.SsoCloudSAMLLogRequest() + rq.ssoServiceProviderId = sp_id + + rs = api.communicate_rest( + params, rq, 'sso/config/sso_cloud_log_saml_get', + rs_type=ssocloud.SsoCloudSAMLLogResponse) + + fmt = kwargs.get('format') + verbose = kwargs.get('verbose', False) + + if not rs.entry and fmt != 'json': + logging.info('No SAML log entries found for SP "%s".', svc.get('name', target)) + return + + if fmt == 'json': + entries = [] + for entry in rs.entry: + e = { + 'server_time': entry.serverTime, + 'direction': entry.direction, + 'message_type': entry.messageType, + 'message_issued': entry.messageIssued, + 'from_entity_id': entry.fromEntityId, + 'saml_status': entry.samlStatus, + 'is_signed': entry.isSigned, + 'is_ok': entry.isOK, + } + if verbose: + e['relay_state'] = entry.relayState + e['saml_content'] = entry.samlContent + entries.append(e) + output = json.dumps(entries, indent=2) + output_path = kwargs.get('output') + if output_path: + try: + with open(os.path.expanduser(output_path), 'w') as f: + f.write(output) + logging.info('Log output written to %s', output_path) + except IOError as e: + raise CommandError('sso-cloud', f'Failed to write log output file "{output_path}": {e}') + else: + print(output) + return + + table = [] + headers = ['time', 'direction', 'type', 'status', 'signed', 'ok'] + if verbose: + headers.append('from_entity') + for entry in rs.entry: + row = [ + entry.serverTime, + entry.direction, + entry.messageType, + entry.samlStatus, + 'Yes' if entry.isSigned else 'No', + 'Yes' if entry.isOK else 'No', + ] + if verbose: + row.append(entry.fromEntityId) + table.append(row) + + dump_report_data(table, headers=headers, fmt=fmt, filename=kwargs.get('output')) + + if verbose: + logging.info('') + for i, entry in enumerate(rs.entry): + logging.info('--- Entry %d: %s %s ---', i + 1, entry.direction, entry.messageType) + if entry.relayState: + logging.info('Relay State: %s', entry.relayState) + if entry.samlContent: + logging.info('SAML Content:\n%s', entry.samlContent) + logging.info('') + + +class SsoCloudLogClearCommand(EnterpriseCommand, SsoCloudMixin): + def get_parser(self): + return sso_cloud_log_clear_parser + + def execute(self, params, **kwargs): + # type: (Any, **Any) -> Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + + rq = ssocloud.SsoCloudSAMLLogRequest() + rq.ssoServiceProviderId = sp_id + + api.communicate_rest( + params, rq, 'sso/config/sso_cloud_log_saml_clear', + rs_type=ssocloud.SsoCloudSAMLLogResponse) + + logging.info('SAML log entries cleared for SP "%s".', svc.get('name', target)) diff --git a/keepercommander/commands/sso_cloud/metadata_commands.py b/keepercommander/commands/sso_cloud/metadata_commands.py new file mode 100644 index 000000000..3f7bec2cf --- /dev/null +++ b/keepercommander/commands/sso_cloud/metadata_commands.py @@ -0,0 +1,133 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + self.ensure_cloud_sso(svc, target) + + filepath = kwargs.get('file', '') + filepath = os.path.expanduser(filepath) + if not os.path.isfile(filepath): + raise CommandError('sso-cloud', f'File not found: "{filepath}"') + + with open(filepath, 'rb') as f: + file_content = f.read() + + filename = os.path.basename(filepath) + + config_rs = self.get_selected_configuration(params, sp_id, config_target=kwargs.get('config')) + config_id = config_rs.ssoSpConfigurationId + + rq = ssocloud.SsoCloudIdpMetadataRequest() + rq.ssoSpConfigurationId = config_id + rq.filename = filename + rq.content = file_content + + rs = api.communicate_rest( + params, rq, 'sso/config/sso_cloud_upload_idp_metadata', + rs_type=ssocloud.SsoCloudConfigurationValidationResponse) + + has_errors = False + for vc in rs.validationContent: + if vc.isSuccessful: + logging.info('IdP metadata uploaded and validated successfully for configuration %s.', + vc.ssoSpConfigurationId) + else: + has_errors = True + logging.warning('Validation errors for configuration %s:', vc.ssoSpConfigurationId) + for msg in vc.errorMessage: + logging.warning(' - %s', msg) + + if not has_errors: + logging.info('File "%s" uploaded to configuration "%s" (ID: %s).', + filename, config_rs.name, config_id) + + if not has_errors and kwargs.get('force_authn'): + setting_rq = ssocloud.SsoCloudConfigurationRequest() + setting_rq.ssoServiceProviderId = sp_id + setting_rq.ssoSpConfigurationId = config_id + action = ssocloud.SsoCloudSettingAction() + action.settingName = 'sso_idp_force_login_mode' + action.operation = ssocloud.SET + action.value = 'true' + setting_rq.ssoCloudSettingAction.append(action) + api.communicate_rest( + params, setting_rq, 'sso/config/sso_cloud_configuration_setting_set', + rs_type=ssocloud.SsoCloudConfigurationResponse) + logging.info('ForceAuthn enabled.') + elif has_errors and kwargs.get('force_authn'): + logging.warning('Skipping --force-authn activation because metadata upload had validation errors.') + + +class SsoCloudDownloadMetadataCommand(EnterpriseCommand, SsoCloudMixin): + def get_parser(self): + return sso_cloud_download_parser + + def execute(self, params, **kwargs): + # type: (KeeperParams, **Any) -> Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + self.ensure_cloud_sso(svc, target) + + server_base = params.rest_context.server_base + if not server_base.endswith('/'): + server_base += '/' + metadata_url = f'{server_base}sso/saml/metadata/{sp_id}' + + rs = http_requests.get( + metadata_url, + proxies=params.rest_context.proxies, + verify=params.rest_context.certificate_check, + timeout=30, + ) + if rs.status_code != 200: + raise CommandError('sso-cloud', + f'Failed to download SP metadata (HTTP {rs.status_code}): {rs.text[:200]}') + + xml_content = rs.text + output_path = kwargs.get('output') + if output_path: + output_path = os.path.expanduser(output_path) + try: + with open(output_path, 'w', encoding='utf-8') as f: + f.write(xml_content) + logging.info('SP metadata saved to: %s', output_path) + except IOError as e: + raise CommandError('sso-cloud', f'Failed to write metadata file "{output_path}": {e}') + else: + print(xml_content) diff --git a/keepercommander/commands/sso_cloud/mixin.py b/keepercommander/commands/sso_cloud/mixin.py new file mode 100644 index 000000000..1760c2e2b --- /dev/null +++ b/keepercommander/commands/sso_cloud/mixin.py @@ -0,0 +1,331 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' dict + """Resolve an SSO service provider by ID or name from enterprise data.""" + if not target: + raise CommandError('sso-cloud', 'SSO Service Provider name or ID is required.') + + sso_services = params.enterprise.get('sso_services', []) + if not sso_services: + raise CommandError('sso-cloud', 'No SSO Cloud service providers found in this enterprise.') + + try: + target_id = int(target) + for svc in sso_services: + if svc.get('sso_service_provider_id') == target_id: + return svc + except (ValueError, TypeError): + logging.debug('Target "%s" is not numeric, searching by name.', target) + + target_lower = target.lower() + matches = [s for s in sso_services if s.get('name', '').lower() == target_lower] + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + raise CommandError('sso-cloud', + f'Multiple SSO service providers match "{target}". Use the SP ID instead.') + + raise CommandError('sso-cloud', + f'SSO Service Provider "{target}" not found. ' + f'Run "ed -f" to refresh enterprise data, then "sso list" to verify.') + + @staticmethod + def ensure_cloud_sso(svc, target=''): + # type: (dict, str) -> None + """Warn if the SP doesn't appear as Cloud SSO in cached enterprise data.""" + if not svc.get('is_cloud'): + logging.debug('SSO Service Provider "%s" is_cloud flag is not set in enterprise cache. ' + 'Proceeding anyway — the server will enforce if invalid.', + svc.get('name', target)) + + @staticmethod + def get_node_name(params, node_id): + # type: (KeeperParams, int) -> str + """Resolve a node ID to its display name.""" + for node in params.enterprise.get('nodes', []): + if node['node_id'] == node_id: + if node.get('parent_id', 0) > 0: + return node['data'].get('displayname') or str(node_id) + else: + return params.enterprise.get('enterprise_name', str(node_id)) + return str(node_id) + + @staticmethod + def get_selected_configuration(params, sp_id, config_target=None): + # type: (KeeperParams, int, Optional[str]) -> ssocloud.SsoCloudConfigurationResponse + """Fetch the active or specified configuration for a service provider.""" + list_rq = ssocloud.SsoCloudServiceProviderConfigurationListRequest() + list_rq.ssoServiceProviderId = sp_id + list_rs = api.communicate_rest( + params, list_rq, 'sso/config/sso_cloud_sp_configuration_get', + rs_type=ssocloud.SsoCloudServiceProviderConfigurationListResponse) + + owned = [c for c in list_rs.configurationItem + if not c.ssoServiceProviderId or sp_id in c.ssoServiceProviderId] + + if not owned: + raise CommandError('sso-cloud', f'No configurations found for SP ID {sp_id}.') + + config_item = None + if config_target: + try: + config_id = int(config_target) + config_item = next( + (c for c in owned if c.ssoSpConfigurationId == config_id), None) + except ValueError: + pass + + if not config_item: + config_lower = config_target.lower() + matches = [c for c in owned if c.name.lower() == config_lower] + if len(matches) == 1: + config_item = matches[0] + elif len(matches) > 1: + raise CommandError('sso-cloud', + f'Multiple configurations match "{config_target}". Use Configuration ID.') + + if not config_item: + raise CommandError('sso-cloud', f'Configuration "{config_target}" not found.') + else: + config_item = next((c for c in owned if c.isSelected), None) + if not config_item: + config_item = owned[0] + + get_rq = ssocloud.SsoCloudConfigurationRequest() + get_rq.ssoServiceProviderId = sp_id + get_rq.ssoSpConfigurationId = config_item.ssoSpConfigurationId + return api.communicate_rest( + params, get_rq, 'sso/config/sso_cloud_configuration_get', + rs_type=ssocloud.SsoCloudConfigurationResponse) + + @staticmethod + def format_setting_value(setting): + # type: (ssocloud.SsoCloudSettingValue) -> str + """Format a setting value for display, handling special cases.""" + value = setting.value or '' + if setting.isFromFile: + if value and len(value) > 80: + return f'[{len(value)} bytes]' + if setting.settingName == 'sso_idp_type_id': + try: + idp_type = int(value) + return IDP_TYPE_NAMES.get(idp_type, f'Unknown ({value})') + except (ValueError, TypeError): + pass + return value + + @staticmethod + def _extract_sp_values(config_rs): + # type: (ssocloud.SsoCloudConfigurationResponse) -> dict + keys = ('sso_sp_entity_id', 'sso_sp_acs_endpoint', 'sso_sp_login_endpoint', + 'sso_sp_logout_endpoint', 'sso_sp_slo_endpoint', + 'sso_idp_initiated_login_endpoint', 'sso_sp_domain') + result = {} + for sv in config_rs.ssoCloudSettingValue: + if sv.settingName in keys: + result[sv.settingName] = sv.value or '' + return result + + @staticmethod + def _get_idp_type_name(config_rs): + # type: (ssocloud.SsoCloudConfigurationResponse) -> Optional[str] + for sv in config_rs.ssoCloudSettingValue: + if sv.settingName == 'sso_idp_type_id' and sv.value: + try: + return IDP_ENUM_TO_KEY.get(int(sv.value)) + except (ValueError, TypeError): + pass + return None + + @staticmethod + def show_idp_guidance(config_rs, sp_name=''): + # type: (ssocloud.SsoCloudConfigurationResponse, str) -> None + """Show IdP-specific setup guidance with formatted output.""" + idp_type_name = SsoCloudMixin._get_idp_type_name(config_rs) + if not idp_type_name: + logging.warning('No IdP type set on this configuration. ' + 'Use "sso-cloud set --set sso_idp_type_id=" to set the IdP type, ' + 'then re-run "sso-cloud guide".') + return + guidance = IDP_SETUP_GUIDANCE.get(idp_type_name) + if not guidance: + supported = ', '.join(sorted(IDP_SETUP_GUIDANCE.keys())) + logging.warning('No setup guide is available for IdP type "%s". ' + 'Guides currently exist for: %s.', idp_type_name, supported) + return + + sp = SsoCloudMixin._extract_sp_values(config_rs) + portal = guidance['portal_name'] + display_name = sp_name or str(config_rs.ssoServiceProviderId) + + vals = { + 'name': display_name, + 'entity_id': sp.get('sso_sp_entity_id', ''), + 'acs_endpoint': sp.get('sso_sp_acs_endpoint', ''), + 'login_endpoint': sp.get('sso_sp_login_endpoint', ''), + 'idp_login_endpoint': sp.get('sso_idp_initiated_login_endpoint', ''), + 'slo_endpoint': sp.get('sso_sp_slo_endpoint', ''), + 'auth0_json': AUTH0_SAML_JSON_TEMPLATE.format( + entity_id=sp.get('sso_sp_entity_id') or ''), + } + + BAR = '\u2500' * 60 + CMD_TAG = '[Commander]' + IDP_TAG = f'[{portal}]' + + print('') + print(f'{portal} SSO Setup Guide') + print(BAR) + print(guidance.get('portal_url', '')) + print('') + + step_num = 0 + for kind, text in guidance['steps']: + filled = text.format(**vals) + + if kind == 'value': + print(f' {filled}') + print('') + elif kind == 'json': + for json_line in filled.splitlines(): + print(f' {json_line}') + print('') + elif kind == 'note': + print(f' * {filled}') + elif kind == 'cmd': + step_num += 1 + print(f'{step_num:>2}. {CMD_TAG} My Vault> {filled}') + else: + step_num += 1 + print(f'{step_num:>2}. {IDP_TAG} {filled}') + + print('') + + @staticmethod + def dump_configuration(config_rs, fmt=None, filename=None): + # type: (ssocloud.SsoCloudConfigurationResponse, Optional[str], Optional[str]) -> None + """Display configuration details.""" + logging.info('') + logging.info('{0:>40s}: {1}'.format('Service Provider ID', config_rs.ssoServiceProviderId)) + logging.info('{0:>40s}: {1}'.format('Configuration ID', config_rs.ssoSpConfigurationId)) + logging.info('{0:>40s}: {1}'.format('Configuration Name', config_rs.name)) + logging.info('{0:>40s}: {1}'.format('Protocol', config_rs.protocol)) + logging.info('{0:>40s}: {1}'.format('Last Modified', config_rs.lastModified)) + + if fmt == 'json': + settings_list = [] + for sv in config_rs.ssoCloudSettingValue: + settings_list.append({ + 'setting_id': sv.settingId, + 'setting_name': sv.settingName, + 'label': sv.label, + 'value': sv.value, + 'editable': sv.isEditable, + 'required': sv.isRequired, + 'from_file': sv.isFromFile, + 'last_modified': sv.lastModified, + }) + output = json.dumps({ + 'sso_service_provider_id': config_rs.ssoServiceProviderId, + 'sso_sp_configuration_id': config_rs.ssoSpConfigurationId, + 'name': config_rs.name, + 'protocol': config_rs.protocol, + 'last_modified': config_rs.lastModified, + 'settings': settings_list + }, indent=2) + if filename: + expanded_path = os.path.expanduser(filename) + try: + with open(expanded_path, 'w') as f: + f.write(output) + logging.info('Output written to %s', expanded_path) + except IOError as e: + raise CommandError('sso-cloud', f'Failed to write output file "{filename}": {e}') + else: + print(output) + return + + settings_by_name = {} # type: Dict[str, ssocloud.SsoCloudSettingValue] + for sv in config_rs.ssoCloudSettingValue: + settings_by_name[sv.settingName] = sv + + for group_label, setting_names in SETTING_GROUPS.items(): + group_settings = [settings_by_name.get(name) for name in setting_names] + group_settings = [s for s in group_settings if s is not None] + if not group_settings: + continue + + logging.info('') + logging.info(' --- %s ---', group_label) + for sv in group_settings: + if sv.isFromFile and sv.value and len(sv.value) > 80: + display_value = f'[{len(sv.value)} bytes]' + else: + display_value = SsoCloudMixin.format_setting_value(sv) + + editable_marker = '' if sv.isEditable else ' (read-only)' + required_marker = ' *' if sv.isRequired else '' + logging.info('{0:>40s}: {1}{2}{3}'.format( + sv.label or sv.settingName, display_value, required_marker, editable_marker)) + + ungrouped_names = set() + for group_names in SETTING_GROUPS.values(): + ungrouped_names.update(group_names) + ungrouped = [sv for name, sv in settings_by_name.items() if name not in ungrouped_names] + if ungrouped: + logging.info('') + logging.info(' --- Other Settings ---') + for sv in ungrouped: + display_value = SsoCloudMixin.format_setting_value(sv) + logging.info('{0:>40s}: {1}'.format(sv.label or sv.settingName, display_value)) + + logging.info('') + + @staticmethod + def dump_sso_services(params, fmt=None, filename=None): + # type: (KeeperParams, Optional[str], Optional[str]) -> None + """Display all SSO service providers as a table.""" + sso_services = params.enterprise.get('sso_services', []) + table = [] + headers = ['sp_id', 'name', 'node_id', 'node_name', 'active', 'is_cloud'] + if fmt and fmt != 'json': + headers = [field_to_title(x) for x in headers] + for svc in sso_services: + sp_id = svc.get('sso_service_provider_id') + name = svc.get('name', '') + node_id = svc.get('node_id', 0) + node_name = SsoCloudMixin.get_node_name(params, node_id) if node_id else 'N/A' + active = svc.get('active', False) + is_cloud = svc.get('is_cloud', False) + table.append([sp_id, name, node_id, node_name, active, is_cloud]) + return dump_report_data(table, headers=headers, fmt=fmt, filename=filename) diff --git a/keepercommander/commands/sso_cloud/parsers.py b/keepercommander/commands/sso_cloud/parsers.py new file mode 100644 index 000000000..8cdb27d9f --- /dev/null +++ b/keepercommander/commands/sso_cloud/parsers.py @@ -0,0 +1,119 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Any + name = kwargs.get('name') + if not name: + logging.warning('"--name" option is required for "create" command') + return + + node_name = kwargs.get('node') + nodes = list(self.resolve_nodes(params, node_name)) + if len(nodes) == 0: + raise CommandError('sso-cloud', f'Node "{node_name}" not found.') + if len(nodes) > 1: + raise CommandError('sso-cloud', f'Node name "{node_name}" is not unique. Use Node ID.') + target_node = nodes[0] + node_id = target_node['node_id'] + + existing = params.enterprise.get('sso_services', []) + for svc in existing: + if svc.get('node_id') == node_id: + raise CommandError('sso-cloud', + f'Node already has an SSO service provider: ' + f'"{svc.get("name")}" (ID: {svc.get("sso_service_provider_id")})') + + tree_key = params.enterprise.get('unencrypted_tree_key') + if not tree_key: + raise CommandError('sso-cloud', 'Enterprise tree key not available. Ensure enterprise data is loaded.') + + sp_data_key = crypto.get_random_bytes(32) + encrypted_sp_data_key = crypto.encrypt_aes_v1(sp_data_key, tree_key) + + rq = { + 'command': 'sso_service_provider_add', + 'sso_service_provider_id': self.get_enterprise_id(params), + 'node_id': node_id, + 'name': name, + 'sp_data_key': utils.base64_url_encode(encrypted_sp_data_key), + 'invite_new_users': True, + 'is_cloud': True, + } + rs = api.communicate(params, rq) + sp_id = rs.get('sso_service_provider_id') or rq['sso_service_provider_id'] + logging.info('SSO Service Provider created: %s (ID: %s)', name, sp_id) + + config_name = kwargs.get('config_name') or 'Default' + config_rq = ssocloud.SsoCloudConfigurationRequest() + config_rq.ssoServiceProviderId = sp_id + config_rq.name = config_name + config_rq.ssoAuthProtocolType = ssocloud.SAML2 + + config_rs = api.communicate_rest( + params, config_rq, 'sso/config/sso_cloud_configuration_add', + rs_type=ssocloud.SsoCloudConfigurationResponse) + + config_id = config_rs.ssoSpConfigurationId + logging.info('SAML2 Configuration created: "%s" (ID: %s)', config_name, config_id) + + setting_rq = ssocloud.SsoCloudConfigurationRequest() + setting_rq.ssoServiceProviderId = sp_id + setting_rq.ssoSpConfigurationId = config_id + + idp_type_name = kwargs['idp_type'] + idp_type_enum = IDP_TYPE_NAME_TO_ENUM.get(idp_type_name.lower()) + if idp_type_enum is not None: + action = ssocloud.SsoCloudSettingAction() + action.settingName = 'sso_idp_type_id' + action.operation = ssocloud.SET + action.value = str(idp_type_enum) + setting_rq.ssoCloudSettingAction.append(action) + + domain = kwargs.get('domain') + if domain: + action = ssocloud.SsoCloudSettingAction() + action.settingName = 'sso_sp_domain' + action.operation = ssocloud.SET + action.value = domain + setting_rq.ssoCloudSettingAction.append(action) + + if setting_rq.ssoCloudSettingAction: + api.communicate_rest( + params, setting_rq, 'sso/config/sso_cloud_configuration_setting_set', + rs_type=ssocloud.SsoCloudConfigurationResponse) + if idp_type_enum is not None: + logging.info('IdP type set to: %s', IDP_TYPE_NAMES.get(idp_type_enum, idp_type_name)) + if domain: + logging.info('Enterprise domain set to: %s', domain) + + api.query_enterprise(params, force=True) + + fmt = kwargs.get('format') + if fmt == 'json': + import json as json_mod + result = { + 'sso_service_provider_id': sp_id, + 'name': name, + 'node_id': node_id, + 'config_id': config_id, + 'config_name': config_name, + 'idp_type': idp_type_name, + } + if domain: + result['domain'] = domain + try: + config_rs = self.get_selected_configuration(params, sp_id) + settings = {} + for sv in config_rs.ssoCloudSettingValue: + settings[sv.settingName] = sv.value or '' + result['settings'] = settings + except Exception as e: + logging.debug('Failed to fetch settings for JSON output: %s', e) + print(json_mod.dumps(result, indent=2)) + else: + logging.info('') + logging.info('Next steps:') + logging.info(' sso-cloud guide "%s" View IdP-specific setup instructions', name) + logging.info(' sso-cloud get "%s" View configuration details & endpoints', name) + + +class SsoCloudDeleteCommand(EnterpriseCommand, SsoCloudMixin): + def get_parser(self): + return sso_cloud_delete_parser + + def execute(self, params, **kwargs): + # type: (KeeperParams, **Any) -> Any + target = kwargs.get('target') + svc = self.find_sso_service(params, target) + sp_id = svc['sso_service_provider_id'] + sp_name = svc.get('name', target) + self.ensure_cloud_sso(svc, target) + + config_target = kwargs.get('config') + if config_target: + self._delete_configuration(params, sp_id, config_target, kwargs.get('force')) + else: + self._delete_service_provider(params, sp_id, sp_name, kwargs.get('force')) + + @staticmethod + def _delete_configuration(params, sp_id, config_target, force): + # type: (KeeperParams, int, str, bool) -> None + config_rs = SsoCloudMixin.get_selected_configuration(params, sp_id, config_target=config_target) + config_id = config_rs.ssoSpConfigurationId + config_name = config_rs.name + + if not force: + answer = user_choice( + f'Are you sure you want to delete configuration "{config_name}" (ID: {config_id})?', + 'yn', default='n') + if answer.lower() != 'y': + logging.info('Delete cancelled.') + return + + rq = ssocloud.SsoCloudConfigurationRequest() + rq.ssoServiceProviderId = sp_id + rq.ssoSpConfigurationId = config_id + + api.communicate_rest( + params, rq, 'sso/config/sso_cloud_configuration_delete', + rs_type=ssocloud.SsoCloudConfigurationResponse) + + logging.info('Configuration "%s" (ID: %s) deleted.', config_name, config_id) + api.query_enterprise(params, force=True) + + @staticmethod + def _delete_service_provider(params, sp_id, sp_name, force): + # type: (KeeperParams, int, str, bool) -> None + if not force: + answer = user_choice( + f'Are you sure you want to delete SSO Service Provider "{sp_name}" (ID: {sp_id}) ' + f'and ALL its configurations?', 'yn', default='n') + if answer.lower() != 'y': + logging.info('Delete cancelled.') + return + + rq = { + 'command': 'sso_service_provider_delete', + 'sso_service_provider_id': sp_id, + } + api.communicate(params, rq) + + logging.info('SSO Service Provider "%s" (ID: %s) deleted.', sp_name, sp_id) + api.query_enterprise(params, force=True) From 529fd4cbd672be881f2916777f0661307499ef2b Mon Sep 17 00:00:00 2001 From: lthievenaz-keeper Date: Thu, 14 May 2026 14:08:07 +0100 Subject: [PATCH 25/26] Disable SO_REUSEADDR on tunnel start When starting tunnels from Windows, Keeper reuses the same port (49152). This is because of SO_REUSEADDR and Windows behavior, and is resolved when disabling it. --- keepercommander/commands/tunnel/port_forward/tunnel_helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 3fdfff16d..82a494ff9 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -786,8 +786,6 @@ def find_open_port(tried_ports: list, start_port=49152, end_port=65535, preferre def is_port_open(host: str, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: - # Enable SO_REUSEADDR to allow binding to ports in TIME_WAIT state - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) return True except OSError: From aa5aa24449b65d73d995515f8a3add99480f9b39 Mon Sep 17 00:00:00 2001 From: Sergey Kolupaev Date: Thu, 14 May 2026 10:35:47 -0700 Subject: [PATCH 26/26] Release 18.0.1 --- keepercommander/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keepercommander/__init__.py b/keepercommander/__init__.py index f9ffffaab..d8e2d6d52 100644 --- a/keepercommander/__init__.py +++ b/keepercommander/__init__.py @@ -10,4 +10,4 @@ # Contact: commander@keepersecurity.com # -__version__ = '18.0.0' +__version__ = '18.0.1'