diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 013305399..098e1a70c 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -48,7 +48,7 @@ ## Where to add new code or tests ๐Ÿงญ -- SDK code: `nodejs/src`, `python/copilot`, `go`, `dotnet/src` -- Unit tests: `nodejs/test`, `python/*`, `go/*`, `dotnet/test` +- SDK code: `nodejs/src`, `python/copilot`, `go`, `dotnet/src`, `rust/src` +- Unit tests: `nodejs/test`, `python/*`, `go/*`, `dotnet/test`, `rust/tests` - E2E tests: `*/e2e/` folders that use the shared replay proxy and `test/snapshots/` - Generated types: update schema in `@github/copilot` then run `cd nodejs && npm run generate:session-types` and commit generated files in `src/generated` or language generated location. diff --git a/.github/skills/rust-coding-skill/SKILL.md b/.github/skills/rust-coding-skill/SKILL.md new file mode 100644 index 000000000..7e0342f06 --- /dev/null +++ b/.github/skills/rust-coding-skill/SKILL.md @@ -0,0 +1,255 @@ +--- +name: rust-coding-skill +description: "Use this skill whenever editing `*.rs` files in the `rust/` SDK in order to write idiomatic, efficient, well-structured Rust code" +--- + +# Rust Coding Skill + +Opinionated Rust rules for the Copilot Rust SDK (`rust/`). Priority order: + +1. **Readable code** โ€” every line should earn its place +2. **Correct code** โ€” especially in concurrent/async contexts +3. **Performant code** โ€” think about allocations, data structures, hot paths + +## Error handling + +The SDK's public error type is `crate::Error` (`rust/src/error.rs`). Add new +variants there rather than introducing parallel error enums per module โ€” every +public failure mode is part of the API contract and should be expressible in one +type. Internal modules can use `thiserror` enums when a richer local taxonomy +helps; convert at the boundary. + +`anyhow` is reserved for binaries and example code. Library code never returns +`anyhow::Result` โ€” callers can't pattern-match on `anyhow::Error`, so it would +prevent them from handling specific failures. + +In production code, prefer `?`, `let-else`, and `if let`. Reach for `expect("โ€ฆ")` +when an invariant cannot fail and the message would help debug a future +regression. `unwrap()` belongs in tests only โ€” Clippy enforces this in the SDK +via `#![cfg_attr(test, allow(clippy::unwrap_used))]` in `lib.rs`. + +When you need to log on the way through, prefer +`.inspect_err(|e| warn!(error = ?e, "context"))?` over a `match` that logs and +re-wraps. It reads top-to-bottom and keeps the happy path uncluttered. + +## Async and concurrency + +The default for request-scoped I/O is `async fn` plus `.await` โ€” futures +inherit cancellation from their parent task and can borrow local references. +Reach for `tokio::spawn` only when you genuinely need background work (an event +loop, a long-lived watcher) and track the `JoinHandle` so you can cancel or join +it on shutdown. Fire-and-forget spawns silently swallow panics and outlive the +session; don't. + +Blocking calls (filesystem, subprocess wait) belong in +`tokio::task::spawn_blocking`, *not* on the async runtime. The blocking pool is +bounded, so for genuinely long-lived workers (think: file watchers that run for +the lifetime of a session) prefer `std::thread::spawn` with a channel back into +async land. + +Lock choice matters. `tokio::sync::Mutex` is correct when you must hold the +guard across `.await`; `parking_lot::Mutex` (or `RwLock`) is faster on hot +synchronous paths and is what `session.rs` uses for capability state. +`std::sync::Mutex` is rarely the right answer in this crate โ€” its poisoning +semantics buy us nothing and it's slower than `parking_lot`. Never hold a +`std::sync::Mutex` guard across an `.await`; Clippy will catch this, but the +fix is to move the await out, not silence the lint. + +For lazy statics use `std::sync::LazyLock`. The `once_cell` crate is no longer +needed. + +## Traits and conversions + +Plain functions on a type beat traits for navigability. Use them as the default. + +**Trait-based extension points are different.** When a consumer must plug behaviour into the SDK, prefer one trait with one default-impl method per event over per-event `Box` callback fields. This is what `tower_lsp::LanguageServer`, `rmcp::ServerHandler`, and `notify::EventHandler` do โ€” the dominant idiom in async Rust for "wire-protocol handler" traits. Callback fields fight `Send + Sync + 'static`, fragment consumer state across closures, and skip exhaustiveness checks. + +The four extension-point traits in this crate: + +- **`SessionHandler`** (`rust/src/handler.rs`) โ€” per-event methods (`on_permission_request`, `on_user_input`, `on_external_tool`, `on_elicitation`, `on_exit_plan_mode`, `on_auto_mode_switch`, `on_session_event`) each with a default impl. The dispatcher `on_event(HandlerEvent)` is itself a default method that fans out to them; override per-event methods in normal use, override `on_event` only when you want a single exhaustive match. Concurrent invocations are possible (notification-triggered events run on spawned tasks), so `Send + Sync + 'static` is required on the trait. +- **`SessionHooks`** (`rust/src/hooks.rs`) โ€” optional lifecycle callbacks. The SDK auto-enables hooks when an impl is supplied to `create_session` / `resume_session`. +- **`SystemMessageTransform`** (`rust/src/system_message.rs`) โ€” declare `section_ids()` and return content from `transform_section()`. +- **`ToolHandler`** (`rust/src/tool.rs`) โ€” client-side tool implementations, dispatched by name via `ToolHandlerRouter`. + +`ApproveAllHandler` is the standard test handler for `SessionHandler`. + +**Don't add traits without a clear extension story.** Don't implement `From`/`Into` for SDK-internal conversions: they can't take extra parameters, can't return `Result`, and hide which conversion is happening at call sites. Prefer named methods like `to_info(&self)` or `MyType::from_record(record, ctx)`. + +Trivial field re-shaping is best inlined. Closures should stay short (under ~10 lines); extract to named functions when they grow. Visitor patterns are a closure-fest โ€” expose `iter()` and let the consumer drive. + +## Concurrency primitives + +**Channels, not callback closures, for event flow.** Closures fight `Send + Sync + 'static` and don't compose with `select!`. Channel choice by semantics: + +| Use case | Primitive | +|---|---| +| One producer โ†’ one consumer with backpressure | `tokio::sync::mpsc` (cap 1) or `tokio::sync::oneshot` for single value | +| Many producers โ†’ one consumer | `tokio::sync::mpsc` | +| One producer โ†’ many consumers, every event delivered (pub/sub) | `tokio::sync::broadcast` | +| One producer โ†’ many consumers, only the latest value matters | `tokio::sync::watch` | + +For the **public** API, prefer returning `impl Stream` (wrap a `broadcast::Receiver` in `tokio_stream::wrappers::BroadcastStream`). `Stream` composes with `select!`, `take`, `map`, `filter`, `timeout`. See `EventSubscription` and `LifecycleSubscription`. + +**Cancellation: drop is the primitive; `tokio_util::sync::CancellationToken` for SDK-internal task coordination.** + +- **Caller-owned futures** (`send_message`, subscription streams): drop / `select!` / `tokio::time::timeout`. Don't accept a token parameter โ€” it duplicates what `select!` already provides. Document cancel-safety on every `.await` in the hot path. +- **SDK-internal tasks** (event loops, subprocess readers, anything `tokio::spawn`ed by the SDK): use `CancellationToken` stored on the long-lived handle. `Drop` calls `cancel()`. `Session::cancellation_token()` exposes a child token so callers can bind external work to the session lifetime. + +Refs: [`CancellationToken`][ctoken] ยท [`tonic` example][tonic-cancel] ยท [withoutboats: async clean-up][wb-cleanup] ยท [Cybernetist: cancellation patterns][cybernetist]. + +[ctoken]: https://docs.rs/tokio-util/latest/tokio_util/sync/struct.CancellationToken.html +[tonic-cancel]: https://github.com/hyperium/tonic/blob/master/examples/src/cancellation/server.rs +[wb-cleanup]: https://without.boats/blog/asynchronous-clean-up/ +[cybernetist]: https://cybernetist.com/2024/04/19/rust-tokio-task-cancellation-patterns/ + +## Optional fields and serde + +Use `Option` for optional fields, not nullable references or sentinel values. Defaults come from `Default` impls. Pair with `#[non_exhaustive]` on public config structs and a builder so adding fields stays non-breaking. + +For required builder fields: prefer `build() -> Result` over typestate unless required-field count is tiny (1-2). + +JSON: `#[serde(rename_all = "camelCase")]` at the type level, per-field `#[serde(rename = "โ€ฆ")]` for outliers, `#[serde(skip_serializing_if = "Option::is_none")]` for output, `#[serde(default)]` for input tolerance. Reach for `serde_with` only for non-trivial transforms (durations, base64, numeric-as-string keys). + +## Tracing โ€” `#[tracing::instrument]` is banned + +Banned via `clippy.toml`. Use manual spans with `error_span!`: + +- **Almost always use `error_span!`**, not `info_span!`. Span level controls + the *minimum* filter at which the span appears. An `info_span` disappears when + the filter is `warn` or `error` โ€” taking all child events with it, even + errors. `error_span!` ensures the span is always present. +- **Spawned tasks lose parent context.** Attach a span with `.instrument()` or + events inside won't correlate. +- **Never hold `span.enter()` guards across `.await`** โ€” use `.instrument(span)` + instead (also enforced by Clippy). + +```rust +use tracing::Instrument; + +async fn send_message(&self, session_id: &str, prompt: &str) -> Result<(), Error> { + let span = tracing::error_span!("send_message", session_id = %session_id); + async { /* body */ }.instrument(span).await +} + +let span = tracing::error_span!("event_loop", session_id = %id); +tokio::spawn(async move { run_loop().await }.instrument(span)); +``` + +Log with structured fields: `info!(session_id = %id, "Session created")`. +Static messages stay greppable; dynamic data goes in named fields, not +interpolated into the message string. + +## Idioms that don't port from other languages + +When porting from Node, Python, Go, or .NET: see the **Concurrency primitives** and **Traits and conversions** sections above. The two patterns that most reliably translate poorly are (1) per-event `Box` callback fields โ€” use a trait with default-impl methods (the `tower_lsp::LanguageServer` / `rmcp::ServerHandler` / `notify::EventHandler` shape) โ€” and (2) plumbing `context.Context` / `CancellationToken` through every call site โ€” drop-cancel for caller-owned futures, `tokio_util::sync::CancellationToken` for SDK-internal tasks. + +## Code organization + +- **Public API:** every `pub` item in the crate is part of the SDK's contract. + Adding a field to a `pub struct` is a breaking change unless the struct is + `#[non_exhaustive]` or constructors hide field-by-field literals. Prefer + `Default + ..Default::default()` patterns and document new fields with + rustdoc. +- **Generated code lives in `rust/src/generated/`** and must not be + hand-edited. Regenerate with `cd scripts/codegen && npm run generate:rust`. + When a generated type lacks a field the schema doesn't yet describe (e.g. + `Tool::overrides_built_in_tool`), hand-author the user-facing type in + `rust/src/types.rs` and stop re-exporting the generated one. +- **`#[expect(dead_code)]`** instead of `#[allow(dead_code)]` on individual + fields โ€” it forces a cleanup once the field gets used. +- **`..Default::default()`** โ€” avoid in production code (be explicit about + which fields you're setting); prefer it in tests and doc examples to keep + the focus on the values that matter for the test. +- **Import grouping** โ€” three blocks separated by blank lines: + (1) `std`/`core`/`alloc`, (2) external crates, (3) + `crate::`/`super::`/`self::`. Enforced by nightly `cargo fmt` via + `rust/.rustfmt.nightly.toml`. +- **`pub(crate)` vs `pub`** โ€” most modules in `lib.rs` are private (`mod`), so + `pub` items inside them are already crate-private. Use `pub(crate)` only when + you want to be explicit that an item must not become part of the public API. + +## Testing + +- **No mock testing.** Depend on real implementations, spin up lightweight + versions (e.g. `MockServer` in tests), or restructure code so the logic + under test takes its dependency's output as input. +- `assert_eq!(actual, expected)` โ€” actual first, for readable diffs. +- Tests at end of file: `#[cfg(test)] mod tests`. Never place production code + after the test module. +- Keep tests concurrent-safe โ€” unique temp dirs (`tempfile::tempdir()`), + unique data, no global state. +- `ApproveAllHandler` is the standard test handler for sessions that don't + exercise permission logic โ€” see `rust/src/handler.rs:174`. + +## Cross-platform + +The SDK ships on macOS, Windows, and Linux; CI exercises all three. Construct +paths with `Path::join` rather than string concatenation โ€” `/` and `\` are not +interchangeable, and string equality breaks on Windows UNC paths. Log paths +with `path.display()`; serialize with `to_string_lossy()` only when you need a +`String`. + +Process spawning needs care. The SDK applies `CREATE_NO_WINDOW` on Windows +when launching the CLI (see `Client::build_command`); preserve that if you +touch process spawning. Subprocess stdout often contains `\r` on Windows โ€” strip +or split on `\r?\n` rather than assuming `\n`. + +Tests must use `tempfile::tempdir()`, never hardcoded `/tmp/`, and any test +that asserts on a path string needs to normalize separators or use +`std::path::MAIN_SEPARATOR`. + +## Build speed + +Specify Tokio features explicitly โ€” never `features = ["full"]`. Iterate with +`cargo check`; reach for `cargo build` only when you need the binary. Audit +new dependency feature flags with `cargo tree` before committing. + +## Comments + +Explain **why**, never **what**. No comments that restate code. No decorative +banners (`// โ”€โ”€ Section โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€`). + +**Never compare to other SDKs in code comments or rustdoc.** Don't write +"Mirrors Node's `Foo`", "Like Go's `Bar`", "Unlike Python's `Baz`", or include +file/line citations into other SDKs (`nodejs/src/types.ts:1592`, `go/types.go:14`). +The Rust SDK seeks parity with the Node, Python, Go, and .NET SDKs, and that +fact is stated once at the top of `rust/README.md`. Intentional divergences +live in the README's "Differences From Other SDKs" section. Repeating the +relationship per-symbol is unscalable, drifts as the other SDKs evolve, and +adds noise to consumer-facing rustdoc โ€” Rust users care about the Rust API, +not its lineage. Self-references within the Rust crate (e.g. "Mirrors +[`from_streams`] but addsโ€ฆ") are fine. + +## Toolchain + +The SDK is pinned to `rust 1.94.0` via `rust/rust-toolchain.toml`. Formatting +uses nightly (`nightly-2026-04-14`) so unstable rustfmt options like grouped +imports work โ€” see `rust/.rustfmt.nightly.toml`. CI runs: + +```bash +cd rust +cargo +nightly-2026-04-14 fmt --check +cargo clippy --all-features --all-targets -- -D warnings +cargo test --all-features +``` + +Match those exact commands locally before pushing. + +## Codegen + +JSON-RPC and session-event types are generated from the Copilot CLI schema: + +| Source | Output | +|---|---| +| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | +| `nodejs/node_modules/@github/copilot/schemas/session-events.schema.json` | `rust/src/generated/session_events.rs` | + +Regenerate with: + +```bash +cd scripts/codegen && npm run generate:rust +``` + +Never hand-edit files under `rust/src/generated/`. If a generated type needs a +field the schema lacks, hand-author the user-facing type in `rust/src/types.rs` +and stop re-exporting the generated one. diff --git a/.github/skills/rust-coding-skill/examples.md b/.github/skills/rust-coding-skill/examples.md new file mode 100644 index 000000000..602d6ffcb --- /dev/null +++ b/.github/skills/rust-coding-skill/examples.md @@ -0,0 +1,184 @@ +# Rust Coding Skill โ€” Examples + +Patterns specific to the Rust SDK in this repo (`rust/`) that aren't obvious +from general Rust knowledge. + +## Defining a tool + +### Anti-pattern โ€” building the wire payload by hand + +```rust +let raw = serde_json::json!({ + "name": "get_weather", + "description": "...", + "parameters": { "type": "object", ... }, +}); +config.tools = Some(vec![serde_json::from_value(raw)?]); +``` + +### Preferred โ€” implement `ToolHandler`, route via `ToolHandlerRouter` + +```rust +use copilot::tool::{Tool, ToolHandler, ToolHandlerRouter, ToolInvocation, ToolResult}; +use copilot::Error; + +struct GetWeatherTool; + +#[async_trait::async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + description: "Get the current weather for a city.".to_string(), + // ..Default::default() โ€” leaves namespaced_name, instructions, + // overrides_built_in_tool, skip_permission at their defaults. + ..Default::default() + } + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + // ... + Ok(ToolResult::Text("...".into())) + } +} + +use copilot::handler::ApproveAllHandler; +use std::sync::Arc; + +let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(ApproveAllHandler), +); +``` + +## Spans for spawned event loops + +The session event loop is spawned per session. Always attach a span so events +emitted inside it correlate. + +### Anti-pattern โ€” losing parent context + +```rust +tokio::spawn(async move { + while let Some(event) = rx.recv().await { + info!("event {:?}", event); // No span โ€” can't filter by session + } +}); +``` + +### Preferred โ€” `error_span!` + `.instrument()` + +```rust +use tracing::Instrument; + +let span = tracing::error_span!("session_event_loop", session_id = %id); +tokio::spawn(async move { + while let Some(event) = rx.recv().await { + info!(event_type = ?event.kind, "session event"); + } +}.instrument(span)); +``` + +## Concurrent permission handlers + +`HandlerEvent::PermissionRequest` and `HandlerEvent::ExternalTool` are dispatched +on spawned tasks (see `rust/src/session.rs:973` and `:1022`). Implementations +must be safe for concurrent invocation. + +The `SessionHandler` trait declares `Send + Sync + 'static`, so the compiler +enforces this โ€” handlers with non-`Sync` state (e.g. `RefCell`, `Cell`, +`Rc`) won't compile. The examples below make the rejection mechanism explicit. + +### Won't compile โ€” non-`Sync` state + +```rust +struct MyHandler { + last_request: std::cell::RefCell>, // RefCell: !Sync +} + +#[async_trait] +impl SessionHandler for MyHandler { +// ^^^^^^^^^^^^^^ the trait `Sync` is not implemented for `RefCell<...>` + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { /* ... */ } +} +``` + +The error surfaces at the `impl` site, not at use site, because the trait's +`Send + Sync` bound makes `RefCell` ineligible for any field of any type that +implements `SessionHandler`. + +### Preferred โ€” `parking_lot::Mutex` or atomics + +```rust +struct MyHandler { + last_request: parking_lot::Mutex>, // Mutex: Sync if T: Send +} +``` + +## Adding a field to a public struct + +Adding a field to a public, non-exhaustive struct is a breaking change because +existing callers' struct literals stop compiling. Two patterns soften this: + +### Pattern 1 โ€” `Default` + `..Default::default()` in docs + +```rust +#[derive(Default)] +pub struct Tool { + pub name: String, + pub description: String, + // new field + pub overrides_built_in_tool: bool, +} + +// In docs and examples: +let t = Tool { + name: "x".into(), + description: "y".into(), + ..Default::default() +}; +``` + +### Pattern 2 โ€” `#[non_exhaustive]` for types callers shouldn't construct + +Use sparingly โ€” only for types that are *only* meant to be received from the +SDK, never built by users. + +```rust +#[non_exhaustive] +pub struct CreateSessionResult { + pub session_id: SessionId, + // ... +} +``` + +## Test handler for non-permission scenarios + +When a test doesn't exercise the permission flow, use the SDK's built-in +`ApproveAllHandler` instead of writing a custom one: + +```rust +use copilot::handler::ApproveAllHandler; +use copilot::types::SessionConfig; +use std::sync::Arc; + +let session = client + .create_session(SessionConfig::default().with_handler(Arc::new(ApproveAllHandler))) + .await?; +``` + +## Regenerating types after a schema bump + +```bash +# 1. Update schema (usually arrives with @github/copilot package update) +cd nodejs && npm install @github/copilot@latest && cd .. + +# 2. Regenerate Rust types +cd scripts/codegen && npm run generate:rust + +# 3. Verify +cd ../../rust && cargo check --all-features +``` + +If a generated type changes shape, hand-fix any user-facing wrappers in +`rust/src/types.rs` rather than monkey-patching the generated file. diff --git a/.github/workflows/codegen-check.yml b/.github/workflows/codegen-check.yml index 9fd7f0542..d48b6a491 100644 --- a/.github/workflows/codegen-check.yml +++ b/.github/workflows/codegen-check.yml @@ -13,6 +13,7 @@ on: - 'python/copilot/generated/**' - 'go/generated_*.go' - 'go/rpc/**' + - 'rust/src/generated/**' - '.github/workflows/codegen-check.yml' workflow_dispatch: @@ -34,6 +35,24 @@ jobs: with: go-version: '1.22' + # Rust generator runs `cargo fmt` on the output, so we need a toolchain with rustfmt. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt + + # Nightly rustfmt for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) โ€” pinned in + # `rust/.rustfmt.nightly.toml`. The Rust generator emits unconsolidated + # imports under stable rustfmt; nightly fmt consolidates them to match + # the canonical committed form. + - name: Install nightly rustfmt + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + - name: Install nodejs SDK dependencies working-directory: ./nodejs run: npm ci @@ -46,6 +65,10 @@ jobs: working-directory: ./scripts/codegen run: npm run generate + - name: Apply nightly rustfmt to generated Rust output + working-directory: ./rust + run: cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml + - name: Check for uncommitted changes run: | if [ -n "$(git status --porcelain)" ]; then diff --git a/.github/workflows/rust-publish-release.yml b/.github/workflows/rust-publish-release.yml new file mode 100644 index 000000000..348d2acf0 --- /dev/null +++ b/.github/workflows/rust-publish-release.yml @@ -0,0 +1,54 @@ +name: "Rust SDK: Publish Release" + +# Publishes the `copilot-sdk` crate to crates.io when a release-plz +# version-bump PR is merged to `main`. See rust/RELEASING.md for the +# full release process and one-time setup (CARGO_REGISTRY_TOKEN, etc). + +on: + push: + branches: + - main + paths: + - 'rust/Cargo.toml' + - 'rust/Cargo.lock' + - 'rust/release-plz.toml' + workflow_dispatch: + +permissions: + contents: write + +concurrency: + group: rust-release-plz-publish + cancel-in-progress: false + +jobs: + publish: + name: Publish to crates.io + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: Run release-plz release + uses: release-plz/action@v0.5 + with: + command: release + manifest_path: rust/Cargo.toml + config: rust/release-plz.toml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.github/workflows/rust-release-pr.yml b/.github/workflows/rust-release-pr.yml new file mode 100644 index 000000000..41420f3e4 --- /dev/null +++ b/.github/workflows/rust-release-pr.yml @@ -0,0 +1,56 @@ +name: "Rust SDK: Create Release PR" + +# release-plz opens a PR that bumps the `copilot-sdk` version in +# `rust/Cargo.toml` and updates `rust/CHANGELOG.md` based on +# conventional-commit history since the last `rust-vX.Y.Z` tag. +# +# Review and merge that PR on the maintainer's schedule. Publishing to +# crates.io happens separately in `rust-publish-release.yml` once the +# version bump lands on `main`. +# +# Runs manually only โ€” we don't want a PR to race with every push. + +on: + workflow_dispatch: + +permissions: + contents: write + pull-requests: write + +concurrency: + group: rust-release-plz-pr + cancel-in-progress: false + +jobs: + release-pr: + name: Create Release PR + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: Run release-plz release-pr + uses: release-plz/action@v0.5 + with: + command: release-pr + manifest_path: rust/Cargo.toml + config: rust/release-plz.toml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # CARGO_REGISTRY_TOKEN is not required for release-pr (no publish), + # but release-plz inspects the crate on crates.io to compute the + # next version. Public crate inspection doesn't need auth. diff --git a/.github/workflows/rust-sdk-tests.yml b/.github/workflows/rust-sdk-tests.yml new file mode 100644 index 000000000..201841784 --- /dev/null +++ b/.github/workflows/rust-sdk-tests.yml @@ -0,0 +1,170 @@ +name: "Rust SDK Tests" + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + paths: + - 'rust/**' + - 'test/**' + - 'nodejs/package.json' + - '.github/workflows/rust-sdk-tests.yml' + - '.github/actions/setup-copilot/**' + - '!**/*.md' + - '!**/LICENSE*' + - '!**/.gitignore' + - '!**/.editorconfig' + - '!**/*.png' + - '!**/*.jpg' + - '!**/*.jpeg' + - '!**/*.gif' + - '!**/*.svg' + workflow_dispatch: + merge_group: + +permissions: + contents: read + +jobs: + test: + name: "Rust SDK Tests" + env: + POWERSHELL_UPDATECHECK: Off + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + + - uses: ./.github/actions/setup-copilot + id: setup-copilot + + # rust-toolchain.toml in rust/ pins the stable channel + components. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt, clippy + + # Nightly rustfmt for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) โ€” pinned in + # `.rustfmt.nightly.toml`. + - name: Install nightly rustfmt + if: runner.os == 'Linux' + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: cargo fmt --check (nightly) + if: runner.os == 'Linux' + run: cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml --check + + - name: cargo clippy + if: runner.os == 'Linux' + run: cargo clippy --all-targets --features test-support -- --no-deps -D warnings -D clippy::unwrap_used -D clippy::disallowed_macros -D clippy::await_holding_invalid_type + + - name: cargo doc + if: runner.os == 'Linux' + env: + RUSTDOCFLAGS: "-D warnings" + run: cargo doc --no-deps --all-features + + - name: Install test harness dependencies + working-directory: ./test/harness + run: npm ci --ignore-scripts + + - name: Warm up PowerShell + if: runner.os == 'Windows' + run: pwsh.exe -Command "Write-Host 'PowerShell ready'" + + - name: cargo test + env: + COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} + COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} + run: cargo test --features test-support + + # Detects accidental public-API breakage against the crate's last + # published version on crates.io. Non-blocking until the crate has + # a first published release โ€” once a 0.1.0 ships, flip + # `continue-on-error` to `false` to enforce SemVer. + - name: cargo semver-checks + if: runner.os == 'Linux' + continue-on-error: true + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + package: github-copilot-sdk + manifest-path: rust/Cargo.toml + + # Validates the `embedded-cli` build path on all three supported + # platforms. This is the only place `build.rs` actually runs (the + # default `cargo test` job above has `COPILOT_CLI_VERSION` unset, so + # `build.rs` returns immediately). Catches regressions in the + # download / verify / extract / embed pipeline before they ship to + # crates.io and before bundling consumers (e.g. github-app's + # bundled-CLI release pipeline) hit them downstream. + bundle: + name: "Rust SDK Bundled CLI Build" + env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + key: bundled-cli + + - name: Read pinned @github/copilot CLI version + id: cli-version + working-directory: ./nodejs + run: | + version=$(node -p "require('./package.json').dependencies['@github/copilot'].replace(/^[\^~]/, '')") + echo "version=$version" >> "$GITHUB_OUTPUT" + echo "Pinned CLI version: $version" + + # Cache the downloaded archive across runs so we don't refetch + # ~130 MB on every CI invocation. Keyed by OS + CLI version; on + # cache miss the bundle job exercises the full ureq download + + # SHA-256 + retry path, which is exactly the regression surface + # we want validated. + - name: Cache bundled CLI tarball + uses: actions/cache@v4 + with: + path: ./rust/.bundled-cli-cache + key: bundled-cli-${{ matrix.os }}-${{ steps.cli-version.outputs.version }} + + - name: cargo build --features embedded-cli + env: + COPILOT_CLI_VERSION: ${{ steps.cli-version.outputs.version }} + BUNDLED_CLI_CACHE_DIR: ${{ github.workspace }}/rust/.bundled-cli-cache + run: cargo build --features embedded-cli diff --git a/.github/workflows/scenario-builds.yml b/.github/workflows/scenario-builds.yml index ae368075c..923560aba 100644 --- a/.github/workflows/scenario-builds.yml +++ b/.github/workflows/scenario-builds.yml @@ -9,6 +9,8 @@ on: - "python/copilot/**" - "go/**/*.go" - "dotnet/src/**" + - "rust/src/**" + - "rust/Cargo.toml" - ".github/workflows/scenario-builds.yml" push: branches: @@ -185,3 +187,46 @@ jobs: echo -e "Failures:$FAILURES" exit 1 fi + + # โ”€โ”€ Rust โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + build-rust: + name: "Rust scenarios" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: dtolnay/rust-toolchain@1.94.0 + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + test/scenarios/**/rust/target + key: ${{ runner.os }}-cargo-scenarios-${{ hashFiles('rust/Cargo.toml', 'test/scenarios/**/rust/Cargo.toml') }} + restore-keys: | + ${{ runner.os }}-cargo-scenarios- + + - name: Build all Rust scenarios + run: | + PASS=0; FAIL=0; FAILURES="" + for manifest in $(find test/scenarios -path '*/rust/Cargo.toml' | sort); do + dir=$(dirname "$manifest") + scenario="${dir#test/scenarios/}" + echo "::group::$scenario" + if (cd "$dir" && cargo build --quiet 2>&1); then + echo "โœ… $scenario" + PASS=$((PASS + 1)) + else + echo "โŒ $scenario" + FAIL=$((FAIL + 1)) + FAILURES="$FAILURES\n $scenario" + fi + echo "::endgroup::" + done + echo "" + echo "Rust builds: $PASS passed, $FAIL failed" + if [ "$FAIL" -gt 0 ]; then + echo -e "Failures:$FAILURES" + exit 1 + fi diff --git a/.github/workflows/update-copilot-dependency.yml b/.github/workflows/update-copilot-dependency.yml index a39d0575e..05833bf73 100644 --- a/.github/workflows/update-copilot-dependency.yml +++ b/.github/workflows/update-copilot-dependency.yml @@ -40,6 +40,22 @@ jobs: with: dotnet-version: "10.0.x" + # Rust generator runs `cargo fmt` on its output under stable rustfmt; + # nightly rustfmt is needed for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) pinned in + # `rust/.rustfmt.nightly.toml`. See codegen-check.yml for the same step. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt + + - name: Install nightly rustfmt + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + - name: Update @github/copilot in nodejs env: VERSION: ${{ inputs.version }} @@ -68,6 +84,7 @@ jobs: run: | cd nodejs && npx prettier --write "src/generated/**/*.ts" cd ../dotnet && dotnet format src/GitHub.Copilot.SDK.csproj + cd ../rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml - name: Create pull request env: diff --git a/.gitignore b/.gitignore index a445051c6..ba3ebfcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,9 @@ docs/.validation/ .DS_Store +# Rust scenario build artifacts +test/scenarios/**/rust/target/ +test/scenarios/**/rust/Cargo.lock + # Visual Studio .vs/ diff --git a/justfile b/justfile index 5bb0ce0fa..ab97c1d3d 100644 --- a/justfile +++ b/justfile @@ -3,13 +3,13 @@ default: @just --list # Format all code across all languages -format: format-go format-python format-nodejs format-dotnet +format: format-go format-python format-nodejs format-dotnet format-rust # Lint all code across all languages -lint: lint-go lint-python lint-nodejs lint-dotnet +lint: lint-go lint-python lint-nodejs lint-dotnet lint-rust # Run tests for all languages -test: test-go test-python test-nodejs test-dotnet test-corrections +test: test-go test-python test-nodejs test-dotnet test-rust test-corrections # Format Go code format-go: @@ -71,6 +71,27 @@ test-dotnet: @echo "=== Testing .NET code ===" @cd dotnet && dotnet test test/GitHub.Copilot.SDK.Test.csproj +# Format Rust code (uses nightly for unstable formatting options) +format-rust: + @echo "=== Formatting Rust code ===" + @cd rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml + +# Lint Rust code +lint-rust: + @echo "=== Linting Rust code ===" + @cd rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml --check + @cd rust && cargo clippy --all-targets --features test-support -- --no-deps -D warnings -D clippy::unwrap_used -D clippy::disallowed_macros -D clippy::await_holding_invalid_type + +# Test Rust code +test-rust: + @echo "=== Testing Rust code ===" + @cd rust && cargo test --features test-support + +# Generate Rust types from JSON schemas +generate-rust: + @echo "=== Generating Rust types ===" + @cd scripts/codegen && npm run generate:rust + # Test correction collection scripts test-corrections: @echo "=== Testing correction scripts ===" diff --git a/nodejs/scripts/update-protocol-version.ts b/nodejs/scripts/update-protocol-version.ts index a18a560c7..ef3ac9a2f 100644 --- a/nodejs/scripts/update-protocol-version.ts +++ b/nodejs/scripts/update-protocol-version.ts @@ -117,4 +117,22 @@ internal static class SdkProtocolVersion fs.writeFileSync(path.join(rootDir, "dotnet", "src", "SdkProtocolVersion.cs"), csharpCode); console.log(" โœ“ dotnet/src/SdkProtocolVersion.cs"); +// Generate Rust +const rustCode = `// Code generated by update-protocol-version.ts. DO NOT EDIT. + +//! The SDK protocol version. Must match the version expected by the +//! copilot-agent-runtime server. + +/// The SDK protocol version. +pub const SDK_PROTOCOL_VERSION: u32 = ${version}; + +/// Returns the SDK protocol version. +#[must_use] +pub const fn get_sdk_protocol_version() -> u32 { + SDK_PROTOCOL_VERSION +} +`; +fs.writeFileSync(path.join(rootDir, "rust", "src", "sdk_protocol_version.rs"), rustCode); +console.log(" โœ“ rust/src/sdk_protocol_version.rs"); + console.log("Done!"); diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000..c17da7f58 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock.bak diff --git a/rust/.rustfmt.nightly.toml b/rust/.rustfmt.nightly.toml new file mode 100644 index 000000000..677b79658 --- /dev/null +++ b/rust/.rustfmt.nightly.toml @@ -0,0 +1,7 @@ +# These options are only available in nightly, but it should be fine to use nightly for just formatting. +group_imports = "StdExternalCrate" +imports_granularity = "Module" +reorder_impl_items = true + +# stable options +edition = "2024" diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 000000000..f3fb29261 --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,15 @@ +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.4.32&search=#group_imports +# In the mean time it is commented out because it will cause warnings. +#group_imports = "StdExternalCrate" + +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.5.1&search=#imports_granularity +# In the mean time it is commented out because it will cause warnings. +#imports_granularity = "Module" + +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.4.36&search=order#reorder_impl_items +# In the mean time it is commented out because it will cause warnings. +#reorder_impl_items = true +edition = "2024" diff --git a/rust/CHANGELOG.md b/rust/CHANGELOG.md new file mode 100644 index 000000000..a9d0eb850 --- /dev/null +++ b/rust/CHANGELOG.md @@ -0,0 +1,518 @@ +# Changelog + +All notable changes to the `github-copilot-sdk` crate will be documented in this file. + +The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +After 0.1.0 ships, [release-plz](https://release-plz.dev/) will prepend new +entries from conventional-commit history. The Unreleased entry below is +hand-curated so that crates.io readers get a usable summary of the public +surface on first publish, not a flat list of merge commits โ€” release-plz +will rename `[Unreleased]` to `[0.1.0] - ` and add a fresh empty +`[Unreleased]` above it when it cuts the first release PR. + +## [Unreleased] + +Initial public release. Programmatic Rust access to the GitHub Copilot CLI +over JSON-RPC 2.0 (stdio or TCP), with handler-based event dispatch, typed +tool/permission/elicitation helpers, and runtime session management. + +This is a **technical preview**. The crate is pre-1.0 and the public API may +change in breaking ways before 1.0. The rendered docs on +[docs.rs](https://docs.rs/github-copilot-sdk) are the canonical reference for the +public surface. + +### Added + +#### Client lifecycle +- `Client::start` โ€” spawn and manage a GitHub Copilot CLI child process. +- `Client::from_streams` โ€” connect to a CLI server over caller-supplied + `AsyncRead`/`AsyncWrite` (testing, custom transports). +- `Client::stop` / `Client::force_stop` โ€” graceful and immediate shutdown. +- `Client::state` returning `ConnectionState` (`Connecting`, `Connected`, + `Disconnecting`, `Disconnected`). +- `Client::subscribe_lifecycle` returning a `LifecycleSubscription` for + runtime observation of created / destroyed / foreground / background + events. Implements `tokio_stream::Stream` and offers an inherent + `recv()`; drop the value to unsubscribe. +- `Client::ping(message)` returning typed `PingResponse` and + `Client::verify_protocol_version` for handshake validation. +- `Client::list_sessions`, `get_session_metadata`, `delete_session`, + `get_last_session_id`, `get_foreground_session_id`, + `set_foreground_session_id`. +- `Client::list_models`, `get_status` (typed `GetStatusResponse`), + `get_auth_status` (typed `GetAuthStatusResponse`), `get_quota`, + `send_telemetry`. + +#### Sessions +- `Client::create_session` and `Client::resume_session` accepting + `SessionConfig` with handler, capabilities, system message, mode, model, + permission policy, working directory, and resume parameters. +- `Session::send` returning the assigned message ID for + correlation with later events. +- `Session::send_and_wait` for synchronous prompt โ†’ final-event flows. +- `Session::subscribe` returning an `EventSubscription` for observe-only + access to the session's event stream. Implements `tokio_stream::Stream` + and offers an inherent `recv()`; drop the value to unsubscribe. +- Mode + model controls: `get_mode` / `set_mode`, `get_model` / + `set_model(model, SetModelOptions)` with `reasoning_effort` and + `model_capabilities` overrides. +- Plan helpers: `read_plan`, `delete_plan`. +- Workspace helpers: `list_workspace_files`, `read_workspace_file`, + `create_workspace_file`, `cwd`, `remote_url`. +- UI primitives: `session.ui().elicitation()`, `confirm()`, `select()`, + `input()` โ€” grouped under a `SessionUi` sub-API to mirror .NET / Python / + Go. +- `Session::log(message, LogOptions)` with optional severity and + ephemeral flag. +- `Session::send_telemetry`, `start_fleet`, `abort`, + `set_approve_all_permissions`, `set_name`. +- `Session::disconnect` (canonical) and `Session::destroy` (alias) + preserve on-disk session state for later resume. +- `Session::stop_event_loop` for shutting down the per-session loop. +- `Session::cancellation_token()` returns a [`tokio_util::sync::CancellationToken`] + child token that fires when the session shuts down (via + `stop_event_loop`, `destroy`, or `Drop`). Lets external tasks bind their + lifetime to a session via `tokio::select!` without taking a strong + reference to the session. Cancelling the returned child token does not + shut the session down โ€” only `stop_event_loop` (or dropping the session) + does. + +#### Handlers + helpers +- `SessionHandler` trait with default fallback impls for each event + (permissions, external tools, elicitation, plan-mode prompts). +- `ApproveAllHandler` / `DenyAllHandler` reference handlers. +- Permission policy helpers: `permission::approve_all`, + `permission::deny_all`, `permission::approve_if`, plus chainable + builders on `SessionConfig` (`approve_all_permissions`, + `deny_all_permissions`, `approve_if`). +- `PermissionResult` is `#[non_exhaustive]` and supports `Approved`, + `Denied`, `Deferred` (handler will resolve via + `handlePendingPermissionRequest` itself โ€” notification path only; + direct RPC falls back to `Approved`), and + `Custom(serde_json::Value)` for response shapes beyond + `{ "kind": "approve-once" | "reject" }` (e.g. allowlist payloads). +- All extension-point and protocol-evolving public enums are + `#[non_exhaustive]` so future variants are additive (non-breaking): + `Error`, `ProtocolError`, `SessionError`, `Transport`, `Attachment`, + `ToolResult`, `ElicitationMode`, `InputFormat`, `GitHubReferenceType`, + `SessionLifecycleEventType`, plus the handler/hook event/response enums. + Closed taxonomies (`LogLevel`, `ConnectionState`, `CliProgram`) remain + exhaustive so callers benefit from compile-time exhaustiveness checks. +- Tool helpers: `tool::DefineTool`, `tool::tool_schema_for`, + `tool::ToolHandlerRouter`, derive support via `derive` feature. + `ToolHandlerRouter` overrides each `SessionHandler` per-event method + directly, so callers can use the narrow-typed entry points (e.g. + `router.on_external_tool(invocation).await -> ToolResult`) instead of + unwrapping a `HandlerResponse` from `on_event`. The default `on_event` + still routes correctly through the per-event methods, so legacy + callers are unaffected. +- Hooks API for instrumenting send/receive flows (`github_copilot_sdk::hooks`). +- `SessionHandler::on_auto_mode_switch` โ€” typed handler for the CLI's + rate-limit-recovery prompt (`autoModeSwitch.request` JSON-RPC + callback, added in copilot-agent-runtime PR #7024). Returns a typed + [`AutoModeSwitchResponse`] enum with `Yes`, `YesAlways`, `No` + variants (`#[serde(rename_all = "snake_case")]`, wire values byte- + identical to the runtime's `"yes" | "yes_always" | "no"` schema). + Default impl declines (`No`); override only if your application + surfaces a UX for the prompt. `SessionConfig::request_auto_mode_switch` + and `ResumeSessionConfig::request_auto_mode_switch` default to + `Some(true)` so the CLI advertises the callback to the SDK out of the + box. **Cross-SDK divergence:** typed handler is Rust-only as of 0.1.0. + Node, Python, Go, and .NET observe the request as a raw JSON-RPC + callback today; parity ports for those SDKs are post-release follow-up + work. +- New session-event fields surfaced by the `@github/copilot ^1.0.39` + schema bump: + - `SessionErrorData.eligible_for_auto_switch: Option` โ€” set on + `errorType: "rate_limit"` to signal the runtime will follow with an + `auto_mode_switch.requested` event. UI clients can suppress + duplicate rendering of the rate-limit error when they show their + own auto-mode-switch prompt. + - `SessionErrorData.error_code: Option` โ€” fine-grained + upstream provider error code (e.g. + `"user_weekly_rate_limited"`, `"integration_rate_limited"`). + - `SessionModelChangeData.cause: Option` โ€” + `"rate_limit_auto_switch"` for changes triggered by the + auto-mode-switch recovery path. Lets UI render contextual copy. + - `AutoModeSwitchRequestedData.retry_after_seconds: Option` โ€” + seconds until the rate limit resets, when known. Clients can + render a humanized reset time alongside the prompt. (The request- + callback path's `retry_after_seconds` parameter on + [`SessionHandler::on_auto_mode_switch`](crate::handler::SessionHandler::on_auto_mode_switch) + uses `Option` for HTTP `Retry-After` `delta-seconds` + semantics.) + +#### Types +- Newtype `SessionId`, plus generated RPC types under `github_copilot_sdk::generated`. +- `LogLevel`, `LogOptions`, `SetModelOptions`, `PingResponse`, + `SessionLifecycleEvent`, `SessionLifecycleEventType`, `ConnectionState`, + `SessionTelemetryEvent`, `ServerTelemetryEvent`, `SystemMessageConfig`, + `MessageOptions`, `SectionOverride`, `Attachment`, + `InputFormat`, `InputOptions`. +- Strongly-typed `Error` and `ProtocolError` with `is_transport_failure` + classifier and `error_codes` constants. + +#### Typed RPC namespace +- `Client::rpc()` and `Session::rpc()` accessors exposing a generated, typed + view over the full GitHub Copilot CLI JSON-RPC API. Sub-namespaces mirror the + schema (e.g. `client.rpc().models().list()`, `session.rpc().workspaces() + .list_files()`, `session.rpc().agent().list()`, + `session.rpc().tasks().list()`). +- All hand-authored helpers (`list_workspace_files`, `read_plan`, `set_mode`, + `list_models`, `get_quota`, etc.) are now thin one-line delegations over + this namespace. Wire-method strings exist in exactly one place + (`generated/rpc.rs`), making typo bugs like the `session.workspace.*` + โ†’ `session.workspaces.*` regression structurally impossible. Public + helper signatures are unchanged. + +#### Configuration parity +- All remaining public configuration types are now `#[non_exhaustive]` + for forward-compatibility โ€” adding fields post-1.0 is non-breaking on + consumers that construct via `Default::default()` plus field + assignment or the `with_*` builders. Affected: `SessionConfig`, + `ResumeSessionConfig`, `ClientOptions`, `ProviderConfig`, + `McpServerConfig`, `Tool`, `CustomAgentConfig`, + `InfiniteSessionConfig`, `SystemMessageConfig`, `ConnectionState`. + (`HookEvent`, `HookOutput`, `MessageOptions`, `TelemetryConfig`, + `SessionFsConfig`, `FsError`, `FileInfo`, `DirEntry`, `ToolInvocation`, + `Error`, `Transport`, `DeliveryMode` were already marked.) Callers + using exhaustive struct literals must switch to + `let mut x = Type::default(); x.field = ...;` or the available `with_*` + builders; `..Default::default()` no longer compiles for these types + outside the defining crate. +- `MessageOptions::mode` is now typed `Option` (was + `Option`). `DeliveryMode` is `#[non_exhaustive]` and serializes + to the wire strings `"enqueue"` (default) and `"immediate"`. The prior + rustdoc incorrectly described this field as a permission mode; the + field controls how the prompt is delivered relative to in-flight work. + `MessageOptions::with_mode` now takes `DeliveryMode` directly. Callers + that previously passed `"agent"` or `"autopilot"` were already silently + no-ops at the CLI level โ€” switch to a `DeliveryMode` variant or omit + the field entirely. +- `SessionConfig::default()` and `ResumeSessionConfig::new()` now set the + four permission-flow flags (`request_user_input`, `request_permission`, + `request_exit_plan_mode`, `request_elicitation`) to `Some(true)` instead + of `None`. Mirrors Node's `client.ts` behavior of always advertising the + permission surface and deriving handler presence from the + `SessionHandler` impl. The default `DenyAllHandler` refuses all + permission requests so the wire surface is safe out-of-the-box; callers + that want the wire surface fully disabled set the flags explicitly to + `Some(false)`. +- `SessionListFilter` โ€” typed filter for `Client::list_sessions` covering + `cwd`, `git_root`, `repository`, and `branch`. Replaces the prior + `Option` parameter. +- `McpServerConfig` tagged enum (`Stdio` / `Http` / `Sse`) with + `McpStdioServerConfig` and `McpHttpServerConfig` payload structs. + `SessionConfig::mcp_servers`, `ResumeSessionConfig::mcp_servers`, and + `CustomAgentConfig::mcp_servers` are now `Option>` instead of typeless `Value` maps. Stdio configurations + serialized by older callers (no explicit `type`, or `type: "local"`) are + accepted on the deserialize path. +- `PermissionRequestData` gains typed `kind: Option` + and `tool_call_id: Option` fields covering the eight CLI + permission categories (`shell`, `write`, `read`, `url`, `mcp`, + `custom-tool`, `memory`, `hook`); unknown values fall through to + `PermissionRequestKind::Unknown` for forward compatibility. The original + params object is still available via the existing `extra: Value` flatten. +- `PermissionResult` gains `UserNotAvailable` (sent as + `{ "kind": "user-not-available" }`) and `NoResult` (sent as + `{ "kind": "no-result" }`) variants for headless agents and explicit + fall-through-to-CLI-default responses. +- `Client::stop` cooperatively shuts down active sessions before killing + the CLI child: walks every session still registered with the client, + sends `session.destroy` for each, then kills the child. Errors from + per-session destroys and the terminal child-kill are collected into a + new `StopErrors` aggregate (`Result<(), StopErrors>`) instead of + short-circuiting on the first failure, mirroring the Node SDK's + `Error[]` return shape. `StopErrors` implements `std::error::Error` + and exposes `errors()` / `into_errors()` for inspection. Callers that + previously used `client.stop().await?` should switch to + `client.stop().await.ok();` (best-effort) or match on the aggregate. +- `ResumeSessionConfig::disable_resume: Option` โ€” force-fail resume + if the session does not exist on disk, instead of silently starting a + new session. +- `SessionConfig` and `ResumeSessionConfig` gain six configuration knobs + matching the Node SDK shape (Bucket B.1): + - `session_id: Option` (SessionConfig only โ€” required on + resume, where it remains `SessionId`) โ€” supply a custom session ID + instead of letting the CLI generate one. + - `working_directory: Option` โ€” per-session cwd override, + independent of [`ClientOptions::cwd`](crate::ClientOptions::cwd). + - `config_dir: Option` โ€” override the default configuration + directory location for this session. + - `model_capabilities: Option` โ€” per-property + overrides for model capabilities, deep-merged over runtime defaults. + The same type was previously available only on + `SetModelOptions::model_capabilities`. + - `github_token: Option` โ€” per-session GitHub token. Distinct + from [`ClientOptions::github_token`], which authenticates the CLI + process; this token determines the GitHub identity used for content + exclusion, model routing, and quota checks for this session. The + field is redacted from the `Debug` output. + - `include_sub_agent_streaming_events: Option` โ€” forward streaming + delta events from sub-agents to this connection (Node default: true). +- `ClientOptions` gains the simple subset of Node's + `CopilotClientOptions` knobs (Bucket B.2): + - `log_level: Option` โ€” typed enum (`None`, `Error`, `Warning`, + `Info`, `Debug`, `All`) replacing the previously hard-coded + `--log-level info` argument. When unset, the SDK still passes + `--log-level info` for parity with prior behavior. + - `session_idle_timeout_seconds: Option` โ€” server-wide idle + timeout for sessions in seconds. When `Some(n)` with `n > 0`, the + SDK passes `--session-idle-timeout `. `None` or `Some(0)` leaves + sessions running indefinitely (the CLI default). + - The Node knob `isChildProcess` (sub-CLI parent-stdio mode) and + `autoStart` (lazy-init pattern) are intentionally **not** ported โ€” + `isChildProcess` requires a transport variant the Rust SDK does not + yet support; `autoStart` does not apply because [`Client::start`] is + a single explicit constructor rather than a deferred-init pattern. + - `on_list_models: Option>` โ€” BYOK escape + hatch matching Node's `onListModels`. When set, [`Client::list_models`] + returns the handler's result without making a `models.list` RPC. + `ListModelsHandler` is a new public `async_trait` (mirrors the shape + of `SessionHandler` / `SessionHooks`) with a single + `async fn list_models(&self) -> Result, Error>` method. + `ClientOptions` switched from `#[derive(Debug)]` to a manual `Debug` + impl that prints the handler as `` / `None` (same precedent as + `SessionConfig::handler` and `github_token`). +- `MessageOptions` gains `request_headers: Option>` + with a corresponding [`MessageOptions::with_request_headers`] builder + method, matching Node's `MessageOptions.requestHeaders` and Go's + `MessageOptions.RequestHeaders`. Custom HTTP headers are forwarded to + the CLI via the `requestHeaders` field on `session.send`. The field is + omitted from the wire when `None` or empty (matches Node's + `omitempty` semantics). +- Slash command registration: new [`CommandHandler`] async trait, + [`CommandDefinition`] (with `new`/`with_description` builders), and + [`CommandContext`] (`session_id`, `command`, `command_name`, `args`) + hand-authored in `crate::types`. `SessionConfig::commands` and + `ResumeSessionConfig::commands` accept a `Vec` via + the new `with_commands` builder, matching Node's + `SessionConfig.commands`, Python's `SessionConfig.commands`, and Go's + `SessionConfig.Commands`. The SDK serializes only `{name, description?}` + on the wire (handlers stay client-side), and dispatches incoming + `command.execute` events to the registered handler โ€” acking with no + error on success, `error: ` on `Err`, and + `error: "Unknown command: "` when the name is unregistered. + `CommandContext` and `CommandDefinition` are `#[non_exhaustive]` so + forward-compatible fields (e.g. aliases, completion providers) can land + without breaking callers. +- Custom session filesystem: new [`SessionFsProvider`] async trait, + [`SessionFsConfig`], [`FsError`], [`FileInfo`], [`DirEntry`], + [`DirEntryKind`], and [`SessionFsConventions`] in `crate::session_fs` + (also re-exported from `crate::types`). When [`ClientOptions::session_fs`] + is set, [`Client::start`] calls `sessionFs.setProvider` on the CLI to + delegate per-session filesystem operations to a provider supplied via + [`SessionConfig::with_session_fs_provider`] / + [`ResumeSessionConfig::with_session_fs_provider`]. Inbound `sessionFs.*` + requests dispatch to the provider; `FsError::NotFound` maps to the wire + `ENOENT` code and other `FsError` values map to `UNKNOWN`. + `From` is provided so handlers backed by `std::fs` / + `tokio::fs` can propagate errors with `?`. All trait methods have + default implementations returning `Err(FsError::Other("not supported"))`, + so providers only override the methods they need and forward-compatible + schema additions land without breaking existing implementations. + Diverges from Node/Python/Go's factory-closure pattern in favor of + direct `Arc` registration. +- W3C Trace Context propagation: new [`TraceContext`] struct and + [`TraceContextProvider`] async trait in `crate::trace_context` (also + re-exported from `crate::types`). Hybrid shape combines Node's + callback-based `onGetTraceContext` and Go's per-turn + `MessageOptions.Traceparent` / `Tracestate`: + [`ClientOptions::on_get_trace_context`] supplies an ambient provider that + injects `traceparent` / `tracestate` on `session.create`, + `session.resume`, and `session.send`, while + [`MessageOptions::with_traceparent`], [`MessageOptions::with_tracestate`], + and [`MessageOptions::with_trace_context`] override per-turn (override + wins; provider is not invoked when MessageOptions carries trace headers). + [`ToolInvocation`] is now `#[non_exhaustive]` and exposes inbound + `traceparent` / `tracestate` populated from `external_tool.requested` + events, plus a [`ToolInvocation::trace_context`] helper. Wire fields are + omitted when unset (matches Node/Go `omitempty` semantics). +- `ToolInvocation` and `SessionId` now derive `Default`. Production code + never constructs `ToolInvocation` literals (it's a CLI-emitted read-only + type), but downstream test scaffolding can now use + `ToolInvocation { tool_name: "...".into(), ..Default::default() }` and + absorb future `#[non_exhaustive]` field additions automatically. +- OpenTelemetry env-var passthrough: new [`TelemetryConfig`] struct and + [`OtelExporterType`] enum (both `#[non_exhaustive]`), wired on + [`ClientOptions::telemetry`]. When `Some(...)`, the SDK injects + `COPILOT_OTEL_ENABLED=true` plus `OTEL_EXPORTER_OTLP_ENDPOINT`, + `COPILOT_OTEL_FILE_EXPORTER_PATH`, `COPILOT_OTEL_EXPORTER_TYPE`, + `COPILOT_OTEL_SOURCE_NAME`, and + `OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT` into the spawned CLI + process โ€” verbatim env-var names matching Node/Python/Go. Pure + passthrough: no `opentelemetry-rust` dependency; the CLI itself owns the + exporter. `exporter_type` is a typed enum (`OtlpHttp` / `File`) following + the [`LogLevel`](LogLevel) precedent for finite, enumerated CLI knobs; + serialized verbatim as `"otlp-http"` / `"file"`. User-supplied + `ClientOptions::env` continues to win over telemetry-injected values. +- `ClientOptions::copilot_home: Option` (and + `with_copilot_home`) โ€” overrides the directory where the CLI persists + its state. Exported as `COPILOT_HOME` to the spawned CLI process. + Useful for sandboxing test runs or running multiple isolated SDK + instances side-by-side. Mirrors Node `copilotHome` / + Python `copilot_home`. +- `ClientOptions::tcp_connection_token: Option` (and + `with_tcp_connection_token`) โ€” optional auth token for TCP transport. + Sent in the new `connect` JSON-RPC handshake (with backward-compat + fall-back to `ping` for legacy CLI servers) and exported as + `COPILOT_CONNECTION_TOKEN` to spawned CLI processes. When the SDK + spawns its own CLI in TCP mode and this is left unset, a UUID is + generated automatically so the loopback listener is safe by default. + Combining with `Transport::Stdio` returns + `Error::InvalidConfig` from `Client::start`. +- `SessionConfig::instruction_directories: Option>` and + `ResumeSessionConfig::instruction_directories` (plus + `with_instruction_directories` builders on both) โ€” additional + directories searched for custom instruction files. Distinct from + `skill_directories`. Forwarded to the CLI on session create / resume. +- `Error::InvalidConfig(String)` variant for client-construction errors + that surface from `Client::start` (e.g. `tcp_connection_token` paired + with `Transport::Stdio`, empty token, etc). + +### Documentation +- `README.md` with quickstart, architecture diagram, and feature matrix. +- Examples under `examples/`: `chat`, `hooks`, `tool_server`, + `lifecycle_observer`. +- `RELEASING.md` operational runbook for maintainers. + +#### Builder ergonomics +- `ClientOptions::new()` plus a chainable `with_*` builder per public + field (`with_program`, `with_prefix_args`, `with_cwd`, `with_env`, + `with_env_remove`, `with_extra_args`, `with_transport`, + `with_github_token`, `with_use_logged_in_user`, `with_log_level`, + `with_session_idle_timeout_seconds`, `with_list_models_handler`, + `with_session_fs`, `with_trace_context_provider`, `with_telemetry`). + Mirrors the existing [`MessageOptions::new`] / `with_*` shape and + closes the cross-crate ergonomics gap on `#[non_exhaustive]` โ€” + external callers no longer need to write + `let mut opts = ClientOptions::default(); opts.field = ...;` for + every field they touch. Existing `ClientOptions::default()` and + mut-let-and-assign continue to work unchanged. +- `Tool::new(name)` plus `with_namespaced_name`, `with_description`, + `with_instructions`, `with_parameters`, `with_overrides_built_in_tool`, + `with_skip_permission` for tool definitions. Same rationale โ€” + `Tool` is the most-instantiated `#[non_exhaustive]` type at consumer + call sites in real-world consumer code, where the + builder shape replaces the per-consumer `make_tool(name, desc, + params)` helper that consumers were writing to smooth over the + mut-let pattern. +- Per-field `with_*` builder methods on `SessionConfig` and + `ResumeSessionConfig` covering every public scalar, vector, and + optional-struct field (~30 new methods on each). Mirrors the + `ClientOptions` / `Tool` shape; existing closure-installing + chains (`with_handler`, `with_hooks`, `with_transform`, + `with_commands`, `with_session_fs_provider`, + `approve_all_permissions`, etc.) continue to work unchanged. The + primary win: external session-construction sites collapse from + `let mut cfg = ResumeSessionConfig::new(id); cfg.client_name = + Some("...".into()); cfg.streaming = Some(true); ...` (10-15 + lines per site) to a single fluent chain. +- Round out builder coverage on the remaining consumer-facing + config structs: `CustomAgentConfig::new(name, prompt)` plus + `with_display_name`, `with_description`, `with_tools`, + `with_mcp_servers`, `with_infer`, `with_skills`; + `InfiniteSessionConfig::new()` plus `with_enabled`, + `with_background_compaction_threshold`, + `with_buffer_exhaustion_threshold`; + `ProviderConfig::new(base_url)` plus `with_provider_type`, + `with_wire_api`, `with_api_key`, `with_bearer_token`, + `with_azure`, `with_headers`; `SystemMessageConfig::new()` plus + `with_mode`, `with_content`, `with_sections`; + `TelemetryConfig::new()` plus `with_otlp_endpoint`, + `with_file_path`, `with_exporter_type`, `with_source_name`, + `with_capture_content`. `TraceContext` also gains a symmetric + `new()` + `with_traceparent` pair alongside the existing + `from_traceparent` shorthand. +- Documented the direct-field-assignment escape hatch on + `SessionConfig` and `ResumeSessionConfig` for callers forwarding + `Option` values from upstream code (matches the + `http::request::Parts` / `hyper::Body::Builder` convention; per- + field `with_*_opt` setters intentionally omitted to keep the + primary API surface small). + +#### Build infrastructure +- `build.rs` no longer shells out to `curl` for the bundled-CLI + download. The `embedded-cli` feature now downloads the + `SHA256SUMS.txt` and platform tarball through `ureq` (rustls TLS, + pure-Rust, no system dependencies). Removes the implicit `curl`- + on-PATH requirement that previously broke the build on minimal + Windows / container environments. Includes bounded retries with + exponential backoff (1s/2s/4s) on transient failures (5xx, + connect/read timeouts, transport errors) โ€” 4xx responses still + fail fast as before. + +### Fixed +- `SessionEvent` and `TypedSessionEvent` now expose the `agentId` + envelope field added to `session-events.schema.json` upstream + (`f8cf846`, "Derive session event envelopes from schema"). Sub-agent + events were silently dropping the attribution at the deserialization + boundary; consumers had no way to distinguish events emitted by the + root agent from events emitted by a sub-agent. Other SDKs (Node, + Python, Go, .NET) all carry this field. Round-trip parity test added + in `types::tests::session_event_round_trips_agent_id_on_envelope`. +- `Session::user_input` no longer double-dispatches when the CLI sends + both a `user_input.requested` notification (for observers) and a + `userInput.request` JSON-RPC call (the actual prompt) for the same + prompt. The notification path is now a no-op; the JSON-RPC path + remains authoritative. Matches Python / Go / .NET / Node SDK + behavior, all of which only register the JSON-RPC handler. Fixes + github/github-app#4249, where consumers saw duplicate `ask_user` + and `exit_plan` widgets on every prompt. +- `SessionUi::elicitation` (and the `confirm` / `select` / `input` + convenience helpers that delegate through it) now sends the user-supplied + JSON Schema as `requestedSchema` on the wire, matching the + `session.ui.elicitation` request shape that all other SDKs ship and that + this crate's own generated `UIElicitationRequest` type expects. The + hand-authored convenience layer was sending it as `schema`, so every UI + helper call was effectively dead โ€” the CLI saw a missing required + `requestedSchema` field. The mock-server test for elicitation + round-tripped through the same misnamed field, so the bug slipped past + unit tests; the test now asserts on `requestedSchema` and explicitly + rejects a stray `schema` key. +- `Client::list_sessions` now wraps the optional filter under `params.filter` + on the wire, matching the `session.list` request shape that Node, Python, + Go, and .NET ship. The hand-authored implementation was flattening the + filter fields directly onto `params`, which the runtime silently ignored + โ€” so `list_sessions(Some(filter))` was functionally equivalent to + `list_sessions(None)` in 0.0.x. Same class of bug as the elicitation + wire fix above: the existing mock-server test asserted on the flat shape + it observed rather than the schema's wrapped shape, so the bug + round-tripped through both ends. The test now asserts the wrapped path + (`params.filter.repository`) and explicitly rejects the flattened + fallback (`params.repository`). +- `Client::get_status` and `Client::get_auth_status` now use the + correct wire method names (`status.get` and `auth.getStatus`) + matching Node, Go, Python, and .NET. The hand-authored + implementation was sending `getStatus` and `getAuthStatus` โ€” names + that aren't registered on the CLI runtime โ€” so both calls would + have returned a "method not found" error (or a misleading no-such- + method log) instead of the expected status payload. Same class of + bug as the elicitation `requestedSchema` and `list_sessions` + filter-wrapping fixes above: the mock-server test for these + methods asserted on the wrong-name strings the implementation + used, so the bugs round-tripped through both ends. The test now + asserts on the canonical wire names AND explicitly rejects the + hand-authored aliases (`assert_ne!(request["method"], "getStatus")` + / `"getAuthStatus"`). + +### Notes +- Minimum supported Rust version (MSRV): 1.94.0 (pinned via + `rust-toolchain.toml`). +- No `Client::actual_port` accessor โ€” this SDK is strictly stream-based, + so the concept doesn't apply. See `Client::from_streams` rustdoc. +- `cargo semver-checks` runs in `continue-on-error` mode for 0.1.0; will + flip to blocking once 0.1.0 is published and serves as the baseline. +- `infinite_sessions: Option` is wired on both + `SessionConfig` and `ResumeSessionConfig` and follows the same + default-omit-on-the-wire semantics as Node/Go: when `None`, the field + is skipped and the CLI applies its own default. No behavioral + divergence from the other SDKs. +- `Client::stop` returns `Result<(), StopErrors>` and now cooperatively + shuts down each active session via `session.destroy` before killing + the CLI child, aggregating all per-session and child-kill errors into + the returned `StopErrors`. See the entry under "Configuration parity" + above for the migration note. diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..1163de37e --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1777 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "github-copilot-sdk" +version = "0.1.0" +dependencies = [ + "async-trait", + "dirs", + "flate2", + "getrandom 0.2.17", + "parking_lot", + "regex", + "schemars", + "serde", + "serde_json", + "serial_test", + "sha2", + "tar", + "tempfile", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "ureq", + "zip", + "zstd", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.0", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "libredox" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +dependencies = [ + "bitflags", + "libc", + "plain", + "redox_syscall 0.7.4", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.5.18", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f450ad9c3b1da563fb6948a8e0fb0fb9269711c9c73d9ea1de5058c79c8d643a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 1.0.69", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serial_test" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tar" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.52.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.7", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap", + "memchr", + "thiserror 2.0.18", + "zopfli", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000..0b90cc1ab --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,71 @@ +[package] +name = "github-copilot-sdk" +version = "0.1.0" +edition = "2024" +rust-version = "1.94.0" +description = "Rust SDK for programmatic control of the GitHub Copilot CLI via JSON-RPC. Technical preview, pre-1.0." +keywords = ["copilot", "github", "ai", "json-rpc", "sdk"] +categories = ["api-bindings", "development-tools"] +repository = "https://github.com/github/copilot-sdk" +homepage = "https://github.com/github/copilot-sdk" +documentation = "https://docs.rs/github-copilot-sdk" +readme = "README.md" +license = "MIT" +exclude = [ + "RELEASING.md", + "release-plz.toml", + "rust-toolchain.toml", + ".rustfmt.toml", + ".rustfmt.nightly.toml", + "clippy.toml", + ".gitignore", +] + +[lib] +name = "github_copilot_sdk" + +[features] +default = [] +embedded-cli = ["dep:sha2", "dep:zstd"] +derive = ["dep:schemars"] +test-support = [] + +# Build docs.rs documentation with all features so feature-gated APIs +# (e.g. `define_tool`, `schema_for`) appear and intra-doc links resolve. +# Mirror this locally with: `cargo doc --no-deps --all-features`. +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +async-trait = "0.1" +schemars = { version = "1", optional = true } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["io-util", "sync", "rt", "process", "net", "time", "macros"] } +tokio-stream = { version = "0.1", features = ["sync"] } +tokio-util = { version = "0.7", default-features = false } +tracing = "0.1" +dirs = "5" +parking_lot = "0.12" +regex = "1" +sha2 = { version = "0.10", optional = true } +getrandom = "0.2" +zstd = { version = "0.13", optional = true } + +[dev-dependencies] +schemars = "1" +serial_test = "3" +tempfile = "3" +sha2 = "0.10" +tokio = { version = "1", features = ["rt-multi-thread"] } +zstd = "0.13" + +[build-dependencies] +flate2 = "1" +sha2 = "0.10" +tar = "0.4" +ureq = { version = "2", default-features = false, features = ["tls"] } +zip = { version = "2", default-features = false, features = ["deflate"] } +zstd = "0.13" diff --git a/rust/LICENSE b/rust/LICENSE new file mode 120000 index 000000000..ea5b60640 --- /dev/null +++ b/rust/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 000000000..8222f8630 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,800 @@ +# GitHub Copilot CLI SDK for Rust + +A Rust SDK for programmatic access to the GitHub Copilot CLI. + +> **Note:** This SDK is in technical preview and may change in breaking ways. + +See [github/copilot-sdk](https://github.com/github/copilot-sdk) for the equivalent SDKs in TypeScript, Python, Go, and .NET. The Rust SDK seeks parity with those SDKs; see [Differences From Other SDKs](#differences-from-other-sdks) below for the small set of intentional divergences. + +## Quick Start + +```rust,no_run +use std::sync::Arc; +use github_copilot_sdk::{Client, ClientOptions, SessionConfig}; +use github_copilot_sdk::handler::ApproveAllHandler; + +# async fn example() -> Result<(), github_copilot_sdk::Error> { +let client = Client::start(ClientOptions::default()).await?; +let session = client.create_session( + SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), +).await?; +let _message_id = session.send("Hello!").await?; +session.disconnect().await?; +client.stop().await.ok(); +# Ok(()) +# } +``` + +## Architecture + +```text +Your Application + โ†“ + github_copilot_sdk::Client (manages CLI process lifecycle) + โ†“ + github_copilot_sdk::Session (per-session event loop + handler dispatch) + โ†“ JSON-RPC over stdio or TCP + copilot --server --stdio +``` + +The SDK manages the CLI process lifecycle: spawning, health-checking, and graceful shutdown. Communication uses [JSON-RPC 2.0](https://www.jsonrpc.org/specification) over stdin/stdout with `Content-Length` framing (the same protocol used by LSP). TCP transport is also supported. + +## API Reference + +### Client + +```rust,ignore +// Start a client (spawns CLI process) +let client = Client::start(options).await?; + +// Create a new session +let session = client.create_session(config.with_handler(handler)).await?; + +// Resume an existing session +let session = client.resume_session(config.with_handler(handler)).await?; + +// Low-level RPC +let result = client.call("method.name", Some(params)).await?; +let response = client.send_request("method.name", Some(params)).await?; + +// Health check (echoes message back, returns typed PingResponse) +let pong = client.ping("hello").await?; + +// Shutdown +client.stop().await?; +``` + +**`ClientOptions`:** + +| Field | Type | Description | +|---|---|---| +| `program` | `CliProgram` | `Resolve` (default: auto-detect) or `Path(PathBuf)` (explicit) | +| `prefix_args` | `Vec` | Args before `--server` (e.g. script path for node) | +| `cwd` | `PathBuf` | Working directory for CLI process | +| `env` | `Vec<(OsString, OsString)>` | Environment variables for CLI process | +| `env_remove` | `Vec` | Environment variables to remove | +| `extra_args` | `Vec` | Extra CLI flags | +| `transport` | `Transport` | `Stdio` (default), `Tcp { port }`, or `External { host, port }` | + +With the default `CliProgram::Resolve`, `Client::start()` automatically resolves the binary via `github_copilot_sdk::resolve::copilot_binary()` โ€” checking `COPILOT_CLI_PATH`, the [embedded CLI](#embedded-cli), and then the system PATH. Use `CliProgram::Path(path)` to skip resolution. + +### Session + +Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches events to the `SessionHandler`. + +```rust,ignore +use github_copilot_sdk::MessageOptions; + +// Simple send โ€” &str / String convert into MessageOptions automatically. +// Returns the assigned message ID for correlation with later events. +let _id = session.send("Fix the bug in auth.rs").await?; + +// Send with mode and attachments +let _id = session + .send( + MessageOptions::new("What's in this image?") + .with_mode("autopilot") + .with_attachments(attachments), + ) + .await?; + +// Message history +let messages = session.get_messages().await?; + +// Abort the current agent turn +session.abort().await?; + +// Model management +let model = session.get_model().await?; +session.set_model("claude-sonnet-4.5", None).await?; + +// Mode management (interactive, plan, autopilot) +let mode = session.get_mode().await?; +session.set_mode("autopilot").await?; + +// Workspace files +let files = session.list_workspace_files().await?; +let content = session.read_workspace_file("plan.md").await?; + +// Plan management +let (exists, content) = session.read_plan().await?; +session.update_plan("Updated plan content").await?; + +// Fleet (sub-agents) +session.start_fleet(Some("Implement the auth module")).await?; + +// Cleanup (preserves on-disk session state for later resume) +session.disconnect().await?; +``` + +#### Typed RPC namespace + +The ergonomic helpers above are convenience wrappers over a fully-typed +JSON-RPC namespace generated from the GitHub Copilot CLI schema. `Client::rpc()` +and `Session::rpc()` give direct access to every method on the wire, +including ones with no helper today, with strongly-typed request and +response structs. + +```rust,ignore +// Methods with helpers โ€” wire strings live in one generated place. +let files = session.rpc().workspaces().list_files().await?.files; +let models = client.rpc().models().list().await?.models; + +// Methods with no helper โ€” full schema-typed access. +let agents = session.rpc().agent().list().await?.agents; +let tasks = session.rpc().tasks().list().await?.tasks; +let forked = client + .rpc() + .sessions() + .fork(github_copilot_sdk::generated::api_types::SessionsForkRequest { + session_id: "session-id".to_string(), + from_message_id: None, + }) + .await?; +``` + +New RPCs land in the namespace immediately as the schema regenerates; +helpers are added on top only when an ergonomic story is worth the +maintenance. + +### SessionHandler + +Implement this trait to control how a session responds to CLI events. Two styles are supported: + +**1. Per-event methods (recommended).** Override only the callbacks you care about; every method has a safe default (permission โ†’ deny, user input โ†’ none, external tool โ†’ "no handler", elicitation โ†’ cancel, exit plan โ†’ default). This is the `serenity::EventHandler` pattern. + +```rust,ignore +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; + +struct MyHandler; + +#[async_trait] +impl SessionHandler for MyHandler { + async fn on_permission_request( + &self, + _sid: SessionId, + _rid: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + if data.extra.get("tool").and_then(|v| v.as_str()) == Some("view") { + PermissionResult::Approved + } else { + PermissionResult::Denied + } + } + + async fn on_session_event(&self, sid: SessionId, event: github_copilot_sdk::types::SessionEvent) { + println!("[{sid}] {}", event.event_type); + } +} +``` + +**2. Single `on_event` method.** Override `on_event` directly and `match` on `HandlerEvent` โ€” useful for logging middleware, custom routing, or when you want one exhaustive dispatch point. + +```rust,ignore +use github_copilot_sdk::handler::*; +use async_trait::async_trait; + +#[async_trait] +impl SessionHandler for MyRouter { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { session_id, event } => { + println!("[{session_id}] {}", event.event_type); + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + HandlerEvent::UserInput { question, .. } => { + HandlerResponse::UserInput(Some(UserInputResponse { + answer: prompt_user(&question), + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } +} +``` + +The default `on_event` dispatches to the per-event methods, so overriding `on_event` short-circuits them entirely โ€” pick one style per handler. + +Events are processed serially per session โ€” blocking in a handler method pauses that session's event loop (which is correct, since the CLI is also waiting for the response). Other sessions are unaffected. + +> **Note:** Notification-triggered events (`PermissionRequest` via `permission.requested`, `ExternalTool` via `external_tool.requested`) are dispatched on spawned tasks and may run concurrently with the serial event loop. See the trait-level docs on `SessionHandler` for details. + +### SessionConfig + +```rust,ignore +let config = SessionConfig { + model: Some("gpt-5".into()), + system_message: Some(SystemMessageConfig { + content: Some("Always explain your reasoning.".into()), + ..Default::default() + }), + request_elicitation: Some(true), // enable elicitation provider + ..Default::default() +}; +let session = client.create_session(config.with_handler(handler)).await?; +``` + +### Session Hooks + +Hooks intercept CLI behavior at lifecycle points โ€” tool use, prompt submission, session start/end, and errors. Install a `SessionHooks` impl with [`SessionConfig::with_hooks`] โ€” the SDK auto-enables `hooks` in `SessionConfig` when one is set. + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::hooks::*; +use async_trait::async_trait; + +struct MyHooks; + +#[async_trait] +impl SessionHooks for MyHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, ctx } => { + if input.tool_name == "dangerous_tool" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by policy".to_string()), + ..Default::default() + }) + } else { + HookOutput::None // pass through + } + } + HookEvent::SessionStart { input, .. } => { + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("Extra system context".to_string()), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } +} + +let session = client + .create_session( + config + .with_handler(handler) + .with_hooks(Arc::new(MyHooks)), + ) + .await?; +``` + +**Hook events:** `PreToolUse`, `PostToolUse`, `UserPromptSubmitted`, `SessionStart`, `SessionEnd`, `ErrorOccurred`. Each carries typed input/output structs. Return `HookOutput::None` for events you don't handle. + +### System Message Transforms + +Transforms customize system message sections during session creation. The SDK injects `action: "transform"` entries for each section ID your transform handles. + +```rust,ignore +use github_copilot_sdk::transforms::*; +use async_trait::async_trait; + +struct MyTransform; + +#[async_trait] +impl SystemMessageTransform for MyTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + Some(format!("{content}\n\nAlways be concise.")) + } +} + +let session = client + .create_session( + config + .with_handler(handler) + .with_transform(Arc::new(MyTransform)), + ) + .await?; +``` + +### Tool Registration + +Define client-side tools as named types with `ToolHandler`, then route them with `ToolHandlerRouter`. Enable the `derive` feature for `schema_for::()` โ€” it generates JSON Schema from Rust types via `schemars`. + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ + schema_for, tool_parameters, JsonSchema, ToolHandler, ToolHandlerRouter, +}; +use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; +use serde::Deserialize; +use async_trait::async_trait; + +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { + /// City name + city: String, + /// Temperature unit + unit: Option, +} + +struct GetWeatherTool; + +#[async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + namespaced_name: None, + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + instructions: None, + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; + Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) + } +} + +// Build a router that dispatches tool calls by name +let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(ApproveAllHandler), +); + +let config = SessionConfig { + tools: Some(router.tools()), + ..Default::default() +} +.with_handler(Arc::new(router)); +let session = client.create_session(config).await?; +``` + +Tools are named types (not closures) โ€” visible in stack traces and navigable via "go to definition". The router implements `SessionHandler`, forwarding unrecognized tools and non-tool events to the inner handler. + +For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression: + +```rust,ignore +use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandlerRouter}; +use github_copilot_sdk::ToolResult; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { city: String } + +let router = ToolHandlerRouter::new( + vec![define_tool( + "get_weather", + "Get weather for a city", + |_inv, params: GetWeatherParams| async move { + Ok(ToolResult::Text(format!("Sunny in {}", params.city))) + }, + )], + Arc::new(ApproveAllHandler), +); +``` + +The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) alongside the deserialized parameters, so handlers that need `inv.session_id` or `inv.tool_call_id` for telemetry, streaming updates, or scoped lookups can use them directly. Use `_inv` when you don't need the metadata. + +Reach for the `ToolHandler` trait directly when you need shared state across multiple methods or want a named type that shows up by name in stack traces. + +### Permission Policies + +Set a permission policy directly on `SessionConfig` with the chainable builders. They wrap whatever handler you've installed (defaulting to `DenyAllHandler` if none) so only permission requests are intercepted; every other event flows through unchanged. + +```rust,ignore +let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(my_handler)) + .approve_all_permissions(), + // or .deny_all_permissions() + // or .approve_permissions_if(|data| { + // data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + // }) + ) + .await?; +``` + +> Call the policy method **after** `with_handler` โ€” `with_handler` overwrites the handler field, so `approve_all_permissions().with_handler(...)` discards the wrap. + +For composing a policy onto a handler outside the builder chain (e.g. when wrapping a `ToolHandlerRouter` you've built elsewhere), the `permission` module exposes the same primitives as free functions: + +```rust,ignore +use github_copilot_sdk::permission; + +let router = ToolHandlerRouter::new(tools, Arc::new(MyHandler)); +let handler = permission::approve_all(Arc::new(router)); +// or permission::deny_all(...) / permission::approve_if(..., predicate) + +let session = client.create_session(config.with_handler(handler)).await?; +``` + +### Capabilities & Elicitation + +The SDK negotiates capabilities with the CLI after session creation. Enable elicitation to let the agent present structured UI dialogs (forms, URL prompts) to the user. + +```rust,ignore +let config = SessionConfig { + request_elicitation: Some(true), + ..Default::default() +}; +``` + +The handler receives `HandlerEvent::ElicitationRequest` with a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. Return `HandlerResponse::Elicitation(result)`. + +### User Input Requests + +Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Implement `SessionHandler::on_user_input` and the SDK will forward `userInput.request` callbacks: + +```rust,ignore +async fn on_user_input( + &self, + _session_id: SessionId, + question: String, + choices: Option>, + _allow_freeform: Option, +) -> Option { + // Render `question` + `choices` to your UI, then: + Some(UserInputResponse { + answer: "Yes".to_string(), + was_freeform: false, + }) +} +``` + +Return `None` to signal "no answer available" (the CLI falls back to its own prompt). Enable via `SessionConfig::request_user_input` (defaults to `Some(true)`). + +### Slash Commands + +Register named commands so users can invoke them as `/name args` from the TUI: + +```rust,ignore +use github_copilot_sdk::types::{CommandContext, CommandDefinition, CommandHandler}; +use async_trait::async_trait; + +struct DeployCommand; + +#[async_trait] +impl CommandHandler for DeployCommand { + async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + println!("deploy {}", ctx.args); + Ok(()) + } +} + +let mut config = SessionConfig::default(); +config.commands = Some(vec![ + CommandDefinition::new("deploy", Arc::new(DeployCommand)) + .with_description("Deploy the application"), +]); +``` + +Only `name` and `description` are sent over the wire; the handler stays in your process. Returning `Err(_)` surfaces the message back through the TUI. + +### Streaming + +Set `streaming: true` to receive incremental delta events alongside finalized messages: + +```rust,ignore +let mut config = SessionConfig::default(); +config.streaming = Some(true); + +let mut events = session.subscribe(); +while let Ok(event) = events.recv().await { + match event.event_type.as_str() { + "assistant.message_delta" | "assistant.reasoning_delta" => { + if let Some(d) = event.data.get("delta").and_then(|v| v.as_str()) { + print!("{d}"); + } + } + "assistant.message" => println!(), // final + _ => {} + } +} +``` + +When streaming is off (the default), only the final `assistant.message` and `assistant.reasoning` events fire. Delta events arrive in order; concatenating their `delta` text payloads reproduces the final message. + +### Infinite Sessions + +Enable the SDK's session-store integration so conversations persist across CLI restarts and grow beyond the model's context window via automatic compaction: + +```rust,ignore +use github_copilot_sdk::types::InfiniteSessionConfig; + +let mut infinite = InfiniteSessionConfig::default(); +infinite.workspace_path = Some("/path/to/workspace".into()); + +let mut config = SessionConfig::default(); +config.infinite_sessions = Some(infinite); +``` + +The CLI emits `session.compaction_start` / `session.compaction_complete` events around each compaction. The session id remains stable across compactions; resume with `Client::resume_session` to pick up a prior conversation. Workspace state lives under `~/.copilot/session-state/{sessionId}` by default โ€” override with `workspace_path` to relocate. + +### Custom Providers (BYOK) + +Route model traffic through your own inference endpoint instead of GitHub's hosted models: + +```rust,ignore +use github_copilot_sdk::types::ProviderConfig; + +let mut provider = ProviderConfig::default(); +provider.provider_type = Some("openai".to_string()); +provider.base_url = "https://my-proxy.example.com/v1".to_string(); +provider.bearer_token = Some(std::env::var("OPENAI_API_KEY")?); + +let mut config = SessionConfig::default(); +config.provider = Some(provider); +``` + +Provider types include `"openai"`, `"azure"`, and `"anthropic"`. Set `wire_api` to `"completions"` or `"responses"` (OpenAI/Azure only). Custom headers go in `provider.headers`. The SDK forwards the configuration to the CLI verbatim โ€” the CLI handles the upstream call, including authentication. + +### Telemetry + +Forward OpenTelemetry signals from the spawned CLI process to your collector: + +```rust,ignore +use github_copilot_sdk::{ClientOptions, OtelExporterType, TelemetryConfig}; + +let mut telem = TelemetryConfig::default(); +telem.exporter_type = Some(OtelExporterType::OtlpHttp); +telem.otlp_endpoint = Some("http://localhost:4318".to_string()); +telem.source_name = Some("my-app".to_string()); + +let mut opts = ClientOptions::default(); +opts.telemetry = Some(telem); +let client = Client::start(opts).await?; +``` + +The SDK injects the appropriate environment variables (`COPILOT_OTEL_EXPORTER_TYPE`, `OTEL_EXPORTER_OTLP_ENDPOINT`, ...) into the spawned CLI process. The SDK takes no OpenTelemetry dependency; the CLI itself owns the exporter pipeline. Caller-supplied `ClientOptions::env` entries override telemetry-injected values. + +### Progress Reporting (`send_and_wait`) + +For fire-and-forget messaging where you need to block until the agent finishes: + +```rust,ignore +use std::time::Duration; +use github_copilot_sdk::MessageOptions; + +// Sends a message and blocks until session.idle or session.error +session + .send_and_wait( + MessageOptions::new("Fix the bug").with_wait_timeout(Duration::from_secs(120)), + ) + .await?; +``` + +Default timeout is 60 seconds. Only one `send_and_wait` can be active per session โ€” concurrent calls return an error. + +### Newtypes + +**`SessionId`** โ€” a newtype wrapper around `String` that prevents accidentally passing workspace IDs or request IDs where session IDs are expected. Transparent serialization (`#[serde(transparent)]`), zero-cost `Deref`, and ergonomic comparisons with `&str` and `String`. + +```rust,ignore +use github_copilot_sdk::SessionId; + +let id = SessionId::new("sess-abc123"); +assert_eq!(id, "sess-abc123"); // compare with &str +let raw: String = id.into_inner(); // unwrap when needed +``` + +### Error Handling + +The SDK uses a typed error enum: + +```rust,ignore +pub enum Error { + Protocol(ProtocolError), // JSON-RPC framing, CLI startup, version mismatch + Rpc { code: i32, message: String }, // CLI returned an error response + Session(SessionError), // Session not found, agent error, timeout, conflicts + Io(std::io::Error), // Transport I/O error + Json(serde_json::Error), // Serialization error + BinaryNotFound { name, hint }, // CLI binary not found +} + +// Check if the transport is broken (caller should discard the client) +if err.is_transport_failure() { + client = Client::start(options).await?; +} +``` + +## Differences From Other SDKs + +The Rust SDK aligns closely with the Node, Python, and Go SDKs but diverges +in a few places where Rust idiom or the type system gives a clearly better +shape, and exposes a small additional surface where the language affords +ergonomics the dynamically-typed SDKs don't. + +### Shape divergence + +- **`SessionFsProvider` registration is direct, not factory-closure.** Where + Node/Python/Go accept a closure that the runtime calls on each + session-create to build a fresh provider, the Rust SDK takes + `Arc` directly via + [`SessionConfig::with_session_fs_provider`]. The factory pattern doesn't + cleanly express in Rust at the session-config call site โ€” there is no + `Session` value to thread in, and the SDK already prefers traits over + boxed closures for handler-shaped APIs (`SessionHandler`, `SessionHooks`, + `ToolHandler`). + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::session_fs::{SessionFsConfig, SessionFsConventions}; + +let mut options = ClientOptions::default(); +options.session_fs = Some(SessionFsConfig::new( + "/workspace", + "/workspace/.copilot", + SessionFsConventions::Posix, +)); +let client = Client::start(options).await?; + +let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_session_fs_provider(Arc::new(MyProvider::new())), + ) + .await?; +``` + +See [`examples/session_fs.rs`](examples/session_fs.rs) for a complete +in-memory provider implementation. + +### Rust-only API + +A handful of conveniences exist only on the Rust SDK as of 0.1.0. These +are surface areas where Rust idiom (newtypes, enums, trait objects) +gives a clearly nicer shape than Node/Python/Go currently expose. Rust +gets to be Rust here โ€” cross-SDK parity for these is a post-release +conversation, not a release blocker. None of these are deprecated and +none of them are scheduled for removal. + +- **`Client::get_quota`** โ€” top-level convenience wrapper for fetching + account-level request quota snapshots. Rust-only as of 0.1.0; the other + SDKs do not expose a client-level shortcut. The underlying + `account.getQuota` JSON-RPC endpoint itself is available cross-SDK via + each SDK's typed `rpc()` namespace (Node + `client.rpc().account().getQuota()`, Python + `client.rpc().account.get_quota()`, Go + `client.Rpc().Account().GetQuota()`, .NET + `client.Rpc().Account().GetQuotaAsync()`), including in Rust at + `client.rpc().account().get_quota()`. +- **First-class `Session` convenience methods** โ€” `set_mode` / `get_mode`, + `set_name` / `get_name`, `get_model`, `read_plan` / `update_plan` / + `delete_plan`, `start_fleet`, `list_workspace_files` / + `read_workspace_file` / `create_workspace_file`. The other SDKs require + the consumer to drive the typed JSON-RPC namespace directly for these. +- **`Client::send_telemetry` / `Session::send_telemetry`** โ€” top-level + and session-scoped telemetry emission via `sendTelemetry` / + `session.sendTelemetry`. Other SDKs do not currently expose these RPC + endpoints in their public APIs (not even via the typed namespace). +- **Typed newtypes** โ€” `SessionId` and `RequestId` are `#[serde(transparent)]` + newtypes around `String`, so the type system distinguishes a session + identifier from an arbitrary `String` at compile time. Node/Python/Go + use bare strings. +- **Permission policy builders** โ€” `permission::approve_all`, + `permission::deny_all`, and `permission::approve_if(handler, predicate)` + in `crate::permission` provide composable, no-handler-needed permission + shortcuts that wrap an existing `SessionHandler`. Other SDKs require a + full handler implementation for these patterns. +- **`Client::from_streams`** โ€” connect to a CLI server over arbitrary + caller-supplied `AsyncRead` / `AsyncWrite`. Useful for testing, + in-process embedding, or custom transports. Other SDKs are spawn-only + or fixed-stdio. +- **`enum Transport { Stdio, Tcp, External }`** โ€” explicit, exhaustive + transport selector on `ClientOptions::transport`. Node/Python/Go rely + on conditional config field combinations instead. +- **Split `prefix_args` / `extra_args`** on `ClientOptions` โ€” separate + arg vectors for "prepend before subcommand" vs "append after the + built-in flags", giving precise control over CLI invocation order + without string-splicing. +- **`SessionHandler::on_auto_mode_switch`** โ€” typed handler for the CLI's + rate-limit-recovery prompt (CLI's `autoModeSwitch.request` callback, + added in copilot-agent-runtime PR #7024). Returns + `AutoModeSwitchResponse::{Yes, YesAlways, No}`. Default impl declines. + Cross-SDK parity is post-release follow-up โ€” Node / Python / Go / .NET + consumers currently observe the request as a raw event and must drive + the wire response themselves. + +## Layout + +| File | Description | +|---|---| +| `lib.rs` | `Client`, `ClientOptions`, `CliProgram`, `Transport`, `Error` | +| `session.rs` | `Session` struct, event loop, `send`/`send_and_wait`, `Client::create_session`/`resume_session` | +| `subscription.rs` | `EventSubscription` / `LifecycleSubscription` (`Stream`-able observer handles for `subscribe()` / `subscribe_lifecycle()`) | +| `handler.rs` | `SessionHandler` trait, `HandlerEvent`/`HandlerResponse` enums, `ApproveAllHandler` | +| `hooks.rs` | `SessionHooks` trait, `HookEvent`/`HookOutput` enums, typed hook inputs/outputs | +| `transforms.rs` | `SystemMessageTransform` trait, section-level system message customization | +| `tool.rs` | `ToolHandler` trait, `ToolHandlerRouter`, `schema_for::()` (with `derive` feature) | +| `types.rs` | CLI protocol types (`SessionId`, `SessionEvent`, `SessionConfig`, `Tool`, etc.) | +| `resolve.rs` | Binary resolution (`copilot_binary`, `node_binary`, `extended_path`) | +| `embeddedcli.rs` | Embedded CLI extraction (`embedded-cli` feature) | +| `router.rs` | Internal per-session event demux | +| `jsonrpc.rs` | Internal Content-Length framed JSON-RPC transport | + +## Embedded CLI + +By default, `copilot_binary()` searches `COPILOT_CLI_PATH`, the system PATH, and common install locations. To **ship with a specific CLI version** embedded in the binary, set `COPILOT_CLI_VERSION` at build time: + +```bash +COPILOT_CLI_VERSION=1.0.15 cargo build +``` + +### How it works + +1. **Build time:** The SDK's `build.rs` detects `COPILOT_CLI_VERSION`, downloads the platform-appropriate archive from the [`github/copilot-cli` GitHub Releases](https://github.com/github/copilot-cli/releases) (`copilot-{platform}.tar.gz` on macOS/Linux, `.zip` on Windows), verifies the archive's SHA-256 against the release's `SHA256SUMS.txt`, extracts the `copilot` binary, compresses it with zstd, and embeds via `include_bytes!()`. No extra steps or tools needed โ€” just the env var. + +2. **Runtime:** On the first call to `github_copilot_sdk::resolve::copilot_binary()`, the embedded binary is lazily extracted to `~/.cache/github-copilot-sdk-{version}/copilot` (or `copilot.exe` on Windows), SHA-256 verified, and cached. Subsequent calls return the cached path. + +3. **Dev builds:** Without the env var, `build.rs` does nothing. The binary is resolved from PATH as usual โ€” zero friction. + +### Resolution priority + +`copilot_binary()` checks these sources in order: + +1. `COPILOT_CLI_PATH` environment variable +2. Embedded CLI (build-time, via `COPILOT_CLI_VERSION`) +3. System PATH + common install locations + +### Platforms + +Supported: `darwin-arm64`, `darwin-x64`, `linux-x64`, `linux-arm64`, `win32-x64`, `win32-arm64`. The target platform is auto-detected from `CARGO_CFG_TARGET_OS` and `CARGO_CFG_TARGET_ARCH` (cross-compilation works). + +## Features + +No features are enabled by default โ€” the bare SDK resolves the CLI from `COPILOT_CLI_PATH` or the system PATH without pulling in additional feature-gated dependencies. + +| Feature | Default | Description | +|---|---|---| +| `embedded-cli` | โ€” | Build-time CLI embedding via `COPILOT_CLI_VERSION` (adds `sha2`, `zstd`). Enable when you need to ship a self-contained binary with a pinned CLI version. | +| `derive` | โ€” | `schema_for::()` for generating JSON Schema from Rust types (adds `schemars`). Enable when defining [tool parameters](#tool-registration). | + +```toml +# These examples use registry syntax for illustration; until the crate is +# published, use a path or git dependency instead. + +# Minimal โ€” resolve CLI from PATH +github-copilot-sdk = "0.1" + +# Ship a pinned CLI version in your binary +github-copilot-sdk = { version = "0.1", features = ["embedded-cli"] } + +# Derive JSON Schema for tool parameters +github-copilot-sdk = { version = "0.1", features = ["derive"] } + +# Both +github-copilot-sdk = { version = "0.1", features = ["embedded-cli", "derive"] } +``` diff --git a/rust/RELEASING.md b/rust/RELEASING.md new file mode 100644 index 000000000..5361591d2 --- /dev/null +++ b/rust/RELEASING.md @@ -0,0 +1,192 @@ +# Releasing `github-copilot-sdk` + +This document describes how to cut a release of the `github-copilot-sdk` Rust crate +and publish it to [crates.io]. It is the operational counterpart to the +workflow files under `../.github/workflows/rust-*.yml` (which run the actual +mechanics). + +If you are adding code to the SDK, you do not need to read this. This is for +maintainers cutting a release. + +[crates.io]: https://crates.io/crates/github-copilot-sdk + +--- + +## TL;DR + +1. Land your changes on `main` using conventional-commit messages. +2. Trigger the **Rust SDK: Create Release PR** workflow manually + (`workflow_dispatch`). +3. Review and merge the PR that release-plz opens. +4. The **Rust SDK: Publish Release** workflow runs automatically when that + PR merges, publishes to crates.io, tags `rust-vX.Y.Z`, and creates a + GitHub Release. + +The first 0.1.0 publish requires a one-time `CARGO_REGISTRY_TOKEN` secret +setup โ€” see [First-time setup](#first-time-setup) below. + +--- + +## How releases are cut + +The crate uses [release-plz] in a two-PR workflow. Both PRs run unattended +through GitHub Actions; the only manual step is reviewing and merging. + +[release-plz]: https://release-plz.dev/ + +### Step 1 โ€” `release-plz release-pr` + +Workflow: `.github/workflows/rust-release-pr.yml` (`workflow_dispatch` only). + +When you trigger it, release-plz: + +- Reads conventional-commit history since the last `rust-vX.Y.Z` tag. +- Decides the next version (patch / minor / major) per SemVer rules. +- Bumps `rust/Cargo.toml`'s `version` field. +- Renames `## [Unreleased]` in `rust/CHANGELOG.md` to `## [X.Y.Z] - + ` and prepends a fresh empty `## [Unreleased]` above it. +- Opens a PR with those changes. + +Review the PR. The CHANGELOG entry is the one users see on crates.io and on +the GitHub Release page, so make sure it reads well. Edit the PR directly if +the auto-generated entry needs tweaking. + +> **First-publish note.** The hand-curated 0.1.0 entry currently lives +> under `## [Unreleased]` so release-plz will rename it cleanly on the +> first run. If release-plz instead generates a *second* entry from +> conventional commits and prepends it above the curated one (depends on +> the configured `body` template), delete the auto-generated stub in the +> PR and keep the curated entry โ€” you only want one 0.1.0 section. + +### Step 2 โ€” `release-plz release` (publish) + +Workflow: `.github/workflows/rust-publish-release.yml` (auto-runs on push +to `main` when `rust/Cargo.toml`, `rust/Cargo.lock`, or `rust/release-plz.toml` +changes). + +When the release-PR from step 1 merges, this workflow detects that +`rust/Cargo.toml`'s version is newer than the latest `rust-vX.Y.Z` tag and: + +- Runs `cargo publish` to upload to crates.io. +- Creates a `rust-vX.Y.Z` git tag. +- Creates a GitHub Release with the CHANGELOG entry as the body. + +The workflow is a no-op on non-release commits, so it's safe to run on every +push. + +--- + +## First-time setup + +Before the first 0.1.0 publish, complete this checklist exactly once: + +1. **Reserve the crate name.** Have a maintainer with crates.io 2FA log in + to crates.io and run `cargo publish` for an empty stub OR claim the name + via the "New Crate" form. The owner account should be a service account + (preferred) or a senior maintainer. +2. **Generate a scoped API token.** crates.io โ†’ Account Settings โ†’ API + Tokens โ†’ New Token. Scope it to publish `github-copilot-sdk` *only* โ€” do not + issue an unscoped token. +3. **Add the secret.** GitHub repo Settings โ†’ Secrets and variables โ†’ + Actions โ†’ New repository secret named `CARGO_REGISTRY_TOKEN`, value = + the token from step 2. +4. **Rotation.** Rotate the token annually and whenever the maintainer set + changes. There's no automated reminder for this โ€” set a calendar event. + +Until this checklist is complete, `cargo publish` in the workflow will fail. +That's intentional: it keeps accidental publishes from happening before the +repo is ready. + +--- + +## Versioning policy + +The crate follows [SemVer]. Pre-1.0 we treat **0.x.0** as breaking and +**0.x.y** as additive โ€” same as the Rust ecosystem convention. + +[SemVer]: https://semver.org/ + +Two CI checks defend the API surface: + +- **`cargo semver-checks`** (`.github/workflows/rust-sdk-tests.yml`) โ€” + detects breaking changes against the latest *published* version on + crates.io. Currently `continue-on-error: true` because there's no + baseline yet. **Flip it to `false` after 0.1.0 ships** to make SemVer + enforcement blocking. + +For ad-hoc public-surface inspection, `cargo public-api -sss --features +derive,test-support` is handy โ€” but the surface is not snapshotted in the +repo. The rendered docs on [docs.rs](https://docs.rs/github-copilot-sdk) are the +canonical reference; `cargo-semver-checks` is the gate. + +For 0.x โ†’ 1.0, do an explicit API review pass (compare against the +language siblings under `../{nodejs,python,go,dotnet}/`), +remove anything `#[doc(hidden)]` you don't intend to keep public, and +write out the 1.0 commitment in the CHANGELOG. + +--- + +## Public-disclosure gate + +The Rust SDK release-prep work happens on `tclem/rust-sdk-release-prep` +and is held *unpushed* until product/comms gives explicit OK. Do not push +the branch, open a PR, or otherwise expose the work without that signal โ€” +even if CI looks ready. + +Ways to keep moving without pushing: + +- Land work in local commits on the prep branch. +- Use `cargo publish --dry-run --allow-dirty` to validate package contents. +- Use `cargo public-api -sss --features derive,test-support` for ad-hoc + surface inspection. + +When the gate opens: + +1. Push `tclem/rust-sdk-release-prep`. +2. Open a PR titled "Rust SDK: prepare for 0.1.0 release" (or similar). +3. Once it merges, trigger the **Rust SDK: Create Release PR** workflow and + proceed with the publish flow above. + +--- + +## Manual publish (emergency only) + +If GitHub Actions is unavailable, a maintainer with crates.io credentials +can publish locally: + +```sh +cd rust + +# Verify the package contents first. +cargo publish --dry-run + +# Publish for real. +cargo publish + +# Tag and push. +git tag rust-v$(cargo metadata --no-deps --format-version=1 \ + | jq -r '.packages[] | select(.name=="github-copilot-sdk") | .version' | head -1) +git push origin --tags +``` + +Manual publishes skip the release-PR review step, so write the CHANGELOG +entry by hand before publishing and commit it on `main` first. + +--- + +## Yanking a release + +If a published version contains a critical bug (security, data loss, panic +on common input), yank it from crates.io to prevent new installs: + +```sh +cargo yank --version X.Y.Z github-copilot-sdk +``` + +Yanking does *not* delete the version โ€” existing `Cargo.lock` files keep +working โ€” but it stops new resolutions from picking it. Follow up with a +patch release that fixes the bug, and add a note to the yanked version's +GitHub Release explaining why. + +Reverse with `cargo yank --undo --version X.Y.Z github-copilot-sdk` if the yank +was a mistake. diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 000000000..22463c9a9 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,340 @@ +use std::io::Read; +use std::path::Path; +use std::time::Duration; + +use sha2::Digest; + +fn main() { + println!("cargo:rerun-if-env-changed=COPILOT_CLI_VERSION"); + println!("cargo:rerun-if-env-changed=BUNDLED_CLI_CACHE_DIR"); + println!("cargo::rustc-check-cfg=cfg(has_bundled_cli)"); + + let Ok(version) = std::env::var("COPILOT_CLI_VERSION") else { + return; + }; + + let Some(platform) = target_platform() else { + println!( + "cargo:warning=COPILOT_CLI_VERSION set but unsupported target platform, skipping CLI bundling" + ); + return; + }; + + let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR is always set by cargo"); + let out = Path::new(&out_dir); + + let base_url = format!("https://github.com/github/copilot-cli/releases/download/v{version}"); + let cache_dir = std::env::var("BUNDLED_CLI_CACHE_DIR") + .ok() + .map(std::path::PathBuf::from); + + // Download SHA256SUMS and find the expected hash for our platform's tarball. + let asset_name = platform.asset_name; + println!("cargo:warning=Bundling GitHub Copilot CLI v{version} ({asset_name})"); + // Download checksums and find the expected hash for our platform's archive. + let checksums_url = format!("{base_url}/SHA256SUMS.txt"); + let checksums = download_with_retry(&checksums_url); + let checksums_text = + std::str::from_utf8(&checksums).expect("checksums file is not valid UTF-8"); + let expected_hash = find_sha256_for_asset(checksums_text, asset_name); + + // Use a versioned cache key since copilot asset names don't include the version. + let cache_key = format!("v{version}-{asset_name}"); + + // Download the archive (or read from cache) and verify integrity. + let archive = cached_download( + &format!("{base_url}/{asset_name}"), + &cache_key, + &expected_hash, + &cache_dir, + ); + println!("cargo:warning=SHA-256 verified ({expected_hash})"); + + // Extract the binary from the archive. + let binary = extract_binary(&archive, platform.binary_name, platform.is_zip); + println!( + "cargo:warning=Extracted {} ({} bytes)", + platform.binary_name, + binary.len() + ); + + // Compress and embed. + let hash = sha256(&binary); + let compressed = zstd::encode_all(&binary[..], 19).expect("zstd compression failed"); + println!( + "cargo:warning=Compressed to {} bytes ({:.1}%)", + compressed.len(), + compressed.len() as f64 / binary.len() as f64 * 100.0 + ); + + std::fs::write(out.join("copilot_cli.zst"), &compressed) + .expect("failed to write copilot_cli.zst"); + + let hash_tokens: Vec = hash.iter().map(|b| format!("0x{b:02x}")).collect(); + let generated = format!( + r#"// Auto-generated by github-copilot-sdk build.rs. Do not edit. +pub(super) static CLI_BYTES: &[u8] = include_bytes!("copilot_cli.zst"); +pub(super) static CLI_HASH: [u8; 32] = [{}]; +pub(super) static CLI_VERSION: &str = "{version}"; +"#, + hash_tokens.join(", ") + ); + + std::fs::write(out.join("bundled_cli.rs"), generated).expect("failed to write bundled_cli.rs"); + + println!("cargo:rustc-cfg=has_bundled_cli"); +} + +struct Platform { + asset_name: &'static str, + binary_name: &'static str, + is_zip: bool, +} + +fn target_platform() -> Option { + let os = std::env::var("CARGO_CFG_TARGET_OS").ok()?; + let arch = std::env::var("CARGO_CFG_TARGET_ARCH").ok()?; + + match (os.as_str(), arch.as_str()) { + ("macos", "aarch64") => Some(Platform { + asset_name: "copilot-darwin-arm64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("macos", "x86_64") => Some(Platform { + asset_name: "copilot-darwin-x64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("linux", "x86_64") => Some(Platform { + asset_name: "copilot-linux-x64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("linux", "aarch64") => Some(Platform { + asset_name: "copilot-linux-arm64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("windows", "x86_64") => Some(Platform { + asset_name: "copilot-win32-x64.zip", + binary_name: "copilot.exe", + is_zip: true, + }), + ("windows", "aarch64") => Some(Platform { + asset_name: "copilot-win32-arm64.zip", + binary_name: "copilot.exe", + is_zip: true, + }), + _ => None, + } +} + +/// Read a file from the download cache, or download it (with retries) and save +/// to cache. Verifies SHA-256 on every path. Evicts stale/corrupt cache entries +/// automatically. Cache I/O failures are treated as cache misses โ€” they never +/// break the build. +fn cached_download( + url: &str, + cache_key: &str, + expected_hash: &str, + cache_dir: &Option, +) -> Vec { + if let Some(dir) = cache_dir { + let cached_path = dir.join(cache_key); + if cached_path.is_file() { + match std::fs::read(&cached_path) { + Ok(data) if hex_sha256(&data) == expected_hash => { + println!( + "cargo:warning=Using cached archive: {}", + cached_path.display() + ); + return data; + } + Ok(_) => { + println!("cargo:warning=Cached archive hash mismatch, re-downloading"); + let _ = std::fs::remove_file(&cached_path); + } + Err(e) => { + println!( + "cargo:warning=Failed to read cache {}, re-downloading: {e}", + cached_path.display() + ); + } + } + } + } + + let data = download_with_retry(url); + let actual_hash = hex_sha256(&data); + if actual_hash != expected_hash { + panic!( + "Archive integrity check failed for {url}!\n expected: {expected_hash}\n actual: {actual_hash}\n \ + This could indicate a corrupted download or a supply-chain attack." + ); + } + + if let Some(dir) = cache_dir { + if let Err(e) = std::fs::create_dir_all(dir) { + println!( + "cargo:warning=Failed to create cache directory {}: {e}", + dir.display() + ); + } else { + let cached_path = dir.join(cache_key); + if let Err(e) = std::fs::write(&cached_path, &data) { + println!( + "cargo:warning=Failed to write cache file {}: {e}", + cached_path.display() + ); + } else { + println!("cargo:warning=Cached archive to: {}", cached_path.display()); + } + } + } + + data +} + +/// Maximum number of HTTP attempts (one initial + this many retries on transient errors). +const MAX_RETRIES: u32 = 3; + +/// Download `url` with bounded retries on transient network errors. Backoff is +/// exponential starting at 1s. 4xx responses fail fast; 5xx and connect/read +/// errors are retried. +fn download_with_retry(url: &str) -> Vec { + let mut attempt = 0u32; + loop { + attempt += 1; + match try_download(url) { + Ok(bytes) => return bytes, + Err(err) if err.transient && attempt <= MAX_RETRIES => { + let backoff = Duration::from_secs(1u64 << (attempt - 1)); + println!( + "cargo:warning=Transient download failure for {url} (attempt {attempt}/{}): {} โ€” retrying in {}s", + MAX_RETRIES + 1, + err.message, + backoff.as_secs(), + ); + std::thread::sleep(backoff); + } + Err(err) => panic!("Failed to download {url}: {}", err.message), + } + } +} + +struct DownloadError { + message: String, + transient: bool, +} + +fn try_download(url: &str) -> Result, DownloadError> { + let agent = ureq::AgentBuilder::new() + .timeout_connect(Duration::from_secs(30)) + .timeout_read(Duration::from_secs(120)) + .build(); + + match agent.get(url).call() { + Ok(response) => { + let mut bytes = Vec::new(); + response + .into_reader() + .read_to_end(&mut bytes) + .map_err(|e| DownloadError { + message: format!("read error: {e}"), + transient: true, + })?; + Ok(bytes) + } + // 5xx โ€” server-side, treat as transient. + Err(ureq::Error::Status(code, response)) if (500..600).contains(&code) => { + Err(DownloadError { + message: format!("HTTP {code} {}", response.status_text()), + transient: true, + }) + } + // 4xx โ€” client-side, fail fast. + Err(ureq::Error::Status(code, response)) => Err(DownloadError { + message: format!("HTTP {code} {}", response.status_text()), + transient: false, + }), + // Transport-layer (DNS, connect, TLS, read timeout) โ€” treat as transient. + Err(ureq::Error::Transport(t)) => Err(DownloadError { + message: format!("transport error: {t}"), + transient: true, + }), + } +} + +fn find_sha256_for_asset(sums: &str, asset_name: &str) -> String { + for line in sums.lines() { + // Format: " " (two spaces) + if let Some((hash, name)) = line.split_once(" ") + && name.trim() == asset_name + { + return hash.trim().to_string(); + } + } + panic!("SHA256SUMS.txt does not contain an entry for {asset_name}"); +} + +fn extract_binary(archive_bytes: &[u8], binary_name: &str, is_zip: bool) -> Vec { + if is_zip { + extract_from_zip(archive_bytes, binary_name) + } else { + extract_from_tarball(archive_bytes, binary_name) + } +} + +fn extract_from_tarball(tarball: &[u8], binary_name: &str) -> Vec { + let gz = flate2::read::GzDecoder::new(tarball); + let mut archive = tar::Archive::new(gz); + + for entry in archive.entries().expect("failed to read tarball entries") { + let mut entry = entry.expect("failed to read tarball entry"); + let path = entry + .path() + .expect("entry has no path") + .to_string_lossy() + .to_string(); + if path == binary_name || path.ends_with(&format!("/{binary_name}")) { + let mut bytes = Vec::new(); + entry + .read_to_end(&mut bytes) + .expect("failed to read binary from tarball"); + return bytes; + } + } + + panic!("'{binary_name}' not found in tarball"); +} + +fn extract_from_zip(zip_bytes: &[u8], binary_name: &str) -> Vec { + // Minimal zip extraction โ€” find the binary by name. + // The Windows assets are .zip files with just copilot.exe at the root. + let cursor = std::io::Cursor::new(zip_bytes); + let mut archive = zip::ZipArchive::new(cursor).expect("failed to read zip archive"); + + for i in 0..archive.len() { + let mut file = archive.by_index(i).expect("failed to read zip entry"); + let name = file.name().to_string(); + if name == binary_name || name.ends_with(&format!("/{binary_name}")) { + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .expect("failed to read binary from zip"); + return bytes; + } + } + + panic!("'{binary_name}' not found in zip"); +} + +fn sha256(data: &[u8]) -> [u8; 32] { + let mut hasher = sha2::Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +fn hex_sha256(data: &[u8]) -> String { + sha256(data).iter().map(|b| format!("{b:02x}")).collect() +} diff --git a/rust/clippy.toml b/rust/clippy.toml new file mode 100644 index 000000000..22781c472 --- /dev/null +++ b/rust/clippy.toml @@ -0,0 +1,8 @@ +await-holding-invalid-types = [ + { path = "tracing::span::Entered", reason = "generates incorrect spans when held across 'await' points" }, + { path = "tracing::span::EnteredSpan", reason = "generates incorrect spans when held across 'await' points" }, +] + +disallowed-macros = [ + { path = "tracing::instrument", reason = "tracing::instrument is error-prone. Use tracing::error_span! in the method body instead." }, +] diff --git a/rust/examples/chat.rs b/rust/examples/chat.rs new file mode 100644 index 000000000..37293c6bc --- /dev/null +++ b/rust/examples/chat.rs @@ -0,0 +1,122 @@ +//! Interactive chat with GitHub Copilot. +//! +//! Starts a GitHub Copilot CLI server, creates a session, and enters a read-eval-print +//! loop where each line you type is sent to the agent. Streaming is enabled so +//! response tokens print to stdout incrementally as they arrive. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example chat +//! ``` + +use std::io::{self, BufRead, Write}; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{ + HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, UserInputResponse, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent}; +use github_copilot_sdk::{Client, ClientOptions}; + +/// Handler that prints assistant message deltas as they stream in +/// and auto-approves permissions. +struct ChatHandler; + +#[async_trait] +impl SessionHandler for ChatHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { event, .. } => { + print_event(&event); + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + HandlerEvent::UserInput { question, .. } => { + // Prompt the user on behalf of the agent. + print!("\n[agent asks] {question}\n> "); + io::stdout().flush().ok(); + let answer = read_line().unwrap_or_default(); + HandlerResponse::UserInput(Some(UserInputResponse { + answer, + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } +} + +fn print_event(event: &SessionEvent) { + match event.event_type.as_str() { + "assistant.message_delta" => { + let text = event + .data + .get("deltaContent") + .and_then(|c| c.as_str()) + .unwrap_or(""); + print!("{text}"); + io::stdout().flush().ok(); + } + "assistant.message" => { + // Final message โ€” print a newline to terminate the streamed output. + println!(); + } + "session.error" => { + let msg = event + .data + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown error"); + eprintln!("\n[error] {msg}"); + } + _ => {} + } +} + +fn read_line() -> Option { + let stdin = io::stdin(); + let mut line = String::new(); + stdin.lock().read_line(&mut line).ok()?; + if line.is_empty() { + return None; // EOF + } + Some(line.trim_end_matches(&['\n', '\r'][..]).to_string()) +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + + let config = { + let mut cfg = SessionConfig::default(); + cfg.streaming = Some(true); + cfg.with_handler(Arc::new(ChatHandler)) + }; + let session = client.create_session(config).await?; + + println!( + "Session {} started. Type a message (Ctrl-D to quit).\n", + session.id() + ); + + loop { + print!("> "); + io::stdout().flush().ok(); + + let Some(line) = read_line() else { break }; + if line.is_empty() { + continue; + } + + session + .send_and_wait(MessageOptions::new(line).with_wait_timeout(Duration::from_secs(120))) + .await?; + } + + println!("\nGoodbye."); + session.destroy().await?; + Ok(()) +} diff --git a/rust/examples/hooks.rs b/rust/examples/hooks.rs new file mode 100644 index 000000000..86f6ceadc --- /dev/null +++ b/rust/examples/hooks.rs @@ -0,0 +1,133 @@ +//! Session hooks for logging and auditing. +//! +//! Demonstrates `SessionHooks` to intercept lifecycle events โ€” logging every +//! tool invocation, summarizing prompts, and recording session start/end +//! for audit purposes. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example hooks +//! ``` + +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{ + HookEvent, HookOutput, PostToolUseOutput, PreToolUseOutput, SessionEndOutput, SessionHooks, + SessionStartOutput, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +/// Hooks implementation that logs lifecycle events to stdout. +struct AuditHooks; + +#[async_trait] +impl SessionHooks for AuditHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::SessionStart { input, ctx } => { + println!( + "[audit] session {} started (source={}, cwd={})", + ctx.session_id, + input.source, + input.cwd.display(), + ); + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("You are being audited. Be concise.".to_string()), + ..Default::default() + }) + } + + HookEvent::PreToolUse { input, ctx } => { + println!( + "[audit] session {} โ€” pre tool use: {} (args: {})", + ctx.session_id, input.tool_name, input.tool_args, + ); + // Example: deny a specific tool by name. + if input.tool_name == "dangerous_tool" { + return HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by audit policy".to_string()), + ..Default::default() + }); + } + HookOutput::None + } + + HookEvent::PostToolUse { input, ctx } => { + println!( + "[audit] session {} โ€” post tool use: {} (result: {})", + ctx.session_id, input.tool_name, input.tool_result, + ); + HookOutput::PostToolUse(PostToolUseOutput::default()) + } + + HookEvent::UserPromptSubmitted { input, ctx } => { + println!( + "[audit] session {} โ€” user prompt ({} chars)", + ctx.session_id, + input.prompt.len(), + ); + HookOutput::None + } + + HookEvent::SessionEnd { input, ctx } => { + println!( + "[audit] session {} ended (reason={})", + ctx.session_id, input.reason, + ); + HookOutput::SessionEnd(SessionEndOutput { + session_summary: Some("Audited session complete.".to_string()), + ..Default::default() + }) + } + + HookEvent::ErrorOccurred { input, ctx } => { + eprintln!( + "[audit] session {} โ€” error in {}: {} (recoverable={})", + ctx.session_id, input.error_context, input.error, input.recoverable, + ); + HookOutput::None + } + + _ => HookOutput::None, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + + // hooks: true is set automatically when a hooks handler is provided. + let config = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(Arc::new(AuditHooks)); + let session = client.create_session(config).await?; + + println!( + "Session {} with audit hooks. Sending a message...\n", + session.id() + ); + + let response = session + .send_and_wait( + MessageOptions::new("Say hello in three languages.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + if let Some(event) = response { + let text = event + .data + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + println!("\n{text}"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs new file mode 100644 index 000000000..612792073 --- /dev/null +++ b/rust/examples/lifecycle_observer.rs @@ -0,0 +1,120 @@ +//! Observe lifecycle and event traffic without owning permission decisions. +//! +//! Demonstrates the channel-based observer APIs: +//! +//! - [`Client::subscribe_lifecycle`] โ€” `tokio::sync::broadcast::Receiver` of +//! every `session.lifecycle` notification (created / destroyed / errored / +//! foreground / background). Filter by matching on `event.event_type` in +//! the consumer. +//! - [`Session::subscribe`] โ€” receiver for the per-session `session.event` +//! stream (assistant messages, tool calls, permission prompts, etc.). +//! Observe-only โ€” the constructor handler still owns permission decisions. +//! - [`Client::state`] โ€” current connection state without polling. +//! - [`Client::get_session_metadata`] โ€” inspect a session without resuming +//! it. +//! - [`Client::force_stop`] โ€” synchronous shutdown for cleanup paths. +//! +//! Drop the receiver to unsubscribe โ€” there is no separate cancel handle. +//! Slow consumers receive `RecvError::Lagged(n)` and resync on the next +//! event; they do not block the producer. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example lifecycle_observer +//! ``` +//! +//! [`Client::subscribe_lifecycle`]: github_copilot_sdk::Client::subscribe_lifecycle +//! [`Session::subscribe`]: github_copilot_sdk::session::Session::subscribe +//! [`Client::state`]: github_copilot_sdk::Client::state +//! [`Client::get_session_metadata`]: github_copilot_sdk::Client::get_session_metadata +//! [`Client::force_stop`]: github_copilot_sdk::Client::force_stop + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionLifecycleEventType}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + println!("[client] state: {:?}", client.state()); + + // Wildcard lifecycle subscriber: see every session.lifecycle event, + // counting deletions inline by filtering on event_type. + let mut lifecycle_rx = client.subscribe_lifecycle(); + let deleted = Arc::new(AtomicUsize::new(0)); + let deleted_clone = Arc::clone(&deleted); + let lifecycle_task = tokio::spawn(async move { + while let Ok(event) = lifecycle_rx.recv().await { + let summary = event + .metadata + .as_ref() + .and_then(|m| m.summary.as_deref()) + .unwrap_or(""); + println!( + "[lifecycle:*] {:?} session={} summary={}", + event.event_type, event.session_id, summary, + ); + if event.event_type == SessionLifecycleEventType::Deleted { + deleted_clone.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let config = SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + println!("[client] state after create: {:?}", client.state()); + + // Per-session observer: see every assistant message, tool call, etc. + // Subscribers fire alongside the constructor handler; they're great for + // logging or metrics that should run regardless of how the handler + // decides to respond. + let mut session_rx = session.subscribe(); + let session_events = Arc::new(AtomicUsize::new(0)); + let session_events_clone = Arc::clone(&session_events); + let session_task = tokio::spawn(async move { + while let Ok(event) = session_rx.recv().await { + session_events_clone.fetch_add(1, Ordering::Relaxed); + println!("[session-event] {}", event.event_type); + } + }); + + if let Some(metadata) = client.get_session_metadata(session.id()).await? { + println!( + "[metadata] id={} modified={} summary={}", + metadata.session_id, + metadata.modified_time, + metadata.summary.as_deref().unwrap_or(""), + ); + } + + session + .send_and_wait( + MessageOptions::new("Say hello in five words or fewer.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + session.destroy().await?; + + // Synchronous shutdown โ€” useful in panicking-cleanup paths or tests + // where you don't have an async runtime available to await `stop()`. + // For graceful shutdown in normal flow, prefer `client.stop().await`. + client.force_stop(); + println!("[client] state after force_stop: {:?}", client.state()); + + // Stopping the client closes the broadcast senders, so the consumer + // tasks observe `RecvError::Closed` and exit cleanly. + let _ = lifecycle_task.await; + let _ = session_task.await; + + println!( + "\n[summary] session_events={} sessions_deleted={}", + session_events.load(Ordering::Relaxed), + deleted.load(Ordering::Relaxed), + ); + + Ok(()) +} diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs new file mode 100644 index 000000000..0dbbb3414 --- /dev/null +++ b/rust/examples/session_fs.rs @@ -0,0 +1,139 @@ +//! Custom `SessionFsProvider` backed by an in-memory map. +//! +//! Demonstrates registering a [`SessionFsProvider`] so the CLI delegates all +//! per-session filesystem operations to your code. Useful for sandboxed +//! sessions, projecting files into virtual storage, or applying permission +//! policies before bytes are read or written. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example session_fs +//! ``` + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + SessionFsProvider, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; +use parking_lot::Mutex; + +struct InMemoryProvider { + files: Mutex>, +} + +impl InMemoryProvider { + fn new() -> Self { + let mut seed = HashMap::new(); + seed.insert( + "/workspace/README.md".to_string(), + "# Demo project\n\nThis file lives in memory.\n".to_string(), + ); + Self { + files: Mutex::new(seed), + } + } +} + +#[async_trait] +impl SessionFsProvider for InMemoryProvider { + async fn read_file(&self, path: &str) -> Result { + self.files + .lock() + .get(path) + .cloned() + .ok_or_else(|| FsError::NotFound(path.to_string())) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + Ok(()) + } + + async fn exists(&self, path: &str) -> Result { + Ok(self.files.lock().contains_key(path)) + } + + async fn stat(&self, path: &str) -> Result { + let files = self.files.lock(); + let content = files + .get(path) + .ok_or_else(|| FsError::NotFound(path.to_string()))?; + Ok(FileInfo::new( + true, + false, + content.len() as i64, + "2025-01-01T00:00:00Z", + "2025-01-01T00:00:00Z", + )) + } + + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let prefix = if path.ends_with('/') { + path.to_string() + } else { + format!("{path}/") + }; + let names: Vec = self + .files + .lock() + .keys() + .filter_map(|k| k.strip_prefix(&prefix)) + .filter(|rest| !rest.is_empty()) + .map(|rest| { + let name = rest.split('/').next().unwrap_or(rest); + DirEntry::new(name, DirEntryKind::File) + }) + .collect(); + Ok(names) + } + + async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { + if self.files.lock().remove(path).is_none() && !force { + return Err(FsError::NotFound(path.to_string())); + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider: Arc = Arc::new(InMemoryProvider::new()); + + let options = { + let mut opts = ClientOptions::default(); + opts.session_fs = Some(SessionFsConfig::new( + "/workspace", + "/workspace/.copilot", + SessionFsConventions::Posix, + )); + opts + }; + + let client = Client::start(options).await?; + let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_session_fs_provider(provider), + ) + .await?; + + let response = session + .send(MessageOptions::new("Summarize README.md.")) + .await?; + println!("Assistant: {response}"); + + Ok(()) +} diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs new file mode 100644 index 000000000..55bacbbe6 --- /dev/null +++ b/rust/examples/tool_server.rs @@ -0,0 +1,187 @@ +//! Define custom tools and expose them to the Copilot agent. +//! +//! Registers two tools โ€” `get_weather` (typed params via schemars) and +//! `roll_dice` (manual schema) โ€” then asks the agent a question that +//! triggers tool use. +//! +//! Requires the `derive` feature for typed parameter schemas: +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example tool_server --features derive +//! ``` + +// Gate the entire example behind the `derive` feature so it compiles +// (as a stub that prints the required feature flag) when clippy/check +// runs without the feature. +#[cfg(not(feature = "derive"))] +fn main() { + eprintln!("This example requires the `derive` feature:"); + eprintln!(" cargo run -p github-copilot-sdk --example tool_server --features derive"); + std::process::exit(1); +} + +#[cfg(feature = "derive")] +use std::sync::Arc; +#[cfg(feature = "derive")] +use std::time::Duration; + +#[cfg(feature = "derive")] +use async_trait::async_trait; +#[cfg(feature = "derive")] +use github_copilot_sdk::handler::ApproveAllHandler; +#[cfg(feature = "derive")] +use github_copilot_sdk::tool::{ + JsonSchema, ToolHandler, ToolHandlerRouter, schema_for, tool_parameters, +}; +#[cfg(feature = "derive")] +use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult}; +#[cfg(feature = "derive")] +use github_copilot_sdk::{Client, ClientOptions, Error}; +#[cfg(feature = "derive")] +use serde::Deserialize; + +// --------------------------------------------------------------------------- +// Tool 1: get_weather โ€” typed parameters derived from a Rust struct +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { + /// City name (e.g. "Seattle"). + city: String, + /// Temperature unit: "celsius" or "fahrenheit". + unit: Option, +} + +#[cfg(feature = "derive")] +struct GetWeatherTool; + +#[cfg(feature = "derive")] +#[async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + let mut tool = Tool::default(); + tool.name = "get_weather".to_string(); + tool.description = "Get the current weather for a city.".to_string(); + tool.parameters = tool_parameters(schema_for::()); + tool + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(invocation.arguments)?; + let unit = params.unit.as_deref().unwrap_or("celsius"); + // Stub response โ€” a real implementation would call a weather API. + let reply = format!( + "Weather in {}: 18ยฐ{}, partly cloudy", + params.city, + if unit == "fahrenheit" { "F" } else { "C" }, + ); + Ok(ToolResult::Text(reply)) + } +} + +// --------------------------------------------------------------------------- +// Tool 2: roll_dice โ€” manual JSON Schema +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +struct RollDiceTool; + +#[cfg(feature = "derive")] +#[async_trait] +impl ToolHandler for RollDiceTool { + fn tool(&self) -> Tool { + let mut tool = Tool::default(); + tool.name = "roll_dice".to_string(); + tool.description = "Roll one or more dice and return the total.".to_string(); + tool.parameters = tool_parameters(serde_json::json!({ + "type": "object", + "properties": { + "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, + "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } + } + })); + tool + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let sides = invocation + .arguments + .get("sides") + .and_then(|v| v.as_u64()) + .unwrap_or(6) + .clamp(1, 1000) as u32; + let count = invocation + .arguments + .get("count") + .and_then(|v| v.as_u64()) + .unwrap_or(1) + .clamp(1, 100) as u32; + + let mut total = 0u32; + let mut rolls = Vec::with_capacity(count as usize); + for _ in 0..count { + // Simple deterministic "random" for the example. + let roll = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos() + % sides) + + 1; + rolls.push(roll); + total += roll; + } + + Ok(ToolResult::Text(format!( + "Rolled {count}d{sides}: {rolls:?} = {total}" + ))) + } +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool), Box::new(RollDiceTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let handler = Arc::new(router); + + let client = Client::start(ClientOptions::default()).await?; + + let config = { + let mut cfg = SessionConfig::default(); + cfg.tools = Some(tools); + cfg.with_handler(handler) + }; + let session = client.create_session(config).await?; + + println!( + "Session {} โ€” asking about weather + dice...\n", + session.id() + ); + + let response = session + .send_and_wait( + MessageOptions::new("What's the weather in Seattle? Also roll 3d20 for me.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + if let Some(event) = response { + let text = event + .data + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + println!("{text}"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/rust/release-plz.toml b/rust/release-plz.toml new file mode 100644 index 000000000..82c38ffd7 --- /dev/null +++ b/rust/release-plz.toml @@ -0,0 +1,35 @@ +[workspace] +# release-plz config for the Rust github-copilot-sdk crate. +# +# The crate lives in the `rust/` subdirectory of the monorepo, so +# invoke release-plz from this directory (via the release-plz workflows +# under `.github/workflows/`). release-plz will: +# +# 1. `release-plz release-pr`: open a PR updating `rust/Cargo.toml`'s +# version and `rust/CHANGELOG.md` based on conventional-commit +# history on `tclem/rust-sdk-release-prep`-style branches. +# 2. `release-plz release`: after that PR is merged to main, publish +# the tagged version to crates.io and create a `rust-vX.Y.Z` git +# tag. +# +# Publishing requires a `CARGO_REGISTRY_TOKEN` repository secret scoped +# to the `github-copilot-sdk` crate owner account. See +# `.github/workflows/rust-publish-release.yml` for the setup checklist. +# +# Reference: https://release-plz.dev/docs/config +changelog_update = true +dependencies_update = false +git_release_enable = true +# Prefix crate git tags so they don't collide with the monorepo's +# top-level `vX.Y.Z` tags used by the other SDKs. +git_tag_name = "rust-v{{ version }}" +git_release_name = "rust-v{{ version }}" + +[[package]] +name = "github-copilot-sdk" +changelog_path = "CHANGELOG.md" +# Mark pre-1.0 publishes as prereleases on the GitHub release page so +# consumers don't pick them up as "stable" by default. Maintainers +# should flip this (or remove it) when cutting 1.0. +git_release_type = "auto" + diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 000000000..2259b2c8a --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.94.0" +components = ["clippy", "rust-analyzer", "rustfmt"] +profile = "default" diff --git a/rust/src/embeddedcli.rs b/rust/src/embeddedcli.rs new file mode 100644 index 000000000..d0e5ea9ff --- /dev/null +++ b/rust/src/embeddedcli.rs @@ -0,0 +1,278 @@ +#[cfg(any(has_bundled_cli, test))] +use std::fs; +#[cfg(any(has_bundled_cli, test))] +use std::io::{self, Read, Write}; +#[cfg(any(has_bundled_cli, test))] +use std::path::Path; +use std::path::PathBuf; +use std::sync::OnceLock; + +#[cfg(has_bundled_cli)] +use tracing::{info, warn}; + +// When the SDK is built with COPILOT_CLI_VERSION set, build.rs generates +// bundled_cli.rs with the compressed binary bytes, hash, and version. +#[cfg(has_bundled_cli)] +mod build_time { + include!(concat!(env!("OUT_DIR"), "/bundled_cli.rs")); +} + +static INSTALLED_PATH: OnceLock> = OnceLock::new(); + +/// Returns the bundled CLI version string, if one was embedded at build time. +pub fn bundled_version() -> Option<&'static str> { + #[cfg(has_bundled_cli)] + { + Some(build_time::CLI_VERSION) + } + #[cfg(not(has_bundled_cli))] + { + None + } +} + +/// Returns the path to the installed CLI binary, lazily extracting on first call. +/// +/// When the SDK was built with `COPILOT_CLI_VERSION` set, this extracts the +/// embedded binary to `~/.cache/github-copilot-sdk-{version}/copilot` (or +/// `copilot.exe` on Windows), verifies the SHA-256 hash, and returns the +/// path. Subsequent calls return the cached result. +/// +/// Returns `None` if no CLI was embedded at build time. +pub fn path() -> Option { + INSTALLED_PATH + .get_or_init(|| { + #[cfg(has_bundled_cli)] + { + match install( + build_time::CLI_BYTES, + build_time::CLI_HASH, + build_time::CLI_VERSION, + ) { + Ok(path) => { + info!(path = %path.display(), version = build_time::CLI_VERSION, "embedded CLI installed"); + return Some(path); + } + Err(e) => { + warn!(error = %e, "embedded CLI installation failed"); + } + } + } + None + }) + .clone() +} + +#[cfg(has_bundled_cli)] +fn install( + compressed: &[u8], + expected_hash: [u8; 32], + version: &str, +) -> Result { + let verbose = std::env::var("COPILOT_CLI_INSTALL_VERBOSE").ok().as_deref() == Some("1"); + + let cache = dirs::cache_dir().unwrap_or_else(std::env::temp_dir); + // Use a versioned directory so multiple versions can coexist, + // but keep the binary named `copilot` โ€” the CLI checks argv[0] + // for this exact name. + let install_dir = if version.is_empty() { + cache.join("github-copilot-sdk") + } else { + cache.join(format!("github-copilot-sdk-{}", sanitize_version(version))) + }; + fs::create_dir_all(&install_dir).map_err(EmbeddedCliError::CreateDir)?; + + let binary_name = binary_name(); + let final_path = install_dir.join(&binary_name); + + // If the binary already exists and hash matches, skip extraction. + if final_path.is_file() { + let existing_hash = hash_file(&final_path)?; + if existing_hash == expected_hash { + if verbose { + eprintln!("embedded CLI already installed at {}", final_path.display()); + } + return Ok(final_path); + } + if verbose { + eprintln!("embedded CLI hash mismatch, reinstalling"); + } + } + + let start = std::time::Instant::now(); + let decompressed = decompress(compressed)?; + + let actual_hash = sha256(&decompressed); + if actual_hash != expected_hash { + return Err(EmbeddedCliError::HashMismatch); + } + + write_binary(&final_path, &decompressed)?; + + if verbose { + eprintln!( + "embedded CLI installed at {} in {:?}", + final_path.display(), + start.elapsed() + ); + } + + Ok(final_path) +} + +#[cfg(any(has_bundled_cli, test))] +fn binary_name() -> String { + if cfg!(target_os = "windows") { + "copilot.exe".to_string() + } else { + "copilot".to_string() + } +} + +#[cfg(has_bundled_cli)] +fn sanitize_version(version: &str) -> String { + version + .chars() + .map(|c| match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '.' | '-' | '_' => c, + _ => '_', + }) + .collect() +} + +#[cfg(any(has_bundled_cli, test))] +fn decompress(data: &[u8]) -> Result, EmbeddedCliError> { + let mut decoder = zstd::Decoder::new(data).map_err(EmbeddedCliError::Decompress)?; + let mut out = Vec::new(); + decoder + .read_to_end(&mut out) + .map_err(EmbeddedCliError::Decompress)?; + Ok(out) +} + +#[cfg(any(has_bundled_cli, test))] +fn sha256(data: &[u8]) -> [u8; 32] { + use sha2::Digest; + let mut hasher = sha2::Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +#[cfg(has_bundled_cli)] +fn hash_file(path: &Path) -> Result<[u8; 32], EmbeddedCliError> { + use sha2::Digest; + let mut file = fs::File::open(path).map_err(EmbeddedCliError::Io)?; + let mut hasher = sha2::Sha256::new(); + let mut buf = [0u8; 8192]; + loop { + let n = file.read(&mut buf).map_err(EmbeddedCliError::Io)?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + Ok(hasher.finalize().into()) +} + +#[cfg(any(has_bundled_cli, test))] +fn write_binary(path: &Path, data: &[u8]) -> Result<(), EmbeddedCliError> { + let mut file = fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .map_err(EmbeddedCliError::Io)?; + + file.write_all(data).map_err(EmbeddedCliError::Io)?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(path, fs::Permissions::from_mode(0o755)) + .map_err(EmbeddedCliError::Io)?; + } + + Ok(()) +} + +#[cfg(any(has_bundled_cli, test))] +#[derive(Debug, thiserror::Error)] +#[allow(dead_code)] +enum EmbeddedCliError { + #[error("failed to create install directory: {0}")] + CreateDir(io::Error), + + #[error("decompression failed: {0}")] + Decompress(io::Error), + + #[error("SHA-256 hash of decompressed binary does not match expected hash")] + HashMismatch, + + #[error("I/O error: {0}")] + Io(io::Error), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn install_extracts_to_cache_dir() { + let temp = tempfile::tempdir().expect("should create temp dir"); + let original = b"fake copilot binary"; + let hash = sha256(original); + let compressed = zstd::encode_all(&original[..], 3).expect("compression should succeed"); + + // Override cache dir via env for test isolation. + let path = install_to_dir(&temp, &compressed, hash); + let expected_name = binary_name(); + assert!(path.is_file()); + assert_eq!( + path.file_name().and_then(|s| s.to_str()), + Some(expected_name.as_str()) + ); + + let installed_content = fs::read(&path).expect("should read installed binary"); + assert_eq!(installed_content, original); + + // Second install should be idempotent (hash matches, skips extraction). + let path2 = install_to_dir(&temp, &compressed, hash); + assert_eq!(path, path2); + } + + #[test] + fn install_rejects_hash_mismatch() { + let temp = tempfile::tempdir().expect("should create temp dir"); + let original = b"fake copilot binary"; + let wrong_hash = [0u8; 32]; + let compressed = zstd::encode_all(&original[..], 3).expect("compression should succeed"); + + let result = install_to_dir_result(&temp, &compressed, wrong_hash); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("SHA-256"),); + } + + // Test helpers that install to a specific directory instead of the global cache. + fn install_to_dir(temp: &tempfile::TempDir, compressed: &[u8], hash: [u8; 32]) -> PathBuf { + install_to_dir_result(temp, compressed, hash).expect("install should succeed") + } + + fn install_to_dir_result( + temp: &tempfile::TempDir, + compressed: &[u8], + hash: [u8; 32], + ) -> Result { + let install_dir = temp.path().to_path_buf(); + fs::create_dir_all(&install_dir).expect("create dir"); + let binary_name = binary_name(); + let final_path = install_dir.join(&binary_name); + + let decompressed = decompress(compressed)?; + let actual_hash = sha256(&decompressed); + if actual_hash != hash { + return Err(EmbeddedCliError::HashMismatch); + } + write_binary(&final_path, &decompressed)?; + Ok(final_path) + } +} diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs new file mode 100644 index 000000000..2b78fbae5 --- /dev/null +++ b/rust/src/generated/api_types.rs @@ -0,0 +1,3460 @@ +//! Auto-generated from api.schema.json โ€” do not edit manually. + +#![allow(clippy::large_enum_variant)] + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::types::{RequestId, SessionId}; + +/// JSON-RPC method name constants. +pub mod rpc_methods { + /// `ping` + pub const PING: &str = "ping"; + /// `connect` + pub const CONNECT: &str = "connect"; + /// `models.list` + pub const MODELS_LIST: &str = "models.list"; + /// `tools.list` + pub const TOOLS_LIST: &str = "tools.list"; + /// `account.getQuota` + pub const ACCOUNT_GETQUOTA: &str = "account.getQuota"; + /// `mcp.config.list` + pub const MCP_CONFIG_LIST: &str = "mcp.config.list"; + /// `mcp.config.add` + pub const MCP_CONFIG_ADD: &str = "mcp.config.add"; + /// `mcp.config.update` + pub const MCP_CONFIG_UPDATE: &str = "mcp.config.update"; + /// `mcp.config.remove` + pub const MCP_CONFIG_REMOVE: &str = "mcp.config.remove"; + /// `mcp.config.enable` + pub const MCP_CONFIG_ENABLE: &str = "mcp.config.enable"; + /// `mcp.config.disable` + pub const MCP_CONFIG_DISABLE: &str = "mcp.config.disable"; + /// `mcp.discover` + pub const MCP_DISCOVER: &str = "mcp.discover"; + /// `skills.config.setDisabledSkills` + pub const SKILLS_CONFIG_SETDISABLEDSKILLS: &str = "skills.config.setDisabledSkills"; + /// `skills.discover` + pub const SKILLS_DISCOVER: &str = "skills.discover"; + /// `sessionFs.setProvider` + pub const SESSIONFS_SETPROVIDER: &str = "sessionFs.setProvider"; + /// `sessions.fork` + pub const SESSIONS_FORK: &str = "sessions.fork"; + /// `session.suspend` + pub const SESSION_SUSPEND: &str = "session.suspend"; + /// `session.auth.getStatus` + pub const SESSION_AUTH_GETSTATUS: &str = "session.auth.getStatus"; + /// `session.model.getCurrent` + pub const SESSION_MODEL_GETCURRENT: &str = "session.model.getCurrent"; + /// `session.model.switchTo` + pub const SESSION_MODEL_SWITCHTO: &str = "session.model.switchTo"; + /// `session.mode.get` + pub const SESSION_MODE_GET: &str = "session.mode.get"; + /// `session.mode.set` + pub const SESSION_MODE_SET: &str = "session.mode.set"; + /// `session.name.get` + pub const SESSION_NAME_GET: &str = "session.name.get"; + /// `session.name.set` + pub const SESSION_NAME_SET: &str = "session.name.set"; + /// `session.plan.read` + pub const SESSION_PLAN_READ: &str = "session.plan.read"; + /// `session.plan.update` + pub const SESSION_PLAN_UPDATE: &str = "session.plan.update"; + /// `session.plan.delete` + pub const SESSION_PLAN_DELETE: &str = "session.plan.delete"; + /// `session.workspaces.getWorkspace` + pub const SESSION_WORKSPACES_GETWORKSPACE: &str = "session.workspaces.getWorkspace"; + /// `session.workspaces.listFiles` + pub const SESSION_WORKSPACES_LISTFILES: &str = "session.workspaces.listFiles"; + /// `session.workspaces.readFile` + pub const SESSION_WORKSPACES_READFILE: &str = "session.workspaces.readFile"; + /// `session.workspaces.createFile` + pub const SESSION_WORKSPACES_CREATEFILE: &str = "session.workspaces.createFile"; + /// `session.instructions.getSources` + pub const SESSION_INSTRUCTIONS_GETSOURCES: &str = "session.instructions.getSources"; + /// `session.fleet.start` + pub const SESSION_FLEET_START: &str = "session.fleet.start"; + /// `session.agent.list` + pub const SESSION_AGENT_LIST: &str = "session.agent.list"; + /// `session.agent.getCurrent` + pub const SESSION_AGENT_GETCURRENT: &str = "session.agent.getCurrent"; + /// `session.agent.select` + pub const SESSION_AGENT_SELECT: &str = "session.agent.select"; + /// `session.agent.deselect` + pub const SESSION_AGENT_DESELECT: &str = "session.agent.deselect"; + /// `session.agent.reload` + pub const SESSION_AGENT_RELOAD: &str = "session.agent.reload"; + /// `session.tasks.startAgent` + pub const SESSION_TASKS_STARTAGENT: &str = "session.tasks.startAgent"; + /// `session.tasks.list` + pub const SESSION_TASKS_LIST: &str = "session.tasks.list"; + /// `session.tasks.promoteToBackground` + pub const SESSION_TASKS_PROMOTETOBACKGROUND: &str = "session.tasks.promoteToBackground"; + /// `session.tasks.cancel` + pub const SESSION_TASKS_CANCEL: &str = "session.tasks.cancel"; + /// `session.tasks.remove` + pub const SESSION_TASKS_REMOVE: &str = "session.tasks.remove"; + /// `session.skills.list` + pub const SESSION_SKILLS_LIST: &str = "session.skills.list"; + /// `session.skills.enable` + pub const SESSION_SKILLS_ENABLE: &str = "session.skills.enable"; + /// `session.skills.disable` + pub const SESSION_SKILLS_DISABLE: &str = "session.skills.disable"; + /// `session.skills.reload` + pub const SESSION_SKILLS_RELOAD: &str = "session.skills.reload"; + /// `session.mcp.list` + pub const SESSION_MCP_LIST: &str = "session.mcp.list"; + /// `session.mcp.enable` + pub const SESSION_MCP_ENABLE: &str = "session.mcp.enable"; + /// `session.mcp.disable` + pub const SESSION_MCP_DISABLE: &str = "session.mcp.disable"; + /// `session.mcp.reload` + pub const SESSION_MCP_RELOAD: &str = "session.mcp.reload"; + /// `session.mcp.oauth.login` + pub const SESSION_MCP_OAUTH_LOGIN: &str = "session.mcp.oauth.login"; + /// `session.plugins.list` + pub const SESSION_PLUGINS_LIST: &str = "session.plugins.list"; + /// `session.extensions.list` + pub const SESSION_EXTENSIONS_LIST: &str = "session.extensions.list"; + /// `session.extensions.enable` + pub const SESSION_EXTENSIONS_ENABLE: &str = "session.extensions.enable"; + /// `session.extensions.disable` + pub const SESSION_EXTENSIONS_DISABLE: &str = "session.extensions.disable"; + /// `session.extensions.reload` + pub const SESSION_EXTENSIONS_RELOAD: &str = "session.extensions.reload"; + /// `session.tools.handlePendingToolCall` + pub const SESSION_TOOLS_HANDLEPENDINGTOOLCALL: &str = "session.tools.handlePendingToolCall"; + /// `session.commands.handlePendingCommand` + pub const SESSION_COMMANDS_HANDLEPENDINGCOMMAND: &str = "session.commands.handlePendingCommand"; + /// `session.ui.elicitation` + pub const SESSION_UI_ELICITATION: &str = "session.ui.elicitation"; + /// `session.ui.handlePendingElicitation` + pub const SESSION_UI_HANDLEPENDINGELICITATION: &str = "session.ui.handlePendingElicitation"; + /// `session.permissions.handlePendingPermissionRequest` + pub const SESSION_PERMISSIONS_HANDLEPENDINGPERMISSIONREQUEST: &str = + "session.permissions.handlePendingPermissionRequest"; + /// `session.permissions.setApproveAll` + pub const SESSION_PERMISSIONS_SETAPPROVEALL: &str = "session.permissions.setApproveAll"; + /// `session.permissions.resetSessionApprovals` + pub const SESSION_PERMISSIONS_RESETSESSIONAPPROVALS: &str = + "session.permissions.resetSessionApprovals"; + /// `session.log` + pub const SESSION_LOG: &str = "session.log"; + /// `session.shell.exec` + pub const SESSION_SHELL_EXEC: &str = "session.shell.exec"; + /// `session.shell.kill` + pub const SESSION_SHELL_KILL: &str = "session.shell.kill"; + /// `session.history.compact` + pub const SESSION_HISTORY_COMPACT: &str = "session.history.compact"; + /// `session.history.truncate` + pub const SESSION_HISTORY_TRUNCATE: &str = "session.history.truncate"; + /// `session.usage.getMetrics` + pub const SESSION_USAGE_GETMETRICS: &str = "session.usage.getMetrics"; + /// `sessionFs.readFile` + pub const SESSIONFS_READFILE: &str = "sessionFs.readFile"; + /// `sessionFs.writeFile` + pub const SESSIONFS_WRITEFILE: &str = "sessionFs.writeFile"; + /// `sessionFs.appendFile` + pub const SESSIONFS_APPENDFILE: &str = "sessionFs.appendFile"; + /// `sessionFs.exists` + pub const SESSIONFS_EXISTS: &str = "sessionFs.exists"; + /// `sessionFs.stat` + pub const SESSIONFS_STAT: &str = "sessionFs.stat"; + /// `sessionFs.mkdir` + pub const SESSIONFS_MKDIR: &str = "sessionFs.mkdir"; + /// `sessionFs.readdir` + pub const SESSIONFS_READDIR: &str = "sessionFs.readdir"; + /// `sessionFs.readdirWithTypes` + pub const SESSIONFS_READDIRWITHTYPES: &str = "sessionFs.readdirWithTypes"; + /// `sessionFs.rm` + pub const SESSIONFS_RM: &str = "sessionFs.rm"; + /// `sessionFs.rename` + pub const SESSIONFS_RENAME: &str = "sessionFs.rename"; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountGetQuotaRequest { + /// GitHub token for per-user quota lookup. When provided, resolves this token to determine the user's quota instead of using the global auth. + #[serde(skip_serializing_if = "Option::is_none")] + pub git_hub_token: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountQuotaSnapshot { + /// Number of requests included in the entitlement + pub entitlement_requests: i64, + /// Whether the user has an unlimited usage entitlement + pub is_unlimited_entitlement: bool, + /// Number of overage requests made this period + pub overage: f64, + /// Whether overage is allowed when quota is exhausted + pub overage_allowed_with_exhausted_quota: bool, + /// Percentage of entitlement remaining + pub remaining_percentage: f64, + /// Date when the quota resets (ISO 8601 string) + #[serde(skip_serializing_if = "Option::is_none")] + pub reset_date: Option, + /// Whether usage is still permitted after quota exhaustion + pub usage_allowed_with_exhausted_quota: bool, + /// Number of requests used so far this period + pub used_requests: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountGetQuotaResult { + /// Quota snapshots keyed by type (e.g., chat, completions, premium_interactions) + pub quota_snapshots: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentInfo { + /// Description of the agent's purpose + pub description: String, + /// Human-readable display name + pub display_name: String, + /// Unique identifier of the custom agent + pub name: String, + /// Absolute local file path of the agent definition. Only set for file-based agents loaded from disk; remote agents do not have a path. + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentGetCurrentResult { + /// Currently selected custom agent, or null if using the default agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentList { + /// Available custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentReloadResult { + /// Reloaded custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentSelectRequest { + /// Name of the custom agent to select + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentSelectResult { + /// The newly selected custom agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsHandlePendingCommandRequest { + /// Error message if the command handler failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Request ID from the command invocation event + pub request_id: RequestId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsHandlePendingCommandResult { + /// Whether the command was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConnectRequest { + /// Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN + #[serde(skip_serializing_if = "Option::is_none")] + pub token: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConnectResult { + /// Always true on success + pub ok: bool, + /// Server protocol version number + pub protocol_version: i64, + /// Server package version + pub version: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CurrentModel { + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DiscoveredMcpServer { + /// Whether the server is enabled (not in the disabled list) + pub enabled: bool, + /// Server name (config key) + pub name: String, + /// Configuration source + pub source: DiscoveredMcpServerSource, + /// Server transport type: stdio, http, sse, or memory (local configs are normalized to stdio) + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EmbeddedBlobResourceContents { + /// Base64-encoded binary content of the resource + pub blob: String, + /// MIME type of the blob content + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// URI identifying the resource + pub uri: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EmbeddedTextResourceContents { + /// MIME type of the text content + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Text content of the resource + pub text: String, + /// URI identifying the resource + pub uri: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Extension { + /// Source-qualified ID (e.g., 'project:my-ext', 'user:auth-helper') + pub id: String, + /// Extension name (directory name) + pub name: String, + /// Process ID if the extension is running + #[serde(skip_serializing_if = "Option::is_none")] + pub pid: Option, + /// Discovery source: project (.github/extensions/) or user (~/.copilot/extensions/) + pub source: ExtensionSource, + /// Current status: running, disabled, failed, or starting + pub status: ExtensionStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionList { + /// Discovered extensions and their current status + pub extensions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsDisableRequest { + /// Source-qualified extension ID to disable + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsEnableRequest { + /// Source-qualified extension ID to enable + pub id: String, +} + +/// Expanded external tool result payload +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlm { + /// Structured content blocks from the tool + #[serde(default)] + pub contents: Vec, + /// Optional error message for failed executions + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Execution outcome classification. Optional for back-compat; normalized to 'success' (or 'failure' when error is present) when missing or unrecognized. + #[serde(skip_serializing_if = "Option::is_none")] + pub result_type: Option, + /// Detailed log content for timeline display + #[serde(skip_serializing_if = "Option::is_none")] + pub session_log: Option, + /// Text result returned to the model + pub text_result_for_llm: String, + /// Optional tool-specific telemetry + #[serde(default)] + pub tool_telemetry: HashMap, +} + +/// Audio content block with base64-encoded data +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentAudio { + /// Base64-encoded audio data + pub data: String, + /// MIME type of the audio (e.g., audio/wav, audio/mpeg) + pub mime_type: String, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentAudioType, +} + +/// Image content block with base64-encoded data +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentImage { + /// Base64-encoded image data + pub data: String, + /// MIME type of the image (e.g., image/png, image/jpeg) + pub mime_type: String, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentImageType, +} + +/// Embedded resource content block with inline text or binary data +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentResource { + /// The embedded resource contents, either text or base64-encoded binary + pub resource: serde_json::Value, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentResourceType, +} + +/// Icon image for a resource +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentResourceLinkIcon { + /// MIME type of the icon image + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Available icon sizes (e.g., ['16x16', '32x32']) + #[serde(default)] + pub sizes: Vec, + /// URL or path to the icon image + pub src: String, + /// Theme variant this icon is intended for + #[serde(skip_serializing_if = "Option::is_none")] + pub theme: Option, +} + +/// Resource link content block referencing an external resource +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentResourceLink { + /// Human-readable description of the resource + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Icons associated with this resource + #[serde(default)] + pub icons: Vec, + /// MIME type of the resource content + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + /// Resource name identifier + pub name: String, + /// Size of the resource in bytes + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + /// Human-readable display title for the resource + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentResourceLinkType, + /// URI identifying the resource + pub uri: String, +} + +/// Terminal/shell output content block with optional exit code and working directory +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentTerminal { + /// Working directory where the command was executed + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + /// Process exit code, if the command has completed + #[serde(skip_serializing_if = "Option::is_none")] + pub exit_code: Option, + /// Terminal/shell output text + pub text: String, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentTerminalType, +} + +/// Plain text content block +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolTextResultForLlmContentText { + /// The text content + pub text: String, + /// Content block type discriminator + pub r#type: ExternalToolTextResultForLlmContentTextType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FleetStartRequest { + /// Optional user prompt to combine with fleet instructions + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FleetStartResult { + /// Whether fleet mode was successfully activated + pub started: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HandlePendingToolCallRequest { + /// Error message if the tool call failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Request ID of the pending tool call + pub request_id: RequestId, + /// Tool call result (string or expanded result object) + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HandlePendingToolCallResult { + /// Whether the tool call result was handled successfully + pub success: bool, +} + +/// Post-compaction context window usage breakdown +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryCompactContextWindow { + /// Token count from non-system messages (user, assistant, tool) + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Current total tokens in the context window (system + conversation + tool definitions) + pub current_tokens: i64, + /// Current number of messages in the conversation + pub messages_length: i64, + /// Token count from system message(s) + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Maximum token count for the model's context window + pub token_limit: i64, + /// Token count from tool definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryCompactResult { + /// Post-compaction context window usage breakdown + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + /// Number of messages removed during compaction + pub messages_removed: i64, + /// Whether compaction completed successfully + pub success: bool, + /// Number of tokens freed by compaction + pub tokens_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryTruncateRequest { + /// Event ID to truncate to. This event and all events after it are removed from the session. + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryTruncateResult { + /// Number of events that were removed + pub events_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InstructionsSources { + /// Glob pattern from frontmatter โ€” when set, this instruction applies only to matching files + #[serde(skip_serializing_if = "Option::is_none")] + pub apply_to: Option, + /// Raw content of the instruction file + pub content: String, + /// Short description (body after frontmatter) for use in instruction tables + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Unique identifier for this source (used for toggling) + pub id: String, + /// Human-readable label + pub label: String, + /// Where this source lives โ€” used for UI grouping + pub location: InstructionsSourcesLocation, + /// File path relative to repo or absolute for home + pub source_path: String, + /// Category of instruction source โ€” used for merge logic + pub r#type: InstructionsSourcesType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InstructionsGetSourcesResult { + /// Instruction sources for the session + pub sources: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogRequest { + /// When true, the message is transient and not persisted to the session event log on disk + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// Log severity level. Determines how the message is displayed in the timeline. Defaults to "info". + #[serde(skip_serializing_if = "Option::is_none")] + pub level: Option, + /// Human-readable message + pub message: String, + /// Optional URL the user can open in their browser for more details + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogResult { + /// The unique identifier of the emitted session event + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigAddRequest { + /// MCP server configuration (local/stdio or remote/http) + pub config: serde_json::Value, + /// Unique name for the MCP server + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigDisableRequest { + /// Names of MCP servers to disable. Each server is added to the persisted disabled list so new sessions skip it. Already-disabled names are ignored. Active sessions keep their current connections until they end. + pub names: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigEnableRequest { + /// Names of MCP servers to enable. Each server is removed from the persisted disabled list so new sessions spawn it. Unknown or already-enabled names are ignored. + pub names: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigList { + /// All MCP servers from user config, keyed by name + pub servers: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigRemoveRequest { + /// Name of the MCP server to remove + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigUpdateRequest { + /// MCP server configuration (local/stdio or remote/http) + pub config: serde_json::Value, + /// Name of the MCP server to update + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDisableRequest { + /// Name of the MCP server to disable + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDiscoverRequest { + /// Working directory used as context for discovery (e.g., plugin resolution) + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDiscoverResult { + /// MCP servers discovered from all sources + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpEnableRequest { + /// Name of the MCP server to enable + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthLoginRequest { + /// Optional override for the body text shown on the OAuth loopback callback success page. When omitted, the runtime applies a neutral fallback; callers driving interactive auth should pass surface-specific copy telling the user where to return. + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_success_message: Option, + /// Optional override for the OAuth client display name shown on the consent screen. Applies to newly registered dynamic clients only โ€” existing registrations keep the name they were created with. When omitted, the runtime applies a neutral fallback; callers driving interactive auth should pass their own surface-specific label so the consent screen matches the product the user sees. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// When true, clears any cached OAuth token for the server and runs a full new authorization. Use when the user explicitly wants to switch accounts or believes their session is stuck. + #[serde(skip_serializing_if = "Option::is_none")] + pub force_reauth: Option, + /// Name of the remote MCP server to authenticate + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthLoginResult { + /// URL the caller should open in a browser to complete OAuth. Omitted when cached tokens were still valid and no browser interaction was needed โ€” the server is already reconnected in that case. When present, the runtime starts the callback listener before returning and continues the flow in the background; completion is signaled via session.mcp_server_status_changed. + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServer { + /// Error message if the server failed to connect + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Server name (config key) + pub name: String, + /// Configuration source: user, workspace, plugin, or builtin + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServerStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfigHttp { + #[serde(skip_serializing_if = "Option::is_none")] + pub filter_mapping: Option, + #[serde(default)] + pub headers: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_default_server: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_client_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_grant_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_public_client: Option, + /// Timeout in milliseconds for tool calls to this server. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Tools to include. Defaults to all tools if not specified. + #[serde(default)] + pub tools: Vec, + /// Remote transport type. Defaults to "http" when omitted. + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, + pub url: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfigLocal { + pub args: Vec, + pub command: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(default)] + pub env: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub filter_mapping: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_default_server: Option, + /// Timeout in milliseconds for tool calls to this server. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Tools to include. Defaults to all tools if not specified. + #[serde(default)] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerList { + /// Configured MCP servers + pub servers: Vec, +} + +/// Billing information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelBilling { + /// Billing cost multiplier relative to the base rate + pub multiplier: f64, +} + +/// Vision-specific limits +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesLimitsVision { + /// Maximum image size in bytes + #[serde(rename = "max_prompt_image_size")] + pub max_prompt_image_size: i64, + /// Maximum number of images per prompt + #[serde(rename = "max_prompt_images")] + pub max_prompt_images: i64, + /// MIME types the model accepts + #[serde(rename = "supported_media_types")] + pub supported_media_types: Vec, +} + +/// Token limits for prompts, outputs, and context window +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesLimits { + /// Maximum total context window size in tokens + #[serde( + rename = "max_context_window_tokens", + skip_serializing_if = "Option::is_none" + )] + pub max_context_window_tokens: Option, + /// Maximum number of output/completion tokens + #[serde(rename = "max_output_tokens", skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Maximum number of prompt/input tokens + #[serde(rename = "max_prompt_tokens", skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + /// Vision-specific limits + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Feature flags indicating what the model supports +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesSupports { + /// Whether this model supports reasoning effort configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this model supports vision/image input + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Model capabilities and limits +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilities { + /// Token limits for prompts, outputs, and context window + #[serde(skip_serializing_if = "Option::is_none")] + pub limits: Option, + /// Feature flags indicating what the model supports + #[serde(skip_serializing_if = "Option::is_none")] + pub supports: Option, +} + +/// Policy state (if applicable) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelPolicy { + /// Current policy state for this model + pub state: String, + /// Usage terms or conditions for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub terms: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Model { + /// Billing information + #[serde(skip_serializing_if = "Option::is_none")] + pub billing: Option, + /// Model capabilities and limits + pub capabilities: ModelCapabilities, + /// Default reasoning effort level (only present if model supports reasoning effort) + #[serde(skip_serializing_if = "Option::is_none")] + pub default_reasoning_effort: Option, + /// Model identifier (e.g., "claude-sonnet-4.5") + pub id: String, + /// Display name + pub name: String, + /// Policy state (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub policy: Option, + /// Supported reasoning effort levels (only present if model supports reasoning effort) + #[serde(default)] + pub supported_reasoning_efforts: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideLimitsVision { + /// Maximum image size in bytes + #[serde( + rename = "max_prompt_image_size", + skip_serializing_if = "Option::is_none" + )] + pub max_prompt_image_size: Option, + /// Maximum number of images per prompt + #[serde(rename = "max_prompt_images", skip_serializing_if = "Option::is_none")] + pub max_prompt_images: Option, + /// MIME types the model accepts + #[serde(rename = "supported_media_types", default)] + pub supported_media_types: Vec, +} + +/// Token limits for prompts, outputs, and context window +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideLimits { + /// Maximum total context window size in tokens + #[serde( + rename = "max_context_window_tokens", + skip_serializing_if = "Option::is_none" + )] + pub max_context_window_tokens: Option, + #[serde(rename = "max_output_tokens", skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(rename = "max_prompt_tokens", skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Feature flags indicating what the model supports +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideSupports { + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Override individual model capabilities resolved by the runtime +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverride { + /// Token limits for prompts, outputs, and context window + #[serde(skip_serializing_if = "Option::is_none")] + pub limits: Option, + /// Feature flags indicating what the model supports + #[serde(skip_serializing_if = "Option::is_none")] + pub supports: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelList { + /// List of available models with full metadata + pub models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsListRequest { + /// GitHub token for per-user model listing. When provided, resolves this token to determine the user's Copilot plan and available models instead of using the global auth. + #[serde(skip_serializing_if = "Option::is_none")] + pub git_hub_token: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSwitchToRequest { + /// Override individual model capabilities resolved by the runtime + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Model identifier to switch to + pub model_id: String, + /// Reasoning effort level to use for the model + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSwitchToResult { + /// Currently active model identifier after the switch + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModeSetRequest { + /// The agent mode. Valid values: "interactive", "plan", "autopilot". + pub mode: SessionMode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NameGetResult { + /// The session name (user-set or auto-generated), or null if not yet set + pub name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NameSetRequest { + /// New session name (1โ€“100 characters, trimmed of leading/trailing whitespace) + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalCommands { + pub command_identifiers: Vec, + pub kind: PermissionDecisionApproveForLocationApprovalCommandsKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalRead { + pub kind: PermissionDecisionApproveForLocationApprovalReadKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalWrite { + pub kind: PermissionDecisionApproveForLocationApprovalWriteKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMcp { + pub kind: PermissionDecisionApproveForLocationApprovalMcpKind, + pub server_name: String, + pub tool_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMcpSampling { + pub kind: PermissionDecisionApproveForLocationApprovalMcpSamplingKind, + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMemory { + pub kind: PermissionDecisionApproveForLocationApprovalMemoryKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalCustomTool { + pub kind: PermissionDecisionApproveForLocationApprovalCustomToolKind, + pub tool_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocation { + /// The approval to persist for this location + pub approval: PermissionDecisionApproveForLocationApproval, + /// Approved and persisted for this project location + pub kind: PermissionDecisionApproveForLocationKind, + /// The location key (git root or cwd) to persist the approval to + pub location_key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalCommands { + pub command_identifiers: Vec, + pub kind: PermissionDecisionApproveForSessionApprovalCommandsKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalRead { + pub kind: PermissionDecisionApproveForSessionApprovalReadKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalWrite { + pub kind: PermissionDecisionApproveForSessionApprovalWriteKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMcp { + pub kind: PermissionDecisionApproveForSessionApprovalMcpKind, + pub server_name: String, + pub tool_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMcpSampling { + pub kind: PermissionDecisionApproveForSessionApprovalMcpSamplingKind, + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMemory { + pub kind: PermissionDecisionApproveForSessionApprovalMemoryKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalCustomTool { + pub kind: PermissionDecisionApproveForSessionApprovalCustomToolKind, + pub tool_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSession { + /// The approval to add as a session-scoped rule + #[serde(skip_serializing_if = "Option::is_none")] + pub approval: Option, + /// The URL domain to approve for this session + #[serde(skip_serializing_if = "Option::is_none")] + pub domain: Option, + /// Approved and remembered for the rest of the session + pub kind: PermissionDecisionApproveForSessionKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveOnce { + /// The permission request was approved for this one instance + pub kind: PermissionDecisionApproveOnceKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApprovePermanently { + /// The URL domain to approve permanently + pub domain: String, + /// Approved and persisted across sessions + pub kind: PermissionDecisionApprovePermanentlyKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionReject { + /// Optional feedback from the user explaining the denial + #[serde(skip_serializing_if = "Option::is_none")] + pub feedback: Option, + /// Denied by the user during an interactive prompt + pub kind: PermissionDecisionRejectKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionUserNotAvailable { + /// Denied because user confirmation was unavailable + pub kind: PermissionDecisionUserNotAvailableKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionRequest { + /// Request ID of the pending permission request + pub request_id: RequestId, + pub result: PermissionDecision, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestResult { + /// Whether the permission request was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsResetSessionApprovalsRequest {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsResetSessionApprovalsResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsSetApproveAllRequest { + /// Whether to auto-approve all tool permission requests + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsSetApproveAllResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingRequest { + /// Optional message to echo back + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingResult { + /// Echoed message (or default greeting) + pub message: String, + /// Server protocol version number + pub protocol_version: i64, + /// Server timestamp in milliseconds + pub timestamp: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PlanReadResult { + /// The content of the plan file, or null if it does not exist + pub content: Option, + /// Whether the plan file exists in the workspace + pub exists: bool, + /// Absolute file path of the plan file, or null if workspace is not enabled + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PlanUpdateRequest { + /// The new content for the plan file + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Plugin { + /// Whether the plugin is currently enabled + pub enabled: bool, + /// Marketplace the plugin came from + pub marketplace: String, + /// Plugin name + pub name: String, + /// Installed version + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PluginList { + /// Installed plugins + pub plugins: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerSkill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled (based on global config) + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// The project path this skill belongs to (only for project/inherited skills) + #[serde(skip_serializing_if = "Option::is_none")] + pub project_path: Option, + /// Source location type (e.g., project, personal-copilot, plugin, builtin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerSkillList { + /// All discovered skills across all sources + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthStatus { + /// Authentication type + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// Copilot plan tier (e.g., individual_pro, business) + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_plan: Option, + /// Authentication host URL + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Whether the session has resolved authentication + pub is_authenticated: bool, + /// Authenticated login/username, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable authentication status description + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsAppendFileRequest { + /// Content to append + pub content: String, + /// Optional POSIX-style mode for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, +} + +/// Describes a filesystem error. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsError { + /// Error classification + pub code: SessionFsErrorCode, + /// Free-form detail about the error, for logging/diagnostics + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsExistsRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsExistsResult { + /// Whether the path exists + pub exists: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsMkdirRequest { + /// Optional POSIX-style mode for newly created directories + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, + /// Create parent directories as needed + #[serde(skip_serializing_if = "Option::is_none")] + pub recursive: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirResult { + /// Entry names in the directory + pub entries: Vec, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesEntry { + /// Entry name + pub name: String, + /// Entry type + pub r#type: SessionFsReaddirWithTypesEntryType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesResult { + /// Directory entries with type information + pub entries: Vec, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReadFileRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReadFileResult { + /// File content as UTF-8 string + pub content: String, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsRenameRequest { + /// Destination path using SessionFs conventions + pub dest: String, + /// Source path using SessionFs conventions + pub src: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsRmRequest { + /// Ignore errors if the path does not exist + #[serde(skip_serializing_if = "Option::is_none")] + pub force: Option, + /// Path using SessionFs conventions + pub path: String, + /// Remove directories and their contents recursively + #[serde(skip_serializing_if = "Option::is_none")] + pub recursive: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsSetProviderRequest { + /// Path conventions used by this filesystem + pub conventions: SessionFsSetProviderConventions, + /// Initial working directory for sessions + pub initial_cwd: String, + /// Path within each session's SessionFs where the runtime stores files for that session + pub session_state_path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsSetProviderResult { + /// Whether the provider was set successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsStatRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsStatResult { + /// ISO 8601 timestamp of creation + pub birthtime: String, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Whether the path is a directory + pub is_directory: bool, + /// Whether the path is a file + pub is_file: bool, + /// ISO 8601 timestamp of last modification + pub mtime: String, + /// File size in bytes + pub size: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsWriteFileRequest { + /// Content to write + pub content: String, + /// Optional POSIX-style mode for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionsForkRequest { + /// Source session ID to fork from + pub session_id: SessionId, + /// Optional event ID boundary. When provided, the fork includes only events before this ID (exclusive). When omitted, all events are included. + #[serde(skip_serializing_if = "Option::is_none")] + pub to_event_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionsForkResult { + /// The new forked session's ID + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellExecRequest { + /// Shell command to execute + pub command: String, + /// Working directory (defaults to session working directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + /// Timeout in milliseconds (default: 30000) + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellExecResult { + /// Unique identifier for tracking streamed output + pub process_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellKillRequest { + /// Process identifier returned by shell.exec + pub process_id: String, + /// Signal to send (default: SIGTERM) + #[serde(skip_serializing_if = "Option::is_none")] + pub signal: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellKillResult { + /// Whether the signal was sent successfully + pub killed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Skill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Source location type (e.g., project, personal, plugin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillList { + /// Available skills + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsConfigSetDisabledSkillsRequest { + /// List of skill names to disable + pub disabled_skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDisableRequest { + /// Name of the skill to disable + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDiscoverRequest { + /// Optional list of project directory paths to scan for project-scoped skills + #[serde(default)] + pub project_paths: Vec, + /// Optional list of additional skill directory paths to include + #[serde(default)] + pub skill_directories: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsEnableRequest { + /// Name of the skill to enable + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskAgentInfo { + /// ISO 8601 timestamp when the current active period began + #[serde(skip_serializing_if = "Option::is_none")] + pub active_started_at: Option, + /// Accumulated active execution time in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub active_time_ms: Option, + /// Type of agent running this task + pub agent_type: String, + /// Whether the task is currently in the original sync wait and can be moved to background mode. False once it is already backgrounded, idle, finished, or no longer has a promotable sync waiter. + #[serde(skip_serializing_if = "Option::is_none")] + pub can_promote_to_background: Option, + /// ISO 8601 timestamp when the task finished + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + /// Short description of the task + pub description: String, + /// Error message when the task failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// How the agent is currently being managed by the runtime + #[serde(skip_serializing_if = "Option::is_none")] + pub execution_mode: Option, + /// Unique task identifier + pub id: String, + /// ISO 8601 timestamp when the agent entered idle state + #[serde(skip_serializing_if = "Option::is_none")] + pub idle_since: Option, + /// Most recent response text from the agent + #[serde(skip_serializing_if = "Option::is_none")] + pub latest_response: Option, + /// Model used for the task when specified + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Prompt passed to the agent + pub prompt: String, + /// Result text from the task when available + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// ISO 8601 timestamp when the task was started + pub started_at: String, + /// Current lifecycle status of the task + pub status: TaskAgentInfoStatus, + /// Tool call ID associated with this agent task + pub tool_call_id: String, + /// Task kind + pub r#type: TaskAgentInfoType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskList { + /// Currently tracked tasks + pub tasks: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksCancelRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksCancelResult { + /// Whether the task was successfully cancelled + pub cancelled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskShellInfo { + /// Whether the shell runs inside a managed PTY session or as an independent background process + pub attachment_mode: TaskShellInfoAttachmentMode, + /// Whether this shell task can be promoted to background mode + #[serde(skip_serializing_if = "Option::is_none")] + pub can_promote_to_background: Option, + /// Command being executed + pub command: String, + /// ISO 8601 timestamp when the task finished + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + /// Short description of the task + pub description: String, + /// Whether the shell command is currently sync-waited or background-managed + #[serde(skip_serializing_if = "Option::is_none")] + pub execution_mode: Option, + /// Unique task identifier + pub id: String, + /// Path to the detached shell log, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub log_path: Option, + /// Process ID when available + #[serde(skip_serializing_if = "Option::is_none")] + pub pid: Option, + /// ISO 8601 timestamp when the task was started + pub started_at: String, + /// Current lifecycle status of the task + pub status: TaskShellInfoStatus, + /// Task kind + pub r#type: TaskShellInfoType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksPromoteToBackgroundRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksPromoteToBackgroundResult { + /// Whether the task was successfully promoted to background mode + pub promoted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksRemoveRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksRemoveResult { + /// Whether the task was removed. Returns false if the task does not exist or is still running/idle (cancel it first). + pub removed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksStartAgentRequest { + /// Type of agent to start (e.g., 'explore', 'task', 'general-purpose') + pub agent_type: String, + /// Short description of the task + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional model override + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Short name for the agent, used to generate a human-readable ID + pub name: String, + /// Task prompt for the agent + pub prompt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksStartAgentResult { + /// Generated agent ID for the background task + pub agent_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + /// Description of what the tool does + pub description: String, + /// Optional instructions for how to use this tool effectively + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// Tool identifier (e.g., "bash", "grep", "str_replace_editor") + pub name: String, + /// Optional namespaced name for declarative filtering (e.g., "playwright/navigate" for MCP tools) + #[serde(skip_serializing_if = "Option::is_none")] + pub namespaced_name: Option, + /// JSON Schema for the tool's input parameters + #[serde(default)] + pub parameters: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolList { + /// List of available built-in tools with metadata + pub tools: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsListRequest { + /// Optional model ID โ€” when provided, the returned tool list reflects model-specific overrides + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfFieldItemsAnyOf { + pub r#const: String, + pub title: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfFieldItems { + pub any_of: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfField { + #[serde(default)] + pub default: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub items: UIElicitationArrayAnyOfFieldItems, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationArrayAnyOfFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayEnumFieldItems { + pub r#enum: Vec, + pub r#type: UIElicitationArrayEnumFieldItemsType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayEnumField { + #[serde(default)] + pub default: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub items: UIElicitationArrayEnumFieldItems, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationArrayEnumFieldType, +} + +/// JSON Schema describing the form fields to present to the user +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchema { + /// Form field definitions, keyed by field name + pub properties: HashMap, + /// List of required field names + #[serde(default)] + pub required: Vec, + /// Schema type indicator (always 'object') + pub r#type: UIElicitationSchemaType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationRequest { + /// Message describing what information is needed from the user + pub message: String, + /// JSON Schema describing the form fields to present to the user + pub requested_schema: UIElicitationSchema, +} + +/// The elicitation response (accept with form values, decline, or cancel) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationResponse { + /// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) + pub action: UIElicitationResponseAction, + /// The form values submitted by the user (present when action is 'accept') + #[serde(default)] + pub content: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationResult { + /// Whether the response was accepted. False if the request was already resolved by another client. + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyBoolean { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyBooleanType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyNumber { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyNumberType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyString { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyStringType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringEnumField { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub r#enum: Vec, + #[serde(default)] + pub enum_names: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationStringEnumFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringOneOfFieldOneOf { + pub r#const: String, + pub title: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringOneOfField { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub one_of: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationStringOneOfFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIHandlePendingElicitationRequest { + /// The unique request ID from the elicitation.requested event + pub request_id: RequestId, + /// The elicitation response (accept with form values, decline, or cancel) + pub result: UIElicitationResponse, +} + +/// Aggregated code change metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsCodeChanges { + /// Number of distinct files modified + pub files_modified_count: i64, + /// Total lines of code added + pub lines_added: i64, + /// Total lines of code removed + pub lines_removed: i64, +} + +/// Request count and cost metrics for this model +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetricRequests { + /// User-initiated premium request cost (with multiplier applied) + pub cost: f64, + /// Number of API requests made with this model + pub count: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetricTokenDetail { + /// Accumulated token count for this token type + pub token_count: i64, +} + +/// Token usage metrics for this model +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetricUsage { + /// Total tokens read from prompt cache + pub cache_read_tokens: i64, + /// Total tokens written to prompt cache + pub cache_write_tokens: i64, + /// Total input tokens consumed + pub input_tokens: i64, + /// Total output tokens produced + pub output_tokens: i64, + /// Total output tokens used for reasoning + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetric { + /// Request count and cost metrics for this model + pub requests: UsageMetricsModelMetricRequests, + /// Token count details per type + #[serde(default)] + pub token_details: HashMap, + /// Accumulated nano-AI units cost for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub total_nano_aiu: Option, + /// Token usage metrics for this model + pub usage: UsageMetricsModelMetricUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsTokenDetail { + /// Accumulated token count for this token type + pub token_count: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageGetMetricsResult { + /// Aggregated code change metrics + pub code_changes: UsageMetricsCodeChanges, + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Input tokens from the most recent main-agent API call + pub last_call_input_tokens: i64, + /// Output tokens from the most recent main-agent API call + pub last_call_output_tokens: i64, + /// Per-model token and request metrics, keyed by model identifier + pub model_metrics: HashMap, + /// Session start timestamp (epoch milliseconds) + pub session_start_time: i64, + /// Session-wide per-token-type accumulated token counts + #[serde(default)] + pub token_details: HashMap, + /// Total time spent in model API calls (milliseconds) + pub total_api_duration_ms: f64, + /// Session-wide accumulated nano-AI units cost + #[serde(skip_serializing_if = "Option::is_none")] + pub total_nano_aiu: Option, + /// Total user-initiated premium request cost across all models (may be fractional due to multipliers) + pub total_premium_request_cost: f64, + /// Raw count of user-initiated API requests + pub total_user_requests: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesCreateFileRequest { + /// File content to write as a UTF-8 string + pub content: String, + /// Relative path within the workspace files directory + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesGetWorkspaceResultWorkspace { + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + #[serde( + rename = "chronicle_sync_dismissed", + skip_serializing_if = "Option::is_none" + )] + pub chronicle_sync_dismissed: Option, + #[serde(rename = "created_at", skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(rename = "git_root", skip_serializing_if = "Option::is_none")] + pub git_root: Option, + #[serde(rename = "host_type", skip_serializing_if = "Option::is_none")] + pub host_type: Option, + pub id: String, + #[serde(rename = "mc_last_event_id", skip_serializing_if = "Option::is_none")] + pub mc_last_event_id: Option, + #[serde(rename = "mc_session_id", skip_serializing_if = "Option::is_none")] + pub mc_session_id: Option, + #[serde(rename = "mc_task_id", skip_serializing_if = "Option::is_none")] + pub mc_task_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(rename = "remote_steerable", skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + #[serde(rename = "session_sync_level", skip_serializing_if = "Option::is_none")] + pub session_sync_level: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + #[serde(rename = "summary_count", skip_serializing_if = "Option::is_none")] + pub summary_count: Option, + #[serde(rename = "updated_at", skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(rename = "user_named", skip_serializing_if = "Option::is_none")] + pub user_named: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesGetWorkspaceResult { + /// Current workspace metadata, or null if not available + pub workspace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesListFilesResult { + /// Relative file paths in the workspace files directory + pub files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesReadFileRequest { + /// Relative path within the workspace files directory + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesReadFileResult { + /// File content as a UTF-8 string + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsListResult { + /// List of available models with full metadata + pub models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsListResult { + /// List of available built-in tools with metadata + pub tools: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigListResult { + /// All MCP servers from user config, keyed by name + pub servers: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDiscoverResult { + /// All discovered skills across all sources + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSuspendParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthGetStatusParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthGetStatusResult { + /// Authentication type + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// Copilot plan tier (e.g., individual_pro, business) + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_plan: Option, + /// Authentication host URL + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Whether the session has resolved authentication + pub is_authenticated: bool, + /// Authenticated login/username, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable authentication status description + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelGetCurrentParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelGetCurrentResult { + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelSwitchToResult { + /// Currently active model identifier after the switch + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModeGetParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNameGetParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNameGetResult { + /// The session name (user-set or auto-generated), or null if not yet set + pub name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanReadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanReadResult { + /// The content of the plan file, or null if it does not exist + pub content: Option, + /// Whether the plan file exists in the workspace + pub exists: bool, + /// Absolute file path of the plan file, or null if workspace is not enabled + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanDeleteParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceResultWorkspace { + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + #[serde( + rename = "chronicle_sync_dismissed", + skip_serializing_if = "Option::is_none" + )] + pub chronicle_sync_dismissed: Option, + #[serde(rename = "created_at", skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(rename = "git_root", skip_serializing_if = "Option::is_none")] + pub git_root: Option, + #[serde(rename = "host_type", skip_serializing_if = "Option::is_none")] + pub host_type: Option, + pub id: String, + #[serde(rename = "mc_last_event_id", skip_serializing_if = "Option::is_none")] + pub mc_last_event_id: Option, + #[serde(rename = "mc_session_id", skip_serializing_if = "Option::is_none")] + pub mc_session_id: Option, + #[serde(rename = "mc_task_id", skip_serializing_if = "Option::is_none")] + pub mc_task_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(rename = "remote_steerable", skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + #[serde(rename = "session_sync_level", skip_serializing_if = "Option::is_none")] + pub session_sync_level: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + #[serde(rename = "summary_count", skip_serializing_if = "Option::is_none")] + pub summary_count: Option, + #[serde(rename = "updated_at", skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(rename = "user_named", skip_serializing_if = "Option::is_none")] + pub user_named: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceResult { + /// Current workspace metadata, or null if not available + pub workspace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesListFilesParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesListFilesResult { + /// Relative file paths in the workspace files directory + pub files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesReadFileResult { + /// File content as a UTF-8 string + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInstructionsGetSourcesParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInstructionsGetSourcesResult { + /// Instruction sources for the session + pub sources: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFleetStartResult { + /// Whether fleet mode was successfully activated + pub started: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentListResult { + /// Available custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentGetCurrentParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentGetCurrentResult { + /// Currently selected custom agent, or null if using the default agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentSelectResult { + /// The newly selected custom agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentDeselectParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentReloadResult { + /// Reloaded custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksStartAgentResult { + /// Generated agent ID for the background task + pub agent_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksListResult { + /// Currently tracked tasks + pub tasks: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksPromoteToBackgroundResult { + /// Whether the task was successfully promoted to background mode + pub promoted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksCancelResult { + /// Whether the task was successfully cancelled + pub cancelled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksRemoveResult { + /// Whether the task was removed. Returns false if the task does not exist or is still running/idle (cancel it first). + pub removed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsListResult { + /// Available skills + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpListResult { + /// Configured MCP servers + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpOauthLoginResult { + /// URL the caller should open in a browser to complete OAuth. Omitted when cached tokens were still valid and no browser interaction was needed โ€” the server is already reconnected in that case. When present, the runtime starts the callback listener before returning and continues the flow in the background; completion is signaled via session.mcp_server_status_changed. + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPluginsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPluginsListResult { + /// Installed plugins + pub plugins: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsListResult { + /// Discovered extensions and their current status + pub extensions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionToolsHandlePendingToolCallResult { + /// Whether the tool call result was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCommandsHandlePendingCommandResult { + /// Whether the command was handled successfully + pub success: bool, +} + +/// The elicitation response (accept with form values, decline, or cancel) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUiElicitationResult { + /// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) + pub action: UIElicitationResponseAction, + /// The form values submitted by the user (present when action is 'accept') + #[serde(default)] + pub content: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUiHandlePendingElicitationResult { + /// Whether the response was accepted. False if the request was already resolved by another client. + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsHandlePendingPermissionRequestResult { + /// Whether the permission request was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsSetApproveAllResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsResetSessionApprovalsResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionLogResult { + /// The unique identifier of the emitted session event + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShellExecResult { + /// Unique identifier for tracking streamed output + pub process_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShellKillResult { + /// Whether the signal was sent successfully + pub killed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryCompactParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryCompactResult { + /// Post-compaction context window usage breakdown + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + /// Number of messages removed during compaction + pub messages_removed: i64, + /// Whether compaction completed successfully + pub success: bool, + /// Number of tokens freed by compaction + pub tokens_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryTruncateResult { + /// Number of events that were removed + pub events_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageGetMetricsParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageGetMetricsResult { + /// Aggregated code change metrics + pub code_changes: UsageMetricsCodeChanges, + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Input tokens from the most recent main-agent API call + pub last_call_input_tokens: i64, + /// Output tokens from the most recent main-agent API call + pub last_call_output_tokens: i64, + /// Per-model token and request metrics, keyed by model identifier + pub model_metrics: HashMap, + /// Session start timestamp (epoch milliseconds) + pub session_start_time: i64, + /// Session-wide per-token-type accumulated token counts + #[serde(default)] + pub token_details: HashMap, + /// Total time spent in model API calls (milliseconds) + pub total_api_duration_ms: f64, + /// Session-wide accumulated nano-AI units cost + #[serde(skip_serializing_if = "Option::is_none")] + pub total_nano_aiu: Option, + /// Total user-initiated premium request cost across all models (may be fractional due to multipliers) + pub total_premium_request_cost: f64, + /// Raw count of user-initiated API requests + pub total_user_requests: i64, +} + +/// Authentication type +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AuthInfoType { + #[serde(rename = "hmac")] + Hmac, + #[serde(rename = "env")] + Env, + #[serde(rename = "user")] + User, + #[serde(rename = "gh-cli")] + GhCli, + #[serde(rename = "api-key")] + ApiKey, + #[serde(rename = "token")] + Token, + #[serde(rename = "copilot-api-token")] + CopilotApiToken, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Configuration source +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiscoveredMcpServerSource { + #[serde(rename = "user")] + User, + #[serde(rename = "workspace")] + Workspace, + #[serde(rename = "plugin")] + Plugin, + #[serde(rename = "builtin")] + Builtin, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Server transport type: stdio, http, sse, or memory (local configs are normalized to stdio) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiscoveredMcpServerType { + #[serde(rename = "stdio")] + Stdio, + #[serde(rename = "http")] + Http, + #[serde(rename = "sse")] + Sse, + #[serde(rename = "memory")] + Memory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Discovery source: project (.github/extensions/) or user (~/.copilot/extensions/) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionSource { + #[serde(rename = "project")] + Project, + #[serde(rename = "user")] + User, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current status: running, disabled, failed, or starting +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "starting")] + Starting, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentAudioType { + #[serde(rename = "audio")] + Audio, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentImageType { + #[serde(rename = "image")] + Image, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentResourceType { + #[serde(rename = "resource")] + Resource, +} + +/// Theme variant this icon is intended for +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentResourceLinkIconTheme { + #[serde(rename = "light")] + Light, + #[serde(rename = "dark")] + Dark, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentResourceLinkType { + #[serde(rename = "resource_link")] + ResourceLink, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentTerminalType { + #[serde(rename = "terminal")] + Terminal, +} + +/// Content block type discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExternalToolTextResultForLlmContentTextType { + #[serde(rename = "text")] + Text, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum FilterMappingString { + #[serde(rename = "none")] + None, + #[serde(rename = "markdown")] + Markdown, + #[serde(rename = "hidden_characters")] + HiddenCharacters, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum FilterMappingValue { + #[serde(rename = "none")] + None, + #[serde(rename = "markdown")] + Markdown, + #[serde(rename = "hidden_characters")] + HiddenCharacters, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Where this source lives โ€” used for UI grouping +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum InstructionsSourcesLocation { + #[serde(rename = "user")] + User, + #[serde(rename = "repository")] + Repository, + #[serde(rename = "working-directory")] + WorkingDirectory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Category of instruction source โ€” used for merge logic +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum InstructionsSourcesType { + #[serde(rename = "home")] + Home, + #[serde(rename = "repo")] + Repo, + #[serde(rename = "model")] + Model, + #[serde(rename = "vscode")] + Vscode, + #[serde(rename = "nested-agents")] + NestedAgents, + #[serde(rename = "child-instructions")] + ChildInstructions, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Log severity level. Determines how the message is displayed in the timeline. Defaults to "info". +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionLogLevel { + #[serde(rename = "info")] + Info, + #[serde(rename = "warning")] + Warning, + #[serde(rename = "error")] + Error, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Configuration source: user, workspace, plugin, or builtin +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerSource { + #[serde(rename = "user")] + User, + #[serde(rename = "workspace")] + Workspace, + #[serde(rename = "plugin")] + Plugin, + #[serde(rename = "builtin")] + Builtin, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigHttpOauthGrantType { + #[serde(rename = "authorization_code")] + AuthorizationCode, + #[serde(rename = "client_credentials")] + ClientCredentials, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Remote transport type. Defaults to "http" when omitted. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigHttpType { + #[serde(rename = "http")] + Http, + #[serde(rename = "sse")] + Sse, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigLocalType { + #[serde(rename = "local")] + Local, + #[serde(rename = "stdio")] + Stdio, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The agent mode. Valid values: "interactive", "plan", "autopilot". +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionMode { + #[serde(rename = "interactive")] + Interactive, + #[serde(rename = "plan")] + Plan, + #[serde(rename = "autopilot")] + Autopilot, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalReadKind { + #[serde(rename = "read")] + Read, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalWriteKind { + #[serde(rename = "write")] + Write, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMcpSamplingKind { + #[serde(rename = "mcp-sampling")] + McpSampling, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// The approval to persist for this location +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecisionApproveForLocationApproval { + Commands(PermissionDecisionApproveForLocationApprovalCommands), + Read(PermissionDecisionApproveForLocationApprovalRead), + Write(PermissionDecisionApproveForLocationApprovalWrite), + Mcp(PermissionDecisionApproveForLocationApprovalMcp), + McpSampling(PermissionDecisionApproveForLocationApprovalMcpSampling), + Memory(PermissionDecisionApproveForLocationApprovalMemory), + CustomTool(PermissionDecisionApproveForLocationApprovalCustomTool), +} + +/// Approved and persisted for this project location +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationKind { + #[serde(rename = "approve-for-location")] + ApproveForLocation, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalReadKind { + #[serde(rename = "read")] + Read, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalWriteKind { + #[serde(rename = "write")] + Write, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMcpSamplingKind { + #[serde(rename = "mcp-sampling")] + McpSampling, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// The approval to add as a session-scoped rule +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecisionApproveForSessionApproval { + Commands(PermissionDecisionApproveForSessionApprovalCommands), + Read(PermissionDecisionApproveForSessionApprovalRead), + Write(PermissionDecisionApproveForSessionApprovalWrite), + Mcp(PermissionDecisionApproveForSessionApprovalMcp), + McpSampling(PermissionDecisionApproveForSessionApprovalMcpSampling), + Memory(PermissionDecisionApproveForSessionApprovalMemory), + CustomTool(PermissionDecisionApproveForSessionApprovalCustomTool), +} + +/// Approved and remembered for the rest of the session +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionKind { + #[serde(rename = "approve-for-session")] + ApproveForSession, +} + +/// The permission request was approved for this one instance +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveOnceKind { + #[serde(rename = "approve-once")] + ApproveOnce, +} + +/// Approved and persisted across sessions +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApprovePermanentlyKind { + #[serde(rename = "approve-permanently")] + ApprovePermanently, +} + +/// Denied by the user during an interactive prompt +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionRejectKind { + #[serde(rename = "reject")] + Reject, +} + +/// Denied because user confirmation was unavailable +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionUserNotAvailableKind { + #[serde(rename = "user-not-available")] + UserNotAvailable, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecision { + ApproveOnce(PermissionDecisionApproveOnce), + ApproveForSession(PermissionDecisionApproveForSession), + ApproveForLocation(PermissionDecisionApproveForLocation), + ApprovePermanently(PermissionDecisionApprovePermanently), + Reject(PermissionDecisionReject), + UserNotAvailable(PermissionDecisionUserNotAvailable), +} + +/// Error classification +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsErrorCode { + ENOENT, + UNKNOWN, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Entry type +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsReaddirWithTypesEntryType { + #[serde(rename = "file")] + File, + #[serde(rename = "directory")] + Directory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Path conventions used by this filesystem +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsSetProviderConventions { + #[serde(rename = "windows")] + Windows, + #[serde(rename = "posix")] + Posix, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Signal to send (default: SIGTERM) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ShellKillSignal { + SIGTERM, + SIGKILL, + SIGINT, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// How the agent is currently being managed by the runtime +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoExecutionMode { + #[serde(rename = "sync")] + Sync, + #[serde(rename = "background")] + Background, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current lifecycle status of the task +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "idle")] + Idle, + #[serde(rename = "completed")] + Completed, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "cancelled")] + Cancelled, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Task kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoType { + #[serde(rename = "agent")] + Agent, +} + +/// Whether the shell runs inside a managed PTY session or as an independent background process +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoAttachmentMode { + #[serde(rename = "attached")] + Attached, + #[serde(rename = "detached")] + Detached, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the shell command is currently sync-waited or background-managed +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoExecutionMode { + #[serde(rename = "sync")] + Sync, + #[serde(rename = "background")] + Background, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current lifecycle status of the task +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "idle")] + Idle, + #[serde(rename = "completed")] + Completed, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "cancelled")] + Cancelled, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Task kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoType { + #[serde(rename = "shell")] + Shell, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayAnyOfFieldType { + #[serde(rename = "array")] + Array, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayEnumFieldItemsType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayEnumFieldType { + #[serde(rename = "array")] + Array, +} + +/// Schema type indicator (always 'object') +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaType { + #[serde(rename = "object")] + Object, +} + +/// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationResponseAction { + #[serde(rename = "accept")] + Accept, + #[serde(rename = "decline")] + Decline, + #[serde(rename = "cancel")] + Cancel, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyBooleanType { + #[serde(rename = "boolean")] + Boolean, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyNumberType { + #[serde(rename = "number")] + Number, + #[serde(rename = "integer")] + Integer, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyStringFormat { + #[serde(rename = "email")] + Email, + #[serde(rename = "uri")] + Uri, + #[serde(rename = "date")] + Date, + #[serde(rename = "date-time")] + DateTime, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyStringType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationStringEnumFieldType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationStringOneOfFieldType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspacesGetWorkspaceResultWorkspaceHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspacesGetWorkspaceResultWorkspaceSessionSyncLevel { + #[serde(rename = "local")] + Local, + #[serde(rename = "user")] + User, + #[serde(rename = "repo_and_user")] + RepoAndUser, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionWorkspacesGetWorkspaceResultWorkspaceHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionWorkspacesGetWorkspaceResultWorkspaceSessionSyncLevel { + #[serde(rename = "local")] + Local, + #[serde(rename = "user")] + User, + #[serde(rename = "repo_and_user")] + RepoAndUser, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} diff --git a/rust/src/generated/mod.rs b/rust/src/generated/mod.rs new file mode 100644 index 000000000..5466a5e35 --- /dev/null +++ b/rust/src/generated/mod.rs @@ -0,0 +1,15 @@ +//! Auto-generated protocol types โ€” do not edit manually. +//! +//! Generated from the Copilot protocol JSON Schemas by `scripts/codegen/rust.ts`. +#![allow(missing_docs)] +#![allow(rustdoc::bare_urls)] + +pub mod api_types; +pub mod rpc; +pub mod session_events; + +// Re-export session event types at the module root โ€” no conflicts with +// hand-written types. API types are kept namespaced under `api_types::` +// because some names (Tool, ModelCapabilities, etc.) overlap with the +// hand-written SDK API types in `types.rs`. +pub use session_events::*; diff --git a/rust/src/generated/rpc.rs b/rust/src/generated/rpc.rs new file mode 100644 index 000000000..ee38f27a5 --- /dev/null +++ b/rust/src/generated/rpc.rs @@ -0,0 +1,1583 @@ +//! Auto-generated typed JSON-RPC namespace โ€” do not edit manually. +//! +//! Generated from `api.schema.json` by `scripts/codegen/rust.ts`. The +//! [`ClientRpc`] and [`SessionRpc`] view structs let callers reach every +//! protocol method through a typed namespace tree, so wire method names +//! and request/response shapes live in exactly one place โ€” this file. + +#![allow(missing_docs)] +#![allow(clippy::too_many_arguments)] + +use super::api_types::{rpc_methods, *}; +use crate::session::Session; +use crate::{Client, Error}; + +/// Typed view over the [`Client`]'s server-level RPC namespace. +#[derive(Clone, Copy)] +pub struct ClientRpc<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpc<'a> { + /// `account.*` sub-namespace. + pub fn account(&self) -> ClientRpcAccount<'a> { + ClientRpcAccount { + client: self.client, + } + } + + /// `mcp.*` sub-namespace. + pub fn mcp(&self) -> ClientRpcMcp<'a> { + ClientRpcMcp { + client: self.client, + } + } + + /// `models.*` sub-namespace. + pub fn models(&self) -> ClientRpcModels<'a> { + ClientRpcModels { + client: self.client, + } + } + + /// `sessionFs.*` sub-namespace. + pub fn session_fs(&self) -> ClientRpcSessionFs<'a> { + ClientRpcSessionFs { + client: self.client, + } + } + + /// `sessions.*` sub-namespace. + pub fn sessions(&self) -> ClientRpcSessions<'a> { + ClientRpcSessions { + client: self.client, + } + } + + /// `skills.*` sub-namespace. + pub fn skills(&self) -> ClientRpcSkills<'a> { + ClientRpcSkills { + client: self.client, + } + } + + /// `tools.*` sub-namespace. + pub fn tools(&self) -> ClientRpcTools<'a> { + ClientRpcTools { + client: self.client, + } + } + + /// Wire method: `ping`. + pub async fn ping(&self, params: PingRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::PING, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `connect`. + pub async fn connect(&self, params: ConnectRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::CONNECT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `account.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcAccount<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcAccount<'a> { + /// Wire method: `account.getQuota`. + pub async fn get_quota(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::ACCOUNT_GETQUOTA, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `mcp.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcMcp<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcMcp<'a> { + /// `mcp.config.*` sub-namespace. + pub fn config(&self) -> ClientRpcMcpConfig<'a> { + ClientRpcMcpConfig { + client: self.client, + } + } + + /// Wire method: `mcp.discover`. + pub async fn discover(&self, params: McpDiscoverRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_DISCOVER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `mcp.config.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcMcpConfig<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcMcpConfig<'a> { + /// Wire method: `mcp.config.list`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `mcp.config.add`. + pub async fn add(&self, params: McpConfigAddRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_ADD, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.update`. + pub async fn update(&self, params: McpConfigUpdateRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_UPDATE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.remove`. + pub async fn remove(&self, params: McpConfigRemoveRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_REMOVE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.enable`. + pub async fn enable(&self, params: McpConfigEnableRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.disable`. + pub async fn disable(&self, params: McpConfigDisableRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `models.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcModels<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcModels<'a> { + /// Wire method: `models.list`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::MODELS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `sessionFs.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSessionFs<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSessionFs<'a> { + /// Wire method: `sessionFs.setProvider`. + pub async fn set_provider( + &self, + params: SessionFsSetProviderRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SESSIONFS_SETPROVIDER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `sessions.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSessions<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSessions<'a> { + /// Wire method: `sessions.fork`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn fork(&self, params: SessionsForkRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SESSIONS_FORK, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `skills.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSkills<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSkills<'a> { + /// `skills.config.*` sub-namespace. + pub fn config(&self) -> ClientRpcSkillsConfig<'a> { + ClientRpcSkillsConfig { + client: self.client, + } + } + + /// Wire method: `skills.discover`. + pub async fn discover(&self, params: SkillsDiscoverRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SKILLS_DISCOVER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `skills.config.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSkillsConfig<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSkillsConfig<'a> { + /// Wire method: `skills.config.setDisabledSkills`. + pub async fn set_disabled_skills( + &self, + params: SkillsConfigSetDisabledSkillsRequest, + ) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call( + rpc_methods::SKILLS_CONFIG_SETDISABLEDSKILLS, + Some(wire_params), + ) + .await?; + Ok(()) + } +} + +/// `tools.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcTools<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcTools<'a> { + /// Wire method: `tools.list`. + pub async fn list(&self, params: ToolsListRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::TOOLS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// Typed view over a [`Session`]'s RPC namespace. +#[derive(Clone, Copy)] +pub struct SessionRpc<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpc<'a> { + /// `session.agent.*` sub-namespace. + pub fn agent(&self) -> SessionRpcAgent<'a> { + SessionRpcAgent { + session: self.session, + } + } + + /// `session.auth.*` sub-namespace. + pub fn auth(&self) -> SessionRpcAuth<'a> { + SessionRpcAuth { + session: self.session, + } + } + + /// `session.commands.*` sub-namespace. + pub fn commands(&self) -> SessionRpcCommands<'a> { + SessionRpcCommands { + session: self.session, + } + } + + /// `session.extensions.*` sub-namespace. + pub fn extensions(&self) -> SessionRpcExtensions<'a> { + SessionRpcExtensions { + session: self.session, + } + } + + /// `session.fleet.*` sub-namespace. + pub fn fleet(&self) -> SessionRpcFleet<'a> { + SessionRpcFleet { + session: self.session, + } + } + + /// `session.history.*` sub-namespace. + pub fn history(&self) -> SessionRpcHistory<'a> { + SessionRpcHistory { + session: self.session, + } + } + + /// `session.instructions.*` sub-namespace. + pub fn instructions(&self) -> SessionRpcInstructions<'a> { + SessionRpcInstructions { + session: self.session, + } + } + + /// `session.mcp.*` sub-namespace. + pub fn mcp(&self) -> SessionRpcMcp<'a> { + SessionRpcMcp { + session: self.session, + } + } + + /// `session.mode.*` sub-namespace. + pub fn mode(&self) -> SessionRpcMode<'a> { + SessionRpcMode { + session: self.session, + } + } + + /// `session.model.*` sub-namespace. + pub fn model(&self) -> SessionRpcModel<'a> { + SessionRpcModel { + session: self.session, + } + } + + /// `session.name.*` sub-namespace. + pub fn name(&self) -> SessionRpcName<'a> { + SessionRpcName { + session: self.session, + } + } + + /// `session.permissions.*` sub-namespace. + pub fn permissions(&self) -> SessionRpcPermissions<'a> { + SessionRpcPermissions { + session: self.session, + } + } + + /// `session.plan.*` sub-namespace. + pub fn plan(&self) -> SessionRpcPlan<'a> { + SessionRpcPlan { + session: self.session, + } + } + + /// `session.plugins.*` sub-namespace. + pub fn plugins(&self) -> SessionRpcPlugins<'a> { + SessionRpcPlugins { + session: self.session, + } + } + + /// `session.shell.*` sub-namespace. + pub fn shell(&self) -> SessionRpcShell<'a> { + SessionRpcShell { + session: self.session, + } + } + + /// `session.skills.*` sub-namespace. + pub fn skills(&self) -> SessionRpcSkills<'a> { + SessionRpcSkills { + session: self.session, + } + } + + /// `session.tasks.*` sub-namespace. + pub fn tasks(&self) -> SessionRpcTasks<'a> { + SessionRpcTasks { + session: self.session, + } + } + + /// `session.tools.*` sub-namespace. + pub fn tools(&self) -> SessionRpcTools<'a> { + SessionRpcTools { + session: self.session, + } + } + + /// `session.ui.*` sub-namespace. + pub fn ui(&self) -> SessionRpcUi<'a> { + SessionRpcUi { + session: self.session, + } + } + + /// `session.usage.*` sub-namespace. + pub fn usage(&self) -> SessionRpcUsage<'a> { + SessionRpcUsage { + session: self.session, + } + } + + /// `session.workspaces.*` sub-namespace. + pub fn workspaces(&self) -> SessionRpcWorkspaces<'a> { + SessionRpcWorkspaces { + session: self.session, + } + } + + /// Wire method: `session.suspend`. + pub async fn suspend(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SUSPEND, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.log`. + pub async fn log(&self, params: LogRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_LOG, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.agent.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcAgent<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcAgent<'a> { + /// Wire method: `session.agent.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.getCurrent`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn get_current(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_GETCURRENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.select`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn select(&self, params: AgentSelectRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_SELECT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.deselect`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn deselect(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_DESELECT, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.agent.reload`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn reload(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_RELOAD, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.auth.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcAuth<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcAuth<'a> { + /// Wire method: `session.auth.getStatus`. + pub async fn get_status(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AUTH_GETSTATUS, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.commands.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcCommands<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcCommands<'a> { + /// Wire method: `session.commands.handlePendingCommand`. + pub async fn handle_pending_command( + &self, + params: CommandsHandlePendingCommandRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_COMMANDS_HANDLEPENDINGCOMMAND, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.extensions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcExtensions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcExtensions<'a> { + /// Wire method: `session.extensions.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.extensions.enable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn enable(&self, params: ExtensionsEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.extensions.disable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn disable(&self, params: ExtensionsDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.extensions.reload`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.fleet.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcFleet<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcFleet<'a> { + /// Wire method: `session.fleet.start`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn start(&self, params: FleetStartRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_FLEET_START, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.history.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcHistory<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcHistory<'a> { + /// Wire method: `session.history.compact`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn compact(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_HISTORY_COMPACT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.history.truncate`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn truncate( + &self, + params: HistoryTruncateRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_HISTORY_TRUNCATE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.instructions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcInstructions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcInstructions<'a> { + /// Wire method: `session.instructions.getSources`. + pub async fn get_sources(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_INSTRUCTIONS_GETSOURCES, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.mcp.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMcp<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMcp<'a> { + /// `session.mcp.oauth.*` sub-namespace. + pub fn oauth(&self) -> SessionRpcMcpOauth<'a> { + SessionRpcMcpOauth { + session: self.session, + } + } + + /// Wire method: `session.mcp.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.mcp.enable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn enable(&self, params: McpEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.mcp.disable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn disable(&self, params: McpDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.mcp.reload`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.mcp.oauth.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMcpOauth<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMcpOauth<'a> { + /// Wire method: `session.mcp.oauth.login`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn login(&self, params: McpOauthLoginRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_OAUTH_LOGIN, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.mode.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMode<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMode<'a> { + /// Wire method: `session.mode.get`. + pub async fn get(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODE_GET, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.mode.set`. + pub async fn set(&self, params: ModeSetRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODE_SET, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.model.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcModel<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcModel<'a> { + /// Wire method: `session.model.getCurrent`. + pub async fn get_current(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODEL_GETCURRENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.model.switchTo`. + pub async fn switch_to( + &self, + params: ModelSwitchToRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODEL_SWITCHTO, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.name.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcName<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcName<'a> { + /// Wire method: `session.name.get`. + pub async fn get(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_NAME_GET, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.name.set`. + pub async fn set(&self, params: NameSetRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_NAME_SET, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.permissions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPermissions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPermissions<'a> { + /// Wire method: `session.permissions.handlePendingPermissionRequest`. + pub async fn handle_pending_permission_request( + &self, + params: PermissionDecisionRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_HANDLEPENDINGPERMISSIONREQUEST, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.permissions.setApproveAll`. + pub async fn set_approve_all( + &self, + params: PermissionsSetApproveAllRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_SETAPPROVEALL, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.permissions.resetSessionApprovals`. + pub async fn reset_session_approvals( + &self, + ) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_RESETSESSIONAPPROVALS, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.plan.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPlan<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPlan<'a> { + /// Wire method: `session.plan.read`. + pub async fn read(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_READ, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.plan.update`. + pub async fn update(&self, params: PlanUpdateRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_UPDATE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.plan.delete`. + pub async fn delete(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_DELETE, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.plugins.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPlugins<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPlugins<'a> { + /// Wire method: `session.plugins.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLUGINS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.shell.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcShell<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcShell<'a> { + /// Wire method: `session.shell.exec`. + pub async fn exec(&self, params: ShellExecRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SHELL_EXEC, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.shell.kill`. + pub async fn kill(&self, params: ShellKillRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SHELL_KILL, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.skills.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcSkills<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcSkills<'a> { + /// Wire method: `session.skills.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.skills.enable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn enable(&self, params: SkillsEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.skills.disable`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn disable(&self, params: SkillsDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.skills.reload`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.tasks.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcTasks<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcTasks<'a> { + /// Wire method: `session.tasks.startAgent`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn start_agent( + &self, + params: TasksStartAgentRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_STARTAGENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.list`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.promoteToBackground`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn promote_to_background( + &self, + params: TasksPromoteToBackgroundRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_TASKS_PROMOTETOBACKGROUND, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.cancel`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn cancel(&self, params: TasksCancelRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_CANCEL, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.remove`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn remove(&self, params: TasksRemoveRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_REMOVE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.tools.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcTools<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcTools<'a> { + /// Wire method: `session.tools.handlePendingToolCall`. + pub async fn handle_pending_tool_call( + &self, + params: HandlePendingToolCallRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_TOOLS_HANDLEPENDINGTOOLCALL, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.ui.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcUi<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcUi<'a> { + /// Wire method: `session.ui.elicitation`. + pub async fn elicitation( + &self, + params: UIElicitationRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_UI_ELICITATION, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.ui.handlePendingElicitation`. + pub async fn handle_pending_elicitation( + &self, + params: UIHandlePendingElicitationRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_UI_HANDLEPENDINGELICITATION, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.usage.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcUsage<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcUsage<'a> { + /// Wire method: `session.usage.getMetrics`. + /// + ///
+ /// + /// **Experimental.** This API is part of an experimental wire-protocol surface + /// and may change or be removed in future SDK or CLI releases. Pin both the + /// SDK and CLI versions if your code depends on it. + /// + ///
+ pub async fn get_metrics(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_USAGE_GETMETRICS, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.workspaces.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcWorkspaces<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcWorkspaces<'a> { + /// Wire method: `session.workspaces.getWorkspace`. + pub async fn get_workspace(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_WORKSPACES_GETWORKSPACE, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.listFiles`. + pub async fn list_files(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_WORKSPACES_LISTFILES, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.readFile`. + pub async fn read_file( + &self, + params: WorkspacesReadFileRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_WORKSPACES_READFILE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.createFile`. + pub async fn create_file(&self, params: WorkspacesCreateFileRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_WORKSPACES_CREATEFILE, + Some(wire_params), + ) + .await?; + Ok(()) + } +} diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs new file mode 100644 index 000000000..f3200925e --- /dev/null +++ b/rust/src/generated/session_events.rs @@ -0,0 +1,3091 @@ +//! Auto-generated from session-events.schema.json โ€” do not edit manually. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::types::{RequestId, SessionId}; + +/// Identifies the kind of session event. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SessionEventType { + #[serde(rename = "session.start")] + SessionStart, + #[serde(rename = "session.resume")] + SessionResume, + #[serde(rename = "session.remote_steerable_changed")] + SessionRemoteSteerableChanged, + #[serde(rename = "session.error")] + SessionError, + #[serde(rename = "session.idle")] + SessionIdle, + #[serde(rename = "session.title_changed")] + SessionTitleChanged, + #[serde(rename = "session.info")] + SessionInfo, + #[serde(rename = "session.warning")] + SessionWarning, + #[serde(rename = "session.model_change")] + SessionModelChange, + #[serde(rename = "session.mode_changed")] + SessionModeChanged, + #[serde(rename = "session.plan_changed")] + SessionPlanChanged, + #[serde(rename = "session.workspace_file_changed")] + SessionWorkspaceFileChanged, + #[serde(rename = "session.handoff")] + SessionHandoff, + #[serde(rename = "session.truncation")] + SessionTruncation, + #[serde(rename = "session.snapshot_rewind")] + SessionSnapshotRewind, + #[serde(rename = "session.shutdown")] + SessionShutdown, + #[serde(rename = "session.context_changed")] + SessionContextChanged, + #[serde(rename = "session.usage_info")] + SessionUsageInfo, + #[serde(rename = "session.compaction_start")] + SessionCompactionStart, + #[serde(rename = "session.compaction_complete")] + SessionCompactionComplete, + #[serde(rename = "session.task_complete")] + SessionTaskComplete, + #[serde(rename = "user.message")] + UserMessage, + #[serde(rename = "pending_messages.modified")] + PendingMessagesModified, + #[serde(rename = "assistant.turn_start")] + AssistantTurnStart, + #[serde(rename = "assistant.intent")] + AssistantIntent, + #[serde(rename = "assistant.reasoning")] + AssistantReasoning, + #[serde(rename = "assistant.reasoning_delta")] + AssistantReasoningDelta, + #[serde(rename = "assistant.streaming_delta")] + AssistantStreamingDelta, + #[serde(rename = "assistant.message")] + AssistantMessage, + #[serde(rename = "assistant.message_start")] + AssistantMessageStart, + #[serde(rename = "assistant.message_delta")] + AssistantMessageDelta, + #[serde(rename = "assistant.turn_end")] + AssistantTurnEnd, + #[serde(rename = "assistant.usage")] + AssistantUsage, + #[serde(rename = "model.call_failure")] + ModelCallFailure, + #[serde(rename = "abort")] + Abort, + #[serde(rename = "tool.user_requested")] + ToolUserRequested, + #[serde(rename = "tool.execution_start")] + ToolExecutionStart, + #[serde(rename = "tool.execution_partial_result")] + ToolExecutionPartialResult, + #[serde(rename = "tool.execution_progress")] + ToolExecutionProgress, + #[serde(rename = "tool.execution_complete")] + ToolExecutionComplete, + #[serde(rename = "skill.invoked")] + SkillInvoked, + #[serde(rename = "subagent.started")] + SubagentStarted, + #[serde(rename = "subagent.completed")] + SubagentCompleted, + #[serde(rename = "subagent.failed")] + SubagentFailed, + #[serde(rename = "subagent.selected")] + SubagentSelected, + #[serde(rename = "subagent.deselected")] + SubagentDeselected, + #[serde(rename = "hook.start")] + HookStart, + #[serde(rename = "hook.end")] + HookEnd, + #[serde(rename = "system.message")] + SystemMessage, + #[serde(rename = "system.notification")] + SystemNotification, + #[serde(rename = "permission.requested")] + PermissionRequested, + #[serde(rename = "permission.completed")] + PermissionCompleted, + #[serde(rename = "user_input.requested")] + UserInputRequested, + #[serde(rename = "user_input.completed")] + UserInputCompleted, + #[serde(rename = "elicitation.requested")] + ElicitationRequested, + #[serde(rename = "elicitation.completed")] + ElicitationCompleted, + #[serde(rename = "sampling.requested")] + SamplingRequested, + #[serde(rename = "sampling.completed")] + SamplingCompleted, + #[serde(rename = "mcp.oauth_required")] + McpOauthRequired, + #[serde(rename = "mcp.oauth_completed")] + McpOauthCompleted, + #[serde(rename = "external_tool.requested")] + ExternalToolRequested, + #[serde(rename = "external_tool.completed")] + ExternalToolCompleted, + #[serde(rename = "command.queued")] + CommandQueued, + #[serde(rename = "command.execute")] + CommandExecute, + #[serde(rename = "command.completed")] + CommandCompleted, + #[serde(rename = "auto_mode_switch.requested")] + AutoModeSwitchRequested, + #[serde(rename = "auto_mode_switch.completed")] + AutoModeSwitchCompleted, + #[serde(rename = "commands.changed")] + CommandsChanged, + #[serde(rename = "capabilities.changed")] + CapabilitiesChanged, + #[serde(rename = "exit_plan_mode.requested")] + ExitPlanModeRequested, + #[serde(rename = "exit_plan_mode.completed")] + ExitPlanModeCompleted, + #[serde(rename = "session.tools_updated")] + SessionToolsUpdated, + #[serde(rename = "session.background_tasks_changed")] + SessionBackgroundTasksChanged, + #[serde(rename = "session.skills_loaded")] + SessionSkillsLoaded, + #[serde(rename = "session.custom_agents_updated")] + SessionCustomAgentsUpdated, + #[serde(rename = "session.mcp_servers_loaded")] + SessionMcpServersLoaded, + #[serde(rename = "session.mcp_server_status_changed")] + SessionMcpServerStatusChanged, + #[serde(rename = "session.extensions_loaded")] + SessionExtensionsLoaded, + /// Unknown event type for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Typed session event data, discriminated by the event `type` field. +/// +/// Use with [`TypedSessionEvent`] for fully typed event handling. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum SessionEventData { + #[serde(rename = "session.start")] + SessionStart(SessionStartData), + #[serde(rename = "session.resume")] + SessionResume(SessionResumeData), + #[serde(rename = "session.remote_steerable_changed")] + SessionRemoteSteerableChanged(SessionRemoteSteerableChangedData), + #[serde(rename = "session.error")] + SessionError(SessionErrorData), + #[serde(rename = "session.idle")] + SessionIdle(SessionIdleData), + #[serde(rename = "session.title_changed")] + SessionTitleChanged(SessionTitleChangedData), + #[serde(rename = "session.info")] + SessionInfo(SessionInfoData), + #[serde(rename = "session.warning")] + SessionWarning(SessionWarningData), + #[serde(rename = "session.model_change")] + SessionModelChange(SessionModelChangeData), + #[serde(rename = "session.mode_changed")] + SessionModeChanged(SessionModeChangedData), + #[serde(rename = "session.plan_changed")] + SessionPlanChanged(SessionPlanChangedData), + #[serde(rename = "session.workspace_file_changed")] + SessionWorkspaceFileChanged(SessionWorkspaceFileChangedData), + #[serde(rename = "session.handoff")] + SessionHandoff(SessionHandoffData), + #[serde(rename = "session.truncation")] + SessionTruncation(SessionTruncationData), + #[serde(rename = "session.snapshot_rewind")] + SessionSnapshotRewind(SessionSnapshotRewindData), + #[serde(rename = "session.shutdown")] + SessionShutdown(SessionShutdownData), + #[serde(rename = "session.context_changed")] + SessionContextChanged(SessionContextChangedData), + #[serde(rename = "session.usage_info")] + SessionUsageInfo(SessionUsageInfoData), + #[serde(rename = "session.compaction_start")] + SessionCompactionStart(SessionCompactionStartData), + #[serde(rename = "session.compaction_complete")] + SessionCompactionComplete(SessionCompactionCompleteData), + #[serde(rename = "session.task_complete")] + SessionTaskComplete(SessionTaskCompleteData), + #[serde(rename = "user.message")] + UserMessage(UserMessageData), + #[serde(rename = "pending_messages.modified")] + PendingMessagesModified(PendingMessagesModifiedData), + #[serde(rename = "assistant.turn_start")] + AssistantTurnStart(AssistantTurnStartData), + #[serde(rename = "assistant.intent")] + AssistantIntent(AssistantIntentData), + #[serde(rename = "assistant.reasoning")] + AssistantReasoning(AssistantReasoningData), + #[serde(rename = "assistant.reasoning_delta")] + AssistantReasoningDelta(AssistantReasoningDeltaData), + #[serde(rename = "assistant.streaming_delta")] + AssistantStreamingDelta(AssistantStreamingDeltaData), + #[serde(rename = "assistant.message")] + AssistantMessage(AssistantMessageData), + #[serde(rename = "assistant.message_start")] + AssistantMessageStart(AssistantMessageStartData), + #[serde(rename = "assistant.message_delta")] + AssistantMessageDelta(AssistantMessageDeltaData), + #[serde(rename = "assistant.turn_end")] + AssistantTurnEnd(AssistantTurnEndData), + #[serde(rename = "assistant.usage")] + AssistantUsage(AssistantUsageData), + #[serde(rename = "model.call_failure")] + ModelCallFailure(ModelCallFailureData), + #[serde(rename = "abort")] + Abort(AbortData), + #[serde(rename = "tool.user_requested")] + ToolUserRequested(ToolUserRequestedData), + #[serde(rename = "tool.execution_start")] + ToolExecutionStart(ToolExecutionStartData), + #[serde(rename = "tool.execution_partial_result")] + ToolExecutionPartialResult(ToolExecutionPartialResultData), + #[serde(rename = "tool.execution_progress")] + ToolExecutionProgress(ToolExecutionProgressData), + #[serde(rename = "tool.execution_complete")] + ToolExecutionComplete(ToolExecutionCompleteData), + #[serde(rename = "skill.invoked")] + SkillInvoked(SkillInvokedData), + #[serde(rename = "subagent.started")] + SubagentStarted(SubagentStartedData), + #[serde(rename = "subagent.completed")] + SubagentCompleted(SubagentCompletedData), + #[serde(rename = "subagent.failed")] + SubagentFailed(SubagentFailedData), + #[serde(rename = "subagent.selected")] + SubagentSelected(SubagentSelectedData), + #[serde(rename = "subagent.deselected")] + SubagentDeselected(SubagentDeselectedData), + #[serde(rename = "hook.start")] + HookStart(HookStartData), + #[serde(rename = "hook.end")] + HookEnd(HookEndData), + #[serde(rename = "system.message")] + SystemMessage(SystemMessageData), + #[serde(rename = "system.notification")] + SystemNotification(SystemNotificationData), + #[serde(rename = "permission.requested")] + PermissionRequested(PermissionRequestedData), + #[serde(rename = "permission.completed")] + PermissionCompleted(PermissionCompletedData), + #[serde(rename = "user_input.requested")] + UserInputRequested(UserInputRequestedData), + #[serde(rename = "user_input.completed")] + UserInputCompleted(UserInputCompletedData), + #[serde(rename = "elicitation.requested")] + ElicitationRequested(ElicitationRequestedData), + #[serde(rename = "elicitation.completed")] + ElicitationCompleted(ElicitationCompletedData), + #[serde(rename = "sampling.requested")] + SamplingRequested(SamplingRequestedData), + #[serde(rename = "sampling.completed")] + SamplingCompleted(SamplingCompletedData), + #[serde(rename = "mcp.oauth_required")] + McpOauthRequired(McpOauthRequiredData), + #[serde(rename = "mcp.oauth_completed")] + McpOauthCompleted(McpOauthCompletedData), + #[serde(rename = "external_tool.requested")] + ExternalToolRequested(ExternalToolRequestedData), + #[serde(rename = "external_tool.completed")] + ExternalToolCompleted(ExternalToolCompletedData), + #[serde(rename = "command.queued")] + CommandQueued(CommandQueuedData), + #[serde(rename = "command.execute")] + CommandExecute(CommandExecuteData), + #[serde(rename = "command.completed")] + CommandCompleted(CommandCompletedData), + #[serde(rename = "auto_mode_switch.requested")] + AutoModeSwitchRequested(AutoModeSwitchRequestedData), + #[serde(rename = "auto_mode_switch.completed")] + AutoModeSwitchCompleted(AutoModeSwitchCompletedData), + #[serde(rename = "commands.changed")] + CommandsChanged(CommandsChangedData), + #[serde(rename = "capabilities.changed")] + CapabilitiesChanged(CapabilitiesChangedData), + #[serde(rename = "exit_plan_mode.requested")] + ExitPlanModeRequested(ExitPlanModeRequestedData), + #[serde(rename = "exit_plan_mode.completed")] + ExitPlanModeCompleted(ExitPlanModeCompletedData), + #[serde(rename = "session.tools_updated")] + SessionToolsUpdated(SessionToolsUpdatedData), + #[serde(rename = "session.background_tasks_changed")] + SessionBackgroundTasksChanged(SessionBackgroundTasksChangedData), + #[serde(rename = "session.skills_loaded")] + SessionSkillsLoaded(SessionSkillsLoadedData), + #[serde(rename = "session.custom_agents_updated")] + SessionCustomAgentsUpdated(SessionCustomAgentsUpdatedData), + #[serde(rename = "session.mcp_servers_loaded")] + SessionMcpServersLoaded(SessionMcpServersLoadedData), + #[serde(rename = "session.mcp_server_status_changed")] + SessionMcpServerStatusChanged(SessionMcpServerStatusChangedData), + #[serde(rename = "session.extensions_loaded")] + SessionExtensionsLoaded(SessionExtensionsLoadedData), +} + +/// A session event with typed data payload. +/// +/// The common event fields (id, timestamp, parentId, ephemeral, agentId) +/// are available directly. The event-specific data is in the `payload` +/// field as a [`SessionEventData`] enum. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TypedSessionEvent { + /// Unique event identifier (UUID v4). + pub id: String, + /// ISO 8601 timestamp when the event was created. + pub timestamp: String, + /// ID of the preceding event in the chain. + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + /// When true, the event is transient and not persisted. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// Sub-agent instance identifier. Absent for events from the root / + /// main agent and session-level events. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + /// The typed event payload (discriminated by event type). + #[serde(flatten)] + pub payload: SessionEventData, +} + +/// Working directory and git context at session start +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkingDirectoryContext { + /// Base commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub base_commit: Option, + /// Current git branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Current working directory path + pub cwd: String, + /// Root directory of the git repository, resolved via git rev-parse + #[serde(skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Head commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub head_commit: Option, + /// Hosting platform type of the repository (github or ado) + #[serde(skip_serializing_if = "Option::is_none")] + pub host_type: Option, + /// Repository identifier derived from the git remote URL ("owner/name" for GitHub, "org/project/repo" for Azure DevOps) + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Raw host string from the git remote URL (e.g. "github.com", "mycompany.ghe.com", "dev.azure.com") + #[serde(skip_serializing_if = "Option::is_none")] + pub repository_host: Option, +} + +/// Session initialization metadata including context and configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartData { + /// Whether the session was already in use by another client at start time + #[serde(skip_serializing_if = "Option::is_none")] + pub already_in_use: Option, + /// Working directory and git context at session start + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Version string of the Copilot application + pub copilot_version: String, + /// Identifier of the software producing the events (e.g., "copilot-agent") + pub producer: String, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this session supports remote steering via Mission Control + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + /// Model selected at session creation time, if any + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_model: Option, + /// Unique identifier for the session + pub session_id: SessionId, + /// ISO 8601 timestamp when the session was created + pub start_time: String, + /// Schema version number for the session event format + pub version: f64, +} + +/// Session resume metadata including current context and event count +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionResumeData { + /// Whether the session was already in use by another client at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub already_in_use: Option, + /// Updated working directory and git context at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// When true, tool calls and permission requests left in flight by the previous session lifetime remain pending after resume and the agentic loop awaits their results. User sends are queued behind the pending work until all such requests reach a terminal state. When false (the default), any such tool calls and permission requests are immediately marked as interrupted on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub continue_pending_work: Option, + /// Total number of persisted events in the session at the time of resume + pub event_count: f64, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this session supports remote steering via Mission Control + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + /// ISO 8601 timestamp when the session was resumed + pub resume_time: String, + /// Model currently selected at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_model: Option, + /// True when this resume attached to a session that the runtime already had running in-memory (for example, an extension joining a session another client was actively driving). False (or omitted) for cold resumes โ€” the runtime had to reconstitute the session from its persisted event log. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_was_active: Option, +} + +/// Notifies Mission Control that the session's remote steering capability has changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionRemoteSteerableChangedData { + /// Whether this session now supports remote steering via Mission Control + pub remote_steerable: bool, +} + +/// Error details for timeline display including message and optional diagnostic information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionErrorData { + /// Only set on `errorType: "rate_limit"`. When `true`, the runtime will follow this error with an `auto_mode_switch.requested` event (or silently switch if `continueOnAutoMode` is enabled). UI clients can use this flag to suppress duplicate rendering of the rate-limit error when they show their own auto-mode-switch prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub eligible_for_auto_switch: Option, + /// Fine-grained error code from the upstream provider, when available. For `errorType: "rate_limit"`, this is one of the `RateLimitErrorCode` values (e.g., `"user_weekly_rate_limited"`, `"user_global_rate_limited"`, `"rate_limited"`, `"user_model_rate_limited"`, `"integration_rate_limited"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub error_code: Option, + /// Category of error (e.g., "authentication", "authorization", "quota", "rate_limit", "context_limit", "query") + pub error_type: String, + /// Human-readable error message + pub message: String, + /// GitHub request tracing ID (x-github-request-id header) for correlating with server-side logs + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Error stack trace, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, + /// HTTP status code from the upstream request, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, + /// Optional URL associated with this error that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Payload indicating the session is idle with no background agents in flight +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionIdleData { + /// True when the preceding agentic loop was cancelled via abort signal + #[serde(skip_serializing_if = "Option::is_none")] + pub aborted: Option, +} + +/// Session title change payload containing the new display title +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTitleChangedData { + /// The new display title for the session + pub title: String, +} + +/// Informational message for timeline display with categorization +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInfoData { + /// Category of informational message (e.g., "notification", "timing", "context_window", "mcp", "snapshot", "configuration", "authentication", "model") + pub info_type: String, + /// Human-readable informational message for display in the timeline + pub message: String, + /// Optional actionable tip displayed with this message + #[serde(skip_serializing_if = "Option::is_none")] + pub tip: Option, + /// Optional URL associated with this message that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Warning message for timeline display with categorization +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWarningData { + /// Human-readable warning message for display in the timeline + pub message: String, + /// Optional URL associated with this warning that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + /// Category of warning (e.g., "subscription", "policy", "mcp") + pub warning_type: String, +} + +/// Model change details including previous and new model identifiers +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelChangeData { + /// Reason the change happened, when not user-initiated. Currently `"rate_limit_auto_switch"` for changes triggered by the auto-mode-switch rate-limit recovery path. UI clients can use this to render contextual copy. + #[serde(skip_serializing_if = "Option::is_none")] + pub cause: Option, + /// Newly selected model identifier + pub new_model: String, + /// Model that was previously selected, if any + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_model: Option, + /// Reasoning effort level before the model change, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_reasoning_effort: Option, + /// Reasoning effort level after the model change, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +/// Agent mode change details including previous and new modes +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModeChangedData { + /// Agent mode after the change (e.g., "interactive", "plan", "autopilot") + pub new_mode: String, + /// Agent mode before the change (e.g., "interactive", "plan", "autopilot") + pub previous_mode: String, +} + +/// Plan file operation details indicating what changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanChangedData { + /// The type of operation performed on the plan file + pub operation: PlanChangedOperation, +} + +/// Workspace file change details including path and operation type +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspaceFileChangedData { + /// Whether the file was newly created or updated + pub operation: WorkspaceFileChangedOperation, + /// Relative path within the session workspace files directory + pub path: String, +} + +/// Repository context for the handed-off session +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HandoffRepository { + /// Git branch name, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Repository name + pub name: String, + /// Repository owner (user or organization) + pub owner: String, +} + +/// Session handoff metadata including source, context, and repository information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHandoffData { + /// Additional context information for the handoff + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// ISO 8601 timestamp when the handoff occurred + pub handoff_time: String, + /// GitHub host URL for the source session (e.g., https://github.com or https://tenant.ghe.com) + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Session ID of the remote session being handed off + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session_id: Option, + /// Repository context for the handed-off session + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Origin type of the session being handed off + pub source_type: HandoffSourceType, + /// Summary of the work done in the source session + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// Conversation truncation statistics including token counts and removed content metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTruncationData { + /// Number of messages removed by truncation + pub messages_removed_during_truncation: f64, + /// Identifier of the component that performed truncation (e.g., "BasicTruncator") + pub performed_by: String, + /// Number of conversation messages after truncation + pub post_truncation_messages_length: f64, + /// Total tokens in conversation messages after truncation + pub post_truncation_tokens_in_messages: f64, + /// Number of conversation messages before truncation + pub pre_truncation_messages_length: f64, + /// Total tokens in conversation messages before truncation + pub pre_truncation_tokens_in_messages: f64, + /// Maximum token count for the model's context window + pub token_limit: f64, + /// Number of tokens removed by truncation + pub tokens_removed_during_truncation: f64, +} + +/// Session rewind details including target event and count of removed events +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSnapshotRewindData { + /// Number of events that were removed by the rewind + pub events_removed: f64, + /// Event ID that was rewound to; this event and all after it were removed + pub up_to_event_id: String, +} + +/// Aggregate code change metrics for the session +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownCodeChanges { + /// List of file paths that were modified during the session + pub files_modified: Vec, + /// Total number of lines added during the session + pub lines_added: f64, + /// Total number of lines removed during the session + pub lines_removed: f64, +} + +/// Request count and cost metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetricRequests { + /// Cumulative cost multiplier for requests to this model + pub cost: f64, + /// Total number of API requests made to this model + pub count: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetricTokenDetail { + /// Accumulated token count for this token type + pub token_count: f64, +} + +/// Token usage breakdown +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetricUsage { + /// Total tokens read from prompt cache across all requests + pub cache_read_tokens: f64, + /// Total tokens written to prompt cache across all requests + pub cache_write_tokens: f64, + /// Total input tokens consumed across all requests to this model + pub input_tokens: f64, + /// Total output tokens produced across all requests to this model + pub output_tokens: f64, + /// Total reasoning tokens produced across all requests to this model + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetric { + /// Request count and cost metrics + pub requests: ShutdownModelMetricRequests, + /// Token count details per type + #[serde(default)] + pub token_details: HashMap, + /// Accumulated nano-AI units cost for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub total_nano_aiu: Option, + /// Token usage breakdown + pub usage: ShutdownModelMetricUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownTokenDetail { + /// Accumulated token count for this token type + pub token_count: f64, +} + +/// Session termination metrics including usage statistics, code changes, and shutdown reason +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShutdownData { + /// Aggregate code change metrics for the session + pub code_changes: ShutdownCodeChanges, + /// Non-system message token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Model that was selected at the time of shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Total tokens in context window at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub current_tokens: Option, + /// Error description when shutdownType is "error" + #[serde(skip_serializing_if = "Option::is_none")] + pub error_reason: Option, + /// Per-model usage breakdown, keyed by model identifier + pub model_metrics: HashMap, + /// Unix timestamp (milliseconds) when the session started + pub session_start_time: f64, + /// Whether the session ended normally ("routine") or due to a crash/fatal error ("error") + pub shutdown_type: ShutdownType, + /// System message token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Session-wide per-token-type accumulated token counts + #[serde(default)] + pub token_details: HashMap, + /// Tool definitions token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, + /// Cumulative time spent in API calls during the session, in milliseconds + pub total_api_duration_ms: f64, + /// Session-wide accumulated nano-AI units cost + #[serde(skip_serializing_if = "Option::is_none")] + pub total_nano_aiu: Option, + /// Total number of premium API requests used during the session + pub total_premium_requests: f64, +} + +/// Working directory and git context at session start +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionContextChangedData { + /// Base commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub base_commit: Option, + /// Current git branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Current working directory path + pub cwd: String, + /// Root directory of the git repository, resolved via git rev-parse + #[serde(skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Head commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub head_commit: Option, + /// Hosting platform type of the repository (github or ado) + #[serde(skip_serializing_if = "Option::is_none")] + pub host_type: Option, + /// Repository identifier derived from the git remote URL ("owner/name" for GitHub, "org/project/repo" for Azure DevOps) + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Raw host string from the git remote URL (e.g. "github.com", "mycompany.ghe.com", "dev.azure.com") + #[serde(skip_serializing_if = "Option::is_none")] + pub repository_host: Option, +} + +/// Current context window usage statistics including token and message counts +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageInfoData { + /// Token count from non-system messages (user, assistant, tool) + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Current number of tokens in the context window + pub current_tokens: f64, + /// Whether this is the first usage_info event emitted in this session + #[serde(skip_serializing_if = "Option::is_none")] + pub is_initial: Option, + /// Current number of messages in the conversation + pub messages_length: f64, + /// Token count from system message(s) + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Maximum token count for the model's context window + pub token_limit: f64, + /// Token count from tool definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Context window breakdown at the start of LLM-powered conversation compaction +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCompactionStartData { + /// Token count from non-system messages (user, assistant, tool) at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Token count from system message(s) at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Token count from tool definitions at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Token usage detail for a single billing category +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsedCopilotUsageTokenDetail { + /// Number of tokens in this billing batch + pub batch_size: f64, + /// Cost per batch of tokens + pub cost_per_batch: f64, + /// Total token count for this entry + pub token_count: f64, + /// Token category (e.g., "input", "output") + pub token_type: String, +} + +/// Per-request cost and usage data from the CAPI copilot_usage response field +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsedCopilotUsage { + /// Itemized token usage breakdown + pub token_details: Vec, + /// Total cost in nano-AI units for this request + pub total_nano_aiu: f64, +} + +/// Token usage breakdown for the compaction LLM call (aligned with assistant.usage format) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsed { + /// Cached input tokens reused in the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_tokens: Option, + /// Tokens written to prompt cache in the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_tokens: Option, + /// Per-request cost and usage data from the CAPI copilot_usage response field + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_usage: Option, + /// Duration of the compaction LLM call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// Input tokens consumed by the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + /// Model identifier used for the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Output tokens produced by the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, +} + +/// Conversation compaction results including success status, metrics, and optional error details +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCompactionCompleteData { + /// Checkpoint snapshot number created for recovery + #[serde(skip_serializing_if = "Option::is_none")] + pub checkpoint_number: Option, + /// File path where the checkpoint was stored + #[serde(skip_serializing_if = "Option::is_none")] + pub checkpoint_path: Option, + /// Token usage breakdown for the compaction LLM call (aligned with assistant.usage format) + #[serde(skip_serializing_if = "Option::is_none")] + pub compaction_tokens_used: Option, + /// Token count from non-system messages (user, assistant, tool) after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Error message if compaction failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Number of messages removed during compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub messages_removed: Option, + /// Total tokens in conversation after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub post_compaction_tokens: Option, + /// Number of messages before compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub pre_compaction_messages_length: Option, + /// Total tokens in conversation before compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub pre_compaction_tokens: Option, + /// GitHub request tracing ID (x-github-request-id header) for the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Whether compaction completed successfully + pub success: bool, + /// LLM-generated summary of the compacted conversation history + #[serde(skip_serializing_if = "Option::is_none")] + pub summary_content: Option, + /// Token count from system message(s) after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Number of tokens removed during compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_removed: Option, + /// Token count from tool definitions after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Task completion notification with summary from the agent +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTaskCompleteData { + /// Whether the tool call succeeded. False when validation failed (e.g., invalid arguments) + #[serde(skip_serializing_if = "Option::is_none")] + pub success: Option, + /// Summary of the completed task, provided by the agent + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserMessageData { + /// The agent mode that was active when this message was sent + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_mode: Option, + /// Files, selections, or GitHub references attached to the message + #[serde(default)] + pub attachments: Vec, + /// The user's message text as displayed in the timeline + pub content: String, + /// CAPI interaction ID for correlating this user message with its turn + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Path-backed native document attachments that stayed on the tagged_files path flow because native upload would exceed the request size limit + #[serde(default)] + pub native_document_path_fallback_paths: Vec, + /// Parent agent task ID for background telemetry correlated to this user turn + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_agent_task_id: Option, + /// Origin of this message, used for timeline filtering (e.g., "skill-pdf" for skill-injected messages that should be hidden from the user) + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Normalized document MIME types that were sent natively instead of through tagged_files XML + #[serde(default)] + pub supported_native_document_mime_types: Vec, + /// Transformed version of the message sent to the model, with XML wrapping, timestamps, and other augmentations for prompt caching + #[serde(skip_serializing_if = "Option::is_none")] + pub transformed_content: Option, +} + +/// Empty payload; the event signals that the pending message queue has changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PendingMessagesModifiedData {} + +/// Turn initialization metadata including identifier and interaction tracking +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantTurnStartData { + /// CAPI interaction ID for correlating this turn with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Identifier for this turn within the agentic loop, typically a stringified turn number + pub turn_id: String, +} + +/// Agent intent description for current activity or plan +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantIntentData { + /// Short description of what the agent is currently doing or planning to do + pub intent: String, +} + +/// Assistant reasoning content for timeline display with complete thinking text +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantReasoningData { + /// The complete extended thinking text from the model + pub content: String, + /// Unique identifier for this reasoning block + pub reasoning_id: String, +} + +/// Streaming reasoning delta for incremental extended thinking updates +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantReasoningDeltaData { + /// Incremental text chunk to append to the reasoning content + pub delta_content: String, + /// Reasoning block ID this delta belongs to, matching the corresponding assistant.reasoning event + pub reasoning_id: String, +} + +/// Streaming response progress with cumulative byte count +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantStreamingDeltaData { + /// Cumulative total bytes received from the streaming response so far + pub total_response_size_bytes: f64, +} + +/// A tool invocation request from the assistant +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageToolRequest { + /// Arguments to pass to the tool, format depends on the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Resolved intention summary describing what this specific call does + #[serde(skip_serializing_if = "Option::is_none")] + pub intention_summary: Option, + /// Name of the MCP server hosting this tool, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_server_name: Option, + /// Name of the tool being invoked + pub name: String, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Human-readable display title for the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_title: Option, + /// Tool call type: "function" for standard tool calls, "custom" for grammar-based tool calls. Defaults to "function" when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +/// Assistant response containing text content, optional tool requests, and interaction metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageData { + /// The assistant's text response content + pub content: String, + /// Encrypted reasoning content from OpenAI models. Session-bound and stripped on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub encrypted_content: Option, + /// CAPI interaction ID for correlating this message with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Unique identifier for this assistant message + pub message_id: String, + /// Actual output token count from the API response (completion_tokens), used for accurate token accounting + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Generation phase for phased-output models (e.g., thinking vs. response phases) + #[serde(skip_serializing_if = "Option::is_none")] + pub phase: Option, + /// Opaque/encrypted extended thinking data from Anthropic models. Session-bound and stripped on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_opaque: Option, + /// Readable reasoning text from the model's extended thinking + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_text: Option, + /// GitHub request tracing ID (x-github-request-id header) for correlating with server-side logs + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Tool invocations requested by the assistant in this message + #[serde(default)] + pub tool_requests: Vec, + /// Identifier for the agent loop turn that produced this message, matching the corresponding assistant.turn_start event + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_id: Option, +} + +/// Streaming assistant message start metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageStartData { + /// Message ID this start event belongs to, matching subsequent deltas and assistant.message + pub message_id: String, + /// Generation phase this message belongs to for phased-output models + #[serde(skip_serializing_if = "Option::is_none")] + pub phase: Option, +} + +/// Streaming assistant message delta for incremental response updates +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageDeltaData { + /// Incremental text chunk to append to the message content + pub delta_content: String, + /// Message ID this delta belongs to, matching the corresponding assistant.message event + pub message_id: String, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, +} + +/// Turn completion metadata including the turn identifier +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantTurnEndData { + /// Identifier of the turn that has ended, matching the corresponding assistant.turn_start event + pub turn_id: String, +} + +/// Token usage detail for a single billing category +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageCopilotUsageTokenDetail { + /// Number of tokens in this billing batch + pub batch_size: f64, + /// Cost per batch of tokens + pub cost_per_batch: f64, + /// Total token count for this entry + pub token_count: f64, + /// Token category (e.g., "input", "output") + pub token_type: String, +} + +/// Per-request cost and usage data from the CAPI copilot_usage response field +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageCopilotUsage { + /// Itemized token usage breakdown + pub token_details: Vec, + /// Total cost in nano-AI units for this request + pub total_nano_aiu: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageQuotaSnapshot { + /// Total requests allowed by the entitlement + pub entitlement_requests: f64, + /// Whether the user has an unlimited usage entitlement + pub is_unlimited_entitlement: bool, + /// Number of requests over the entitlement limit + pub overage: f64, + /// Whether overage is allowed when quota is exhausted + pub overage_allowed_with_exhausted_quota: bool, + /// Percentage of quota remaining (0.0 to 1.0) + pub remaining_percentage: f64, + /// Date when the quota resets + #[serde(skip_serializing_if = "Option::is_none")] + pub reset_date: Option, + /// Whether usage is still permitted after quota exhaustion + pub usage_allowed_with_exhausted_quota: bool, + /// Number of requests already consumed + pub used_requests: f64, +} + +/// LLM API call usage metrics including tokens, costs, quotas, and billing information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageData { + /// Completion ID from the model provider (e.g., chatcmpl-abc123) + #[serde(skip_serializing_if = "Option::is_none")] + pub api_call_id: Option, + /// Number of tokens read from prompt cache + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_tokens: Option, + /// Number of tokens written to prompt cache + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_tokens: Option, + /// Per-request cost and usage data from the CAPI copilot_usage response field + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_usage: Option, + /// Model multiplier cost for billing purposes + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + /// Duration of the API call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls + #[serde(skip_serializing_if = "Option::is_none")] + pub initiator: Option, + /// Number of input tokens consumed + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + /// Average inter-token latency in milliseconds. Only available for streaming requests + #[serde(skip_serializing_if = "Option::is_none")] + pub inter_token_latency_ms: Option, + /// Model identifier used for this API call + pub model: String, + /// Number of output tokens produced + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + /// Parent tool call ID when this usage originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// GitHub request tracing ID (x-github-request-id header) for server-side log correlation + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Per-quota resource usage snapshots, keyed by quota identifier + #[serde(default)] + pub quota_snapshots: HashMap, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Number of output tokens used for reasoning (e.g., chain-of-thought) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + /// Time to first token in milliseconds. Only available for streaming requests + #[serde(skip_serializing_if = "Option::is_none")] + pub ttft_ms: Option, +} + +/// Failed LLM API call metadata for telemetry +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCallFailureData { + /// Completion ID from the model provider (e.g., chatcmpl-abc123) + #[serde(skip_serializing_if = "Option::is_none")] + pub api_call_id: Option, + /// Duration of the failed API call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Raw provider/runtime error message for restricted telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls + #[serde(skip_serializing_if = "Option::is_none")] + pub initiator: Option, + /// Model identifier used for the failed API call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// GitHub request tracing ID (x-github-request-id header) for server-side log correlation + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Where the failed model call originated + pub source: ModelCallFailureSource, + /// HTTP status code from the failed request + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, +} + +/// Turn abort information including the reason for termination +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AbortData { + /// Reason the current turn was aborted (e.g., "user initiated") + pub reason: String, +} + +/// User-initiated tool invocation request with tool name and arguments +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolUserRequestedData { + /// Arguments for the tool invocation + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Name of the tool the user wants to invoke + pub tool_name: String, +} + +/// Tool execution startup details including MCP server information when applicable +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionStartData { + /// Arguments passed to the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Name of the MCP server hosting this tool, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_server_name: Option, + /// Original tool name on the MCP server, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_tool_name: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Name of the tool being executed + pub tool_name: String, + /// Identifier for the agent loop turn this tool was invoked in, matching the corresponding assistant.turn_start event + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_id: Option, +} + +/// Streaming tool execution output for incremental result display +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionPartialResultData { + /// Incremental output chunk from the running tool + pub partial_output: String, + /// Tool call ID this partial result belongs to + pub tool_call_id: String, +} + +/// Tool execution progress notification with status message +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionProgressData { + /// Human-readable progress status message (e.g., from an MCP server) + pub progress_message: String, + /// Tool call ID this progress notification belongs to + pub tool_call_id: String, +} + +/// Error details when the tool execution failed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteError { + /// Machine-readable error code + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + /// Human-readable error message + pub message: String, +} + +/// Tool execution result on success +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteResult { + /// Concise tool result text sent to the LLM for chat completion, potentially truncated for token efficiency + pub content: String, + /// Structured content blocks (text, images, audio, resources) returned by the tool in their native format + #[serde(default)] + pub contents: Vec, + /// Full detailed tool result for UI/timeline display, preserving complete content such as diffs. Falls back to content when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub detailed_content: Option, +} + +/// Tool execution completion results including success status, detailed output, and error information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteData { + /// Error details when the tool execution failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// CAPI interaction ID for correlating this tool execution with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Whether this tool call was explicitly requested by the user rather than the assistant + #[serde(skip_serializing_if = "Option::is_none")] + pub is_user_requested: Option, + /// Model identifier that generated this tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Tool execution result on success + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Whether the tool execution completed successfully + pub success: bool, + /// Unique identifier for the completed tool call + pub tool_call_id: String, + /// Tool-specific telemetry data (e.g., CodeQL check counts, grep match counts) + #[serde(default)] + pub tool_telemetry: HashMap, + /// Identifier for the agent loop turn this tool was invoked in, matching the corresponding assistant.turn_start event + #[serde(skip_serializing_if = "Option::is_none")] + pub turn_id: Option, +} + +/// Skill invocation details including content, allowed tools, and plugin metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillInvokedData { + /// Tool names that should be auto-approved when this skill is active + #[serde(default)] + pub allowed_tools: Vec, + /// Full content of the skill file, injected into the conversation for the model + pub content: String, + /// Description of the skill from its SKILL.md frontmatter + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Name of the invoked skill + pub name: String, + /// File path to the SKILL.md definition + pub path: String, + /// Name of the plugin this skill originated from, when applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub plugin_name: Option, + /// Version of the plugin this skill originated from, when applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub plugin_version: Option, +} + +/// Sub-agent startup details including parent tool call and agent information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentStartedData { + /// Description of what the sub-agent does + pub agent_description: String, + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, +} + +/// Sub-agent completion details for successful execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentCompletedData { + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Wall-clock duration of the sub-agent execution in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Model used by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, + /// Total tokens (input + output) consumed by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, + /// Total number of tool calls made by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tool_calls: Option, +} + +/// Sub-agent failure details including error message and agent information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentFailedData { + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Wall-clock duration of the sub-agent execution in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Error message describing why the sub-agent failed + pub error: String, + /// Model used by the sub-agent (if any model calls succeeded before failure) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, + /// Total tokens (input + output) consumed before the sub-agent failed + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, + /// Total number of tool calls made before the sub-agent failed + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tool_calls: Option, +} + +/// Custom agent selection details including name and available tools +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentSelectedData { + /// Human-readable display name of the selected custom agent + pub agent_display_name: String, + /// Internal name of the selected custom agent + pub agent_name: String, + /// List of tool names available to this agent, or null for all tools + pub tools: Vec, +} + +/// Empty payload; the event signals that the custom agent was deselected, returning to the default agent +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentDeselectedData {} + +/// Hook invocation start details including type and input data +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookStartData { + /// Unique identifier for this hook invocation + pub hook_invocation_id: String, + /// Type of hook being invoked (e.g., "preToolUse", "postToolUse", "sessionStart") + pub hook_type: String, + /// Input data passed to the hook + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, +} + +/// Error details when the hook failed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookEndError { + /// Human-readable error message + pub message: String, + /// Error stack trace, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, +} + +/// Hook invocation completion details including output, success status, and error information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookEndData { + /// Error details when the hook failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Identifier matching the corresponding hook.start event + pub hook_invocation_id: String, + /// Type of hook that was invoked (e.g., "preToolUse", "postToolUse", "sessionStart") + pub hook_type: String, + /// Output data produced by the hook + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + /// Whether the hook completed successfully + pub success: bool, +} + +/// Metadata about the prompt template and its construction +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemMessageMetadata { + /// Version identifier of the prompt template used + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_version: Option, + /// Template variables used when constructing the prompt + #[serde(default)] + pub variables: HashMap, +} + +/// System/developer instruction content with role and optional template metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemMessageData { + /// The system or developer prompt text sent as model input + pub content: String, + /// Metadata about the prompt template and its construction + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + /// Optional name identifier for the message source + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Message role: "system" for system prompts, "developer" for developer-injected instructions + pub role: SystemMessageRole, +} + +/// System-generated notification for runtime events like background task completion +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemNotificationData { + /// The notification text, typically wrapped in XML tags + pub content: String, + /// Structured metadata identifying what triggered this notification + pub kind: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShellCommand { + /// Command identifier (e.g., executable name) + pub identifier: String, + /// Whether this command is read-only (no side effects) + pub read_only: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShellPossibleUrl { + /// URL that may be accessed by the command + pub url: String, +} + +/// Shell command permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShell { + /// Whether the UI can offer session-wide approval for this command pattern + pub can_offer_session_approval: bool, + /// Parsed command identifiers found in the command text + pub commands: Vec, + /// The complete shell command text to be executed + pub full_command_text: String, + /// Whether the command includes a file write redirection (e.g., > or >>) + pub has_write_file_redirection: bool, + /// Human-readable description of what the command intends to do + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestShellKind, + /// File paths that may be read or written by the command + pub possible_paths: Vec, + /// URLs that may be accessed by the command + pub possible_urls: Vec, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Optional warning message about risks of running this command + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option, +} + +/// File write permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestWrite { + /// Whether the UI can offer session-wide approval for file write operations + pub can_offer_session_approval: bool, + /// Unified diff showing the proposed changes + pub diff: String, + /// Path of the file being written to + pub file_name: String, + /// Human-readable description of the intended file change + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestWriteKind, + /// Complete new file contents for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub new_file_contents: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// File or directory read permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestRead { + /// Human-readable description of why the file is being read + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestReadKind, + /// Path of the file or directory being read + pub path: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// MCP tool invocation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestMcp { + /// Arguments to pass to the MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Permission kind discriminator + pub kind: PermissionRequestMcpKind, + /// Whether this MCP tool is read-only (no side effects) + pub read_only: bool, + /// Name of the MCP server providing the tool + pub server_name: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Internal name of the MCP tool + pub tool_name: String, + /// Human-readable title of the MCP tool + pub tool_title: String, +} + +/// URL access permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestUrl { + /// Human-readable description of why the URL is being accessed + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestUrlKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to be fetched + pub url: String, +} + +/// Memory operation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestMemory { + /// Whether this is a store or vote memory operation + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Source references for the stored fact (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub citations: Option, + /// Vote direction (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, + /// The fact being stored or voted on + pub fact: String, + /// Permission kind discriminator + pub kind: PermissionRequestMemoryKind, + /// Reason for the vote (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + /// Topic or subject of the memory (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub subject: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Custom tool invocation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestCustomTool { + /// Arguments to pass to the custom tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Permission kind discriminator + pub kind: PermissionRequestCustomToolKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Description of what the custom tool does + pub tool_description: String, + /// Name of the custom tool + pub tool_name: String, +} + +/// Hook confirmation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestHook { + /// Optional message from the hook explaining why confirmation is needed + #[serde(skip_serializing_if = "Option::is_none")] + pub hook_message: Option, + /// Permission kind discriminator + pub kind: PermissionRequestHookKind, + /// Arguments of the tool call being gated + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_args: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Name of the tool the hook is gating + pub tool_name: String, +} + +/// Shell command permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestCommands { + /// Whether the UI can offer session-wide approval for this command pattern + pub can_offer_session_approval: bool, + /// Command identifiers covered by this approval prompt + pub command_identifiers: Vec, + /// The complete shell command text to be executed + pub full_command_text: String, + /// Human-readable description of what the command intends to do + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestCommandsKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Optional warning message about risks of running this command + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option, +} + +/// File write permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestWrite { + /// Whether the UI can offer session-wide approval for file write operations + pub can_offer_session_approval: bool, + /// Unified diff showing the proposed changes + pub diff: String, + /// Path of the file being written to + pub file_name: String, + /// Human-readable description of the intended file change + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestWriteKind, + /// Complete new file contents for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub new_file_contents: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// File read permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestRead { + /// Human-readable description of why the file is being read + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestReadKind, + /// Path of the file or directory being read + pub path: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// MCP tool invocation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestMcp { + /// Arguments to pass to the MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestMcpKind, + /// Name of the MCP server providing the tool + pub server_name: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Internal name of the MCP tool + pub tool_name: String, + /// Human-readable title of the MCP tool + pub tool_title: String, +} + +/// URL access permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestUrl { + /// Human-readable description of why the URL is being accessed + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestUrlKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to be fetched + pub url: String, +} + +/// Memory operation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestMemory { + /// Whether this is a store or vote memory operation + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Source references for the stored fact (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub citations: Option, + /// Vote direction (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, + /// The fact being stored or voted on + pub fact: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestMemoryKind, + /// Reason for the vote (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + /// Topic or subject of the memory (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub subject: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Custom tool invocation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestCustomTool { + /// Arguments to pass to the custom tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestCustomToolKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Description of what the custom tool does + pub tool_description: String, + /// Name of the custom tool + pub tool_name: String, +} + +/// Path access permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestPath { + /// Underlying permission kind that needs path approval + pub access_kind: PermissionPromptRequestPathAccessKind, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestPathKind, + /// File paths that require explicit approval + pub paths: Vec, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Hook confirmation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestHook { + /// Optional message from the hook explaining why confirmation is needed + #[serde(skip_serializing_if = "Option::is_none")] + pub hook_message: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestHookKind, + /// Arguments of the tool call being gated + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_args: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Name of the tool the hook is gating + pub tool_name: String, +} + +/// Permission request notification requiring client approval with request details +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestedData { + /// Details of the permission being requested + pub permission_request: PermissionRequest, + /// Derived user-facing permission prompt details for UI consumers + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_request: Option, + /// Unique identifier for this permission request; used to respond via session.respondToPermission() + pub request_id: RequestId, + /// When true, this permission was already resolved by a permissionRequest hook and requires no client action + #[serde(skip_serializing_if = "Option::is_none")] + pub resolved_by_hook: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionApproved { + /// The permission request was approved + pub kind: PermissionApprovedKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalCommands { + /// Command identifiers approved by the user + pub command_identifiers: Vec, + /// Command approval kind + pub kind: UserToolSessionApprovalCommandsKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalRead { + /// Read approval kind + pub kind: UserToolSessionApprovalReadKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalWrite { + /// Write approval kind + pub kind: UserToolSessionApprovalWriteKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalMcp { + /// MCP tool approval kind + pub kind: UserToolSessionApprovalMcpKind, + /// MCP server name + pub server_name: String, + /// Optional MCP tool name, or null for all tools on the server + pub tool_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalMemory { + /// Memory approval kind + pub kind: UserToolSessionApprovalMemoryKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserToolSessionApprovalCustomTool { + /// Custom tool approval kind + pub kind: UserToolSessionApprovalCustomToolKind, + /// Custom tool name + pub tool_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionApprovedForSession { + /// The approval to add as a session-scoped rule + pub approval: UserToolSessionApproval, + /// Approved and remembered for the rest of the session + pub kind: PermissionApprovedForSessionKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionApprovedForLocation { + /// The approval to persist for this location + pub approval: UserToolSessionApproval, + /// Approved and persisted for this project location + pub kind: PermissionApprovedForLocationKind, + /// The location key (git root or cwd) to persist the approval to + pub location_key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionCancelled { + /// The permission request was cancelled before a response was used + pub kind: PermissionCancelledKind, + /// Optional explanation of why the request was cancelled + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRule { + /// Optional rule argument matched against the request + pub argument: Option, + /// The rule kind, such as Shell or GitHubMCP + pub kind: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDeniedByRules { + /// Denied because approval rules explicitly blocked it + pub kind: PermissionDeniedByRulesKind, + /// Rules that denied the request + pub rules: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser { + /// Denied because no approval rule matched and user confirmation was unavailable + pub kind: PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUserKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDeniedInteractivelyByUser { + /// Optional feedback from the user explaining the denial + #[serde(skip_serializing_if = "Option::is_none")] + pub feedback: Option, + /// Whether to force-reject the current agent turn + #[serde(skip_serializing_if = "Option::is_none")] + pub force_reject: Option, + /// Denied by the user during an interactive prompt + pub kind: PermissionDeniedInteractivelyByUserKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDeniedByContentExclusionPolicy { + /// Denied by the organization's content exclusion policy + pub kind: PermissionDeniedByContentExclusionPolicyKind, + /// Human-readable explanation of why the path was excluded + pub message: String, + /// File path that triggered the exclusion + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDeniedByPermissionRequestHook { + /// Whether to interrupt the current agent turn + #[serde(skip_serializing_if = "Option::is_none")] + pub interrupt: Option, + /// Denied by a permission request hook registered by an extension or plugin + pub kind: PermissionDeniedByPermissionRequestHookKind, + /// Optional message from the hook explaining the denial + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +/// Permission request completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionCompletedData { + /// Request ID of the resolved permission request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// The result of the permission request + pub result: PermissionResult, + /// Optional tool call ID associated with this permission prompt; clients may use it to correlate UI created from tool-scoped prompts + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// User input request notification with question and optional predefined choices +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserInputRequestedData { + /// Whether the user can provide a free-form text response in addition to predefined choices + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_freeform: Option, + /// Predefined choices for the user to select from, if applicable + #[serde(default)] + pub choices: Vec, + /// The question or prompt to present to the user + pub question: String, + /// Unique identifier for this input request; used to respond via session.respondToUserInput() + pub request_id: RequestId, + /// The LLM-assigned tool call ID that triggered this request; used by remote UIs to correlate responses + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// User input request completion with the user's response +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserInputCompletedData { + /// The user's answer to the input request + #[serde(skip_serializing_if = "Option::is_none")] + pub answer: Option, + /// Request ID of the resolved user input request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// Whether the answer was typed as free-form text rather than selected from choices + #[serde(skip_serializing_if = "Option::is_none")] + pub was_freeform: Option, +} + +/// JSON Schema describing the form fields to present to the user (form mode only) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequestedSchema { + /// Form field definitions, keyed by field name + pub properties: HashMap, + /// List of required field names + #[serde(default)] + pub required: Vec, + /// Schema type indicator (always 'object') + pub r#type: ElicitationRequestedSchemaType, +} + +/// Elicitation request; may be form-based (structured input) or URL-based (browser redirect) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequestedData { + /// The source that initiated the request (MCP server name, or absent for agent-initiated) + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation_source: Option, + /// Message describing what information is needed from the user + pub message: String, + /// Elicitation mode; "form" for structured input, "url" for browser-based. Defaults to "form" when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// JSON Schema describing the form fields to present to the user (form mode only) + #[serde(skip_serializing_if = "Option::is_none")] + pub requested_schema: Option, + /// Unique identifier for this elicitation request; used to respond via session.respondToElicitation() + pub request_id: RequestId, + /// Tool call ID from the LLM completion; used to correlate with CompletionChunk.toolCall.id for remote UIs + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to open in the user's browser (url mode only) + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Elicitation request completion with the user's response +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationCompletedData { + /// The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// The submitted form data when action is 'accept'; keys match the requested schema fields + #[serde(default)] + pub content: HashMap, + /// Request ID of the resolved elicitation request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Sampling request from an MCP server; contains the server name and a requestId for correlation +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingRequestedData { + /// The JSON-RPC request ID from the MCP protocol + pub mcp_request_id: serde_json::Value, + /// Unique identifier for this sampling request; used to respond via session.respondToSampling() + pub request_id: RequestId, + /// Name of the MCP server that initiated the sampling request + pub server_name: String, +} + +/// Sampling request completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingCompletedData { + /// Request ID of the resolved sampling request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Static OAuth client configuration, if the server specifies one +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthRequiredStaticClientConfig { + /// OAuth client ID for the server + pub client_id: String, + /// Optional non-default OAuth grant type. When set to 'client_credentials', the OAuth flow runs headlessly using the client_id + keychain-stored secret (no browser, no callback server). + #[serde(skip_serializing_if = "Option::is_none")] + pub grant_type: Option, + /// Whether this is a public OAuth client + #[serde(skip_serializing_if = "Option::is_none")] + pub public_client: Option, +} + +/// OAuth authentication request for an MCP server +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthRequiredData { + /// Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() + pub request_id: RequestId, + /// Display name of the MCP server that requires OAuth + pub server_name: String, + /// URL of the MCP server that requires OAuth + pub server_url: String, + /// Static OAuth client configuration, if the server specifies one + #[serde(skip_serializing_if = "Option::is_none")] + pub static_client_config: Option, +} + +/// MCP OAuth request completion notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthCompletedData { + /// Request ID of the resolved OAuth request + pub request_id: RequestId, +} + +/// External tool invocation request for client-side tool execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolRequestedData { + /// Arguments to pass to the external tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Unique identifier for this request; used to respond via session.respondToExternalTool() + pub request_id: RequestId, + /// Session ID that this external tool request belongs to + pub session_id: SessionId, + /// Tool call ID assigned to this external tool invocation + pub tool_call_id: String, + /// Name of the external tool to invoke + pub tool_name: String, + /// W3C Trace Context traceparent header for the execute_tool span + #[serde(skip_serializing_if = "Option::is_none")] + pub traceparent: Option, + /// W3C Trace Context tracestate header for the execute_tool span + #[serde(skip_serializing_if = "Option::is_none")] + pub tracestate: Option, +} + +/// External tool completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolCompletedData { + /// Request ID of the resolved external tool request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Queued slash command dispatch request for client execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandQueuedData { + /// The slash command text to be executed (e.g., /help, /clear) + pub command: String, + /// Unique identifier for this request; used to respond via session.respondToQueuedCommand() + pub request_id: RequestId, +} + +/// Registered command dispatch request routed to the owning client +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandExecuteData { + /// Raw argument string after the command name + pub args: String, + /// The full command text (e.g., /deploy production) + pub command: String, + /// Command name without leading / + pub command_name: String, + /// Unique identifier; used to respond via session.commands.handlePendingCommand() + pub request_id: RequestId, +} + +/// Queued command completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandCompletedData { + /// Request ID of the resolved command request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Auto mode switch request notification requiring user approval +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AutoModeSwitchRequestedData { + /// The rate limit error code that triggered this request + #[serde(skip_serializing_if = "Option::is_none")] + pub error_code: Option, + /// Unique identifier for this request; used to respond via session.respondToAutoModeSwitch() + pub request_id: RequestId, + /// Seconds until the rate limit resets, when known. Lets clients render a humanized reset time alongside the prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_after_seconds: Option, +} + +/// Auto mode switch completion notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AutoModeSwitchCompletedData { + /// Request ID of the resolved request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// The user's choice: 'yes', 'yes_always', or 'no' + pub response: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsChangedCommand { + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub name: String, +} + +/// SDK command registration change notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsChangedData { + /// Current list of registered SDK commands + pub commands: Vec, +} + +/// UI capability changes +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CapabilitiesChangedUI { + /// Whether elicitation is now supported + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation: Option, +} + +/// Session capability change notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CapabilitiesChangedData { + /// UI capability changes + #[serde(skip_serializing_if = "Option::is_none")] + pub ui: Option, +} + +/// Plan approval request with plan content and available user actions +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeRequestedData { + /// Available actions the user can take (e.g., approve, edit, reject) + pub actions: Vec, + /// Full content of the plan file + pub plan_content: String, + /// The recommended action for the user to take + pub recommended_action: String, + /// Unique identifier for this request; used to respond via session.respondToExitPlanMode() + pub request_id: RequestId, + /// Summary of the plan that was created + pub summary: String, +} + +/// Plan mode exit completion with the user's approval decision and optional feedback +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeCompletedData { + /// Whether the plan was approved by the user + #[serde(skip_serializing_if = "Option::is_none")] + pub approved: Option, + /// Whether edits should be auto-approved without confirmation + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_approve_edits: Option, + /// Free-form feedback from the user if they requested changes to the plan + #[serde(skip_serializing_if = "Option::is_none")] + pub feedback: Option, + /// Request ID of the resolved exit plan mode request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// Which action the user selected (e.g. 'autopilot', 'interactive', 'exit_only') + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_action: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionToolsUpdatedData { + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionBackgroundTasksChangedData {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsLoadedSkill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Source location type of the skill (e.g., project, personal, plugin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsLoadedData { + /// Array of resolved skill metadata + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CustomAgentsUpdatedAgent { + /// Description of what the agent does + pub description: String, + /// Human-readable display name + pub display_name: String, + /// Unique identifier for the agent + pub id: String, + /// Model override for this agent, if set + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Internal name of the agent + pub name: String, + /// Source location: user, project, inherited, remote, or plugin + pub source: String, + /// List of tool names available to this agent + pub tools: Vec, + /// Whether the agent can be selected by the user + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCustomAgentsUpdatedData { + /// Array of loaded custom agent metadata + pub agents: Vec, + /// Fatal errors from agent loading + pub errors: Vec, + /// Non-fatal warnings from agent loading + pub warnings: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServersLoadedServer { + /// Error message if the server failed to connect + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Server name (config key) + pub name: String, + /// Configuration source: user, workspace, plugin, or builtin + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServersLoadedServerStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpServersLoadedData { + /// Array of MCP server status summaries + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpServerStatusChangedData { + /// Name of the MCP server whose status changed + pub server_name: String, + /// New connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServerStatusChangedStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsLoadedExtension { + /// Source-qualified extension ID (e.g., 'project:my-ext', 'user:auth-helper') + pub id: String, + /// Extension name (directory name) + pub name: String, + /// Discovery source + pub source: ExtensionsLoadedExtensionSource, + /// Current status: running, disabled, failed, or starting + pub status: ExtensionsLoadedExtensionStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsLoadedData { + /// Array of discovered extensions and their status + pub extensions: Vec, +} + +/// Hosting platform type of the repository (github or ado) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkingDirectoryContextHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The type of operation performed on the plan file +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PlanChangedOperation { + #[serde(rename = "create")] + Create, + #[serde(rename = "update")] + Update, + #[serde(rename = "delete")] + Delete, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the file was newly created or updated +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspaceFileChangedOperation { + #[serde(rename = "create")] + Create, + #[serde(rename = "update")] + Update, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Origin type of the session being handed off +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum HandoffSourceType { + #[serde(rename = "remote")] + Remote, + #[serde(rename = "local")] + Local, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the session ended normally ("routine") or due to a crash/fatal error ("error") +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ShutdownType { + #[serde(rename = "routine")] + Routine, + #[serde(rename = "error")] + Error, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The agent mode that was active when this message was sent +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserMessageAgentMode { + #[serde(rename = "interactive")] + Interactive, + #[serde(rename = "plan")] + Plan, + #[serde(rename = "autopilot")] + Autopilot, + #[serde(rename = "shell")] + Shell, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Tool call type: "function" for standard tool calls, "custom" for grammar-based tool calls. Defaults to "function" when absent. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AssistantMessageToolRequestType { + #[serde(rename = "function")] + Function, + #[serde(rename = "custom")] + Custom, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Where the failed model call originated +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelCallFailureSource { + #[serde(rename = "top_level")] + TopLevel, + #[serde(rename = "subagent")] + Subagent, + #[serde(rename = "mcp_sampling")] + McpSampling, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Message role: "system" for system prompts, "developer" for developer-injected instructions +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SystemMessageRole { + #[serde(rename = "system")] + System, + #[serde(rename = "developer")] + Developer, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestShellKind { + #[serde(rename = "shell")] + Shell, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestWriteKind { + #[serde(rename = "write")] + Write, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestReadKind { + #[serde(rename = "read")] + Read, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestUrlKind { + #[serde(rename = "url")] + Url, +} + +/// Whether this is a store or vote memory operation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryAction { + #[serde(rename = "store")] + Store, + #[serde(rename = "vote")] + Vote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Vote direction (vote only) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryDirection { + #[serde(rename = "upvote")] + Upvote, + #[serde(rename = "downvote")] + Downvote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestHookKind { + #[serde(rename = "hook")] + Hook, +} + +/// Details of the permission being requested +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionRequest { + Shell(PermissionRequestShell), + Write(PermissionRequestWrite), + Read(PermissionRequestRead), + Mcp(PermissionRequestMcp), + Url(PermissionRequestUrl), + Memory(PermissionRequestMemory), + CustomTool(PermissionRequestCustomTool), + Hook(PermissionRequestHook), +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestWriteKind { + #[serde(rename = "write")] + Write, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestReadKind { + #[serde(rename = "read")] + Read, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestUrlKind { + #[serde(rename = "url")] + Url, +} + +/// Whether this is a store or vote memory operation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryAction { + #[serde(rename = "store")] + Store, + #[serde(rename = "vote")] + Vote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Vote direction (vote only) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryDirection { + #[serde(rename = "upvote")] + Upvote, + #[serde(rename = "downvote")] + Downvote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// Underlying permission kind that needs path approval +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestPathAccessKind { + #[serde(rename = "read")] + Read, + #[serde(rename = "shell")] + Shell, + #[serde(rename = "write")] + Write, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestPathKind { + #[serde(rename = "path")] + Path, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestHookKind { + #[serde(rename = "hook")] + Hook, +} + +/// Derived user-facing permission prompt details for UI consumers +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionPromptRequest { + Commands(PermissionPromptRequestCommands), + Write(PermissionPromptRequestWrite), + Read(PermissionPromptRequestRead), + Mcp(PermissionPromptRequestMcp), + Url(PermissionPromptRequestUrl), + Memory(PermissionPromptRequestMemory), + CustomTool(PermissionPromptRequestCustomTool), + Path(PermissionPromptRequestPath), + Hook(PermissionPromptRequestHook), +} + +/// The permission request was approved +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionApprovedKind { + #[serde(rename = "approved")] + Approved, +} + +/// Command approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +/// Read approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalReadKind { + #[serde(rename = "read")] + Read, +} + +/// Write approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalWriteKind { + #[serde(rename = "write")] + Write, +} + +/// MCP tool approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +/// Memory approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +/// Custom tool approval kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserToolSessionApprovalCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// The approval to add as a session-scoped rule +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum UserToolSessionApproval { + Commands(UserToolSessionApprovalCommands), + Read(UserToolSessionApprovalRead), + Write(UserToolSessionApprovalWrite), + Mcp(UserToolSessionApprovalMcp), + Memory(UserToolSessionApprovalMemory), + CustomTool(UserToolSessionApprovalCustomTool), +} + +/// Approved and remembered for the rest of the session +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionApprovedForSessionKind { + #[serde(rename = "approved-for-session")] + ApprovedForSession, +} + +/// Approved and persisted for this project location +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionApprovedForLocationKind { + #[serde(rename = "approved-for-location")] + ApprovedForLocation, +} + +/// The permission request was cancelled before a response was used +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionCancelledKind { + #[serde(rename = "cancelled")] + Cancelled, +} + +/// Denied because approval rules explicitly blocked it +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDeniedByRulesKind { + #[serde(rename = "denied-by-rules")] + DeniedByRules, +} + +/// Denied because no approval rule matched and user confirmation was unavailable +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUserKind { + #[serde(rename = "denied-no-approval-rule-and-could-not-request-from-user")] + DeniedNoApprovalRuleAndCouldNotRequestFromUser, +} + +/// Denied by the user during an interactive prompt +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDeniedInteractivelyByUserKind { + #[serde(rename = "denied-interactively-by-user")] + DeniedInteractivelyByUser, +} + +/// Denied by the organization's content exclusion policy +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDeniedByContentExclusionPolicyKind { + #[serde(rename = "denied-by-content-exclusion-policy")] + DeniedByContentExclusionPolicy, +} + +/// Denied by a permission request hook registered by an extension or plugin +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDeniedByPermissionRequestHookKind { + #[serde(rename = "denied-by-permission-request-hook")] + DeniedByPermissionRequestHook, +} + +/// The result of the permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionResult { + Approved(PermissionApproved), + ApprovedForSession(PermissionApprovedForSession), + ApprovedForLocation(PermissionApprovedForLocation), + Cancelled(PermissionCancelled), + DeniedByRules(PermissionDeniedByRules), + DeniedNoApprovalRuleAndCouldNotRequestFromUser( + PermissionDeniedNoApprovalRuleAndCouldNotRequestFromUser, + ), + DeniedInteractivelyByUser(PermissionDeniedInteractivelyByUser), + DeniedByContentExclusionPolicy(PermissionDeniedByContentExclusionPolicy), + DeniedByPermissionRequestHook(PermissionDeniedByPermissionRequestHook), +} + +/// Elicitation mode; "form" for structured input, "url" for browser-based. Defaults to "form" when absent. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationRequestedMode { + #[serde(rename = "form")] + Form, + #[serde(rename = "url")] + Url, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Schema type indicator (always 'object') +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationRequestedSchemaType { + #[serde(rename = "object")] + Object, +} + +/// The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationCompletedAction { + #[serde(rename = "accept")] + Accept, + #[serde(rename = "decline")] + Decline, + #[serde(rename = "cancel")] + Cancel, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Optional non-default OAuth grant type. When set to 'client_credentials', the OAuth flow runs headlessly using the client_id + keychain-stored secret (no browser, no callback server). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpOauthRequiredStaticClientConfigGrantType { + #[serde(rename = "client_credentials")] + ClientCredentials, +} + +/// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServersLoadedServerStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// New connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerStatusChangedStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Discovery source +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionsLoadedExtensionSource { + #[serde(rename = "project")] + Project, + #[serde(rename = "user")] + User, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current status: running, disabled, failed, or starting +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionsLoadedExtensionStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "starting")] + Starting, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} diff --git a/rust/src/handler.rs b/rust/src/handler.rs new file mode 100644 index 000000000..79c7d381d --- /dev/null +++ b/rust/src/handler.rs @@ -0,0 +1,608 @@ +//! Event handler traits for session lifecycle. +//! +//! The [`SessionHandler`](crate::handler::SessionHandler) trait is the primary extension point โ€” implement +//! [`on_event`](crate::handler::SessionHandler::on_event) to control how sessions respond to +//! CLI events, permission requests, tool calls, and user input prompts. + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::types::{ + ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, + SessionEvent, SessionId, ToolInvocation, ToolResult, +}; + +/// Events dispatched by the SDK session event loop to the handler. +/// +/// The handler returns a [`HandlerResponse`] indicating how the SDK should +/// respond to the CLI. For fire-and-forget events (`SessionEvent`), the +/// response is ignored. +#[non_exhaustive] +#[derive(Debug)] +pub enum HandlerEvent { + /// Informational session event from the timeline (e.g. assistant.message_delta, + /// session.idle, tool.execution_start). Fire-and-forget โ€” return `HandlerResponse::Ok`. + SessionEvent { + /// The session that emitted this event. + session_id: SessionId, + /// The event payload. + event: SessionEvent, + }, + + /// The CLI requests permission for an action. Return `HandlerResponse::Permission(..)`. + PermissionRequest { + /// The requesting session. + session_id: SessionId, + /// Unique ID to correlate the response. + request_id: RequestId, + /// Permission request payload. + data: PermissionRequestData, + }, + + /// The CLI requests user input. Return `HandlerResponse::UserInput(..)`. + /// The handler may block (e.g. awaiting a UI dialog) โ€” this is expected. + UserInput { + /// The requesting session. + session_id: SessionId, + /// The question text to present. + question: String, + /// Optional multiple-choice options. + choices: Option>, + /// Whether free-form text input is allowed. + allow_freeform: Option, + }, + + /// The CLI requests execution of a client-defined tool. + /// Return `HandlerResponse::ToolResult(..)`. + ExternalTool { + /// The tool call to execute. + invocation: ToolInvocation, + }, + + /// The CLI broadcasts an elicitation request for the provider to handle. + /// Return `HandlerResponse::Elicitation(..)`. + ElicitationRequest { + /// The requesting session. + session_id: SessionId, + /// Unique ID to correlate the response. + request_id: RequestId, + /// The elicitation request payload. + request: ElicitationRequest, + }, + + /// The CLI requests exiting plan mode. Return `HandlerResponse::ExitPlanMode(..)`. + ExitPlanMode { + /// The requesting session. + session_id: SessionId, + /// Plan mode exit payload. + data: ExitPlanModeData, + }, + + /// The CLI asks whether to switch to auto model when an eligible rate + /// limit is hit. Return [`HandlerResponse::AutoModeSwitch`]. + AutoModeSwitch { + /// The requesting session. + session_id: SessionId, + /// The specific rate-limit error code that triggered the request, + /// if known (e.g. `user_weekly_rate_limited`, `user_global_rate_limited`). + error_code: Option, + /// Seconds until the rate limit resets, when known. Per RFC 9110's + /// `Retry-After` `delta-seconds` form, this is an integer count of + /// seconds. Handlers can use it to render a humanized reset time + /// alongside the prompt. + retry_after_seconds: Option, + }, +} + +/// Response from the handler back to the SDK, used to construct the +/// JSON-RPC reply sent to the CLI. +#[non_exhaustive] +#[derive(Debug)] +pub enum HandlerResponse { + /// No response needed (used for fire-and-forget `SessionEvent`s). + Ok, + /// Permission decision. + Permission(PermissionResult), + /// User input response (or `None` to signal no input available). + UserInput(Option), + /// Result of a tool execution. + ToolResult(ToolResult), + /// Elicitation result (accept/decline/cancel with optional form data). + Elicitation(ElicitationResult), + /// Exit plan mode decision. + ExitPlanMode(ExitPlanModeResult), + /// Auto-mode-switch decision. + AutoModeSwitch(AutoModeSwitchResponse), +} + +/// Result of a permission request. +/// +/// `#[non_exhaustive]` so future variants can be added without a major +/// version bump. Match arms must include a `_` fallback. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum PermissionResult { + /// Permission granted. + Approved, + /// Permission denied. + Denied, + /// Defer the response. The handler will resolve this request itself + /// later โ€” typically after a UI prompt โ€” by calling + /// `session.permissions.handlePendingPermissionRequest` directly. The + /// SDK will not send a response for this request. + /// + /// **Notification path only** (`permission.requested`). On the direct + /// RPC path (`permission.request`), `Deferred` falls back to + /// [`Approved`](Self::Approved) because that path must return a value + /// to satisfy the JSON-RPC reply contract. + Deferred, + /// Provide the full response payload. The SDK passes the value as-is + /// in the `result` field of `handlePendingPermissionRequest` + /// (notification path) or as the JSON-RPC `result` directly (direct + /// RPC path). + /// + /// Use this for response shapes beyond `{ "kind": "approve-once" }` + /// or `{ "kind": "reject" }` โ€” for example, "approve and remember" + /// with allowlist data. + Custom(serde_json::Value), + /// No user is available to respond โ€” for example, headless agents + /// without an interactive session. Sent as + /// `{ "kind": "user-not-available" }`. + UserNotAvailable, + /// The handler has no result to provide and the CLI should fall back + /// to its default policy. Sent as `{ "kind": "no-result" }`. Distinct + /// from [`Deferred`](Self::Deferred), which suppresses the reply + /// entirely so the handler can resolve later out-of-band. + NoResult, +} + +/// Response to a user input request. +#[derive(Debug, Clone)] +pub struct UserInputResponse { + /// The user's answer text. + pub answer: String, + /// Whether the answer was free-form (not a preset choice). + pub was_freeform: bool, +} + +/// Result of an exit-plan-mode request. +#[derive(Debug, Clone)] +pub struct ExitPlanModeResult { + /// Whether the user approved exiting plan mode. + pub approved: bool, + /// The action the user selected (if any). + pub selected_action: Option, + /// Optional feedback text from the user. + pub feedback: Option, +} + +impl Default for ExitPlanModeResult { + fn default() -> Self { + Self { + approved: true, + selected_action: None, + feedback: None, + } + } +} + +/// Response to a [`HandlerEvent::AutoModeSwitch`] request. +/// +/// Wire serialization matches the CLI's `autoModeSwitch.request` response +/// schema: `"yes"`, `"yes_always"`, or `"no"`. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AutoModeSwitchResponse { + /// Approve the auto-mode switch for this rate-limit cycle only. + Yes, + /// Approve and remember โ€” auto-accept future auto-mode switches in this + /// session without prompting. + YesAlways, + /// Decline the auto-mode switch. The session stays on the current model + /// and surfaces the rate-limit error. + No, +} + +/// Callback trait for session events. +/// +/// Implement this trait to control how a session responds to CLI events, +/// permission requests, tool calls, user input prompts, elicitations, and +/// plan-mode exits. There are two styles of implementation โ€” pick whichever +/// fits your use case: +/// +/// 1. **Per-event methods (recommended for most handlers).** Override the +/// specific `on_*` methods you care about; every method has a safe +/// default so you only write what you need. This is the pattern used by +/// [`serenity::EventHandler`][serenity], `lapin`, and most Rust SDKs +/// that dispatch broker/client callbacks. +/// 2. **Single [`on_event`](Self::on_event) method.** Override this one +/// method and `match` on [`HandlerEvent`] yourself. Useful for logging +/// middleware, custom routing, or when you want an exhaustiveness check +/// across all variants. +/// +/// When you override [`on_event`](Self::on_event) directly, the per-event methods are not +/// called โ€” your implementation is entirely responsible for dispatch. The +/// default [`on_event`](Self::on_event) fans out to the per-event methods. +/// +/// [serenity]: https://docs.rs/serenity/latest/serenity/client/trait.EventHandler.html +/// +/// # Default behavior +/// +/// - Permission requests โ†’ **denied** (safe default). +/// - User input โ†’ `None` (no answer available). +/// - External tool calls โ†’ failure result with "no handler registered". +/// - Elicitation โ†’ `"cancel"`. +/// - Exit plan mode โ†’ [`ExitPlanModeResult::default`]. +/// - Auto-mode-switch โ†’ [`AutoModeSwitchResponse::No`] (decline by default; the +/// session stays on its current model and surfaces the rate-limit error). +/// - Session events โ†’ ignored (fire-and-forget). +/// +/// # Concurrency +/// +/// **Request-triggered events** (`UserInput`, `ExternalTool` via `tool.call`, +/// `ExitPlanMode`, `PermissionRequest` via `permission.request`) are awaited +/// inline in the event loop and therefore processed **serially** per session. +/// Blocking here pauses that session's event loop โ€” which is correct, since +/// the CLI is also blocked waiting for the response. +/// +/// **Notification-triggered events** (`PermissionRequest` via +/// `permission.requested`, `ExternalTool` via `external_tool.requested`) are +/// dispatched on spawned tasks and may run **concurrently** with each other +/// and with the serial event loop. Implementations must be safe for +/// concurrent invocation. +/// +/// # Example +/// +/// ```no_run +/// use async_trait::async_trait; +/// use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +/// use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; +/// +/// struct ApproveReadsOnly; +/// +/// #[async_trait] +/// impl SessionHandler for ApproveReadsOnly { +/// async fn on_permission_request( +/// &self, +/// _sid: SessionId, +/// _rid: RequestId, +/// data: PermissionRequestData, +/// ) -> PermissionResult { +/// match data.extra.get("tool").and_then(|v| v.as_str()) { +/// Some("view") | Some("ls") | Some("grep") => PermissionResult::Approved, +/// _ => PermissionResult::Denied, +/// } +/// } +/// } +/// ``` +#[async_trait] +pub trait SessionHandler: Send + Sync + 'static { + /// Handle an event from the session. + /// + /// The default implementation destructures `event` and calls the + /// matching per-event method (e.g. [`on_permission_request`](Self::on_permission_request) + /// for [`HandlerEvent::PermissionRequest`]). Override this method only + /// if you want a single dispatch point with exhaustive matching โ€” most + /// handlers should override the per-event methods instead. + /// + /// See the [trait-level docs](SessionHandler#concurrency) for details on + /// which events may be dispatched concurrently. + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { session_id, event } => { + self.on_session_event(session_id, event).await; + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { + session_id, + request_id, + data, + } => HandlerResponse::Permission( + self.on_permission_request(session_id, request_id, data) + .await, + ), + HandlerEvent::UserInput { + session_id, + question, + choices, + allow_freeform, + } => HandlerResponse::UserInput( + self.on_user_input(session_id, question, choices, allow_freeform) + .await, + ), + HandlerEvent::ExternalTool { invocation } => { + HandlerResponse::ToolResult(self.on_external_tool(invocation).await) + } + HandlerEvent::ElicitationRequest { + session_id, + request_id, + request, + } => HandlerResponse::Elicitation( + self.on_elicitation(session_id, request_id, request).await, + ), + HandlerEvent::ExitPlanMode { session_id, data } => { + HandlerResponse::ExitPlanMode(self.on_exit_plan_mode(session_id, data).await) + } + HandlerEvent::AutoModeSwitch { + session_id, + error_code, + retry_after_seconds, + } => HandlerResponse::AutoModeSwitch( + self.on_auto_mode_switch(session_id, error_code, retry_after_seconds) + .await, + ), + } + } + + /// Informational timeline event (assistant messages, tool execution + /// markers, session idle, etc.). Fire-and-forget โ€” the return value is + /// ignored. + /// + /// Default: do nothing. + async fn on_session_event(&self, _session_id: SessionId, _event: SessionEvent) {} + + /// The CLI is asking whether the agent may perform a privileged action. + /// + /// Default: [`PermissionResult::Denied`]. The default-deny posture + /// matches the CLI's safety model; override to implement your own + /// policy (see the [`permission`](crate::permission) module for common + /// wrappers like `approve_all` / `approve_if`). + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied + } + + /// The CLI is asking the user a question (optionally with a list of + /// choices). + /// + /// Default: `None` โ€” the CLI interprets this as "no answer available" + /// and falls back to its own prompt behavior. + async fn on_user_input( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + None + } + + /// The CLI wants to invoke a client-defined ("external") tool. + /// + /// Default: a failure [`ToolResult`] indicating no tool handler is + /// registered. Typical implementations route to a + /// [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) which + /// dispatches to tools registered via + /// [`define_tool`](crate::tool::define_tool) or custom + /// [`ToolHandler`](crate::tool::ToolHandler) impls. + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { + let msg = format!("No handler registered for tool '{}'", invocation.tool_name); + ToolResult::Expanded(crate::types::ToolResultExpanded { + text_result_for_llm: msg.clone(), + result_type: "failure".to_string(), + session_log: None, + error: Some(msg), + }) + } + + /// The CLI is requesting an elicitation (structured form / URL prompt). + /// + /// Default: cancel. + async fn on_elicitation( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, + } + } + + /// The CLI is asking the user whether to exit plan mode. + /// + /// Default: [`ExitPlanModeResult::default`] (approved with no action). + async fn on_exit_plan_mode( + &self, + _session_id: SessionId, + _data: ExitPlanModeData, + ) -> ExitPlanModeResult { + ExitPlanModeResult::default() + } + + /// The CLI is asking whether to switch to auto model after an eligible + /// rate limit. + /// + /// `retry_after_seconds`, when present, is the number of seconds until the + /// rate limit resets (RFC 9110 `Retry-After` `delta-seconds`). Handlers + /// can use it to render a humanized reset time alongside the prompt. + /// + /// Default: [`AutoModeSwitchResponse::No`] โ€” decline. Override only if + /// your application surfaces a UX for the rate-limit-recovery prompt. + async fn on_auto_mode_switch( + &self, + _session_id: SessionId, + _error_code: Option, + _retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + AutoModeSwitchResponse::No + } +} + +/// A [`SessionHandler`] that auto-approves all permissions and ignores all events. +/// +/// Useful for CLI tools, scripts, and tests that don't need interactive +/// permission prompts or custom tool handling. +#[derive(Debug, Clone)] +pub struct ApproveAllHandler; + +#[async_trait] +impl SessionHandler for ApproveAllHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } +} + +/// A [`SessionHandler`] that denies all permission requests and otherwise +/// relies on the trait's default fallback responses for every other event +/// (e.g. tool invocations return "unhandled", elicitations cancel, plan-mode +/// prompts decline). This is the safe default used when no handler is set on +/// [`SessionConfig::handler`](crate::types::SessionConfig::handler) โ€” sessions +/// will not stall on permission prompts (they're denied immediately) but no +/// privileged actions will be taken without an explicit opt-in. +#[derive(Debug, Clone)] +pub struct DenyAllHandler; + +#[async_trait] +impl SessionHandler for DenyAllHandler { + // All defaults are already safe: permissions deny, everything else is a + // sensible fallback. We just reuse them here for clarity. +} + +#[cfg(test)] +mod tests { + use serde_json::Value; + + use super::*; + use crate::types::{PermissionRequestData, RequestId, SessionId}; + + fn perm_data() -> PermissionRequestData { + PermissionRequestData::default() + } + + // A handler that overrides only `on_permission_request` (per-method style). + struct ApproveViaPerMethod; + + #[async_trait] + impl SessionHandler for ApproveViaPerMethod { + async fn on_permission_request( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + } + + // A handler that overrides `on_event` directly (legacy / routing style). + struct ApproveViaOnEvent; + + #[async_trait] + impl SessionHandler for ApproveViaOnEvent { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + _ => HandlerResponse::Ok, + } + } + } + + #[tokio::test] + async fn per_method_override_dispatches_via_default_on_event() { + let h = ApproveViaPerMethod; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Approved) + )); + } + + #[tokio::test] + async fn on_event_override_short_circuits_per_method_defaults() { + let h = ApproveViaOnEvent; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Approved) + )); + } + + #[tokio::test] + async fn deny_all_handler_uses_default_permission_deny() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Denied) + )); + } + + #[tokio::test] + async fn default_on_external_tool_returns_failure() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::ExternalTool { + invocation: crate::types::ToolInvocation { + session_id: SessionId::from("s1".to_string()), + tool_call_id: "tc1".to_string(), + tool_name: "missing".to_string(), + arguments: Value::Null, + traceparent: None, + tracestate: None, + }, + }) + .await; + match resp { + HandlerResponse::ToolResult(crate::types::ToolResult::Expanded(exp)) => { + assert_eq!(exp.result_type, "failure"); + assert!(exp.text_result_for_llm.contains("missing")); + assert_eq!(exp.error.as_deref(), Some(exp.text_result_for_llm.as_str())); + } + other => panic!("unexpected response: {other:?}"), + } + } + + #[tokio::test] + async fn default_on_elicitation_returns_cancel() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::ElicitationRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + request: crate::types::ElicitationRequest { + message: "test".to_string(), + requested_schema: None, + mode: Some(crate::types::ElicitationMode::Form), + elicitation_source: None, + url: None, + }, + }) + .await; + match resp { + HandlerResponse::Elicitation(r) => assert_eq!(r.action, "cancel"), + other => panic!("unexpected response: {other:?}"), + } + } +} diff --git a/rust/src/hooks.rs b/rust/src/hooks.rs new file mode 100644 index 000000000..ca755c6f9 --- /dev/null +++ b/rust/src/hooks.rs @@ -0,0 +1,715 @@ +//! Lifecycle hook callbacks invoked at key session points. +//! +//! Hooks let you intercept and modify CLI behavior โ€” approve or deny tool +//! use, rewrite user prompts, inject context at session start, and handle +//! errors. Implement [`SessionHooks`](crate::hooks::SessionHooks) and pass it to +//! [`Client::create_session`](crate::Client::create_session). + +use std::path::PathBuf; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::types::SessionId; + +/// Context provided to every hook invocation. +#[derive(Debug, Clone)] +pub struct HookContext { + /// The session this hook was triggered in. + pub session_id: SessionId, +} + +/// Input for the `preToolUse` hook โ€” received before a tool executes. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PreToolUseInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Name of the tool about to execute. + pub tool_name: String, + /// Arguments passed to the tool. + pub tool_args: Value, +} + +/// Output for the `preToolUse` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PreToolUseOutput { + /// "allow" or "deny". + #[serde(skip_serializing_if = "Option::is_none")] + pub permission_decision: Option, + /// Reason for the decision (shown to the agent). + #[serde(skip_serializing_if = "Option::is_none")] + pub permission_decision_reason: Option, + /// Replacement arguments for the tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_args: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `postToolUse` hook โ€” received after a tool executes. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostToolUseInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Name of the tool that executed. + pub tool_name: String, + /// Arguments that were passed to the tool. + pub tool_args: Value, + /// Result returned by the tool. + pub tool_result: Value, +} + +/// Output for the `postToolUse` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PostToolUseOutput { + /// Replacement result for the tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_result: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `userPromptSubmitted` hook โ€” received when the user sends a message. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserPromptSubmittedInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// The user's message text. + pub prompt: String, +} + +/// Output for the `userPromptSubmitted` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct UserPromptSubmittedOutput { + /// Replacement prompt text. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_prompt: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `sessionStart` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// How the session was started: `"startup"`, `"resume"`, or `"new"`. + pub source: String, + /// The first user message, if any. + #[serde(default)] + pub initial_prompt: Option, +} + +/// Output for the `sessionStart` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartOutput { + /// Extra context injected at session start. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Config overrides applied to the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_config: Option, +} + +/// Input for the `sessionEnd` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEndInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Why the session ended: `"complete"`, `"error"`, `"abort"`, `"timeout"`, `"user_exit"`. + pub reason: String, + /// The last assistant message. + #[serde(default)] + pub final_message: Option, + /// Error message, if the session ended due to an error. + #[serde(default)] + pub error: Option, +} + +/// Output for the `sessionEnd` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEndOutput { + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, + /// Actions to run during cleanup. + #[serde(skip_serializing_if = "Option::is_none")] + pub cleanup_actions: Option>, + /// Summary text for the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_summary: Option, +} + +/// Input for the `errorOccurred` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ErrorOccurredInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// The error message. + pub error: String, + /// Context where the error occurred: `"model_call"`, `"tool_execution"`, `"system"`, `"user_input"`. + pub error_context: String, + /// Whether the error is recoverable. + pub recoverable: bool, +} + +/// Output for the `errorOccurred` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ErrorOccurredOutput { + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, + /// How to handle the error: `"retry"`, `"skip"`, or `"abort"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub error_handling: Option, + /// Number of retries to attempt. + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_count: Option, + /// Message to show the user. + #[serde(skip_serializing_if = "Option::is_none")] + pub user_notification: Option, +} + +/// Events dispatched to [`SessionHooks::on_hook`] at CLI lifecycle points. +/// +/// Each variant carries the typed input for that hook plus the shared +/// [`HookContext`]. The handler returns a matching [`HookOutput`] variant +/// (or [`HookOutput::None`] to signal "no hook registered"). +#[non_exhaustive] +#[derive(Debug)] +pub enum HookEvent { + /// Fired before a tool executes. + PreToolUse { + /// Typed input data. + input: PreToolUseInput, + /// Session context. + ctx: HookContext, + }, + /// Fired after a tool executes. + PostToolUse { + /// Typed input data. + input: PostToolUseInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when the user sends a message. + UserPromptSubmitted { + /// Typed input data. + input: UserPromptSubmittedInput, + /// Session context. + ctx: HookContext, + }, + /// Fired at session creation or resume. + SessionStart { + /// Typed input data. + input: SessionStartInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when the session ends. + SessionEnd { + /// Typed input data. + input: SessionEndInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when an error occurs. + ErrorOccurred { + /// Typed input data. + input: ErrorOccurredInput, + /// Session context. + ctx: HookContext, + }, +} + +/// Response from [`SessionHooks::on_hook`] back to the SDK. +/// +/// Return the variant matching the [`HookEvent`] you received, or +/// [`HookOutput::None`] to indicate no hook is registered for that event. +#[non_exhaustive] +#[derive(Debug)] +pub enum HookOutput { + /// No hook registered โ€” the SDK returns an empty output object to the CLI. + None, + /// Response for a pre-tool-use hook. + PreToolUse(PreToolUseOutput), + /// Response for a post-tool-use hook. + PostToolUse(PostToolUseOutput), + /// Response for a user-prompt-submitted hook. + UserPromptSubmitted(UserPromptSubmittedOutput), + /// Response for a session-start hook. + SessionStart(SessionStartOutput), + /// Response for a session-end hook. + SessionEnd(SessionEndOutput), + /// Response for an error-occurred hook. + ErrorOccurred(ErrorOccurredOutput), +} + +impl HookOutput { + fn variant_name(&self) -> &'static str { + match self { + Self::None => "None", + Self::PreToolUse(_) => "PreToolUse", + Self::PostToolUse(_) => "PostToolUse", + Self::UserPromptSubmitted(_) => "UserPromptSubmitted", + Self::SessionStart(_) => "SessionStart", + Self::SessionEnd(_) => "SessionEnd", + Self::ErrorOccurred(_) => "ErrorOccurred", + } + } +} + +/// Callback trait for session hooks โ€” invoked by the CLI at key lifecycle +/// points (tool use, prompt submission, session start/end, errors). +/// +/// Implement this trait to intercept and modify CLI behavior at hook points. +/// There are two styles of implementation โ€” pick whichever fits: +/// +/// 1. **Per-hook methods (recommended).** Override the specific `on_*` hook +/// methods you care about; every hook has a default that returns `None` +/// (meaning "no hook registered, use CLI default behavior"). +/// 2. **Single [`on_hook`](Self::on_hook) method.** Override this one and +/// `match` on [`HookEvent`] yourself โ€” useful for logging middleware or +/// shared dispatch logic. +/// +/// Hooks only fire when hooks are enabled on the session (via +/// [`SessionConfig::hooks = Some(true)`](crate::types::SessionConfig::hooks), +/// which [`SessionConfig::with_hooks`](crate::types::SessionConfig::with_hooks) +/// sets automatically). +#[async_trait] +pub trait SessionHooks: Send + Sync + 'static { + /// Top-level dispatch. The default implementation fans out to the + /// per-hook methods below; override this only if you want a single + /// matching point across all hook types. + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, ctx } => self + .on_pre_tool_use(input, ctx) + .await + .map(HookOutput::PreToolUse) + .unwrap_or(HookOutput::None), + HookEvent::PostToolUse { input, ctx } => self + .on_post_tool_use(input, ctx) + .await + .map(HookOutput::PostToolUse) + .unwrap_or(HookOutput::None), + HookEvent::UserPromptSubmitted { input, ctx } => self + .on_user_prompt_submitted(input, ctx) + .await + .map(HookOutput::UserPromptSubmitted) + .unwrap_or(HookOutput::None), + HookEvent::SessionStart { input, ctx } => self + .on_session_start(input, ctx) + .await + .map(HookOutput::SessionStart) + .unwrap_or(HookOutput::None), + HookEvent::SessionEnd { input, ctx } => self + .on_session_end(input, ctx) + .await + .map(HookOutput::SessionEnd) + .unwrap_or(HookOutput::None), + HookEvent::ErrorOccurred { input, ctx } => self + .on_error_occurred(input, ctx) + .await + .map(HookOutput::ErrorOccurred) + .unwrap_or(HookOutput::None), + } + } + + /// Called before a tool executes. Return `Some(output)` to approve/deny + /// or modify the call, or `None` (default) to pass through unchanged. + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called after a tool executes. Return `Some(output)` to inject + /// additional context or signal post-processing decisions; `None` + /// (default) means no follow-up. + async fn on_post_tool_use( + &self, + _input: PostToolUseInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the user submits a prompt. Return `Some(output)` to + /// rewrite the prompt or inject extra context; `None` (default) passes + /// through unchanged. + async fn on_user_prompt_submitted( + &self, + _input: UserPromptSubmittedInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called at session creation or resume. Return `Some(output)` to + /// inject startup context. + async fn on_session_start( + &self, + _input: SessionStartInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the session ends. Return `Some(output)` if your hook + /// needs to signal cleanup behavior. + async fn on_session_end( + &self, + _input: SessionEndInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the CLI reports an error. Return `Some(output)` to + /// influence retry behavior or surface a user-facing notification. + async fn on_error_occurred( + &self, + _input: ErrorOccurredInput, + _ctx: HookContext, + ) -> Option { + None + } +} + +/// Dispatches a `hooks.invoke` request to [`SessionHooks::on_hook`]. +/// +/// Returns `Ok(Value)` shaped like `{ "output": ... }` on success. +/// If no hook is registered ([`HookOutput::None`]), the output is an empty +/// object: `{ "output": {} }`. +pub(crate) async fn dispatch_hook( + hooks: &dyn SessionHooks, + session_id: &SessionId, + hook_type: &str, + raw_input: Value, +) -> Result { + let ctx = HookContext { + session_id: session_id.clone(), + }; + + let event = match hook_type { + "preToolUse" => { + let input: PreToolUseInput = serde_json::from_value(raw_input)?; + HookEvent::PreToolUse { input, ctx } + } + "postToolUse" => { + let input: PostToolUseInput = serde_json::from_value(raw_input)?; + HookEvent::PostToolUse { input, ctx } + } + "userPromptSubmitted" => { + let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?; + HookEvent::UserPromptSubmitted { input, ctx } + } + "sessionStart" => { + let input: SessionStartInput = serde_json::from_value(raw_input)?; + HookEvent::SessionStart { input, ctx } + } + "sessionEnd" => { + let input: SessionEndInput = serde_json::from_value(raw_input)?; + HookEvent::SessionEnd { input, ctx } + } + "errorOccurred" => { + let input: ErrorOccurredInput = serde_json::from_value(raw_input)?; + HookEvent::ErrorOccurred { input, ctx } + } + _ => { + tracing::warn!( + hook_type = hook_type, + session_id = %session_id, + "unknown hook type" + ); + return Ok(serde_json::json!({ "output": {} })); + } + }; + + let output = hooks.on_hook(event).await; + + // Validate that the output variant matches the dispatched hook type. + // A mismatched return (e.g. HookOutput::SessionEnd for a preToolUse + // event) is treated as "no hook registered" to avoid sending the CLI + // a semantically wrong response. + let output_value = match (hook_type, &output) { + (_, HookOutput::None) => None, + ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?), + ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?), + ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => { + Some(serde_json::to_value(o)?) + } + ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?), + ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?), + ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?), + _ => { + tracing::warn!( + hook_type = hook_type, + session_id = %session_id, + output_variant = output.variant_name(), + "hook returned mismatched output variant, treating as unregistered" + ); + None + } + }; + + Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) })) +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestHooks; + + #[async_trait] + impl SessionHooks for TestHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, .. } => { + if input.tool_name == "dangerous_tool" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by policy".to_string()), + ..Default::default() + }) + } else { + HookOutput::None + } + } + HookEvent::UserPromptSubmitted { input, .. } => { + HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput { + modified_prompt: Some(format!("[prefixed] {}", input.prompt)), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + #[tokio::test] + async fn dispatch_pre_tool_use_deny() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "dangerous_tool", + "toolArgs": {} + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + let output = &result["output"]; + assert_eq!(output["permissionDecision"], "deny"); + assert_eq!(output["permissionDecisionReason"], "blocked by policy"); + } + + #[tokio::test] + async fn dispatch_pre_tool_use_passthrough() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "safe_tool", + "toolArgs": {"key": "value"} + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + // No hook registered for this tool โ€” output should be empty object + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_user_prompt_submitted() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "prompt": "hello world" + }); + let result = dispatch_hook( + &hooks, + &SessionId::new("sess-1"), + "userPromptSubmitted", + input, + ) + .await + .unwrap(); + assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world"); + } + + #[tokio::test] + async fn dispatch_unregistered_hook_returns_empty() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "reason": "complete" + }); + // TestHooks doesn't handle SessionEnd + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_unknown_hook_type() { + let hooks = TestHooks; + let input = serde_json::json!({}); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_mismatched_output_returns_empty() { + struct MismatchHooks; + #[async_trait] + impl SessionHooks for MismatchHooks { + async fn on_hook(&self, _event: HookEvent) -> HookOutput { + // Always return SessionEnd output regardless of event type + HookOutput::SessionEnd(SessionEndOutput { + session_summary: Some("oops".to_string()), + ..Default::default() + }) + } + } + + let hooks = MismatchHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "some_tool", + "toolArgs": {} + }); + // preToolUse event gets a SessionEnd output โ€” should be treated as empty + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_post_tool_use_default() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "some_tool", + "toolArgs": {}, + "toolResult": "success" + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_session_start() { + struct StartHooks; + #[async_trait] + impl SessionHooks for StartHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::SessionStart { .. } => { + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("extra context".to_string()), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + let hooks = StartHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "source": "new" + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input) + .await + .unwrap(); + assert_eq!(result["output"]["additionalContext"], "extra context"); + } + + #[tokio::test] + async fn dispatch_error_occurred() { + struct ErrorHooks; + #[async_trait] + impl SessionHooks for ErrorHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::ErrorOccurred { .. } => { + HookOutput::ErrorOccurred(ErrorOccurredOutput { + error_handling: Some("retry".to_string()), + retry_count: Some(3), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + let hooks = ErrorHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "error": "model timeout", + "errorContext": "model_call", + "recoverable": true + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input) + .await + .unwrap(); + assert_eq!(result["output"]["errorHandling"], "retry"); + assert_eq!(result["output"]["retryCount"], 3); + } +} diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs new file mode 100644 index 000000000..5f6d95612 --- /dev/null +++ b/rust/src/jsonrpc.rs @@ -0,0 +1,549 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::sync::{broadcast, mpsc, oneshot}; +use tracing::{Instrument, error, warn}; + +use crate::{Error, ProtocolError}; + +/// A JSON-RPC 2.0 request message. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcRequest { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Request ID for correlating responses. + pub id: u64, + /// RPC method name. + pub method: String, + /// Optional method parameters. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A JSON-RPC 2.0 response message. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcResponse { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Request ID this response correlates to. + pub id: u64, + /// Success payload (mutually exclusive with `error`). + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Error payload (mutually exclusive with `result`). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// A JSON-RPC 2.0 error object. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + /// Numeric error code. + pub code: i32, + /// Human-readable error description. + pub message: String, + /// Optional structured error data. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// Standard JSON-RPC 2.0 error codes. +pub mod error_codes { + /// Method not found (-32601). + pub const METHOD_NOT_FOUND: i32 = -32601; + /// Invalid method parameters (-32602). + pub const INVALID_PARAMS: i32 = -32602; + /// Internal server error (-32603). + #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")] + pub const INTERNAL_ERROR: i32 = -32603; +} + +/// A JSON-RPC 2.0 notification (no `id`, no response expected). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcNotification { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Notification method name. + pub method: String, + /// Optional notification parameters. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A parsed JSON-RPC 2.0 message โ€” request, response, or notification. +#[derive(Debug, Clone, Serialize)] +pub enum JsonRpcMessage { + /// An incoming or outgoing request. + Request(JsonRpcRequest), + /// A response to a previous request. + Response(JsonRpcResponse), + /// A fire-and-forget notification. + Notification(JsonRpcNotification), +} + +/// Custom deserializer that dispatches based on field presence instead of +/// `#[serde(untagged)]` which tries each variant sequentially (3ร— parse +/// attempts for Notification โ€” the hot-path streaming variant). +/// +/// Dispatch logic: +/// - has `id` + has `method` โ†’ Request +/// - has `id` + no `method` โ†’ Response +/// - no `id` โ†’ Notification +impl<'de> Deserialize<'de> for JsonRpcMessage { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + let obj = value + .as_object() + .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?; + + let has_id = obj.contains_key("id"); + let has_method = obj.contains_key("method"); + + if has_id && has_method { + JsonRpcRequest::deserialize(value) + .map(JsonRpcMessage::Request) + .map_err(serde::de::Error::custom) + } else if has_id { + JsonRpcResponse::deserialize(value) + .map(JsonRpcMessage::Response) + .map_err(serde::de::Error::custom) + } else { + JsonRpcNotification::deserialize(value) + .map(JsonRpcMessage::Notification) + .map_err(serde::de::Error::custom) + } + } +} + +impl JsonRpcRequest { + /// Create a new JSON-RPC request with the given ID, method, and params. + pub fn new(id: u64, method: &str, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + } + } +} + +impl JsonRpcResponse { + /// Returns `true` if this response contains an error. + #[allow(dead_code)] + pub fn is_error(&self) -> bool { + self.error.is_some() + } +} + +const CONTENT_LENGTH_HEADER: &str = "Content-Length: "; + +/// One framed JSON-RPC message handed to the writer actor. +/// +/// `frame` is the fully serialized bytes (header + body); the caller pays +/// the serde cost synchronously before enqueueing so the actor never sees a +/// `Result` from JSON encoding. `ack` resolves once the bytes have been +/// fully written and flushed (or the underlying I/O reports an error). If +/// the caller drops the `oneshot::Receiver`, the actor still completes the +/// frame โ€” caller cancellation cannot desync the wire. +struct WriteCommand { + frame: Vec, + ack: oneshot::Sender>, +} + +/// Low-level JSON-RPC 2.0 client over Content-Length-framed streams. +/// +/// # Cancel safety +/// +/// All public methods (`write`, `send_request`) are **cancel-safe**: the +/// actual bytes hit the wire on a dedicated background actor task, so +/// dropping the caller's future after `await` returns `Pending` cannot +/// produce a partial frame on the wire. Frames either land atomically or +/// the underlying I/O fails. See `cancel-safety review` artifact for the +/// full RFD-400 reasoning. +pub struct JsonRpcClient { + request_id: AtomicU64, + /// Sender side of the writer actor's command queue. Public methods + /// pre-serialize their frames and enqueue here; the background actor + /// drains the queue and serializes writes onto the underlying + /// `AsyncWrite`. Unbounded by design โ€” RFD 400 explicitly permits this + /// for cancel-safety, and JSON-RPC frames are small relative to the + /// natural request/response back-pressure of the wire. + write_tx: mpsc::UnboundedSender, + pending_requests: Arc>>>, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, +} + +impl JsonRpcClient { + /// Create a new client from async read/write streams. + /// + /// Spawns two background tasks: a reader that dispatches incoming + /// messages to pending request channels, the notification broadcast, + /// or the request-forwarding channel; and a writer actor that owns the + /// underlying `AsyncWrite` and serializes frames atomically. + pub fn new( + writer: impl AsyncWrite + Unpin + Send + 'static, + reader: impl AsyncRead + Unpin + Send + 'static, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, + ) -> Self { + let (write_tx, write_rx) = mpsc::unbounded_channel::(); + + let writer_span = tracing::error_span!("jsonrpc_write_loop"); + tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span)); + + let client = Self { + request_id: AtomicU64::new(1), + write_tx, + pending_requests: Arc::new(RwLock::new(HashMap::new())), + notification_tx, + request_tx, + }; + + let pending_requests = client.pending_requests.clone(); + let notification_tx_clone = client.notification_tx.clone(); + let request_tx_clone = client.request_tx.clone(); + let reader_span = tracing::error_span!("jsonrpc_read_loop"); + + tokio::spawn( + async move { + Self::read_loop( + reader, + pending_requests, + notification_tx_clone, + request_tx_clone, + ) + .await; + } + .instrument(reader_span), + ); + + client + } + + /// Writer-actor task. Owns the `AsyncWrite`, drains the command queue, + /// and writes each frame atomically (header + body + flush) before + /// signaling the ack. + /// + /// Caller-side cancellation cannot interrupt a write in progress: + /// dropping the ack `oneshot::Receiver` does not cancel the in-flight + /// I/O. Once `WriteCommand` is enqueued the frame is committed to land + /// on the wire (or surface an `io::Error` to the ack receiver if the + /// transport is broken). + /// + /// Exits cleanly when all senders drop (channel closes), flushing any + /// final buffered bytes. + async fn write_loop( + mut writer: impl AsyncWrite + Unpin + Send + 'static, + mut rx: mpsc::UnboundedReceiver, + ) { + while let Some(WriteCommand { frame, ack }) = rx.recv().await { + let result = async { + writer.write_all(&frame).await?; + writer.flush().await?; + Ok::<_, std::io::Error>(()) + } + .await; + + // Caller may have dropped the ack receiver (e.g. their + // `await` was cancelled); that's fine โ€” we still completed + // the write, which was the whole point. + let _ = ack.send(result); + } + } + + async fn read_loop( + reader: impl AsyncRead + Unpin + Send, + pending_requests: Arc>>>, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, + ) { + let mut reader = BufReader::new(reader); + + loop { + match Self::read_message(&mut reader).await { + Ok(Some(message)) => match message { + JsonRpcMessage::Response(response) => { + let id = response.id; + let tx = pending_requests.write().remove(&id); + if let Some(tx) = tx { + if tx.send(response).is_err() { + warn!(request_id = %id, "failed to send response for request"); + } + } else { + warn!(request_id = %id, "received response for unknown request id"); + } + } + JsonRpcMessage::Notification(notification) => { + let _ = notification_tx.send(notification); + } + JsonRpcMessage::Request(request) => { + if request_tx.send(request).is_err() { + warn!("failed to forward JSON-RPC request, channel closed"); + } + } + }, + Ok(None) => { + break; + } + Err(e) => { + error!(error = %e, "error reading from CLI"); + break; + } + } + } + + // Drain in-flight requests so callers observe cancellation + // instead of hanging on a oneshot receiver. + let mut pending = pending_requests.write(); + if !pending.is_empty() { + warn!( + count = pending.len(), + "draining pending requests after read loop exit" + ); + pending.clear(); + } + } + + async fn read_message( + reader: &mut BufReader, + ) -> Result, Error> { + let mut line = String::new(); + let mut content_length = None; + + loop { + line.clear(); + if reader.read_line(&mut line).await? == 0 { + return Ok(None); + } + + let trimmed = line.trim(); + if trimmed.is_empty() { + break; + } + + if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) { + content_length = Some(value.trim().parse::().map_err(|_| { + Error::Protocol(ProtocolError::InvalidContentLength( + value.trim().to_string(), + )) + })?); + } + } + + let Some(length) = content_length else { + return Err(Error::Protocol(ProtocolError::MissingContentLength)); + }; + + let mut body = vec![0u8; length]; + reader.read_exact(&mut body).await?; + + let message: JsonRpcMessage = serde_json::from_slice(&body)?; + Ok(Some(message)) + } + + /// Send a JSON-RPC request and wait for the matching response. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The frame is committed to the wire via the writer + /// actor before this future yields; cancelling the await drops the + /// response oneshot but does not desync the transport. The pending- + /// requests map is cleaned up automatically (the `PendingGuard` drop + /// removes the entry, and the read loop's response handling tolerates + /// a missing entry). + pub async fn send_request( + &self, + method: &str, + params: Option, + ) -> Result { + let id = self.request_id.fetch_add(1, Ordering::SeqCst); + let request = JsonRpcRequest::new(id, method, params); + + let (tx, rx) = oneshot::channel(); + self.pending_requests.write().insert(id, tx); + + // RAII guard that removes the pending entry if this future is + // dropped before the response arrives. Disarmed below before the + // success return so the read loop owns the cleanup on the happy + // path. + let mut guard = PendingGuard { + map: &self.pending_requests, + id, + armed: true, + }; + + // The PendingGuard's drop removes the entry on every error path + // and on cancellation; disarmed below before the success return so + // the read loop owns the cleanup on the happy path. + self.write(&request).await?; + + let response = rx + .await + .map_err(|_| Error::Protocol(ProtocolError::RequestCancelled))?; + guard.disarm(); + Ok(response) + } + + /// Write a Content-Length-framed JSON-RPC message to the transport. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Pre-serializes the body, enqueues it on the writer + /// actor's command channel, and awaits an ack. Caller cancellation + /// drops the ack receiver; the actor still completes the frame and + /// flushes. A partial frame can never appear on the wire. + pub async fn write(&self, message: &T) -> Result<(), Error> { + let body = serde_json::to_vec(message)?; + let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4); + frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes()); + frame.extend_from_slice(body.len().to_string().as_bytes()); + frame.extend_from_slice(b"\r\n\r\n"); + frame.extend_from_slice(&body); + + let (ack_tx, ack_rx) = oneshot::channel(); + self.write_tx + .send(WriteCommand { frame, ack: ack_tx }) + .map_err(|_| { + Error::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "writer actor has shut down", + )) + })?; + + match ack_rx.await { + Ok(Ok(())) => Ok(()), + Ok(Err(e)) => Err(Error::Io(e)), + Err(_) => Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "writer actor dropped ack without responding", + ))), + } + } +} + +/// RAII guard that removes a pending-request entry from the map if the +/// owning future is dropped before the response arrives. Disarmed on the +/// happy path so the read loop's response handling owns the cleanup. +struct PendingGuard<'a> { + map: &'a RwLock>>, + id: u64, + armed: bool, +} + +impl PendingGuard<'_> { + fn disarm(&mut self) { + self.armed = false; + } +} + +impl Drop for PendingGuard<'_> { + fn drop(&mut self) { + if self.armed { + self.map.write().remove(&self.id); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_notification() { + let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event")); + } + + #[test] + fn deserialize_request() { + let json = + r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!( + matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request") + ); + } + + #[test] + fn deserialize_response_with_result() { + let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error())); + } + + #[test] + fn deserialize_error_response() { + let json = + r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + match msg { + JsonRpcMessage::Response(r) => { + assert!(r.is_error()); + let err = r.error.unwrap(); + assert_eq!(err.code, -32600); + assert_eq!(err.message, "Invalid Request"); + } + other => panic!("expected Response, got {other:?}"), + } + } + + #[test] + fn deserialize_rejects_non_object() { + let result = serde_json::from_str::(r#""not an object""#); + assert!(result.is_err()); + } + + #[test] + fn request_new_sets_version() { + let req = JsonRpcRequest::new(42, "test.method", None); + assert_eq!(req.jsonrpc, "2.0"); + assert_eq!(req.id, 42); + assert_eq!(req.method, "test.method"); + assert!(req.params.is_none()); + } + + #[test] + fn request_serializes_camel_case() { + let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({}))); + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains(r#""jsonrpc":"2.0""#)); + assert!(json.contains(r#""id":1"#)); + assert!(json.contains(r#""method":"ping""#)); + } + + #[test] + fn notification_without_params_omits_field() { + let n = JsonRpcNotification { + jsonrpc: "2.0".into(), + method: "ping".into(), + params: None, + }; + let json = serde_json::to_string(&n).unwrap(); + assert!(!json.contains("params")); + } + + #[test] + fn response_without_error_omits_field() { + let r = JsonRpcResponse { + jsonrpc: "2.0".into(), + id: 1, + result: Some(serde_json::json!(true)), + error: None, + }; + let json = serde_json::to_string(&r).unwrap(); + assert!(!json.contains("error")); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 000000000..0c07e0470 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,2425 @@ +#![doc = include_str!("../README.md")] +#![warn(missing_docs)] +#![deny(rustdoc::broken_intra_doc_links)] +#![cfg_attr(test, allow(clippy::unwrap_used))] + +/// Bundled CLI binary extraction and caching. +pub mod embeddedcli; +/// Event handler traits for session lifecycle. +pub mod handler; +/// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). +pub mod hooks; +mod jsonrpc; +/// Permission-policy helpers that wrap an existing [`handler::SessionHandler`]. +pub mod permission; +/// GitHub Copilot CLI binary resolution (env var, embedded, PATH search). +pub mod resolve; +mod router; +/// Session management โ€” create, resume, send messages, and interact with the agent. +pub mod session; +/// Custom session filesystem provider (virtualizable filesystem layer). +pub mod session_fs; +mod session_fs_dispatch; +/// Event subscription handles returned by `subscribe()` methods. +pub mod subscription; +/// Typed tool definition framework and dispatch router. +pub mod tool; +/// W3C Trace Context propagation for distributed tracing. +pub mod trace_context; +/// System message transform callbacks for customizing agent prompts. +pub mod transforms; +/// Protocol types shared between the SDK and the GitHub Copilot CLI. +pub mod types; + +/// Auto-generated protocol types from Copilot JSON Schemas. +pub mod generated; + +use std::ffi::OsString; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::{Arc, OnceLock}; + +use async_trait::async_trait; +// JSON-RPC wire types are internal transport details (like Go SDK's internal/jsonrpc2/). +// External callers interact via Client/Session methods, not raw RPC. +pub(crate) use jsonrpc::{ + JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes, +}; + +/// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature). +#[cfg(feature = "test-support")] +pub mod test_support { + pub use crate::jsonrpc::{ + JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + error_codes, + }; +} +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader}; +use tokio::net::TcpStream; +use tokio::process::{Child, Command}; +use tokio::sync::{broadcast, mpsc, oneshot}; +use tracing::{Instrument, debug, error, info, warn}; +pub use types::*; + +mod sdk_protocol_version; +pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version}; +pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError}; + +/// Minimum protocol version this SDK can communicate with. +const MIN_PROTOCOL_VERSION: u32 = 2; + +/// Errors returned by the SDK. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// JSON-RPC transport or protocol violation. + #[error("protocol error: {0}")] + Protocol(ProtocolError), + + /// The CLI returned a JSON-RPC error response. + #[error("RPC error {code}: {message}")] + Rpc { + /// JSON-RPC error code. + code: i32, + /// Human-readable error message. + message: String, + }, + + /// Session-scoped error (not found, agent error, timeout, etc.). + #[error("session error: {0}")] + Session(SessionError), + + /// I/O error on the stdio transport or during process spawn. + #[error(transparent)] + Io(#[from] std::io::Error), + + /// Failed to serialize or deserialize a JSON-RPC message. + #[error(transparent)] + Json(#[from] serde_json::Error), + + /// A required binary was not found on the system. + #[error("binary not found: {name} ({hint})")] + BinaryNotFound { + /// Binary name that was searched for. + name: &'static str, + /// Guidance on how to install or configure the binary. + hint: &'static str, + }, + + /// Invalid combination of [`ClientOptions`] supplied to [`Client::start`]. + /// Surfaces consumer-side configuration errors that would otherwise + /// produce confusing runtime failures (e.g. a connection token paired + /// with stdio transport). + #[error("invalid client configuration: {0}")] + InvalidConfig(String), +} + +impl Error { + /// Returns true if this error indicates the transport is broken โ€” the CLI + /// process exited, the connection was lost, or an I/O failure occurred. + /// Callers should discard the client and create a fresh one. + pub fn is_transport_failure(&self) -> bool { + matches!( + self, + Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_) + ) + } +} + +/// Aggregate of errors collected during [`Client::stop`]. +/// +/// `Client::stop` performs cooperative shutdown across every active +/// session before killing the CLI child process. Errors from any +/// per-session `session.destroy` RPC and from the terminal child-kill +/// step are collected here rather than short-circuiting on the first +/// failure, so callers see the full picture of what went wrong during +/// teardown. +/// +/// Implements [`std::error::Error`] and forwards to `Display` for the +/// first error, with a count suffix when there are more. +#[derive(Debug)] +pub struct StopErrors(Vec); + +impl StopErrors { + /// Borrow the collected errors as a slice, in the order they + /// occurred (per-session destroys first, then child-kill last). + pub fn errors(&self) -> &[Error] { + &self.0 + } + + /// Consume the aggregate and return the underlying error vector. + pub fn into_errors(self) -> Vec { + self.0 + } +} + +impl std::fmt::Display for StopErrors { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.as_slice() { + [] => write!(f, "stop completed with no errors"), + [only] => write!(f, "stop failed: {only}"), + [first, rest @ ..] => write!( + f, + "stop failed with {n} errors; first: {first}", + n = 1 + rest.len(), + ), + } + } +} + +impl std::error::Error for StopErrors { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0 + .first() + .map(|e| e as &(dyn std::error::Error + 'static)) + } +} + +/// Specific protocol-level errors in the JSON-RPC transport or CLI lifecycle. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum ProtocolError { + /// Missing `Content-Length` header in a JSON-RPC message. + #[error("missing Content-Length header")] + MissingContentLength, + + /// Invalid `Content-Length` header value. + #[error("invalid Content-Length value: \"{0}\"")] + InvalidContentLength(String), + + /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). + #[error("request cancelled")] + RequestCancelled, + + /// The CLI process did not report a listening port within the timeout. + #[error("timed out waiting for CLI to report listening port")] + CliStartupTimeout, + + /// The CLI process exited before reporting a listening port. + #[error("CLI exited before reporting listening port")] + CliStartupFailed, + + /// The CLI server's protocol version is outside the SDK's supported range. + #[error("version mismatch: server={server}, supported={min}โ€“{max}")] + VersionMismatch { + /// Version reported by the server. + server: u32, + /// Minimum version supported by this SDK. + min: u32, + /// Maximum version supported by this SDK. + max: u32, + }, + + /// The CLI server's protocol version changed between calls. + #[error("version changed: was {previous}, now {current}")] + VersionChanged { + /// Previously negotiated version. + previous: u32, + /// Newly reported version. + current: u32, + }, +} + +/// Session-scoped errors. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum SessionError { + /// The CLI could not find the requested session. + #[error("session not found: {0}")] + NotFound(SessionId), + + /// The CLI reported an error during agent execution (via `session.error` event). + #[error("{0}")] + AgentError(String), + + /// A `send_and_wait` call exceeded its timeout. + #[error("timed out after {0:?}")] + Timeout(std::time::Duration), + + /// `send` was called while a `send_and_wait` is in flight. + #[error("cannot send while send_and_wait is in flight")] + SendWhileWaiting, + + /// The session event loop exited before a pending `send_and_wait` completed. + #[error("event loop closed before session reached idle")] + EventLoopClosed, + + /// Elicitation is not supported by the host. + /// Check `session.capabilities().ui.elicitation` before calling UI methods. + #[error( + "elicitation not supported by host โ€” check session.capabilities().ui.elicitation first" + )] + ElicitationNotSupported, + + /// The client was started with [`ClientOptions::session_fs`] but this + /// session was created without a [`SessionFsProvider`]. Set one via + /// [`SessionConfig::with_session_fs_provider`] (or + /// [`ResumeSessionConfig::with_session_fs_provider`]). + #[error( + "session was created on a client with session_fs configured but no SessionFsProvider was supplied" + )] + SessionFsProviderRequired, + + /// [`ClientOptions::session_fs`] was provided with empty or invalid + /// fields. All of `initial_cwd` and `session_state_path` must be + /// non-empty. + #[error("invalid SessionFsConfig: {0}")] + InvalidSessionFsConfig(String), +} + +/// How the SDK communicates with the CLI server. +#[derive(Debug, Default)] +#[non_exhaustive] +pub enum Transport { + /// Communicate over stdin/stdout pipes (default). + #[default] + Stdio, + /// Spawn the CLI with `--port` and connect via TCP. + Tcp { + /// Port to listen on (0 for OS-assigned). + port: u16, + }, + /// Connect to an already-running CLI server (no process spawning). + External { + /// Hostname or IP of the running server. + host: String, + /// Port of the running server. + port: u16, + }, +} + +/// How the SDK locates the GitHub Copilot CLI binary. +#[derive(Debug, Clone, Default)] +pub enum CliProgram { + /// Auto-resolve: `COPILOT_CLI_PATH` โ†’ embedded CLI โ†’ PATH + common locations. + /// This is the default. + #[default] + Resolve, + /// Use an explicit binary path (skips resolution). + Path(PathBuf), +} + +impl From for CliProgram { + fn from(path: PathBuf) -> Self { + Self::Path(path) + } +} + +/// Options for starting a [`Client`]. +/// +/// When `program` is [`CliProgram::Resolve`] (the default), +/// [`Client::start`] automatically resolves the binary via +/// [`resolve::copilot_binary()`] โ€” checking `COPILOT_CLI_PATH`, the +/// embedded CLI, and then the system PATH and common install locations. +/// +/// Set `program` to [`CliProgram::Path`] to use an explicit binary. +#[non_exhaustive] +pub struct ClientOptions { + /// How to locate the CLI binary. + pub program: CliProgram, + /// Arguments prepended before `--server` (e.g. the script path for node). + pub prefix_args: Vec, + /// Working directory for the CLI process. + pub cwd: PathBuf, + /// Environment variables set on the child process. + pub env: Vec<(OsString, OsString)>, + /// Environment variable names to remove from the child process. + pub env_remove: Vec, + /// Extra CLI flags appended after the transport-specific arguments. + pub extra_args: Vec, + /// Transport mode used to communicate with the CLI server. + pub transport: Transport, + /// GitHub token for authentication. When set, the SDK passes the token + /// to the CLI via `--auth-token-env COPILOT_SDK_AUTH_TOKEN` and exports + /// the token in that env var. When set, the CLI defaults to *not* + /// using the logged-in user (override with [`Self::use_logged_in_user`]). + pub github_token: Option, + /// Whether the CLI should fall back to the logged-in `gh` user when no + /// token is provided. `None` means use the runtime default (true unless + /// [`Self::github_token`] is set, in which case false). + pub use_logged_in_user: Option, + /// Log level passed to the CLI server via `--log-level`. When `None`, + /// the SDK uses [`LogLevel::Info`]. + pub log_level: Option, + /// Server-wide idle timeout for sessions, in seconds. When set to a + /// positive value, the SDK passes `--session-idle-timeout ` to + /// the CLI; sessions without activity for this duration are + /// automatically cleaned up. `None` or `Some(0)` leaves sessions + /// running indefinitely (the CLI default). + pub session_idle_timeout_seconds: Option, + /// Optional override for [`Client::list_models`]. + /// + /// When set, [`Client::list_models`] returns the handler's result + /// without making a `models.list` RPC. This is the BYOK escape hatch + /// for environments where the model catalog is provisioned separately + /// from the GitHub Copilot CLI (e.g. external inference servers selected via + /// [`Transport::External`]). + pub on_list_models: Option>, + /// Custom session filesystem provider configuration. + /// + /// When set, the SDK calls `sessionFs.setProvider` during + /// [`Client::start`] to register a virtualizable filesystem layer with + /// the CLI. Each session created on this client must supply its own + /// [`SessionFsProvider`] via + /// [`SessionConfig::with_session_fs_provider`](crate::SessionConfig::with_session_fs_provider). + pub session_fs: Option, + /// Optional [`TraceContextProvider`] used to inject W3C Trace Context + /// headers (`traceparent` / `tracestate`) on outbound `session.create`, + /// `session.resume`, and `session.send` requests. + /// + /// When [`MessageOptions`] carries a per-turn override (set via + /// [`MessageOptions::with_trace_context`](crate::types::MessageOptions::with_trace_context) + /// or the underlying fields), it takes precedence over this provider. + /// + /// [`MessageOptions`]: crate::types::MessageOptions + pub on_get_trace_context: Option>, + /// OpenTelemetry config forwarded to the spawned CLI process. See + /// [`TelemetryConfig`] for the env-var mapping. The SDK takes no + /// OpenTelemetry dependency โ€” this is pure spawn-time env injection. + pub telemetry: Option, + /// Override the directory where the CLI persists its state (sessions, + /// auth, telemetry buffers). When set, exported as `COPILOT_HOME` to + /// the spawned CLI process. Useful for sandboxing test runs or + /// running multiple isolated SDK instances side-by-side. + pub copilot_home: Option, + /// Optional connection token for TCP transport. Sent to the CLI in + /// the `connect` handshake and exported as `COPILOT_CONNECTION_TOKEN` + /// to spawned CLI processes. Required when the CLI server was started + /// with a token, ignored otherwise. + /// + /// When the SDK spawns its own CLI in TCP mode and this is left + /// `None`, a UUID is generated automatically so the loopback listener + /// is safe by default. Combining with [`Transport::Stdio`] is invalid + /// and surfaces as an error from [`Client::start`]. + pub tcp_connection_token: Option, +} + +impl std::fmt::Debug for ClientOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientOptions") + .field("program", &self.program) + .field("prefix_args", &self.prefix_args) + .field("cwd", &self.cwd) + .field("env", &self.env) + .field("env_remove", &self.env_remove) + .field("extra_args", &self.extra_args) + .field("transport", &self.transport) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field("use_logged_in_user", &self.use_logged_in_user) + .field("log_level", &self.log_level) + .field( + "session_idle_timeout_seconds", + &self.session_idle_timeout_seconds, + ) + .field( + "on_list_models", + &self.on_list_models.as_ref().map(|_| ""), + ) + .field("session_fs", &self.session_fs) + .field( + "on_get_trace_context", + &self.on_get_trace_context.as_ref().map(|_| ""), + ) + .field("telemetry", &self.telemetry) + .field("copilot_home", &self.copilot_home) + .field( + "tcp_connection_token", + &self.tcp_connection_token.as_ref().map(|_| ""), + ) + .finish() + } +} + +/// Custom handler for [`Client::list_models`]. +/// +/// Implementations override the default `models.list` RPC, returning a +/// caller-supplied catalog of models. Set via [`ClientOptions::on_list_models`]. +/// +/// Implementations must be `Send + Sync` because [`Client`] is shared across +/// tasks. Errors returned by [`list_models`](Self::list_models) are propagated +/// from [`Client::list_models`] unchanged. +#[async_trait] +pub trait ListModelsHandler: Send + Sync + 'static { + /// Return the list of available models. + async fn list_models(&self) -> Result, Error>; +} + +/// Log verbosity for the CLI server (passed via `--log-level`). +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// Suppress all CLI logs. + None, + /// Errors only. + Error, + /// Warnings and errors. + Warning, + /// Default. Info and above. + Info, + /// Debug, info, warnings, errors. + Debug, + /// Everything, including trace output. + All, +} + +impl LogLevel { + /// CLI argument value (e.g. `"info"`, `"debug"`). + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::Error => "error", + Self::Warning => "warning", + Self::Info => "info", + Self::Debug => "debug", + Self::All => "all", + } + } +} + +impl std::fmt::Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Backend exporter for the CLI's OpenTelemetry pipeline. +/// +/// Maps to the `COPILOT_OTEL_EXPORTER_TYPE` environment variable on the +/// spawned CLI process. Wire values are `"otlp-http"` and `"file"`. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub enum OtelExporterType { + /// Export via OTLP HTTP to the endpoint configured by + /// [`TelemetryConfig::otlp_endpoint`]. + OtlpHttp, + /// Export to a JSON-lines file at the path configured by + /// [`TelemetryConfig::file_path`]. + File, +} + +impl OtelExporterType { + /// Environment-variable value (`"otlp-http"` or `"file"`). + pub fn as_str(self) -> &'static str { + match self { + Self::OtlpHttp => "otlp-http", + Self::File => "file", + } + } +} + +/// OpenTelemetry configuration forwarded to the spawned GitHub Copilot CLI +/// process. +/// +/// When [`ClientOptions::telemetry`] is `Some(...)`, the SDK sets +/// `COPILOT_OTEL_ENABLED=true` plus any populated fields below as the +/// corresponding `OTEL_*` / `COPILOT_OTEL_*` environment variables. The +/// CLI's built-in OpenTelemetry exporter consumes these at startup. The +/// SDK itself takes no OpenTelemetry dependency. +/// +/// Environment-variable mapping: +/// +/// | Field | Variable | +/// |----------------------|-------------------------------------------------------| +/// | (any field set) | `COPILOT_OTEL_ENABLED=true` | +/// | [`otlp_endpoint`] | `OTEL_EXPORTER_OTLP_ENDPOINT` | +/// | [`file_path`] | `COPILOT_OTEL_FILE_EXPORTER_PATH` | +/// | [`exporter_type`] | `COPILOT_OTEL_EXPORTER_TYPE` | +/// | [`source_name`] | `COPILOT_OTEL_SOURCE_NAME` | +/// | [`capture_content`] | `OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT` | +/// +/// Caller-supplied entries in [`ClientOptions::env`] override these, so a +/// developer can pin any individual variable to a different value while +/// keeping the rest of the config managed by [`TelemetryConfig`]. +/// +/// Marked `#[non_exhaustive]` so future CLI-side telemetry knobs can be +/// added without breaking callers. +/// +/// [`otlp_endpoint`]: Self::otlp_endpoint +/// [`file_path`]: Self::file_path +/// [`exporter_type`]: Self::exporter_type +/// [`source_name`]: Self::source_name +/// [`capture_content`]: Self::capture_content +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct TelemetryConfig { + /// OTLP HTTP endpoint URL for trace/metric export. + pub otlp_endpoint: Option, + /// File path for JSON-lines trace output. + pub file_path: Option, + /// Exporter backend type. Typically [`OtelExporterType::OtlpHttp`] or + /// [`OtelExporterType::File`]. + pub exporter_type: Option, + /// Instrumentation scope name. Useful for distinguishing this + /// embedder's traces from other Copilot-CLI consumers exporting to the + /// same backend. + pub source_name: Option, + /// Whether the CLI captures GenAI message content (prompts and + /// responses) on emitted spans. `Some(true)` opts in; `Some(false)` + /// opts out; `None` leaves the CLI default (typically off). + pub capture_content: Option, +} + +impl TelemetryConfig { + /// Construct an empty [`TelemetryConfig`]; all fields default to + /// unset (`is_empty()` returns `true`). + pub fn new() -> Self { + Self::default() + } + + /// Set the OTLP HTTP endpoint URL for trace/metric export. + pub fn with_otlp_endpoint(mut self, endpoint: impl Into) -> Self { + self.otlp_endpoint = Some(endpoint.into()); + self + } + + /// Set the file path for JSON-lines trace output. + pub fn with_file_path(mut self, path: impl Into) -> Self { + self.file_path = Some(path.into()); + self + } + + /// Set the exporter backend type. + pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self { + self.exporter_type = Some(exporter_type); + self + } + + /// Set the instrumentation scope name. Useful for distinguishing + /// this embedder's traces from other Copilot-CLI consumers + /// exporting to the same backend. + pub fn with_source_name(mut self, source_name: impl Into) -> Self { + self.source_name = Some(source_name.into()); + self + } + + /// Opt in or out of GenAI message content capture on emitted spans. + /// `true` opts in; `false` opts out. Leaving this unset preserves + /// the CLI default (typically off). + pub fn with_capture_content(mut self, capture: bool) -> Self { + self.capture_content = Some(capture); + self + } + + /// Returns `true` if all fields are unset. Used by [`Client::start`] + /// to decide whether to set `COPILOT_OTEL_ENABLED`. + pub fn is_empty(&self) -> bool { + self.otlp_endpoint.is_none() + && self.file_path.is_none() + && self.exporter_type.is_none() + && self.source_name.is_none() + && self.capture_content.is_none() + } +} + +impl Default for ClientOptions { + fn default() -> Self { + Self { + program: CliProgram::Resolve, + prefix_args: Vec::new(), + cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + env: Vec::new(), + env_remove: Vec::new(), + extra_args: Vec::new(), + transport: Transport::default(), + github_token: None, + use_logged_in_user: None, + log_level: None, + session_idle_timeout_seconds: None, + on_list_models: None, + session_fs: None, + on_get_trace_context: None, + telemetry: None, + copilot_home: None, + tcp_connection_token: None, + } + } +} + +impl ClientOptions { + /// Construct a new [`ClientOptions`] with default values. + /// + /// Equivalent to [`ClientOptions::default`]; provided as a documented + /// construction entry point for the builder chain. The struct is + /// `#[non_exhaustive]`, so external callers cannot use struct-literal + /// syntax โ€” use this builder or [`Default::default`] plus mut-let. + /// + /// # Example + /// + /// ``` + /// # use github_copilot_sdk::{ClientOptions, LogLevel}; + /// let opts = ClientOptions::new() + /// .with_log_level(LogLevel::Debug) + /// .with_github_token("ghp_โ€ฆ"); + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// How to locate the CLI binary. See [`CliProgram`]. + pub fn with_program(mut self, program: impl Into) -> Self { + self.program = program.into(); + self + } + + /// Arguments prepended before `--server` (e.g. the script path for node). + pub fn with_prefix_args(mut self, args: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.prefix_args = args.into_iter().map(Into::into).collect(); + self + } + + /// Working directory for the CLI process. + pub fn with_cwd(mut self, cwd: impl Into) -> Self { + self.cwd = cwd.into(); + self + } + + /// Environment variables to set on the child process. + pub fn with_env(mut self, env: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into, + { + self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); + self + } + + /// Environment variable names to remove from the child process. + pub fn with_env_remove(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.env_remove = names.into_iter().map(Into::into).collect(); + self + } + + /// Extra CLI flags appended after the transport-specific arguments. + pub fn with_extra_args(mut self, args: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.extra_args = args.into_iter().map(Into::into).collect(); + self + } + + /// Transport mode used to communicate with the CLI server. See [`Transport`]. + pub fn with_transport(mut self, transport: Transport) -> Self { + self.transport = transport; + self + } + + /// GitHub token for authentication. The SDK passes the token to the + /// CLI via `--auth-token-env COPILOT_SDK_AUTH_TOKEN`. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Whether the CLI should fall back to the logged-in `gh` user when + /// no token is provided. See the field docs for default semantics. + pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self { + self.use_logged_in_user = Some(use_logged_in); + self + } + + /// Log level passed to the CLI server via `--log-level`. + pub fn with_log_level(mut self, level: LogLevel) -> Self { + self.log_level = Some(level); + self + } + + /// Server-wide idle timeout for sessions (seconds). Pass `0` to leave + /// sessions running indefinitely (the CLI default). + pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self { + self.session_idle_timeout_seconds = Some(seconds); + self + } + + /// Override [`Client::list_models`] with a caller-supplied handler. + /// The handler is wrapped in `Arc` internally. + pub fn with_list_models_handler(mut self, handler: H) -> Self + where + H: ListModelsHandler + 'static, + { + self.on_list_models = Some(Arc::new(handler)); + self + } + + /// Custom session filesystem provider configuration. + pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self { + self.session_fs = Some(config); + self + } + + /// Set the [`TraceContextProvider`] used to inject W3C Trace Context + /// headers on outbound `session.create` / `session.resume` / + /// `session.send` requests. The provider is wrapped in `Arc` internally. + pub fn with_trace_context_provider

(mut self, provider: P) -> Self + where + P: TraceContextProvider + 'static, + { + self.on_get_trace_context = Some(Arc::new(provider)); + self + } + + /// OpenTelemetry config forwarded to the spawned CLI process. + pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self { + self.telemetry = Some(config); + self + } + + /// Override the directory where the CLI persists its state. Set as + /// `COPILOT_HOME` on the spawned CLI process. + pub fn with_copilot_home(mut self, home: impl Into) -> Self { + self.copilot_home = Some(home.into()); + self + } + + /// Set the connection token for TCP transport. Sent in the `connect` + /// handshake and exported as `COPILOT_CONNECTION_TOKEN` to spawned + /// CLI processes. + pub fn with_tcp_connection_token(mut self, token: impl Into) -> Self { + self.tcp_connection_token = Some(token.into()); + self + } +} + +/// Validate a [`SessionFsConfig`] before sending `sessionFs.setProvider`. +fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> { + if cfg.initial_cwd.trim().is_empty() { + return Err(Error::Session(SessionError::InvalidSessionFsConfig( + "initial_cwd must not be empty".to_string(), + ))); + } + if cfg.session_state_path.trim().is_empty() { + return Err(Error::Session(SessionError::InvalidSessionFsConfig( + "session_state_path must not be empty".to_string(), + ))); + } + Ok(()) +} + +/// Generate a fresh CSPRNG-backed token for authenticating an SDK-spawned +/// loopback CLI server. 128 bits of entropy, lowercase-hex encoded โ€” not +/// a UUID (the schema-shaped IDs in this crate stay `String` per the +/// pre-1.0 review consensus, so adopting a `Uuid` type just for SDK- +/// generated secrets would be inconsistent and semantically misleading; +/// this is opaque random data, not an identifier). +fn generate_connection_token() -> String { + let mut bytes = [0u8; 16]; + getrandom::getrandom(&mut bytes) + .expect("OS CSPRNG (getrandom) is unavailable; cannot generate connection token"); + let mut hex = String::with_capacity(32); + for byte in bytes { + use std::fmt::Write; + let _ = write!(hex, "{byte:02x}"); + } + hex +} + +/// Connection to a GitHub Copilot CLI server (stdio, TCP, or external). +/// +/// Cheaply cloneable โ€” cloning shares the underlying connection. +/// The child process (if any) is killed when the last clone drops. +#[derive(Clone)] +pub struct Client { + inner: Arc, +} + +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("cwd", &self.inner.cwd) + .field("pid", &self.pid()) + .finish() + } +} + +struct ClientInner { + child: parking_lot::Mutex>, + rpc: JsonRpcClient, + cwd: PathBuf, + request_rx: parking_lot::Mutex>>, + notification_tx: broadcast::Sender, + router: router::SessionRouter, + negotiated_protocol_version: OnceLock, + server_telemetry_method: parking_lot::Mutex>, + state: parking_lot::Mutex, + lifecycle_tx: broadcast::Sender, + on_list_models: Option>, + session_fs_configured: bool, + on_get_trace_context: Option>, + /// Token sent in the `connect` handshake. Auto-generated when the + /// SDK spawns its own CLI in TCP mode and no explicit token is set; + /// `None` for stdio and for external-server transport without an + /// explicit token. + effective_connection_token: Option, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ServerTelemetryRpcMethod { + SendTelemetry, + NamespacedSendTelemetry, +} + +impl ServerTelemetryRpcMethod { + fn as_str(self) -> &'static str { + match self { + Self::SendTelemetry => "sendTelemetry", + Self::NamespacedSendTelemetry => "server.sendTelemetry", + } + } +} + +impl Client { + /// Start a CLI server process with the given options. + /// + /// For [`Transport::Stdio`], spawns the CLI with `--stdio` and communicates + /// over stdin/stdout pipes. For [`Transport::Tcp`], spawns with `--port` + /// and connects via TCP once the server reports it is listening. For + /// [`Transport::External`], connects to an already-running server. + /// + /// After establishing the connection, calls [`verify_protocol_version`](Self::verify_protocol_version) + /// to ensure the CLI server speaks a compatible protocol version. + /// When [`ClientOptions::session_fs`] is set, also calls + /// `sessionFs.setProvider` to register the SDK as the filesystem + /// backend. + pub async fn start(options: ClientOptions) -> Result { + if let Some(cfg) = &options.session_fs { + validate_session_fs_config(cfg)?; + } + // Validate token + transport combination. Stdio cannot use a + // connection token; auto-generate a UUID when the SDK spawns + // its own CLI in TCP mode and no explicit token was set. + if let Some(token) = &options.tcp_connection_token { + if token.is_empty() { + return Err(Error::InvalidConfig( + "tcp_connection_token must be a non-empty string".to_string(), + )); + } + if matches!(options.transport, Transport::Stdio) { + return Err(Error::InvalidConfig( + "tcp_connection_token cannot be used with Transport::Stdio".to_string(), + )); + } + } + let effective_connection_token: Option = match &options.transport { + Transport::Stdio => None, + Transport::Tcp { .. } => Some( + options + .tcp_connection_token + .clone() + .unwrap_or_else(generate_connection_token), + ), + Transport::External { .. } => options.tcp_connection_token.clone(), + }; + let mut options = options; + if matches!(options.transport, Transport::Tcp { .. }) + && options.tcp_connection_token.is_none() + { + // Auto-generated tokens flow to the spawned CLI via env, so + // make the field reflect what we'll actually send. + options.tcp_connection_token = effective_connection_token.clone(); + } + let session_fs_config = options.session_fs.clone(); + let program = match &options.program { + CliProgram::Path(path) => { + info!(path = %path.display(), "using explicit copilot CLI path"); + path.clone() + } + CliProgram::Resolve => { + let resolved = resolve::copilot_binary()?; + info!(path = %resolved.display(), "resolved copilot CLI"); + #[cfg(windows)] + { + if let Some(ext) = resolved.extension().and_then(|e| e.to_str()) { + if ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat") { + warn!( + path = %resolved.display(), + ext = %ext, + "resolved copilot CLI is a .cmd/.bat wrapper; \ + this may cause console window flashes on Windows" + ); + } + } + } + resolved + } + }; + + let client = match options.transport { + Transport::External { ref host, port } => { + info!(host = %host, port = %port, "connecting to external CLI server"); + let stream = TcpStream::connect((host.as_str(), port)).await?; + let (reader, writer) = tokio::io::split(stream); + Self::from_transport( + reader, + writer, + None, + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + effective_connection_token.clone(), + )? + } + Transport::Tcp { port } => { + let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?; + let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?; + let (reader, writer) = tokio::io::split(stream); + Self::drain_stderr(&mut child); + Self::from_transport( + reader, + writer, + Some(child), + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + effective_connection_token.clone(), + )? + } + Transport::Stdio => { + let mut child = Self::spawn_stdio(&program, &options)?; + let stdin = child.stdin.take().expect("stdin is piped"); + let stdout = child.stdout.take().expect("stdout is piped"); + Self::drain_stderr(&mut child); + Self::from_transport( + stdout, + stdin, + Some(child), + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + effective_connection_token.clone(), + )? + } + }; + + client.verify_protocol_version().await?; + if let Some(cfg) = session_fs_config { + let request = crate::generated::api_types::SessionFsSetProviderRequest { + conventions: cfg.conventions.into_wire(), + initial_cwd: cfg.initial_cwd, + session_state_path: cfg.session_state_path, + }; + client.rpc().session_fs().set_provider(request).await?; + } + Ok(client) + } + + /// Create a Client from raw async streams (no child process). + /// + /// Useful for testing or connecting to a server over a custom transport. + pub fn from_streams( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + cwd: PathBuf, + ) -> Result { + Self::from_transport(reader, writer, None, cwd, None, false, None, None) + } + + /// Construct a [`Client`] from raw streams with a + /// [`TraceContextProvider`] preset, for integration testing. + /// + /// Mirrors [`from_streams`](Self::from_streams) but exposes the + /// `on_get_trace_context` plumbing so tests can verify outbound + /// `traceparent` / `tracestate` injection on `session.create`, + /// `session.resume`, and `session.send`. + #[cfg(any(test, feature = "test-support"))] + pub fn from_streams_with_trace_provider( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + cwd: PathBuf, + provider: Arc, + ) -> Result { + Self::from_transport(reader, writer, None, cwd, None, false, Some(provider), None) + } + + /// Construct a [`Client`] from raw streams with a preset + /// `effective_connection_token`, for integration testing the + /// `connect` handshake's token-forwarding path. + #[cfg(any(test, feature = "test-support"))] + pub fn from_streams_with_connection_token( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + cwd: PathBuf, + token: Option, + ) -> Result { + Self::from_transport(reader, writer, None, cwd, None, false, None, token) + } + + /// Public test-only wrapper around the random connection-token + /// generator used by [`Client::start`] when the SDK spawns a TCP + /// server without an explicit token. Lets integration tests + /// validate the token shape (32-char lowercase hex, 128 bits of + /// entropy) without re-implementing the helper. + #[cfg(any(test, feature = "test-support"))] + pub fn generate_connection_token_for_test() -> String { + generate_connection_token() + } + + #[allow(clippy::too_many_arguments)] + fn from_transport( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + child: Option, + cwd: PathBuf, + on_list_models: Option>, + session_fs_configured: bool, + on_get_trace_context: Option>, + effective_connection_token: Option, + ) -> Result { + let (request_tx, request_rx) = mpsc::unbounded_channel::(); + let (notification_broadcast_tx, _) = broadcast::channel::(1024); + let rpc = JsonRpcClient::new( + writer, + reader, + notification_broadcast_tx.clone(), + request_tx, + ); + + let pid = child.as_ref().and_then(|c| c.id()); + info!(pid = ?pid, "copilot CLI client ready"); + + let client = Self { + inner: Arc::new(ClientInner { + child: parking_lot::Mutex::new(child), + rpc, + cwd, + request_rx: parking_lot::Mutex::new(Some(request_rx)), + notification_tx: notification_broadcast_tx, + router: router::SessionRouter::new(), + negotiated_protocol_version: OnceLock::new(), + server_telemetry_method: parking_lot::Mutex::new(None), + state: parking_lot::Mutex::new(ConnectionState::Connected), + lifecycle_tx: broadcast::channel(256).0, + on_list_models, + session_fs_configured, + on_get_trace_context, + effective_connection_token, + }), + }; + client.spawn_lifecycle_dispatcher(); + Ok(client) + } + + /// Spawn the background task that re-broadcasts `session.lifecycle` + /// notifications via [`ClientInner::lifecycle_tx`] to subscribers + /// returned by [`Self::subscribe_lifecycle`]. + fn spawn_lifecycle_dispatcher(&self) { + let inner = Arc::clone(&self.inner); + let mut notif_rx = inner.notification_tx.subscribe(); + tokio::spawn(async move { + loop { + match notif_rx.recv().await { + Ok(notification) => { + if notification.method != "session.lifecycle" { + continue; + } + let Some(params) = notification.params.as_ref() else { + continue; + }; + let event: SessionLifecycleEvent = + match serde_json::from_value(params.clone()) { + Ok(e) => e, + Err(e) => { + warn!( + error = %e, + "failed to deserialize session.lifecycle notification" + ); + continue; + } + }; + // `send` only errors when there are no subscribers โ€” that's + // the normal case before any consumer calls subscribe_lifecycle. + let _ = inner.lifecycle_tx.send(event); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + warn!(missed = n, "lifecycle dispatcher lagged"); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }); + } + + fn build_command(program: &Path, options: &ClientOptions) -> Command { + let mut command = Command::new(program); + for arg in &options.prefix_args { + command.arg(arg); + } + // Inject the SDK auth token first so explicit `env` / `env_remove` + // entries can override or strip it. + if let Some(token) = &options.github_token { + command.env("COPILOT_SDK_AUTH_TOKEN", token); + } + // Inject telemetry env vars before user env so callers can still + // override individual variables via `options.env`. + if let Some(telemetry) = &options.telemetry { + command.env("COPILOT_OTEL_ENABLED", "true"); + if let Some(endpoint) = &telemetry.otlp_endpoint { + command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint); + } + if let Some(path) = &telemetry.file_path { + command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path); + } + if let Some(exporter) = telemetry.exporter_type { + command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str()); + } + if let Some(source) = &telemetry.source_name { + command.env("COPILOT_OTEL_SOURCE_NAME", source); + } + if let Some(capture) = telemetry.capture_content { + command.env( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + if capture { "true" } else { "false" }, + ); + } + } + if let Some(home) = &options.copilot_home { + command.env("COPILOT_HOME", home); + } + if let Some(token) = &options.tcp_connection_token { + command.env("COPILOT_CONNECTION_TOKEN", token); + } + for (key, value) in &options.env { + command.env(key, value); + } + for key in &options.env_remove { + command.env_remove(key); + } + command + .current_dir(&options.cwd) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + #[cfg(windows)] + { + use std::os::windows::process::CommandExt; + const CREATE_NO_WINDOW: u32 = 0x08000000; + command.as_std_mut().creation_flags(CREATE_NO_WINDOW); + } + + command + } + + /// Returns the CLI auth flags derived from [`ClientOptions::github_token`] + /// and [`ClientOptions::use_logged_in_user`]. + /// + /// When a token is set, adds `--auth-token-env COPILOT_SDK_AUTH_TOKEN`. + /// When the effective `use_logged_in_user` is `false` (either explicitly + /// or because a token was provided without an override), adds + /// `--no-auto-login`. + fn auth_args(options: &ClientOptions) -> Vec<&'static str> { + let mut args: Vec<&'static str> = Vec::new(); + if options.github_token.is_some() { + args.push("--auth-token-env"); + args.push("COPILOT_SDK_AUTH_TOKEN"); + } + let use_logged_in = options + .use_logged_in_user + .unwrap_or(options.github_token.is_none()); + if !use_logged_in { + args.push("--no-auto-login"); + } + args + } + + /// Returns `--session-idle-timeout ` when + /// [`ClientOptions::session_idle_timeout_seconds`] is `Some(n)` with + /// `n > 0`. Otherwise returns an empty vector. + fn session_idle_timeout_args(options: &ClientOptions) -> Vec { + match options.session_idle_timeout_seconds { + Some(secs) if secs > 0 => { + vec!["--session-idle-timeout".to_string(), secs.to_string()] + } + _ => Vec::new(), + } + } + + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { + info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)"); + let mut command = Self::build_command(program, options); + let log_level = options.log_level.unwrap_or(LogLevel::Info); + command + .args([ + "--server", + "--stdio", + "--no-auto-update", + "--log-level", + log_level.as_str(), + ]) + .args(Self::auth_args(options)) + .args(Self::session_idle_timeout_args(options)) + .args(&options.extra_args) + .stdin(Stdio::piped()); + Ok(command.spawn()?) + } + + async fn spawn_tcp( + program: &Path, + options: &ClientOptions, + port: u16, + ) -> Result<(Child, u16), Error> { + info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); + let mut command = Self::build_command(program, options); + let log_level = options.log_level.unwrap_or(LogLevel::Info); + command + .args([ + "--server", + "--port", + &port.to_string(), + "--no-auto-update", + "--log-level", + log_level.as_str(), + ]) + .args(Self::auth_args(options)) + .args(Self::session_idle_timeout_args(options)) + .args(&options.extra_args) + .stdin(Stdio::null()); + let mut child = command.spawn()?; + let stdout = child.stdout.take().expect("stdout is piped"); + + let (port_tx, port_rx) = oneshot::channel::(); + let span = tracing::error_span!("copilot_cli_port_scan"); + tokio::spawn( + async move { + // Scan stdout for the port announcement. + let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex"); + let mut lines = BufReader::new(stdout).lines(); + let mut port_tx = Some(port_tx); + while let Ok(Some(line)) = lines.next_line().await { + debug!(line = %line, "CLI stdout"); + if let Some(tx) = port_tx.take() { + if let Some(caps) = port_re.captures(&line) + && let Some(p) = + caps.get(1).and_then(|m| m.as_str().parse::().ok()) + { + let _ = tx.send(p); + continue; + } + // Not the port line โ€” put tx back + port_tx = Some(tx); + } + } + } + .instrument(span), + ); + + let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx) + .await + .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))? + .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?; + + info!(port = %actual_port, "CLI server listening"); + Ok((child, actual_port)) + } + + fn drain_stderr(child: &mut Child) { + if let Some(stderr) = child.stderr.take() { + let span = tracing::error_span!("copilot_cli"); + tokio::spawn( + async move { + let mut reader = BufReader::new(stderr).lines(); + while let Ok(Some(line)) = reader.next_line().await { + warn!(line = %line, "CLI stderr"); + } + } + .instrument(span), + ); + } + } + + /// Returns the working directory of the CLI process. + pub fn cwd(&self) -> &PathBuf { + &self.inner.cwd + } + + /// Typed RPC namespace for server-level methods. + /// + /// Every protocol method lives here under its schema-aligned path โ€” + /// e.g. `client.rpc().models().list()`. Wire method names and request/ + /// response types are generated from the protocol schema, so the typed + /// namespace can't drift from the wire contract. + /// + /// The hand-authored helpers on [`Client`] delegate to this namespace + /// and remain the recommended entry point for everyday use; reach for + /// `rpc()` when you want a method without a hand-written wrapper. + pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> { + crate::generated::rpc::ClientRpc { client: self } + } + + /// Send a JSON-RPC request and wait for the response. + pub(crate) async fn send_request( + &self, + method: &str, + params: Option, + ) -> Result { + self.inner.rpc.send_request(method, params).await + } + + /// Send a JSON-RPC request, check for errors, and return the result value. + /// + /// This is the primary method for session-level RPC calls. It wraps + /// the internal send/receive cycle with error checking so callers + /// don't need to inspect the response manually. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The frame is committed to the wire via the + /// writer-actor task before the future yields; cancelling the await + /// (via `tokio::time::timeout`, `select!`, or dropped JoinHandle) + /// drops the response oneshot but does not desync the transport. + /// The pending-requests entry is cleaned up by an RAII guard. + /// However, the call's *side effect* on the CLI may still occur โ€” + /// the CLI receives the request and processes it; the caller just + /// won't see the response. For idempotent methods this is fine; for + /// non-idempotent methods (e.g. `session.create`) the caller should + /// avoid wrapping the call in a timeout shorter than the expected + /// CLI processing window. + pub async fn call( + &self, + method: &str, + params: Option, + ) -> Result { + let session_id: Option = params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()) + .map(SessionId::from); + let response = self.send_request(method, params).await?; + if let Some(err) = response.error { + if err.message.contains("Session not found") { + return Err(Error::Session(SessionError::NotFound( + session_id.unwrap_or_else(|| "unknown".into()), + ))); + } + return Err(Error::Rpc { + code: err.code, + message: err.message, + }); + } + Ok(response.result.unwrap_or(serde_json::Value::Null)) + } + + /// Send a JSON-RPC response back to the CLI (e.g. for permission or tool call requests). + pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> { + self.inner.rpc.write(response).await + } + + /// Take the receiver for incoming JSON-RPC requests from the CLI. + /// + /// Can only be called once โ€” subsequent calls return `None`. + #[expect(dead_code, reason = "reserved for future pub(crate) use")] + pub(crate) fn take_request_rx(&self) -> Option> { + self.inner.request_rx.lock().take() + } + + /// Register a session to receive filtered events and requests. + /// + /// Returns per-session channels for notifications and requests, routed + /// by `sessionId`. Starts the internal router on first call. + /// + /// When done, call [`unregister_session`](Self::unregister_session) to + /// clean up (typically on session destroy). + pub(crate) fn register_session( + &self, + session_id: &SessionId, + ) -> crate::router::SessionChannels { + self.inner + .router + .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); + self.inner.router.register(session_id) + } + + /// Unregister a session, dropping its per-session channels. + pub(crate) fn unregister_session(&self, session_id: &SessionId) { + self.inner.router.unregister(session_id); + } + + /// Returns the protocol version negotiated with the CLI server, if any. + /// + /// Set during [`start`](Self::start). Returns `None` if the server didn't + /// report a version, or if the client was created via + /// [`from_streams`](Self::from_streams) without calling + /// [`verify_protocol_version`](Self::verify_protocol_version). + pub fn protocol_version(&self) -> Option { + self.inner.negotiated_protocol_version.get().copied() + } + + /// Verify the CLI server's protocol version is within the supported range. + /// + /// Called automatically by [`start`](Self::start). Call manually after + /// [`from_streams`](Self::from_streams) if you need version verification + /// on a custom transport. + /// + /// # Handshake sequence + /// + /// 1. Sends the `connect` JSON-RPC method, forwarding + /// [`ClientOptions::tcp_connection_token`] (or the auto-generated + /// token for SDK-spawned TCP servers) as the `token` param. This + /// is the canonical handshake used by all SDK languages and is + /// what the CLI uses to enforce loopback authentication when + /// started with `COPILOT_CONNECTION_TOKEN`. + /// 2. If the server returns `-32601` (`MethodNotFound`), falls back + /// to the legacy `ping` RPC. This preserves compatibility with + /// older CLI versions that predate `connect`. + /// + /// # Result + /// + /// Returns an error if the negotiated `protocolVersion` is outside + /// `MIN_PROTOCOL_VERSION`..=[`SDK_PROTOCOL_VERSION`]. If the server + /// doesn't report a version, logs a warning and succeeds. + pub async fn verify_protocol_version(&self) -> Result<(), Error> { + // Try the new `connect` handshake first (sends the connection + // token, if any). Fall back to `ping` for legacy CLI servers + // that don't expose `connect` (-32601 MethodNotFound). Matches + // the Node SDK's verify-version sequence. + let server_version = match self.connect_handshake().await { + Ok(v) => v, + Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => { + self.ping(None).await?.protocol_version + } + Err(e) => return Err(e), + }; + + match server_version { + None => { + warn!("CLI server did not report protocolVersion; skipping version check"); + } + Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => { + return Err(Error::Protocol(ProtocolError::VersionMismatch { + server: v, + min: MIN_PROTOCOL_VERSION, + max: SDK_PROTOCOL_VERSION, + })); + } + Some(v) => { + if let Some(&existing) = self.inner.negotiated_protocol_version.get() { + if existing != v { + return Err(Error::Protocol(ProtocolError::VersionChanged { + previous: existing, + current: v, + })); + } + } else { + let _ = self.inner.negotiated_protocol_version.set(v); + } + } + } + + Ok(()) + } + + /// Send the `connect` JSON-RPC handshake. Returns the server's + /// reported protocol version, or `None` if the server omits it. + /// Forwards [`ClientOptions::tcp_connection_token`] (or the + /// auto-generated token for SDK-spawned TCP servers) as the `token` + /// param. Server-side, the token is required when the server was + /// started with `COPILOT_CONNECTION_TOKEN`. + async fn connect_handshake(&self) -> Result, Error> { + let result = self + .rpc() + .connect(crate::generated::api_types::ConnectRequest { + token: self.inner.effective_connection_token.clone(), + }) + .await?; + Ok(u32::try_from(result.protocol_version).ok()) + } + + /// Send a `ping` RPC and return the typed [`PingResponse`]. + /// + /// Pass `Some(message)` to have the server echo it back; pass `None` for + /// a bare health check. The response includes a `protocolVersion` when + /// the CLI reports one. + /// + /// [`PingResponse`]: crate::types::PingResponse + pub async fn ping(&self, message: Option<&str>) -> Result { + let params = match message { + Some(m) => serde_json::json!({ "message": m }), + None => serde_json::json!({}), + }; + let value = self + .call(generated::api_types::rpc_methods::PING, Some(params)) + .await?; + Ok(serde_json::from_value(value)?) + } + + /// List persisted sessions, optionally filtered by working directory, + /// repository, or git context. + pub async fn list_sessions( + &self, + filter: Option, + ) -> Result, Error> { + let params = match filter { + Some(f) => serde_json::json!({ "filter": f }), + None => serde_json::json!({}), + }; + let result = self.call("session.list", Some(params)).await?; + let response: ListSessionsResponse = serde_json::from_value(result)?; + Ok(response.sessions) + } + + /// Fetch metadata for a specific persisted session by ID. + /// + /// Returns `Ok(None)` if no session with the given ID exists. More + /// efficient than calling [`list_sessions`](Self::list_sessions) and + /// filtering when you only need data for a single session. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: &github_copilot_sdk::Client) -> Result<(), github_copilot_sdk::Error> { + /// use github_copilot_sdk::types::SessionId; + /// if let Some(metadata) = client.get_session_metadata(&SessionId::new("session-123")).await? { + /// println!("Session started at: {}", metadata.start_time); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn get_session_metadata( + &self, + session_id: &SessionId, + ) -> Result, Error> { + let result = self + .call( + "session.getMetadata", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + let response: GetSessionMetadataResponse = serde_json::from_value(result)?; + Ok(response.session) + } + + /// Delete a persisted session by ID. + pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> { + self.call( + "session.delete", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + Ok(()) + } + + /// Return the ID of the most recently updated session, if any. + /// + /// Useful for resuming the last conversation when the session ID was + /// not stored. Returns `Ok(None)` if no sessions exist. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: &github_copilot_sdk::Client) -> Result<(), github_copilot_sdk::Error> { + /// if let Some(last_id) = client.get_last_session_id().await? { + /// println!("Last session: {last_id}"); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn get_last_session_id(&self) -> Result, Error> { + let result = self + .call("session.getLastId", Some(serde_json::json!({}))) + .await?; + let response: GetLastSessionIdResponse = serde_json::from_value(result)?; + Ok(response.session_id) + } + + /// Return the ID of the session currently displayed in the TUI, if any. + /// + /// Only meaningful when connected to a server running in TUI+server mode + /// (`--ui-server`). Returns `Ok(None)` if no foreground session is set. + pub async fn get_foreground_session_id(&self) -> Result, Error> { + let result = self + .call("session.getForeground", Some(serde_json::json!({}))) + .await?; + let response: GetForegroundSessionResponse = serde_json::from_value(result)?; + Ok(response.session_id) + } + + /// Request that the TUI switch to displaying the specified session. + /// + /// Only meaningful when connected to a server running in TUI+server mode + /// (`--ui-server`). + pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> { + self.call( + "session.setForeground", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + Ok(()) + } + + /// Get the CLI server status. + pub async fn get_status(&self) -> Result { + let result = self.call("status.get", Some(serde_json::json!({}))).await?; + Ok(serde_json::from_value(result)?) + } + + /// Get authentication status. + pub async fn get_auth_status(&self) -> Result { + let result = self + .call("auth.getStatus", Some(serde_json::json!({}))) + .await?; + Ok(serde_json::from_value(result)?) + } + + /// List available models. + /// + /// When [`ClientOptions::on_list_models`] is set, returns the handler's + /// result without making a `models.list` RPC. Otherwise queries the CLI. + pub async fn list_models(&self) -> Result, Error> { + if let Some(handler) = &self.inner.on_list_models { + return handler.list_models().await; + } + Ok(self.rpc().models().list().await?.models) + } + + /// Invoke [`ClientOptions::on_get_trace_context`] when configured, + /// otherwise return [`TraceContext::default()`]. + pub(crate) async fn resolve_trace_context(&self) -> TraceContext { + if let Some(provider) = &self.inner.on_get_trace_context { + provider.get_trace_context().await + } else { + TraceContext::default() + } + } + + /// Send a top-level telemetry event via `sendTelemetry`. + pub async fn send_telemetry(&self, event: ServerTelemetryEvent) -> Result<(), Error> { + let params = serde_json::to_value(event)?; + let cached_method = { *self.inner.server_telemetry_method.lock() }; + if let Some(method) = cached_method { + match self.call(method.as_str(), Some(params.clone())).await { + Ok(_) => return Ok(()), + Err(Error::Rpc { code, .. }) + if code == error_codes::METHOD_NOT_FOUND + && method == ServerTelemetryRpcMethod::SendTelemetry => + { + self.call( + ServerTelemetryRpcMethod::NamespacedSendTelemetry.as_str(), + Some(params), + ) + .await?; + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::NamespacedSendTelemetry); + return Ok(()); + } + Err(error) => return Err(error), + } + } + + match self + .call( + ServerTelemetryRpcMethod::SendTelemetry.as_str(), + Some(params.clone()), + ) + .await + { + Ok(_) => { + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::SendTelemetry); + Ok(()) + } + Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => { + self.call( + ServerTelemetryRpcMethod::NamespacedSendTelemetry.as_str(), + Some(params), + ) + .await?; + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::NamespacedSendTelemetry); + Ok(()) + } + Err(error) => Err(error), + } + } + + /// Fetch account-level quota snapshots (request-based usage). + /// + /// This top-level convenience wrapper is Rust-only as of 0.1.0; the Node, + /// Python, Go, and .NET SDKs do not expose a client-level shortcut for + /// quota lookup. The underlying `account.getQuota` JSON-RPC endpoint is + /// itself available cross-SDK via each SDK's typed `rpc()` namespace + /// (Node `client.rpc().account().getQuota()`, Python + /// `client.rpc().account.get_quota()`, Go `client.Rpc().Account().GetQuota()`, + /// .NET `client.Rpc().Account().GetQuotaAsync()`), and in Rust at + /// `client.rpc().account().get_quota()`. This wrapper is a thin shortcut + /// for that same call. + pub async fn get_quota(&self) -> Result { + self.rpc().account().get_quota().await + } + + /// Return the OS process ID of the CLI child process, if one was spawned. + pub fn pid(&self) -> Option { + self.inner.child.lock().as_ref().and_then(|c| c.id()) + } + + /// Cooperatively shut down the client and the CLI child process. + /// + /// Walks every still-registered session and sends `session.destroy` + /// for each one, then kills the CLI child. Errors from per-session + /// destroys and the final child-kill are collected into + /// [`StopErrors`] rather than short-circuiting on the first failure + /// โ€” so callers see the full picture of teardown. + /// + /// If you have already called [`Session::disconnect`] on every + /// session this client created, the per-session destroy step is a + /// no-op (the router map is empty); only the child-kill remains. + /// + /// [`Session::disconnect`]: crate::session::Session::disconnect + /// + /// # Cancel safety + /// + /// **Cancel-unsafe but recoverable.** The body sequentially destroys + /// every registered session (each via [`Client::call`](Self::call), + /// individually cancel-safe) before killing the child. Cancelling + /// `stop()` mid-loop leaves some sessions still in the router map + /// and the child still running. Recovery: call [`force_stop`](Self::force_stop) + /// (sync, kills the child unconditionally and clears router state) + /// or call `stop()` again with a fresh future. The documented + /// `tokio::time::timeout(..., client.stop())` pattern in the example + /// below uses `force_stop` as the fallback for exactly this case. + pub async fn stop(&self) -> Result<(), StopErrors> { + let pid = self.pid(); + info!(pid = ?pid, "stopping CLI process"); + let mut errors: Vec = Vec::new(); + + // Snapshot the registered session IDs without holding the router + // lock across the destroy RPCs. + for session_id in self.inner.router.session_ids() { + match self + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await + { + Ok(_) => {} + Err(e) => { + warn!( + session_id = %session_id, + error = %e, + "session.destroy failed during Client::stop", + ); + errors.push(e); + } + } + self.inner.router.unregister(&session_id); + } + + let child = self.inner.child.lock().take(); + *self.inner.state.lock() = ConnectionState::Disconnected; + if let Some(mut child) = child + && let Err(e) = child.kill().await + { + errors.push(Error::Io(e)); + } + + info!(pid = ?pid, errors = errors.len(), "CLI process stopped"); + if errors.is_empty() { + Ok(()) + } else { + Err(StopErrors(errors)) + } + } + + /// Forcibly stop the CLI process without waiting for it to exit. + /// + /// Synchronous fallback when [`stop`](Self::stop) is unsuitable โ€” for + /// example when the awaiting tokio runtime is shutting down or the + /// process is wedged on I/O. Sends a kill signal without awaiting + /// reaper completion and immediately drops all per-session router + /// state so dependent tasks observe a closed channel rather than a + /// hang. + /// + /// # Cancel safety + /// + /// **Synchronous and infallible by construction.** Not async; cannot + /// be cancelled. Designed as the recovery path when [`stop`](Self::stop) + /// is wrapped in a timeout that elapses. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: github_copilot_sdk::Client) { + /// // Try graceful shutdown first; fall back to force_stop if hung. + /// match tokio::time::timeout( + /// std::time::Duration::from_secs(5), + /// client.stop(), + /// ).await { + /// Ok(_) => {} + /// Err(_) => client.force_stop(), + /// } + /// # } + /// ``` + pub fn force_stop(&self) { + let pid = self.pid(); + info!(pid = ?pid, "force-stopping CLI process"); + if let Some(mut child) = self.inner.child.lock().take() + && let Err(e) = child.start_kill() + { + error!(pid = ?pid, error = %e, "failed to send kill signal"); + } + // Drop all session channels so any awaiters see a closed channel + // instead of waiting for responses that will never arrive. + self.inner.router.clear(); + *self.inner.state.lock() = ConnectionState::Disconnected; + } + + /// Subscribe to lifecycle events. + /// + /// Returns a [`LifecycleSubscription`] that yields every + /// [`SessionLifecycleEvent`] sent by the CLI. Drop the value to + /// unsubscribe; there is no separate cancel handle. + /// + /// The returned handle implements both an inherent + /// [`recv`](LifecycleSubscription::recv) method and [`Stream`](tokio_stream::Stream), + /// so callers can use a `while let` loop or any combinator from + /// `tokio_stream::StreamExt` / `futures::StreamExt`. + /// + /// Each subscriber maintains its own queue. If a consumer cannot keep + /// up, the oldest events are dropped and `recv` returns + /// [`RecvError::Lagged`] with the count of skipped events; consumers + /// should match on it and continue. Slow consumers do not block the + /// producer. + /// + /// To filter by event type, match on `event.event_type` in the + /// consumer task. There is no built-in typed filter โ€” `match` is more + /// flexible and keeps the API surface small. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: github_copilot_sdk::Client) { + /// let mut events = client.subscribe_lifecycle(); + /// tokio::spawn(async move { + /// while let Ok(event) = events.recv().await { + /// println!("session {} -> {:?}", event.session_id, event.event_type); + /// } + /// }); + /// # } + /// ``` + pub fn subscribe_lifecycle(&self) -> LifecycleSubscription { + LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe()) + } + + /// Return the current [`ConnectionState`]. + /// + /// The state advances to [`Connected`](ConnectionState::Connected) once + /// [`Client::start`] / [`Client::from_streams`] returns successfully and + /// drops to [`Disconnected`](ConnectionState::Disconnected) after + /// [`stop`](Self::stop) or [`force_stop`](Self::force_stop). + pub fn state(&self) -> ConnectionState { + *self.inner.state.lock() + } +} + +impl Drop for ClientInner { + fn drop(&mut self) { + if let Some(ref mut child) = *self.child.lock() { + let pid = child.id(); + if let Err(e) = child.start_kill() { + error!(pid = ?pid, error = %e, "failed to kill CLI process on drop"); + } else { + info!(pid = ?pid, "kill signal sent for CLI process on drop"); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn is_transport_failure_matches_request_cancelled() { + let err = Error::Protocol(ProtocolError::RequestCancelled); + assert!(err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_matches_io_error() { + let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); + assert!(err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_rejects_rpc_error() { + let err = Error::Rpc { + code: -1, + message: "bad".into(), + }; + assert!(!err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_rejects_session_error() { + let err = Error::Session(SessionError::NotFound("s1".into())); + assert!(!err.is_transport_failure()); + } + + #[test] + fn client_options_builder_composes() { + let opts = ClientOptions::new() + .with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot"))) + .with_prefix_args(["node"]) + .with_cwd(PathBuf::from("/tmp")) + .with_env([("KEY", "value")]) + .with_env_remove(["UNWANTED"]) + .with_extra_args(["--quiet"]) + .with_github_token("ghp_test") + .with_use_logged_in_user(false) + .with_log_level(LogLevel::Debug) + .with_session_idle_timeout_seconds(120); + assert!(matches!(opts.program, CliProgram::Path(_))); + assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]); + assert_eq!(opts.cwd, PathBuf::from("/tmp")); + assert_eq!( + opts.env, + vec![( + std::ffi::OsString::from("KEY"), + std::ffi::OsString::from("value") + )] + ); + assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]); + assert_eq!(opts.extra_args, vec!["--quiet".to_string()]); + assert_eq!(opts.github_token.as_deref(), Some("ghp_test")); + assert_eq!(opts.use_logged_in_user, Some(false)); + assert!(matches!(opts.log_level, Some(LogLevel::Debug))); + assert_eq!(opts.session_idle_timeout_seconds, Some(120)); + } + + #[test] + fn is_transport_failure_rejects_other_protocol_errors() { + let err = Error::Protocol(ProtocolError::CliStartupTimeout); + assert!(!err.is_transport_failure()); + } + + #[test] + fn build_command_lets_env_remove_strip_injected_token() { + let opts = ClientOptions { + github_token: Some("secret".to_string()), + env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + // get_envs() iter yields the latest action per key โ€” None means removed. + let action = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .map(|(_, v)| v); + assert_eq!( + action, + Some(None), + "env_remove should win over github_token" + ); + } + + #[test] + fn build_command_lets_env_override_injected_token() { + let opts = ClientOptions { + github_token: Some("from-options".to_string()), + env: vec![( + std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"), + std::ffi::OsString::from("from-env"), + )], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + let value = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .and_then(|(_, v)| v); + assert_eq!(value, Some(std::ffi::OsStr::new("from-env"))); + } + + #[test] + fn build_command_injects_github_token_by_default() { + let opts = ClientOptions { + github_token: Some("just-the-token".to_string()), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + let value = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .and_then(|(_, v)| v); + assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token"))); + } + + fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> { + cmd.as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new(key)) + .and_then(|(_, v)| v) + } + + #[test] + fn telemetry_config_builder_composes() { + let cfg = TelemetryConfig::new() + .with_otlp_endpoint("http://collector:4318") + .with_file_path(PathBuf::from("/var/log/copilot.jsonl")) + .with_exporter_type(OtelExporterType::OtlpHttp) + .with_source_name("my-app") + .with_capture_content(true); + + assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318")); + assert_eq!( + cfg.file_path.as_deref(), + Some(Path::new("/var/log/copilot.jsonl")), + ); + assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp)); + assert_eq!(cfg.source_name.as_deref(), Some("my-app")); + assert_eq!(cfg.capture_content, Some(true)); + assert!(!cfg.is_empty()); + assert!(TelemetryConfig::new().is_empty()); + } + + #[test] + fn build_command_sets_otel_env_when_telemetry_enabled() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://collector:4318".to_string()), + file_path: Some(PathBuf::from("/var/log/copilot.jsonl")), + exporter_type: Some(OtelExporterType::OtlpHttp), + source_name: Some("my-app".to_string()), + capture_content: Some(true), + }), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_ENABLED"), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://collector:4318")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"), + Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"), + Some(std::ffi::OsStr::new("otlp-http")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"), + Some(std::ffi::OsStr::new("my-app")), + ); + assert_eq!( + env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"), + Some(std::ffi::OsStr::new("true")), + ); + } + + #[test] + fn build_command_omits_otel_env_when_telemetry_none() { + let opts = ClientOptions::default(); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + for key in [ + "COPILOT_OTEL_ENABLED", + "OTEL_EXPORTER_OTLP_ENDPOINT", + "COPILOT_OTEL_FILE_EXPORTER_PATH", + "COPILOT_OTEL_EXPORTER_TYPE", + "COPILOT_OTEL_SOURCE_NAME", + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + ] { + assert!( + env_value(&cmd, key).is_none(), + "expected {key} to be unset when telemetry is None", + ); + } + } + + #[test] + fn build_command_omits_unset_telemetry_fields() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://collector:4318".to_string()), + ..Default::default() + }), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + // The one set field plus the implicit enabled flag should propagate. + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_ENABLED"), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://collector:4318")), + ); + // None of the other fields should leak as env vars. + for key in [ + "COPILOT_OTEL_FILE_EXPORTER_PATH", + "COPILOT_OTEL_EXPORTER_TYPE", + "COPILOT_OTEL_SOURCE_NAME", + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + ] { + assert!(env_value(&cmd, key).is_none(), "{key} should be unset"); + } + } + + #[test] + fn build_command_lets_user_env_override_telemetry() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://from-config:4318".to_string()), + ..Default::default() + }), + env: vec![( + std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"), + std::ffi::OsString::from("http://from-user-env:4318"), + )], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://from-user-env:4318")), + "user-supplied options.env should override telemetry config", + ); + } + + #[test] + fn build_command_sets_copilot_home_env_when_configured() { + let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot")); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "COPILOT_HOME"), + Some(std::ffi::OsStr::new("/custom/copilot")), + ); + + let opts = ClientOptions::default(); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert!(env_value(&cmd, "COPILOT_HOME").is_none()); + } + + #[test] + fn build_command_sets_connection_token_env_when_configured() { + let opts = ClientOptions::new().with_tcp_connection_token("secret-token"); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "COPILOT_CONNECTION_TOKEN"), + Some(std::ffi::OsStr::new("secret-token")), + ); + + let opts = ClientOptions::default(); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert!(env_value(&cmd, "COPILOT_CONNECTION_TOKEN").is_none()); + } + + #[tokio::test] + async fn start_rejects_token_with_stdio_transport() { + let opts = ClientOptions::new() + .with_tcp_connection_token("token-123") + .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); + let err = Client::start(opts).await.unwrap_err(); + assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + let Error::InvalidConfig(msg) = err else { + unreachable!() + }; + assert!( + msg.contains("Stdio"), + "error should explain the stdio incompatibility: {msg}" + ); + } + + #[tokio::test] + async fn start_rejects_empty_connection_token() { + let opts = ClientOptions::new() + .with_tcp_connection_token("") + .with_transport(Transport::Tcp { port: 0 }) + .with_program(CliProgram::Path(PathBuf::from("/bin/echo"))); + let err = Client::start(opts).await.unwrap_err(); + assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}"); + } + + #[test] + fn telemetry_config_capture_content_serializes_as_lowercase_bool() { + let opts_true = ClientOptions { + telemetry: Some(TelemetryConfig { + capture_content: Some(true), + ..Default::default() + }), + ..Default::default() + }; + let opts_false = ClientOptions { + telemetry: Some(TelemetryConfig { + capture_content: Some(false), + ..Default::default() + }), + ..Default::default() + }; + let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true); + let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false); + assert_eq!( + env_value( + &cmd_true, + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" + ), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value( + &cmd_false, + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" + ), + Some(std::ffi::OsStr::new("false")), + ); + } + + #[test] + fn session_idle_timeout_args_are_omitted_by_default() { + let opts = ClientOptions::default(); + assert!(Client::session_idle_timeout_args(&opts).is_empty()); + } + + #[test] + fn session_idle_timeout_args_omitted_for_zero() { + let opts = ClientOptions { + session_idle_timeout_seconds: Some(0), + ..Default::default() + }; + assert!(Client::session_idle_timeout_args(&opts).is_empty()); + } + + #[test] + fn session_idle_timeout_args_emit_flag_for_positive_value() { + let opts = ClientOptions { + session_idle_timeout_seconds: Some(300), + ..Default::default() + }; + assert_eq!( + Client::session_idle_timeout_args(&opts), + vec!["--session-idle-timeout".to_string(), "300".to_string()] + ); + } + + #[test] + fn log_level_str_round_trips() { + for level in [ + LogLevel::None, + LogLevel::Error, + LogLevel::Warning, + LogLevel::Info, + LogLevel::Debug, + LogLevel::All, + ] { + let s = level.as_str(); + let json = serde_json::to_string(&level).unwrap(); + assert_eq!(json, format!("\"{s}\"")); + let parsed: LogLevel = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, level); + } + } + + #[test] + fn client_options_debug_redacts_handler() { + struct StubHandler; + #[async_trait] + impl ListModelsHandler for StubHandler { + async fn list_models(&self) -> Result, Error> { + Ok(vec![]) + } + } + let opts = ClientOptions { + on_list_models: Some(Arc::new(StubHandler)), + github_token: Some("secret-token".into()), + ..Default::default() + }; + let debug = format!("{opts:?}"); + assert!(debug.contains("on_list_models: Some(\"\")")); + assert!(debug.contains("github_token: Some(\"\")")); + assert!(!debug.contains("secret-token")); + } + + #[tokio::test] + async fn list_models_uses_on_list_models_handler_when_set() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct CountingHandler { + calls: Arc, + models: Vec, + } + #[async_trait] + impl ListModelsHandler for CountingHandler { + async fn list_models(&self) -> Result, Error> { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(self.models.clone()) + } + } + + let calls = Arc::new(AtomicUsize::new(0)); + let model = Model { + billing: None, + capabilities: ModelCapabilities { + limits: None, + supports: None, + }, + default_reasoning_effort: None, + id: "byok-gpt-4".into(), + name: "BYOK GPT-4".into(), + policy: None, + supported_reasoning_efforts: Vec::new(), + }; + let handler = Arc::new(CountingHandler { + calls: Arc::clone(&calls), + models: vec![model.clone()], + }); + + // We can't call list_models() through Client::start without a CLI, but we + // can exercise the override path by directly constructing a Client whose + // inner has the handler set. This is the same dispatch path as the real + // call; from_streams's None default is replaced via inner construction. + let inner = ClientInner { + child: parking_lot::Mutex::new(None), + rpc: { + let (req_tx, _req_rx) = mpsc::unbounded_channel(); + let (notif_tx, _notif_rx) = broadcast::channel(16); + let (read_pipe, _write_pipe) = tokio::io::duplex(64); + let (_unused_read, write_pipe) = tokio::io::duplex(64); + JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx) + }, + cwd: PathBuf::from("."), + request_rx: parking_lot::Mutex::new(None), + notification_tx: broadcast::channel(16).0, + router: router::SessionRouter::new(), + negotiated_protocol_version: OnceLock::new(), + server_telemetry_method: parking_lot::Mutex::new(None), + state: parking_lot::Mutex::new(ConnectionState::Connected), + lifecycle_tx: broadcast::channel(16).0, + on_list_models: Some(handler), + session_fs_configured: false, + on_get_trace_context: None, + effective_connection_token: None, + }; + let client = Client { + inner: Arc::new(inner), + }; + + let result = client.list_models().await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, "byok-gpt-4"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } +} diff --git a/rust/src/permission.rs b/rust/src/permission.rs new file mode 100644 index 000000000..02db23e06 --- /dev/null +++ b/rust/src/permission.rs @@ -0,0 +1,166 @@ +//! Permission-policy helpers that compose with an existing +//! [`SessionHandler`](crate::handler::SessionHandler). +//! +//! These wrap an inner handler and override **only** permission requests, +//! forwarding every other event (tool calls, user input, elicitation, +//! session events) to the inner handler. Use them when you have a custom +//! tool handler โ€” typically a [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) โ€” +//! but want a one-line policy for permission prompts. +//! +//! For a full handler that approves or denies everything, see +//! [`ApproveAllHandler`](crate::handler::ApproveAllHandler) and +//! [`DenyAllHandler`](crate::handler::DenyAllHandler). +//! +//! # Example +//! +//! ```rust,no_run +//! # use std::sync::Arc; +//! # use github_copilot_sdk::handler::ApproveAllHandler; +//! # use github_copilot_sdk::permission; +//! # use github_copilot_sdk::tool::ToolHandlerRouter; +//! let router = ToolHandlerRouter::new(vec![], Arc::new(ApproveAllHandler)); +//! // Inherit the router's tool dispatch but auto-approve all permission prompts: +//! let handler = permission::approve_all(Arc::new(router)); +//! ``` + +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use crate::types::PermissionRequestData; + +/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is +/// auto-approved. All other events are forwarded to `inner`. +pub fn approve_all(inner: Arc) -> Arc { + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::ApproveAll, + }) +} + +/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is +/// auto-denied. All other events are forwarded to `inner`. +pub fn deny_all(inner: Arc) -> Arc { + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::DenyAll, + }) +} + +/// Wrap `inner` with a closure-based policy: `predicate` is called for each +/// permission request; `true` approves, `false` denies. All other events +/// are forwarded to `inner`. +/// +/// ```rust,no_run +/// # use std::sync::Arc; +/// # use github_copilot_sdk::handler::ApproveAllHandler; +/// # use github_copilot_sdk::permission; +/// let inner = Arc::new(ApproveAllHandler); +/// let handler = permission::approve_if(inner, |data| { +/// // Inspect data.extra (the raw JSON payload) for custom policy. +/// data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") +/// }); +/// # let _ = handler; +/// ``` +pub fn approve_if(inner: Arc, predicate: F) -> Arc +where + F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static, +{ + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::Predicate(Arc::new(predicate)), + }) +} + +enum Policy { + ApproveAll, + DenyAll, + Predicate(Arc bool + Send + Sync>), +} + +struct PermissionOverrideHandler { + inner: Arc, + policy: Policy, +} + +#[async_trait] +impl SessionHandler for PermissionOverrideHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { ref data, .. } => { + let approved = match &self.policy { + Policy::ApproveAll => true, + Policy::DenyAll => false, + Policy::Predicate(f) => f(data), + }; + HandlerResponse::Permission(if approved { + PermissionResult::Approved + } else { + PermissionResult::Denied + }) + } + other => self.inner.on_event(other).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::handler::ApproveAllHandler; + use crate::types::{RequestId, SessionId}; + + fn request() -> HandlerEvent { + HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1"), + request_id: RequestId::new("1"), + data: PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() + }, + } + } + + #[tokio::test] + async fn approve_all_approves_permission_requests() { + let h = approve_all(Arc::new(ApproveAllHandler)); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn deny_all_denies_permission_requests() { + let h = deny_all(Arc::new(ApproveAllHandler)); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied, got {other:?}"), + } + } + + #[tokio::test] + async fn approve_if_consults_predicate() { + let h = approve_if(Arc::new(ApproveAllHandler), |data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied for shell, got {other:?}"), + } + } + + #[tokio::test] + async fn non_permission_events_forward_to_inner() { + let h = deny_all(Arc::new(ApproveAllHandler)); + let event = HandlerEvent::ExitPlanMode { + session_id: SessionId::from("s1"), + data: crate::types::ExitPlanModeData::default(), + }; + match h.on_event(event).await { + HandlerResponse::ExitPlanMode(_) => {} + other => panic!("expected ExitPlanMode forwarded, got {other:?}"), + } + } +} diff --git a/rust/src/resolve.rs b/rust/src/resolve.rs new file mode 100644 index 000000000..8521a4b55 --- /dev/null +++ b/rust/src/resolve.rs @@ -0,0 +1,677 @@ +use std::collections::HashSet; +use std::env; +use std::ffi::OsStr; +use std::path::{Path, PathBuf}; + +use serde::Serialize; +use tracing::warn; + +use crate::Error; + +/// How the copilot binary was resolved. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum BinarySource { + /// Extracted from the build-time embedded binary. + Bundled, + /// Set via `COPILOT_CLI_PATH` environment variable. + EnvOverride, + /// Found on PATH or standard search locations. + Local, +} + +/// Find the `copilot` CLI binary on the system. +/// +/// Checks `COPILOT_CLI_PATH` env var first, then searches PATH and common +/// install locations (homebrew, nvm, nodenv, fnm, volta, cargo, etc.). +/// Use `COPILOT_CLI_NAME` to override the binary name (default: `copilot`). +pub fn copilot_binary() -> Result { + copilot_binary_with_source().map(|(path, _)| path) +} + +/// Like [`copilot_binary`] but also reports how the binary was resolved. +pub fn copilot_binary_with_source() -> Result<(PathBuf, BinarySource), Error> { + if let Ok(value) = env::var("COPILOT_CLI_PATH") { + let candidate = PathBuf::from(value); + if candidate.is_file() { + return Ok((candidate, BinarySource::EnvOverride)); + } + if candidate.is_dir() + && let Some(found) = find_copilot_in_dir(&candidate) + { + return Ok((found, BinarySource::EnvOverride)); + } + warn!(path = %candidate.display(), "COPILOT_CLI_PATH set but not usable"); + } + + if let Some(path) = crate::embeddedcli::path() { + return Ok((path, BinarySource::Bundled)); + } + + for dir in standard_search_paths() { + if let Some(found) = find_copilot_in_dir(&dir) { + return Ok((found, BinarySource::Local)); + } + } + + Err(Error::BinaryNotFound { + name: "copilot", + hint: "ensure the GitHub Copilot CLI is installed and on PATH, or set COPILOT_CLI_PATH. use COPILOT_CLI_NAME to override the binary name (default: copilot)", + }) +} + +/// Find the `copilot` CLI binary using only the current PATH entries. +/// +/// This is intentionally narrower than [`copilot_binary`]: it does not honor +/// override env vars and does not search inferred install locations. +pub fn copilot_binary_on_path() -> Result { + if let Some(found) = find_executable_in_path( + env::var_os("PATH").as_deref(), + &literal_copilot_executable_names(), + ) { + return Ok(found); + } + + Err(Error::BinaryNotFound { + name: "copilot", + hint: "ensure the `copilot` command is installed and available on PATH", + }) +} + +/// Build an extended `PATH` by prepending `extra` dirs to the standard +/// search paths (current PATH + common install locations). +pub fn extended_path(extra: &[PathBuf]) -> Option { + let mut paths = SearchPaths::new(); + for p in extra { + paths.push(p.clone()); + } + paths.append_standard(); + if paths.is_empty() { + return None; + } + env::join_paths(paths).ok() +} + +fn copilot_executable_names() -> Vec { + let base = env::var("COPILOT_CLI_NAME").unwrap_or_else(|_| "copilot".to_string()); + executable_names_for_base(&base) +} + +fn literal_copilot_executable_names() -> Vec { + executable_names_for_base("copilot") +} + +fn executable_names_for_base(base: &str) -> Vec { + #[cfg(target_os = "windows")] + { + vec![ + format!("{}.exe", base), + format!("{}.cmd", base), + format!("{}.bat", base), + ] + } + #[cfg(not(target_os = "windows"))] + { + vec![base.to_string()] + } +} + +fn find_executable(dir: &Path, names: &[impl AsRef]) -> Option { + if dir.as_os_str().is_empty() { + return None; + } + names + .iter() + .map(|n| dir.join(n.as_ref())) + .find(|c| c.is_file()) +} + +fn find_copilot_in_dir(dir: &Path) -> Option { + find_executable(dir, &copilot_executable_names()) +} + +fn find_executable_in_path( + path_env: Option<&OsStr>, + names: &[impl AsRef], +) -> Option { + let path_env = path_env?; + for dir in env::split_paths(path_env) { + if let Some(found) = find_executable(&dir, names) { + return Some(found); + } + } + None +} + +/// Ordered, deduplicated collection of directory paths to search for binaries. +/// +/// Paths are stored in insertion order. Duplicates and empty paths are +/// silently dropped on `push`. Implements `Iterator` so it can be passed +/// directly to `env::join_paths` or used in a `for` loop. +struct SearchPaths { + seen: HashSet, + paths: Vec, +} + +impl SearchPaths { + fn new() -> Self { + Self { + seen: HashSet::new(), + paths: Vec::new(), + } + } + + /// Add a path if it hasn't been seen before. Empty paths are ignored. + fn push(&mut self, path: PathBuf) { + if !path.as_os_str().is_empty() && self.seen.insert(path.clone()) { + self.paths.push(path); + } + } + + fn is_empty(&self) -> bool { + self.paths.is_empty() + } + + /// Append the standard search paths: current PATH, home-relative dirs, + /// version manager paths (nvm, nodenv, fnm), and platform-specific dirs. + fn append_standard(&mut self) { + if let Some(existing) = env::var_os("PATH") { + for p in env::split_paths(&existing) { + self.push(p); + } + } + + if let Some(home) = dirs::home_dir() { + self.push(home.join(".local/bin")); + self.push(home.join(".cargo/bin")); + self.push(home.join(".bun/bin")); + self.push(home.join(".npm-global/bin")); + self.push(home.join(".yarn/bin")); + self.push(home.join(".volta/bin")); + self.push(home.join(".asdf/shims")); + self.push(home.join("bin")); + } + + // Platform-specific standard dirs come before version-manager paths + // so that the system-installed node (e.g. /opt/homebrew/bin/node) + // takes precedence over arbitrary old versions found under + // ~/.nvm/versions, ~/.nodenv/versions, etc. + #[cfg(target_os = "macos")] + { + self.push(PathBuf::from("/opt/homebrew/bin")); + self.push(PathBuf::from("/usr/local/bin")); + self.push(PathBuf::from("/usr/bin")); + self.push(PathBuf::from("/bin")); + self.push(PathBuf::from("/usr/sbin")); + self.push(PathBuf::from("/sbin")); + } + + #[cfg(target_os = "linux")] + { + self.push(PathBuf::from("/usr/local/bin")); + self.push(PathBuf::from("/usr/bin")); + self.push(PathBuf::from("/bin")); + self.push(PathBuf::from("/snap/bin")); + } + + #[cfg(target_os = "windows")] + { + if let Some(appdata) = env::var_os("APPDATA") { + self.push(PathBuf::from(appdata).join("npm")); + } + if let Some(local) = env::var_os("LOCALAPPDATA") { + let local = PathBuf::from(local); + self.push(local.join("Programs")); + // User-scope winget install of Git for Windows. + self.push(local.join("Programs").join("Git").join("cmd")); + self.push(local.join("Programs").join("Git").join("bin")); + } + // Git for Windows standard machine-scope install locations. + for env_var in ["ProgramFiles", "ProgramW6432", "ProgramFiles(x86)"] { + if let Some(program_files) = env::var_os(env_var) { + let program_files = PathBuf::from(program_files); + self.push(program_files.join("Git").join("cmd")); + self.push(program_files.join("Git").join("bin")); + } + } + } + + // Version manager paths are a fallback for binary discovery โ€” + // they enumerate every installed version, so an arbitrary old + // node/copilot can appear first if filesystem ordering is unlucky. + for p in collect_nvm_paths() { + self.push(p); + } + for p in collect_nodenv_paths() { + self.push(p); + } + for p in collect_fnm_paths() { + self.push(p); + } + } +} + +impl IntoIterator for SearchPaths { + type IntoIter = std::vec::IntoIter; + type Item = PathBuf; + + fn into_iter(self) -> Self::IntoIter { + self.paths.into_iter() + } +} + +/// Collect standard search paths for binary resolution. +fn standard_search_paths() -> SearchPaths { + let mut paths = SearchPaths::new(); + paths.append_standard(); + paths +} + +fn collect_nvm_paths() -> Vec { + let mut paths = Vec::new(); + let nvm_dir = env::var_os("NVM_DIR") + .map(PathBuf::from) + .or_else(|| dirs::home_dir().map(|home| home.join(".nvm"))); + let Some(nvm_dir) = nvm_dir else { + return paths; + }; + let versions_dir = nvm_dir.join("versions").join("node"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => return paths, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("bin")); + } + } + paths +} + +fn collect_nodenv_paths() -> Vec { + let mut paths = Vec::new(); + let root = env::var_os("NODENV_ROOT") + .map(PathBuf::from) + .or_else(|| dirs::home_dir().map(|home| home.join(".nodenv"))); + let Some(root) = root else { + return paths; + }; + let versions_dir = root.join("versions"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => return paths, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("bin")); + } + } + paths +} + +fn fnm_root_candidates_from( + fnm_dir: Option, + xdg_data_home: Option, + home: Option, +) -> Vec { + let mut roots = SearchPaths::new(); + + if let Some(fnm_dir) = fnm_dir.filter(|path| !path.as_os_str().is_empty()) { + roots.push(fnm_dir); + } + + if let Some(xdg_data_home) = xdg_data_home.filter(|path| !path.as_os_str().is_empty()) { + roots.push(xdg_data_home.join("fnm")); + } + + if let Some(home) = home { + roots.push(home.join(".local").join("share").join("fnm")); + roots.push(home.join(".fnm")); + } + + roots.paths +} + +fn collect_fnm_paths() -> Vec { + let roots = fnm_root_candidates_from( + env::var_os("FNM_DIR").map(PathBuf::from), + env::var_os("XDG_DATA_HOME").map(PathBuf::from), + dirs::home_dir(), + ); + + let mut paths = SearchPaths::new(); + for root in &roots { + paths.push(root.join("aliases").join("default").join("bin")); + + let versions_dir = root.join("node-versions"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => continue, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("installation").join("bin")); + } + } + } + + paths.paths +} + +#[cfg(test)] +mod tests { + use std::path::{Path, PathBuf}; + use std::{env, fs}; + + use serial_test::serial; + use tempfile::tempdir; + + use super::{ + copilot_binary_on_path, find_executable_in_path, fnm_root_candidates_from, + literal_copilot_executable_names, + }; + + #[test] + fn fnm_root_candidates_include_xdg_and_legacy_locations() { + let home = PathBuf::from("/tmp/copilot-home"); + + let roots = fnm_root_candidates_from(None, None, Some(home.clone())); + + assert_eq!( + roots, + vec![ + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + } + + #[test] + fn fnm_root_candidates_prefer_explicit_locations_first() { + let home = PathBuf::from("/tmp/copilot-home"); + let explicit_fnm_dir = PathBuf::from("/tmp/custom-fnm"); + let xdg_data_home = PathBuf::from("/tmp/xdg-data"); + + let roots = fnm_root_candidates_from( + Some(explicit_fnm_dir.clone()), + Some(xdg_data_home.clone()), + Some(home.clone()), + ); + + assert_eq!( + roots, + vec![ + explicit_fnm_dir, + xdg_data_home.join("fnm"), + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + } + + #[test] + fn fnm_root_candidates_ignore_empty_xdg_data_home() { + let home = PathBuf::from("/tmp/copilot-home"); + + let roots = fnm_root_candidates_from(None, Some(PathBuf::new()), Some(home.clone())); + + assert_eq!( + roots, + vec![ + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + assert!(!roots.iter().any(|path| path == &PathBuf::from("fnm"))); + } + + #[test] + fn fnm_root_produces_expected_bin_paths() { + let temp_dir = tempdir().expect("should create temp dir"); + let root = temp_dir.path().join("fnm-root"); + let alias_bin = root.join("aliases").join("default").join("bin"); + let version_bin = root + .join("node-versions") + .join("v22.18.0") + .join("installation") + .join("bin"); + + fs::create_dir_all(&alias_bin).expect("should create fnm alias bin"); + fs::create_dir_all(&version_bin).expect("should create fnm version bin"); + + let roots = fnm_root_candidates_from(Some(root.clone()), None, None); + assert_eq!(roots, vec![root.clone()]); + + // Verify the expected bin paths exist under the root structure + assert!(alias_bin.is_dir()); + assert!(version_bin.is_dir()); + } + + #[test] + fn find_copilot_in_path_finds_binary_in_path_entries() { + let temp_dir = tempdir().expect("should create temp dir"); + let bin_dir = temp_dir.path().join("bin"); + fs::create_dir_all(&bin_dir).expect("should create bin dir"); + + let executable_name = literal_copilot_executable_names() + .into_iter() + .next() + .expect("should provide a copilot executable name"); + let executable_path = bin_dir.join(&executable_name); + fs::write(&executable_path, "#!/bin/sh\n").expect("should create fake binary"); + + let path_env = + env::join_paths([Path::new("/missing"), bin_dir.as_path()]).expect("should build PATH"); + + assert_eq!( + find_executable_in_path( + Some(path_env.as_os_str()), + &literal_copilot_executable_names() + ), + Some(executable_path) + ); + } + + #[test] + fn find_copilot_in_path_ignores_missing_entries() { + let path_env = env::join_paths([Path::new("/missing-one"), Path::new("/missing-two")]) + .expect("should build PATH"); + + assert_eq!( + find_executable_in_path( + Some(path_env.as_os_str()), + &literal_copilot_executable_names() + ), + None + ); + } + + #[test] + #[serial] + #[cfg(target_os = "macos")] + fn platform_dirs_precede_version_manager_dirs() { + let temp = tempdir().expect("should create temp dir"); + let fake_home = temp.path().join("home"); + + // Create fake nvm version dirs so collect_nvm_paths() returns entries. + let nvm_dir = fake_home.join(".nvm"); + let nvm_version_bin = nvm_dir + .join("versions") + .join("node") + .join("v18.0.0") + .join("bin"); + fs::create_dir_all(&nvm_version_bin).expect("should create nvm version bin"); + + // Create fake nodenv version dirs. + let nodenv_root = fake_home.join(".nodenv"); + let nodenv_version_bin = nodenv_root.join("versions").join("20.0.0").join("bin"); + fs::create_dir_all(&nodenv_version_bin).expect("should create nodenv version bin"); + + // Create fake fnm version dirs. + let fnm_root = fake_home.join(".local").join("share").join("fnm"); + let fnm_version_bin = fnm_root + .join("node-versions") + .join("v22.0.0") + .join("installation") + .join("bin"); + fs::create_dir_all(&fnm_version_bin).expect("should create fnm version bin"); + + // Save env vars. + let prev_path = env::var_os("PATH"); + let prev_home = env::var_os("HOME"); + let prev_nvm_dir = env::var_os("NVM_DIR"); + let prev_nodenv_root = env::var_os("NODENV_ROOT"); + let prev_fnm_dir = env::var_os("FNM_DIR"); + let prev_xdg_data_home = env::var_os("XDG_DATA_HOME"); + + // Set env: empty PATH so only append_standard() dirs appear, + // HOME to our fake home, and explicit version-manager roots. + // Safety: test-only, single-threaded via #[serial]. + unsafe { + env::set_var("PATH", ""); + env::set_var("HOME", &fake_home); + env::set_var("NVM_DIR", &nvm_dir); + env::set_var("NODENV_ROOT", &nodenv_root); + env::remove_var("FNM_DIR"); + env::remove_var("XDG_DATA_HOME"); + } + + let paths: Vec = super::standard_search_paths().into_iter().collect(); + + // Restore env vars. + // Safety: test-only, single-threaded via #[serial]. + unsafe { + match prev_path { + Some(v) => env::set_var("PATH", v), + None => env::remove_var("PATH"), + } + match prev_home { + Some(v) => env::set_var("HOME", v), + None => env::remove_var("HOME"), + } + match prev_nvm_dir { + Some(v) => env::set_var("NVM_DIR", v), + None => env::remove_var("NVM_DIR"), + } + match prev_nodenv_root { + Some(v) => env::set_var("NODENV_ROOT", v), + None => env::remove_var("NODENV_ROOT"), + } + match prev_fnm_dir { + Some(v) => env::set_var("FNM_DIR", v), + None => env::remove_var("FNM_DIR"), + } + match prev_xdg_data_home { + Some(v) => env::set_var("XDG_DATA_HOME", v), + None => env::remove_var("XDG_DATA_HOME"), + } + } + + let platform_dirs: Vec = vec![ + PathBuf::from("/opt/homebrew/bin"), + PathBuf::from("/usr/local/bin"), + PathBuf::from("/usr/bin"), + PathBuf::from("/bin"), + PathBuf::from("/usr/sbin"), + PathBuf::from("/sbin"), + ]; + + // Find the last platform dir index and the first version-manager dir index. + let last_platform_idx = platform_dirs + .iter() + .filter_map(|d| paths.iter().position(|p| p == d)) + .max() + .expect("at least one platform dir should be present"); + + let version_manager_prefixes = [ + nvm_version_bin.parent().unwrap().parent().unwrap(), // .nvm/versions/node + nodenv_version_bin.parent().unwrap().parent().unwrap(), // .nodenv/versions + fnm_version_bin + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap(), // .local/share/fnm + ]; + + let first_version_mgr_idx = paths + .iter() + .position(|p| { + version_manager_prefixes + .iter() + .any(|prefix| p.starts_with(prefix)) + }) + .expect("at least one version-manager dir should be present"); + + assert!( + last_platform_idx < first_version_mgr_idx, + "Platform dirs (last at index {last_platform_idx}) must precede \ + version-manager dirs (first at index {first_version_mgr_idx}).\n\ + Full path list: {paths:#?}" + ); + } + + #[test] + #[serial] + fn find_executable_in_path_can_ignore_copilot_name_override() { + let temp_dir = tempdir().expect("should create temp dir"); + let bin_dir = temp_dir.path().join("bin"); + fs::create_dir_all(&bin_dir).expect("should create bin dir"); + + let path_executable_name = literal_copilot_executable_names() + .into_iter() + .next() + .expect("should provide a literal copilot executable name"); + #[cfg(target_os = "windows")] + let overridden_executable_name = "my-copilot.exe"; + + #[cfg(not(target_os = "windows"))] + let overridden_executable_name = "my-copilot"; + + let path_executable_path = bin_dir.join(&path_executable_name); + let overridden_executable_path = bin_dir.join(overridden_executable_name); + + fs::write(&path_executable_path, "#!/bin/sh\n").expect("should create literal fake binary"); + fs::write(&overridden_executable_path, "#!/bin/sh\n") + .expect("should create overridden fake binary"); + + let path_env = + env::join_paths([Path::new("/missing"), bin_dir.as_path()]).expect("should build PATH"); + + let previous_path = env::var_os("PATH"); + let previous_copilot_cli_name = env::var_os("COPILOT_CLI_NAME"); + // Safety: test-only, single-threaded via #[serial]. + unsafe { + env::set_var("PATH", &path_env); + env::set_var("COPILOT_CLI_NAME", "my-copilot"); + } + + let resolved_path = copilot_binary_on_path(); + + // Safety: test-only, single-threaded via #[serial]. + unsafe { + if let Some(previous_path) = previous_path { + env::set_var("PATH", previous_path); + } else { + env::remove_var("PATH"); + } + + if let Some(previous_copilot_cli_name) = previous_copilot_cli_name { + env::set_var("COPILOT_CLI_NAME", previous_copilot_cli_name); + } else { + env::remove_var("COPILOT_CLI_NAME"); + } + } + + assert_eq!( + resolved_path.expect("should find the literal copilot binary on PATH"), + path_executable_path + ); + } +} diff --git a/rust/src/router.rs b/rust/src/router.rs new file mode 100644 index 000000000..e14630e03 --- /dev/null +++ b/rust/src/router.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::Mutex; +use tokio::sync::{broadcast, mpsc}; +use tracing::warn; + +use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest}; +use crate::types::{SessionEventNotification, SessionId}; + +/// Per-session channels created by the router during session registration. +pub(crate) struct SessionChannels { + /// Filtered `session.event` notifications for this session. + pub(crate) notifications: mpsc::UnboundedReceiver, + /// Filtered JSON-RPC requests (tool.call, userInput.request, etc.) for this session. + pub(crate) requests: mpsc::UnboundedReceiver, +} + +struct SessionSenders { + notifications: mpsc::UnboundedSender, + requests: mpsc::UnboundedSender, +} + +/// Routes notifications and requests by sessionId to per-session channels. +/// +/// Internal to the SDK โ€” consumers interact via `Client::register_session()`. +pub(crate) struct SessionRouter { + sessions: Arc>>, + started: Mutex, +} + +impl SessionRouter { + pub(crate) fn new() -> Self { + Self { + sessions: Arc::new(Mutex::new(HashMap::new())), + started: Mutex::new(false), + } + } + + /// Register a session to receive filtered events and requests. + pub(crate) fn register(&self, session_id: &SessionId) -> SessionChannels { + let (notif_tx, notif_rx) = mpsc::unbounded_channel(); + let (req_tx, req_rx) = mpsc::unbounded_channel(); + self.sessions.lock().insert( + session_id.clone(), + SessionSenders { + notifications: notif_tx, + requests: req_tx, + }, + ); + SessionChannels { + notifications: notif_rx, + requests: req_rx, + } + } + + /// Unregister a session, dropping its channels. + pub(crate) fn unregister(&self, session_id: &SessionId) { + self.sessions.lock().remove(session_id.as_str()); + } + + /// Snapshot every currently-registered session ID. + /// + /// Used by [`Client::stop`](crate::Client::stop) to iterate active + /// sessions for cooperative shutdown without holding the router lock + /// across `.await`. + pub(crate) fn session_ids(&self) -> Vec { + self.sessions.lock().keys().cloned().collect() + } + + /// Drop all registered session channels. + /// + /// Used by [`Client::force_stop`](crate::Client::force_stop) to release + /// per-session state without waiting for graceful unregistration. + pub(crate) fn clear(&self) { + self.sessions.lock().clear(); + } + + /// Start the router tasks if not already running. + /// + /// Takes the notification broadcast and request channel from the Client. + /// If `request_rx` is `None` (already taken by `take_request_rx()`), + /// only notification routing is available. + pub(crate) fn ensure_started( + &self, + notification_tx: &broadcast::Sender, + request_rx: &Mutex>>, + ) { + let mut started = self.started.lock(); + if *started { + return; + } + *started = true; + + // Notification routing task + let sessions = self.sessions.clone(); + let mut notif_rx = notification_tx.subscribe(); + tokio::spawn(async move { + loop { + match notif_rx.recv().await { + Ok(notification) => { + if notification.method != "session.event" { + continue; + } + let Some(ref params) = notification.params else { + continue; + }; + let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) + else { + continue; + }; + + let sender = { + let guard = sessions.lock(); + guard.get(session_id).map(|s| s.notifications.clone()) + }; + if let Some(sender) = sender { + match serde_json::from_value::(params.clone()) + { + Ok(event_notification) => { + let _ = sender.send(event_notification); + } + Err(e) => { + warn!( + error = %e, + session_id = session_id, + "failed to deserialize session event notification" + ); + } + } + } + // Unknown session IDs are silently dropped โ€” the session + // may have been unregistered between dispatch and delivery. + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(missed = n, "notification router lagged"); + } + Err(broadcast::error::RecvError::Closed) => break, + } + } + }); + + // Request routing task (if request_rx is available) + if let Some(mut rx) = request_rx.lock().take() { + let sessions = self.sessions.clone(); + tokio::spawn(async move { + while let Some(request) = rx.recv().await { + let session_id = request + .params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()); + + if let Some(sid) = session_id { + let sender = { + let guard = sessions.lock(); + guard.get(sid).map(|s| s.requests.clone()) + }; + if let Some(sender) = sender { + let _ = sender.send(request); + } else { + warn!( + session_id = sid, + method = %request.method, + "request for unregistered session" + ); + } + } else { + warn!( + method = %request.method, + "request missing sessionId" + ); + } + } + }); + } + } +} diff --git a/rust/src/sdk_protocol_version.rs b/rust/src/sdk_protocol_version.rs new file mode 100644 index 000000000..21089f99e --- /dev/null +++ b/rust/src/sdk_protocol_version.rs @@ -0,0 +1,13 @@ +// Code generated by update-protocol-version.ts. DO NOT EDIT. + +//! The SDK protocol version. Must match the version expected by the +//! copilot-agent-runtime server. + +/// The SDK protocol version. +pub const SDK_PROTOCOL_VERSION: u32 = 3; + +/// Returns the SDK protocol version. +#[must_use] +pub const fn get_sdk_protocol_version() -> u32 { + SDK_PROTOCOL_VERSION +} diff --git a/rust/src/session.rs b/rust/src/session.rs new file mode 100644 index 000000000..e20f24114 --- /dev/null +++ b/rust/src/session.rs @@ -0,0 +1,1991 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use parking_lot::Mutex as ParkingLotMutex; +use serde_json::Value; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; +use tracing::{Instrument, warn}; + +use crate::generated::api_types::{ + LogRequest, ModeSetRequest, ModelSwitchToRequest, NameSetRequest, PermissionDecision, + PermissionDecisionApproveOnce, PermissionDecisionApproveOnceKind, PermissionDecisionReject, + PermissionDecisionRejectKind, PlanUpdateRequest, SessionMode, WorkspacesCreateFileRequest, + WorkspacesReadFileRequest, +}; +use crate::generated::session_events::{ + CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, SessionErrorData, + SessionEventType, +}; +use crate::handler::{ + AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, PermissionResult, + SessionHandler, UserInputResponse, +}; +use crate::hooks::SessionHooks; +use crate::session_fs::SessionFsProvider; +use crate::trace_context::inject_trace_context; +use crate::transforms::SystemMessageTransform; +use crate::types::{ + CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, + ElicitationResult, ExitPlanModeData, GetMessagesResponse, InputOptions, MessageOptions, + PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, + SessionConfig, SessionEvent, SessionId, SessionTelemetryEvent, SetModelOptions, + SystemMessageConfig, ToolInvocation, ToolResult, ToolResultResponse, TraceContext, + ensure_attachment_display_names, +}; +use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; + +/// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. +struct IdleWaiter { + tx: oneshot::Sender, Error>>, + last_assistant_message: Option, +} + +/// RAII guard that clears the [`Session::idle_waiter`] slot on drop. Used +/// by [`Session::send_and_wait`] to ensure the slot doesn't leak if the +/// caller's future is cancelled (outer `tokio::time::timeout` / `select!` +/// / dropped JoinHandle). Synchronous clear via `parking_lot::Mutex` โ€” +/// no async drop needed. +/// +/// Without this, an outer cancellation between "install waiter" and +/// "drain channel" would leave the slot occupied, causing all subsequent +/// `send` and `send_and_wait` calls on the session to return +/// [`SendWhileWaiting`](SessionError::SendWhileWaiting). Closes RFD-400 +/// review finding #2. +struct WaiterGuard { + slot: Arc>>, +} + +impl Drop for WaiterGuard { + fn drop(&mut self) { + self.slot.lock().take(); + } +} + +/// A session on a GitHub Copilot CLI server. +/// +/// Created via [`Client::create_session`] or [`Client::resume_session`]. +/// Owns an internal event loop that dispatches events to the [`SessionHandler`]. +/// +/// Protocol methods (`send`, `get_messages`, `abort`, etc.) automatically +/// inject the session ID into RPC params. +/// +/// Call [`destroy`](Self::destroy) for graceful cleanup (RPC + local). If dropped +/// without calling `destroy`, the `Drop` impl aborts the event loop and +/// unregisters from the router as a best-effort safety net. +pub struct Session { + id: SessionId, + cwd: PathBuf, + workspace_path: Option, + remote_url: Option, + client: Client, + /// Handle to the spawned event-loop task. Sync `parking_lot::Mutex` + /// because the lock is never held across an `.await` and the `Drop` + /// impl needs to take the handle synchronously without `try_lock` + /// fallibility. + event_loop: ParkingLotMutex>>, + /// Cooperative shutdown signal for the event loop. The loop selects + /// on [`shutdown.cancelled()`](CancellationToken::cancelled) alongside + /// its inbound channels; [`Session::stop_event_loop`] and [`Drop`] + /// both call [`cancel()`](CancellationToken::cancel) to ask the loop + /// to exit between iterations rather than aborting the task (which + /// can land at any await point and leave the session mid-protocol). + /// See RFD-400 review finding #3. + /// + /// `CancellationToken` is the canonical signalling primitive in + /// `tokio_util`; it is what `tonic` uses for the equivalent task- + /// coordination case. Advanced consumers can obtain a child token + /// via [`Session::cancellation_token`] to bind their own work to + /// the session lifetime. + shutdown: CancellationToken, + /// Only populated while a `send_and_wait` call is in flight. + /// + /// Sync `parking_lot::Mutex` because the lock is never held across an + /// `.await`, and synchronous access lets the `WaiterGuard` RAII helper + /// in `send_and_wait` clear the slot from a `Drop` impl on caller-side + /// cancellation. See RFD-400 review (cancel-safety hardening). + idle_waiter: Arc>>, + /// Capabilities negotiated with the CLI, updated on `capabilities.changed` events. + capabilities: Arc>, + /// Broadcast channel for runtime event subscribers โ€” see [`Session::subscribe`]. + event_tx: tokio::sync::broadcast::Sender, +} + +impl Session { + /// Session ID assigned by the CLI. + pub fn id(&self) -> &SessionId { + &self.id + } + + /// Working directory of the CLI process. + pub fn cwd(&self) -> &PathBuf { + &self.cwd + } + + /// Workspace directory for the session (if using infinite sessions). + pub fn workspace_path(&self) -> Option<&Path> { + self.workspace_path.as_deref() + } + + /// Remote session URL, if the session is running remotely. + pub fn remote_url(&self) -> Option<&str> { + self.remote_url.as_deref() + } + + /// Session capabilities negotiated with the CLI. + /// + /// Capabilities are set during session creation and updated at runtime + /// via `capabilities.changed` events. + pub fn capabilities(&self) -> SessionCapabilities { + self.capabilities.read().clone() + } + + /// Returns a [`CancellationToken`] that fires when this session shuts + /// down (via [`Session::stop_event_loop`], [`Session::destroy`], or + /// [`Drop`]). + /// + /// Use this to bind an external task's lifetime to the session โ€” when + /// the session shuts down, awaiting [`cancelled()`](CancellationToken::cancelled) + /// resolves so cooperative consumers can stop cleanly. + /// + /// The returned handle is a *child* token: calling + /// [`cancel()`](CancellationToken::cancel) on it cancels only the + /// caller's child, not the session itself. To cancel the session, call + /// [`Session::stop_event_loop`]. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(session: github_copilot_sdk::session::Session) { + /// let token = session.cancellation_token(); + /// tokio::select! { + /// _ = token.cancelled() => println!("session shut down"), + /// _ = tokio::time::sleep(std::time::Duration::from_secs(60)) => { + /// println!("60s elapsed, session still alive"); + /// } + /// } + /// # } + /// ``` + pub fn cancellation_token(&self) -> CancellationToken { + self.shutdown.child_token() + } + + /// Subscribe to events for this session. + /// + /// Returns an [`EventSubscription`](crate::subscription::EventSubscription) + /// that yields every [`SessionEvent`] dispatched on this session's + /// event loop. Drop the value to unsubscribe; there is no separate + /// cancel handle. + /// + /// **Observe-only.** Subscribers receive a clone of every + /// [`SessionEvent`] but cannot influence permission decisions, tool + /// results, or anything else that requires returning a + /// [`HandlerResponse`]. Those remain + /// the responsibility of the [`SessionHandler`] passed via + /// [`SessionConfig::handler`](crate::types::SessionConfig::handler). + /// + /// The returned handle implements both an inherent + /// [`recv`](crate::subscription::EventSubscription::recv) method and + /// [`Stream`](tokio_stream::Stream), so callers can use a `while let` + /// loop or any combinator from `tokio_stream::StreamExt` / + /// `futures::StreamExt`. + /// + /// Each subscriber maintains its own queue. If a consumer cannot keep + /// up, the oldest events are dropped and `recv` returns + /// [`RecvError::Lagged`](crate::subscription::RecvError::Lagged) + /// reporting the count of skipped events. Slow consumers do not block + /// the session's event loop. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(session: github_copilot_sdk::session::Session) { + /// let mut events = session.subscribe(); + /// tokio::spawn(async move { + /// while let Ok(event) = events.recv().await { + /// println!("[{}] event {}", event.id, event.event_type); + /// } + /// }); + /// # } + /// ``` + pub fn subscribe(&self) -> crate::subscription::EventSubscription { + crate::subscription::EventSubscription::new(self.event_tx.subscribe()) + } + + /// The underlying Client (for advanced use cases). + pub fn client(&self) -> &Client { + &self.client + } + + /// Typed RPC namespace for this session. + /// + /// Every protocol method lives here under its schema-aligned path โ€” + /// e.g. `session.rpc().workspaces().list_files()`. Wire method names + /// and request/response types are generated from the protocol schema, + /// so the typed namespace can't drift from the wire contract. + /// + /// The hand-authored helpers on [`Session`] delegate to this namespace + /// and remain the recommended entry point for everyday use; reach for + /// `rpc()` when you want a method without a hand-written wrapper. + pub fn rpc(&self) -> crate::generated::rpc::SessionRpc<'_> { + crate::generated::rpc::SessionRpc { session: self } + } + + /// Stop the internal event loop. Called automatically on [`destroy`](Self::destroy). + /// + /// Cooperative: signals shutdown via the session's [`CancellationToken`] + /// and awaits the loop's natural exit rather than aborting the task. + /// Any in-flight handler (permission callback, tool call, elicitation + /// response) completes before the loop exits, so the CLI never sees a + /// half-handled request. See RFD-400 review finding #3. + pub async fn stop_event_loop(&self) { + self.shutdown.cancel(); + let handle = self.event_loop.lock().take(); + if let Some(handle) = handle { + let _ = handle.await; + } + // Fail any pending send_and_wait so it returns immediately. + if let Some(waiter) = self.idle_waiter.lock().take() { + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::EventLoopClosed))); + } + } + + /// Send a user message to the agent. + /// + /// Accepts anything convertible to [`MessageOptions`] โ€” pass a `&str` for the + /// trivial case, or build a `MessageOptions` for mode/attachments. The + /// `wait_timeout` field on `MessageOptions` is ignored here (use + /// [`send_and_wait`](Self::send_and_wait) if you need to wait). + /// + /// Returns the assigned message ID, which can be used to correlate the + /// send with later [`SessionEvent`]s emitted in + /// response (assistant messages, tool requests, etc.). + /// + /// Returns an error if a [`send_and_wait`](Self::send_and_wait) call is + /// currently in flight, since the plain send would race with the waiter. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The underlying `session.send` RPC is dispatched + /// through the writer-actor (see [`Client::call`](crate::Client::call)), + /// so dropping this future after the actor has committed to writing + /// will not produce a partial frame on the wire. If the caller's + /// future is dropped between "frame enqueued" and "response received", + /// the message has already landed on the wire โ€” the agent will process + /// it and emit events normally; the caller just won't see the returned + /// message ID. + pub async fn send(&self, opts: impl Into) -> Result { + if self.idle_waiter.lock().is_some() { + return Err(Error::Session(SessionError::SendWhileWaiting)); + } + self.send_inner(opts.into()).await + } + + async fn send_inner(&self, opts: MessageOptions) -> Result { + let mut params = serde_json::json!({ + "sessionId": self.id, + "prompt": opts.prompt, + }); + if let Some(m) = opts.mode { + params["mode"] = serde_json::to_value(m)?; + } + if let Some(mut a) = opts.attachments { + ensure_attachment_display_names(&mut a); + params["attachments"] = serde_json::to_value(a)?; + } + if let Some(headers) = opts.request_headers + && !headers.is_empty() + { + params["requestHeaders"] = serde_json::to_value(headers)?; + } + let trace_ctx = if opts.traceparent.is_some() || opts.tracestate.is_some() { + TraceContext { + traceparent: opts.traceparent, + tracestate: opts.tracestate, + } + } else { + self.client.resolve_trace_context().await + }; + inject_trace_context(&mut params, &trace_ctx); + let result = self.client.call("session.send", Some(params)).await?; + let message_id = result + .get("messageId") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_default(); + Ok(message_id) + } + + /// Enable or disable session-wide auto-approval for tool permission requests. + pub async fn set_approve_all_permissions(&self, enabled: bool) -> Result<(), Error> { + self.rpc() + .permissions() + .set_approve_all( + crate::generated::api_types::PermissionsSetApproveAllRequest { enabled }, + ) + .await?; + Ok(()) + } + + /// Send a user message and wait for the agent to finish processing. + /// + /// Accepts anything convertible to [`MessageOptions`] โ€” pass a `&str` for the + /// trivial case, or build a `MessageOptions` for mode/attachments/timeout. + /// Blocks until `session.idle` (success) or `session.error` (failure), + /// returning the last `assistant.message` event captured during streaming. + /// Times out after `MessageOptions::wait_timeout` (default 60 seconds). + /// + /// Only one `send_and_wait` call may be active per session at a time. + /// Calling [`send`](Self::send) while a `send_and_wait` + /// is in flight will also return an error. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** A `WaiterGuard` clears the in-flight slot on every + /// exit path (success, internal failure, internal timeout, *and* + /// external cancellation via `tokio::time::timeout` / `select!` / + /// dropped JoinHandle). Subsequent `send` and `send_and_wait` calls on + /// this session will succeed normally โ€” the slot is never leaked. + pub async fn send_and_wait( + &self, + opts: impl Into, + ) -> Result, Error> { + let opts = opts.into(); + let timeout_duration = opts.wait_timeout.unwrap_or(Duration::from_secs(60)); + let (tx, rx) = oneshot::channel(); + + { + let mut guard = self.idle_waiter.lock(); + if guard.is_some() { + return Err(Error::Session(SessionError::SendWhileWaiting)); + } + *guard = Some(IdleWaiter { + tx, + last_assistant_message: None, + }); + } + + // RAII: clears the idle_waiter slot on every exit path, including + // external cancellation (caller's outer `select!` / `timeout` / + // dropped future). Without this, an outer cancellation would leak + // the slot and brick subsequent `send`/`send_and_wait` calls. + let _waiter_guard = WaiterGuard { + slot: self.idle_waiter.clone(), + }; + + let result = tokio::time::timeout(timeout_duration, async { + self.send_inner(opts).await?; + match rx.await { + Ok(result) => result, + Err(_) => Err(Error::Session(SessionError::EventLoopClosed)), + } + }) + .await; + + match result { + Ok(inner) => inner, + Err(_) => Err(Error::Session(SessionError::Timeout(timeout_duration))), + } + } + + /// Retrieve the session's message history. + pub async fn get_messages(&self) -> Result, Error> { + let result = self + .client + .call( + "session.getMessages", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + let response: GetMessagesResponse = serde_json::from_value(result)?; + Ok(response.events) + } + + /// Abort the current agent turn. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Single `session.abort` RPC; the underlying + /// [`Client::call`](crate::Client::call) is cancel-safe via the + /// writer-actor. + pub async fn abort(&self) -> Result<(), Error> { + self.client + .call( + "session.abort", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + Ok(()) + } + + /// Switch to a different model. + /// + /// Pass `None` for `opts` if no extra configuration is needed. + pub async fn set_model(&self, model: &str, opts: Option) -> Result<(), Error> { + let opts = opts.unwrap_or_default(); + let request = ModelSwitchToRequest { + model_id: model.to_string(), + reasoning_effort: opts.reasoning_effort, + model_capabilities: opts.model_capabilities, + }; + self.rpc().model().switch_to(request).await?; + Ok(()) + } + + /// Get the current model. + pub async fn get_model(&self) -> Result, Error> { + Ok(self.rpc().model().get_current().await?.model_id) + } + + /// Set the session mode (e.g. "interactive", "plan", "autopilot"). + pub async fn set_mode(&self, mode: &str) -> Result { + let parsed: SessionMode = serde_json::from_value(Value::String(mode.to_string()))?; + self.rpc() + .mode() + .set(ModeSetRequest { mode: parsed }) + .await?; + Ok(mode.to_string()) + } + + /// Get the current session mode. + pub async fn get_mode(&self) -> Result { + let mode = self.rpc().mode().get().await?; + Ok(serde_json::to_value(mode)? + .as_str() + .unwrap_or("interactive") + .to_string()) + } + + /// Get the current session name. + pub async fn get_name(&self) -> Result, Error> { + Ok(self.rpc().name().get().await?.name) + } + + /// Set the current session name. + pub async fn set_name(&self, name: &str) -> Result<(), Error> { + self.rpc() + .name() + .set(NameSetRequest { + name: name.to_string(), + }) + .await + } + + /// Disconnect this session from the CLI. + /// + /// Sends the `session.destroy` RPC, stops the event loop, and unregisters + /// the session from the client. **Session state on disk** (conversation + /// history, planning state, artifacts) is **preserved**, so the + /// conversation can be resumed later via [`Client::resume_session`] + /// using this session's ID. To permanently remove all on-disk session + /// data, use [`Client::delete_session`] instead. + /// + /// The caller should ensure the session is idle (e.g. [`send_and_wait`] + /// has returned) before disconnecting; in-flight tool or event handlers + /// may otherwise observe failures. + /// + /// [`Client::resume_session`]: crate::Client::resume_session + /// [`Client::delete_session`]: crate::Client::delete_session + /// [`send_and_wait`]: Self::send_and_wait + pub async fn disconnect(&self) -> Result<(), Error> { + self.client + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + self.stop_event_loop().await; + self.client.unregister_session(&self.id); + Ok(()) + } + + /// Alias for [`disconnect`](Self::disconnect). + /// + /// Named after the `session.destroy` wire RPC. Prefer `disconnect` in + /// new code โ€” the wire-level "destroy" is misleading because on-disk + /// state is preserved. + pub async fn destroy(&self) -> Result<(), Error> { + self.disconnect().await + } + + /// List files in the session workspace. + pub async fn list_workspace_files(&self) -> Result, Error> { + Ok(self.rpc().workspaces().list_files().await?.files) + } + + /// Read a file from the session workspace. + pub async fn read_workspace_file(&self, path: &Path) -> Result { + Ok(self + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string_lossy().into_owned(), + }) + .await? + .content) + } + + /// Create a file in the session workspace. + pub async fn create_workspace_file(&self, path: &Path, content: &str) -> Result<(), Error> { + self.rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string_lossy().into_owned(), + content: content.to_string(), + }) + .await + } + + /// Read the session plan. + pub async fn read_plan(&self) -> Result<(bool, Option), Error> { + let r = self.rpc().plan().read().await?; + Ok((r.exists, r.content)) + } + + /// Update the session plan. + pub async fn update_plan(&self, content: &str) -> Result<(), Error> { + self.rpc() + .plan() + .update(PlanUpdateRequest { + content: content.to_string(), + }) + .await + } + + /// Delete the session plan. + pub async fn delete_plan(&self) -> Result<(), Error> { + self.rpc().plan().delete().await + } + + /// Write a log message to the session. + /// + /// Pass `None` for `opts` to use defaults (info level, persisted). + pub async fn log( + &self, + message: &str, + opts: Option, + ) -> Result<(), Error> { + let opts = opts.unwrap_or_default(); + let level = match opts.level { + Some(level) => Some(serde_json::from_value(serde_json::to_value(level)?)?), + None => None, + }; + let request = LogRequest { + message: message.to_string(), + level, + ephemeral: opts.ephemeral, + url: None, + }; + self.rpc().log(request).await?; + Ok(()) + } + + /// Send a telemetry event through the session's internal shared API. + pub async fn send_telemetry(&self, event: SessionTelemetryEvent) -> Result<(), Error> { + let mut params = serde_json::to_value(event)?; + let params_object = params + .as_object_mut() + .expect("SessionTelemetryEvent always serializes to an object"); + params_object.insert("sessionId".to_string(), serde_json::to_value(&self.id)?); + + self.client + .call("session.sendTelemetry", Some(params)) + .await?; + Ok(()) + } + + /// Returns the UI sub-API for elicitation, confirmation, selection, and + /// free-form input. + /// + /// All UI methods route through `session.ui.*` RPCs and require host + /// support โ€” check `session.capabilities().ui.elicitation` before use. + pub fn ui(&self) -> SessionUi<'_> { + SessionUi { session: self } + } + + /// Returns an error if the host doesn't support elicitation. + fn assert_elicitation(&self) -> Result<(), Error> { + if self + .capabilities + .read() + .ui + .as_ref() + .and_then(|u| u.elicitation) + != Some(true) + { + return Err(Error::Session(SessionError::ElicitationNotSupported)); + } + Ok(()) + } + + /// Start a fleet of sub-agents. + pub async fn start_fleet(&self, prompt: Option<&str>) -> Result { + Ok(self + .rpc() + .fleet() + .start(crate::generated::api_types::FleetStartRequest { + prompt: prompt.map(|s| s.to_string()), + }) + .await? + .started) + } + + /// Generic RPC forwarder โ€” auto-injects sessionId into params. + pub async fn call_rpc( + &self, + method: &str, + extra_params: Option, + ) -> Result { + let mut params = serde_json::json!({ "sessionId": self.id }); + let extra_obj = extra_params.as_ref().and_then(Value::as_object); + if let (Some(base), Some(extra_obj)) = (params.as_object_mut(), extra_obj) { + for (k, v) in extra_obj { + base.insert(k.clone(), v.clone()); + } + } + self.client.call(method, Some(params)).await + } +} + +impl Drop for Session { + fn drop(&mut self) { + // Cooperative shutdown: cancel the event loop's token to signal + // exit between iterations. The loop will see the cancellation on + // its next select poll and break cleanly without interrupting an + // in-flight handler. We do NOT abort the JoinHandle โ€” that would + // land at any await point in the loop body, potentially leaving + // the CLI with an unanswered request id. RFD-400 review finding + // #3. + // + // The handle itself is left in `event_loop` to be reaped by the + // tokio runtime when it next polls; we intentionally don't await + // it here because Drop is sync. + self.shutdown.cancel(); + self.client.unregister_session(&self.id); + } +} + +/// UI sub-API for a [`Session`] โ€” elicitation, confirmation, selection, +/// and free-form input. +/// +/// Acquired via [`Session::ui`]. Methods route to `session.ui.*` RPCs and +/// require host elicitation support โ€” check +/// `session.capabilities().ui.elicitation` before use. +pub struct SessionUi<'a> { + session: &'a Session, +} + +impl<'a> SessionUi<'a> { + /// Request user input via an interactive UI form (elicitation). + /// + /// Sends a JSON Schema describing form fields to the CLI host. The host + /// renders a form dialog and returns the user's response. + /// + /// Prefer the typed convenience methods [`confirm`](Self::confirm), + /// [`select`](Self::select), and [`input`](Self::input) for common cases. + pub async fn elicitation( + &self, + message: &str, + schema: Value, + ) -> Result { + self.session.assert_elicitation()?; + let result = self + .session + .client + .call( + "session.ui.elicitation", + Some(serde_json::json!({ + "sessionId": self.session.id, + "message": message, + "requestedSchema": schema, + })), + ) + .await?; + let elicitation: ElicitationResult = serde_json::from_value(result)?; + Ok(elicitation) + } + + /// Ask the user a yes/no confirmation question. + /// + /// Returns `true` if the user accepted and confirmed, `false` otherwise. + pub async fn confirm(&self, message: &str) -> Result { + self.session.assert_elicitation()?; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "confirmed": { + "type": "boolean", + "default": true, + } + }, + "required": ["confirmed"] + }); + let result = self.elicitation(message, schema).await?; + Ok(result.action == "accept" + && result + .content + .and_then(|c| c.get("confirmed").and_then(|v| v.as_bool())) + == Some(true)) + } + + /// Ask the user to select from a list of options. + /// + /// Returns the selected option string on accept, or `None` on decline/cancel. + pub async fn select(&self, message: &str, options: &[&str]) -> Result, Error> { + self.session.assert_elicitation()?; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "selection": { + "type": "string", + "enum": options, + } + }, + "required": ["selection"] + }); + let result = self.elicitation(message, schema).await?; + if result.action != "accept" { + return Ok(None); + } + let selection = result.content.and_then(|c| { + c.get("selection") + .and_then(|v| v.as_str()) + .map(String::from) + }); + Ok(selection) + } + + /// Ask the user for free-form text input. + /// + /// Returns the input string on accept, or `None` on decline/cancel. + /// Use [`InputOptions`] to set validation constraints and field metadata. + pub async fn input( + &self, + message: &str, + options: Option<&InputOptions<'_>>, + ) -> Result, Error> { + self.session.assert_elicitation()?; + let mut field = serde_json::json!({ "type": "string" }); + if let Some(opts) = options { + if let Some(title) = opts.title { + field["title"] = Value::String(title.to_string()); + } + if let Some(desc) = opts.description { + field["description"] = Value::String(desc.to_string()); + } + if let Some(min) = opts.min_length { + field["minLength"] = Value::Number(min.into()); + } + if let Some(max) = opts.max_length { + field["maxLength"] = Value::Number(max.into()); + } + if let Some(fmt) = &opts.format { + field["format"] = Value::String(fmt.as_str().to_string()); + } + if let Some(default) = opts.default { + field["default"] = Value::String(default.to_string()); + } + } + let schema = serde_json::json!({ + "type": "object", + "properties": { "value": field }, + "required": ["value"] + }); + let result = self.elicitation(message, schema).await?; + if result.action != "accept" { + return Ok(None); + } + let value = result + .content + .and_then(|c| c.get("value").and_then(|v| v.as_str()).map(String::from)); + Ok(value) + } +} + +impl Client { + /// Create a new session on the CLI. + /// + /// Sends `session.create`, registers the session on the router, + /// and spawns an internal event loop that dispatches to the handler. + /// + /// All callbacks (event handler, hooks, transform) are configured + /// via [`SessionConfig`] using [`with_handler`](SessionConfig::with_handler), + /// [`with_hooks`](SessionConfig::with_hooks), and + /// [`with_transform`](SessionConfig::with_transform). + /// + /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the + /// wire-level `hooks` flag is automatically enabled. + /// + /// If [`transform`](SessionConfig::transform) is set, the SDK injects + /// `action: "transform"` sections into the [`SystemMessageConfig`] wire + /// format and handles `systemMessage.transform` RPC callbacks during + /// the session. + /// + /// If [`handler`](SessionConfig::handler) is `None`, the session uses + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) โ€” permission + /// requests are denied; other events are no-ops. + pub async fn create_session(&self, mut config: SessionConfig) -> Result { + let handler = config + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + let hooks = config.hooks_handler.take(); + let transforms = config.transform.take(); + let command_handlers = build_command_handler_map(config.commands.as_deref()); + let session_fs_provider = config.session_fs_provider.take(); + if self.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + + if hooks.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(ref transforms) = transforms { + inject_transform_sections(&mut config, transforms.as_ref()); + } + let mut params = serde_json::to_value(&config)?; + let trace_ctx = self.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + let result = self.call("session.create", Some(params)).await?; + let create_result: CreateSessionResult = serde_json::from_value(result)?; + + let session_id = create_result.session_id; + let capabilities = Arc::new(parking_lot::RwLock::new( + create_result.capabilities.unwrap_or_default(), + )); + let channels = self.register_session(&session_id); + + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + self.clone(), + handler, + hooks, + transforms, + command_handlers, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + Ok(Session { + id: session_id, + cwd: self.cwd().clone(), + workspace_path: create_result.workspace_path, + remote_url: create_result.remote_url, + client: self.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + event_tx, + }) + } + + /// Resume an existing session on the CLI. + /// + /// Sends `session.resume` and `session.skills.reload`, registers the + /// session on the router, and spawns the event loop. + /// + /// All callbacks (event handler, hooks, transform) are configured + /// via [`ResumeSessionConfig`] using its `with_*` builder methods. + /// + /// See [`Self::create_session`] for the defaults applied when callback + /// fields are unset. + pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { + let handler = config + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + let hooks = config.hooks_handler.take(); + let transforms = config.transform.take(); + let command_handlers = build_command_handler_map(config.commands.as_deref()); + let session_fs_provider = config.session_fs_provider.take(); + if self.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + + if hooks.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(ref transforms) = transforms { + inject_transform_sections_resume(&mut config, transforms.as_ref()); + } + let session_id = config.session_id.clone(); + let mut params = serde_json::to_value(&config)?; + let trace_ctx = self.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + let result = self.call("session.resume", Some(params)).await?; + + // The CLI may reassign the session ID on resume. + let cli_session_id: SessionId = result + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or(&session_id) + .into(); + + let resume_capabilities: Option = result + .get("capabilities") + .and_then(|v| { + serde_json::from_value(v.clone()) + .map_err(|e| warn!(error = %e, "failed to deserialize capabilities from resume response")) + .ok() + }); + let remote_url = result + .get("remoteUrl") + .or_else(|| result.get("remote_url")) + .and_then(|value| value.as_str()) + .map(ToString::to_string); + + // Reload skills after resume (best-effort). + if let Err(e) = self + .call( + "session.skills.reload", + Some(serde_json::json!({ "sessionId": cli_session_id })), + ) + .await + { + warn!(error = %e, "failed to reload skills after resume"); + } + + let capabilities = Arc::new(parking_lot::RwLock::new( + resume_capabilities.unwrap_or_default(), + )); + let channels = self.register_session(&cli_session_id); + + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = CancellationToken::new(); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + cli_session_id.clone(), + self.clone(), + handler, + hooks, + transforms, + command_handlers, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + Ok(Session { + id: cli_session_id, + cwd: self.cwd().clone(), + workspace_path: None, + remote_url, + client: self.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + event_tx, + }) + } +} + +type CommandHandlerMap = HashMap>; + +fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc { + let map = match commands { + Some(commands) => commands + .iter() + .filter(|cmd| !cmd.name.is_empty()) + .map(|cmd| (cmd.name.clone(), cmd.handler.clone())) + .collect(), + None => HashMap::new(), + }; + Arc::new(map) +} + +#[allow(clippy::too_many_arguments)] +fn spawn_event_loop( + session_id: SessionId, + client: Client, + handler: Arc, + hooks: Option>, + transforms: Option>, + command_handlers: Arc, + session_fs_provider: Option>, + channels: crate::router::SessionChannels, + idle_waiter: Arc>>, + capabilities: Arc>, + event_tx: tokio::sync::broadcast::Sender, + shutdown: CancellationToken, +) -> JoinHandle<()> { + let crate::router::SessionChannels { + mut notifications, + mut requests, + } = channels; + + let span = tracing::error_span!("session_event_loop", session_id = %session_id); + tokio::spawn( + async move { + loop { + // `mpsc::UnboundedReceiver::recv` and + // `CancellationToken::cancelled` are both cancel-safe per + // RFD 400. The selected branch's `await`'d handler is + // *not* mid-cancelled by the select โ€” once a branch fires + // it runs to completion within the loop's iteration. + // Spawned child tasks inside `handle_notification` + // (permission/tool/elicitation callbacks) intentionally + // outlive the parent loop and own their own cleanup; + // this is RFD 400's "spawn background tasks to perform + // cancel-unsafe operations" pattern and is correct as-is. + tokio::select! { + _ = shutdown.cancelled() => break, + Some(notification) = notifications.recv() => { + handle_notification( + &session_id, &client, &handler, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, + ).await; + } + Some(request) = requests.recv() => { + handle_request( + &session_id, &client, &handler, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, + ).await; + } + else => break, + } + } + // Channels closed or shutdown signaled โ€” fail any pending + // send_and_wait so the caller observes a clean error. + if let Some(waiter) = idle_waiter.lock().take() { + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::EventLoopClosed))); + } + } + .instrument(span), + ) +} + +fn extract_request_id(data: &Value) -> Option { + data.get("requestId") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .map(RequestId::new) +} + +fn pending_permission_result_kind(response: &HandlerResponse) -> &'static str { + match response { + HandlerResponse::Permission(PermissionResult::Approved) => "approve-once", + HandlerResponse::Permission(PermissionResult::Denied) => "reject", + HandlerResponse::Permission(PermissionResult::NoResult) => "no-result", + // Fallback to "user-not-available" for UserNotAvailable, Deferred (when + // forced through this path), Custom (handled separately upstream), and + // any non-permission HandlerResponse that gets here defensively. + _ => "user-not-available", + } +} + +fn permission_request_response(response: &HandlerResponse) -> PermissionDecision { + match response { + HandlerResponse::Permission(PermissionResult::Approved) => { + PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce { + kind: PermissionDecisionApproveOnceKind::ApproveOnce, + }) + } + _ => PermissionDecision::Reject(PermissionDecisionReject { + kind: PermissionDecisionRejectKind::Reject, + feedback: None, + }), + } +} + +/// Map a handler response into the `result` payload for the notification +/// path (`session.permissions.handlePendingPermissionRequest`). +/// +/// Returns `None` when the SDK must not respond โ€” currently only the +/// [`PermissionResult::Deferred`] case, where the handler takes over +/// responsibility for the round-trip itself. +fn notification_permission_payload(response: &HandlerResponse) -> Option { + match response { + HandlerResponse::Permission(PermissionResult::Deferred) => None, + HandlerResponse::Permission(PermissionResult::Custom(value)) => Some(value.clone()), + _ => Some(serde_json::json!({ + "kind": pending_permission_result_kind(response), + })), + } +} + +/// Map a handler response into the JSON-RPC `result` payload for the +/// direct-RPC path (`permission.request`). +/// +/// Always returns a value. [`PermissionResult::Deferred`] is treated as +/// [`PermissionResult::Approved`] here because the JSON-RPC contract +/// requires a reply โ€” see the variant's doc comment. +fn direct_permission_payload(response: &HandlerResponse) -> Value { + match response { + HandlerResponse::Permission(PermissionResult::Custom(value)) => value.clone(), + HandlerResponse::Permission(PermissionResult::Deferred) => serde_json::to_value( + permission_request_response(&HandlerResponse::Permission(PermissionResult::Approved)), + ) + .expect("serializing direct permission response should succeed"), + HandlerResponse::Permission(PermissionResult::NoResult) + | HandlerResponse::Permission(PermissionResult::UserNotAvailable) => serde_json::json!({ + "kind": pending_permission_result_kind(response), + }), + _ => serde_json::to_value(permission_request_response(response)) + .expect("serializing direct permission response should succeed"), + } +} + +/// Process a notification from the CLI's broadcast channel. +#[allow(clippy::too_many_arguments)] +async fn handle_notification( + session_id: &SessionId, + client: &Client, + handler: &Arc, + command_handlers: &Arc, + notification: SessionEventNotification, + idle_waiter: &Arc>>, + capabilities: &Arc>, + event_tx: &tokio::sync::broadcast::Sender, +) { + let event = notification.event.clone(); + let event_type = event.parsed_type(); + + // Signal send_and_wait if active. The lock is only contended when + // a send_and_wait call is in flight (idle_waiter is Some). + match event_type { + SessionEventType::AssistantMessage + | SessionEventType::SessionIdle + | SessionEventType::SessionError => { + let mut guard = idle_waiter.lock(); + if let Some(waiter) = guard.as_mut() { + match event_type { + SessionEventType::AssistantMessage => { + waiter.last_assistant_message = Some(event.clone()); + } + SessionEventType::SessionIdle | SessionEventType::SessionError => { + if let Some(waiter) = guard.take() { + if event_type == SessionEventType::SessionIdle { + let _ = waiter.tx.send(Ok(waiter.last_assistant_message)); + } else { + let error_msg = event + .typed_data::() + .map(|d| d.message) + .or_else(|| { + event + .data + .get("message") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "session error".to_string()); + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::AgentError(error_msg)))); + } + } + } + _ => {} + } + } + } + _ => {} + } + + // Fan out the event to runtime subscribers (`Session::subscribe`). `send` + // only errors when there are no receivers, which is the normal case + // before any consumer subscribes. + let _ = event_tx.send(event.clone()); + + // Fire-and-forget dispatch for the general event. + handler + .on_event(HandlerEvent::SessionEvent { + session_id: session_id.clone(), + event, + }) + .await; + + // Update capabilities when the CLI reports changes. The CLI sends + // the full updated capabilities object โ€” replace wholesale so removals + // and new subfields are handled correctly. + if event_type == SessionEventType::CapabilitiesChanged { + match serde_json::from_value::(notification.event.data.clone()) { + Ok(changed) => *capabilities.write() = changed, + Err(e) => warn!(error = %e, "failed to deserialize capabilities.changed payload"), + } + } + + // Notification-based permission/tool/elicitation requests require a + // separate RPC callback. Spawn concurrently since the CLI doesn't block. + match event_type { + SessionEventType::PermissionRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + let data: PermissionRequestData = + serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| { + PermissionRequestData { + kind: None, + tool_call_id: None, + extra: notification.event.data.clone(), + } + }); + tokio::spawn(async move { + let response = handler + .on_event(HandlerEvent::PermissionRequest { + session_id: sid.clone(), + request_id: request_id.clone(), + data, + }) + .await; + let Some(result_value) = notification_permission_payload(&response) else { + // Handler returned Deferred โ€” it will call + // handlePendingPermissionRequest itself. + return; + }; + let _ = client + .call( + "session.permissions.handlePendingPermissionRequest", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result_value, + })), + ) + .await; + }); + } + SessionEventType::ExternalToolRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let data: ExternalToolRequestedData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize external_tool.requested"); + let client = client.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "error": format!("Failed to deserialize tool request: {e}"), + })), + ) + .await; + }); + return; + } + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + if data.tool_call_id.is_empty() || data.tool_name.is_empty() { + let error_msg = if data.tool_call_id.is_empty() { + "Missing toolCallId" + } else { + "Missing toolName" + }; + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "error": error_msg, + })), + ) + .await; + return; + } + let invocation = ToolInvocation { + session_id: sid.clone(), + tool_call_id: data.tool_call_id, + tool_name: data.tool_name, + arguments: data + .arguments + .unwrap_or(Value::Object(serde_json::Map::new())), + traceparent: data.traceparent, + tracestate: data.tracestate, + }; + let response = handler + .on_event(HandlerEvent::ExternalTool { invocation }) + .await; + let tool_result = match response { + HandlerResponse::ToolResult(r) => r, + _ => ToolResult::Text("Unexpected handler response".to_string()), + }; + let result_value = serde_json::to_value(&tool_result).unwrap_or(Value::Null); + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result_value, + })), + ) + .await; + }); + } + SessionEventType::UserInputRequested => { + // Notification-only signal for observers (UI, telemetry). + // The CLI follows up with a `userInput.request` JSON-RPC call + // that drives `HandlerEvent::UserInput` dispatch โ€” handling + // the notification here too would double-fire the handler + // and produce duplicate prompts on the consumer side. See + // github/github-app#4249. + } + SessionEventType::ElicitationRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let elicitation_data: ElicitationRequestedData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize elicitation request"); + return; + } + }; + let request = ElicitationRequest { + message: elicitation_data.message, + requested_schema: elicitation_data + .requested_schema + .map(|s| serde_json::to_value(s).unwrap_or(Value::Null)), + mode: elicitation_data.mode.map(|m| match m { + crate::generated::session_events::ElicitationRequestedMode::Form => { + crate::types::ElicitationMode::Form + } + crate::generated::session_events::ElicitationRequestedMode::Url => { + crate::types::ElicitationMode::Url + } + _ => crate::types::ElicitationMode::Unknown, + }), + elicitation_source: elicitation_data.elicitation_source, + url: elicitation_data.url, + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let cancel = ElicitationResult { + action: "cancel".to_string(), + content: None, + }; + // Dispatch to handler inside a nested task so panics are + // caught as JoinErrors (matches Node SDK's try/catch pattern). + let handler_task = tokio::spawn({ + let sid = sid.clone(); + let request_id = request_id.clone(); + async move { + handler + .on_event(HandlerEvent::ElicitationRequest { + session_id: sid, + request_id, + request, + }) + .await + } + }); + let result = match handler_task.await { + Ok(HandlerResponse::Elicitation(r)) => r, + _ => cancel.clone(), + }; + if let Err(e) = client + .call( + "session.ui.handlePendingElicitation", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result, + })), + ) + .await + { + // RPC failed โ€” attempt cancel as last resort + warn!(error = %e, "handlePendingElicitation failed, sending cancel"); + let _ = client + .call( + "session.ui.handlePendingElicitation", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": cancel, + })), + ) + .await; + } + }); + } + SessionEventType::CommandExecute => { + let data: CommandExecuteData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize command.execute"); + return; + } + }; + let client = client.clone(); + let command_handlers = command_handlers.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let request_id = data.request_id; + let ack_error = match command_handlers.get(&data.command_name).cloned() { + None => Some(format!("Unknown command: {}", data.command_name)), + Some(handler) => { + let ctx = CommandContext { + session_id: sid.clone(), + command: data.command, + command_name: data.command_name, + args: data.args, + }; + match handler.on_command(ctx).await { + Ok(()) => None, + Err(e) => Some(e.to_string()), + } + } + }; + let mut params = serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + }); + if let Some(error_msg) = ack_error { + params["error"] = serde_json::Value::String(error_msg); + } + let _ = client + .call("session.commands.handlePendingCommand", Some(params)) + .await; + }); + } + _ => {} + } +} + +/// Process a JSON-RPC request from the CLI. +async fn handle_request( + session_id: &SessionId, + client: &Client, + handler: &Arc, + hooks: Option<&dyn SessionHooks>, + transforms: Option<&dyn SystemMessageTransform>, + session_fs_provider: Option<&Arc>, + request: crate::JsonRpcRequest, +) { + let sid = session_id.clone(); + + if request.method.starts_with("sessionFs.") { + crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await; + return; + } + + match request.method.as_str() { + "hooks.invoke" => { + let params = request.params.as_ref(); + let hook_type = params + .and_then(|p| p.get("hookType")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + let input = params + .and_then(|p| p.get("input")) + .cloned() + .unwrap_or(Value::Object(Default::default())); + + let rpc_result = if let Some(hooks) = hooks { + match crate::hooks::dispatch_hook(hooks, &sid, hook_type, input).await { + Ok(output) => output, + Err(e) => { + warn!(error = %e, hook_type = hook_type, "hook dispatch failed"); + serde_json::json!({ "output": {} }) + } + } + } else { + serde_json::json!({ "output": {} }) + }; + + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "tool.call" => { + let invocation: ToolInvocation = match request + .params + .as_ref() + .and_then(|p| serde_json::from_value::(p.clone()).ok()) + { + Some(inv) => inv, + None => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + "invalid tool.call params", + ) + .await; + return; + } + }; + let response = handler + .on_event(HandlerEvent::ExternalTool { invocation }) + .await; + let tool_result = match response { + HandlerResponse::ToolResult(r) => r, + _ => ToolResult::Text("Unexpected handler response".to_string()), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::json!(ToolResultResponse { + result: tool_result + })), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "userInput.request" => { + let params = request.params.as_ref(); + let Some(question) = params + .and_then(|p| p.get("question")) + .and_then(|v| v.as_str()) + else { + warn!("userInput.request missing 'question' field"); + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INVALID_PARAMS, + message: "missing required field: question".to_string(), + data: None, + }), + }; + let _ = client.send_response(&rpc_response).await; + return; + }; + let question = question.to_string(); + let choices = params + .and_then(|p| p.get("choices")) + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }); + let allow_freeform = params + .and_then(|p| p.get("allowFreeform")) + .and_then(|v| v.as_bool()); + + let response = handler + .on_event(HandlerEvent::UserInput { + session_id: sid, + question, + choices, + allow_freeform, + }) + .await; + + let rpc_result = match response { + HandlerResponse::UserInput(Some(UserInputResponse { + answer, + was_freeform, + })) => serde_json::json!({ + "answer": answer, + "wasFreeform": was_freeform, + }), + _ => serde_json::json!({ "noResponse": true }), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "exitPlanMode.request" => { + let params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + let data: ExitPlanModeData = match serde_json::from_value(params) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize exitPlanMode.request params, using defaults"); + ExitPlanModeData::default() + } + }; + + let response = handler + .on_event(HandlerEvent::ExitPlanMode { + session_id: sid, + data, + }) + .await; + + let rpc_result = match response { + HandlerResponse::ExitPlanMode(ExitPlanModeResult { + approved, + selected_action, + feedback, + }) => serde_json::json!({ + "approved": approved, + "selectedAction": selected_action, + "feedback": feedback, + }), + _ => serde_json::json!({ "approved": true }), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "autoModeSwitch.request" => { + let error_code = request + .params + .as_ref() + .and_then(|p| p.get("errorCode")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let retry_after_seconds = request + .params + .as_ref() + .and_then(|p| p.get("retryAfterSeconds")) + .and_then(|v| v.as_u64()); + + let response = handler + .on_event(HandlerEvent::AutoModeSwitch { + session_id: sid, + error_code, + retry_after_seconds, + }) + .await; + + let answer = match response { + HandlerResponse::AutoModeSwitch(answer) => answer, + _ => AutoModeSwitchResponse::No, + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::json!({ "response": answer })), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "permission.request" => { + let Some(request_id) = request + .params + .as_ref() + .and_then(|p| p.get("requestId")) + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + else { + warn!("permission.request missing 'requestId' field"); + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INVALID_PARAMS, + message: "missing required field: requestId".to_string(), + data: None, + }), + }; + let _ = client.send_response(&rpc_response).await; + return; + }; + let request_id = RequestId::new(request_id); + let raw_params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + let data: PermissionRequestData = + serde_json::from_value(raw_params.clone()).unwrap_or(PermissionRequestData { + kind: None, + tool_call_id: None, + extra: raw_params, + }); + + let response = handler + .on_event(HandlerEvent::PermissionRequest { + session_id: sid, + request_id, + data, + }) + .await; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(direct_permission_payload(&response)), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "systemMessage.transform" => { + let params = request.params.as_ref(); + let sections: HashMap = + match params.and_then(|p| p.get("sections")) { + Some(v) => match serde_json::from_value(v.clone()) { + Ok(s) => s, + Err(e) => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + &format!("invalid sections: {e}"), + ) + .await; + return; + } + }, + None => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + "missing sections parameter", + ) + .await; + return; + } + }; + + let rpc_result = if let Some(transforms) = transforms { + let response = + crate::transforms::dispatch_transform(transforms, &sid, sections).await; + match serde_json::to_value(response) { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "failed to serialize transform response"); + serde_json::json!({ "sections": {} }) + } + } + } else { + // No transforms registered โ€” pass through all sections unchanged. + let passthrough: HashMap = sections; + serde_json::json!({ "sections": passthrough }) + }; + + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + method => { + warn!( + method = method, + "unhandled request method in session event loop" + ); + let _ = send_error_response( + client, + request.id, + error_codes::METHOD_NOT_FOUND, + &format!("unknown method: {method}"), + ) + .await; + } + } +} + +async fn send_error_response( + client: &Client, + id: u64, + code: i32, + message: &str, +) -> Result<(), Error> { + let response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(crate::JsonRpcError { + code, + message: message.to_string(), + data: None, + }), + }; + client.send_response(&response).await +} + +/// Inject `action: "transform"` sections into a `SystemMessageConfig`, +/// forcing `mode: "customize"` (required by the CLI for transforms to fire). +/// Preserves any existing caller-provided section overrides. +fn apply_transform_sections( + sys_msg: &mut SystemMessageConfig, + transforms: &dyn SystemMessageTransform, +) { + sys_msg.mode = Some("customize".to_string()); + let sections = sys_msg.sections.get_or_insert_with(HashMap::new); + for id in transforms.section_ids() { + sections.entry(id).or_insert_with(|| SectionOverride { + action: Some("transform".to_string()), + content: None, + }); + } +} + +fn inject_transform_sections(config: &mut SessionConfig, transforms: &dyn SystemMessageTransform) { + let sys_msg = config.system_message.get_or_insert_with(Default::default); + apply_transform_sections(sys_msg, transforms); +} + +fn inject_transform_sections_resume( + config: &mut ResumeSessionConfig, + transforms: &dyn SystemMessageTransform, +) { + let sys_msg = config.system_message.get_or_insert_with(Default::default); + apply_transform_sections(sys_msg, transforms); +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{ + direct_permission_payload, notification_permission_payload, pending_permission_result_kind, + permission_request_response, + }; + use crate::handler::{HandlerResponse, PermissionResult}; + + #[test] + fn pending_permission_requests_use_decision_kinds() { + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Permission( + PermissionResult::Approved, + )), + "approve-once" + ); + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Permission(PermissionResult::Denied)), + "reject" + ); + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Ok), + "user-not-available" + ); + } + + #[test] + fn direct_permission_requests_use_decision_response_kinds() { + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Permission( + PermissionResult::Approved + ),)) + .expect("serializing approved permission response should succeed"), + json!({ "kind": "approve-once" }) + ); + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Permission( + PermissionResult::Denied + ),)) + .expect("serializing denied permission response should succeed"), + json!({ "kind": "reject" }) + ); + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Ok)) + .expect("serializing fallback permission response should succeed"), + json!({ "kind": "reject" }) + ); + } + + #[test] + fn notification_payload_handles_deferred_and_custom() { + // Deferred โ†’ no payload, SDK must not respond. + assert!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Deferred, + )) + .is_none() + ); + + // Custom โ†’ handler-supplied value passed through verbatim. + let custom = json!({ + "kind": "approve-and-remember", + "allowlist": ["ls", "grep"], + }); + assert_eq!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Custom(custom.clone()), + )), + Some(custom) + ); + + // Approved/Denied โ†’ existing kind-only shape. + assert_eq!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Approved, + )), + Some(json!({ "kind": "approve-once" })) + ); + assert_eq!( + notification_permission_payload( + &HandlerResponse::Permission(PermissionResult::Denied,) + ), + Some(json!({ "kind": "reject" })) + ); + } + + #[test] + fn direct_payload_handles_deferred_and_custom() { + // Custom โ†’ handler-supplied value passed through verbatim. + let custom = json!({ + "kind": "approve-and-remember", + "allowlist": ["ls", "grep"], + }); + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Custom( + custom.clone(), + ))), + custom + ); + + // Deferred โ†’ falls back to Approved because the direct RPC must reply. + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Deferred)), + json!({ "kind": "approve-once" }) + ); + + // Approved/Denied โ†’ existing kind-only shape. + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Approved)), + json!({ "kind": "approve-once" }) + ); + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Denied)), + json!({ "kind": "reject" }) + ); + } +} diff --git a/rust/src/session_fs.rs b/rust/src/session_fs.rs new file mode 100644 index 000000000..e675760a1 --- /dev/null +++ b/rust/src/session_fs.rs @@ -0,0 +1,394 @@ +//! Session filesystem provider โ€” virtualizable filesystem layer over JSON-RPC. +//! +//! When [`ClientOptions::session_fs`] is set, the SDK tells the CLI to delegate +//! all per-session filesystem operations (`readFile`, `writeFile`, `stat`, ...) +//! to a [`SessionFsProvider`] registered on each session. This lets host +//! applications sandbox sessions, project files into in-memory or remote +//! storage, and apply permission policies before bytes move. +//! +//! # Concurrency +//! +//! Each inbound `sessionFs.*` request is dispatched on its own spawned task, +//! matching Node's behavior. Provider implementations MUST be safe for +//! concurrent invocation across distinct paths. Use internal synchronization +//! (e.g. [`tokio::sync::Mutex`] keyed by path) if your backing store needs +//! ordering. +//! +//! # Errors +//! +//! Provider methods return [`Result`]. The SDK adapts these into +//! the schema's `{ ..., error: Option }` payload, mapping +//! [`FsError::NotFound`] to the wire's `ENOENT` and everything else to +//! `UNKNOWN`. A [`From`] conversion is provided so handlers +//! backed by [`tokio::fs`](https://docs.rs/tokio/latest/tokio/fs/index.html) +//! can propagate `io::Error` with `?`. +//! +//! # Example +//! +//! ```no_run +//! use std::sync::Arc; +//! use async_trait::async_trait; +//! use github_copilot_sdk::types::{SessionFsProvider, FsError, FileInfo, DirEntry}; +//! +//! struct MyProvider; +//! +//! #[async_trait] +//! impl SessionFsProvider for MyProvider { +//! async fn read_file(&self, path: &str) -> Result { +//! std::fs::read_to_string(path) +//! .map_err(FsError::from) +//! } +//! } +//! ``` + +use async_trait::async_trait; + +use crate::generated::api_types::{ + SessionFsError, SessionFsErrorCode, SessionFsReaddirWithTypesEntry, + SessionFsReaddirWithTypesEntryType, SessionFsSetProviderConventions, SessionFsStatResult, +}; + +/// Configuration for a custom session filesystem provider. +/// +/// When set on [`ClientOptions::session_fs`](crate::ClientOptions::session_fs), +/// the SDK calls `sessionFs.setProvider` during [`Client::start`](crate::Client::start) +/// to tell the CLI to route per-session filesystem operations to the SDK. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct SessionFsConfig { + /// Initial working directory for sessions (the user's project directory). + pub initial_cwd: String, + /// Path within each session's SessionFs where the runtime stores + /// session-scoped files (events, workspace, checkpoints, etc.). + pub session_state_path: String, + /// Path conventions used by this filesystem provider. + pub conventions: SessionFsConventions, +} + +impl SessionFsConfig { + /// Build a new config with the required fields. + pub fn new( + initial_cwd: impl Into, + session_state_path: impl Into, + conventions: SessionFsConventions, + ) -> Self { + Self { + initial_cwd: initial_cwd.into(), + session_state_path: session_state_path.into(), + conventions, + } + } +} + +/// Path conventions used by a session filesystem provider. +/// +/// Hand-authored consumer-facing enum (rather than reusing +/// [`SessionFsSetProviderConventions`]) to avoid exposing the generated +/// catch-all `Unknown` variant on the input side. The SDK rejects unknown +/// conventions at validation time with a typed error. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionFsConventions { + /// POSIX-style paths (`/foo/bar`). + Posix, + /// Windows-style paths (`C:\foo\bar`). + Windows, +} + +impl SessionFsConventions { + pub(crate) fn into_wire(self) -> SessionFsSetProviderConventions { + match self { + Self::Posix => SessionFsSetProviderConventions::Posix, + Self::Windows => SessionFsSetProviderConventions::Windows, + } + } +} + +/// Error returned by a [`SessionFsProvider`] method. +/// +/// The SDK maps this onto the wire schema's [`SessionFsError`]: +/// [`FsError::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[non_exhaustive] +#[derive(Debug, Clone, thiserror::Error)] +pub enum FsError { + /// File or directory does not exist. + #[error("not found: {0}")] + NotFound(String), + + /// Any other filesystem error (permission denied, I/O error, etc.). + /// + /// The wire mapping always uses `UNKNOWN` as the code; the message is + /// preserved for diagnostics. + #[error("{0}")] + Other(String), +} + +impl FsError { + pub(crate) fn into_wire(self) -> SessionFsError { + match self { + Self::NotFound(message) => SessionFsError { + code: SessionFsErrorCode::ENOENT, + message: Some(message), + }, + Self::Other(message) => SessionFsError { + code: SessionFsErrorCode::UNKNOWN, + message: Some(message), + }, + } + } +} + +impl From for FsError { + fn from(err: std::io::Error) -> Self { + match err.kind() { + std::io::ErrorKind::NotFound => Self::NotFound(err.to_string()), + _ => Self::Other(err.to_string()), + } + } +} + +/// File or directory metadata returned by [`SessionFsProvider::stat`]. +/// +/// The SDK adapts this into the wire's [`SessionFsStatResult`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct FileInfo { + /// Whether the path is a regular file. + pub is_file: bool, + /// Whether the path is a directory. + pub is_directory: bool, + /// File size in bytes. + pub size: i64, + /// ISO 8601 timestamp of last modification. + pub mtime: String, + /// ISO 8601 timestamp of creation. + pub birthtime: String, +} + +impl FileInfo { + /// Build a metadata record. The mtime/birthtime arguments are caller- + /// supplied ISO 8601 strings โ€” the SDK does not format timestamps for + /// you. + pub fn new( + is_file: bool, + is_directory: bool, + size: i64, + mtime: impl Into, + birthtime: impl Into, + ) -> Self { + Self { + is_file, + is_directory, + size, + mtime: mtime.into(), + birthtime: birthtime.into(), + } + } + + pub(crate) fn into_wire(self) -> SessionFsStatResult { + SessionFsStatResult { + is_file: self.is_file, + is_directory: self.is_directory, + size: self.size, + mtime: self.mtime, + birthtime: self.birthtime, + error: None, + } + } +} + +/// Kind of entry returned by [`SessionFsProvider::readdir_with_types`]. +/// +/// The wire schema's `Unknown` forward-compat variant is intentionally absent +/// from this consumer-facing enum โ€” providers must classify each entry as +/// either a file or a directory. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DirEntryKind { + /// Regular file. + File, + /// Directory. + Directory, +} + +impl DirEntryKind { + fn into_wire(self) -> SessionFsReaddirWithTypesEntryType { + match self { + Self::File => SessionFsReaddirWithTypesEntryType::File, + Self::Directory => SessionFsReaddirWithTypesEntryType::Directory, + } + } +} + +/// Single entry in a directory listing returned by +/// [`SessionFsProvider::readdir_with_types`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct DirEntry { + /// Entry name (basename, not full path). + pub name: String, + /// Whether the entry is a file or a directory. + pub kind: DirEntryKind, +} + +impl DirEntry { + /// Build a new directory entry. + pub fn new(name: impl Into, kind: DirEntryKind) -> Self { + Self { + name: name.into(), + kind, + } + } + + pub(crate) fn into_wire(self) -> SessionFsReaddirWithTypesEntry { + SessionFsReaddirWithTypesEntry { + name: self.name, + r#type: self.kind.into_wire(), + } + } +} + +/// Implementor-supplied filesystem backing for a session. +/// +/// Each method takes a path using the conventions declared in +/// [`SessionFsConfig::conventions`] and returns the operation's result. The +/// SDK adapts every `Result<_, FsError>` into the JSON-RPC response shape +/// expected by the GitHub Copilot CLI. +/// +/// # Concurrency +/// +/// Implementations MUST be `Send + Sync` and safe for concurrent invocation +/// across distinct paths. The SDK dispatches each inbound `sessionFs.*` +/// request on its own spawned task. Use internal synchronization (e.g. +/// [`tokio::sync::Mutex`] keyed by path) if your backing store requires +/// ordering. +/// +/// # Forward compatibility +/// +/// Methods on this trait have default implementations that return +/// `Err(FsError::Other("operation not supported".into()))`. When the CLI +/// schema grows new `sessionFs.*` methods, the SDK adds them to this trait +/// with default impls so existing implementations continue to compile. +/// Override only the methods relevant to your backing store. +#[async_trait] +pub trait SessionFsProvider: Send + Sync + 'static { + /// Read the full contents of a file as UTF-8. + async fn read_file(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("read_file not supported".to_string())) + } + + /// Write content to a file, creating parent directories if needed. + async fn write_file( + &self, + path: &str, + content: &str, + mode: Option, + ) -> Result<(), FsError> { + let _ = (path, content, mode); + Err(FsError::Other("write_file not supported".to_string())) + } + + /// Append content to a file, creating parent directories if needed. + async fn append_file( + &self, + path: &str, + content: &str, + mode: Option, + ) -> Result<(), FsError> { + let _ = (path, content, mode); + Err(FsError::Other("append_file not supported".to_string())) + } + + /// Check whether a path exists. + /// + /// Returns `Ok(false)` for non-existent paths, not [`FsError::NotFound`]. + async fn exists(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("exists not supported".to_string())) + } + + /// Get metadata about a file or directory. + async fn stat(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("stat not supported".to_string())) + } + + /// Create a directory. When `recursive`, missing parents are also created. + async fn mkdir(&self, path: &str, recursive: bool, mode: Option) -> Result<(), FsError> { + let _ = (path, recursive, mode); + Err(FsError::Other("mkdir not supported".to_string())) + } + + /// List entry names in a directory. + async fn readdir(&self, path: &str) -> Result, FsError> { + let _ = path; + Err(FsError::Other("readdir not supported".to_string())) + } + + /// List directory entries with type information. + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let _ = path; + Err(FsError::Other( + "readdir_with_types not supported".to_string(), + )) + } + + /// Remove a file or directory. When `force`, missing paths are not an + /// error. When `recursive`, directory contents are removed as well. + async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { + let _ = (path, recursive, force); + Err(FsError::Other("rm not supported".to_string())) + } + + /// Rename or move a file or directory. + async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { + let _ = (src, dest); + Err(FsError::Other("rename not supported".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fs_error_maps_io_not_found_to_enoent() { + let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "missing.txt"); + let fs_err: FsError = io_err.into(); + assert!(matches!(fs_err, FsError::NotFound(_))); + let wire = fs_err.into_wire(); + assert_eq!(wire.code, SessionFsErrorCode::ENOENT); + } + + #[test] + fn fs_error_maps_other_io_to_unknown() { + let io_err = std::io::Error::other("disk full"); + let fs_err: FsError = io_err.into(); + assert!(matches!(fs_err, FsError::Other(_))); + let wire = fs_err.into_wire(); + assert_eq!(wire.code, SessionFsErrorCode::UNKNOWN); + assert!(wire.message.unwrap().contains("disk full")); + } + + #[test] + fn conventions_maps_to_wire() { + assert_eq!( + SessionFsConventions::Posix.into_wire(), + SessionFsSetProviderConventions::Posix + ); + assert_eq!( + SessionFsConventions::Windows.into_wire(), + SessionFsSetProviderConventions::Windows + ); + } + + struct DefaultProvider; + #[async_trait] + impl SessionFsProvider for DefaultProvider {} + + #[tokio::test] + async fn default_impls_return_unsupported() { + let p = DefaultProvider; + let err = p.read_file("/x").await.unwrap_err(); + assert!(matches!(err, FsError::Other(ref m) if m.contains("not supported"))); + } +} diff --git a/rust/src/session_fs_dispatch.rs b/rust/src/session_fs_dispatch.rs new file mode 100644 index 000000000..7b2ae49fd --- /dev/null +++ b/rust/src/session_fs_dispatch.rs @@ -0,0 +1,351 @@ +//! Inbound `sessionFs.*` JSON-RPC request dispatch helpers. +//! +//! Internal โ€” public-facing trait lives in `crate::session_fs`. Each helper +//! deserializes the typed request, calls the [`SessionFsProvider`] method, +//! and serializes the schema response with `FsError` mapped onto the wire's +//! `SessionFsError` variant. + +use std::sync::Arc; + +use serde::Serialize; +use serde_json::Value; +use tracing::warn; + +use crate::generated::api_types::{ + SessionFsAppendFileRequest, SessionFsExistsRequest, SessionFsExistsResult, + SessionFsMkdirRequest, SessionFsReadFileRequest, SessionFsReadFileResult, + SessionFsReaddirRequest, SessionFsReaddirResult, SessionFsReaddirWithTypesRequest, + SessionFsReaddirWithTypesResult, SessionFsRenameRequest, SessionFsRmRequest, + SessionFsStatRequest, SessionFsStatResult, SessionFsWriteFileRequest, +}; +use crate::session_fs::{FsError, SessionFsProvider}; +use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes}; + +/// Helper: serialize a typed result, send the response. +async fn respond(client: &Client, request_id: u64, result: T) { + let value = match serde_json::to_value(&result) { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "failed to serialize sessionFs response"); + send_error(client, request_id, "serialization failure").await; + return; + } + }; + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: Some(value), + error: None, + }) + .await; +} + +async fn send_error(client: &Client, request_id: u64, message: &str) { + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INTERNAL_ERROR, + message: message.to_string(), + data: None, + }), + }) + .await; +} + +fn parse_params(request: &JsonRpcRequest) -> Option { + request + .params + .as_ref() + .and_then(|p| serde_json::from_value(p.clone()).ok()) +} + +pub(crate) async fn read_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReadFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.readFile params").await; + return; + } + }; + let id = request.id; + let result = match provider.read_file(¶ms.path).await { + Ok(content) => SessionFsReadFileResult { + content, + error: None, + }, + Err(e) => SessionFsReadFileResult { + content: String::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn write_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsWriteFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.writeFile params").await; + return; + } + }; + let id = request.id; + match provider + .write_file(¶ms.path, ¶ms.content, params.mode) + .await + { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn append_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsAppendFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.appendFile params").await; + return; + } + }; + let id = request.id; + match provider + .append_file(¶ms.path, ¶ms.content, params.mode) + .await + { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn exists( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsExistsRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.exists params").await; + return; + } + }; + let id = request.id; + // Match Node's `createSessionFsAdapter`: errors collapse to `exists: false`. + let exists_value = provider.exists(¶ms.path).await.unwrap_or(false); + respond( + client, + id, + SessionFsExistsResult { + exists: exists_value, + }, + ) + .await; +} + +pub(crate) async fn stat( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsStatRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.stat params").await; + return; + } + }; + let id = request.id; + let result = match provider.stat(¶ms.path).await { + Ok(info) => info.into_wire(), + Err(e) => SessionFsStatResult { + is_file: false, + is_directory: false, + size: 0, + mtime: String::new(), + birthtime: String::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn mkdir( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsMkdirRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.mkdir params").await; + return; + } + }; + let id = request.id; + let recursive = params.recursive.unwrap_or(false); + match provider.mkdir(¶ms.path, recursive, params.mode).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn readdir( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReaddirRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.readdir params").await; + return; + } + }; + let id = request.id; + let result = match provider.readdir(¶ms.path).await { + Ok(entries) => SessionFsReaddirResult { + entries, + error: None, + }, + Err(e) => SessionFsReaddirResult { + entries: Vec::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn readdir_with_types( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReaddirWithTypesRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error( + client, + request.id, + "invalid sessionFs.readdirWithTypes params", + ) + .await; + return; + } + }; + let id = request.id; + let result = match provider.readdir_with_types(¶ms.path).await { + Ok(entries) => SessionFsReaddirWithTypesResult { + entries: entries.into_iter().map(|e| e.into_wire()).collect(), + error: None, + }, + Err(e) => SessionFsReaddirWithTypesResult { + entries: Vec::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn rm( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsRmRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.rm params").await; + return; + } + }; + let id = request.id; + let recursive = params.recursive.unwrap_or(false); + let force = params.force.unwrap_or(false); + match provider.rm(¶ms.path, recursive, force).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn rename( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsRenameRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.rename params").await; + return; + } + }; + let id = request.id; + match provider.rename(¶ms.src, ¶ms.dest).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +/// Dispatch a `sessionFs.*` request to the appropriate handler. Returns +/// `true` if the request was a session-fs method (whether or not a provider +/// was registered), `false` otherwise (caller should continue matching). +pub(crate) async fn dispatch( + client: &Client, + provider: Option<&Arc>, + request: JsonRpcRequest, +) -> bool { + let method = request.method.as_str(); + if !method.starts_with("sessionFs.") { + return false; + } + let provider = match provider { + Some(p) => p.clone(), + None => { + warn!(method = %method, "sessionFs request without registered provider"); + send_error( + client, + request.id, + "no sessionFs provider registered for this session", + ) + .await; + return true; + } + }; + match method { + "sessionFs.readFile" => read_file(client, &provider, request).await, + "sessionFs.writeFile" => write_file(client, &provider, request).await, + "sessionFs.appendFile" => append_file(client, &provider, request).await, + "sessionFs.exists" => exists(client, &provider, request).await, + "sessionFs.stat" => stat(client, &provider, request).await, + "sessionFs.mkdir" => mkdir(client, &provider, request).await, + "sessionFs.readdir" => readdir(client, &provider, request).await, + "sessionFs.readdirWithTypes" => readdir_with_types(client, &provider, request).await, + "sessionFs.rm" => rm(client, &provider, request).await, + "sessionFs.rename" => rename(client, &provider, request).await, + _ => { + warn!(method = %method, "unknown sessionFs.* method"); + send_error(client, request.id, "unknown sessionFs method").await; + } + } + true +} + +// FsError is used through `into_wire()` calls above. +#[allow(dead_code)] +fn _ensure_fs_error_used(_e: FsError) {} diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs new file mode 100644 index 000000000..52c15b2eb --- /dev/null +++ b/rust/src/subscription.rs @@ -0,0 +1,218 @@ +//! Subscription handles for observing session and lifecycle events. +//! +//! Returned by [`Session::subscribe`](crate::session::Session::subscribe) and +//! [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +//! +//! Each subscription is an opt-in **observer** of events that are also +//! delivered to the [`SessionHandler`](crate::handler::SessionHandler). +//! Subscribers receive a clone of every event but cannot influence +//! permission decisions, tool results, or anything else that requires +//! returning a [`HandlerResponse`](crate::handler::HandlerResponse). +//! +//! # Async iteration +//! +//! The subscription types implement [`tokio_stream::Stream`], so consumers +//! can use adapter combinators from [`tokio_stream::StreamExt`] or +//! `futures::StreamExt` (filtering, mapping, batching, racing with +//! `tokio::select!`, etc.) without learning the SDK's internal channel +//! choice. A simple `while let Ok(event) = sub.recv().await { ... }` loop +//! also works for callers who don't need the [`Stream`](tokio_stream::Stream) +//! surface. +//! +//! # Lag policy +//! +//! Each subscriber maintains its own internal queue. If a consumer cannot +//! keep up, the oldest events are dropped and the next call yields +//! [`Lagged`] reporting how many events were skipped. Slow subscribers do +//! not block the producer. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::sync::broadcast::Receiver; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::{Stream, StreamExt as _}; + +use crate::types::{SessionEvent, SessionLifecycleEvent}; + +/// The subscription fell behind the producer. +/// +/// Reports the number of events that were dropped from this subscriber's +/// queue because the consumer didn't keep up. The subscription continues +/// after this error, starting from the next live event โ€” callers who care +/// about lag should match on it and decide whether to resync, re-fetch, or +/// log and continue. +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[error("subscription lagged behind by {0} events")] +pub struct Lagged(u64); + +impl Lagged { + /// Number of events skipped before this consumer could read them. + pub fn skipped(&self) -> u64 { + self.0 + } +} + +/// Error returned by [`EventSubscription::recv`] and +/// [`LifecycleSubscription::recv`]. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum RecvError { + /// The producer is gone โ€” the session has shut down or the client has + /// stopped. No further events will be delivered. + #[error("subscription closed")] + Closed, + + /// The subscriber fell behind. See [`Lagged`]. + #[error(transparent)] + Lagged(#[from] Lagged), +} + +macro_rules! define_subscription { + ( + $(#[$meta:meta])* + $name:ident, $item:ty $(,)? + ) => { + $(#[$meta])* + #[must_use = "subscriptions are inert until polled"] + pub struct $name { + inner: BroadcastStream<$item>, + } + + impl $name { + pub(crate) fn new(rx: Receiver<$item>) -> Self { + Self { + inner: BroadcastStream::new(rx), + } + } + + /// Receive the next event. + /// + /// Returns: + /// + /// - `Ok(event)` for the next delivered event. + /// - `Err(`[`RecvError::Lagged`]`)` if the subscriber fell behind; + /// call `recv` again to continue from the next live event. + /// - `Err(`[`RecvError::Closed`]`)` once the producer is gone. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Wraps a `tokio::sync::broadcast::Receiver` + /// via `BroadcastStream`; both are cancel-safe by design. + /// Dropping the future before completion is harmless โ€” events + /// already buffered for this subscriber remain available on + /// the next `recv` call. + pub async fn recv(&mut self) -> Result<$item, RecvError> { + match self.inner.next().await { + Some(Ok(event)) => Ok(event), + Some(Err(BroadcastStreamRecvError::Lagged(n))) => { + Err(RecvError::Lagged(Lagged(n))) + } + None => Err(RecvError::Closed), + } + } + } + + impl Stream for $name { + type Item = Result<$item, Lagged>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))), + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => { + Poll::Ready(Some(Err(Lagged(n)))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + } + }; +} + +define_subscription! { + /// Subscription to runtime events for a single + /// [`Session`](crate::session::Session). + /// + /// Created by [`Session::subscribe`](crate::session::Session::subscribe). + /// Implements [`Stream`] yielding `Result`. + /// Drop the value to unsubscribe; there is no separate cancel handle. + EventSubscription, SessionEvent +} + +define_subscription! { + /// Subscription to lifecycle events on a [`Client`](crate::Client). + /// + /// Created by + /// [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). + /// Implements [`Stream`] yielding `Result`. + /// Drop the value to unsubscribe; there is no separate cancel handle. + LifecycleSubscription, SessionLifecycleEvent +} + +#[cfg(test)] +mod tests { + use tokio::sync::broadcast; + + use super::*; + + fn make_event(id: &str) -> SessionEvent { + SessionEvent { + id: id.into(), + timestamp: "2025-01-01T00:00:00Z".into(), + parent_id: None, + ephemeral: None, + agent_id: None, + debug_cli_received_at_ms: None, + debug_ws_forwarded_at_ms: None, + event_type: "noop".into(), + data: serde_json::json!({}), + } + } + + #[tokio::test] + async fn recv_yields_then_closes_on_drop_sender() { + let (tx, rx) = broadcast::channel(8); + let mut sub = EventSubscription::new(rx); + tx.send(make_event("a")).unwrap(); + tx.send(make_event("b")).unwrap(); + drop(tx); + + assert_eq!(sub.recv().await.unwrap().id, "a"); + assert_eq!(sub.recv().await.unwrap().id, "b"); + assert!(matches!(sub.recv().await, Err(RecvError::Closed))); + } + + #[tokio::test] + async fn recv_surfaces_lag() { + let (tx, rx) = broadcast::channel(2); + let mut sub = EventSubscription::new(rx); + for id in ["a", "b", "c", "d"] { + tx.send(make_event(id)).unwrap(); + } + match sub.recv().await { + Err(RecvError::Lagged(l)) => assert_eq!(l.skipped(), 2), + other => panic!("expected Lagged, got {other:?}"), + } + // Subscription continues with the live tail. + assert_eq!(sub.recv().await.unwrap().id, "c"); + assert_eq!(sub.recv().await.unwrap().id, "d"); + } + + #[tokio::test] + async fn stream_impl_matches_recv_semantics() { + let (tx, rx) = broadcast::channel(8); + let mut sub = EventSubscription::new(rx); + tx.send(make_event("a")).unwrap(); + drop(tx); + + // poll_next path + let next = sub.next().await; + assert_eq!(next.unwrap().unwrap().id, "a"); + assert!(sub.next().await.is_none()); + } +} diff --git a/rust/src/tool.rs b/rust/src/tool.rs new file mode 100644 index 000000000..cccdad486 --- /dev/null +++ b/rust/src/tool.rs @@ -0,0 +1,828 @@ +//! Typed tool definition framework. +//! +//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types, +//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a +//! [`SessionHandler`](crate::handler::SessionHandler). +//! +//! Enable the `derive` feature for `schema_for`, which generates JSON +//! Schema from Rust types via `schemars`. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +/// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas. +#[cfg(feature = "derive")] +pub use schemars::JsonSchema; + +use crate::Error; +use crate::handler::{ExitPlanModeResult, PermissionResult, SessionHandler, UserInputResponse}; +use crate::types::{ + ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, + SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, ToolResultExpanded, +}; + +/// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. +/// +/// Strips `$schema` and `title` root-level metadata so the output is ready +/// to use as [`Tool::parameters`]. +/// +/// # Example +/// +/// ```rust +/// use github_copilot_sdk::tool::{schema_for, JsonSchema}; +/// +/// #[derive(JsonSchema)] +/// struct Params { +/// /// City name +/// city: String, +/// } +/// +/// let schema = schema_for::(); +/// assert_eq!(schema["type"], "object"); +/// assert!(schema["properties"]["city"].is_object()); +/// ``` +#[cfg(feature = "derive")] +pub fn schema_for() -> serde_json::Value { + let schema = schemars::schema_for!(T); + let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail"); + if let Some(obj) = value.as_object_mut() { + obj.remove("$schema"); + obj.remove("title"); + } + value +} + +/// Convert a JSON Schema [`Value`](serde_json::Value) into the +/// [`Tool::parameters`] map shape expected by the protocol. +/// +/// Panics if the input is not a JSON object โ€” tool parameter schemas +/// are always top-level objects (`{"type": "object", ...}`). Pair with +/// [`schema_for`] or a `serde_json::json!(...)` literal. +/// +/// Use [`try_tool_parameters`] when the schema comes from dynamic input and +/// should return a recoverable error instead of panicking. +/// +/// # Example +/// +/// ```rust +/// use github_copilot_sdk::tool::tool_parameters; +/// use github_copilot_sdk::Tool; +/// +/// let mut tool = Tool::default(); +/// tool.name = "ping".to_string(); +/// tool.description = "ping the server".to_string(); +/// tool.parameters = tool_parameters(serde_json::json!({"type": "object"})); +/// # let _ = tool; +/// ``` +pub fn tool_parameters(schema: serde_json::Value) -> HashMap { + try_tool_parameters(schema).expect("tool parameter schema must be a JSON object") +} + +/// Fallible variant of [`tool_parameters`] for callers handling dynamic schema input. +pub fn try_tool_parameters( + schema: serde_json::Value, +) -> Result, serde_json::Error> { + serde_json::from_value(schema) +} + +/// A client-defined tool with its handler logic. +/// +/// Implement this trait for each tool you expose to the Copilot agent. +/// The struct is a named type โ€” visible in stack traces and navigable +/// via "go to definition" โ€” unlike closure-based alternatives. +/// +/// # Example +/// +/// ```rust,ignore +/// use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler}; +/// use github_copilot_sdk::{Error, Tool, ToolInvocation, ToolResult}; +/// use serde::Deserialize; +/// use async_trait::async_trait; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct GetWeatherParams { +/// /// City name +/// city: String, +/// /// Temperature unit +/// unit: Option, +/// } +/// +/// struct GetWeatherTool; +/// +/// #[async_trait] +/// impl ToolHandler for GetWeatherTool { +/// fn tool(&self) -> Tool { +/// Tool { +/// name: "get_weather".to_string(), +/// namespaced_name: None, +/// description: "Get weather for a city".to_string(), +/// parameters: tool_parameters(schema_for::()), +/// instructions: None, +/// ..Default::default() +/// } +/// } +/// +/// async fn call(&self, inv: ToolInvocation) -> Result { +/// let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; +/// Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) +/// } +/// } +/// ``` +#[async_trait] +pub trait ToolHandler: Send + Sync { + /// The tool definition sent to the CLI during session creation. + fn tool(&self) -> Tool; + + /// Handle a tool invocation from the agent. + async fn call(&self, invocation: ToolInvocation) -> Result; +} + +/// Define a tool from an async function (or closure) that takes a typed, +/// `JsonSchema`-derived parameter struct. +/// +/// The returned `Box` plugs directly into +/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated +/// via [`schema_for`] at construction time. +/// +/// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`) +/// accepts both bare `async fn` items and closures โ€” the same shape as +/// [`tower::service_fn`][tower-service-fn] and +/// [`hyper::service::service_fn`][hyper-service-fn]. Prefer a free `async fn` +/// for non-trivial tools so it shows up in stack traces by name. +/// +/// The closure receives the full [`ToolInvocation`] alongside the deserialized +/// parameters so handlers can use `inv.session_id`, `inv.tool_call_id`, or +/// other invocation metadata. Handlers that don't need that metadata can +/// destructure with `|_inv, params|`. +/// +/// # Example +/// +/// ```rust,no_run +/// use github_copilot_sdk::tool::{define_tool, JsonSchema}; +/// use github_copilot_sdk::types::ToolInvocation; +/// use github_copilot_sdk::{Error, ToolResult}; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct GetWeatherParams { +/// /// City name +/// city: String, +/// } +/// +/// async fn get_weather( +/// inv: ToolInvocation, +/// params: GetWeatherParams, +/// ) -> Result { +/// // `inv.session_id` and `inv.tool_call_id` are available for telemetry, +/// // streaming updates, scoping DB lookups, etc. +/// let _ = inv.session_id; +/// Ok(ToolResult::Text(format!("Sunny in {}", params.city))) +/// } +/// +/// // Pass a free async fn โ€” preferred for non-trivial tools. +/// let tool = define_tool("get_weather", "Get weather for a city", get_weather); +/// +/// // ...or an inline closure when the body is trivial. +/// let tool = define_tool( +/// "echo", +/// "Echo the input", +/// |_inv, params: GetWeatherParams| async move { +/// Ok(ToolResult::Text(params.city)) +/// }, +/// ); +/// # let _ = tool; +/// ``` +/// +/// [tower-service-fn]: https://docs.rs/tower/latest/tower/fn.service_fn.html +/// [hyper-service-fn]: https://docs.rs/hyper/latest/hyper/service/fn.service_fn.html +#[cfg(feature = "derive")] +pub fn define_tool( + name: impl Into, + description: impl Into, + handler: F, +) -> Box +where + P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, + F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + struct FnTool { + name: String, + description: String, + parameters: HashMap, + handler: F, + _marker: std::marker::PhantomData, + } + + #[async_trait] + impl ToolHandler for FnTool + where + P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, + F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + fn tool(&self) -> Tool { + Tool { + name: self.name.clone(), + description: self.description.clone(), + parameters: self.parameters.clone(), + ..Default::default() + } + } + + async fn call(&self, mut invocation: ToolInvocation) -> Result { + let arguments = std::mem::take(&mut invocation.arguments); + let params: P = serde_json::from_value(arguments)?; + (self.handler)(invocation, params).await + } + } + + Box::new(FnTool { + name: name.into(), + description: description.into(), + parameters: tool_parameters(schema_for::

()), + handler, + _marker: std::marker::PhantomData, + }) +} + +/// A [`SessionHandler`] that dispatches tool calls to registered +/// [`ToolHandler`] implementations by name. +/// +/// For tool calls matching a registered handler, the handler is invoked +/// directly. All other events (permissions, user input, unrecognized tools) +/// are forwarded to the inner handler. +/// +/// # Example +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use github_copilot_sdk::handler::ApproveAllHandler; +/// use github_copilot_sdk::tool::ToolHandlerRouter; +/// +/// let router = ToolHandlerRouter::new( +/// vec![/* Box::new(MyTool), ... */], +/// Arc::new(ApproveAllHandler), +/// ); +/// +/// // Use router.tools() in SessionConfig +/// // Use Arc::new(router) as the session handler +/// ``` +pub struct ToolHandlerRouter { + handlers: HashMap>, + inner: Arc, +} + +impl std::fmt::Debug for ToolHandlerRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tools: Vec<_> = self.handlers.keys().collect(); + tools.sort(); + f.debug_struct("ToolHandlerRouter") + .field("tool_count", &self.handlers.len()) + .field("tools", &tools) + .finish() + } +} + +impl ToolHandlerRouter { + /// Create a router from tool handler impls and a fallback handler. + /// + /// Call [`tools()`](Self::tools) to get the tool definitions for + /// [`SessionConfig::tools`](crate::SessionConfig::tools). + pub fn new(tools: Vec>, inner: Arc) -> Self { + let mut handlers = HashMap::new(); + for tool in tools { + handlers.insert(tool.tool().name.clone(), tool); + } + Self { handlers, inner } + } + + /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools). + pub fn tools(&self) -> Vec { + self.handlers.values().map(|h| h.tool()).collect() + } +} + +#[async_trait] +impl SessionHandler for ToolHandlerRouter { + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { + let Some(handler) = self.handlers.get(&invocation.tool_name) else { + return self.inner.on_external_tool(invocation).await; + }; + match handler.call(invocation).await { + Ok(result) => result, + Err(e) => { + let msg = e.to_string(); + ToolResult::Expanded(ToolResultExpanded { + text_result_for_llm: msg.clone(), + result_type: "failure".to_string(), + session_log: None, + error: Some(msg), + }) + } + } + } + + async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) { + self.inner.on_session_event(session_id, event).await + } + + async fn on_permission_request( + &self, + session_id: SessionId, + request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + self.inner + .on_permission_request(session_id, request_id, data) + .await + } + + async fn on_user_input( + &self, + session_id: SessionId, + question: String, + choices: Option>, + allow_freeform: Option, + ) -> Option { + self.inner + .on_user_input(session_id, question, choices, allow_freeform) + .await + } + + async fn on_elicitation( + &self, + session_id: SessionId, + request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult { + self.inner + .on_elicitation(session_id, request_id, request) + .await + } + + async fn on_exit_plan_mode( + &self, + session_id: SessionId, + data: ExitPlanModeData, + ) -> ExitPlanModeResult { + self.inner.on_exit_plan_mode(session_id, data).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{PermissionRequestData, RequestId, SessionId}; + + struct EchoTool; + + #[async_trait] + impl ToolHandler for EchoTool { + fn tool(&self) -> Tool { + Tool { + name: "echo".to_string(), + namespaced_name: None, + description: "Echo the input".to_string(), + parameters: tool_parameters(serde_json::json!({"type": "object"})), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + Ok(ToolResult::Text(inv.arguments.to_string())) + } + } + + #[test] + fn tool_handler_returns_tool_definition() { + let tool = EchoTool; + let def = tool.tool(); + assert_eq!(def.name, "echo"); + assert_eq!(def.description, "Echo the input"); + assert!(def.parameters.contains_key("type")); + } + + #[test] + fn try_tool_parameters_rejects_non_object_schema() { + let err = try_tool_parameters(serde_json::json!(["not", "an", "object"])) + .expect_err("non-object schemas should be rejected"); + + assert!(err.is_data()); + } + + #[tokio::test] + async fn tool_handler_call_returns_result() { + let tool = EchoTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "echo".to_string(), + arguments: serde_json::json!({"msg": "hello"}), + traceparent: None, + tracestate: None, + }; + + let result = tool.call(inv).await.unwrap(); + match result { + ToolResult::Text(s) => assert!(s.contains("hello")), + _ => panic!("expected Text result"), + } + } + + #[cfg(feature = "derive")] + #[tokio::test] + async fn define_tool_builds_schema_and_dispatches() { + use serde::Deserialize; + + #[derive(Deserialize, schemars::JsonSchema)] + struct Params { + city: String, + } + + let tool = define_tool( + "weather", + "Get the weather for a city", + |_inv, params: Params| async move { + Ok(ToolResult::Text(format!("sunny in {}", params.city))) + }, + ); + + let def = tool.tool(); + assert_eq!(def.name, "weather"); + assert_eq!(def.description, "Get the weather for a city"); + assert_eq!(def.parameters["type"], "object"); + assert!(def.parameters["properties"]["city"].is_object()); + + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "weather".to_string(), + arguments: serde_json::json!({"city": "Seattle"}), + traceparent: None, + tracestate: None, + }; + match tool.call(inv).await.unwrap() { + ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"), + _ => panic!("expected Text result"), + } + } + + #[tokio::test] + async fn router_dispatches_to_correct_handler() { + struct ToolA; + #[async_trait] + impl ToolHandler for ToolA { + fn tool(&self) -> Tool { + Tool { + name: "tool_a".to_string(), + namespaced_name: None, + description: "A".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("a_result".to_string())) + } + } + + struct ToolB; + #[async_trait] + impl ToolHandler for ToolB { + fn tool(&self) -> Tool { + Tool { + name: "tool_b".to_string(), + namespaced_name: None, + description: "B".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("b_result".to_string())) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(ToolA), Box::new(ToolB)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let tools = router.tools(); + assert_eq!(tools.len(), 2); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "tool_b".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Text(s) => assert_eq!(s, "b_result"), + _ => panic!("expected ToolResult::Text"), + } + } + + #[tokio::test] + async fn router_falls_through_for_unknown_tool() { + use std::sync::atomic::{AtomicBool, Ordering}; + + struct FallbackHandler { + called: AtomicBool, + } + #[async_trait] + impl SessionHandler for FallbackHandler { + async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult { + self.called.store(true, Ordering::Relaxed); + ToolResult::Text("fallback".to_string()) + } + } + + let fallback = Arc::new(FallbackHandler { + called: AtomicBool::new(false), + }); + let router = ToolHandlerRouter::new(vec![], fallback.clone()); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "unknown".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + assert!(fallback.called.load(Ordering::Relaxed)); + match response { + ToolResult::Text(s) => assert_eq!(s, "fallback"), + _ => panic!("expected fallback result"), + } + } + + #[tokio::test] + async fn router_returns_failure_on_handler_error() { + struct FailTool; + #[async_trait] + impl ToolHandler for FailTool { + fn tool(&self) -> Tool { + Tool { + name: "bad_tool".to_string(), + namespaced_name: None, + description: "Always fails".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Err(Error::Rpc { + code: -1, + message: "intentional failure".to_string(), + }) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(FailTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "bad_tool".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Expanded(exp) => { + assert_eq!(exp.result_type, "failure"); + assert!(exp.error.unwrap().contains("intentional failure")); + } + _ => panic!("expected expanded failure result"), + } + } + + #[tokio::test] + async fn router_forwards_non_tool_events() { + struct PermHandler; + #[async_trait] + impl SessionHandler for PermHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied + } + } + + let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler)); + + let response = router + .on_permission_request( + SessionId::from("s1"), + RequestId::new("r1"), + PermissionRequestData { + extra: serde_json::json!({}), + ..Default::default() + }, + ) + .await; + assert!(matches!(response, PermissionResult::Denied)); + } + + #[tokio::test] + async fn router_default_on_event_dispatches_via_per_event_methods() { + // Regression: callers using the legacy on_event entry point should + // still get correct dispatch through the inherited default impl. + use crate::handler::{HandlerEvent, HandlerResponse}; + + struct OkTool; + #[async_trait] + impl ToolHandler for OkTool { + fn tool(&self) -> Tool { + Tool { + name: "ok_tool".to_string(), + namespaced_name: None, + description: "ok".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("ok".to_string())) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(OkTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let response = router + .on_event(HandlerEvent::ExternalTool { + invocation: ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "ok_tool".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }, + }) + .await; + match response { + HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"), + _ => panic!("expected ToolResult via default on_event"), + } + } + + // Tests requiring `schemars` (the `derive` feature). + #[cfg(feature = "derive")] + mod derive_tests { + use serde::Deserialize; + + use super::super::*; + use crate::SessionId; + + #[derive(Deserialize, schemars::JsonSchema)] + struct GetWeatherParams { + /// City name to get weather for. + city: String, + /// Temperature unit (celsius or fahrenheit). + unit: Option, + } + + #[test] + fn schema_for_generates_clean_schema() { + let schema = schema_for::(); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"]["city"].is_object()); + assert!(schema["properties"]["unit"].is_object()); + // city is required (non-Option), unit is not + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::json!("city"))); + assert!(!required.contains(&serde_json::json!("unit"))); + // Root-level metadata stripped + assert!(schema.get("$schema").is_none()); + assert!(schema.get("title").is_none()); + } + + struct GetWeatherTool; + + #[async_trait] + impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + namespaced_name: None, + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; + Ok(ToolResult::Text(format!( + "{} {}", + params.city, + params.unit.unwrap_or_default() + ))) + } + } + + #[test] + fn tool_handler_with_schema_for() { + let tool = GetWeatherTool; + let def = tool.tool(); + assert_eq!(def.name, "get_weather"); + let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters"); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"]["city"].is_object()); + } + + #[tokio::test] + async fn tool_handler_deserializes_typed_params() { + let tool = GetWeatherTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}), + traceparent: None, + tracestate: None, + }; + + let result = tool.call(inv).await.unwrap(); + match result { + ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"), + _ => panic!("expected Text result"), + } + } + + #[tokio::test] + async fn tool_handler_returns_error_on_bad_params() { + let tool = GetWeatherTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"wrong_field": 42}), + traceparent: None, + tracestate: None, + }; + + let err = tool.call(inv).await.unwrap_err(); + assert!(matches!(err, Error::Json(_))); + } + + #[tokio::test] + async fn router_with_schema_for_tools() { + let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let tools = router.tools(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].name, "get_weather"); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"city": "Portland"}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Text(s) => assert!(s.contains("Portland")), + _ => panic!("expected ToolResult::Text"), + } + } + } +} diff --git a/rust/src/trace_context.rs b/rust/src/trace_context.rs new file mode 100644 index 000000000..287c87cbd --- /dev/null +++ b/rust/src/trace_context.rs @@ -0,0 +1,132 @@ +//! W3C Trace Context propagation for distributed tracing. +//! +//! The GitHub Copilot CLI propagates [W3C Trace Context] headers (`traceparent` +//! and `tracestate`) so SDK consumers can correlate spans created by the +//! CLI with their own observability pipelines. +//! +//! Two injection paths are supported: +//! +//! - **Per-turn override** via [`MessageOptions::traceparent`] / +//! [`MessageOptions::tracestate`](crate::types::MessageOptions::tracestate), +//! which take precedence when set. +//! - **Ambient callback** via +//! [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context), +//! which the SDK invokes before `session.create`, `session.resume`, and +//! `session.send` whenever the per-turn override is absent. +//! +//! [W3C Trace Context]: https://www.w3.org/TR/trace-context/ +//! [`MessageOptions::traceparent`]: crate::types::MessageOptions::traceparent + +use async_trait::async_trait; + +/// W3C Trace Context headers propagated to and from the GitHub Copilot CLI. +/// +/// `traceparent` carries the trace and parent-span identifiers; `tracestate` +/// carries vendor-specific extensions. Either field may be `None` when the +/// caller has nothing to propagate; in that case the corresponding wire +/// field is omitted. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct TraceContext { + /// `traceparent` HTTP header value. + pub traceparent: Option, + /// `tracestate` HTTP header value. + pub tracestate: Option, +} + +impl TraceContext { + /// Construct an empty [`TraceContext`]; both fields default to unset + /// (the SDK skips trace-context injection on the wire). + pub fn new() -> Self { + Self::default() + } + + /// Construct a [`TraceContext`] from a `traceparent` header value, with + /// no `tracestate`. + /// + /// Equivalent to `TraceContext::new().with_traceparent(value)`; kept + /// for ergonomics in the common single-header case. + pub fn from_traceparent(traceparent: impl Into) -> Self { + Self::new().with_traceparent(traceparent) + } + + /// Set or replace the `traceparent` header value, returning `self` for + /// chaining. + pub fn with_traceparent(mut self, traceparent: impl Into) -> Self { + self.traceparent = Some(traceparent.into()); + self + } + + /// Set or replace the `tracestate` header value, returning `self` for + /// chaining. + pub fn with_tracestate(mut self, tracestate: impl Into) -> Self { + self.tracestate = Some(tracestate.into()); + self + } + + /// Returns `true` when neither `traceparent` nor `tracestate` is set. + pub fn is_empty(&self) -> bool { + self.traceparent.is_none() && self.tracestate.is_none() + } +} + +/// Async provider that returns the current [`TraceContext`] for outbound +/// session RPCs. +/// +/// Set via +/// [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context). +/// The SDK invokes [`get_trace_context`](Self::get_trace_context) before +/// each `session.create`, `session.resume`, and `session.send` whenever +/// the call site does not carry a per-turn override. +/// +/// Implementations should handle errors internally and return +/// [`TraceContext::default()`] to skip injection โ€” no `Result` return type +/// is exposed because trace propagation is a best-effort observability +/// feature, not a correctness-critical RPC parameter. +#[async_trait] +pub trait TraceContextProvider: Send + Sync + 'static { + /// Return the current trace context, or [`TraceContext::default()`] to + /// skip injection. + async fn get_trace_context(&self) -> TraceContext; +} + +/// Inject `traceparent` / `tracestate` from `ctx` into the JSON `params` +/// object if either field is set. No-op when both are `None`. +pub(crate) fn inject_trace_context(params: &mut serde_json::Value, ctx: &TraceContext) { + if let Some(tp) = &ctx.traceparent { + params["traceparent"] = serde_json::Value::String(tp.clone()); + } + if let Some(ts) = &ctx.tracestate { + params["tracestate"] = serde_json::Value::String(ts.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::TraceContext; + + #[test] + fn new_yields_empty_context() { + let ctx = TraceContext::new(); + assert!(ctx.is_empty()); + assert!(ctx.traceparent.is_none()); + assert!(ctx.tracestate.is_none()); + } + + #[test] + fn builder_composes_traceparent_and_tracestate() { + let ctx = TraceContext::new() + .with_traceparent("00-trace-span-01") + .with_tracestate("vendor=key"); + assert_eq!(ctx.traceparent.as_deref(), Some("00-trace-span-01")); + assert_eq!(ctx.tracestate.as_deref(), Some("vendor=key")); + assert!(!ctx.is_empty()); + } + + #[test] + fn from_traceparent_matches_builder() { + let direct = TraceContext::from_traceparent("00-trace-span-01"); + let chained = TraceContext::new().with_traceparent("00-trace-span-01"); + assert_eq!(direct, chained); + } +} diff --git a/rust/src/transforms.rs b/rust/src/transforms.rs new file mode 100644 index 000000000..a090bc649 --- /dev/null +++ b/rust/src/transforms.rs @@ -0,0 +1,223 @@ +//! System message transform callbacks for customizing agent prompts. +//! +//! Implement [`SystemMessageTransform`](crate::transforms::SystemMessageTransform) to intercept and modify system prompt +//! sections during session creation. The CLI sends the current content for +//! each section the transform registered, and the SDK returns the modified +//! content. + +use std::collections::HashMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::types::SessionId; + +/// Context provided to every transform invocation. +#[derive(Debug, Clone)] +pub struct TransformContext { + /// The session being created or resumed. + pub session_id: SessionId, +} + +/// Handles `systemMessage.transform` RPC requests from the CLI. +/// +/// The CLI sends these during session creation/resumption when the session's +/// `SystemMessageConfig` contains sections with `action: "transform"`. For each +/// such section, the CLI provides the current content and expects the SDK to +/// return the (possibly modified) content. +/// +/// Implement this trait and pass it to [`Client::create_session`](crate::Client::create_session) / +/// [`Client::resume_session`](crate::Client::resume_session) to participate in system message customization. +/// +/// # Example +/// +/// ```ignore +/// struct MyTransform; +/// +/// #[async_trait::async_trait] +/// impl SystemMessageTransform for MyTransform { +/// fn section_ids(&self) -> Vec { +/// vec!["instructions".to_string()] +/// } +/// +/// async fn transform_section( +/// &self, +/// _section_id: &str, +/// content: &str, +/// _ctx: TransformContext, +/// ) -> Option { +/// Some(format!("{content}\n\nAlways be concise.")) +/// } +/// } +/// ``` +#[async_trait] +pub trait SystemMessageTransform: Send + Sync + 'static { + /// Section IDs this transform handles. + /// + /// The SDK injects `action: "transform"` entries into the + /// [`SystemMessageConfig`](crate::types::SystemMessageConfig) wire format + /// for each returned ID. + fn section_ids(&self) -> Vec; + + /// Transform a section's content. Return `Some(new_content)` to modify the + /// section, or `None` to pass through unchanged. + async fn transform_section( + &self, + section_id: &str, + content: &str, + ctx: TransformContext, + ) -> Option; +} + +/// Wire format for a single section in the transform request/response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct TransformSection { + pub(crate) content: String, +} + +/// Wire format for the `systemMessage.transform` response. +#[derive(Debug, Clone, Serialize)] +pub(crate) struct TransformResponse { + pub(crate) sections: HashMap, +} + +/// Apply transforms to the incoming sections map, returning the response. +/// +/// For each section, calls the matching transform if the implementor returns +/// `Some`; otherwise passes through the original content. +pub(crate) async fn dispatch_transform( + transform: &dyn SystemMessageTransform, + session_id: &SessionId, + sections: HashMap, +) -> TransformResponse { + let ctx = TransformContext { + session_id: session_id.clone(), + }; + + let mut result = HashMap::with_capacity(sections.len()); + for (section_id, data) in sections { + let content = match transform + .transform_section(§ion_id, &data.content, ctx.clone()) + .await + { + Some(transformed) => transformed, + None => data.content, + }; + result.insert(section_id, TransformSection { content }); + } + + TransformResponse { sections: result } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestTransform; + + #[async_trait] + impl SystemMessageTransform for TestTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string(), "context".to_string()] + } + + async fn transform_section( + &self, + section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + match section_id { + "instructions" => Some(format!("[modified] {content}")), + _ => None, + } + } + } + + #[tokio::test] + async fn dispatch_applies_matching_transform() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "instructions".to_string(), + TransformSection { + content: "be helpful".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!( + response.sections["instructions"].content, + "[modified] be helpful" + ); + } + + #[tokio::test] + async fn dispatch_passes_through_unhandled_section() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "context".to_string(), + TransformSection { + content: "original context".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!(response.sections["context"].content, "original context"); + } + + #[tokio::test] + async fn dispatch_unknown_section_passes_through() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "unknown".to_string(), + TransformSection { + content: "mystery".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!(response.sections["unknown"].content, "mystery"); + } + + #[tokio::test] + async fn dispatch_mixed_sections() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "instructions".to_string(), + TransformSection { + content: "help me".to_string(), + }, + ); + sections.insert( + "context".to_string(), + TransformSection { + content: "some context".to_string(), + }, + ); + sections.insert( + "other".to_string(), + TransformSection { + content: "other stuff".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!( + response.sections["instructions"].content, + "[modified] help me" + ); + assert_eq!(response.sections["context"].content, "some context"); + assert_eq!(response.sections["other"].content, "other stuff"); + } + + #[tokio::test] + async fn section_ids_returns_registered_sections() { + let transform = TestTransform; + let ids = transform.section_ids(); + assert_eq!(ids, vec!["instructions", "context"]); + } +} diff --git a/rust/src/types.rs b/rust/src/types.rs new file mode 100644 index 000000000..afb1d71f7 --- /dev/null +++ b/rust/src/types.rs @@ -0,0 +1,3672 @@ +//! Protocol types shared between the SDK and the GitHub Copilot CLI. +//! +//! These types map directly to the JSON-RPC request/response payloads +//! defined by the GitHub Copilot CLI protocol. They are used for session +//! configuration, event handling, tool invocations, and model queries. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::handler::SessionHandler; +use crate::hooks::SessionHooks; +pub use crate::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + SessionFsProvider, +}; +pub use crate::trace_context::{TraceContext, TraceContextProvider}; +use crate::transforms::SystemMessageTransform; + +/// Lifecycle state of a [`Client`](crate::Client) connection to the CLI. +/// +/// The state advances from `Connecting` โ†’ `Connected` during construction, +/// transitions to `Disconnected` after [`Client::stop`](crate::Client::stop) or +/// [`Client::force_stop`](crate::Client::force_stop), and lands in +/// `Error` if startup fails or the underlying transport tears down +/// unexpectedly. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum ConnectionState { + /// No CLI process is attached or the process has exited cleanly. + Disconnected, + /// The client is starting up (spawning the CLI, negotiating protocol). + Connecting, + /// The client is connected and ready to handle RPC traffic. + Connected, + /// Startup failed or the connection encountered an unrecoverable error. + Error, +} + +/// Type of [`SessionLifecycleEvent`] received via [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +/// +/// Values serialize as the dotted JSON strings the CLI sends (e.g. +/// `"session.created"`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[non_exhaustive] +pub enum SessionLifecycleEventType { + /// A new session was created. + #[serde(rename = "session.created")] + Created, + /// A session was deleted. + #[serde(rename = "session.deleted")] + Deleted, + /// A session's metadata was updated (e.g. summary regenerated). + #[serde(rename = "session.updated")] + Updated, + /// A session moved into the foreground. + #[serde(rename = "session.foreground")] + Foreground, + /// A session moved into the background. + #[serde(rename = "session.background")] + Background, +} + +/// Optional metadata attached to a [`SessionLifecycleEvent`]. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SessionLifecycleEventMetadata { + /// ISO-8601 timestamp the session was created. + #[serde(rename = "startTime")] + pub start_time: String, + /// ISO-8601 timestamp the session was last modified. + #[serde(rename = "modifiedTime")] + pub modified_time: String, + /// Optional generated summary of the session conversation so far. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// A `session.lifecycle` notification dispatched to subscribers obtained via +/// [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SessionLifecycleEvent { + /// The kind of lifecycle change this event represents. + #[serde(rename = "type")] + pub event_type: SessionLifecycleEventType, + /// Identifier of the session this event refers to. + #[serde(rename = "sessionId")] + pub session_id: SessionId, + /// Optional metadata describing the session at the time of the event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Opaque session identifier assigned by the CLI. +/// +/// A newtype wrapper around `String` that provides type safety โ€” prevents +/// accidentally passing a workspace ID or request ID where a session ID +/// is expected. Derefs to `str` for zero-friction borrowing. +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct SessionId(String); + +impl SessionId { + /// Create a new session ID from any string-like value. + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + /// Borrow the inner string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume the wrapper, returning the inner string. + pub fn into_inner(self) -> String { + self.0 + } +} + +impl std::ops::Deref for SessionId { + type Target = str; + + fn deref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for SessionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for SessionId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for SessionId { + fn from(s: &str) -> Self { + Self(s.to_owned()) + } +} + +impl AsRef for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::borrow::Borrow for SessionId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl From for String { + fn from(id: SessionId) -> String { + id.0 + } +} + +impl PartialEq for SessionId { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + +impl PartialEq for SessionId { + fn eq(&self, other: &String) -> bool { + &self.0 == other + } +} + +impl PartialEq for String { + fn eq(&self, other: &SessionId) -> bool { + self == &other.0 + } +} + +impl PartialEq<&str> for SessionId { + fn eq(&self, other: &&str) -> bool { + self.0 == *other + } +} + +impl PartialEq<&SessionId> for SessionId { + fn eq(&self, other: &&SessionId) -> bool { + self.0 == other.0 + } +} + +impl PartialEq for &SessionId { + fn eq(&self, other: &SessionId) -> bool { + self.0 == other.0 + } +} + +/// Opaque request identifier for pending CLI requests (permission, user-input, etc.). +/// +/// A newtype wrapper around `String` that provides type safety โ€” prevents +/// accidentally passing a session ID or workspace ID where a request ID +/// is expected. Derefs to `str` for zero-friction borrowing. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct RequestId(String); + +impl RequestId { + /// Create a new request ID from any string-like value. + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + /// Consume the wrapper, returning the inner string. + pub fn into_inner(self) -> String { + self.0 + } +} + +impl std::ops::Deref for RequestId { + type Target = str; + + fn deref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for RequestId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for RequestId { + fn from(s: &str) -> Self { + Self(s.to_owned()) + } +} + +impl AsRef for RequestId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::borrow::Borrow for RequestId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl From for String { + fn from(id: RequestId) -> String { + id.0 + } +} + +impl PartialEq for RequestId { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + +impl PartialEq for RequestId { + fn eq(&self, other: &String) -> bool { + &self.0 == other + } +} + +impl PartialEq for String { + fn eq(&self, other: &RequestId) -> bool { + self == &other.0 + } +} + +impl PartialEq<&str> for RequestId { + fn eq(&self, other: &&str) -> bool { + self.0 == *other + } +} + +/// A tool that the client exposes to the Copilot agent. +/// +/// Sent to the CLI as part of [`SessionConfig::tools`] / [`ResumeSessionConfig::tools`] +/// at session creation/resume time. The Rust SDK hand-authors this struct +/// (rather than using the schema-generated form) so it can carry runtime +/// hints โ€” `overrides_built_in_tool`, `skip_permission` โ€” that don't appear +/// in the wire schema but are honored by the CLI. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct Tool { + /// Tool identifier (e.g., `"bash"`, `"grep"`, `"str_replace_editor"`). + pub name: String, + /// Optional namespaced name for declarative filtering (e.g., `"playwright/navigate"` + /// for MCP tools). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub namespaced_name: Option, + /// Description of what the tool does. + #[serde(default)] + pub description: String, + /// Optional instructions for how to use this tool effectively. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// JSON Schema for the tool's input parameters. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub parameters: HashMap, + /// When `true`, this tool replaces a built-in tool of the same name + /// (e.g. supplying a custom `grep` that the agent uses in place of the + /// CLI's built-in implementation). + #[serde(default, skip_serializing_if = "is_false")] + pub overrides_built_in_tool: bool, + /// When `true`, the CLI does not request permission before invoking + /// this tool. Use with caution โ€” the tool is responsible for any + /// access control. + #[serde(default, skip_serializing_if = "is_false")] + pub skip_permission: bool, +} + +#[inline] +fn is_false(b: &bool) -> bool { + !*b +} + +impl Tool { + /// Construct a new [`Tool`] with the given name and otherwise default + /// values. The struct is `#[non_exhaustive]`, so external callers + /// cannot use struct-literal syntax โ€” use this builder or + /// [`Default::default`] plus mut-let. + /// + /// # Example + /// + /// ``` + /// # use github_copilot_sdk::types::Tool; + /// # use serde_json::json; + /// let tool = Tool::new("greet") + /// .with_description("Say hello to a user") + /// .with_parameters(json!({ + /// "type": "object", + /// "properties": { "name": { "type": "string" } }, + /// "required": ["name"] + /// })); + /// # let _ = tool; + /// ``` + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + ..Default::default() + } + } + + /// Set the namespaced name for declarative filtering (e.g. + /// `"playwright/navigate"` for MCP tools). + pub fn with_namespaced_name(mut self, namespaced_name: impl Into) -> Self { + self.namespaced_name = Some(namespaced_name.into()); + self + } + + /// Set the human-readable description of what the tool does. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = description.into(); + self + } + + /// Set optional instructions for how to use this tool effectively. + pub fn with_instructions(mut self, instructions: impl Into) -> Self { + self.instructions = Some(instructions.into()); + self + } + + /// Set the JSON Schema for the tool's input parameters. + /// + /// Accepts anything that converts into a JSON object, including a + /// `serde_json::Value` produced by `json!({...})`. Non-object values + /// are stored as an empty parameter map; callers that need direct + /// control over the field can construct a `HashMap` + /// and assign it to [`Tool::parameters`] via [`Default::default`]. + pub fn with_parameters(mut self, parameters: Value) -> Self { + self.parameters = parameters + .as_object() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) + .unwrap_or_default(); + self + } + + /// Mark this tool as overriding a built-in tool of the same name. + /// E.g. supplying a custom `grep` that the agent uses in place of the + /// CLI's built-in implementation. + pub fn with_overrides_built_in_tool(mut self, overrides: bool) -> Self { + self.overrides_built_in_tool = overrides; + self + } + + /// When `true`, the CLI will not request permission before invoking + /// this tool. Use with caution โ€” the tool is responsible for any + /// access control. + pub fn with_skip_permission(mut self, skip: bool) -> Self { + self.skip_permission = skip; + self + } +} + +/// Context passed to a [`CommandHandler`] when a registered slash command +/// is executed by the user. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct CommandContext { + /// Session ID where the command was invoked. + pub session_id: SessionId, + /// The full command text (e.g. `"/deploy production"`). + pub command: String, + /// Command name without the leading `/` (e.g. `"deploy"`). + pub command_name: String, + /// Raw argument string after the command name (e.g. `"production"`). + pub args: String, +} + +/// Handler invoked when a registered slash command is executed. +/// +/// Returning `Err(_)` causes the SDK to forward the error message back to +/// the CLI via `session.commands.handlePendingCommand` so the TUI can +/// surface it. Returning `Ok(())` reports success. +#[async_trait::async_trait] +pub trait CommandHandler: Send + Sync { + /// Called when the user invokes the command this handler is registered for. + async fn on_command(&self, ctx: CommandContext) -> Result<(), crate::Error>; +} + +/// Definition of a slash command registered with the session. +/// +/// When the CLI is running with a TUI, registered commands appear as +/// `/name` for the user to invoke. Only `name` and `description` are sent +/// over the wire โ€” the handler is local to this SDK process. +#[non_exhaustive] +#[derive(Clone)] +pub struct CommandDefinition { + /// Command name (without leading `/`). + pub name: String, + /// Human-readable description shown in command-completion UI. + pub description: Option, + /// Handler invoked when the command is executed. + pub handler: Arc, +} + +impl CommandDefinition { + /// Construct a new command definition. Use [`with_description`](Self::with_description) + /// to add a description. + pub fn new(name: impl Into, handler: Arc) -> Self { + Self { + name: name.into(), + description: None, + handler, + } + } + + /// Set the human-readable description shown in the CLI's command-completion UI. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } +} + +impl std::fmt::Debug for CommandDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandDefinition") + .field("name", &self.name) + .field("description", &self.description) + .field("handler", &"") + .finish() + } +} + +impl Serialize for CommandDefinition { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let len = if self.description.is_some() { 2 } else { 1 }; + let mut state = serializer.serialize_struct("CommandDefinition", len)?; + state.serialize_field("name", &self.name)?; + if let Some(description) = &self.description { + state.serialize_field("description", description)?; + } + state.end() + } +} + +/// Configures a custom agent (sub-agent) for the session. +/// +/// Custom agents have their own prompt, tool allowlist, and optionally +/// their own MCP servers and skill set. The agent named in +/// [`SessionConfig::agent`] (or the runtime default) is the active one +/// when the session starts. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct CustomAgentConfig { + /// Unique name of the custom agent. + pub name: String, + /// Display name for UI purposes. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub display_name: Option, + /// Description of what the agent does. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + /// List of tool names the agent can use. `None` means all tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Prompt content for the agent. + pub prompt: String, + /// MCP servers specific to this agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// Whether the agent is available for model inference. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub infer: Option, + /// Skill names to preload into this agent's context at startup. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub skills: Option>, +} + +impl CustomAgentConfig { + /// Construct a custom agent configuration with the required `name` + /// and `prompt` fields populated. + /// + /// All other fields default to unset; use the `with_*` chain to + /// customize them. Fields are also `pub` if direct assignment is + /// preferred for `Option` pass-through. + pub fn new(name: impl Into, prompt: impl Into) -> Self { + Self { + name: name.into(), + prompt: prompt.into(), + ..Self::default() + } + } + + /// Set the display name shown in the CLI's agent-selection UI. + pub fn with_display_name(mut self, display_name: impl Into) -> Self { + self.display_name = Some(display_name.into()); + self + } + + /// Set the description of what the agent does. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Restrict the agent to a specific tool allowlist. When unset, the + /// agent inherits the parent session's tool set. + pub fn with_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Configure agent-specific MCP servers. + pub fn with_mcp_servers(mut self, mcp_servers: HashMap) -> Self { + self.mcp_servers = Some(mcp_servers); + self + } + + /// Whether the agent participates in model inference. + pub fn with_infer(mut self, infer: bool) -> Self { + self.infer = Some(infer); + self + } + + /// Set the skills preloaded into the agent's context at startup. + pub fn with_skills(mut self, skills: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.skills = Some(skills.into_iter().map(Into::into).collect()); + self + } +} + +/// Configures the default (built-in) agent that handles turns when no +/// custom agent is selected. +/// +/// Use [`Self::excluded_tools`] to hide tools from the default agent +/// while keeping them available to custom sub-agents that list them in +/// their [`CustomAgentConfig::tools`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DefaultAgentConfig { + /// Tool names to exclude from the default agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, +} + +/// Configures infinite sessions: persistent workspaces with automatic +/// context-window compaction. +/// +/// When enabled (default), sessions automatically manage context limits +/// through background compaction and persist state to a workspace +/// directory. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct InfiniteSessionConfig { + /// Whether infinite sessions are enabled. Defaults to `true` on the CLI. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enabled: Option, + /// Context utilization (0.0โ€“1.0) at which background compaction starts. + /// Default: 0.80. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub background_compaction_threshold: Option, + /// Context utilization (0.0โ€“1.0) at which the session blocks until + /// compaction completes. Default: 0.95. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub buffer_exhaustion_threshold: Option, +} + +impl InfiniteSessionConfig { + /// Construct an empty [`InfiniteSessionConfig`]; all fields default to + /// unset (the CLI applies its own defaults). + pub fn new() -> Self { + Self::default() + } + + /// Toggle infinite sessions on or off. Defaults to `true` on the CLI + /// when unset. + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = Some(enabled); + self + } + + /// Set the context utilization (0.0โ€“1.0) at which background + /// compaction starts. + pub fn with_background_compaction_threshold(mut self, threshold: f64) -> Self { + self.background_compaction_threshold = Some(threshold); + self + } + + /// Set the context utilization (0.0โ€“1.0) at which the session blocks + /// until compaction completes. + pub fn with_buffer_exhaustion_threshold(mut self, threshold: f64) -> Self { + self.buffer_exhaustion_threshold = Some(threshold); + self + } +} + +/// Configuration for a single MCP server. +/// +/// MCP (Model Context Protocol) servers expose external tools to the +/// agent. Local servers run as a subprocess over stdio; remote servers +/// speak HTTP or Server-Sent Events. +/// +/// Serialized as a JSON object with a `type` discriminator (`"stdio"` | +/// `"http"` | `"sse"`). +/// +/// # Example +/// +/// ``` +/// # use github_copilot_sdk::types::{McpServerConfig, McpStdioServerConfig, McpHttpServerConfig}; +/// # use std::collections::HashMap; +/// let mut servers = HashMap::new(); +/// servers.insert( +/// "playwright".to_string(), +/// McpServerConfig::Stdio(McpStdioServerConfig { +/// tools: vec!["*".to_string()], +/// command: "npx".to_string(), +/// args: vec!["-y".to_string(), "@playwright/mcp".to_string()], +/// ..Default::default() +/// }), +/// ); +/// servers.insert( +/// "weather".to_string(), +/// McpServerConfig::Http(McpHttpServerConfig { +/// tools: vec!["forecast".to_string()], +/// url: "https://example.com/mcp".to_string(), +/// ..Default::default() +/// }), +/// ); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +#[non_exhaustive] +pub enum McpServerConfig { + /// Local MCP server launched as a subprocess and addressed over stdio. + /// On the wire this serializes as `{"type": "stdio", ...}`. The CLI + /// also accepts `"local"` as an alias on input. + #[serde(alias = "local")] + Stdio(McpStdioServerConfig), + /// Remote MCP server addressed over HTTP. + Http(McpHttpServerConfig), + /// Remote MCP server addressed over Server-Sent Events. + Sse(McpHttpServerConfig), +} + +/// Configuration for a local/stdio MCP server. +/// +/// See [`McpServerConfig::Stdio`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpStdioServerConfig { + /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. + #[serde(default)] + pub tools: Vec, + /// Optional timeout in milliseconds for tool calls to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Subprocess executable. + pub command: String, + /// Arguments to pass to the subprocess. + #[serde(default)] + pub args: Vec, + /// Environment variables to set on the subprocess. + /// + /// Interpretation depends on the parent session's + /// `env_value_mode`: `"direct"` (default) treats values as literals; + /// `"indirect"` treats them as env-var names to look up at start time. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub env: HashMap, + /// Working directory for the subprocess. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, +} + +/// Configuration for a remote MCP server (HTTP or SSE). +/// +/// See [`McpServerConfig::Http`] and [`McpServerConfig::Sse`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpHttpServerConfig { + /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. + #[serde(default)] + pub tools: Vec, + /// Optional timeout in milliseconds for tool calls to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Server URL. + pub url: String, + /// Optional HTTP headers to include on every request. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub headers: HashMap, +} + +/// Configures a custom inference provider (BYOK โ€” Bring Your Own Key). +/// +/// Routes session requests through an alternative model provider +/// (OpenAI-compatible, Azure, Anthropic, or local) instead of GitHub +/// Copilot's default routing. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ProviderConfig { + /// Provider type: `"openai"`, `"azure"`, or `"anthropic"`. Defaults to + /// `"openai"` on the CLI. + #[serde(default, skip_serializing_if = "Option::is_none", rename = "type")] + pub provider_type: Option, + /// API format (openai/azure only): `"completions"` or `"responses"`. + /// Defaults to `"completions"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub wire_api: Option, + /// API endpoint URL. + pub base_url: String, + /// API key. Optional for local providers like Ollama. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key: Option, + /// Bearer token for authentication. Sets the `Authorization` header + /// directly. Use for services requiring bearer-token auth instead of + /// API key. Takes precedence over `api_key` when both are set. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bearer_token: Option, + /// Azure-specific options. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure: Option, + /// Custom HTTP headers included in outbound provider requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub headers: Option>, +} + +impl ProviderConfig { + /// Construct a [`ProviderConfig`] with the required `base_url` set; + /// all other fields default to unset. + pub fn new(base_url: impl Into) -> Self { + Self { + base_url: base_url.into(), + ..Self::default() + } + } + + /// Set the provider type (`"openai"`, `"azure"`, or `"anthropic"`). + pub fn with_provider_type(mut self, provider_type: impl Into) -> Self { + self.provider_type = Some(provider_type.into()); + self + } + + /// Set the API format (`"completions"` or `"responses"`; openai/azure only). + pub fn with_wire_api(mut self, wire_api: impl Into) -> Self { + self.wire_api = Some(wire_api.into()); + self + } + + /// Set the API key. Optional for local providers like Ollama. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set the bearer token used to populate the `Authorization` header. + /// Takes precedence over `api_key` when both are set. + pub fn with_bearer_token(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set Azure-specific options. + pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { + self.azure = Some(azure); + self + } + + /// Set the custom HTTP headers attached to outbound provider requests. + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers = Some(headers); + self + } +} + +/// Azure-specific provider options. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AzureProviderOptions { + /// Azure API version. Defaults to `"2024-10-21"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_version: Option, +} + +/// Configuration for creating a new session via the `session.create` RPC. +/// +/// All fields are optional โ€” the CLI applies sensible defaults. +/// +/// # Construction +/// +/// Two equivalent shapes are supported: +/// +/// 1. **Chained builder** (preferred for compile-time-known values): +/// +/// ``` +/// # use github_copilot_sdk::types::SessionConfig; +/// let cfg = SessionConfig::default() +/// .with_client_name("my-app") +/// .with_streaming(true) +/// .with_enable_config_discovery(true); +/// ``` +/// +/// 2. **Direct field assignment** (preferred when forwarding `Option` +/// from upstream code, since `with_` setters take the inner +/// `T`, not `Option`): +/// +/// ``` +/// # use github_copilot_sdk::types::SessionConfig; +/// # let upstream_model: Option = None; +/// # let upstream_system_message: Option = None; +/// let mut cfg = SessionConfig::default() +/// .with_client_name("my-app") +/// .with_streaming(true); +/// cfg.model = upstream_model; +/// cfg.system_message = upstream_system_message; +/// ``` +/// +/// Mixing the two is fine: chain the fields you know at compile time, +/// then assign the `Option` pass-through fields directly. All +/// fields on this struct are `pub`. This pattern matches the +/// `http::request::Parts` / `hyper::Body::Builder` convention in the +/// wider Rust ecosystem. +/// +/// # Field naming across SDKs +/// +/// Rust field names are snake_case (`available_tools`, `system_message`); +/// they round-trip to the camelCase wire protocol via `#[serde(rename_all = +/// "camelCase")]`. When porting code from the TypeScript, Go, Python, or +/// .NET SDKs โ€” or reading the raw JSON-RPC traces โ€” fields appear as +/// `availableTools`, `systemMessage`, etc. +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct SessionConfig { + /// Custom session ID. When unset, the CLI generates one. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Model to use (e.g. `"gpt-4"`, `"claude-sonnet-4"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Application name sent as `User-Agent` context. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// Reasoning effort level (e.g. `"low"`, `"medium"`, `"high"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Enable streaming token deltas via `assistant.message_delta` events. + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + /// Custom system message configuration. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + /// Client-defined tools to expose to the agent. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Allowlist of built-in tool names the agent may use. + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + /// Blocklist of built-in tool names the agent must not use. + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + /// MCP server configurations passed through to the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// How the CLI interprets env values in MCP server configs. + /// `"direct"` = literal values; `"indirect"` = env var names to look up. + #[serde(skip_serializing_if = "Option::is_none")] + pub env_value_mode: Option, + /// When true, the CLI runs config discovery (MCP config files, skills, plugins). + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + /// Enable the `ask_user` tool for interactive user input. Defaults to + /// `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_user_input: Option, + /// Enable `permission.request` JSON-RPC calls from the CLI. Defaults + /// to `Some(true)` via [`SessionConfig::default`]; the default + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) refuses all + /// requests so the wire surface is safe out-of-the-box. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_permission: Option, + /// Enable `exitPlanMode.request` JSON-RPC calls for plan approval. + /// Defaults to `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_exit_plan_mode: Option, + /// Enable `autoModeSwitch.request` JSON-RPC calls. When `true`, the CLI + /// asks the handler whether to switch to auto model when an eligible + /// rate limit is hit. Defaults to `Some(true)` via + /// [`SessionConfig::default`]. Without this flag, the CLI surfaces the + /// rate-limit error directly without offering the auto-mode switch. + /// + /// Currently a Rust-only typed handler; cross-SDK parity (Node / + /// Python / Go / .NET) is post-release follow-up work โ€” see + /// [`SessionHandler::on_auto_mode_switch`]. + /// + /// [`SessionHandler::on_auto_mode_switch`]: crate::handler::SessionHandler::on_auto_mode_switch + #[serde(skip_serializing_if = "Option::is_none")] + pub request_auto_mode_switch: Option, + /// Advertise elicitation provider capability. When true, the CLI sends + /// `elicitation.requested` events that the handler can respond to. + /// Defaults to `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_elicitation: Option, + /// Skill directory paths passed through to the GitHub Copilot CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + /// Additional directories to search for custom instruction files. + /// Forwarded to the CLI; not the same as [`skill_directories`](Self::skill_directories). + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + /// Skill names to disable. Skills in this set will not be available + /// even if found in skill directories. + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + /// MCP server names to disable. Servers in this set will not be + /// started or connected. + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_mcp_servers: Option>, + /// Enable session hooks. When `true`, the CLI sends `hooks.invoke` + /// RPC requests at key lifecycle points (pre/post tool use, prompt + /// submission, session start/end, errors). + #[serde(skip_serializing_if = "Option::is_none")] + pub hooks: Option, + /// Custom agents (sub-agents) configured for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + /// Configures the built-in default agent. Use `excluded_tools` to + /// hide tools from the default agent while keeping them available + /// to custom sub-agents that reference them in their `tools` list. + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + /// Name of the custom agent to activate when the session starts. + /// Must match the `name` of one of the agents in [`Self::custom_agents`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + /// Configures infinite sessions: persistent workspace + automatic + /// context-window compaction. Enabled by default on the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + /// Custom model provider (BYOK). When set, the session routes + /// requests through this provider instead of the default Copilot + /// routing. + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + /// Per-property overrides for model capabilities, deep-merged over + /// runtime defaults. + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Override the default configuration directory location. When set, + /// the session uses this directory for storing config and state. + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + /// Working directory for the session. Tool operations resolve + /// relative paths against this directory. + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + /// Per-session GitHub token. Distinct from + /// [`ClientOptions::github_token`](crate::ClientOptions::github_token), + /// which authenticates the CLI process itself; this token determines + /// the GitHub identity used for content exclusion, model routing, and + /// quota checks for *this session*. + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + /// Forward sub-agent streaming events to this connection. When false, + /// only non-streaming sub-agent events and `subagent.*` lifecycle events + /// are delivered. Defaults to true on the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + /// Slash commands registered for this session. When the CLI has a TUI, + /// each command appears as `/name` for the user to invoke and the + /// associated [`CommandHandler`] is called when executed. + #[serde(skip_serializing_if = "Option::is_none", skip_deserializing)] + pub commands: Option>, + /// Custom session filesystem provider for this session. Required when + /// the [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs) set. + /// See [`SessionFsProvider`]. + #[serde(skip)] + pub session_fs_provider: Option>, + /// Session-level event handler. The default is + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) โ€” permission + /// requests are denied; other events are no-ops. Use + /// [`with_handler`](Self::with_handler) to install a custom handler. + #[serde(skip)] + pub handler: Option>, + /// Session lifecycle hook handler (pre/post tool use, session + /// start/end, etc.). When set, the SDK auto-enables the wire-level + /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. + #[serde(skip)] + pub hooks_handler: Option>, + /// System-message transform. When set, the SDK injects the matching + /// `action: "transform"` sections into the system message and routes + /// `systemMessage.transform` RPC callbacks to it during the session. + /// Use [`with_transform`](Self::with_transform) to install one. + #[serde(skip)] + pub transform: Option>, +} + +impl std::fmt::Debug for SessionConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionConfig") + .field("session_id", &self.session_id) + .field("model", &self.model) + .field("client_name", &self.client_name) + .field("reasoning_effort", &self.reasoning_effort) + .field("streaming", &self.streaming) + .field("system_message", &self.system_message) + .field("tools", &self.tools) + .field("available_tools", &self.available_tools) + .field("excluded_tools", &self.excluded_tools) + .field("mcp_servers", &self.mcp_servers) + .field("env_value_mode", &self.env_value_mode) + .field("enable_config_discovery", &self.enable_config_discovery) + .field("request_user_input", &self.request_user_input) + .field("request_permission", &self.request_permission) + .field("request_exit_plan_mode", &self.request_exit_plan_mode) + .field("request_auto_mode_switch", &self.request_auto_mode_switch) + .field("request_elicitation", &self.request_elicitation) + .field("skill_directories", &self.skill_directories) + .field("instruction_directories", &self.instruction_directories) + .field("disabled_skills", &self.disabled_skills) + .field("disabled_mcp_servers", &self.disabled_mcp_servers) + .field("hooks", &self.hooks) + .field("custom_agents", &self.custom_agents) + .field("default_agent", &self.default_agent) + .field("agent", &self.agent) + .field("infinite_sessions", &self.infinite_sessions) + .field("provider", &self.provider) + .field("model_capabilities", &self.model_capabilities) + .field("config_dir", &self.config_dir) + .field("working_directory", &self.working_directory) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field( + "include_sub_agent_streaming_events", + &self.include_sub_agent_streaming_events, + ) + .field("commands", &self.commands) + .field( + "session_fs_provider", + &self.session_fs_provider.as_ref().map(|_| ""), + ) + .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "hooks_handler", + &self.hooks_handler.as_ref().map(|_| ""), + ) + .field("transform", &self.transform.as_ref().map(|_| "")) + .finish() + } +} + +impl Default for SessionConfig { + /// Permission and elicitation flows are enabled by default. With + /// Rust's trait-based handlers, the SDK installs `DenyAllHandler` when + /// no handler is provided, so these flags being `Some(true)` means the + /// wire surface advertises the capabilities โ€” and the default handler + /// safely refuses requests. Callers that want the wire surface fully + /// disabled set these explicitly to `Some(false)`. + fn default() -> Self { + Self { + session_id: None, + model: None, + client_name: None, + reasoning_effort: None, + streaming: None, + system_message: None, + tools: None, + available_tools: None, + excluded_tools: None, + mcp_servers: None, + env_value_mode: None, + enable_config_discovery: None, + request_user_input: Some(true), + request_permission: Some(true), + request_exit_plan_mode: Some(true), + request_auto_mode_switch: Some(true), + request_elicitation: Some(true), + skill_directories: None, + instruction_directories: None, + disabled_skills: None, + disabled_mcp_servers: None, + hooks: None, + custom_agents: None, + default_agent: None, + agent: None, + infinite_sessions: None, + provider: None, + model_capabilities: None, + config_dir: None, + working_directory: None, + github_token: None, + include_sub_agent_streaming_events: None, + commands: None, + session_fs_provider: None, + handler: None, + hooks_handler: None, + transform: None, + } + } +} + +impl SessionConfig { + /// Install a custom [`SessionHandler`] for this session. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } + + /// Register slash commands for this session. Each command appears as + /// `/name` in the CLI's TUI; the handler is invoked when the user + /// executes the command. Replaces any commands previously set on this + /// config. See [`CommandDefinition`]. + pub fn with_commands(mut self, commands: Vec) -> Self { + self.commands = Some(commands); + self + } + + /// Install a [`SessionFsProvider`] backing the session's filesystem. + /// Required when the [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs). + pub fn with_session_fs_provider(mut self, provider: Arc) -> Self { + self.session_fs_provider = Some(provider); + self + } + + /// Install a [`SessionHooks`] handler. Automatically enables the + /// wire-level `hooks` flag on session creation. + pub fn with_hooks(mut self, hooks: Arc) -> Self { + self.hooks_handler = Some(hooks); + self + } + + /// Install a [`SystemMessageTransform`]. The SDK injects the matching + /// `action: "transform"` sections into the system message and routes + /// `systemMessage.transform` RPC callbacks to it during the session. + pub fn with_transform(mut self, transform: Arc) -> Self { + self.transform = Some(transform); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-approved. Forwards every non-permission event to the inner + /// handler unchanged. + /// + /// If no handler has been installed via [`with_handler`](Self::with_handler), + /// wraps a [`DenyAllHandler`](crate::handler::DenyAllHandler) โ€” useful + /// when you only care about permission policy and want the trait + /// fallback responses for everything else. + /// + /// Order-independent: `with_handler(...).approve_all_permissions()` and + /// `approve_all_permissions().with_handler(...)` are NOT equivalent โ€” + /// the second form discards the wrap because `with_handler` overwrites + /// the handler field. Always call `approve_all_permissions` *after* + /// `with_handler`. + pub fn approve_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_all(inner)); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-denied. See [`approve_all_permissions`](Self::approve_all_permissions) + /// for ordering and default-handler semantics. + pub fn deny_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::deny_all(inner)); + self + } + + /// Wrap the configured handler with a closure-based permission policy: + /// `predicate` is called for each permission request; `true` approves, + /// `false` denies. See + /// [`approve_all_permissions`](Self::approve_all_permissions) for + /// ordering and default-handler semantics. + pub fn approve_permissions_if(mut self, predicate: F) -> Self + where + F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, + { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_if(inner, predicate)); + self + } + + /// Set a custom session ID (when unset, the CLI generates one). + pub fn with_session_id(mut self, id: impl Into) -> Self { + self.session_id = Some(id.into()); + self + } + + /// Set the model identifier (e.g. `"claude-sonnet-4"`). + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); + self + } + + /// Set the application name sent as `User-Agent` context. + pub fn with_client_name(mut self, name: impl Into) -> Self { + self.client_name = Some(name.into()); + self + } + + /// Set the reasoning effort level (e.g. `"low"`, `"medium"`, `"high"`). + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + + /// Enable streaming token deltas via `assistant.message_delta` events. + pub fn with_streaming(mut self, streaming: bool) -> Self { + self.streaming = Some(streaming); + self + } + + /// Set a custom system message configuration. + pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self { + self.system_message = Some(system_message); + self + } + + /// Set the client-defined tools to expose to the agent. + pub fn with_tools>(mut self, tools: I) -> Self { + self.tools = Some(tools.into_iter().collect()); + self + } + + /// Set the allowlist of built-in tool names the agent may use. + pub fn with_available_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.available_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Set the blocklist of built-in tool names the agent must not use. + pub fn with_excluded_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.excluded_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Set MCP server configurations passed through to the CLI. + pub fn with_mcp_servers(mut self, servers: HashMap) -> Self { + self.mcp_servers = Some(servers); + self + } + + /// Set how the CLI interprets env values in MCP server configs + /// (`"direct"` literal vs `"indirect"` env var name lookup). + pub fn with_env_value_mode(mut self, mode: impl Into) -> Self { + self.env_value_mode = Some(mode.into()); + self + } + + /// Enable or disable CLI config discovery (MCP config files, skills, plugins). + pub fn with_enable_config_discovery(mut self, enable: bool) -> Self { + self.enable_config_discovery = Some(enable); + self + } + + /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::default`]. + pub fn with_request_user_input(mut self, enable: bool) -> Self { + self.request_user_input = Some(enable); + self + } + + /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_permission(mut self, enable: bool) -> Self { + self.request_permission = Some(enable); + self + } + + /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { + self.request_exit_plan_mode = Some(enable); + self + } + + /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { + self.request_auto_mode_switch = Some(enable); + self + } + + /// Advertise elicitation provider capability. Defaults to `Some(true)`. + pub fn with_request_elicitation(mut self, enable: bool) -> Self { + self.request_elicitation = Some(enable); + self + } + + /// Set skill directory paths passed through to the CLI. + pub fn with_skill_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.skill_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Set additional directories to search for custom instruction files. + /// Forwarded to the CLI on session create; not the same as + /// [`with_skill_directories`](Self::with_skill_directories). + pub fn with_instruction_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.instruction_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Set the names of skills to disable (overrides skill discovery). + pub fn with_disabled_skills(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.disabled_skills = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Set the names of MCP servers to disable. + pub fn with_disabled_mcp_servers(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.disabled_mcp_servers = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Set the custom agents (sub-agents) configured for this session. + pub fn with_custom_agents>( + mut self, + agents: I, + ) -> Self { + self.custom_agents = Some(agents.into_iter().collect()); + self + } + + /// Configure the built-in default agent. + pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self { + self.default_agent = Some(agent); + self + } + + /// Activate a named custom agent on session start. Must match the + /// `name` of one of the agents in [`Self::custom_agents`]. + pub fn with_agent(mut self, name: impl Into) -> Self { + self.agent = Some(name.into()); + self + } + + /// Configure infinite sessions (persistent workspace + automatic + /// context-window compaction). + pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self { + self.infinite_sessions = Some(config); + self + } + + /// Configure a custom model provider (BYOK). + pub fn with_provider(mut self, provider: ProviderConfig) -> Self { + self.provider = Some(provider); + self + } + + /// Set per-property overrides for model capabilities. + pub fn with_model_capabilities( + mut self, + capabilities: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(capabilities); + self + } + + /// Override the default configuration directory location. + pub fn with_config_dir(mut self, dir: impl Into) -> Self { + self.config_dir = Some(dir.into()); + self + } + + /// Set the per-session working directory. Tool operations resolve + /// relative paths against this directory. + pub fn with_working_directory(mut self, dir: impl Into) -> Self { + self.working_directory = Some(dir.into()); + self + } + + /// Set the per-session GitHub token. Distinct from + /// [`ClientOptions::github_token`](crate::ClientOptions::github_token); + /// this token determines the GitHub identity used for content exclusion, + /// model routing, and quota checks for this session only. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Forward sub-agent streaming events to this connection. Defaults + /// to true on the CLI when unset. + pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self { + self.include_sub_agent_streaming_events = Some(include); + self + } +} + +/// Configuration for resuming an existing session via the `session.resume` RPC. +/// +/// See [`SessionConfig`] for the construction patterns (chained `with_*` +/// builder vs. direct field assignment for `Option` pass-through) and +/// the note on snake_case vs. camelCase field naming. +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ResumeSessionConfig { + /// ID of the session to resume. + pub session_id: SessionId, + /// Application name sent as User-Agent context. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// Enable streaming token deltas. + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + /// Re-supply the system message so the agent retains workspace context + /// across CLI process restarts. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + /// Client-defined tools to re-supply on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Blocklist of built-in tool names. + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + /// Re-supply MCP servers so they remain available after app restart. + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// How the CLI interprets env values in MCP configs. + #[serde(skip_serializing_if = "Option::is_none")] + pub env_value_mode: Option, + /// Enable config discovery on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + /// Enable the ask_user tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_user_input: Option, + /// Enable permission request RPCs. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_permission: Option, + /// Enable exit-plan-mode request RPCs. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_exit_plan_mode: Option, + /// Enable auto-mode-switch request RPCs on resume. Defaults to + /// `Some(true)` via [`ResumeSessionConfig::new`]. See + /// [`SessionConfig::request_auto_mode_switch`] for details. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_auto_mode_switch: Option, + /// Advertise elicitation provider capability on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_elicitation: Option, + /// Skill directory paths passed through to the GitHub Copilot CLI on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + /// Additional directories to search for custom instruction files on + /// resume. Forwarded to the CLI; not the same as [`skill_directories`](Self::skill_directories). + #[serde(skip_serializing_if = "Option::is_none")] + pub instruction_directories: Option>, + /// Enable session hooks on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub hooks: Option, + /// Custom agents to re-supply on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + /// Configures the built-in default agent on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + /// Name of the custom agent to activate. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + /// Re-supply infinite session configuration on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + /// Re-supply BYOK provider configuration on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + /// Per-property model capability overrides on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Override the default configuration directory location on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + /// Per-session working directory on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + /// Per-session GitHub token on resume. See + /// [`SessionConfig::github_token`]. + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + /// Forward sub-agent streaming events to this connection on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + /// Slash commands registered for this session on resume. See + /// [`SessionConfig::commands`] โ€” commands are not persisted server-side, + /// so the resume payload re-supplies the registration. + #[serde(skip_serializing_if = "Option::is_none", skip_deserializing)] + pub commands: Option>, + /// Custom session filesystem provider. Required on resume when the + /// [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs). + /// See [`SessionConfig::session_fs_provider`]. + #[serde(skip)] + pub session_fs_provider: Option>, + /// Force-fail resume if the session does not exist on disk, instead of + /// silently starting a new session. + #[serde(skip_serializing_if = "Option::is_none")] + pub disable_resume: Option, + /// When `true`, instructs the runtime to continue any tool calls or + /// permission requests that were pending when the previous connection + /// was dropped. Use this together with [`Client::force_stop`] to hand + /// off a session from one process to another without losing in-flight + /// work. + /// + /// [`Client::force_stop`]: crate::Client::force_stop + #[serde(skip_serializing_if = "Option::is_none")] + pub continue_pending_work: Option, + /// Session-level event handler. See [`SessionConfig::handler`]. + #[serde(skip)] + pub handler: Option>, + /// Session hook handler. See [`SessionConfig::hooks_handler`]. + #[serde(skip)] + pub hooks_handler: Option>, + /// System-message transform. See [`SessionConfig::transform`]. + #[serde(skip)] + pub transform: Option>, +} + +impl std::fmt::Debug for ResumeSessionConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResumeSessionConfig") + .field("session_id", &self.session_id) + .field("client_name", &self.client_name) + .field("streaming", &self.streaming) + .field("system_message", &self.system_message) + .field("tools", &self.tools) + .field("excluded_tools", &self.excluded_tools) + .field("mcp_servers", &self.mcp_servers) + .field("env_value_mode", &self.env_value_mode) + .field("enable_config_discovery", &self.enable_config_discovery) + .field("request_user_input", &self.request_user_input) + .field("request_permission", &self.request_permission) + .field("request_exit_plan_mode", &self.request_exit_plan_mode) + .field("request_auto_mode_switch", &self.request_auto_mode_switch) + .field("request_elicitation", &self.request_elicitation) + .field("skill_directories", &self.skill_directories) + .field("instruction_directories", &self.instruction_directories) + .field("hooks", &self.hooks) + .field("custom_agents", &self.custom_agents) + .field("default_agent", &self.default_agent) + .field("agent", &self.agent) + .field("infinite_sessions", &self.infinite_sessions) + .field("provider", &self.provider) + .field("model_capabilities", &self.model_capabilities) + .field("config_dir", &self.config_dir) + .field("working_directory", &self.working_directory) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field( + "include_sub_agent_streaming_events", + &self.include_sub_agent_streaming_events, + ) + .field("commands", &self.commands) + .field( + "session_fs_provider", + &self.session_fs_provider.as_ref().map(|_| ""), + ) + .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "hooks_handler", + &self.hooks_handler.as_ref().map(|_| ""), + ) + .field("transform", &self.transform.as_ref().map(|_| "")) + .field("disable_resume", &self.disable_resume) + .field("continue_pending_work", &self.continue_pending_work) + .finish() + } +} + +impl ResumeSessionConfig { + /// Construct a `ResumeSessionConfig` with the given session ID and all + /// other fields left unset. Combine with `.with_*` builders or struct + /// update syntax (`..ResumeSessionConfig::new(id)`) to populate the + /// fields you need. + pub fn new(session_id: SessionId) -> Self { + Self { + session_id, + client_name: None, + streaming: None, + system_message: None, + tools: None, + excluded_tools: None, + mcp_servers: None, + env_value_mode: None, + enable_config_discovery: None, + request_user_input: Some(true), + request_permission: Some(true), + request_exit_plan_mode: Some(true), + request_auto_mode_switch: Some(true), + request_elicitation: Some(true), + skill_directories: None, + instruction_directories: None, + hooks: None, + custom_agents: None, + default_agent: None, + agent: None, + infinite_sessions: None, + provider: None, + model_capabilities: None, + config_dir: None, + working_directory: None, + github_token: None, + include_sub_agent_streaming_events: None, + commands: None, + session_fs_provider: None, + disable_resume: None, + continue_pending_work: None, + handler: None, + hooks_handler: None, + transform: None, + } + } + + /// Install a custom [`SessionHandler`] for this session. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } + + /// Install a [`SessionHooks`] handler. Automatically enables the + /// wire-level `hooks` flag on session resumption. + pub fn with_hooks(mut self, hooks: Arc) -> Self { + self.hooks_handler = Some(hooks); + self + } + + /// Install a [`SystemMessageTransform`]. + pub fn with_transform(mut self, transform: Arc) -> Self { + self.transform = Some(transform); + self + } + + /// Register slash commands for the resumed session. See + /// [`SessionConfig::with_commands`] โ€” commands are not persisted + /// server-side, so the resume payload re-supplies the registration. + pub fn with_commands(mut self, commands: Vec) -> Self { + self.commands = Some(commands); + self + } + + /// Install a [`SessionFsProvider`] backing the resumed session's + /// filesystem. See [`SessionConfig::with_session_fs_provider`]. + pub fn with_session_fs_provider(mut self, provider: Arc) -> Self { + self.session_fs_provider = Some(provider); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-approved. See + /// [`SessionConfig::approve_all_permissions`] for semantics. + pub fn approve_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_all(inner)); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-denied. See + /// [`SessionConfig::deny_all_permissions`] for semantics. + pub fn deny_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::deny_all(inner)); + self + } + + /// Wrap the configured handler with a predicate-based permission policy. + /// See [`SessionConfig::approve_permissions_if`] for semantics. + pub fn approve_permissions_if(mut self, predicate: F) -> Self + where + F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, + { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_if(inner, predicate)); + self + } + + /// Set the application name sent as `User-Agent` context. + pub fn with_client_name(mut self, name: impl Into) -> Self { + self.client_name = Some(name.into()); + self + } + + /// Enable streaming token deltas via `assistant.message_delta` events. + pub fn with_streaming(mut self, streaming: bool) -> Self { + self.streaming = Some(streaming); + self + } + + /// Re-supply the system message so the agent retains workspace context + /// across CLI process restarts. + pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self { + self.system_message = Some(system_message); + self + } + + /// Re-supply client-defined tools on resume. + pub fn with_tools>(mut self, tools: I) -> Self { + self.tools = Some(tools.into_iter().collect()); + self + } + + /// Set the blocklist of built-in tool names the agent must not use. + pub fn with_excluded_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.excluded_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Re-supply MCP server configurations on resume. + pub fn with_mcp_servers(mut self, servers: HashMap) -> Self { + self.mcp_servers = Some(servers); + self + } + + /// Set how the CLI interprets env values in MCP configs (`"direct"` / + /// `"indirect"`). + pub fn with_env_value_mode(mut self, mode: impl Into) -> Self { + self.env_value_mode = Some(mode.into()); + self + } + + /// Enable or disable CLI config discovery on resume. + pub fn with_enable_config_discovery(mut self, enable: bool) -> Self { + self.enable_config_discovery = Some(enable); + self + } + + /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::new`]. + pub fn with_request_user_input(mut self, enable: bool) -> Self { + self.request_user_input = Some(enable); + self + } + + /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_permission(mut self, enable: bool) -> Self { + self.request_permission = Some(enable); + self + } + + /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { + self.request_exit_plan_mode = Some(enable); + self + } + + /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { + self.request_auto_mode_switch = Some(enable); + self + } + + /// Advertise elicitation provider capability on resume. Defaults to `Some(true)`. + pub fn with_request_elicitation(mut self, enable: bool) -> Self { + self.request_elicitation = Some(enable); + self + } + + /// Set skill directory paths passed through to the CLI on resume. + pub fn with_skill_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.skill_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Set additional directories to search for custom instruction files + /// on resume. Forwarded to the CLI; not the same as + /// [`with_skill_directories`](Self::with_skill_directories). + pub fn with_instruction_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.instruction_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Re-supply custom agents on resume. + pub fn with_custom_agents>( + mut self, + agents: I, + ) -> Self { + self.custom_agents = Some(agents.into_iter().collect()); + self + } + + /// Configure the built-in default agent on resume. + pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self { + self.default_agent = Some(agent); + self + } + + /// Activate a named custom agent on resume. + pub fn with_agent(mut self, name: impl Into) -> Self { + self.agent = Some(name.into()); + self + } + + /// Re-supply infinite session configuration on resume. + pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self { + self.infinite_sessions = Some(config); + self + } + + /// Re-supply BYOK provider configuration on resume. + pub fn with_provider(mut self, provider: ProviderConfig) -> Self { + self.provider = Some(provider); + self + } + + /// Set per-property model capability overrides on resume. + pub fn with_model_capabilities( + mut self, + capabilities: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(capabilities); + self + } + + /// Override the default configuration directory location on resume. + pub fn with_config_dir(mut self, dir: impl Into) -> Self { + self.config_dir = Some(dir.into()); + self + } + + /// Set the per-session working directory on resume. + pub fn with_working_directory(mut self, dir: impl Into) -> Self { + self.working_directory = Some(dir.into()); + self + } + + /// Set the per-session GitHub token on resume. See + /// [`SessionConfig::github_token`] for distinction from the + /// client-level token. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Forward sub-agent streaming events to this connection on resume. + pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self { + self.include_sub_agent_streaming_events = Some(include); + self + } + + /// Force-fail resume if the session does not exist on disk, instead + /// of silently starting a new session. + pub fn with_disable_resume(mut self, disable: bool) -> Self { + self.disable_resume = Some(disable); + self + } + + /// When `true`, instructs the runtime to continue any tool calls or + /// permission requests that were pending when the previous connection + /// was dropped. Use this together with + /// [`Client::force_stop`](crate::Client::force_stop) to hand off a + /// session from one process to another without losing in-flight work. + pub fn with_continue_pending_work(mut self, continue_pending: bool) -> Self { + self.continue_pending_work = Some(continue_pending); + self + } +} + +/// Controls how the system message is constructed. +/// +/// Use `mode: "append"` (default) to add content after the built-in system +/// message, `"replace"` to substitute it entirely, or `"customize"` for +/// section-level overrides. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct SystemMessageConfig { + /// How content is applied: `"append"` (default), `"replace"`, or `"customize"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Content string to append or replace. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// Section-level overrides (used with `mode: "customize"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub sections: Option>, +} + +impl SystemMessageConfig { + /// Construct an empty [`SystemMessageConfig`]; all fields default to + /// unset. + pub fn new() -> Self { + Self::default() + } + + /// Set the application mode: `"append"` (default), `"replace"`, or + /// `"customize"`. + pub fn with_mode(mut self, mode: impl Into) -> Self { + self.mode = Some(mode.into()); + self + } + + /// Set the system message content (used by `"append"` and `"replace"` + /// modes). + pub fn with_content(mut self, content: impl Into) -> Self { + self.content = Some(content.into()); + self + } + + /// Set the section-level overrides (used with `mode: "customize"`). + pub fn with_sections(mut self, sections: HashMap) -> Self { + self.sections = Some(sections); + self + } +} + +/// An override operation for a single system prompt section. +/// +/// Used within [`SystemMessageConfig::sections`] when `mode` is `"customize"`. +/// The `action` field determines the operation: `"replace"`, `"remove"`, +/// `"append"`, `"prepend"`, or `"transform"`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SectionOverride { + /// Override action: `"replace"`, `"remove"`, `"append"`, `"prepend"`, or `"transform"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Content for the override operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Response from `session.create`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateSessionResult { + /// The CLI-assigned session ID. + pub session_id: SessionId, + /// Workspace directory for the session (infinite sessions). + #[serde(skip_serializing_if = "Option::is_none")] + pub workspace_path: Option, + /// Remote session URL, if the session is running remotely. + #[serde(default, alias = "remote_url")] + pub remote_url: Option, + /// Capabilities negotiated with the CLI for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, +} + +/// Parameters for the `session.sendTelemetry` RPC. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTelemetryEvent { + /// Telemetry event kind (for example, `"session_shutdown"`). + pub kind: String, + /// Non-restricted string properties to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + /// Restricted string properties that may contain sensitive data. + #[serde(skip_serializing_if = "Option::is_none")] + pub restricted_properties: Option>, + /// Numeric metrics to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metrics: Option>, +} + +/// Severity level for [`Session::log`](crate::session::Session::log) messages. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// Informational message (default). + #[default] + Info, + /// Warning message. + Warning, + /// Error message. + Error, +} + +/// Options for [`Session::log`](crate::session::Session::log). +/// +/// Pass `None` to `log` for defaults (info level, persisted to the session +/// event log on disk). +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogOptions { + /// Log severity. `None` lets the server pick (defaults to `info`). + #[serde(skip_serializing_if = "Option::is_none")] + pub level: Option, + /// When `Some(true)`, the message is transient and not persisted to the + /// session event log on disk. `None` lets the server pick. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, +} + +impl LogOptions { + /// Set [`level`](Self::level). + pub fn with_level(mut self, level: LogLevel) -> Self { + self.level = Some(level); + self + } + + /// Set [`ephemeral`](Self::ephemeral). + pub fn with_ephemeral(mut self, ephemeral: bool) -> Self { + self.ephemeral = Some(ephemeral); + self + } +} + +/// Options for [`Session::set_model`](crate::session::Session::set_model). +/// +/// Pass `None` to `set_model` to switch model without any overrides. +#[derive(Debug, Clone, Default)] +pub struct SetModelOptions { + /// Reasoning effort for the new model (e.g. `"low"`, `"medium"`, + /// `"high"`, `"xhigh"`). + pub reasoning_effort: Option, + /// Override individual model capabilities resolved by the runtime. Only + /// fields set on the override are applied; the rest fall back to the + /// runtime-resolved values for the model. + pub model_capabilities: Option, +} + +impl SetModelOptions { + /// Set [`reasoning_effort`](Self::reasoning_effort). + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + + /// Set [`model_capabilities`](Self::model_capabilities). + pub fn with_model_capabilities( + mut self, + caps: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(caps); + self + } +} + +/// Response from the top-level `ping` RPC. +/// +/// The `protocol_version` field is the most commonly-inspected piece โ€” +/// see [`Client::verify_protocol_version`]. +/// +/// [`Client::verify_protocol_version`]: crate::Client::verify_protocol_version +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingResponse { + /// The message echoed back by the CLI. + #[serde(default)] + pub message: String, + /// Server-side timestamp (Unix epoch milliseconds). + #[serde(default)] + pub timestamp: i64, + /// The protocol version negotiated by the CLI, if reported. + #[serde(skip_serializing_if = "Option::is_none")] + pub protocol_version: Option, +} + +/// Parameters for the top-level `sendTelemetry` RPC. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerTelemetryEvent { + /// Telemetry event kind (for example, `"app.launched"`). + pub kind: String, + /// SDK client name. Non-allowlisted values are hashed in telemetry. + pub client_name: String, + /// Non-restricted string properties to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + /// Restricted string properties that may contain sensitive data. + #[serde(skip_serializing_if = "Option::is_none")] + pub restricted_properties: Option>, + /// Numeric metrics to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metrics: Option>, +} + +/// Line range for file attachments. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentLineRange { + /// First line (1-based). + pub start: u32, + /// Last line (inclusive). + pub end: u32, +} + +/// Cursor position within a file selection. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentSelectionPosition { + /// Line number (0-based). + pub line: u32, + /// Character offset (0-based). + pub character: u32, +} + +/// Range of selected text within a file. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentSelectionRange { + /// Start position. + pub start: AttachmentSelectionPosition, + /// End position. + pub end: AttachmentSelectionPosition, +} + +/// Type of GitHub reference attachment. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum GitHubReferenceType { + /// GitHub issue. + Issue, + /// GitHub pull request. + Pr, + /// GitHub discussion. + Discussion, +} + +/// An attachment included with a user message. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde( + tag = "type", + rename_all = "camelCase", + rename_all_fields = "camelCase" +)] +#[non_exhaustive] +pub enum Attachment { + /// A file path, optionally with a line range. + File { + /// Absolute path to the file. + path: PathBuf, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + /// Optional line range to focus on. + #[serde(skip_serializing_if = "Option::is_none")] + line_range: Option, + }, + /// A directory path. + Directory { + /// Absolute path to the directory. + path: PathBuf, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + }, + /// A text selection within a file. + Selection { + /// Path to the file containing the selection. + file_path: PathBuf, + /// The selected text content. + text: String, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + /// Character range of the selection. + selection: AttachmentSelectionRange, + }, + /// Raw binary data (e.g. an image). + Blob { + /// Base64-encoded data. + data: String, + /// MIME type of the data. + mime_type: String, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + }, + /// A reference to a GitHub issue, PR, or discussion. + #[serde(rename = "github_reference")] + GitHubReference { + /// Issue/PR/discussion number. + number: u64, + /// Title of the referenced item. + title: String, + /// Kind of reference. + reference_type: GitHubReferenceType, + /// Current state (e.g. "open", "closed"). + state: String, + /// URL to the referenced item. + url: String, + }, +} + +impl Attachment { + /// Returns the display name, if set. + pub fn display_name(&self) -> Option<&str> { + match self { + Self::File { display_name, .. } + | Self::Directory { display_name, .. } + | Self::Selection { display_name, .. } + | Self::Blob { display_name, .. } => display_name.as_deref(), + Self::GitHubReference { .. } => None, + } + } + + /// Returns a human-readable label, deriving one from the path if needed. + pub fn label(&self) -> Option { + if let Some(display_name) = self + .display_name() + .map(str::trim) + .filter(|name| !name.is_empty()) + { + return Some(display_name.to_string()); + } + + match self { + Self::GitHubReference { number, title, .. } => Some(if title.trim().is_empty() { + format!("#{}", number) + } else { + title.trim().to_string() + }), + _ => self.derived_display_name(), + } + } + + /// Ensure `display_name` is populated when the variant supports one. + pub fn ensure_display_name(&mut self) { + if self + .display_name() + .map(str::trim) + .is_some_and(|name| !name.is_empty()) + { + return; + } + + let Some(derived_display_name) = self.derived_display_name() else { + return; + }; + + match self { + Self::File { display_name, .. } + | Self::Directory { display_name, .. } + | Self::Selection { display_name, .. } + | Self::Blob { display_name, .. } => *display_name = Some(derived_display_name), + Self::GitHubReference { .. } => {} + } + } + + fn derived_display_name(&self) -> Option { + match self { + Self::File { path, .. } | Self::Directory { path, .. } => { + Some(attachment_name_from_path(path)) + } + Self::Selection { file_path, .. } => Some(attachment_name_from_path(file_path)), + Self::Blob { .. } => Some("attachment".to_string()), + Self::GitHubReference { .. } => None, + } + } +} + +fn attachment_name_from_path(path: &Path) -> String { + path.file_name() + .map(|name| name.to_string_lossy().into_owned()) + .filter(|name| !name.is_empty()) + .unwrap_or_else(|| { + let full = path.to_string_lossy(); + if full.is_empty() { + "attachment".to_string() + } else { + full.into_owned() + } + }) +} + +/// Normalize a list of attachments so every entry has a `display_name`. +pub fn ensure_attachment_display_names(attachments: &mut [Attachment]) { + for attachment in attachments { + attachment.ensure_display_name(); + } +} + +/// Message delivery mode for [`MessageOptions::mode`]. +/// +/// Controls how a prompt is delivered relative to in-flight session work. +/// Wire values: `"enqueue"` and `"immediate"`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum DeliveryMode { + /// Queue the prompt behind any in-flight work (default). + Enqueue, + /// Interrupt the session and run the prompt immediately. + Immediate, +} + +/// Options for sending a user message to the agent. +/// +/// Used by both [`Session::send`](crate::session::Session::send) and +/// [`Session::send_and_wait`](crate::session::Session::send_and_wait); the +/// `wait_timeout` field is honored only by `send_and_wait` and is ignored by +/// `send`. +/// +/// `MessageOptions` is `#[non_exhaustive]` and constructed via [`MessageOptions::new`] +/// plus the `with_*` chain so future fields can land without breaking callers. +/// For the trivial case, both `&str` and `String` implement `Into`, +/// so: +/// +/// ```no_run +/// # use github_copilot_sdk::session::Session; +/// # async fn run(session: Session) -> Result<(), github_copilot_sdk::Error> { +/// session.send("hello").await?; +/// # Ok(()) } +/// ``` +/// +/// is equivalent to: +/// +/// ```no_run +/// # use github_copilot_sdk::session::Session; +/// # use github_copilot_sdk::types::MessageOptions; +/// # async fn run(session: Session) -> Result<(), github_copilot_sdk::Error> { +/// session.send(MessageOptions::new("hello")).await?; +/// # Ok(()) } +/// ``` +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct MessageOptions { + /// The user prompt to send. + pub prompt: String, + /// Optional message delivery mode for this turn. + /// + /// Controls whether the prompt is queued behind in-flight work + /// ([`DeliveryMode::Enqueue`], default) or interrupts the session and + /// runs immediately ([`DeliveryMode::Immediate`]). + pub mode: Option, + /// Optional attachments to include with the message. + pub attachments: Option>, + /// Maximum time to wait for the session to go idle. Honored only by + /// `send_and_wait`. Defaults to 60 seconds when unset. + pub wait_timeout: Option, + /// Custom HTTP headers to include in outbound model requests for this + /// turn. When `None` or empty, no `requestHeaders` field is sent on + /// the wire. + pub request_headers: Option>, + /// W3C Trace Context `traceparent` header for this turn. + /// + /// Per-turn override that takes precedence over + /// [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context). + /// When `None`, the SDK falls back to the provider (if configured) + /// before omitting the field. + pub traceparent: Option, + /// W3C Trace Context `tracestate` header for this turn. + /// + /// Per-turn override paired with [`traceparent`](Self::traceparent). + pub tracestate: Option, +} + +impl MessageOptions { + /// Build a new `MessageOptions` with just a prompt. + pub fn new(prompt: impl Into) -> Self { + Self { + prompt: prompt.into(), + mode: None, + attachments: None, + wait_timeout: None, + request_headers: None, + traceparent: None, + tracestate: None, + } + } + + /// Set the message delivery mode for this turn. + /// + /// Pass [`DeliveryMode::Immediate`] to interrupt the session and run + /// the prompt now; the default ([`DeliveryMode::Enqueue`]) queues the + /// prompt behind in-flight work. + pub fn with_mode(mut self, mode: DeliveryMode) -> Self { + self.mode = Some(mode); + self + } + + /// Attach files / selections / blobs to the message. + pub fn with_attachments(mut self, attachments: Vec) -> Self { + self.attachments = Some(attachments); + self + } + + /// Override the default 60-second wait timeout for `send_and_wait`. + pub fn with_wait_timeout(mut self, timeout: Duration) -> Self { + self.wait_timeout = Some(timeout); + self + } + + /// Set custom HTTP headers for outbound model requests for this turn. + pub fn with_request_headers(mut self, headers: HashMap) -> Self { + self.request_headers = Some(headers); + self + } + + /// Set both `traceparent` and `tracestate` from a [`TraceContext`]. + /// Either field may remain `None` if the [`TraceContext`] has no value + /// for it. Use [`with_traceparent`](Self::with_traceparent) or + /// [`with_tracestate`](Self::with_tracestate) to set them individually. + pub fn with_trace_context(mut self, ctx: TraceContext) -> Self { + self.traceparent = ctx.traceparent; + self.tracestate = ctx.tracestate; + self + } + + /// Set the W3C `traceparent` header for this turn. + pub fn with_traceparent(mut self, traceparent: impl Into) -> Self { + self.traceparent = Some(traceparent.into()); + self + } + + /// Set the W3C `tracestate` header for this turn. + pub fn with_tracestate(mut self, tracestate: impl Into) -> Self { + self.tracestate = Some(tracestate.into()); + self + } +} + +impl From<&str> for MessageOptions { + fn from(prompt: &str) -> Self { + Self::new(prompt) + } +} + +impl From for MessageOptions { + fn from(prompt: String) -> Self { + Self::new(prompt) + } +} + +impl From<&String> for MessageOptions { + fn from(prompt: &String) -> Self { + Self::new(prompt.clone()) + } +} + +/// Response from [`Client::get_status`](crate::Client::get_status). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct GetStatusResponse { + /// Package version (e.g. `"1.0.0"`). + pub version: String, + /// Protocol version for SDK compatibility. + pub protocol_version: u32, +} + +/// Response from [`Client::get_auth_status`](crate::Client::get_auth_status). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct GetAuthStatusResponse { + /// Whether the user is authenticated. + pub is_authenticated: bool, + /// Authentication type (e.g. `"user"`, `"env"`, `"gh-cli"`, `"hmac"`, + /// `"api-key"`, `"token"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// GitHub host URL. + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// User login name. + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable status message. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +/// Wrapper for session event notifications received from the CLI. +/// +/// The CLI sends these as JSON-RPC notifications on the `session.event` method. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEventNotification { + /// The session this event belongs to. + pub session_id: SessionId, + /// The event payload. + pub event: SessionEvent, +} + +/// A single event in a session's timeline. +/// +/// Events form a linked chain via `parent_id`. The `event_type` string +/// identifies the kind (e.g. `"assistant.message_delta"`, `"session.idle"`, +/// `"tool.execution_start"`). Event-specific payload is in `data` as +/// untyped JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEvent { + /// Unique event ID (UUID v4). + pub id: String, + /// ISO 8601 timestamp. + pub timestamp: String, + /// ID of the preceding event in the chain. + pub parent_id: Option, + /// Transient events that are not persisted to disk. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// Sub-agent instance identifier. Absent for events emitted by the + /// root/main agent and for session-level events. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + /// Debug timestamp: when the CLI received this event (ms since epoch). + #[serde(skip_serializing_if = "Option::is_none")] + pub debug_cli_received_at_ms: Option, + /// Debug timestamp: when the event was forwarded over WebSocket. + #[serde(skip_serializing_if = "Option::is_none")] + pub debug_ws_forwarded_at_ms: Option, + /// Event type string (e.g. `"assistant.message"`, `"session.idle"`). + #[serde(rename = "type")] + pub event_type: String, + /// Event-specific data. Structure depends on `event_type`. + pub data: Value, +} + +impl SessionEvent { + /// Parse the string `event_type` into a typed [`SessionEventType`](crate::generated::SessionEventType) enum. + /// + /// Returns `SessionEventType::Unknown` for unrecognized event types, + /// ensuring forward compatibility with newer CLI versions. + pub fn parsed_type(&self) -> crate::generated::SessionEventType { + use serde::de::IntoDeserializer; + let deserializer: serde::de::value::StrDeserializer<'_, serde::de::value::Error> = + self.event_type.as_str().into_deserializer(); + crate::generated::SessionEventType::deserialize(deserializer) + .unwrap_or(crate::generated::SessionEventType::Unknown) + } + + /// Deserialize the event `data` field into a typed struct. + /// + /// Returns `None` if deserialization fails (e.g. unknown event type + /// or schema mismatch). Prefer typed data accessors for specific + /// event types where you need strongly-typed field access. + pub fn typed_data(&self) -> Option { + serde_json::from_value(self.data.clone()).ok() + } + + /// `model_call` errors are transient โ€” the CLI agent loop continues + /// after them and may succeed on the next turn. These should not be + /// treated as session-ending errors. + pub fn is_transient_error(&self) -> bool { + self.event_type == "session.error" + && self.data.get("errorType").and_then(|v| v.as_str()) == Some("model_call") + } +} + +/// A request from the CLI to invoke a client-defined tool. +/// +/// Received as a JSON-RPC request on the `tool.call` method. The client +/// must respond with a [`ToolResultResponse`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ToolInvocation { + /// Session that owns this tool call. + pub session_id: SessionId, + /// Unique ID for this tool call, used to correlate the response. + pub tool_call_id: String, + /// Name of the tool being invoked. + pub tool_name: String, + /// Tool arguments as JSON. + pub arguments: Value, + /// W3C Trace Context `traceparent` header propagated from the CLI's + /// `execute_tool` span. Pass through to OpenTelemetry-aware code so + /// child spans created inside the handler are parented to the CLI + /// span. `None` when the CLI has no trace context for this call. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub traceparent: Option, + /// W3C Trace Context `tracestate` paired with + /// [`traceparent`](Self::traceparent). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tracestate: Option, +} + +impl ToolInvocation { + /// Deserialize this invocation's [`arguments`](Self::arguments) into a + /// strongly-typed parameter struct. + /// + /// Idiomatic way to extract typed parameters when implementing + /// [`ToolHandler`](crate::tool::ToolHandler) directly. Equivalent to + /// `serde_json::from_value(invocation.arguments.clone())` with the SDK's + /// error type. + /// + /// # Example + /// + /// ```rust,no_run + /// # use github_copilot_sdk::{Error, types::ToolInvocation, ToolResult}; + /// # use serde::Deserialize; + /// # #[derive(Deserialize)] struct MyParams { city: String } + /// # async fn example(inv: ToolInvocation) -> Result { + /// let params: MyParams = inv.params()?; + /// // โ€ฆuse `inv.session_id` / `inv.tool_call_id` alongside `params`โ€ฆ + /// # let _ = params; Ok(ToolResult::Text(String::new())) + /// # } + /// ``` + pub fn params(&self) -> Result { + serde_json::from_value(self.arguments.clone()).map_err(crate::Error::from) + } + + /// Returns the propagated [`TraceContext`] for this invocation, or + /// [`TraceContext::default()`] when the CLI sent no headers. + pub fn trace_context(&self) -> TraceContext { + TraceContext { + traceparent: self.traceparent.clone(), + tracestate: self.tracestate.clone(), + } + } +} + +/// Expanded tool result with metadata for the LLM and session log. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolResultExpanded { + /// Result text sent back to the LLM. + pub text_result_for_llm: String, + /// `"success"` or `"failure"`. + pub result_type: String, + /// Optional log message for the session timeline. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_log: Option, + /// Error message, if the tool failed. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// Result of a tool invocation โ€” either a plain text string or an expanded result. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +#[non_exhaustive] +pub enum ToolResult { + /// Simple text result passed directly to the LLM. + Text(String), + /// Structured result with metadata. + Expanded(ToolResultExpanded), +} + +/// JSON-RPC response wrapper for a tool result, sent back to the CLI. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolResultResponse { + /// The tool result payload. + pub result: ToolResult, +} + +/// Metadata for a persisted session, returned by `session.list`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMetadata { + /// The session's unique identifier. + pub session_id: SessionId, + /// ISO 8601 timestamp when the session was created. + pub start_time: String, + /// ISO 8601 timestamp of the last modification. + pub modified_time: String, + /// Agent-generated session summary. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + /// Whether the session is running remotely. + pub is_remote: bool, +} + +/// Response from `session.list`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListSessionsResponse { + /// The list of session metadata entries. + pub sessions: Vec, +} + +/// Filter options for [`Client::list_sessions`](crate::Client::list_sessions). +/// +/// All fields are optional; unset fields don't constrain the result. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionListFilter { + /// Filter by exact `cwd` match. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, + /// Filter by git root path. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Filter by repository in `owner/repo` form. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Filter by git branch name. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub branch: Option, +} + +/// Response from `session.getMetadata`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetSessionMetadataResponse { + /// The session metadata, or `None` if the session was not found. + #[serde(skip_serializing_if = "Option::is_none")] + pub session: Option, +} + +/// Response from `session.getLastId`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetLastSessionIdResponse { + /// The most recently updated session ID, or `None` if no sessions exist. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, +} + +/// Response from `session.getForeground`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetForegroundSessionResponse { + /// The current foreground session ID, or `None` if no foreground session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, +} + +/// Response from `session.getMessages`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetMessagesResponse { + /// Timeline events for the session. + pub events: Vec, +} + +/// Result of an elicitation (interactive UI form) request. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationResult { + /// User's action: `"accept"`, `"decline"`, or `"cancel"`. + pub action: String, + /// Form data submitted by the user (present when action is `"accept"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Elicitation display mode. +/// +/// New modes may be added by the CLI in future protocol versions; the +/// `Unknown` variant keeps deserialization from failing on unrecognised +/// values so the SDK can still surface the request to callers. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub enum ElicitationMode { + /// Structured form input rendered by the host. + Form, + /// Browser redirect to a URL. + Url, + /// A mode not yet known to this SDK version. + #[serde(other)] + Unknown, +} + +/// An incoming elicitation request from the CLI (provider side). +/// +/// Received via `elicitation.requested` session event when the session was +/// created with `request_elicitation: true`. The provider should render a +/// form or dialog and return an [`ElicitationResult`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequest { + /// Message describing what information is needed from the user. + pub message: String, + /// JSON Schema describing the form fields to present. + #[serde(skip_serializing_if = "Option::is_none")] + pub requested_schema: Option, + /// Elicitation display mode. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// The source that initiated the request (e.g. MCP server name). + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation_source: Option, + /// URL to open in the user's browser (url mode only). + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Session-level capabilities reported by the CLI after session creation. +/// +/// Capabilities indicate which features the CLI host supports for this session. +/// Updated at runtime via `capabilities.changed` events. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCapabilities { + /// UI capabilities (elicitation support, etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub ui: Option, +} + +/// UI-specific capabilities for a session. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UiCapabilities { + /// Whether the host supports interactive elicitation dialogs. + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation: Option, +} + +/// Options for the [`SessionUi::input`](crate::session::SessionUi::input) convenience method. +#[derive(Debug, Clone, Default)] +pub struct InputOptions<'a> { + /// Title label for the input field. + pub title: Option<&'a str>, + /// Descriptive text shown below the field. + pub description: Option<&'a str>, + /// Minimum character length. + pub min_length: Option, + /// Maximum character length. + pub max_length: Option, + /// Semantic format hint. + pub format: Option, + /// Default value pre-populated in the field. + pub default: Option<&'a str>, +} + +/// Semantic format hints for text input fields. +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum InputFormat { + /// Email address. + Email, + /// URI. + Uri, + /// Calendar date. + Date, + /// Date and time. + DateTime, +} + +impl InputFormat { + /// Returns the JSON Schema format string for this variant. + pub fn as_str(&self) -> &'static str { + match self { + Self::Email => "email", + Self::Uri => "uri", + Self::Date => "date", + Self::DateTime => "date-time", + } + } +} + +/// Re-exports of generated protocol types that are part of the SDK's +/// public API surface. The canonical definitions live in +/// [`crate::generated::api_types`]; they live here so the crate-root +/// `pub use types::*` surfaces them alongside hand-written SDK types. +pub use crate::generated::api_types::{ + Model, ModelBilling, ModelCapabilities, ModelCapabilitiesLimits, ModelCapabilitiesLimitsVision, + ModelCapabilitiesSupports, ModelList, ModelPolicy, +}; + +/// Permission categories the CLI may request approval for. +/// +/// Wire values are the lower-kebab strings the CLI sends as the `kind` +/// discriminator on a permission request. Marked `#[non_exhaustive]` +/// because the CLI may add new kinds; matches must include a `_` arm. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub enum PermissionRequestKind { + /// Run a shell command. + Shell, + /// Write to a file. + Write, + /// Read a file. + Read, + /// Open a URL. + Url, + /// Invoke an MCP server tool. + Mcp, + /// Invoke a client-defined custom tool. + CustomTool, + /// Update agent memory. + Memory, + /// Run a hook callback. + Hook, + /// Unrecognized kind. The original wire string is available in + /// [`PermissionRequestData::extra`] under the `kind` key. + #[serde(other)] + Unknown, +} + +/// Data sent by the CLI for permission-related events. +/// +/// Used for both the `permission.request` RPC call (which expects a response) +/// and `permission.requested` notifications (fire-and-forget). Contains the +/// full params object. Note that `requestId` is also available as a separate +/// field on `HandlerEvent::PermissionRequest`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestData { + /// The permission category being requested. `None` means the CLI did + /// not include a `kind` field. Use this to branch on common cases + /// (shell, write, etc.) without parsing [`extra`](Self::extra). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub kind: Option, + /// The originating tool-call ID, if this permission request is tied + /// to a specific tool invocation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// The full permission request params from the CLI. The shape varies by + /// permission type and CLI version, so we preserve it as `Value`. + #[serde(flatten)] + pub extra: Value, +} + +/// Data sent by the CLI with an `exitPlanMode.request` RPC call. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeData { + /// Markdown summary of the plan presented to the user. + #[serde(default)] + pub summary: String, + /// Full plan content (e.g. the plan.md body), if available. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub plan_content: Option, + /// Allowed exit actions (e.g. "interactive", "autopilot", "autopilot_fleet"). + #[serde(default)] + pub actions: Vec, + /// Which action the CLI recommends, defaults to "autopilot". + #[serde(default = "default_recommended_action")] + pub recommended_action: String, +} + +fn default_recommended_action() -> String { + "autopilot".to_string() +} + +impl Default for ExitPlanModeData { + fn default() -> Self { + Self { + summary: String::new(), + plan_content: None, + actions: Vec::new(), + recommended_action: default_recommended_action(), + } + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use serde_json::json; + + use super::{ + Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, + ConnectionState, CustomAgentConfig, DeliveryMode, GitHubReferenceType, + InfiniteSessionConfig, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionEvent, + SessionId, SystemMessageConfig, Tool, ensure_attachment_display_names, + }; + use crate::generated::session_events::TypedSessionEvent; + + #[test] + fn tool_builder_composes() { + let tool = Tool::new("greet") + .with_description("Say hello") + .with_namespaced_name("hello/greet") + .with_instructions("Pass the user's name") + .with_parameters(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + })) + .with_overrides_built_in_tool(true) + .with_skip_permission(true); + assert_eq!(tool.name, "greet"); + assert_eq!(tool.description, "Say hello"); + assert_eq!(tool.namespaced_name.as_deref(), Some("hello/greet")); + assert_eq!(tool.instructions.as_deref(), Some("Pass the user's name")); + assert_eq!(tool.parameters.get("type").unwrap(), &json!("object")); + assert!(tool.overrides_built_in_tool); + assert!(tool.skip_permission); + } + + #[test] + fn tool_with_parameters_handles_non_object_value() { + let tool = Tool::new("noop").with_parameters(json!(null)); + assert!(tool.parameters.is_empty()); + } + + #[test] + fn session_config_default_enables_permission_flow_flags() { + let cfg = SessionConfig::default(); + assert_eq!(cfg.request_user_input, Some(true)); + assert_eq!(cfg.request_permission, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, Some(true)); + assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_elicitation, Some(true)); + } + + #[test] + fn resume_session_config_new_enables_permission_flow_flags() { + let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); + assert_eq!(cfg.request_user_input, Some(true)); + assert_eq!(cfg.request_permission, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, Some(true)); + assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_elicitation, Some(true)); + } + + #[test] + fn session_config_builder_composes() { + use std::collections::HashMap; + + let cfg = SessionConfig::default() + .with_session_id(SessionId::from("sess-1")) + .with_model("claude-sonnet-4") + .with_client_name("test-app") + .with_reasoning_effort("medium") + .with_streaming(true) + .with_tools([Tool::new("greet")]) + .with_available_tools(["bash", "view"]) + .with_excluded_tools(["dangerous"]) + .with_mcp_servers(HashMap::new()) + .with_env_value_mode("direct") + .with_enable_config_discovery(true) + .with_request_user_input(false) + .with_skill_directories([PathBuf::from("/tmp/skills")]) + .with_disabled_skills(["broken-skill"]) + .with_disabled_mcp_servers(["broken-server"]) + .with_agent("researcher") + .with_config_dir(PathBuf::from("/tmp/config")) + .with_working_directory(PathBuf::from("/tmp/work")) + .with_github_token("ghp_test") + .with_include_sub_agent_streaming_events(false); + + assert_eq!(cfg.session_id.as_ref().map(|s| s.as_str()), Some("sess-1")); + assert_eq!(cfg.model.as_deref(), Some("claude-sonnet-4")); + assert_eq!(cfg.client_name.as_deref(), Some("test-app")); + assert_eq!(cfg.reasoning_effort.as_deref(), Some("medium")); + assert_eq!(cfg.streaming, Some(true)); + assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1)); + assert_eq!( + cfg.available_tools.as_deref(), + Some(&["bash".to_string(), "view".to_string()][..]) + ); + assert_eq!( + cfg.excluded_tools.as_deref(), + Some(&["dangerous".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.env_value_mode.as_deref(), Some("direct")); + assert_eq!(cfg.enable_config_discovery, Some(true)); + assert_eq!(cfg.request_user_input, Some(false)); // overrode default + assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!( + cfg.skill_directories.as_deref(), + Some(&[PathBuf::from("/tmp/skills")][..]) + ); + assert_eq!( + cfg.disabled_skills.as_deref(), + Some(&["broken-skill".to_string()][..]) + ); + assert_eq!(cfg.agent.as_deref(), Some("researcher")); + assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config"))); + assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work"))); + assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); + assert_eq!(cfg.include_sub_agent_streaming_events, Some(false)); + } + + #[test] + fn resume_session_config_builder_composes() { + use std::collections::HashMap; + + let cfg = ResumeSessionConfig::new(SessionId::from("sess-2")) + .with_client_name("test-app") + .with_streaming(true) + .with_tools([Tool::new("greet")]) + .with_excluded_tools(["dangerous"]) + .with_mcp_servers(HashMap::new()) + .with_env_value_mode("indirect") + .with_enable_config_discovery(true) + .with_request_user_input(false) + .with_skill_directories([PathBuf::from("/tmp/skills")]) + .with_agent("researcher") + .with_config_dir(PathBuf::from("/tmp/config")) + .with_working_directory(PathBuf::from("/tmp/work")) + .with_github_token("ghp_test") + .with_include_sub_agent_streaming_events(true) + .with_disable_resume(true) + .with_continue_pending_work(true); + + assert_eq!(cfg.session_id.as_str(), "sess-2"); + assert_eq!(cfg.client_name.as_deref(), Some("test-app")); + assert_eq!(cfg.streaming, Some(true)); + assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1)); + assert_eq!( + cfg.excluded_tools.as_deref(), + Some(&["dangerous".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.env_value_mode.as_deref(), Some("indirect")); + assert_eq!(cfg.enable_config_discovery, Some(true)); + assert_eq!(cfg.request_user_input, Some(false)); // overrode default + assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!( + cfg.skill_directories.as_deref(), + Some(&[PathBuf::from("/tmp/skills")][..]) + ); + assert_eq!(cfg.agent.as_deref(), Some("researcher")); + assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config"))); + assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work"))); + assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); + assert_eq!(cfg.include_sub_agent_streaming_events, Some(true)); + assert_eq!(cfg.disable_resume, Some(true)); + assert_eq!(cfg.continue_pending_work, Some(true)); + } + + /// `continue_pending_work` must serialize to wire as `continuePendingWork` + /// โ€” the runtime keys off this exact field name to opt into the + /// pending-work-handoff pattern. + #[test] + fn resume_session_config_serializes_continue_pending_work_to_camel_case() { + let cfg = + ResumeSessionConfig::new(SessionId::from("sess-1")).with_continue_pending_work(true); + let wire = serde_json::to_value(&cfg).unwrap(); + assert_eq!(wire["continuePendingWork"], true); + + // Unset case โ€” skip_serializing_if must omit the field. + let cfg = ResumeSessionConfig::new(SessionId::from("sess-2")); + let wire = serde_json::to_value(&cfg).unwrap(); + assert!(wire.get("continuePendingWork").is_none()); + } + + /// `instruction_directories` must serialize to wire as + /// `instructionDirectories` on `SessionConfig`. Cross-SDK parity field + /// (Node/Python pass it through to the CLI verbatim). + #[test] + fn session_config_serializes_instruction_directories_to_camel_case() { + let cfg = + SessionConfig::default().with_instruction_directories([PathBuf::from("/tmp/instr")]); + let wire = serde_json::to_value(&cfg).unwrap(); + assert_eq!( + wire["instructionDirectories"], + serde_json::json!(["/tmp/instr"]) + ); + + // Unset case โ€” skip_serializing_if must omit the field. + let cfg = SessionConfig::default(); + let wire = serde_json::to_value(&cfg).unwrap(); + assert!(wire.get("instructionDirectories").is_none()); + } + + /// Same check on the resume path. Forwarded to the CLI on + /// `session.resume`. + #[test] + fn resume_session_config_serializes_instruction_directories_to_camel_case() { + let cfg = ResumeSessionConfig::new(SessionId::from("sess-1")) + .with_instruction_directories([PathBuf::from("/tmp/instr")]); + let wire = serde_json::to_value(&cfg).unwrap(); + assert_eq!( + wire["instructionDirectories"], + serde_json::json!(["/tmp/instr"]) + ); + + let cfg = ResumeSessionConfig::new(SessionId::from("sess-2")); + let wire = serde_json::to_value(&cfg).unwrap(); + assert!(wire.get("instructionDirectories").is_none()); + } + + #[test] + fn custom_agent_config_builder_composes() { + use std::collections::HashMap; + + let cfg = CustomAgentConfig::new("researcher", "You are a research assistant.") + .with_display_name("Research Assistant") + .with_description("Investigates technical questions.") + .with_tools(["bash", "view"]) + .with_mcp_servers(HashMap::new()) + .with_infer(true) + .with_skills(["rust-coding-skill"]); + + assert_eq!(cfg.name, "researcher"); + assert_eq!(cfg.prompt, "You are a research assistant."); + assert_eq!(cfg.display_name.as_deref(), Some("Research Assistant")); + assert_eq!( + cfg.description.as_deref(), + Some("Investigates technical questions.") + ); + assert_eq!( + cfg.tools.as_deref(), + Some(&["bash".to_string(), "view".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.infer, Some(true)); + assert_eq!( + cfg.skills.as_deref(), + Some(&["rust-coding-skill".to_string()][..]) + ); + } + + #[test] + fn infinite_session_config_builder_composes() { + let cfg = InfiniteSessionConfig::new() + .with_enabled(true) + .with_background_compaction_threshold(0.75) + .with_buffer_exhaustion_threshold(0.92); + + assert_eq!(cfg.enabled, Some(true)); + assert_eq!(cfg.background_compaction_threshold, Some(0.75)); + assert_eq!(cfg.buffer_exhaustion_threshold, Some(0.92)); + } + + #[test] + fn provider_config_builder_composes() { + use std::collections::HashMap; + + let mut headers = HashMap::new(); + headers.insert("X-Custom".to_string(), "value".to_string()); + + let cfg = ProviderConfig::new("https://api.example.com") + .with_provider_type("openai") + .with_wire_api("completions") + .with_api_key("sk-test") + .with_bearer_token("bearer-test") + .with_headers(headers); + + assert_eq!(cfg.base_url, "https://api.example.com"); + assert_eq!(cfg.provider_type.as_deref(), Some("openai")); + assert_eq!(cfg.wire_api.as_deref(), Some("completions")); + assert_eq!(cfg.api_key.as_deref(), Some("sk-test")); + assert_eq!(cfg.bearer_token.as_deref(), Some("bearer-test")); + assert_eq!( + cfg.headers + .as_ref() + .and_then(|h| h.get("X-Custom")) + .map(String::as_str), + Some("value"), + ); + } + + #[test] + fn system_message_config_builder_composes() { + use std::collections::HashMap; + + let cfg = SystemMessageConfig::new() + .with_mode("replace") + .with_content("Custom system message.") + .with_sections(HashMap::new()); + + assert_eq!(cfg.mode.as_deref(), Some("replace")); + assert_eq!(cfg.content.as_deref(), Some("Custom system message.")); + assert!(cfg.sections.is_some()); + } + + #[test] + fn delivery_mode_serializes_to_kebab_case_strings() { + assert_eq!( + serde_json::to_string(&DeliveryMode::Enqueue).unwrap(), + "\"enqueue\"" + ); + assert_eq!( + serde_json::to_string(&DeliveryMode::Immediate).unwrap(), + "\"immediate\"" + ); + let parsed: DeliveryMode = serde_json::from_str("\"immediate\"").unwrap(); + assert_eq!(parsed, DeliveryMode::Immediate); + } + + #[test] + fn connection_state_error_serializes_to_match_go() { + let json = serde_json::to_string(&ConnectionState::Error).unwrap(); + assert_eq!(json, "\"error\""); + let parsed: ConnectionState = serde_json::from_str("\"error\"").unwrap(); + assert_eq!(parsed, ConnectionState::Error); + } + + /// `agentId` is the sub-agent attribution field added in copilot-sdk + /// commit f8cf846 ("Derive session event envelopes from schema"). + /// Every other SDK (Node, Python, Go, .NET) carries it on the event + /// envelope; Rust must too or sub-agent events lose attribution at + /// the deserialization boundary. Cross-SDK parity test. + #[test] + fn session_event_round_trips_agent_id_on_envelope() { + let wire = json!({ + "id": "evt-1", + "timestamp": "2026-04-30T12:00:00Z", + "parentId": null, + "agentId": "sub-agent-42", + "type": "assistant.message", + "data": { "message": "hi" } + }); + + let event: SessionEvent = serde_json::from_value(wire.clone()).unwrap(); + assert_eq!(event.agent_id.as_deref(), Some("sub-agent-42")); + + // Round-trip preserves the field on the wire. + let roundtripped = serde_json::to_value(&event).unwrap(); + assert_eq!(roundtripped["agentId"], "sub-agent-42"); + + // Absent agentId remains absent (skip_serializing_if). + let main_agent_event: SessionEvent = serde_json::from_value(json!({ + "id": "evt-2", + "timestamp": "2026-04-30T12:00:01Z", + "parentId": null, + "type": "session.idle", + "data": {} + })) + .unwrap(); + assert!(main_agent_event.agent_id.is_none()); + let roundtripped = serde_json::to_value(&main_agent_event).unwrap(); + assert!(roundtripped.get("agentId").is_none()); + } + + /// Same parity for the typed event envelope produced by the codegen. + #[test] + fn typed_session_event_round_trips_agent_id_on_envelope() { + let wire = json!({ + "id": "evt-1", + "timestamp": "2026-04-30T12:00:00Z", + "parentId": null, + "agentId": "sub-agent-42", + "type": "session.idle", + "data": {} + }); + + let event: TypedSessionEvent = serde_json::from_value(wire).unwrap(); + assert_eq!(event.agent_id.as_deref(), Some("sub-agent-42")); + + let roundtripped = serde_json::to_value(&event).unwrap(); + assert_eq!(roundtripped["agentId"], "sub-agent-42"); + } + + #[test] + fn connection_state_other_variants_serialize_as_lowercase() { + assert_eq!( + serde_json::to_string(&ConnectionState::Disconnected).unwrap(), + "\"disconnected\"" + ); + assert_eq!( + serde_json::to_string(&ConnectionState::Connecting).unwrap(), + "\"connecting\"" + ); + assert_eq!( + serde_json::to_string(&ConnectionState::Connected).unwrap(), + "\"connected\"" + ); + } + + #[test] + fn deserializes_runtime_attachment_variants() { + let attachments: Vec = serde_json::from_value(json!([ + { + "type": "file", + "path": "/tmp/file.rs", + "displayName": "file.rs", + "lineRange": { "start": 7, "end": 12 } + }, + { + "type": "directory", + "path": "/tmp/project", + "displayName": "project" + }, + { + "type": "selection", + "filePath": "/tmp/lib.rs", + "displayName": "lib.rs", + "text": "fn main() {}", + "selection": { + "start": { "line": 1, "character": 2 }, + "end": { "line": 3, "character": 4 } + } + }, + { + "type": "blob", + "data": "Zm9v", + "mimeType": "image/png", + "displayName": "image.png" + }, + { + "type": "github_reference", + "number": 42, + "title": "Fix rendering", + "referenceType": "issue", + "state": "open", + "url": "https://github.com/example/repo/issues/42" + } + ])) + .expect("attachments should deserialize"); + + assert_eq!(attachments.len(), 5); + assert!(matches!( + &attachments[0], + Attachment::File { + path, + display_name, + line_range: Some(AttachmentLineRange { start: 7, end: 12 }), + } if path == &PathBuf::from("/tmp/file.rs") && display_name.as_deref() == Some("file.rs") + )); + assert!(matches!( + &attachments[1], + Attachment::Directory { path, display_name } + if path == &PathBuf::from("/tmp/project") && display_name.as_deref() == Some("project") + )); + assert!(matches!( + &attachments[2], + Attachment::Selection { + file_path, + display_name, + selection: + AttachmentSelectionRange { + start: AttachmentSelectionPosition { line: 1, character: 2 }, + end: AttachmentSelectionPosition { line: 3, character: 4 }, + }, + .. + } if file_path == &PathBuf::from("/tmp/lib.rs") && display_name.as_deref() == Some("lib.rs") + )); + assert!(matches!( + &attachments[3], + Attachment::Blob { + data, + mime_type, + display_name, + } if data == "Zm9v" && mime_type == "image/png" && display_name.as_deref() == Some("image.png") + )); + assert!(matches!( + &attachments[4], + Attachment::GitHubReference { + number: 42, + title, + reference_type: GitHubReferenceType::Issue, + state, + url, + } if title == "Fix rendering" + && state == "open" + && url == "https://github.com/example/repo/issues/42" + )); + } + + #[test] + fn ensures_display_names_for_variants_that_support_them() { + let mut attachments = vec![ + Attachment::File { + path: PathBuf::from("/tmp/file.rs"), + display_name: None, + line_range: None, + }, + Attachment::Selection { + file_path: PathBuf::from("/tmp/src/lib.rs"), + display_name: None, + text: "fn main() {}".to_string(), + selection: AttachmentSelectionRange { + start: AttachmentSelectionPosition { + line: 0, + character: 0, + }, + end: AttachmentSelectionPosition { + line: 0, + character: 10, + }, + }, + }, + Attachment::Blob { + data: "Zm9v".to_string(), + mime_type: "image/png".to_string(), + display_name: None, + }, + Attachment::GitHubReference { + number: 7, + title: "Track regressions".to_string(), + reference_type: GitHubReferenceType::Issue, + state: "open".to_string(), + url: "https://example.com/issues/7".to_string(), + }, + ]; + + ensure_attachment_display_names(&mut attachments); + + assert_eq!(attachments[0].display_name(), Some("file.rs")); + assert_eq!(attachments[1].display_name(), Some("lib.rs")); + assert_eq!(attachments[2].display_name(), Some("attachment")); + assert_eq!(attachments[3].display_name(), None); + assert_eq!( + attachments[3].label(), + Some("Track regressions".to_string()) + ); + } +} + +#[cfg(test)] +mod permission_builder_tests { + use std::sync::Arc; + + use crate::handler::{ + ApproveAllHandler, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, + }; + use crate::types::{ + PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, + }; + + fn permission_event() -> HandlerEvent { + HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1"), + request_id: RequestId::new("1"), + data: PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() + }, + } + } + + async fn dispatch(handler: &Arc) -> HandlerResponse { + handler.on_event(permission_event()).await + } + + #[tokio::test] + async fn session_config_approve_all_wraps_existing_handler() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_approve_all_defaults_to_deny_inner() { + // Without with_handler, the wrap defaults to DenyAllHandler. The + // approve-all wrap intercepts permission events, so they're still + // approved -- the inner handler is consulted only for other events. + let cfg = SessionConfig::default().approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_deny_all_denies() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .deny_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_approve_permissions_if_consults_predicate() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_permissions_if(|data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied for shell, got {other:?}"), + } + } + + #[tokio::test] + async fn resume_session_config_approve_all_wraps_existing_handler() { + let cfg = ResumeSessionConfig::new(SessionId::from("s1")) + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } +} diff --git a/rust/tests/integration_test.rs b/rust/tests/integration_test.rs new file mode 100644 index 000000000..90e2e1c7a --- /dev/null +++ b/rust/tests/integration_test.rs @@ -0,0 +1,107 @@ +#![allow(clippy::unwrap_used)] + +use std::time::Instant; + +use github_copilot_sdk::resolve::copilot_binary_with_source; +use github_copilot_sdk::{Client, ClientOptions, SDK_PROTOCOL_VERSION}; + +fn default_options() -> ClientOptions { + let mut opts = ClientOptions::default(); + opts.cwd = std::env::current_dir().expect("cwd"); + opts +} + +#[tokio::test] +#[ignore] // requires `copilot` CLI on PATH โ€” run with `cargo test -- --ignored` +async fn start_ping_stop() { + let client = Client::start(default_options()) + .await + .expect("failed to start copilot CLI"); + + // start() calls verify_protocol_version(), so this should be set + let version = client + .protocol_version() + .expect("protocol version not negotiated"); + assert!((2..=SDK_PROTOCOL_VERSION).contains(&version)); + + client.ping(None).await.expect("ping failed"); + client.stop().await.expect("stop failed"); +} + +#[tokio::test] +#[ignore] // requires `copilot` CLI on PATH โ€” run with `cargo test -- --ignored` +async fn force_stop_kills_real_child() { + let client = Client::start(default_options()) + .await + .expect("failed to start copilot CLI"); + + let pid = client.pid().expect("expected a CLI child pid"); + assert!(pid > 0); + + // force_stop is synchronous and must not panic. After it returns, + // pid() should report None because we've taken the child out of the + // mutex. + client.force_stop(); + assert!(client.pid().is_none()); + + // Calling it again should be a no-op rather than panicking. + client.force_stop(); +} + +/// Measures the latency of individual CLI operations that contribute to +/// session creation time. Run with: +/// +/// cargo test -p github-copilot-sdk --test integration_test -- --ignored --nocapture +#[tokio::test] +#[ignore] +async fn cli_operation_latency() { + // Cold start: spawn CLI process + verify protocol version + let t0 = Instant::now(); + let client = Client::start(default_options()) + .await + .expect("cold start failed"); + let cold_start = t0.elapsed(); + + // Warm ping: RPC round-trip on an already-running process + let t1 = Instant::now(); + client.ping(None).await.expect("warm ping failed"); + let warm_ping = t1.elapsed(); + + // list_models: RPC that fetches available models from the CLI + let t2 = Instant::now(); + let models = client.list_models().await.expect("list_models failed"); + let list_models = t2.elapsed(); + + // Second list_models: does the CLI cache internally? + let t2b = Instant::now(); + let _ = client.list_models().await.expect("list_models 2 failed"); + let list_models_2 = t2b.elapsed(); + + client.stop().await.expect("stop first client failed"); + + // Second cold start: measures process spawn cost when the binary is + // already resolved and cached (no extraction overhead) + let t3 = Instant::now(); + let client2 = Client::start(default_options()) + .await + .expect("second cold start failed"); + let second_start = t3.elapsed(); + + client2.stop().await.expect("stop second client failed"); + + let (cli_path, source) = copilot_binary_with_source().expect("copilot binary not found"); + + eprintln!(); + eprintln!("=== CLI operation latency ==="); + eprintln!(" binary: {} ({:?})", cli_path.display(), source); + eprintln!(" cold Client::start: {:>8.1?}", cold_start); + eprintln!(" warm ping(): {:>8.1?}", warm_ping); + eprintln!( + " list_models() ({:>2}): {:>8.1?}", + models.len(), + list_models + ); + eprintln!(" list_models() again: {:>8.1?}", list_models_2); + eprintln!(" second Client::start: {:>8.1?}", second_start); + eprintln!(); +} diff --git a/rust/tests/jsonrpc_test.rs b/rust/tests/jsonrpc_test.rs new file mode 100644 index 000000000..7f7d43213 --- /dev/null +++ b/rust/tests/jsonrpc_test.rs @@ -0,0 +1,412 @@ +#![cfg(feature = "test-support")] +#![allow(clippy::unwrap_used)] + +use github_copilot_sdk::test_support::{JsonRpcClient, JsonRpcNotification, JsonRpcRequest}; +use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; +use tokio::sync::{broadcast, mpsc}; + +/// Write a Content-Length framed JSON-RPC message to a writer. +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +#[tokio::test] +async fn request_response_round_trip() { + // duplex: client_write โ†’ server_read, server_write โ†’ client_read + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (_request_tx, _request_rx) = mpsc::unbounded_channel(); + let request_tx = _request_tx; + + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Spawn a task that reads the request from the server side and sends a response. + let server_handle = tokio::spawn(async move { + let mut buf = Vec::new(); + // Read the Content-Length header + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + buf.resize(length, 0); + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut buf) + .await + .unwrap(); + + let request: JsonRpcRequest = serde_json::from_slice(&buf).unwrap(); + assert_eq!(request.method, "test.echo"); + assert_eq!(request.jsonrpc, "2.0"); + + // Send response + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": request.id, + "result": { "echoed": true } + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + request.id + }); + + let response = client + .send_request("test.echo", Some(serde_json::json!({"hello": "world"}))) + .await + .unwrap(); + + let request_id = server_handle.await.unwrap(); + assert_eq!(response.id, request_id); + assert!(!response.is_error()); + assert_eq!(response.result.unwrap()["echoed"], serde_json::json!(true)); +} + +#[tokio::test] +async fn notification_broadcasting() { + let (_client_write, _discard) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, mut notification_rx) = broadcast::channel(16); + let (request_tx, _request_rx) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(_client_write, client_read, notification_tx, request_tx); + + // Server sends a notification (no id field). + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { "session_id": "s1", "event": "started" } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(¬ification).unwrap(), + ) + .await; + + let received: JsonRpcNotification = + tokio::time::timeout(std::time::Duration::from_secs(2), notification_rx.recv()) + .await + .expect("timed out waiting for notification") + .unwrap(); + + assert_eq!(received.method, "session.event"); + assert_eq!(received.params.unwrap()["session_id"], "s1"); +} + +#[tokio::test] +async fn server_request_forwarding() { + let (_client_write, _discard) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(_client_write, client_read, notification_tx, request_tx); + + // Server sends a request (has both id and method). + let request = serde_json::json!({ + "jsonrpc": "2.0", + "id": 42, + "method": "permission.request", + "params": { "kind": "shell" } + }); + write_framed(&mut server_write, &serde_json::to_vec(&request).unwrap()).await; + + let received: JsonRpcRequest = + tokio::time::timeout(std::time::Duration::from_secs(2), request_rx.recv()) + .await + .expect("timed out waiting for request") + .unwrap(); + + assert_eq!(received.method, "permission.request"); + assert_eq!(received.id, 42); +} + +#[tokio::test] +async fn error_response_round_trip() { + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + let server_handle = tokio::spawn(async move { + // Read request + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut buf) + .await + .unwrap(); + let request: JsonRpcRequest = serde_json::from_slice(&buf).unwrap(); + + // Send error response + let error_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": request.id, + "error": { "code": -32600, "message": "Invalid Request" } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(&error_response).unwrap(), + ) + .await; + }); + + let response = client.send_request("bad.method", None).await.unwrap(); + server_handle.await.unwrap(); + + assert!(response.is_error()); + let error = response.error.unwrap(); + assert_eq!(error.code, -32600); + assert_eq!(error.message, "Invalid Request"); +} + +#[tokio::test] +async fn read_loop_terminates_on_eof() { + let (client_write, _discard) = duplex(4096); + let (server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Drop the server side โ€” the read loop should see EOF and stop. + drop(server_write); + + // Give the read loop time to notice EOF. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; +} + +/// Cancel-safety regression: dropping a `write()` future after the actor has +/// committed to writing must NOT produce a partial frame on the wire. +/// +/// Strategy: spawn a reader task that waits before draining the wire, so +/// the actor's `write_all` blocks waiting for room. Race the caller's +/// future against a sleep; when the sleep wins, the caller's future is +/// dropped while suspended on `ack_rx.await`. Release the reader and +/// verify both frames land on the wire intact. +/// +/// Closes RFD-400 finding #1: `JsonRpcClient::write` was holding a Tokio +/// mutex across `write_all` + `flush`, so caller cancellation mid-frame +/// could desync the transport. The writer-actor refactor moves the I/O +/// onto a dedicated task that owns the writer; caller cancellation drops +/// the ack receiver but does not interrupt the in-flight write. +#[tokio::test] +async fn write_actor_completes_on_caller_cancel() { + use std::sync::Arc; + + use tokio::sync::Notify; + + let (client_write, mut server_read) = duplex(8); + let (_server_write, client_read) = duplex(8); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Reader task that waits for `start` before draining; this gives us + // a window where the actor's write_all is suspended waiting for room. + let start = Arc::new(Notify::new()); + let start_clone = start.clone(); + let reader_task = tokio::spawn(async move { + start_clone.notified().await; + let mut frames = Vec::new(); + for _ in 0..2 { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut body = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut body) + .await + .unwrap(); + let req: JsonRpcRequest = serde_json::from_slice(&body).unwrap(); + frames.push(req); + } + frames + }); + + let frame_a = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: 100, + method: "first.write".to_string(), + params: None, + }; + let frame_b = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: 101, + method: "second.write".to_string(), + params: None, + }; + + // First write: race the future against a sleep. With the reader + // gated, the actor's write_all blocks at the 8-byte buffer boundary, + // so the future stays suspended on `ack_rx.await`. The sleep wins + // after 50ms, dropping the caller's future. The actor still owns the + // write and must complete it once the reader drains. + tokio::select! { + _ = client.write(&frame_a) => panic!("write completed too quickly to test cancellation"), + _ = tokio::time::sleep(std::time::Duration::from_millis(50)) => {} + } + + // Enqueue the second write before releasing the reader. Both frames + // are now in the actor's queue; the actor will drain them in order + // once the reader starts pulling bytes. + let second_handle = tokio::spawn({ + let frame_b = frame_b.clone(); + let client_arc = std::sync::Arc::new(client); + let client_clone = client_arc.clone(); + async move { client_clone.write(&frame_b).await } + }); + + // Release the reader so both frames can flow through the actor. + start.notify_one(); + + let frames = reader_task.await.unwrap(); + second_handle.await.unwrap().unwrap(); + + assert_eq!(frames.len(), 2); + assert_eq!(frames[0].method, "first.write"); + assert_eq!(frames[0].id, 100); + assert_eq!(frames[1].method, "second.write"); + assert_eq!(frames[1].id, 101); +} + +/// Cancel-safety regression: cancelling a `send_request` future before the +/// response arrives must NOT leak the pending-requests entry. The RAII +/// `PendingGuard` removes the entry on drop. +/// +/// Strategy: spawn `send_request`, drop the JoinHandle immediately so the +/// future is cancelled. The CLI eventually sends a response for the +/// cancelled request id; the read loop logs a warning and discards it +/// (the pending entry was already removed by the guard). The next +/// `send_request` should work normally and not collide with the orphan. +/// +/// Closes RFD-400 finding #4. +#[tokio::test] +async fn send_request_cancellation_does_not_leak_pending() { + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + let client = std::sync::Arc::new(client); + + // First request: cancel before the server replies. + let cancelled = tokio::spawn({ + let client = client.clone(); + async move { + // Will await the response oneshot; the JoinHandle abort + // below cancels this future. + let _ = client.send_request("first", None).await; + } + }); + + // Read the first request off the wire so we know it was sent. + async fn read_one_method(reader: &mut tokio::io::DuplexStream) -> (u64, String) { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(reader, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut body = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(reader, &mut body) + .await + .unwrap(); + let req: JsonRpcRequest = serde_json::from_slice(&body).unwrap(); + (req.id, req.method) + } + + let (first_id, first_method) = read_one_method(&mut server_read).await; + assert_eq!(first_method, "first"); + + // Now cancel the in-flight request. + cancelled.abort(); + let _ = cancelled.await; + + // Send a (late) response for the cancelled id. The read loop should + // log a warning and not blow up. + let stale_resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": first_id, + "result": {"echo": "ignored"} + }); + write_framed(&mut server_write, &serde_json::to_vec(&stale_resp).unwrap()).await; + + // Second request: should succeed normally without collision. + let server_task = tokio::spawn(async move { + let (id, method) = read_one_method(&mut server_read).await; + assert_eq!(method, "second"); + let resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": {"ok": true} + }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + }); + + let response = client.send_request("second", None).await.unwrap(); + assert_eq!(response.result.unwrap()["ok"], true); + server_task.await.unwrap(); +} diff --git a/rust/tests/protocol_version_test.rs b/rust/tests/protocol_version_test.rs new file mode 100644 index 000000000..fd4eecada --- /dev/null +++ b/rust/tests/protocol_version_test.rs @@ -0,0 +1,241 @@ +#![allow(clippy::unwrap_used)] + +use github_copilot_sdk::Client; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; + +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +async fn read_framed(reader: &mut (impl tokio::io::AsyncRead + Unpin)) -> serde_json::Value { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + AsyncReadExt::read_exact(reader, &mut byte).await.unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + AsyncReadExt::read_exact(reader, &mut buf).await.unwrap(); + serde_json::from_slice(&buf).unwrap() +} + +/// Verify protocol version against a fake server. Mimics a legacy server +/// that lacks the `connect` JSON-RPC method (-32601 MethodNotFound), +/// forcing the client to fall back to `ping` โ€” the canonical +/// backward-compatibility path documented on `verify_protocol_version`. +async fn verify_with_result( + result: serde_json::Value, +) -> (Result<(), github_copilot_sdk::Error>, Option) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + + let mut server_read = server_read; + let mut server_write = server_write; + + let verify_handle = tokio::spawn({ + let client = client.clone(); + async move { client.verify_protocol_version().await } + }); + + // 1. Client sends `connect` first; respond with MethodNotFound so the + // client falls back to `ping` (the legacy-server compatibility path). + let connect_req = read_framed(&mut server_read).await; + assert_eq!(connect_req["method"], "connect"); + let not_found = serde_json::json!({ + "jsonrpc": "2.0", + "id": connect_req["id"], + "error": { "code": -32601, "message": "Method not found" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(¬_found).unwrap()).await; + + // 2. Client falls back to `ping`; respond with the requested result. + let req = read_framed(&mut server_read).await; + assert_eq!(req["method"], "ping"); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": req["id"], + "result": result, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let res = tokio::time::timeout(std::time::Duration::from_secs(2), verify_handle) + .await + .unwrap() + .unwrap(); + let version = client.protocol_version(); + (res, version) +} + +#[tokio::test] +async fn accepted_when_version_in_range() { + let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 3 })).await; + assert!(res.is_ok()); + assert_eq!(version, Some(3)); +} + +#[tokio::test] +async fn rejected_when_version_out_of_range() { + let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 1 })).await; + let err = res.unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Protocol(github_copilot_sdk::ProtocolError::VersionMismatch { + server: 1, + .. + }) + )); + assert_eq!(version, None); +} + +#[tokio::test] +async fn succeeds_when_version_missing() { + let (res, version) = verify_with_result(serde_json::json!({ "message": "pong" })).await; + assert!(res.is_ok()); + assert_eq!(version, None); +} + +/// New `connect` handshake path: when the server supports `connect` (modern +/// CLIs do), the client uses it directly without falling back to `ping`. +/// Validates the protocolVersion negotiated through the new RPC. +#[tokio::test] +async fn connect_handshake_supplies_protocol_version() { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + + let mut server_read = server_read; + let mut server_write = server_write; + + let verify_handle = tokio::spawn({ + let client = client.clone(); + async move { client.verify_protocol_version().await } + }); + + let req = read_framed(&mut server_read).await; + assert_eq!(req["method"], "connect"); + // Token is None for the from_streams entry point (no transport spawn). + assert!(req["params"].get("token").is_none()); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": req["id"], + "result": { "ok": true, "protocolVersion": 3, "version": "test-1.0.0" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let res = tokio::time::timeout(std::time::Duration::from_secs(2), verify_handle) + .await + .unwrap() + .unwrap(); + assert!(res.is_ok()); + assert_eq!(client.protocol_version(), Some(3)); +} + +/// Positive coverage for token forwarding on the `connect` handshake. A +/// client constructed with a preset `effective_connection_token` MUST +/// place the exact token string in the outbound `connect` request's +/// `token` param. This is the wire-side hand-off that authenticates +/// the SDK to a CLI server started with `COPILOT_CONNECTION_TOKEN`. +#[tokio::test] +async fn connect_handshake_forwards_explicit_token() { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams_with_connection_token( + client_read, + client_write, + std::env::temp_dir(), + Some("explicit-token-abc".to_string()), + ) + .unwrap(); + + let mut server_read = server_read; + let mut server_write = server_write; + + let verify_handle = tokio::spawn({ + let client = client.clone(); + async move { client.verify_protocol_version().await } + }); + + let req = read_framed(&mut server_read).await; + assert_eq!(req["method"], "connect"); + assert_eq!(req["params"]["token"], "explicit-token-abc"); + + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": req["id"], + "result": { "ok": true, "protocolVersion": 3, "version": "test-1.0.0" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + tokio::time::timeout(std::time::Duration::from_secs(2), verify_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +/// Auto-generated tokens (the codepath that fires when the SDK spawns +/// its own CLI in TCP mode and the consumer didn't supply one) must +/// reach the wire too. Builds a token via the SDK's exposed test helper +/// and verifies the same string lands in the outbound `connect`. +#[tokio::test] +async fn connect_handshake_forwards_auto_generated_token() { + let token = Client::generate_connection_token_for_test(); + // Sanity-check the generated shape: 32-char lowercase hex (16 bytes, + // 128 bits of entropy). A regression in the helper would silently + // weaken loopback authentication. + assert_eq!(token.len(), 32, "expected 32-char hex, got {token:?}"); + assert!( + token + .chars() + .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()), + "expected lowercase hex, got {token:?}", + ); + + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams_with_connection_token( + client_read, + client_write, + std::env::temp_dir(), + Some(token.clone()), + ) + .unwrap(); + + let mut server_read = server_read; + let mut server_write = server_write; + + let verify_handle = tokio::spawn({ + let client = client.clone(); + async move { client.verify_protocol_version().await } + }); + + let req = read_framed(&mut server_read).await; + assert_eq!(req["method"], "connect"); + assert_eq!(req["params"]["token"], token); + + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": req["id"], + "result": { "ok": true, "protocolVersion": 3, "version": "test-1.0.0" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + tokio::time::timeout(std::time::Duration::from_secs(2), verify_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs new file mode 100644 index 000000000..434ca1eec --- /dev/null +++ b/rust/tests/session_test.rs @@ -0,0 +1,3734 @@ +#![allow(clippy::unwrap_used)] + +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::Client; +use github_copilot_sdk::handler::{ + ApproveAllHandler, AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, + PermissionResult, SessionHandler, UserInputResponse, +}; +use github_copilot_sdk::types::{ + CommandContext, CommandDefinition, CommandHandler, DeliveryMode, MessageOptions, + ServerTelemetryEvent, SessionConfig, SessionId, SessionTelemetryEvent, ToolResult, +}; +use serde_json::Value; +use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; +use tokio::sync::mpsc; +use tokio::time::timeout; + +const TIMEOUT: Duration = Duration::from_secs(2); +const METHOD_NOT_FOUND: i32 = -32601; + +struct NoopHandler; +#[async_trait] +impl SessionHandler for NoopHandler { + async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { + HandlerResponse::Ok + } +} + +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +async fn read_framed(reader: &mut (impl tokio::io::AsyncRead + Unpin)) -> Value { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(reader, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(reader, &mut buf) + .await + .unwrap(); + serde_json::from_slice(&buf).unwrap() +} + +fn make_client() -> (Client, tokio::io::DuplexStream, tokio::io::DuplexStream) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + (client, server_read, server_write) +} + +struct FakeServer { + read: tokio::io::DuplexStream, + write: tokio::io::DuplexStream, + session_id: String, +} + +impl FakeServer { + async fn read_request(&mut self) -> Value { + read_framed(&mut self.read).await + } + + async fn respond(&mut self, request: &Value, result: Value) { + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result }); + write_framed(&mut self.write, &serde_json::to_vec(&response).unwrap()).await; + } + + async fn send_notification(&mut self, method: &str, params: Value) { + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + write_framed(&mut self.write, &serde_json::to_vec(¬ification).unwrap()).await; + } + + async fn send_event(&mut self, event_type: &str, data: Value) { + self.send_notification( + "session.event", + serde_json::json!({ + "sessionId": self.session_id, + "event": { + "id": format!("evt-{}", rand_id()), + "timestamp": "2025-01-01T00:00:00Z", + "type": event_type, + "data": data, + }, + }), + ) + .await; + } + + async fn send_request(&mut self, id: u64, method: &str, params: Value) { + let request = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + }); + write_framed(&mut self.write, &serde_json::to_vec(&request).unwrap()).await; + } + + async fn read_response(&mut self) -> Value { + read_framed(&mut self.read).await + } +} + +async fn create_session_pair( + handler: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_with_capabilities(handler, serde_json::json!(null)).await +} + +async fn create_session_pair_with_capabilities( + handler: Arc, + capabilities: Value, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(handler)) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + let mut result = serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }); + if !capabilities.is_null() { + result["capabilities"] = capabilities; + } + server.respond(&create_req, result).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +fn rand_id() -> u64 { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + COUNTER.fetch_add(1, Ordering::Relaxed) as u64 +} + +#[tokio::test] +async fn session_subscribe_yields_events_observe_only() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + let mut events = session.subscribe(); + let count = Arc::new(AtomicUsize::new(0)); + let last_type = Arc::new(parking_lot::Mutex::new(String::new())); + let count_clone = count.clone(); + let last_type_clone = last_type.clone(); + let consumer = tokio::spawn(async move { + while let Ok(event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + *last_type_clone.lock() = event.event_type.clone(); + } + }); + + server.send_event("noop.event", serde_json::json!({})).await; + server + .send_event("another.event", serde_json::json!({"k": "v"})) + .await; + + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 2); + assert_eq!(last_type.lock().as_str(), "another.event"); + consumer.abort(); +} + +#[tokio::test] +async fn session_subscribe_drop_stops_delivery() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + let mut events = session.subscribe(); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + let consumer = tokio::spawn(async move { + while let Ok(_event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + } + }); + + server.send_event("first", serde_json::json!({})).await; + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 1); + + // Aborting the consumer drops its receiver; further events have no + // effect on the (now-zero) subscriber count. + consumer.abort(); + tokio::time::sleep(Duration::from_millis(20)).await; + + server.send_event("second", serde_json::json!({})).await; + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(count.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn create_session_sends_correct_rpc() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session({ + let mut cfg = SessionConfig::default(); + cfg.model = Some("gpt-4".to_string()); + cfg.with_handler(Arc::new(NoopHandler)) + }) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["model"], "gpt-4"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s1", "workspacePath": "/ws" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(session.id(), "s1"); + assert_eq!(session.workspace_path(), Some(Path::new("/ws"))); +} + +#[tokio::test] +async fn send_injects_session_id() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("hello").with_mode(DeliveryMode::Immediate)) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["prompt"], "hello"); + assert_eq!(request["params"]["mode"], "immediate"); + + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn send_serializes_request_headers() { + use std::collections::HashMap; + + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Tag".to_string(), "value-1".to_string()); + headers.insert("Authorization".to_string(), "Bearer abc".to_string()); + session + .send(MessageOptions::new("hi").with_request_headers(headers)) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "hi"); + let headers = request["params"]["requestHeaders"] + .as_object() + .expect("requestHeaders should be an object"); + assert_eq!(headers["X-Custom-Tag"], "value-1"); + assert_eq!(headers["Authorization"], "Bearer abc"); + assert_eq!(headers.len(), 2); + + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn send_omits_request_headers_when_unset_or_empty() { + use std::collections::HashMap; + + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("plain")).await } + }); + let request = server.read_request().await; + assert!( + request["params"].get("requestHeaders").is_none(), + "requestHeaders should be omitted when unset, got: {}", + request["params"] + ); + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("plain").with_request_headers(HashMap::new())) + .await + } + }); + let request = server.read_request().await; + assert!( + request["params"].get("requestHeaders").is_none(), + "requestHeaders should be omitted for empty map, got: {}", + request["params"] + ); + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn session_rpc_methods_send_correct_method_names() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let cases: Vec<(&str, Option<&str>)> = vec![ + ("session.abort", None), + ("session.plan.delete", None), + ("session.log", Some("message")), + ("session.sendTelemetry", Some("kind")), + ("session.destroy", None), + ]; + + for (expected_method, extra_param_key) in cases { + let s = session.clone(); + let handle = tokio::spawn(async move { + match expected_method { + "session.abort" => s.abort().await.map(|_| ()), + "session.plan.delete" => s.delete_plan().await, + "session.log" => s.log("test msg", None).await, + "session.sendTelemetry" => { + s.send_telemetry(SessionTelemetryEvent { + kind: "sdk_test_event".to_string(), + properties: Some( + [("source".to_string(), "sdk".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: None, + }) + .await + } + "session.destroy" => s.destroy().await, + _ => unreachable!(), + } + }); + + let request = server.read_request().await; + assert_eq!( + request["method"], expected_method, + "wrong method for {expected_method}" + ); + assert_eq!(request["params"]["sessionId"], server.session_id); + if let Some(key) = extra_param_key { + assert!(!request["params"][key].is_null(), "missing param {key}"); + } + let response = match expected_method { + "session.log" => { + serde_json::json!({ "eventId": "00000000-0000-0000-0000-000000000000" }) + } + _ => serde_json::json!({}), + }; + server.respond(&request, response).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + } +} + +#[tokio::test] +async fn send_telemetry_injects_payload_and_session_id() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_telemetry(SessionTelemetryEvent { + kind: "sdk_test_event".to_string(), + properties: Some( + [ + ("source".to_string(), "sdk".to_string()), + ("feature".to_string(), "shared-api".to_string()), + ] + .into_iter() + .collect(), + ), + restricted_properties: Some( + [("file_path".to_string(), "/tmp/example.ts".to_string())] + .into_iter() + .collect(), + ), + metrics: Some( + [ + ("count".to_string(), 1.0), + ("duration_ms".to_string(), 12.5), + ] + .into_iter() + .collect(), + ), + }) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.sendTelemetry"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["kind"], "sdk_test_event"); + assert_eq!(request["params"]["properties"]["source"], "sdk"); + assert_eq!( + request["params"]["restrictedProperties"]["file_path"], + "/tmp/example.ts" + ); + assert_eq!(request["params"]["metrics"]["duration_ms"], 12.5); + + server.respond(&request, serde_json::json!(null)).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn client_rpc_methods_send_correct_method_names() { + let (client, mut server_read, mut server_write) = make_client(); + + // Wire method names per the CLI runtime registration in @github/copilot + // app.js โ€” verified against Node/Go/Python/.NET SDK call sites which all + // use these exact strings. The schema doesn't currently define these as + // typed RPCs (top-level methods, not under any namespace), so call site + // strings are the source of truth. + for expected_method in ["status.get", "auth.getStatus"] { + let c = client.clone(); + let handle = tokio::spawn(async move { + match expected_method { + "status.get" => c.get_status().await.map(|_| ()), + "auth.getStatus" => c.get_auth_status().await.map(|_| ()), + _ => unreachable!(), + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], expected_method); + // Regression-prevention: must not have reverted to the + // hand-authored `getStatus` / `getAuthStatus` names that don't + // exist on the wire. + assert_ne!(request["method"], "getStatus"); + assert_ne!(request["method"], "getAuthStatus"); + let id = request["id"].as_u64().unwrap(); + let result = match expected_method { + "status.get" => serde_json::json!({ "version": "1.0.0", "protocolVersion": 1 }), + "auth.getStatus" => serde_json::json!({ "isAuthenticated": true }), + _ => unreachable!(), + }; + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + } +} + +#[tokio::test] +async fn server_send_telemetry_sends_correct_payload() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.launched".to_string(), + client_name: "github/autopilot".to_string(), + properties: Some( + [("machine_id".to_string(), "machine-123".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: Some([("launch_count".to_string(), 1.0)].into_iter().collect()), + }) + .await + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "sendTelemetry"); + assert_eq!(request["params"]["kind"], "app.launched"); + assert_eq!(request["params"]["clientName"], "github/autopilot"); + assert_eq!(request["params"]["properties"]["machine_id"], "machine-123"); + assert_eq!(request["params"]["metrics"]["launch_count"], 1.0); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": null }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn server_send_telemetry_falls_back_to_namespaced_method_and_caches_it() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.launched".to_string(), + client_name: "github/autopilot".to_string(), + properties: Some( + [("machine_id".to_string(), "machine-123".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: Some([("launch_count".to_string(), 1.0)].into_iter().collect()), + }) + .await?; + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.closed".to_string(), + client_name: "github/autopilot".to_string(), + properties: None, + restricted_properties: None, + metrics: None, + }) + .await + } + }); + + let first_request = read_framed(&mut server_read).await; + assert_eq!(first_request["method"], "sendTelemetry"); + let first_id = first_request["id"].as_u64().unwrap(); + let first_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": first_id, + "error": { + "code": METHOD_NOT_FOUND, + "message": "Unhandled method sendTelemetry" + } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(&first_response).unwrap(), + ) + .await; + + let second_request = read_framed(&mut server_read).await; + assert_eq!(second_request["method"], "server.sendTelemetry"); + assert_eq!(second_request["params"]["kind"], "app.launched"); + assert_eq!(second_request["params"]["clientName"], "github/autopilot"); + assert_eq!( + second_request["params"]["properties"]["machine_id"], + "machine-123" + ); + assert_eq!(second_request["params"]["metrics"]["launch_count"], 1.0); + + let second_id = second_request["id"].as_u64().unwrap(); + let second_response = serde_json::json!({ "jsonrpc": "2.0", "id": second_id, "result": null }); + write_framed( + &mut server_write, + &serde_json::to_vec(&second_response).unwrap(), + ) + .await; + + let third_request = read_framed(&mut server_read).await; + assert_eq!(third_request["method"], "server.sendTelemetry"); + assert_eq!(third_request["params"]["kind"], "app.closed"); + + let third_id = third_request["id"].as_u64().unwrap(); + let third_response = serde_json::json!({ "jsonrpc": "2.0", "id": third_id, "result": null }); + write_framed( + &mut server_write, + &serde_json::to_vec(&third_response).unwrap(), + ) + .await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn list_sessions_returns_typed_metadata() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_sessions(None).await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.list"); + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessions": [{ + "sessionId": "s1", + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-01T01:00:00Z", + "summary": "test session", + "isRemote": false, + }] + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let sessions = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].session_id, "s1"); + assert_eq!(sessions[0].summary, Some("test session".to_string())); +} + +#[tokio::test] +async fn list_sessions_serializes_typed_filter() { + use github_copilot_sdk::SessionListFilter; + + let (client, mut server_read, mut server_write) = make_client(); + + let filter = SessionListFilter { + repository: Some("octocat/hello".to_string()), + branch: Some("main".to_string()), + ..Default::default() + }; + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_sessions(Some(filter)).await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.list"); + assert_eq!(request["params"]["filter"]["repository"], "octocat/hello"); + assert_eq!(request["params"]["filter"]["branch"], "main"); + // cwd / gitRoot are None and must be omitted from the filter object. + assert!(request["params"]["filter"].get("cwd").is_none()); + assert!(request["params"]["filter"].get("gitRoot").is_none()); + // Regression check: filter must be wrapped under `params.filter`, not + // flattened onto `params` directly. All other SDKs (Node/Python/Go/.NET) + // wrap; flattening is silently ignored by the runtime. + assert!( + request["params"].get("repository").is_none(), + "wire shape is `params.filter.*`, not `params.*` โ€” see Node/Go/Python/.NET" + ); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessions": [] }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[test] +fn mcp_server_config_roundtrips_through_tagged_enum() { + use std::collections::HashMap; + + use github_copilot_sdk::{McpServerConfig, McpStdioServerConfig}; + + let stdio = McpServerConfig::Stdio(McpStdioServerConfig { + command: "node".to_string(), + args: vec!["server.js".to_string()], + env: HashMap::new(), + cwd: None, + tools: vec!["*".to_string()], + timeout: None, + }); + let json = serde_json::to_value(&stdio).unwrap(); + assert_eq!(json["type"], "stdio"); + assert_eq!(json["command"], "node"); + + // CLI may emit the legacy "local" alias; we accept it on the wire. + let local: McpServerConfig = serde_json::from_value(serde_json::json!({ + "type": "local", + "command": "node", + })) + .unwrap(); + assert!(matches!(local, McpServerConfig::Stdio(_))); + + // SessionConfig.mcp_servers round-trips a typed map. + let mut servers = HashMap::new(); + servers.insert("github".to_string(), stdio.clone()); + let cfg_json = serde_json::to_value(&servers).unwrap(); + assert_eq!(cfg_json["github"]["type"], "stdio"); +} + +#[test] +fn permission_request_data_extracts_typed_kind() { + use github_copilot_sdk::{PermissionRequestData, PermissionRequestKind}; + + let data: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "shell", + "toolCallId": "t1", + "command": "ls", + })) + .unwrap(); + assert_eq!(data.kind, Some(PermissionRequestKind::Shell)); + assert_eq!(data.tool_call_id, Some("t1".to_string())); + assert_eq!(data.extra["command"], "ls"); + + let custom: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "custom-tool", + })) + .unwrap(); + assert_eq!(custom.kind, Some(PermissionRequestKind::CustomTool)); + + // Unknown kinds fall through to the catch-all variant rather than failing. + let unknown: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "future-permission-type", + })) + .unwrap(); + assert_eq!(unknown.kind, Some(PermissionRequestKind::Unknown)); +} + +#[tokio::test] +async fn force_stop_is_idempotent_with_no_child() { + // Stream-based clients have no child process. force_stop should be a + // no-op and safe to call multiple times. + let (client, _server_read, _server_write) = make_client(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Connected + ); + client.force_stop(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); + client.force_stop(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); + assert!(client.pid().is_none()); +} + +#[tokio::test] +async fn stop_transitions_state_to_disconnected() { + let (client, _server_read, _server_write) = make_client(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Connected + ); + client.stop().await.expect("stop should succeed"); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); +} + +#[tokio::test] +async fn lifecycle_subscribe_yields_events_with_filter() { + use github_copilot_sdk::{SessionLifecycleEventMetadata, SessionLifecycleEventType as Type}; + + let (client, _server_read, mut server_write) = make_client(); + + let mut all_events = client.subscribe_lifecycle(); + let mut foreground_events = client.subscribe_lifecycle(); + + let wildcard_count = Arc::new(AtomicUsize::new(0)); + let foreground_count = Arc::new(AtomicUsize::new(0)); + let last_session = Arc::new(parking_lot::Mutex::new(None)); + + let w_count = wildcard_count.clone(); + let w_last = last_session.clone(); + let w_consumer = tokio::spawn(async move { + while let Ok(event) = all_events.recv().await { + w_count.fetch_add(1, Ordering::Relaxed); + *w_last.lock() = Some(event.session_id.clone()); + } + }); + let f_count = foreground_count.clone(); + let f_consumer = tokio::spawn(async move { + while let Ok(event) = foreground_events.recv().await { + if event.event_type == Type::Foreground { + f_count.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let body1 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { "type": "session.created", "sessionId": "s1" }, + })) + .unwrap(); + let body2 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { + "type": "session.foreground", + "sessionId": "s2", + "metadata": { + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-02T00:00:00Z", + "summary": "hello", + }, + }, + })) + .unwrap(); + let body3 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { "sessionId": "ignored", "event": { + "id": "x", "timestamp": "t", "type": "noop", "data": {} + }}, + })) + .unwrap(); + write_framed(&mut server_write, &body1).await; + write_framed(&mut server_write, &body2).await; + write_framed(&mut server_write, &body3).await; + + for _ in 0..50 { + if wildcard_count.load(Ordering::Relaxed) >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(wildcard_count.load(Ordering::Relaxed), 2); + assert_eq!(foreground_count.load(Ordering::Relaxed), 1); + assert_eq!(last_session.lock().as_deref(), Some("s2")); + w_consumer.abort(); + f_consumer.abort(); + + let meta = SessionLifecycleEventMetadata { + start_time: "t1".into(), + modified_time: "t2".into(), + summary: Some("s".into()), + }; + assert_eq!(meta.summary.as_deref(), Some("s")); +} + +#[tokio::test] +async fn lifecycle_subscribe_drop_stops_delivery() { + let (client, _server_read, mut server_write) = make_client(); + + let mut events = client.subscribe_lifecycle(); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + let consumer = tokio::spawn(async move { + while let Ok(_event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + } + }); + + let lifecycle_body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { "type": "session.created", "sessionId": "x" }, + })) + .unwrap(); + + write_framed(&mut server_write, &lifecycle_body).await; + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 1); + + consumer.abort(); + tokio::time::sleep(Duration::from_millis(20)).await; + + write_framed(&mut server_write, &lifecycle_body).await; + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(count.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn delete_session_sends_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.delete_session(&SessionId::new("s-to-delete")).await } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.delete"); + assert_eq!(request["params"]["sessionId"], "s-to-delete"); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn get_last_session_id_returns_none_when_empty() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_last_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getLastId"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let last = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert!(last.is_none()); +} + +#[tokio::test] +async fn get_last_session_id_returns_id_when_set() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_last_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getLastId"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-last" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let last = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(last.as_deref(), Some("s-last")); +} + +#[tokio::test] +async fn get_foreground_session_id_returns_id_when_set() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_foreground_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getForeground"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-fg" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let fg = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(fg.as_deref(), Some("s-fg")); +} + +#[tokio::test] +async fn set_foreground_session_id_sends_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .set_foreground_session_id(&SessionId::new("s-target")) + .await + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.setForeground"); + assert_eq!(request["params"]["sessionId"], "s-target"); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn get_session_metadata_returns_typed_metadata() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .get_session_metadata(&SessionId::new("s1")) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getMetadata"); + assert_eq!(request["params"]["sessionId"], "s1"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "session": { + "sessionId": "s1", + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-01T01:00:00Z", + "summary": "loaded session", + "isRemote": false, + } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let metadata = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + let metadata = metadata.expect("server returned a session"); + assert_eq!(metadata.session_id, "s1"); + assert_eq!(metadata.summary.as_deref(), Some("loaded session")); +} + +#[tokio::test] +async fn get_session_metadata_returns_none_when_missing() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .get_session_metadata(&SessionId::new("missing")) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getMetadata"); + + let id = request["id"].as_u64().unwrap(); + // Server responds with an empty result object; `session` is absent. + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": {}, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let metadata = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert!(metadata.is_none()); +} + +#[tokio::test] +async fn list_models_returns_typed_model_info() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_models().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "models.list"); + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "models": [ + { "id": "gpt-4", "name": "GPT-4", "capabilities": {} }, + { "id": "claude-sonnet-4", "name": "Claude Sonnet", "capabilities": {} }, + ] + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let models = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(models.len(), 2); + assert_eq!(models[0].id, "gpt-4"); + assert_eq!(models[1].name, "Claude Sonnet"); +} + +#[tokio::test] +async fn get_messages_returns_typed_events() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_messages().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.getMessages"); + server + .respond( + &request, + serde_json::json!({ + "events": [{ + "id": "e1", + "timestamp": "2025-01-01T00:00:00Z", + "type": "user.message", + "data": { "text": "hello" }, + }] + }), + ) + .await; + + let events = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type, "user.message"); +} + +#[tokio::test] +async fn set_model_sends_switch_to_request() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.set_model("claude-sonnet-4", None).await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.model.switchTo"); + assert_eq!(request["params"]["modelId"], "claude-sonnet-4"); + server + .respond( + &request, + serde_json::json!({ "modelId": "claude-sonnet-4" }), + ) + .await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn get_name_returns_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_name().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.name.get"); + server + .respond(&request, serde_json::json!({ "name": "Fix input flicker" })) + .await; + + assert_eq!( + timeout(TIMEOUT, handle).await.unwrap().unwrap(), + Some("Fix input flicker".to_string()) + ); +} + +#[tokio::test] +async fn set_name_sends_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.set_name("Fix input flicker").await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.name.set"); + assert_eq!(request["params"]["name"], "Fix input flicker"); + server.respond(&request, serde_json::json!(null)).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn elicitation_returns_typed_result() { + let (session, mut server) = create_session_pair_with_capabilities( + Arc::new(NoopHandler), + serde_json::json!({ "ui": { "elicitation": true } }), + ) + .await; + let session = Arc::new(session); + let schema = serde_json::json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + }); + + let handle = tokio::spawn({ + let session = session.clone(); + let schema = schema.clone(); + async move { + session + .ui() + .elicitation("Enter your name", schema) + .await + .unwrap() + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.ui.elicitation"); + assert_eq!(request["params"]["message"], "Enter your name"); + assert_eq!(request["params"]["requestedSchema"], schema); + assert!( + request["params"].get("schema").is_none(), + "wire field is `requestedSchema`, not `schema`" + ); + server + .respond( + &request, + serde_json::json!({ "action": "accept", "content": { "name": "Octocat" } }), + ) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(result.action, "accept"); + assert_eq!(result.content.unwrap()["name"], "Octocat"); +} + +#[tokio::test] +async fn tool_call_dispatches_to_handler() { + struct ToolHandler; + #[async_trait] + impl SessionHandler for ToolHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExternalTool { invocation } => { + assert_eq!(invocation.tool_name, "read_file"); + HandlerResponse::ToolResult(ToolResult::Text("file contents here".to_string())) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ToolHandler)).await; + server + .send_request( + 100, + "tool.call", + serde_json::json!({ + "sessionId": server.session_id, + "toolCallId": "tc-1", + "toolName": "read_file", + "arguments": { "path": "/foo.txt" }, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 100); + assert_eq!(response["result"]["result"], "file contents here"); +} + +#[tokio::test] +async fn permission_request_dispatches_to_handler() { + struct DenyHandler; + #[async_trait] + impl SessionHandler for DenyHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Denied) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(DenyHandler)).await; + server + .send_request( + 200, + "permission.request", + serde_json::json!({ + "sessionId": server.session_id, + "requestId": "perm-1", + "kind": "shell", + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 200); + assert_eq!(response["result"]["kind"], "reject"); +} + +#[tokio::test] +async fn user_input_request_dispatches_to_handler() { + struct InputHandler; + #[async_trait] + impl SessionHandler for InputHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { question, .. } => { + assert_eq!(question, "Pick a color"); + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "blue".to_string(), + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(InputHandler)).await; + server + .send_request( + 300, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "Pick a color", + "choices": ["red", "blue"], + "allowFreeform": true, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 300); + assert_eq!(response["result"]["answer"], "blue"); + assert_eq!(response["result"]["wasFreeform"], true); +} + +#[tokio::test] +async fn user_input_requested_notification_does_not_double_dispatch() { + use std::sync::atomic::{AtomicUsize, Ordering}; + // Regression for github/github-app#4249. The CLI sends BOTH a + // `user_input.requested` notification (for observers) AND a + // `userInput.request` JSON-RPC call (the actual prompt) for every + // user-input prompt. Only the JSON-RPC path should reach the + // handler โ€” dispatching from the notification too produced + // duplicate ask_user widgets on the consumer side. + + struct CountingHandler { + invocations: Arc, + } + #[async_trait] + impl SessionHandler for CountingHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::UserInput { .. } = event { + self.invocations.fetch_add(1, Ordering::SeqCst); + return HandlerResponse::UserInput(Some(UserInputResponse { + answer: "ok".to_string(), + was_freeform: true, + })); + } + HandlerResponse::Ok + } + } + + let invocations = Arc::new(AtomicUsize::new(0)); + let handler = Arc::new(CountingHandler { + invocations: invocations.clone(), + }); + let (_session, mut server) = create_session_pair(handler).await; + + server + .send_event( + "user_input.requested", + serde_json::json!({ + "requestId": "ui-1", + "question": "Allow shell access?", + "choices": ["Yes", "No"], + "allowFreeform": false, + }), + ) + .await; + + // Give the SDK a beat to (incorrectly) auto-dispatch if the + // regression returned. Nothing should arrive on the wire. + let respond_observed = timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + respond_observed.is_err(), + "notification triggered unexpected wire activity: {respond_observed:?}", + ); + assert_eq!( + invocations.load(Ordering::SeqCst), + 0, + "notification path must not invoke the user-input handler", + ); + + // Now drive the JSON-RPC path and confirm the handler still runs once. + server + .send_request( + 301, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "Pick a color", + "allowFreeform": true, + }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 301); + assert_eq!(response["result"]["answer"], "ok"); + assert_eq!(invocations.load(Ordering::SeqCst), 1); +} + +#[tokio::test] +async fn exit_plan_mode_dispatches_to_handler() { + struct PlanHandler; + #[async_trait] + impl SessionHandler for PlanHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExitPlanMode { .. } => { + HandlerResponse::ExitPlanMode(ExitPlanModeResult { + approved: true, + selected_action: Some("autopilot".to_string()), + feedback: None, + }) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(PlanHandler)).await; + server + .send_request( + 400, + "exitPlanMode.request", + serde_json::json!({ "sessionId": server.session_id, "plan": "do the thing" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["approved"], true); + assert_eq!(response["result"]["selectedAction"], "autopilot"); +} + +#[tokio::test] +async fn auto_mode_switch_dispatches_to_handler_and_serializes_response() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct AutoModeHandler { + calls: Arc, + last_error_code: Arc>>, + last_retry_after: Arc>>, + } + #[async_trait] + impl SessionHandler for AutoModeHandler { + async fn on_auto_mode_switch( + &self, + _session_id: github_copilot_sdk::types::SessionId, + error_code: Option, + retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + self.calls.fetch_add(1, Ordering::SeqCst); + *self.last_error_code.lock() = error_code; + *self.last_retry_after.lock() = retry_after_seconds; + AutoModeSwitchResponse::YesAlways + } + } + + let calls = Arc::new(AtomicUsize::new(0)); + let last_error_code = Arc::new(parking_lot::Mutex::new(None)); + let last_retry_after = Arc::new(parking_lot::Mutex::new(None)); + let (_session, mut server) = create_session_pair(Arc::new(AutoModeHandler { + calls: calls.clone(), + last_error_code: last_error_code.clone(), + last_retry_after: last_retry_after.clone(), + })) + .await; + + server + .send_request( + 700, + "autoModeSwitch.request", + serde_json::json!({ + "sessionId": server.session_id, + "errorCode": "user_weekly_rate_limited", + "retryAfterSeconds": 3600, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 700); + assert_eq!(response["result"]["response"], "yes_always"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert_eq!( + last_error_code.lock().as_deref(), + Some("user_weekly_rate_limited") + ); + assert_eq!(*last_retry_after.lock(), Some(3600)); +} + +#[tokio::test] +async fn auto_mode_switch_default_handler_replies_no() { + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + + server + .send_request( + 701, + "autoModeSwitch.request", + serde_json::json!({ + "sessionId": server.session_id, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["response"], "no"); +} + +#[tokio::test] +async fn approve_all_handler_approves_permission_and_plan() { + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + + server + .send_request( + 500, + "permission.request", + serde_json::json!({ + "sessionId": server.session_id, + "requestId": "perm-auto", + "kind": "shell", + }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["kind"], "approve-once"); + + server + .send_request( + 501, + "exitPlanMode.request", + serde_json::json!({ "sessionId": server.session_id, "plan": "go" }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["approved"], true); +} + +#[tokio::test] +async fn session_event_notification_reaches_handler() { + let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); + + struct EventCollector { + tx: mpsc::UnboundedSender, + } + #[async_trait] + impl SessionHandler for EventCollector { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::SessionEvent { event, .. } = event { + self.tx.send(event.event_type).unwrap(); + } + HandlerResponse::Ok + } + } + + let (_session, mut server) = + create_session_pair(Arc::new(EventCollector { tx: event_tx })).await; + server + .send_event("session.idle", serde_json::json!({})) + .await; + + let event_type = timeout(TIMEOUT, event_rx.recv()).await.unwrap().unwrap(); + assert_eq!(event_type, "session.idle"); +} + +#[tokio::test] +async fn router_routes_to_correct_session() { + let (client, mut server_read, mut server_write) = make_client(); + let (tx1, mut rx1) = mpsc::unbounded_channel::(); + let (tx2, mut rx2) = mpsc::unbounded_channel::(); + + struct Collector { + tx: mpsc::UnboundedSender, + } + #[async_trait] + impl SessionHandler for Collector { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::SessionEvent { event, .. } = event { + self.tx.send(event.event_type).unwrap(); + } + HandlerResponse::Ok + } + } + + // Create two sessions on the same client + let mut sessions = Vec::new(); + for (tx, sid) in [(tx1, "s-one"), (tx2, "s-two")] { + let h = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default().with_handler(Arc::new(Collector { tx })), + ) + .await + .unwrap() + } + }); + let req = read_framed(&mut server_read).await; + let id = req["id"].as_u64().unwrap(); + let resp = serde_json::json!({ + "jsonrpc": "2.0", "id": id, + "result": { "sessionId": sid }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + sessions.push(timeout(TIMEOUT, h).await.unwrap().unwrap()); + } + + // Event for s-two should only reach rx2 + let notif = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { + "sessionId": "s-two", + "event": { "id": "e1", "timestamp": "2025-01-01T00:00:00Z", "type": "assistant.message", "data": {} }, + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; + assert_eq!( + timeout(TIMEOUT, rx2.recv()).await.unwrap().unwrap(), + "assistant.message" + ); + assert!(rx1.try_recv().is_err()); + + // Event for s-one should only reach rx1 + let notif = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { + "sessionId": "s-one", + "event": { "id": "e2", "timestamp": "2025-01-01T00:00:00Z", "type": "session.idle", "data": {} }, + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; + assert_eq!( + timeout(TIMEOUT, rx1.recv()).await.unwrap().unwrap(), + "session.idle" + ); + assert!(rx2.try_recv().is_err()); +} + +#[tokio::test] +async fn send_and_wait_returns_last_assistant_message_on_idle() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("hello").with_wait_timeout(Duration::from_secs(5)), + ) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + server.respond(&request, serde_json::json!({})).await; + + server + .send_event( + "assistant.message", + serde_json::json!({ "message": "Hello back!" }), + ) + .await; + server + .send_event("session.idle", serde_json::json!({})) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + let event = result.expect("should have captured assistant.message"); + assert_eq!(event.event_type, "assistant.message"); + assert_eq!(event.data["message"], "Hello back!"); +} + +#[tokio::test] +async fn send_and_wait_returns_error_on_session_error() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("fail").with_wait_timeout(Duration::from_secs(5)), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + server + .send_event( + "session.error", + serde_json::json!({ "message": "something went wrong" }), + ) + .await; + + let err = timeout(TIMEOUT, handle) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!( + matches!(err, github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::AgentError(ref msg)) if msg.contains("something went wrong")) + ); +} + +#[tokio::test] +async fn send_and_wait_times_out() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("hello").with_wait_timeout(Duration::from_millis(100)), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + let err = timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::Timeout(_)) + )); +} + +/// Cancel-safety regression: an outer `tokio::time::timeout` around +/// `send_and_wait` must NOT leak the `idle_waiter` slot. After the outer +/// timeout fires and drops the future, subsequent `send` and +/// `send_and_wait` calls must succeed without `SendWhileWaiting`. +/// +/// Closes RFD-400 review finding #2. +#[tokio::test] +async fn send_and_wait_outer_cancellation_clears_waiter() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + // First call: wrap in outer timeout much shorter than the inner + // wait_timeout. The outer timeout expires, dropping the + // send_and_wait future before the idle/error event arrives. + let handle = tokio::spawn({ + let session = session.clone(); + async move { + tokio::time::timeout( + Duration::from_millis(50), + session.send_and_wait( + MessageOptions::new("first").with_wait_timeout(Duration::from_secs(60)), + ), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + // Outer timeout fires โ†’ Err(Elapsed) returned, future is dropped. + let outer_result = timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + assert!(outer_result.is_err(), "outer timeout should have elapsed"); + + // The WaiterGuard's Drop should have cleared the slot. A subsequent + // `send` must NOT return SendWhileWaiting. + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send("second").await } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "second"); + server + .respond( + &request, + serde_json::json!({ "messageId": "msg-after-cancel" }), + ) + .await; + + let result = timeout(TIMEOUT, send_handle).await.unwrap().unwrap(); + assert_eq!(result.unwrap(), "msg-after-cancel"); +} + +/// Cancel-safety regression: explicitly dropping the JoinHandle of an +/// in-flight `send_and_wait` must clear the waiter slot via WaiterGuard's +/// Drop. The next `send` must succeed. +/// +/// Closes RFD-400 review finding #2. +#[tokio::test] +async fn send_and_wait_drop_clears_waiter() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + // Start a send_and_wait, let it install the waiter, then abort the + // task before any idle/error event arrives. + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("aborted").with_wait_timeout(Duration::from_secs(60)), + ) + .await + } + }); + + // Drain the session.send RPC so we know the waiter is installed. + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + // Now abort the in-flight send_and_wait. The WaiterGuard drops as + // the future unwinds, clearing the slot. + handle.abort(); + let _ = handle.await; + + // Give the runtime a moment to run the drop. + tokio::task::yield_now().await; + + // Next `send` must succeed โ€” no SendWhileWaiting. + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send("after-abort").await } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "after-abort"); + server + .respond( + &request, + serde_json::json!({ "messageId": "msg-after-abort" }), + ) + .await; + + let result = timeout(TIMEOUT, send_handle).await.unwrap().unwrap(); + assert_eq!(result.unwrap(), "msg-after-abort"); +} + +/// Cancel-safety regression: `Session::stop_event_loop` must NOT abort +/// the event-loop task mid-handler. An in-flight handler (here a slow +/// `userInput.request` callback) must run to completion before the loop +/// exits โ€” the CLI receives the response on the wire before the session +/// tears down. +/// +/// Closes RFD-400 review finding #3. +#[tokio::test] +async fn stop_event_loop_completes_in_flight_handler() { + struct SlowHandler; + #[async_trait] + impl SessionHandler for SlowHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { .. } => { + // Sleep so stop_event_loop has a chance to fire while + // the handler is mid-flight. The loop must wait for + // this to return rather than abort it. + tokio::time::sleep(Duration::from_millis(150)).await; + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "completed".to_string(), + was_freeform: false, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (session, mut server) = create_session_pair(Arc::new(SlowHandler)).await; + let session = Arc::new(session); + + server + .send_request( + 900, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "slow", + "choices": null, + "allowFreeform": true, + }), + ) + .await; + + // Give the loop a moment to dispatch into the handler. + tokio::time::sleep(Duration::from_millis(20)).await; + + // Now request shutdown. The loop is parked in handle_request awaiting + // the slow handler. `notify_one()` buffers the signal until the loop + // re-enters its select, which can only happen after the handler + // returns and the response is sent on the wire. + let stop_handle = tokio::spawn({ + let session = session.clone(); + async move { session.stop_event_loop().await } + }); + + // Verify the handler's response lands on the wire BEFORE the loop + // exits โ€” i.e. stop_event_loop did not abort mid-handler. + let response = timeout(Duration::from_secs(2), server.read_response()) + .await + .unwrap(); + assert_eq!(response["id"], 900); + assert_eq!(response["result"]["answer"], "completed"); + + // stop_event_loop completes after the handler returns and the loop + // observes the buffered shutdown signal on its next select iteration. + timeout(Duration::from_secs(2), stop_handle) + .await + .unwrap() + .unwrap(); +} + +/// Cancel-safety regression: dropping a Session does NOT abort the event +/// loop mid-handler. The loop sees the buffered shutdown signal on its +/// next select iteration and exits cleanly. This is the Drop equivalent +/// of stop_event_loop_completes_in_flight_handler; closes RFD-400 review +/// finding #3 for the implicit-drop path that used to call +/// `JoinHandle::abort()`. +#[tokio::test] +async fn drop_session_does_not_abort_handler() { + use std::sync::atomic::{AtomicBool, Ordering}; + + let handler_completed = Arc::new(AtomicBool::new(false)); + + struct CompletionHandler { + completed: Arc, + } + #[async_trait] + impl SessionHandler for CompletionHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { .. } => { + tokio::time::sleep(Duration::from_millis(100)).await; + self.completed.store(true, Ordering::SeqCst); + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "done".to_string(), + was_freeform: false, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (session, mut server) = create_session_pair(Arc::new(CompletionHandler { + completed: handler_completed.clone(), + })) + .await; + + server + .send_request( + 901, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "drop-test", + "choices": null, + "allowFreeform": true, + }), + ) + .await; + + tokio::time::sleep(Duration::from_millis(20)).await; + drop(session); + + let response = timeout(Duration::from_secs(2), server.read_response()) + .await + .unwrap(); + assert_eq!(response["id"], 901); + assert_eq!(response["result"]["answer"], "done"); + assert!( + handler_completed.load(Ordering::SeqCst), + "handler must run to completion despite Session being dropped" + ); +} + +/// `Session::cancellation_token()` returns a child token that fires when +/// the session shuts down. Lets external tasks bind their lifetime to the +/// session via `tokio::select!` without taking a strong reference to the +/// session itself. +#[tokio::test] +async fn cancellation_token_fires_on_session_drop() { + let handler = Arc::new(ApproveAllHandler); + let (session, _server) = create_session_pair(handler).await; + + let token = session.cancellation_token(); + assert!(!token.is_cancelled()); + + drop(session); + + // The session's Drop impl cancels the parent token, which propagates + // to all child tokens. + timeout(Duration::from_secs(2), token.cancelled()) + .await + .expect("child token must observe cancellation after session drop"); + assert!(token.is_cancelled()); +} + +/// Cancelling a child token returned by `cancellation_token()` does NOT +/// shut the session down โ€” child tokens isolate consumer-side cancel +/// logic from the session's own lifecycle. +#[tokio::test] +async fn cancellation_token_child_cancel_does_not_kill_session() { + let handler = Arc::new(ApproveAllHandler); + let (session, _server) = create_session_pair(handler).await; + + let child = session.cancellation_token(); + child.cancel(); + + // Session's own token (and event loop) are untouched. Issue a cheap + // RPC and confirm it still works. + let parent = session.cancellation_token(); + assert!(!parent.is_cancelled()); +} + +#[tokio::test] +async fn elicitation_requested_dispatches_to_handler_and_responds() { + use github_copilot_sdk::types::ElicitationResult; + + struct ElicitHandler; + #[async_trait] + impl SessionHandler for ElicitHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ElicitationRequest { request, .. } => { + assert_eq!(request.message, "Enter your name"); + HandlerResponse::Elicitation(ElicitationResult { + action: "accept".to_string(), + content: Some(serde_json::json!({ "name": "Alice" })), + }) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ElicitHandler)).await; + + // CLI broadcasts elicitation.requested as a session event notification + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-1", + "message": "Enter your name", + "requestedSchema": { + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + }, + "mode": "form", + }), + ) + .await; + + // The SDK should call session.ui.handlePendingElicitation RPC + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.ui.handlePendingElicitation"); + assert_eq!(rpc_call["params"]["requestId"], "elicit-1"); + assert_eq!(rpc_call["params"]["result"]["action"], "accept"); + assert_eq!(rpc_call["params"]["result"]["content"]["name"], "Alice"); +} + +#[tokio::test] +async fn elicitation_requested_cancels_on_handler_error() { + struct FailHandler; + #[async_trait] + impl SessionHandler for FailHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + // Return Ok instead of Elicitation โ€” SDK should treat as cancel + HandlerEvent::ElicitationRequest { .. } => HandlerResponse::Ok, + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(FailHandler)).await; + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-2", + "message": "Pick something", + }), + ) + .await; + + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.ui.handlePendingElicitation"); + assert_eq!(rpc_call["params"]["result"]["action"], "cancel"); +} + +#[tokio::test] +async fn external_tool_requested_dispatches_to_handler_and_responds() { + struct ExternalToolHandler; + #[async_trait] + impl SessionHandler for ExternalToolHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExternalTool { invocation } => { + assert_eq!(invocation.tool_name, "run_tests"); + assert_eq!(invocation.tool_call_id, "tc-ext-1"); + assert_eq!(invocation.arguments["suite"], "unit"); + HandlerResponse::ToolResult(ToolResult::Text("all tests passed".to_string())) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ExternalToolHandler)).await; + + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-ext-1", + "sessionId": server.session_id, + "toolCallId": "tc-ext-1", + "toolName": "run_tests", + "arguments": { "suite": "unit" }, + }), + ) + .await; + + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.tools.handlePendingToolCall"); + assert_eq!(rpc_call["params"]["requestId"], "req-ext-1"); + assert_eq!(rpc_call["params"]["result"], "all tests passed"); +} + +#[tokio::test] +async fn capabilities_captured_from_create_response() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessionId": "cap-session", + "capabilities": { + "ui": { "elicitation": true } + } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let caps = session.capabilities(); + assert_eq!(caps.ui.as_ref().unwrap().elicitation, Some(true)); +} + +#[tokio::test] +async fn capabilities_changed_event_updates_session() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + // Initially no capabilities (create_session_pair doesn't send them) + assert!(session.capabilities().ui.is_none()); + + // CLI sends capabilities.changed event + server + .send_event( + "capabilities.changed", + serde_json::json!({ + "ui": { "elicitation": true } + }), + ) + .await; + + // Poll until the event loop processes the notification + let caps = timeout(TIMEOUT, async { + loop { + let caps = session.capabilities(); + if caps.ui.is_some() { + return caps; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("capabilities should update within timeout"); + + assert_eq!(caps.ui.as_ref().unwrap().elicitation, Some(true)); +} + +#[tokio::test] +async fn request_elicitation_sent_in_create_params() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["requestElicitation"], true); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-elicit" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn elicitation_methods_fail_without_capability() { + let (session, _server) = create_session_pair(Arc::new(NoopHandler)).await; + + // Session created without capabilities โ€” elicitation should fail + let err = session + .ui() + .elicitation("test", serde_json::json!({})) + .await + .unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::ElicitationNotSupported + ) + )); + + let err = session.ui().confirm("ok?").await.unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::ElicitationNotSupported + ) + )); +} + +async fn create_session_pair_with_hooks( + handler: Arc, + hooks: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_hooks(hooks), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + // Verify hooks: true is auto-set in the config + assert_eq!(create_req["params"]["hooks"], true); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn hooks_invoke_dispatches_to_session_hooks() { + use github_copilot_sdk::hooks::{HookEvent, HookOutput, PreToolUseOutput, SessionHooks}; + + struct PolicyHooks; + #[async_trait] + impl SessionHooks for PolicyHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, .. } => { + if input.tool_name == "rm" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("destructive".to_string()), + ..Default::default() + }) + } else { + HookOutput::None + } + } + _ => HookOutput::None, + } + } + } + + let (_session, mut server) = + create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(PolicyHooks)).await; + + // Send a hooks.invoke request for a denied tool + server + .send_request( + 300, + "hooks.invoke", + serde_json::json!({ + "sessionId": server.session_id, + "hookType": "preToolUse", + "input": { + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "rm", + "toolArgs": { "path": "/" } + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 300); + assert_eq!(response["result"]["output"]["permissionDecision"], "deny"); + assert_eq!( + response["result"]["output"]["permissionDecisionReason"], + "destructive" + ); +} + +#[tokio::test] +async fn hooks_invoke_returns_empty_for_unregistered_hook() { + use github_copilot_sdk::hooks::SessionHooks; + + struct EmptyHooks; + #[async_trait] + impl SessionHooks for EmptyHooks {} + + let (_session, mut server) = + create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(EmptyHooks)).await; + + server + .send_request( + 301, + "hooks.invoke", + serde_json::json!({ + "sessionId": server.session_id, + "hookType": "sessionEnd", + "input": { + "timestamp": 1234567890, + "cwd": "/tmp", + "reason": "complete" + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 301); + assert_eq!(response["result"]["output"], serde_json::json!({})); +} + +async fn create_session_pair_with_transforms( + handler: Arc, + transforms: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_transform(transforms), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + // Verify transforms inject customize mode and section overrides + assert_eq!(create_req["params"]["systemMessage"]["mode"], "customize"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn system_message_transform_dispatches_to_transform() { + use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; + + struct AppendTransform; + #[async_trait] + impl SystemMessageTransform for AppendTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + Some(format!("{content}\nAlways be concise.")) + } + } + + let (_session, mut server) = + create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(AppendTransform)).await; + + server + .send_request( + 400, + "systemMessage.transform", + serde_json::json!({ + "sessionId": server.session_id, + "sections": { + "instructions": { "content": "You are helpful." } + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 400); + assert_eq!( + response["result"]["sections"]["instructions"]["content"], + "You are helpful.\nAlways be concise." + ); +} + +#[tokio::test] +async fn system_message_transform_returns_error_for_missing_sections() { + use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; + + struct DummyTransform; + #[async_trait] + impl SystemMessageTransform for DummyTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + _content: &str, + _ctx: TransformContext, + ) -> Option { + None + } + } + + let (_session, mut server) = + create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(DummyTransform)).await; + + // Send request with no sections parameter + server + .send_request( + 401, + "systemMessage.transform", + serde_json::json!({ + "sessionId": server.session_id, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 401); + assert_eq!(response["error"]["code"], -32602); +} + +#[tokio::test] +async fn list_workspace_files_uses_plural_method_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.list_workspace_files().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.listFiles"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond( + &request, + serde_json::json!({ "files": ["a.txt", "subdir/b.txt"] }), + ) + .await; + + let files = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert_eq!(files, vec!["a.txt".to_string(), "subdir/b.txt".to_string()]); +} + +#[tokio::test] +async fn read_workspace_file_uses_plural_method_name_and_forwards_path() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = + tokio::spawn(async move { s.read_workspace_file(Path::new("notes/plan.md")).await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.readFile"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["path"], "notes/plan.md"); + server + .respond(&request, serde_json::json!({ "content": "hello" })) + .await; + + let content = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert_eq!(content, "hello"); +} + +#[tokio::test] +async fn create_workspace_file_uses_plural_method_name_and_forwards_payload() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { + s.create_workspace_file(Path::new("notes/plan.md"), "body") + .await + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.createFile"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["path"], "notes/plan.md"); + assert_eq!(request["params"]["content"], "body"); + server.respond(&request, serde_json::json!({})).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn rpc_namespace_session_agent_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.rpc().agent().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.agent.list"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond(&request, serde_json::json!({ "agents": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.agents.is_empty()); +} + +#[tokio::test] +async fn rpc_namespace_session_tasks_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.rpc().tasks().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.tasks.list"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond(&request, serde_json::json!({ "tasks": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.tasks.is_empty()); +} + +#[tokio::test] +async fn rpc_namespace_client_models_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let client = session.client().clone(); + let handle = tokio::spawn(async move { client.rpc().models().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "models.list"); + server + .respond(&request, serde_json::json!({ "models": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.models.is_empty()); +} + +#[tokio::test] +async fn client_stop_sends_session_destroy_for_each_active_session() { + // One client, two registered sessions. Client::stop must send + // session.destroy for each before returning Ok. + let (client, server_read, server_write) = make_client(); + let session_id_a = format!("test-session-{}", rand_id()); + let session_id_b = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id_a.clone(), + }; + + // Spawn both create_session calls. + let create_a = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_a_req = server.read_request().await; + assert_eq!(create_a_req["method"], "session.create"); + server + .respond( + &create_a_req, + serde_json::json!({ "sessionId": session_id_a, "workspacePath": "/tmp/ws-a" }), + ) + .await; + let _session_a = timeout(TIMEOUT, create_a).await.unwrap(); + + let create_b = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_b_req = server.read_request().await; + assert_eq!(create_b_req["method"], "session.create"); + server + .respond( + &create_b_req, + serde_json::json!({ "sessionId": session_id_b, "workspacePath": "/tmp/ws-b" }), + ) + .await; + let _session_b = timeout(TIMEOUT, create_b).await.unwrap(); + + // Drive Client::stop and respond to each destroy in turn. + let stop_handle = tokio::spawn({ + let client = client.clone(); + async move { client.stop().await } + }); + + let mut destroyed = Vec::new(); + for _ in 0..2 { + let req = server.read_request().await; + assert_eq!(req["method"], "session.destroy"); + destroyed.push(req["params"]["sessionId"].as_str().unwrap().to_string()); + server.respond(&req, serde_json::json!(null)).await; + } + destroyed.sort(); + let mut expected = [session_id_a.clone(), session_id_b.clone()]; + expected.sort(); + assert_eq!(destroyed, expected); + + let stop_result = timeout(TIMEOUT, stop_handle).await.unwrap().unwrap(); + assert!(stop_result.is_ok(), "stop returned errors: {stop_result:?}"); +} + +#[tokio::test] +async fn client_stop_aggregates_session_destroy_errors() { + // session.destroy fails on the wire โ€” Client::stop returns + // StopErrors carrying the failure rather than short-circuiting. + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let client = session.client().clone(); + + let stop_handle = tokio::spawn(async move { client.stop().await }); + + let req = server.read_request().await; + assert_eq!(req["method"], "session.destroy"); + let id = req["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "error": { "code": -32000, "message": "session gone" }, + }); + write_framed(&mut server.write, &serde_json::to_vec(&response).unwrap()).await; + + let stop_result = timeout(TIMEOUT, stop_handle).await.unwrap().unwrap(); + let errors = stop_result.expect_err("expected aggregated errors"); + assert_eq!(errors.errors().len(), 1); + let msg = errors.to_string(); + assert!(msg.contains("session gone"), "unexpected message: {msg}"); +} + +#[test] +fn session_config_serializes_bucket_b_fields() { + use std::path::PathBuf; + + use github_copilot_sdk::{SessionConfig, SessionId}; + + let cfg = { + let mut cfg = SessionConfig::default(); + cfg.session_id = Some(SessionId::from("custom-id")); + cfg.config_dir = Some(PathBuf::from("/tmp/cfg")); + cfg.working_directory = Some(PathBuf::from("/tmp/work")); + cfg.github_token = Some("ghs_secret".to_string()); + cfg.include_sub_agent_streaming_events = Some(false); + cfg + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["sessionId"], "custom-id"); + assert_eq!(json["configDir"], "/tmp/cfg"); + assert_eq!(json["workingDirectory"], "/tmp/work"); + assert_eq!(json["gitHubToken"], "ghs_secret"); + assert_eq!(json["includeSubAgentStreamingEvents"], false); + + // Debug never leaks the token. + let debug = format!("{cfg:?}"); + assert!(!debug.contains("ghs_secret"), "leaked token: {debug}"); + assert!(debug.contains(""), "missing redaction: {debug}"); + + // Unset fields are omitted on the wire. + let empty = serde_json::to_value(SessionConfig::default()).unwrap(); + assert!(empty.get("sessionId").is_none()); + assert!(empty.get("gitHubToken").is_none()); +} + +#[test] +fn resume_session_config_serializes_bucket_b_fields() { + use std::path::PathBuf; + + use github_copilot_sdk::{ResumeSessionConfig, SessionId}; + + let mut cfg = ResumeSessionConfig::new(SessionId::from("sess-1")); + cfg.working_directory = Some(PathBuf::from("/tmp/work")); + cfg.config_dir = Some(PathBuf::from("/tmp/cfg")); + cfg.github_token = Some("ghs_secret".to_string()); + cfg.include_sub_agent_streaming_events = Some(true); + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["sessionId"], "sess-1"); + assert_eq!(json["workingDirectory"], "/tmp/work"); + assert_eq!(json["configDir"], "/tmp/cfg"); + assert_eq!(json["gitHubToken"], "ghs_secret"); + assert_eq!(json["includeSubAgentStreamingEvents"], true); + + let debug = format!("{cfg:?}"); + assert!(!debug.contains("ghs_secret"), "leaked token: {debug}"); +} + +// ===================================================================== +// Slash commands (ยง 4.1) +// ===================================================================== + +struct CountingCommandHandler { + last_ctx: Arc>>, + error_to_return: Option, +} + +#[async_trait] +impl CommandHandler for CountingCommandHandler { + async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + *self.last_ctx.lock() = Some(ctx); + if let Some(message) = &self.error_to_return { + Err(github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::AgentError(message.clone()), + )) + } else { + Ok(()) + } + } +} + +async fn create_session_pair_with_commands( + handler: Arc, + commands: Vec, +) -> (github_copilot_sdk::session::Session, FakeServer, Value) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_commands(commands), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server, create_req) +} + +#[tokio::test] +async fn create_serializes_commands_strips_handler() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![ + CommandDefinition::new( + "deploy", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + ) + .with_description("Deploy to production"), + CommandDefinition::new( + "rollback", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + ), + ]; + + let (_session, _server, create_req) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + let wire = create_req["params"]["commands"] + .as_array() + .expect("commands should be an array"); + assert_eq!(wire.len(), 2); + + let deploy = &wire[0]; + assert_eq!(deploy["name"], "deploy"); + assert_eq!(deploy["description"], "Deploy to production"); + assert!( + deploy.get("handler").is_none(), + "wire payload must not include handler, got: {deploy}" + ); + let deploy_keys: Vec<&String> = deploy.as_object().unwrap().keys().collect(); + assert_eq!(deploy_keys.len(), 2, "got keys: {deploy_keys:?}"); + + let rollback = &wire[1]; + assert_eq!(rollback["name"], "rollback"); + assert!( + rollback.get("description").is_none(), + "description should be omitted when None, got: {rollback}" + ); + assert!(rollback.get("handler").is_none()); + let rollback_keys: Vec<&String> = rollback.as_object().unwrap().keys().collect(); + assert_eq!(rollback_keys.len(), 1, "got keys: {rollback_keys:?}"); +} + +#[tokio::test] +async fn command_execute_dispatches_to_registered_handler_and_acks_success() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![CommandDefinition::new( + "deploy", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + )]; + + let (session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-deploy-1", + "command": "/deploy production", + "commandName": "deploy", + "args": "production", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!( + ack["method"], "session.commands.handlePendingCommand", + "expected handlePendingCommand RPC, got: {ack}" + ); + assert_eq!( + ack["params"]["sessionId"].as_str(), + Some(session.id().as_ref()) + ); + assert_eq!(ack["params"]["requestId"], "req-deploy-1"); + assert!( + ack["params"].get("error").is_none(), + "success ack should omit error, got: {ack}" + ); + + server + .respond(&ack, serde_json::json!({ "success": true })) + .await; + + let ctx = last_ctx + .lock() + .clone() + .expect("handler should have been invoked"); + assert_eq!(ctx.command, "/deploy production"); + assert_eq!(ctx.command_name, "deploy"); + assert_eq!(ctx.args, "production"); + assert_eq!(ctx.session_id.as_ref(), session.id().as_ref()); +} + +#[tokio::test] +async fn command_execute_unknown_command_acks_with_error() { + let (session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), vec![]).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-unknown-1", + "command": "/missing", + "commandName": "missing", + "args": "", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(ack["method"], "session.commands.handlePendingCommand"); + assert_eq!(ack["params"]["requestId"], "req-unknown-1"); + assert_eq!( + ack["params"]["error"], "Unknown command: missing", + "got: {ack}" + ); + server + .respond(&ack, serde_json::json!({ "success": false })) + .await; + drop(session); +} + +#[tokio::test] +async fn command_execute_handler_error_propagates_to_ack() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![CommandDefinition::new( + "fail", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: Some("deploy failed: dry-run rejected".to_string()), + }), + )]; + + let (_session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-fail-1", + "command": "/fail", + "commandName": "fail", + "args": "", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(ack["method"], "session.commands.handlePendingCommand"); + assert_eq!(ack["params"]["requestId"], "req-fail-1"); + let error_msg = ack["params"]["error"] + .as_str() + .expect("ack should include error"); + assert!( + error_msg.contains("deploy failed: dry-run rejected"), + "expected handler error in ack, got: {error_msg}" + ); + server + .respond(&ack, serde_json::json!({ "success": false })) + .await; +} + +// SessionFsProvider tests -------------------------------------------------- + +use github_copilot_sdk::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConventions, SessionFsProvider, +}; + +struct RecordingFsProvider { + files: parking_lot::Mutex>, +} + +impl RecordingFsProvider { + fn new() -> Self { + Self { + files: parking_lot::Mutex::new(std::collections::HashMap::new()), + } + } + + fn with_file(self, path: &str, content: &str) -> Self { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + self + } +} + +#[async_trait] +impl SessionFsProvider for RecordingFsProvider { + async fn read_file(&self, path: &str) -> Result { + self.files + .lock() + .get(path) + .cloned() + .ok_or_else(|| FsError::NotFound(path.to_string())) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + Ok(()) + } + + async fn stat(&self, path: &str) -> Result { + let files = self.files.lock(); + let content = files + .get(path) + .ok_or_else(|| FsError::NotFound(path.to_string()))?; + Ok(FileInfo::new( + true, + false, + content.len() as i64, + "2025-01-01T00:00:00Z", + "2025-01-01T00:00:00Z", + )) + } + + async fn readdir_with_types(&self, _path: &str) -> Result, FsError> { + Ok(vec![ + DirEntry::new("README.md", DirEntryKind::File), + DirEntry::new("src", DirEntryKind::Directory), + ]) + } + + async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { + let mut files = self.files.lock(); + if files.remove(path).is_none() && !force { + return Err(FsError::NotFound(path.to_string())); + } + Ok(()) + } +} + +async fn create_session_pair_with_fs_provider( + handler: Arc, + provider: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_session_fs_provider(provider), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn session_fs_dispatches_read_file_to_provider() { + let provider = Arc::new(RecordingFsProvider::new().with_file("/foo.txt", "hello world")); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 42, + "sessionFs.readFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/foo.txt" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 42); + assert_eq!(response["result"]["content"], "hello world"); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); +} + +#[tokio::test] +async fn session_fs_maps_not_found_to_enoent() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 7, + "sessionFs.readFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/missing.txt" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 7); + let error = &response["result"]["error"]; + assert_eq!(error["code"], "ENOENT"); + assert!(error["message"].as_str().unwrap().contains("missing.txt")); +} + +#[tokio::test] +async fn session_fs_maps_other_to_unknown() { + struct AlwaysFails; + #[async_trait] + impl SessionFsProvider for AlwaysFails { + async fn stat(&self, _path: &str) -> Result { + Err(FsError::Other("backing store unavailable".to_string())) + } + } + + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + + server + .send_request( + 8, + "sessionFs.stat", + serde_json::json!({ "sessionId": server.session_id, "path": "/x" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + let error = &response["result"]["error"]; + assert_eq!(error["code"], "UNKNOWN"); + assert!( + error["message"] + .as_str() + .unwrap() + .contains("backing store unavailable") + ); +} + +#[tokio::test] +async fn session_fs_dispatches_write_file_with_mode() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider.clone()).await; + + server + .send_request( + 10, + "sessionFs.writeFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/out.txt", "content": "abc", "mode": 420 }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 10); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); + assert_eq!(provider.files.lock().get("/out.txt").unwrap(), "abc"); +} + +#[tokio::test] +async fn session_fs_dispatches_readdir_with_types() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 11, + "sessionFs.readdirWithTypes", + serde_json::json!({ "sessionId": server.session_id, "path": "/dir" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + let entries = response["result"]["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0]["name"], "README.md"); + assert_eq!(entries[0]["type"], "file"); + assert_eq!(entries[1]["name"], "src"); + assert_eq!(entries[1]["type"], "directory"); +} + +#[tokio::test] +async fn session_fs_dispatches_rm_with_force() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 12, + "sessionFs.rm", + serde_json::json!({ "sessionId": server.session_id, "path": "/missing", "force": true, "recursive": false }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 12); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); +} + +#[tokio::test] +async fn validate_session_fs_config_rejects_empty_initial_cwd() { + let cfg = github_copilot_sdk::session_fs::SessionFsConfig::new( + "", + "/state", + SessionFsConventions::Posix, + ); + let opts = { + let mut opts = github_copilot_sdk::ClientOptions::default(); + opts.session_fs = Some(cfg); + opts + }; + let err = github_copilot_sdk::Client::start(opts).await.err(); + let err_string = format!("{err:?}"); + assert!( + err_string.contains("initial_cwd") || err_string.contains("InvalidSessionFsConfig"), + "got: {err_string}" + ); +} + +#[tokio::test] +async fn create_session_errors_when_provider_required_but_missing() { + // Without a CLI we can't exercise the configured-but-missing-provider path + // through Client::start; the unit-level behavior is covered by the + // SessionError::SessionFsProviderRequired variant being constructible. + // This test asserts the error type's display formatting is stable. + let err = github_copilot_sdk::SessionError::SessionFsProviderRequired; + assert!(format!("{err}").contains("session_fs")); +} + +// ---------- 4.3 trace context tests ---------- + +struct StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext, + calls: Arc, +} + +#[async_trait] +impl github_copilot_sdk::types::TraceContextProvider for StaticTraceProvider { + async fn get_trace_context(&self) -> github_copilot_sdk::types::TraceContext { + self.calls.fetch_add(1, Ordering::Relaxed); + self.ctx.clone() + } +} + +fn make_client_with_trace_provider( + provider: Arc, +) -> (Client, tokio::io::DuplexStream, tokio::io::DuplexStream) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams_with_trace_provider( + client_read, + client_write, + std::env::temp_dir(), + provider, + ) + .unwrap(); + (client, server_read, server_write) +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_create() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-aaaa-bbbb-01") + .with_tracestate("vendor=value"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-create".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let req = server.read_request().await; + assert_eq!(req["method"], "session.create"); + assert_eq!(req["params"]["traceparent"], "00-aaaa-bbbb-01"); + assert_eq!(req["params"]["tracestate"], "vendor=value"); + server + .respond( + &req, + serde_json::json!({"sessionId": "trace-create", "workspacePath": "/tmp/ws"}), + ) + .await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_resume() { + use github_copilot_sdk::types::ResumeSessionConfig; + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-resume-trace-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-resume".to_string(), + }; + + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")) + .with_handler(Arc::new(NoopHandler)); + client.resume_session(cfg).await.unwrap() + } + }); + + // resume sends `session.resume` then `session.skills.reload`. + let req = server.read_request().await; + assert_eq!(req["method"], "session.resume"); + assert_eq!(req["params"]["traceparent"], "00-resume-trace-01"); + assert!( + req["params"].get("tracestate").is_none(), + "tracestate should be omitted when None" + ); + server + .respond( + &req, + serde_json::json!({"sessionId": "trace-resume", "workspacePath": "/tmp/ws"}), + ) + .await; + let reload_req = server.read_request().await; + assert_eq!(reload_req["method"], "session.skills.reload"); + server.respond(&reload_req, serde_json::json!({})).await; + + timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_send() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-send-trace-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-send".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_req = server.read_request().await; + server + .respond( + &create_req, + serde_json::json!({"sessionId": "trace-send", "workspacePath": "/tmp/ws"}), + ) + .await; + let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); + + // Provider was called once for create; reset by reading the count baseline. + let baseline = calls.load(Ordering::Relaxed); + assert_eq!(baseline, 1, "create_session should call the provider once"); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("hi")).await } + }); + let send_req = server.read_request().await; + assert_eq!(send_req["method"], "session.send"); + assert_eq!(send_req["params"]["traceparent"], "00-send-trace-01"); + server.respond(&send_req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), baseline + 1); +} + +#[tokio::test] +async fn message_options_trace_context_overrides_callback() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-callback-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-override".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_req = server.read_request().await; + server + .respond( + &create_req, + serde_json::json!({"sessionId": "trace-override", "workspacePath": "/tmp/ws"}), + ) + .await; + let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); + + let baseline = calls.load(Ordering::Relaxed); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send( + MessageOptions::new("hi") + .with_traceparent("00-override-01") + .with_tracestate("vendor=override"), + ) + .await + } + }); + let send_req = server.read_request().await; + assert_eq!(send_req["params"]["traceparent"], "00-override-01"); + assert_eq!(send_req["params"]["tracestate"], "vendor=override"); + server.respond(&send_req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); + + // Callback must NOT have been invoked when MessageOptions carried an override. + assert_eq!( + calls.load(Ordering::Relaxed), + baseline, + "callback should be skipped when MessageOptions carries trace headers" + ); +} + +#[tokio::test] +async fn message_options_trace_context_used_without_callback() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("hi").with_traceparent("00-direct-01")) + .await + } + }); + let req = server.read_request().await; + assert_eq!(req["method"], "session.send"); + assert_eq!(req["params"]["traceparent"], "00-direct-01"); + assert!( + req["params"].get("tracestate").is_none(), + "tracestate should be omitted when only traceparent is set" + ); + server.respond(&req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tool_invocation_carries_trace_context_from_event() { + use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, SessionHandler}; + + struct CapturingHandler { + captured: parking_lot::Mutex, Option)>>, + signal: tokio::sync::Notify, + } + + #[async_trait] + impl SessionHandler for CapturingHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::ExternalTool { invocation } = event { + *self.captured.lock() = Some(( + invocation.traceparent.clone(), + invocation.tracestate.clone(), + )); + self.signal.notify_one(); + return HandlerResponse::ToolResult(ToolResult::Text("ok".into())); + } + HandlerResponse::Ok + } + } + + let handler = Arc::new(CapturingHandler { + captured: parking_lot::Mutex::new(None), + signal: tokio::sync::Notify::new(), + }); + let (_session, mut server) = create_session_pair(handler.clone()).await; + + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-1", + "sessionId": server.session_id, + "toolCallId": "tc-1", + "toolName": "calc", + "arguments": {"x": 1}, + "traceparent": "00-tool-01", + "tracestate": "vendor=tool", + }), + ) + .await; + + // Drain the handlePendingToolCall RPC the dispatcher sends after the handler runs. + let pending = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(pending["method"], "session.tools.handlePendingToolCall"); + + timeout(TIMEOUT, handler.signal.notified()).await.unwrap(); + let captured = handler.captured.lock().clone(); + assert_eq!( + captured, + Some((Some("00-tool-01".into()), Some("vendor=tool".into()))), + ); +} + +#[tokio::test] +async fn wire_omits_trace_fields_when_unset() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("hi")).await } + }); + let req = server.read_request().await; + assert!(req["params"].get("traceparent").is_none()); + assert!(req["params"].get("tracestate").is_none()); + server.respond(&req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} diff --git a/scripts/codegen/package.json b/scripts/codegen/package.json index a2df5dded..c42713d84 100644 --- a/scripts/codegen/package.json +++ b/scripts/codegen/package.json @@ -3,11 +3,12 @@ "private": true, "type": "module", "scripts": { - "generate": "tsx typescript.ts && tsx csharp.ts && tsx python.ts && tsx go.ts", + "generate": "tsx typescript.ts && tsx csharp.ts && tsx python.ts && tsx go.ts && tsx rust.ts", "generate:ts": "tsx typescript.ts", "generate:csharp": "tsx csharp.ts", "generate:python": "tsx python.ts", - "generate:go": "tsx go.ts" + "generate:go": "tsx go.ts", + "generate:rust": "tsx rust.ts" }, "dependencies": { "json-schema": "^0.4.0", diff --git a/scripts/codegen/rust.ts b/scripts/codegen/rust.ts new file mode 100644 index 000000000..c9ed49aca --- /dev/null +++ b/scripts/codegen/rust.ts @@ -0,0 +1,1406 @@ +/** + * Rust code generator for the Copilot protocol JSON Schemas. + * + * Reads api.schema.json and session-events.schema.json, emits idiomatic Rust + * types to rust/src/generated/. + * + * Usage: npx tsx scripts/codegen/rust.ts + */ + +import { execFile } from "child_process"; +import fs from "fs/promises"; +import path from "path"; +import { promisify } from "util"; +import type { JSONSchema7, JSONSchema7Definition } from "json-schema"; +import { + type ApiSchema, + type DefinitionCollections, + EXCLUDED_EVENT_TYPES, + REPO_ROOT, + type RpcMethod, + collectDefinitionCollections, + collectDefinitions, + getApiSchemaPath, + getRpcSchemaTypeName, + getSessionEventsSchemaPath, + isObjectSchema, + isRpcMethod, + isSchemaDeprecated, + isVoidSchema, + postProcessSchema, + refTypeName, + resolveObjectSchema, + resolveRef, + resolveSchema, + stripBooleanLiterals, +} from "./utils.js"; + +const execFileAsync = promisify(execFile); + +const GENERATED_DIR = path.join(REPO_ROOT, "rust/src/generated"); + +/** + * JSON property names that should be emitted as a hand-authored newtype rather + * than `String`. The newtype is `#[serde(transparent)]`, so the wire format is + * unchanged. Add new entries sparingly โ€” these only fire when a schema field + * has type `string` and an exact-match name in this map. + */ +const STRING_NEWTYPE_OVERRIDES: Record = { + sessionId: "SessionId", + remoteSessionId: "SessionId", + requestId: "RequestId", +}; + +// โ”€โ”€ Naming helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function toPascalCase(s: string): string { + return s + .split(/[._\-\s]+/) + .map((w) => w.charAt(0).toUpperCase() + w.slice(1)) + .join(""); +} + +function toSnakeCase(s: string): string { + return s + .replace(/([A-Z])/g, "_$1") + .replace(/^_/, "") + .replace(/[.\-\s]+/g, "_") + .toLowerCase() + .replace(/_+/g, "_"); +} + +/** Convert a JSON property name (camelCase) to a Rust field name (snake_case). */ +function toRustFieldName(jsonName: string): string { + return toSnakeCase(jsonName); +} + +/** Convert snake_case back to camelCase (matches serde's rename_all = "camelCase"). */ +function snakeToCamelCase(snake: string): string { + return snake.replace(/_([a-z0-9])/g, (_, c: string) => c.toUpperCase()); +} + +/** + * Rust reserved keywords that need raw identifier syntax (r#). + */ +const RUST_KEYWORDS = new Set([ + "as", + "async", + "await", + "break", + "const", + "continue", + "crate", + "dyn", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "yield", +]); + +function safeRustFieldName(name: string): string { + const snake = toRustFieldName(name); + return RUST_KEYWORDS.has(snake) ? `r#${snake}` : snake; +} + +// โ”€โ”€ Codegen context โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +interface RustCodegenCtx { + /** Accumulated struct definitions. */ + structs: string[]; + /** Accumulated enum definitions. */ + enums: string[]; + /** Track generated type names to avoid duplicates. */ + generatedNames: Set; + /** Schema definitions for $ref resolution. */ + definitions?: DefinitionCollections; +} + +function stripOption(typeName: string): string { + return typeName.startsWith("Option<") && typeName.endsWith(">") + ? typeName.slice("Option<".length, -1) + : typeName; +} + +function getUnionVariants(schema: JSONSchema7): JSONSchema7[] | null { + if (schema.anyOf) return schema.anyOf as JSONSchema7[]; + if (schema.oneOf) return schema.oneOf as JSONSchema7[]; + return null; +} + +function tryEmitRustDiscriminatedUnion( + schema: JSONSchema7, + parentTypeName: string, + jsonPropName: string, + ctx: RustCodegenCtx, +): string | null { + const variants = getUnionVariants(schema); + if (!variants) return null; + + const nonNull = variants.filter((variant) => variant.type !== "null"); + if (nonNull.length <= 1) return null; + + const enumName = + (typeof schema.title === "string" && schema.title) || + parentTypeName + toPascalCase(jsonPropName); + + const resolvedVariants = nonNull.map((variant) => { + if (variant.$ref && typeof variant.$ref === "string") { + const resolved = resolveRef(variant.$ref, ctx.definitions); + return { + schema: (resolved ?? variant) as JSONSchema7, + typeName: toPascalCase(refTypeName(variant.$ref, ctx.definitions)), + }; + } + + const resolved = + resolveObjectSchema(variant, ctx.definitions) ?? + resolveSchema(variant, ctx.definitions) ?? + variant; + const kindConst = (resolved.properties?.kind as JSONSchema7 | undefined) + ?.const; + const typeName = + (typeof resolved.title === "string" && resolved.title) || + (typeof kindConst === "string" + ? `${enumName}${toPascalCase(kindConst)}` + : `${enumName}Variant`); + + return { + schema: resolved as JSONSchema7, + typeName, + }; + }); + + const isDiscriminated = resolvedVariants.every( + ({ schema: variantSchema }) => { + if (!isObjectSchema(variantSchema) || !variantSchema.properties) + return false; + const kind = variantSchema.properties.kind as JSONSchema7 | undefined; + return typeof kind?.const === "string"; + }, + ); + if (!isDiscriminated) return null; + + if (ctx.generatedNames.has(enumName)) { + return enumName; + } + ctx.generatedNames.add(enumName); + + for (const { schema: variantSchema, typeName } of resolvedVariants) { + if (isObjectSchema(variantSchema)) { + emitRustStruct(typeName, variantSchema, ctx); + } + } + + const lines: string[] = []; + if (schema.description) { + for (const line of schema.description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + lines.push("#[serde(untagged)]"); + lines.push(`pub enum ${enumName} {`); + + for (const { schema: variantSchema, typeName } of resolvedVariants) { + const kind = ((variantSchema.properties?.kind as JSONSchema7 | undefined) + ?.const ?? typeName) as string; + lines.push(` ${toPascalCase(kind)}(${stripOption(typeName)}),`); + } + + lines.push("}"); + ctx.enums.push(lines.join("\n")); + return enumName; +} + +function makeCtx(definitions?: DefinitionCollections): RustCodegenCtx { + return { + structs: [], + enums: [], + generatedNames: new Set(), + definitions, + }; +} + +// โ”€โ”€ Type resolution โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/** + * Map a JSON Schema to a Rust type string. Emits nested type definitions as + * side effects into ctx. + */ +function resolveRustType( + propSchema: JSONSchema7, + parentTypeName: string, + jsonPropName: string, + isRequired: boolean, + ctx: RustCodegenCtx, +): string { + const nestedName = parentTypeName + toPascalCase(jsonPropName); + + // $ref โ€” resolve and recurse + if (propSchema.$ref && typeof propSchema.$ref === "string") { + const typeName = toPascalCase( + refTypeName(propSchema.$ref, ctx.definitions), + ); + const resolved = resolveRef(propSchema.$ref, ctx.definitions); + if (resolved) { + if (resolved.enum) { + emitRustStringEnum( + typeName, + resolved.enum as string[], + ctx, + resolved.description, + ); + return wrapOption(typeName, isRequired); + } + if (isObjectSchema(resolved)) { + emitRustStruct(typeName, resolved, ctx); + return wrapOption(typeName, isRequired); + } + return resolveRustType( + resolved, + parentTypeName, + jsonPropName, + isRequired, + ctx, + ); + } + return wrapOption(typeName, isRequired); + } + + // anyOf โ€” nullable pattern or union + if (propSchema.anyOf) { + const discriminatedUnion = tryEmitRustDiscriminatedUnion( + propSchema, + parentTypeName, + jsonPropName, + ctx, + ); + if (discriminatedUnion) { + return wrapOption(discriminatedUnion, isRequired); + } + + const nonNull = (propSchema.anyOf as JSONSchema7[]).filter( + (s) => s.type !== "null", + ); + const hasNull = (propSchema.anyOf as JSONSchema7[]).some( + (s) => s.type === "null", + ); + + if (nonNull.length === 1) { + const innerType = resolveRustType( + nonNull[0], + parentTypeName, + jsonPropName, + true, + ctx, + ); + if (isRequired && !hasNull) return innerType; + return wrapOption(innerType, false); + } + + if (nonNull.length > 1) { + // Multi-type union โ€” use serde_json::Value as escape hatch + return wrapOption("serde_json::Value", isRequired); + } + } + + // oneOf โ€” treat like anyOf for now + if (propSchema.oneOf) { + const discriminatedUnion = tryEmitRustDiscriminatedUnion( + propSchema, + parentTypeName, + jsonPropName, + ctx, + ); + if (discriminatedUnion) { + return wrapOption(discriminatedUnion, isRequired); + } + + const nonNull = (propSchema.oneOf as JSONSchema7[]).filter( + (s) => s.type !== "null", + ); + if (nonNull.length === 1) { + const innerType = resolveRustType( + nonNull[0], + parentTypeName, + jsonPropName, + true, + ctx, + ); + return wrapOption(innerType, isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + // allOf โ€” merge and treat as object + if (propSchema.allOf) { + const merged = resolveObjectSchema(propSchema, ctx.definitions); + if (merged && isObjectSchema(merged)) { + const structName = (propSchema.title as string) || nestedName; + emitRustStruct(structName, merged, ctx); + return wrapOption(structName, isRequired); + } + } + + // enum + if (propSchema.enum && Array.isArray(propSchema.enum)) { + const enumName = (propSchema.title as string) || nestedName; + emitRustStringEnum( + enumName, + propSchema.enum as string[], + ctx, + propSchema.description, + ); + return wrapOption(enumName, isRequired); + } + + // const โ€” just a string + if (propSchema.const !== undefined) { + if (typeof propSchema.const === "string") { + const enumName = (propSchema.title as string) || nestedName; + emitRustConstStringEnum( + enumName, + propSchema.const, + ctx, + propSchema.description, + ); + return wrapOption(enumName, isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + const schemaType = propSchema.type; + + // Type arrays like ["string", "null"] + if (Array.isArray(schemaType)) { + const nonNullTypes = (schemaType as string[]).filter((t) => t !== "null"); + if (nonNullTypes.length === 1) { + const inner = resolveRustType( + { ...propSchema, type: nonNullTypes[0] as JSONSchema7["type"] }, + parentTypeName, + jsonPropName, + true, + ctx, + ); + return wrapOption(inner, false); + } + return wrapOption("serde_json::Value", isRequired); + } + + // Primitive types + if (schemaType === "string") { + const newtype = STRING_NEWTYPE_OVERRIDES[jsonPropName]; + if (newtype) return wrapOption(newtype, isRequired); + return wrapOption("String", isRequired); + } + if (schemaType === "number") return wrapOption("f64", isRequired); + if (schemaType === "integer") return wrapOption("i64", isRequired); + if (schemaType === "boolean") return wrapOption("bool", isRequired); + + // Array + if (schemaType === "array") { + const items = propSchema.items as JSONSchema7 | undefined; + if (items) { + const itemType = resolveRustType( + items, + parentTypeName, + `${jsonPropName}Item`, + true, + ctx, + ); + return wrapOption(`Vec<${itemType}>`, isRequired); + } + return wrapOption("Vec", isRequired); + } + + // Object + if (schemaType === "object" || (propSchema.properties && !schemaType)) { + if ( + propSchema.properties && + Object.keys(propSchema.properties).length > 0 + ) { + const structName = (propSchema.title as string) || nestedName; + emitRustStruct(structName, propSchema, ctx); + return wrapOption(structName, isRequired); + } + if (propSchema.additionalProperties) { + if ( + typeof propSchema.additionalProperties === "object" && + Object.keys(propSchema.additionalProperties as Record) + .length > 0 + ) { + const ap = propSchema.additionalProperties as JSONSchema7; + if (ap.type === "object" && ap.properties) { + const valueName = (ap.title as string) || `${nestedName}Value`; + emitRustStruct(valueName, ap, ctx); + return wrapOption(`HashMap`, isRequired); + } + const valueType = resolveRustType( + ap, + parentTypeName, + `${jsonPropName}Value`, + true, + ctx, + ); + return wrapOption(`HashMap`, isRequired); + } + return wrapOption("HashMap", isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + // Fallback + return wrapOption("serde_json::Value", isRequired); +} + +function wrapOption(rustType: string, isRequired: boolean): string { + if (isRequired) return rustType; + // Don't double-wrap Option, Vec, or HashMap (they're already nullable-ish) + if ( + rustType.startsWith("Option<") || + rustType.startsWith("Vec<") || + rustType.startsWith("HashMap<") + ) { + return rustType; + } + return `Option<${rustType}>`; +} + +// โ”€โ”€ Struct emission โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function emitRustStruct( + typeName: string, + schema: JSONSchema7, + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(typeName)) return; + ctx.generatedNames.add(typeName); + + const required = new Set(schema.required || []); + const lines: string[] = []; + const desc = description || schema.description; + if (desc) { + for (const line of desc.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + if (isSchemaDeprecated(schema)) { + lines.push("#[deprecated]"); + } + lines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + lines.push(`#[serde(rename_all = "camelCase")]`); + lines.push(`pub struct ${typeName} {`); + + for (const [propName, propSchema] of Object.entries( + schema.properties || {}, + )) { + if (typeof propSchema !== "object") continue; + const prop = propSchema as JSONSchema7; + const isReq = required.has(propName); + const rustField = safeRustFieldName(propName); + const rustType = resolveRustType(prop, typeName, propName, isReq, ctx); + + if (prop.description) { + for (const line of prop.description.split(/\r?\n/)) { + lines.push(` /// ${line}`); + } + } + if (isSchemaDeprecated(prop)) { + lines.push(" #[deprecated]"); + } + + // Determine if an explicit rename is needed. `rename_all = "camelCase"` on + // the struct converts snake_case fields to camelCase automatically, so we + // only need an explicit rename when that automatic conversion doesn't produce + // the original JSON property name. + const snakeField = toRustFieldName(propName); + const autoRename = snakeToCamelCase(snakeField); + const needsRename = autoRename !== propName; + const isOptionType = rustType.startsWith("Option<"); + const needsSkip = !isReq && isOptionType; + + if (needsSkip && needsRename) { + lines.push( + ` #[serde(rename = "${propName}", skip_serializing_if = "Option::is_none")]`, + ); + } else if (needsSkip) { + lines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + } else if (!isReq && !isOptionType && needsRename) { + lines.push(` #[serde(rename = "${propName}", default)]`); + } else if (!isReq && !isOptionType) { + lines.push(" #[serde(default)]"); + } else if (needsRename) { + lines.push(` #[serde(rename = "${propName}")]`); + } + + lines.push(` pub ${rustField}: ${rustType},`); + } + + lines.push("}"); + ctx.structs.push(lines.join("\n")); +} + +// โ”€โ”€ Enum emission โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function emitRustStringEnum( + enumName: string, + values: string[], + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(enumName)) return; + ctx.generatedNames.add(enumName); + + const lines: string[] = []; + if (description) { + for (const line of description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"); + lines.push(`pub enum ${enumName} {`); + + for (const value of values) { + const variantName = toPascalCase(value); + if (variantName !== value) { + lines.push(` #[serde(rename = "${value}")]`); + } + lines.push(` ${variantName},`); + } + + // Add a catch-all for forward compatibility + lines.push(" /// Unknown variant for forward compatibility."); + lines.push(" #[serde(other)]"); + lines.push(" Unknown,"); + + lines.push("}"); + ctx.enums.push(lines.join("\n")); +} + +function emitRustConstStringEnum( + enumName: string, + value: string, + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(enumName)) return; + ctx.generatedNames.add(enumName); + + const lines: string[] = []; + if (description) { + for (const line of description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"); + lines.push(`pub enum ${enumName} {`); + const variantName = toPascalCase(value); + if (variantName !== value) { + lines.push(` #[serde(rename = "${value}")]`); + } + lines.push(` ${variantName},`); + lines.push("}"); + ctx.enums.push(lines.join("\n")); +} + +// โ”€โ”€ Session events generation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +interface EventVariant { + /** The event type string, e.g. "session.start" */ + typeName: string; + /** PascalCase variant name, e.g. "SessionStart" */ + variantName: string; + /** Data struct name, e.g. "SessionStartData" */ + dataClassName: string; + /** Schema for the data field */ + dataSchema: JSONSchema7; + /** Description of the event */ + description?: string; +} + +function extractEventVariants(schema: JSONSchema7): EventVariant[] { + const definitionCollections = collectDefinitionCollections( + schema as Record, + ); + const sessionEvent = + resolveSchema( + { $ref: "#/definitions/SessionEvent" }, + definitionCollections, + ) ?? resolveSchema({ $ref: "#/$defs/SessionEvent" }, definitionCollections); + if (!sessionEvent?.anyOf) + throw new Error("Schema must have SessionEvent definition with anyOf"); + + return (sessionEvent.anyOf as JSONSchema7[]) + .map((variant) => { + const resolvedVariant = + resolveObjectSchema(variant as JSONSchema7, definitionCollections) ?? + resolveSchema(variant as JSONSchema7, definitionCollections) ?? + (variant as JSONSchema7); + if (typeof resolvedVariant !== "object" || !resolvedVariant.properties) { + throw new Error("Invalid variant"); + } + const typeSchema = resolvedVariant.properties.type as JSONSchema7; + const typeName = typeSchema?.const as string; + if (!typeName) throw new Error("Variant must have type.const"); + + const dataSchema = + resolveObjectSchema( + resolvedVariant.properties.data as JSONSchema7, + definitionCollections, + ) ?? + resolveSchema( + resolvedVariant.properties.data as JSONSchema7, + definitionCollections, + ) ?? + ((resolvedVariant.properties.data as JSONSchema7) || {}); + + return { + typeName, + variantName: toPascalCase(typeName), + dataClassName: `${toPascalCase(typeName)}Data`, + dataSchema, + description: resolvedVariant.description || dataSchema.description, + }; + }) + .filter((v) => !EXCLUDED_EVENT_TYPES.has(v.typeName)); +} + +function generateSessionEventsCode(schema: JSONSchema7): string { + const variants = extractEventVariants(schema); + const ctx = makeCtx( + collectDefinitionCollections(schema as Record), + ); + + // Generate per-event data structs + for (const variant of variants) { + emitRustStruct( + variant.dataClassName, + variant.dataSchema, + ctx, + variant.description, + ); + } + + // Build the SessionEventType enum + const typeEnumLines: string[] = []; + typeEnumLines.push("/// Identifies the kind of session event."); + typeEnumLines.push( + "#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]", + ); + typeEnumLines.push("pub enum SessionEventType {"); + for (const variant of variants) { + typeEnumLines.push(` #[serde(rename = "${variant.typeName}")]`); + typeEnumLines.push(` ${variant.variantName},`); + } + typeEnumLines.push(" /// Unknown event type for forward compatibility."); + typeEnumLines.push(" #[serde(other)]"); + typeEnumLines.push(" Unknown,"); + typeEnumLines.push("}"); + + // Build the SessionEventData enum (adjacently tagged by type/data) + const dataEnumLines: string[] = []; + dataEnumLines.push( + "/// Typed session event data, discriminated by the event `type` field.", + ); + dataEnumLines.push("///"); + dataEnumLines.push( + "/// Use with [`TypedSessionEvent`] for fully typed event handling.", + ); + dataEnumLines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + dataEnumLines.push(`#[serde(tag = "type", content = "data")]`); + dataEnumLines.push("pub enum SessionEventData {"); + for (const variant of variants) { + dataEnumLines.push(` #[serde(rename = "${variant.typeName}")]`); + dataEnumLines.push(` ${variant.variantName}(${variant.dataClassName}),`); + } + dataEnumLines.push("}"); + + // Build TypedSessionEvent that combines common fields with typed data + const typedEventLines: string[] = []; + typedEventLines.push("/// A session event with typed data payload."); + typedEventLines.push("///"); + typedEventLines.push( + "/// The common event fields (id, timestamp, parentId, ephemeral, agentId)", + ); + typedEventLines.push( + "/// are available directly. The event-specific data is in the `payload`", + ); + typedEventLines.push("/// field as a [`SessionEventData`] enum."); + typedEventLines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + typedEventLines.push(`#[serde(rename_all = "camelCase")]`); + typedEventLines.push("pub struct TypedSessionEvent {"); + typedEventLines.push(" /// Unique event identifier (UUID v4)."); + typedEventLines.push(" pub id: String,"); + typedEventLines.push( + " /// ISO 8601 timestamp when the event was created.", + ); + typedEventLines.push(" pub timestamp: String,"); + typedEventLines.push(" /// ID of the preceding event in the chain."); + typedEventLines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + typedEventLines.push(" pub parent_id: Option,"); + typedEventLines.push( + " /// When true, the event is transient and not persisted.", + ); + typedEventLines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + typedEventLines.push(" pub ephemeral: Option,"); + typedEventLines.push( + " /// Sub-agent instance identifier. Absent for events from the root /", + ); + typedEventLines.push( + " /// main agent and session-level events.", + ); + typedEventLines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + typedEventLines.push(" pub agent_id: Option,"); + typedEventLines.push( + " /// The typed event payload (discriminated by event type).", + ); + typedEventLines.push(" #[serde(flatten)]"); + typedEventLines.push(" pub payload: SessionEventData,"); + typedEventLines.push("}"); + + // Assemble file + const out: string[] = []; + out.push( + "//! Auto-generated from session-events.schema.json โ€” do not edit manually.", + ); + out.push(""); + out.push("use std::collections::HashMap;"); + out.push(""); + out.push("use serde::{Deserialize, Serialize};"); + out.push(""); + out.push("use crate::types::{RequestId, SessionId};"); + out.push(""); + + // SessionEventType enum + out.push(typeEnumLines.join("\n")); + out.push(""); + + // SessionEventData enum + out.push(dataEnumLines.join("\n")); + out.push(""); + + // TypedSessionEvent struct + out.push(typedEventLines.join("\n")); + out.push(""); + + // Per-event data structs + for (const block of ctx.structs) { + out.push(block); + out.push(""); + } + + // Supporting enums + for (const block of ctx.enums) { + out.push(block); + out.push(""); + } + + return out.join("\n"); +} + +// โ”€โ”€ API types generation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function collectRpcMethods( + node: Record, + prefix = "", +): RpcMethod[] { + const methods: RpcMethod[] = []; + for (const [key, value] of Object.entries(node)) { + if (isRpcMethod(value)) { + methods.push(value); + } else if (typeof value === "object" && value !== null) { + methods.push( + ...collectRpcMethods( + value as Record, + prefix ? `${prefix}.${key}` : key, + ), + ); + } + } + return methods; +} + +function rustParamsTypeName(method: RpcMethod): string { + return getRpcSchemaTypeName( + method.params, + `${toPascalCase(method.rpcMethod)}Params`, + ); +} + +function rustResultTypeName(method: RpcMethod): string { + return getRpcSchemaTypeName( + method.result, + `${toPascalCase(method.rpcMethod)}Result`, + ); +} + +function generateApiTypesCode(apiSchema: ApiSchema): string { + const definitions = collectDefinitions(apiSchema as Record); + const defCollections = collectDefinitionCollections( + apiSchema as Record, + ); + const ctx = makeCtx(defCollections); + + // Generate shared definitions (structs & enums) + for (const [name, def] of Object.entries(definitions)) { + if (typeof def !== "object" || def === null) continue; + const schema = def as JSONSchema7; + + if (schema.enum && Array.isArray(schema.enum)) { + emitRustStringEnum( + name, + schema.enum as string[], + ctx, + schema.description, + ); + } else if (isObjectSchema(schema)) { + emitRustStruct(name, schema, ctx, schema.description); + } + } + + // Collect all RPC methods and generate request/response types + const allMethods: RpcMethod[] = []; + for (const group of [ + apiSchema.server, + apiSchema.session, + apiSchema.clientSession, + ]) { + if (group) { + allMethods.push(...collectRpcMethods(group as Record)); + } + } + + // RPC method name constants + const methodConstLines: string[] = []; + methodConstLines.push("/// JSON-RPC method name constants."); + methodConstLines.push("pub mod rpc_methods {"); + + for (const method of allMethods) { + const constName = method.rpcMethod.replace(/\./g, "_").toUpperCase(); + methodConstLines.push(` /// \`${method.rpcMethod}\``); + methodConstLines.push( + ` pub const ${constName}: &str = "${method.rpcMethod}";`, + ); + } + methodConstLines.push("}"); + + // Generate param/result types for each method + for (const method of allMethods) { + if ( + method.params && + isObjectSchema(method.params) && + !isVoidSchema(method.params) + ) { + const paramsName = rustParamsTypeName(method); + emitRustStruct(paramsName, method.params, ctx, method.params.description); + } + if (method.result && !isVoidSchema(method.result)) { + const resultName = rustResultTypeName(method); + const resolved = resolveSchema(method.result, defCollections); + if (resolved) { + if (resolved.enum && Array.isArray(resolved.enum)) { + // Already generated from definitions + } else if (isObjectSchema(resolved)) { + emitRustStruct(resultName, resolved, ctx, resolved.description); + } + } + } + } + + // Assemble file + const out: string[] = []; + out.push("//! Auto-generated from api.schema.json โ€” do not edit manually."); + out.push(""); + out.push("#![allow(clippy::large_enum_variant)]"); + out.push(""); + out.push("use std::collections::HashMap;"); + out.push(""); + out.push("use serde::{Deserialize, Serialize};"); + out.push(""); + out.push("use crate::types::{RequestId, SessionId};"); + out.push(""); + + // Method constants + out.push(methodConstLines.join("\n")); + out.push(""); + + // Shared definition types first, then RPC types + for (const block of ctx.structs) { + out.push(block); + out.push(""); + } + + for (const block of ctx.enums) { + out.push(block); + out.push(""); + } + + return out.join("\n"); +} + +// โ”€โ”€ Typed RPC namespace generation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +interface NamespaceNode { + name: string; + typeName: string; + methods: RpcMethod[]; + children: Map; +} + +function newNamespaceNode(name: string, typeName: string): NamespaceNode { + return { name, typeName, methods: [], children: new Map() }; +} + +/** + * Build a namespace tree from a list of methods. `groupOf(method)` returns the + * dotted group path (e.g. "mcp.config" for "mcp.config.list" / "workspaces" + * for "workspaces.listFiles"); the last segment of `rpcMethod` is the leaf + * method name. + */ +function buildNamespaceTree( + rootTypeName: string, + methods: RpcMethod[], + stripPrefix: string, +): NamespaceNode { + const root = newNamespaceNode("", rootTypeName); + for (const method of methods) { + const trimmed = stripPrefix && method.rpcMethod.startsWith(stripPrefix) + ? method.rpcMethod.slice(stripPrefix.length) + : method.rpcMethod; + const segments = trimmed.split("."); + const groupSegments = segments.slice(0, -1); + let node = root; + for (const seg of groupSegments) { + let child = node.children.get(seg); + if (!child) { + const childTypeName = `${node.typeName}${toPascalCase(seg)}`; + child = newNamespaceNode(seg, childTypeName); + node.children.set(seg, child); + } + node = child; + } + node.methods.push(method); + } + return root; +} + +/** + * Determine if a method has typed params. Returns `{ hasParams, typeName }`. + * Handles `$ref`-based, title-bearing, and inline params uniformly: + * + * - Resolves `$ref` to its definition. + * - For session methods, ignores `sessionId` (the namespace injects it). + * - Returns `hasParams=false` when the resolved property set (after the + * sessionId filter for session methods) is empty. + * - The type name comes from `$ref` (preferred), then the resolved + * definition's `title`, then the inline params `title`. + */ +function getMethodParamsInfo( + method: RpcMethod, + defCollections: DefinitionCollections, + isSession: boolean, +): { hasParams: boolean; typeName: string | null } { + if (!method.params) return { hasParams: false, typeName: null }; + const inline = method.params as JSONSchema7 & { $ref?: string }; + const resolved = resolveSchema(inline, defCollections); + if (!resolved) return { hasParams: false, typeName: null }; + + let typeName: string | null = null; + if (typeof inline.$ref === "string") { + typeName = refTypeName(inline.$ref, defCollections); + } else if (typeof resolved.title === "string") { + typeName = resolved.title; + } else if (typeof inline.title === "string") { + typeName = inline.title; + } + + const allProps = Object.keys(resolved.properties || {}); + const props = isSession + ? allProps.filter((p) => p !== "sessionId") + : allProps; + if (props.length === 0) return { hasParams: false, typeName: null }; + if (!typeName) return { hasParams: false, typeName: null }; + return { hasParams: true, typeName }; +} + +function rpcMethodConstName(method: RpcMethod): string { + return method.rpcMethod.replace(/\./g, "_").toUpperCase(); +} + +function emitNamespaceStruct( + out: string[], + node: NamespaceNode, + holderType: string, + holderField: string, + isSession: boolean, + defCollections: DefinitionCollections, + docPrefix: string, +): void { + const lifetimes = "<'a>"; + out.push(`/// ${docPrefix}`); + out.push(`#[derive(Clone, Copy)]`); + out.push(`pub struct ${node.typeName}${lifetimes} {`); + out.push(` pub(crate) ${holderField}: &'a ${holderType},`); + out.push(`}`); + out.push(""); + + out.push(`impl${lifetimes} ${node.typeName}${lifetimes} {`); + + // Sub-namespace accessors + const childNames = Array.from(node.children.keys()).sort(); + for (const childName of childNames) { + const child = node.children.get(childName)!; + const accessor = toSnakeCase(childName); + const desc = isSession + ? `\`session.${accessorPath(node, childName, isSession)}.*\`` + : `\`${accessorPath(node, childName, isSession)}.*\``; + out.push(` /// ${desc} sub-namespace.`); + out.push( + ` pub fn ${accessor}(&self) -> ${child.typeName}<'a> {`, + ); + out.push(` ${child.typeName} { ${holderField}: self.${holderField} }`); + out.push(` }`); + out.push(""); + } + + // Leaf methods + for (const method of node.methods) { + emitNamespaceMethod(out, method, holderField, isSession, defCollections); + } + + out.push(`}`); + out.push(""); + + // Recursively emit child structs + for (const childName of childNames) { + const child = node.children.get(childName)!; + const childDoc = isSession + ? `\`session.${accessorPath(node, childName, isSession)}.*\` RPCs.` + : `\`${accessorPath(node, childName, isSession)}.*\` RPCs.`; + emitNamespaceStruct( + out, + child, + holderType, + holderField, + isSession, + defCollections, + childDoc, + ); + } +} + +function accessorPath(parent: NamespaceNode, child: string, _isSession: boolean): string { + // Build wire-style dotted path from the namespace tree's "name" chain plus child. + // `parent.name === ""` for root; we accumulate by retrieving parent name only. + // (We don't track full ancestry here; this is just for doc strings โ€” we + // fall back to the child name alone when at the root.) + if (!parent.name) return child; + return `${parent.name}.${child}`; +} + +function getResultTypeName( + method: RpcMethod, + defCollections: DefinitionCollections, +): string | null { + const result = method.result as (JSONSchema7 & { $ref?: string }) | null; + if (!result || isVoidSchema(result)) return null; + if (typeof result.$ref === "string") { + return refTypeName(result.$ref, defCollections); + } + if (typeof result.title === "string") return result.title; + return `${toPascalCase(method.rpcMethod)}Result`; +} + +function emitNamespaceMethod( + out: string[], + method: RpcMethod, + holderField: string, + isSession: boolean, + defCollections: DefinitionCollections, +): void { + const wireMethod = method.rpcMethod; + const constName = rpcMethodConstName(method); + const lastSegment = wireMethod.split(".").pop()!; + const fnName = toSnakeCase(lastSegment); + + const paramsInfo = getMethodParamsInfo(method, defCollections, isSession); + const hasParams = paramsInfo.hasParams; + const paramsTypeName = paramsInfo.typeName; + + const resultTypeName = getResultTypeName(method, defCollections); + const returnType = resultTypeName ? resultTypeName : "()"; + const resultIsVoid = resultTypeName === null; + + const docs: string[] = []; + docs.push(` /// Wire method: \`${wireMethod}\`.`); + if (method.deprecated) docs.push(` #[deprecated]`); + const stability = method.stability; + if (stability === "experimental") { + docs.push(` ///`); + docs.push( + ` ///

`, + ); + docs.push( + ` ///`, + ); + docs.push( + ` /// **Experimental.** This API is part of an experimental wire-protocol surface`, + ); + docs.push( + ` /// and may change or be removed in future SDK or CLI releases. Pin both the`, + ); + docs.push( + ` /// SDK and CLI versions if your code depends on it.`, + ); + docs.push( + ` ///`, + ); + docs.push( + ` ///
`, + ); + } else if (stability && stability !== "stable") { + docs.push(` /// Stability: \`${stability}\`.`); + } + + const paramArg = hasParams ? `, params: ${paramsTypeName}` : ""; + + out.push(...docs); + out.push( + ` pub async fn ${fnName}(&self${paramArg}) -> Result<${returnType}, Error> {`, + ); + + // Build the params Value sent over the wire. + if (isSession) { + if (hasParams) { + out.push(` let mut wire_params = serde_json::to_value(params)?;`); + out.push( + ` wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string());`, + ); + } else { + out.push( + ` let wire_params = serde_json::json!({ "sessionId": self.session.id() });`, + ); + } + out.push( + ` let _value = self.session.client().call(rpc_methods::${constName}, Some(wire_params)).await?;`, + ); + } else { + if (hasParams) { + out.push(` let wire_params = serde_json::to_value(params)?;`); + } else { + out.push(` let wire_params = serde_json::json!({});`); + } + out.push( + ` let _value = self.client.call(rpc_methods::${constName}, Some(wire_params)).await?;`, + ); + } + + if (resultIsVoid) { + out.push(` Ok(())`); + } else { + out.push(` Ok(serde_json::from_value(_value)?)`); + } + out.push(` }`); + out.push(""); +} + +function generateRpcCode(apiSchema: ApiSchema): string { + const defCollections = collectDefinitionCollections( + apiSchema as unknown as Record, + ); + + const serverMethods = apiSchema.server + ? collectRpcMethods(apiSchema.server as Record) + : []; + const sessionMethods = apiSchema.session + ? collectRpcMethods(apiSchema.session as Record) + : []; + + const clientRoot = buildNamespaceTree("ClientRpc", serverMethods, ""); + const sessionRoot = buildNamespaceTree( + "SessionRpc", + sessionMethods, + "session.", + ); + + const out: string[] = []; + out.push( + "//! Auto-generated typed JSON-RPC namespace โ€” do not edit manually.", + ); + out.push("//!"); + out.push( + "//! Generated from `api.schema.json` by `scripts/codegen/rust.ts`. The", + ); + out.push( + "//! [`ClientRpc`] and [`SessionRpc`] view structs let callers reach every", + ); + out.push( + "//! protocol method through a typed namespace tree, so wire method names", + ); + out.push( + "//! and request/response shapes live in exactly one place โ€” this file.", + ); + out.push(""); + out.push("#![allow(missing_docs)]"); + out.push("#![allow(clippy::too_many_arguments)]"); + out.push(""); + out.push("use super::api_types::*;"); + out.push("use super::api_types::rpc_methods;"); + out.push("use crate::session::Session;"); + out.push("use crate::{Client, Error};"); + out.push(""); + + emitNamespaceStruct( + out, + clientRoot, + "Client", + "client", + false, + defCollections, + "Typed view over the [`Client`]'s server-level RPC namespace.", + ); + emitNamespaceStruct( + out, + sessionRoot, + "Session", + "session", + true, + defCollections, + "Typed view over a [`Session`]'s RPC namespace.", + ); + + return out.join("\n"); +} + +// โ”€โ”€ mod.rs generation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +function generateModRs(): string { + const lines: string[] = []; + lines.push("//! Auto-generated protocol types โ€” do not edit manually."); + lines.push("//!"); + lines.push( + "//! Generated from the Copilot protocol JSON Schemas by `scripts/codegen/rust.ts`.", + ); + lines.push("#![allow(missing_docs)]"); + lines.push("#![allow(rustdoc::bare_urls)]"); + lines.push(""); + lines.push("pub mod api_types;"); + lines.push("pub mod rpc;"); + lines.push("pub mod session_events;"); + lines.push(""); + lines.push( + "// Re-export session event types at the module root โ€” no conflicts with", + ); + lines.push( + "// hand-written types. API types are kept namespaced under `api_types::`", + ); + lines.push( + "// because some names (Tool, ModelCapabilities, etc.) overlap with the", + ); + lines.push("// hand-written SDK API types in `types.rs`."); + lines.push("pub use session_events::*;"); + lines.push(""); + return lines.join("\n"); +} + +// โ”€โ”€ Format with rustfmt โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +async function rustfmt(filePath: string): Promise { + try { + await execFileAsync("rustfmt", ["--edition", "2021", filePath]); + } catch (e: unknown) { + const error = e as { stderr?: string }; + console.warn( + `rustfmt warning for ${path.basename(filePath)}: ${error.stderr || e}`, + ); + } +} + +// โ”€โ”€ Main โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +async function generate(): Promise { + console.log("Loading schemas..."); + + const sessionEventsSchemaPath = await getSessionEventsSchemaPath(); + const apiSchemaPath = await getApiSchemaPath(process.argv[2]); + + const sessionEventsRaw = JSON.parse( + await fs.readFile(sessionEventsSchemaPath, "utf-8"), + ); + const apiRaw = JSON.parse( + await fs.readFile(apiSchemaPath, "utf-8"), + ) as ApiSchema; + + const sessionEventsSchema = postProcessSchema( + stripBooleanLiterals(sessionEventsRaw) as JSONSchema7, + ); + const apiSchema = postProcessSchema( + stripBooleanLiterals(apiRaw) as JSONSchema7, + ) as unknown as ApiSchema; + + // Ensure output directory exists + await fs.mkdir(GENERATED_DIR, { recursive: true }); + + // Generate session events + console.log("Generating session_events.rs..."); + const sessionEventsCode = generateSessionEventsCode(sessionEventsSchema); + const sessionEventsPath = path.join(GENERATED_DIR, "session_events.rs"); + await fs.writeFile(sessionEventsPath, sessionEventsCode, "utf-8"); + await rustfmt(sessionEventsPath); + + // Generate API types + console.log("Generating api_types.rs..."); + const apiTypesCode = generateApiTypesCode(apiSchema); + const apiTypesPath = path.join(GENERATED_DIR, "api_types.rs"); + await fs.writeFile(apiTypesPath, apiTypesCode, "utf-8"); + await rustfmt(apiTypesPath); + + // Generate typed RPC namespace + console.log("Generating rpc.rs..."); + const rpcCode = generateRpcCode(apiSchema); + const rpcPath = path.join(GENERATED_DIR, "rpc.rs"); + await fs.writeFile(rpcPath, rpcCode, "utf-8"); + await rustfmt(rpcPath); + + // Generate mod.rs + console.log("Generating mod.rs..."); + const modRsCode = generateModRs(); + const modRsPath = path.join(GENERATED_DIR, "mod.rs"); + await fs.writeFile(modRsPath, modRsCode, "utf-8"); + await rustfmt(modRsPath); + + console.log(`Done! Generated files in ${GENERATED_DIR}`); +} + +generate().catch((err) => { + console.error("Code generation failed:", err); + process.exit(1); +}); diff --git a/test/scenarios/RUST_COVERAGE.md b/test/scenarios/RUST_COVERAGE.md new file mode 100644 index 000000000..f0c61979f --- /dev/null +++ b/test/scenarios/RUST_COVERAGE.md @@ -0,0 +1,61 @@ +# Rust scenario coverage + +Rust SDK scenario samples live alongside the TypeScript / Python / Go / C# samples under +`test/scenarios/*//rust/`. The monorepo's `scenario-builds.yml` workflow +auto-discovers any `*/rust/Cargo.toml` under `test/scenarios/` and verifies it builds. + +## Coverage + +| Category | Scenario | Status | +|----------------|-------------------------|--------| +| `transport/` | `stdio` | โœ… | +| `transport/` | `tcp` | โœ… | +| `transport/` | `external` | โŒ deferred (needs `from_streams`-style sample) | +| `sessions/` | `streaming` | โœ… | +| `sessions/` | `session-resume` | โœ… | +| `sessions/` | `infinite-sessions` | โœ… | +| `sessions/` | `concurrent-sessions` | โœ… | +| `sessions/` | `multi-user-*` | โŒ deferred (multi-client orchestration) | +| `modes/` | `default` | โœ… | +| `modes/` | non-default | โŒ deferred (plan mode, read-only) | +| `tools/` | `no-tools` | โœ… | +| `tools/` | `mcp-servers` | โœ… | +| `tools/` | `skills` | โœ… | +| `tools/` | `tool-filtering` | โœ… | +| `tools/` | `custom-agents` | โœ… | +| `tools/` | `tool-overrides` | โœ… | +| `tools/` | `virtual-filesystem` | โŒ deferred (needs `VirtualFilesystem` hook port) | +| `callbacks/` | `hooks` | โœ… | +| `callbacks/` | `permissions` | โœ… | +| `callbacks/` | `user-input` | โœ… | +| `prompts/` | `system-message` | โœ… | +| `prompts/` | `reasoning-effort` | โœ… | +| `prompts/` | `attachments` | โœ… | +| `bundling/` | * | โŒ app-level concern, not an SDK gap | +| `auth/` | * | โŒ deferred (GitHub-App / token-exchange) | + +## Remaining gaps + +- `transport/external` โ€” needs a sample using an externally-managed CLI process (parity with Node's `from_streams`). +- `tools/virtual-filesystem` โ€” depends on a future `VirtualFilesystem` hook port. +- `modes/*` (non-default) โ€” plan-mode and read-only-mode samples. +- `sessions/multi-user-*` โ€” multi-client orchestration. +- `auth/*` โ€” GitHub-App / token-exchange sample programs. +- `bundling/*` โ€” process bundling is application-level, not an SDK concern. + +## Running the samples locally + +Each scenario's `verify.sh` runs the Rust build + run phase alongside the other +languages. With a token in place (`GITHUB_TOKEN`, or `gh auth login`): + +```sh +cd test/scenarios/transport/stdio && ./verify.sh +``` + +To build all Rust scenario samples without running them (what CI does): + +```sh +for d in $(find test/scenarios -path '*/rust/Cargo.toml'); do + (cd "$(dirname "$d")" && cargo build --quiet) || echo "FAILED: $d" +done +``` diff --git a/test/scenarios/callbacks/hooks/rust/Cargo.toml b/test/scenarios/callbacks/hooks/rust/Cargo.toml new file mode 100644 index 000000000..4c16a91b5 --- /dev/null +++ b/test/scenarios/callbacks/hooks/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hooks-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/hooks/rust/src/main.rs b/test/scenarios/callbacks/hooks/rust/src/main.rs new file mode 100644 index 000000000..179765d2f --- /dev/null +++ b/test/scenarios/callbacks/hooks/rust/src/main.rs @@ -0,0 +1,131 @@ +//! Session hooks โ€” intercept lifecycle events (session start/end, pre/post +//! tool use, user prompt, errors) and log every firing. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{ + ErrorOccurredInput, ErrorOccurredOutput, HookContext, PostToolUseInput, PostToolUseOutput, + PreToolUseInput, PreToolUseOutput, SessionEndInput, SessionEndOutput, SessionHooks, + SessionStartInput, SessionStartOutput, UserPromptSubmittedInput, UserPromptSubmittedOutput, +}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct HookLogger { + log: Arc>>, +} + +impl HookLogger { + async fn append(&self, entry: String) { + self.log.lock().await.push(entry); + } +} + +#[async_trait] +impl SessionHooks for HookLogger { + async fn on_session_start( + &self, + _input: SessionStartInput, + _ctx: HookContext, + ) -> Option { + self.append("onSessionStart".to_string()).await; + None + } + + async fn on_session_end( + &self, + _input: SessionEndInput, + _ctx: HookContext, + ) -> Option { + self.append("onSessionEnd".to_string()).await; + None + } + + async fn on_pre_tool_use( + &self, + input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onPreToolUse:{}", input.tool_name)) + .await; + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } + + async fn on_post_tool_use( + &self, + input: PostToolUseInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onPostToolUse:{}", input.tool_name)) + .await; + None + } + + async fn on_user_prompt_submitted( + &self, + input: UserPromptSubmittedInput, + _ctx: HookContext, + ) -> Option { + self.append("onUserPromptSubmitted".to_string()).await; + let mut out = UserPromptSubmittedOutput::default(); + out.modified_prompt = Some(input.prompt); + Some(out) + } + + async fn on_error_occurred( + &self, + input: ErrorOccurredInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onErrorOccurred:{}", input.error)) + .await; + None + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let hook_log = Arc::new(Mutex::new(Vec::::new())); + let hooks = Arc::new(HookLogger { + log: hook_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(hooks); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "List the files in the current directory using the glob tool with pattern '*.md'.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- Hook execution log ---"); + let log = hook_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal hooks fired: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/hooks/verify.sh b/test/scenarios/callbacks/hooks/verify.sh index 8157fed78..e6f706e61 100755 --- a/test/scenarios/callbacks/hooks/verify.sh +++ b/test/scenarios/callbacks/hooks/verify.sh @@ -120,6 +120,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o hooks-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -137,6 +139,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./hooks-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/callbacks/permissions/rust/Cargo.toml b/test/scenarios/callbacks/permissions/rust/Cargo.toml new file mode 100644 index 000000000..a30a94162 --- /dev/null +++ b/test/scenarios/callbacks/permissions/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "permissions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/permissions/rust/src/main.rs b/test/scenarios/callbacks/permissions/rust/src/main.rs new file mode 100644 index 000000000..214620e35 --- /dev/null +++ b/test/scenarios/callbacks/permissions/rust/src/main.rs @@ -0,0 +1,91 @@ +//! Permission callback โ€” log every `permission.request` from the CLI and +//! approve all of them. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct PermissionLogger { + log: Arc>>, +} + +#[async_trait] +impl SessionHandler for PermissionLogger { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let tool_name = data + .extra + .get("tool") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + self.log.lock().await.push(format!("approved:{tool_name}")); + PermissionResult::Approved + } +} + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let permission_log = Arc::new(Mutex::new(Vec::::new())); + let handler = Arc::new(PermissionLogger { + log: permission_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config + .with_handler(handler) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "List the files in the current directory using glob with pattern '*.md'.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- Permission request log ---"); + let log = permission_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal permission requests: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/permissions/verify.sh b/test/scenarios/callbacks/permissions/verify.sh index bc4af1f6a..e63438a6e 100755 --- a/test/scenarios/callbacks/permissions/verify.sh +++ b/test/scenarios/callbacks/permissions/verify.sh @@ -114,6 +114,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o permissions-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -131,6 +133,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./permissions-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/callbacks/user-input/rust/Cargo.toml b/test/scenarios/callbacks/user-input/rust/Cargo.toml new file mode 100644 index 000000000..83430f128 --- /dev/null +++ b/test/scenarios/callbacks/user-input/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "user-input-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs new file mode 100644 index 000000000..b7fea906e --- /dev/null +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -0,0 +1,103 @@ +//! User-input callback โ€” answer the agent's `ask_user` prompts and log +//! every question. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler, UserInputResponse}; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct InputResponder { + log: Arc>>, +} + +#[async_trait] +impl SessionHandler for InputResponder { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + + async fn on_user_input( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + self.log + .lock() + .await + .push(format!("question: {question}")); + Some(UserInputResponse { + answer: "Paris".to_string(), + was_freeform: true, + }) + } +} + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let input_log = Arc::new(Mutex::new(Vec::::new())); + let handler = Arc::new(InputResponder { + log: input_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.request_user_input = Some(true); + let config = config + .with_handler(handler) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "I want to learn about a city. Use the ask_user tool to ask me \ + which city I'm interested in. Then tell me about that city.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- User input log ---"); + let log = input_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal user input requests: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/user-input/verify.sh b/test/scenarios/callbacks/user-input/verify.sh index 4550a4c1f..5e35eb67c 100755 --- a/test/scenarios/callbacks/user-input/verify.sh +++ b/test/scenarios/callbacks/user-input/verify.sh @@ -114,6 +114,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o user-input-go . 2 # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -131,6 +133,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./user-input-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/modes/default/rust/Cargo.toml b/test/scenarios/modes/default/rust/Cargo.toml new file mode 100644 index 000000000..d3483ec64 --- /dev/null +++ b/test/scenarios/modes/default/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "default-mode-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/modes/default/rust/src/main.rs b/test/scenarios/modes/default/rust/src/main.rs new file mode 100644 index 000000000..ba890997d --- /dev/null +++ b/test/scenarios/modes/default/rust/src/main.rs @@ -0,0 +1,36 @@ +//! Default agent mode โ€” the agent has access to built-in tools (grep, view, etc.) +//! and can use them to complete a task. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "Use the grep tool to search for the word 'SDK' in README.md and show the matching lines.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Response: {content}"); + } + } + + println!("Default mode test complete"); + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/modes/default/verify.sh b/test/scenarios/modes/default/verify.sh index 9d9b78578..e8811d0d9 100755 --- a/test/scenarios/modes/default/verify.sh +++ b/test/scenarios/modes/default/verify.sh @@ -107,6 +107,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o default-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -125,6 +128,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./default-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/prompts/attachments/README.md b/test/scenarios/prompts/attachments/README.md index 76b76751d..2bdb551fb 100644 --- a/test/scenarios/prompts/attachments/README.md +++ b/test/scenarios/prompts/attachments/README.md @@ -34,12 +34,14 @@ Demonstrates sending **file attachments** alongside a prompt using the Copilot S | TypeScript | `attachments: [{ type: "file", path: sampleFile }]` | | Python | `"attachments": [{"type": "file", "path": sample_file}]` | | Go | `Attachments: []copilot.Attachment{{Type: "file", Path: sampleFile}}` | +| Rust | `Attachment::File { path, display_name: None, line_range: None }` | | Language | Blob Attachment Syntax | |----------|------------------------| | TypeScript | `attachments: [{ type: "blob", data: base64Data, mimeType: "image/png" }]` | | Python | `"attachments": [{"type": "blob", "data": base64_data, "mimeType": "image/png"}]` | | Go | `Attachments: []copilot.Attachment{{Type: copilot.AttachmentTypeBlob, Data: &data, MIMEType: &mime}}` | +| Rust | `Attachment::Blob { data, mime_type, display_name: None }` | ## Sample Data diff --git a/test/scenarios/prompts/attachments/rust/Cargo.toml b/test/scenarios/prompts/attachments/rust/Cargo.toml new file mode 100644 index 000000000..e87952f14 --- /dev/null +++ b/test/scenarios/prompts/attachments/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "attachments-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/attachments/rust/src/main.rs b/test/scenarios/prompts/attachments/rust/src/main.rs new file mode 100644 index 000000000..9ba9cc176 --- /dev/null +++ b/test/scenarios/prompts/attachments/rust/src/main.rs @@ -0,0 +1,58 @@ +//! File attachments โ€” send a prompt alongside a file attachment so the +//! model can reference the file's content in its response. + +use std::path::PathBuf; +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{Attachment, MessageOptions, SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = + "You are a helpful assistant. Answer questions about attached files concisely."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + // CARGO_MANIFEST_DIR resolves to .../prompts/attachments/rust at compile time. + let sample_file: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "sample-data.txt"] + .iter() + .collect(); + let sample_file = sample_file.canonicalize().unwrap_or(sample_file); + + let response = session + .send_and_wait( + MessageOptions::new("What languages are listed in the attached file?").with_attachments( + vec![Attachment::File { + path: sample_file, + display_name: None, + line_range: None, + }], + ), + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/attachments/verify.sh b/test/scenarios/prompts/attachments/verify.sh index cf4a91977..41b4f108c 100755 --- a/test/scenarios/prompts/attachments/verify.sh +++ b/test/scenarios/prompts/attachments/verify.sh @@ -110,6 +110,9 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o attachments-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -127,6 +130,9 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./attachments-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml b/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml new file mode 100644 index 000000000..c48db3c98 --- /dev/null +++ b/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "reasoning-effort-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs new file mode 100644 index 000000000..bf1ab9720 --- /dev/null +++ b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs @@ -0,0 +1,40 @@ +//! Reasoning effort โ€” set the model's reasoning depth via +//! `SessionConfig::reasoning_effort`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some("You are a helpful assistant. Answer concisely.".to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-opus-4.6".to_string()); + config.reasoning_effort = Some("low".to_string()); + config.available_tools = Some(Vec::new()); + config.system_message = Some(sysmsg); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Reasoning effort: low"); + println!("Response: {content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/reasoning-effort/verify.sh b/test/scenarios/prompts/reasoning-effort/verify.sh index fe528229e..4d32e4d87 100755 --- a/test/scenarios/prompts/reasoning-effort/verify.sh +++ b/test/scenarios/prompts/reasoning-effort/verify.sh @@ -110,6 +110,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o reasoning-effort- # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./reasoning-effort-g # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/prompts/system-message/rust/Cargo.toml b/test/scenarios/prompts/system-message/rust/Cargo.toml new file mode 100644 index 000000000..0d153f9cc --- /dev/null +++ b/test/scenarios/prompts/system-message/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "system-message-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/system-message/rust/src/main.rs b/test/scenarios/prompts/system-message/rust/src/main.rs new file mode 100644 index 000000000..4218a389b --- /dev/null +++ b/test/scenarios/prompts/system-message/rust/src/main.rs @@ -0,0 +1,40 @@ +//! Custom system message โ€” replace the built-in prompt entirely. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const PIRATE_PROMPT: &str = "You are a pirate. Always respond in pirate speak. Say 'Arrr!' \ +in every response. Use nautical terms and pirate slang throughout."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(PIRATE_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/system-message/verify.sh b/test/scenarios/prompts/system-message/verify.sh index c2699768b..d1f60e5c4 100755 --- a/test/scenarios/prompts/system-message/verify.sh +++ b/test/scenarios/prompts/system-message/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o system-message-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./system-message-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml b/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml new file mode 100644 index 000000000..a6de4e273 --- /dev/null +++ b/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "concurrent-sessions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs new file mode 100644 index 000000000..43932b613 --- /dev/null +++ b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs @@ -0,0 +1,53 @@ +//! Concurrent sessions โ€” two sessions on a single client running in +//! parallel with different system prompts. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const PIRATE_PROMPT: &str = "You are a pirate. Always say Arrr!"; +const ROBOT_PROMPT: &str = "You are a robot. Always say BEEP BOOP!"; + +fn make_config(system: &str) -> SessionConfig { + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(system.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + config.with_handler(Arc::new(ApproveAllHandler)) +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let session1 = client.create_session(make_config(PIRATE_PROMPT)).await?; + let session2 = client.create_session(make_config(ROBOT_PROMPT)).await?; + + let (r1, r2) = tokio::join!( + session1.send_and_wait("What is the capital of France?"), + session2.send_and_wait("What is the capital of France?"), + ); + + if let Some(event) = r1? { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Session 1 (pirate): {content}"); + } + } + if let Some(event) = r2? { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Session 2 (robot): {content}"); + } + } + + session1.destroy().await?; + session2.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/concurrent-sessions/verify.sh b/test/scenarios/sessions/concurrent-sessions/verify.sh index be4e3d309..25e6fab18 100755 --- a/test/scenarios/sessions/concurrent-sessions/verify.sh +++ b/test/scenarios/sessions/concurrent-sessions/verify.sh @@ -138,6 +138,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o concurrent-sessio # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -155,6 +157,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./concurrent-session # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml b/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml new file mode 100644 index 000000000..1f23af8a6 --- /dev/null +++ b/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "infinite-sessions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs new file mode 100644 index 000000000..0c0f06814 --- /dev/null +++ b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs @@ -0,0 +1,55 @@ +//! Infinite sessions โ€” explicit `InfiniteSessionConfig` thresholds and a +//! sequence of three turns to exercise the persistent workspace. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{InfiniteSessionConfig, SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = + Some("You are a helpful assistant. Answer concisely in one sentence.".to_string()); + + let mut infinite = InfiniteSessionConfig::default(); + infinite.enabled = Some(true); + infinite.background_compaction_threshold = Some(0.80); + infinite.buffer_exhaustion_threshold = Some(0.95); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.available_tools = Some(Vec::new()); + config.system_message = Some(sysmsg); + config.infinite_sessions = Some(infinite); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let prompts = [ + "What is the capital of France?", + "What is the capital of Japan?", + "What is the capital of Brazil?", + ]; + + for prompt in prompts { + let response = session.send_and_wait(prompt).await?; + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Q: {prompt}"); + println!("A: {content}\n"); + } + } + } + + println!("Infinite sessions test complete โ€” all messages processed successfully"); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/infinite-sessions/verify.sh b/test/scenarios/sessions/infinite-sessions/verify.sh index fe4de01e4..367901f28 100755 --- a/test/scenarios/sessions/infinite-sessions/verify.sh +++ b/test/scenarios/sessions/infinite-sessions/verify.sh @@ -116,6 +116,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o infinite-sessions # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -133,6 +135,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./infinite-sessions- # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/sessions/session-resume/rust/Cargo.toml b/test/scenarios/sessions/session-resume/rust/Cargo.toml new file mode 100644 index 000000000..ed6207260 --- /dev/null +++ b/test/scenarios/sessions/session-resume/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "session-resume-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/session-resume/rust/src/main.rs b/test/scenarios/sessions/session-resume/rust/src/main.rs new file mode 100644 index 000000000..10cd4fa62 --- /dev/null +++ b/test/scenarios/sessions/session-resume/rust/src/main.rs @@ -0,0 +1,46 @@ +//! Session resume โ€” create a session, plant a memory, then resume by ID +//! and verify the agent recalls it. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{ResumeSessionConfig, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + session + .send_and_wait("Remember this: the secret word is PINEAPPLE.") + .await?; + + let session_id = session.id().clone(); + // Note: do NOT destroy โ€” `resume_session` needs the session to persist. + + let resume_config = + ResumeSessionConfig::new(session_id).with_handler(Arc::new(ApproveAllHandler)); + let resumed = client.resume_session(resume_config).await?; + println!("Session resumed"); + + let response = resumed + .send_and_wait("What was the secret word I told you?") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + resumed.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/session-resume/verify.sh b/test/scenarios/sessions/session-resume/verify.sh index 02cc14d5a..07a5992e9 100755 --- a/test/scenarios/sessions/session-resume/verify.sh +++ b/test/scenarios/sessions/session-resume/verify.sh @@ -117,6 +117,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o session-resume-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -135,6 +137,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./session-resume-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/sessions/streaming/rust/Cargo.toml b/test/scenarios/sessions/streaming/rust/Cargo.toml new file mode 100644 index 000000000..31acc381b --- /dev/null +++ b/test/scenarios/sessions/streaming/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "streaming-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/streaming/rust/src/main.rs b/test/scenarios/sessions/streaming/rust/src/main.rs new file mode 100644 index 000000000..f5cf23764 --- /dev/null +++ b/test/scenarios/sessions/streaming/rust/src/main.rs @@ -0,0 +1,66 @@ +//! Streaming session โ€” count `assistant.message_delta` events while waiting +//! for the final response. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +struct StreamCounter { + chunks: Arc, +} + +#[async_trait] +impl SessionHandler for StreamCounter { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { event, .. } => { + if event.event_type == "assistant.message_delta" { + self.chunks.fetch_add(1, Ordering::Relaxed); + } + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + _ => HandlerResponse::Ok, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let chunks = Arc::new(AtomicUsize::new(0)); + let handler = Arc::new(StreamCounter { + chunks: chunks.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.streaming = Some(true); + let config = config.with_handler(handler); + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!( + "\nStreaming chunks received: {}", + chunks.load(Ordering::Relaxed) + ); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/streaming/verify.sh b/test/scenarios/sessions/streaming/verify.sh index 070ef059b..828f42a43 100755 --- a/test/scenarios/sessions/streaming/verify.sh +++ b/test/scenarios/sessions/streaming/verify.sh @@ -114,6 +114,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o streaming-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -132,6 +135,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./streaming-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/tools/custom-agents/rust/Cargo.toml b/test/scenarios/tools/custom-agents/rust/Cargo.toml new file mode 100644 index 000000000..6d536052c --- /dev/null +++ b/test/scenarios/tools/custom-agents/rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "custom-agents-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust", features = ["derive"] } +schemars = "1" +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs new file mode 100644 index 000000000..e707770bc --- /dev/null +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -0,0 +1,82 @@ +//! Custom agents โ€” define a sub-agent ("researcher") with its own prompt +//! and tool allowlist, alongside a client-defined `analyze-codebase` tool. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::types::{CustomAgentConfig, DefaultAgentConfig, SessionConfig, ToolResult}; +use github_copilot_sdk::{Client, ClientOptions}; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +#[schemars(description = "Parameters for analyze-codebase")] +struct AnalyzeParams { + /// the analysis query + query: String, +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let analyze_codebase = define_tool( + "analyze-codebase", + "Performs deep analysis of the codebase", + |_inv, params: AnalyzeParams| async move { + Ok(ToolResult::Text(format!( + "Analysis result for: {}", + params.query + ))) + }, + ); + + let router = ToolHandlerRouter::new(vec![analyze_codebase], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + + let mut researcher = CustomAgentConfig::default(); + researcher.name = "researcher".to_string(); + researcher.display_name = Some("Research Agent".to_string()); + researcher.description = Some( + "A research agent that can only read and search files, not modify them".to_string(), + ); + researcher.tools = Some(vec![ + "grep".to_string(), + "glob".to_string(), + "view".to_string(), + "analyze-codebase".to_string(), + ]); + researcher.prompt = + "You are a research assistant. You can search and read files but cannot modify \ + anything. When asked about your capabilities, list the tools you have access to." + .to_string(); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.tools = Some(tools); + config.default_agent = Some(DefaultAgentConfig { + excluded_tools: Some(vec!["analyze-codebase".to_string()]), + }); + config.custom_agents = Some(vec![researcher]); + let config = config.with_handler(Arc::new(router)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "What custom agents are available? Describe the researcher agent and its capabilities.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/custom-agents/verify.sh b/test/scenarios/tools/custom-agents/verify.sh index 826f9df9d..4d295b47f 100755 --- a/test/scenarios/tools/custom-agents/verify.sh +++ b/test/scenarios/tools/custom-agents/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o custom-agents-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./custom-agents-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/tools/mcp-servers/rust/Cargo.toml b/test/scenarios/tools/mcp-servers/rust/Cargo.toml new file mode 100644 index 000000000..84c40e3be --- /dev/null +++ b/test/scenarios/tools/mcp-servers/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "mcp-servers-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs new file mode 100644 index 000000000..fd76147a1 --- /dev/null +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -0,0 +1,68 @@ +//! MCP servers โ€” configure an MCP server from env and pass it through to +//! the CLI via `SessionConfig::mcp_servers`. Build-only when +//! `MCP_SERVER_CMD` is unset. + +use std::collections::HashMap; +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{ + McpServerConfig, McpStdioServerConfig, SessionConfig, SystemMessageConfig, +}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mcp_cmd = std::env::var("MCP_SERVER_CMD").ok(); + let mcp_args_env = std::env::var("MCP_SERVER_ARGS").ok(); + let mcp_servers = mcp_cmd.as_ref().map(|cmd| { + let args: Vec = mcp_args_env + .as_deref() + .map(|s| s.split(' ').map(str::to_string).collect()) + .unwrap_or_default(); + let stdio = McpStdioServerConfig { + tools: vec!["*".to_string()], + command: cmd.clone(), + args, + ..Default::default() + }; + let mut map = HashMap::new(); + map.insert("example".to_string(), McpServerConfig::Stdio(stdio)); + map + }); + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = + Some("You are a helpful assistant. Answer questions concisely.".to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + config.mcp_servers = mcp_servers; + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + if mcp_cmd.is_some() { + println!("\nMCP servers configured: example"); + } else { + println!("\nNo MCP servers configured (set MCP_SERVER_CMD to test with a real server)"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/mcp-servers/verify.sh b/test/scenarios/tools/mcp-servers/verify.sh index b087e0625..abde4508e 100755 --- a/test/scenarios/tools/mcp-servers/verify.sh +++ b/test/scenarios/tools/mcp-servers/verify.sh @@ -105,6 +105,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o mcp-servers-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -123,6 +125,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./mcp-servers-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/tools/no-tools/rust/Cargo.toml b/test/scenarios/tools/no-tools/rust/Cargo.toml new file mode 100644 index 000000000..461469946 --- /dev/null +++ b/test/scenarios/tools/no-tools/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "no-tools-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/no-tools/rust/src/main.rs b/test/scenarios/tools/no-tools/rust/src/main.rs new file mode 100644 index 000000000..691ac47ed --- /dev/null +++ b/test/scenarios/tools/no-tools/rust/src/main.rs @@ -0,0 +1,44 @@ +//! No-tools session โ€” replace the system prompt and empty the available tools +//! list so the agent cannot execute code, read files, or call any built-ins. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = "You are a minimal assistant with no tools available. +You cannot execute code, read files, edit files, search, or perform any actions. +You can only respond with text based on your training data. +If asked about your capabilities or tools, clearly state that you have no tools available."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use the bash tool to run 'echo hello'.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/no-tools/verify.sh b/test/scenarios/tools/no-tools/verify.sh index 1223c7dcc..286796b70 100755 --- a/test/scenarios/tools/no-tools/verify.sh +++ b/test/scenarios/tools/no-tools/verify.sh @@ -107,6 +107,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o no-tools-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -125,6 +128,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./no-tools-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/tools/skills/rust/Cargo.toml b/test/scenarios/tools/skills/rust/Cargo.toml new file mode 100644 index 000000000..c2de4b20e --- /dev/null +++ b/test/scenarios/tools/skills/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "skills-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/skills/rust/src/main.rs b/test/scenarios/tools/skills/rust/src/main.rs new file mode 100644 index 000000000..845704fac --- /dev/null +++ b/test/scenarios/tools/skills/rust/src/main.rs @@ -0,0 +1,62 @@ +//! Skills โ€” point the CLI at a directory of user-defined skills via +//! `SessionConfig::skill_directories`. + +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + // CARGO_MANIFEST_DIR resolves to .../tools/skills/rust at compile time. + let skills_dir: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "sample-skills"] + .iter() + .collect(); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.skill_directories = Some(vec![skills_dir]); + let config = config + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use the greeting skill to greet someone named Alice.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\nSkill directories configured successfully"); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/skills/verify.sh b/test/scenarios/tools/skills/verify.sh index fb13fcb16..6d1881173 100755 --- a/test/scenarios/tools/skills/verify.sh +++ b/test/scenarios/tools/skills/verify.sh @@ -108,6 +108,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o skills-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -125,6 +127,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./skills-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/tools/tool-filtering/rust/Cargo.toml b/test/scenarios/tools/tool-filtering/rust/Cargo.toml new file mode 100644 index 000000000..88e38073d --- /dev/null +++ b/test/scenarios/tools/tool-filtering/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tool-filtering-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/tool-filtering/rust/src/main.rs b/test/scenarios/tools/tool-filtering/rust/src/main.rs new file mode 100644 index 000000000..edc203550 --- /dev/null +++ b/test/scenarios/tools/tool-filtering/rust/src/main.rs @@ -0,0 +1,47 @@ +//! Tool filtering โ€” restrict the agent to a subset of built-in tools via +//! `SessionConfig::available_tools`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = "You are a helpful assistant. You have access to a limited set \ +of tools. When asked about your tools, list exactly which tools you have available."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(vec![ + "grep".to_string(), + "glob".to_string(), + "view".to_string(), + ]); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("What tools do you have available? List each one by name.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/tool-filtering/verify.sh b/test/scenarios/tools/tool-filtering/verify.sh index 058b7129e..d73377718 100755 --- a/test/scenarios/tools/tool-filtering/verify.sh +++ b/test/scenarios/tools/tool-filtering/verify.sh @@ -119,6 +119,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tool-filtering-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -137,6 +139,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tool-filtering-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/tools/tool-overrides/rust/Cargo.toml b/test/scenarios/tools/tool-overrides/rust/Cargo.toml new file mode 100644 index 000000000..f3b9d6aef --- /dev/null +++ b/test/scenarios/tools/tool-overrides/rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "tool-overrides-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust", features = ["derive"] } +schemars = "1" +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs new file mode 100644 index 000000000..ce002a27d --- /dev/null +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -0,0 +1,61 @@ +//! Tool overrides โ€” replace the built-in `grep` tool with a custom +//! implementation that returns a distinct marker. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::types::{SessionConfig, ToolResult}; +use github_copilot_sdk::{Client, ClientOptions}; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +#[schemars(description = "Parameters for custom grep")] +struct GrepParams { + /// Search query + query: String, +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let grep_tool = define_tool( + "grep", + "A custom grep implementation that overrides the built-in", + |_inv, params: GrepParams| async move { + Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {}", params.query))) + }, + ); + + let router = ToolHandlerRouter::new(vec![grep_tool], Arc::new(ApproveAllHandler)); + let mut tools = router.tools(); + for t in tools.iter_mut() { + if t.name == "grep" { + t.overrides_built_in_tool = true; + } + } + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.tools = Some(tools); + let config = config.with_handler(Arc::new(router)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use grep to search for the word 'hello'") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/tool-overrides/verify.sh b/test/scenarios/tools/tool-overrides/verify.sh index b7687de50..cf9b34d51 100755 --- a/test/scenarios/tools/tool-overrides/verify.sh +++ b/test/scenarios/tools/tool-overrides/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tool-overrides-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tool-overrides-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/transport/stdio/README.md b/test/scenarios/transport/stdio/README.md index 5178935cc..7de2457ec 100644 --- a/test/scenarios/transport/stdio/README.md +++ b/test/scenarios/transport/stdio/README.md @@ -23,6 +23,7 @@ Each sample follows the same flow: | `typescript/` | `@github/copilot-sdk` | TypeScript (Node.js) | | `python/` | `github-copilot-sdk` | Python | | `go/` | `github.com/github/copilot-sdk/go` | Go | +| `rust/` | `copilot-sdk` | Rust | ## Prerequisites diff --git a/test/scenarios/transport/stdio/rust/Cargo.toml b/test/scenarios/transport/stdio/rust/Cargo.toml new file mode 100644 index 000000000..aa22474c0 --- /dev/null +++ b/test/scenarios/transport/stdio/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "stdio-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/transport/stdio/rust/src/main.rs b/test/scenarios/transport/stdio/rust/src/main.rs new file mode 100644 index 000000000..156b3587d --- /dev/null +++ b/test/scenarios/transport/stdio/rust/src/main.rs @@ -0,0 +1,30 @@ +//! Stdio transport โ€” spawn the CLI as a child and exchange JSON-RPC over its stdio. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/transport/stdio/verify.sh b/test/scenarios/transport/stdio/verify.sh index 9a5b11b17..f9f004675 100755 --- a/test/scenarios/transport/stdio/verify.sh +++ b/test/scenarios/transport/stdio/verify.sh @@ -104,6 +104,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o stdio-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -122,6 +125,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./stdio-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/transport/tcp/rust/Cargo.toml b/test/scenarios/transport/tcp/rust/Cargo.toml new file mode 100644 index 000000000..fe5d19a91 --- /dev/null +++ b/test/scenarios/transport/tcp/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tcp-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/transport/tcp/rust/src/main.rs b/test/scenarios/transport/tcp/rust/src/main.rs new file mode 100644 index 000000000..49691c1b2 --- /dev/null +++ b/test/scenarios/transport/tcp/rust/src/main.rs @@ -0,0 +1,43 @@ +//! TCP transport โ€” connect to an externally-running CLI server. Reads +//! `COPILOT_CLI_URL` (default `localhost:3000`) for `host:port`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions, Transport}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let cli_url = + std::env::var("COPILOT_CLI_URL").unwrap_or_else(|_| "localhost:3000".to_string()); + let (host, port_str) = cli_url + .split_once(':') + .expect("COPILOT_CLI_URL must be 'host:port'"); + let port: u16 = port_str.parse().expect("COPILOT_CLI_URL port must be u16"); + + let mut opts = ClientOptions::default(); + opts.transport = Transport::External { + host: host.to_string(), + port, + }; + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/transport/tcp/verify.sh b/test/scenarios/transport/tcp/verify.sh index 711e0959a..fd30b98f9 100755 --- a/test/scenarios/transport/tcp/verify.sh +++ b/test/scenarios/transport/tcp/verify.sh @@ -163,6 +163,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tcp-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" @@ -181,6 +183,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tcp-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" diff --git a/test/scenarios/verify.sh b/test/scenarios/verify.sh index 543c93d2b..7b6b066a0 100755 --- a/test/scenarios/verify.sh +++ b/test/scenarios/verify.sh @@ -43,12 +43,13 @@ TOTAL=${#VERIFY_SCRIPTS[@]} # โ”€โ”€ SDK icon helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ sdk_icons() { local log="$1" - local ts py go cs + local ts py go cs rs ts="$(sdk_status "$log" "TypeScript")" py="$(sdk_status "$log" "Python")" go="$(sdk_status "$log" "Go ")" cs="$(sdk_status "$log" "C#")" - printf "TS %s PY %s GO %s C# %s" "$ts" "$py" "$go" "$cs" + rs="$(sdk_status "$log" "Rust")" + printf "TS %s PY %s GO %s C# %s RS %s" "$ts" "$py" "$go" "$cs" "$rs" } sdk_status() {