diff --git a/.gitignore b/.gitignore index 6ff86481d..a445051c6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ # Documentation validation output docs/.validation/ .DS_Store + +# Visual Studio +.vs/ diff --git a/dotnet/.gitignore b/dotnet/.gitignore index ef38c1ee2..870a409f5 100644 --- a/dotnet/.gitignore +++ b/dotnet/.gitignore @@ -16,7 +16,6 @@ src/build/GitHub.Copilot.SDK.props *.sln.docstates # IDE -.vs/ .vscode/ *.swp *~ diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 5447fee51..822b36c93 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -10,7 +10,6 @@ - diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index c1fa14f81..791f70d45 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -5,8 +5,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using StreamJsonRpc; -using StreamJsonRpc.Protocol; using System.Collections.Concurrent; using System.Data; using System.Diagnostics; @@ -80,7 +78,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private readonly List> _lifecycleHandlers = []; private readonly Dictionary>> _typedLifecycleHandlers = []; private readonly object _lifecycleHandlersLock = new(); - private ServerRpc? _rpc; + private ServerRpc? _serverRpc; /// /// Gets the typed RPC client for server-scoped methods (no session required). @@ -92,7 +90,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable /// Thrown if the client is not started. public ServerRpc Rpc => _disposed ? throw new ObjectDisposedException(nameof(CopilotClient)) - : _rpc ?? throw new InvalidOperationException("Client is not started. Call StartAsync first."); + : _serverRpc ?? throw new InvalidOperationException("Client is not started. Call StartAsync first."); /// /// Gets the actual TCP port the CLI server is listening on, if using TCP transport. @@ -341,7 +339,7 @@ private async Task CleanupConnectionAsync(List? errors) catch (Exception ex) { errors?.Add(ex); } // Clear RPC and models cache - _rpc = null; + _serverRpc = null; _modelsCache = null; if (ctx.NetworkStream is not null) @@ -350,12 +348,6 @@ private async Task CleanupConnectionAsync(List? errors) catch (Exception ex) { errors?.Add(ex); } } - if (ctx.TcpClient is not null) - { - try { ctx.TcpClient.Dispose(); } - catch (Exception ex) { errors?.Add(ex); } - } - if (ctx.CliProcess is { } childProcess) { try @@ -1059,9 +1051,9 @@ internal static async Task InvokeRpcAsync(JsonRpc rpc, string method, obje { try { - return await rpc.InvokeWithCancellationAsync(method, args, cancellationToken); + return await rpc.InvokeAsync(method, args, cancellationToken); } - catch (StreamJsonRpc.ConnectionLostException ex) + catch (ConnectionLostException ex) { string? stderrOutput = null; if (stderrBuffer is not null) @@ -1078,7 +1070,7 @@ internal static async Task InvokeRpcAsync(JsonRpc rpc, string method, obje } throw new IOException($"Communication error with Copilot CLI: {ex.Message}", ex); } - catch (StreamJsonRpc.RemoteRpcException ex) + catch (RemoteRpcException ex) { throw new IOException($"Communication error with Copilot CLI: {ex.Message}", ex); } @@ -1329,12 +1321,15 @@ private static (string FileName, IEnumerable Args) ResolveCliCommand(str private async Task ConnectToServerAsync(Process? cliProcess, string? tcpHost, int? tcpPort, StringBuilder? stderrBuffer, CancellationToken cancellationToken) { Stream inputStream, outputStream; - TcpClient? tcpClient = null; NetworkStream? networkStream = null; if (_options.UseStdio) { - if (cliProcess == null) throw new InvalidOperationException("CLI process not started"); + if (cliProcess == null) + { + throw new InvalidOperationException("CLI process not started"); + } + inputStream = cliProcess.StandardOutput.BaseStream; outputStream = cliProcess.StandardInput.BaseStream; } @@ -1345,33 +1340,38 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? throw new InvalidOperationException("Cannot connect because TCP host or port are not available"); } - tcpClient = new(); - await tcpClient.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); - networkStream = tcpClient.GetStream(); - inputStream = networkStream; - outputStream = networkStream; + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + try + { + await socket.ConnectAsync(tcpHost, tcpPort.Value, cancellationToken); + } + catch + { + socket.Dispose(); + throw; + } + + inputStream = outputStream = networkStream = new NetworkStream(socket, ownsSocket: true); } - var rpc = new JsonRpc(new HeaderDelimitedMessageHandler( + var rpc = new JsonRpc( outputStream, inputStream, - CreateSystemTextJsonFormatter())) - { - TraceSource = new LoggerTraceSource(_logger), - }; + SerializerOptionsForMessageFormatter, + _logger); var handler = new RpcHandler(this); - rpc.AddLocalRpcMethod("session.event", handler.OnSessionEvent); - rpc.AddLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); + rpc.SetLocalRpcMethod("session.event", handler.OnSessionEvent); + rpc.SetLocalRpcMethod("session.lifecycle", handler.OnSessionLifecycle); // Protocol v3 servers send tool calls / permission requests as broadcast events. // Protocol v2 servers use the older tool.call / permission.request RPC model. // We always register v2 adapters because handlers are set up before version // negotiation; a v3 server will simply never send these requests. - rpc.AddLocalRpcMethod("tool.call", handler.OnToolCallV2); - rpc.AddLocalRpcMethod("permission.request", handler.OnPermissionRequestV2); - rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); - rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); - rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); + rpc.SetLocalRpcMethod("tool.call", handler.OnToolCallV2); + rpc.SetLocalRpcMethod("permission.request", handler.OnPermissionRequestV2); + rpc.SetLocalRpcMethod("userInput.request", handler.OnUserInputRequest); + rpc.SetLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); + rpc.SetLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => { var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); @@ -1380,18 +1380,11 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? rpc.StartListening(); // Transition state to Disconnected if the JSON-RPC connection drops - _ = rpc.Completion.ContinueWith(_ => _disconnected = true, TaskScheduler.Default); + _ = rpc.Completion.ContinueWith(_ => _disconnected = true, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); - _rpc = new ServerRpc(rpc); + _serverRpc = new ServerRpc(rpc); - return new Connection(rpc, cliProcess, tcpClient, networkStream, stderrBuffer); - } - - [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")] - [UnconditionalSuppressMessage("AOT", "IL3050", Justification = "Using happy path from https://microsoft.github.io/vs-streamjsonrpc/docs/nativeAOT.html")] - private static SystemTextJsonFormatter CreateSystemTextJsonFormatter() - { - return new() { JsonSerializerOptions = SerializerOptionsForMessageFormatter }; + return new Connection(rpc, cliProcess, networkStream, stderrBuffer); } private static JsonSerializerOptions SerializerOptionsForMessageFormatter { get; } = CreateSerializerOptions(); @@ -1410,12 +1403,6 @@ private static JsonSerializerOptions CreateSerializerOptions() options.TypeInfoResolverChain.Add(SessionEventsJsonContext.Default); options.TypeInfoResolverChain.Add(SDK.Rpc.RpcJsonContext.Default); - // StreamJsonRpc's RequestId needs serialization when CancellationToken fires during - // JSON-RPC operations. Its built-in converter (RequestIdSTJsonConverter) is internal, - // and [JsonSerializable] can't source-gen for it (SYSLIB1220), so we provide our own - // AOT-safe resolver + converter. - options.TypeInfoResolverChain.Add(new RequestIdTypeInfoResolver()); - options.MakeReadOnly(); return options; @@ -1484,7 +1471,7 @@ public void OnSessionLifecycle(string type, string sessionId, JsonElement? metad client.DispatchLifecycleEvent(evt); } - public async Task OnUserInputRequest(string sessionId, string question, IList? choices = null, bool? allowFreeform = null) + public async ValueTask OnUserInputRequest(string sessionId, string question, IList? choices = null, bool? allowFreeform = null) { var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); var request = new UserInputRequest @@ -1498,14 +1485,14 @@ public async Task OnUserInputRequest(string sessionId, return new UserInputRequestResponse(result.Answer, result.WasFreeform); } - public async Task OnHooksInvoke(string sessionId, string hookType, JsonElement input) + public async ValueTask OnHooksInvoke(string sessionId, string hookType, JsonElement input) { var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); var output = await session.HandleHooksInvokeAsync(hookType, input); return new HooksInvokeResponse(output); } - public async Task OnSystemMessageTransform(string sessionId, JsonElement sections) + public async ValueTask OnSystemMessageTransform(string sessionId, JsonElement sections) { var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); return await session.HandleSystemMessageTransformAsync(sections); @@ -1513,7 +1500,7 @@ public async Task OnSystemMessageTransform(st // Protocol v2 backward-compatibility adapters - public async Task OnToolCallV2(string sessionId, + public async ValueTask OnToolCallV2(string sessionId, string toolCallId, string toolName, object? arguments, @@ -1580,7 +1567,7 @@ public async Task OnToolCallV2(string sessionId, } } - public async Task OnPermissionRequestV2(string sessionId, JsonElement permissionRequest) + public async ValueTask OnPermissionRequestV2(string sessionId, JsonElement permissionRequest) { var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); @@ -1611,12 +1598,10 @@ public async Task OnPermissionRequestV2(string sess private class Connection( JsonRpc rpc, Process? cliProcess, // Set if we created the child process - TcpClient? tcpClient, // Set if using TCP NetworkStream? networkStream, // Set if using TCP StringBuilder? stderrBuffer = null) // Captures stderr for error messages { public Process? CliProcess => cliProcess; - public TcpClient? TcpClient => tcpClient; public JsonRpc Rpc => rpc; public NetworkStream? NetworkStream => networkStream; public StringBuilder? StderrBuffer => stderrBuffer; @@ -1770,90 +1755,11 @@ internal record ToolCallResponseV2( internal record PermissionRequestResponseV2( PermissionRequestResult Result); - /// Trace source that forwards all logs to the ILogger. - internal sealed class LoggerTraceSource : TraceSource - { - public LoggerTraceSource(ILogger logger) : base(nameof(LoggerTraceSource), SourceLevels.All) - { - Listeners.Clear(); - Listeners.Add(new LoggerTraceListener(logger)); - } - - private sealed class LoggerTraceListener(ILogger logger) : TraceListener - { - public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? message) - { - LogLevel level = MapLevel(eventType); - if (logger.IsEnabled(level)) - { - logger.Log(level, "[{Source}] {Message}", source, message); - } - } - - public override void TraceEvent(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, string? format, params object?[]? args) - { - LogLevel level = MapLevel(eventType); - if (logger.IsEnabled(level)) - { - logger.Log(level, "[{Source}] {Message}", source, args is null || args.Length == 0 ? format : string.Format(CultureInfo.InvariantCulture, format ?? "", args)); - } - } - - public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, object? data) - { - LogLevel level = MapLevel(eventType); - if (logger.IsEnabled(level)) - { - logger.Log(level, "[{Source}] {Data}", source, data); - } - } - - public override void TraceData(TraceEventCache? eventCache, string source, TraceEventType eventType, int id, params object?[]? data) - { - LogLevel level = MapLevel(eventType); - if (logger.IsEnabled(level)) - { - logger.Log(level, "[{Source}] {Data}", source, data is null ? null : string.Join(", ", data)); - } - } - - public override void Write(string? message) - { - if (logger.IsEnabled(LogLevel.Trace)) - { - logger.LogTrace("{Message}", message); - } - } - - public override void WriteLine(string? message) - { - if (logger.IsEnabled(LogLevel.Trace)) - { - logger.LogTrace("{Message}", message); - } - } - - private static LogLevel MapLevel(TraceEventType eventType) - { - return eventType switch - { - TraceEventType.Critical => LogLevel.Critical, - TraceEventType.Error => LogLevel.Error, - TraceEventType.Warning => LogLevel.Warning, - TraceEventType.Information => LogLevel.Information, - TraceEventType.Verbose => LogLevel.Debug, - _ => LogLevel.Trace - }; - } - } - } - [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, NumberHandling = JsonNumberHandling.AllowReadingFromString, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] - [JsonSerializable(typeof(CommonErrorData))] [JsonSerializable(typeof(CreateSessionRequest))] [JsonSerializable(typeof(CreateSessionResponse))] [JsonSerializable(typeof(CustomAgentConfig))] @@ -1887,50 +1793,6 @@ private static LogLevel MapLevel(TraceEventType eventType) [JsonSerializable(typeof(UserInputResponse))] internal partial class ClientJsonContext : JsonSerializerContext; - /// - /// AOT-safe type info resolver for . - /// StreamJsonRpc's own RequestIdSTJsonConverter is internal (SYSLIB1220/CS0122), - /// so we provide our own converter and wire it through - /// to stay fully AOT/trimming-compatible. - /// - private sealed class RequestIdTypeInfoResolver : IJsonTypeInfoResolver - { - public JsonTypeInfo? GetTypeInfo(Type type, JsonSerializerOptions options) - { - if (type == typeof(RequestId)) - return JsonMetadataServices.CreateValueInfo(options, new RequestIdJsonConverter()); - return null; - } - } - - private sealed class RequestIdJsonConverter : JsonConverter - { - public override RequestId Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return reader.TokenType switch - { - JsonTokenType.Number => reader.TryGetInt64(out long val) - ? new RequestId(val) - : new RequestId(reader.HasValueSequence - ? Encoding.UTF8.GetString(reader.ValueSequence) - : Encoding.UTF8.GetString(reader.ValueSpan)), - JsonTokenType.String => new RequestId(reader.GetString()!), - JsonTokenType.Null => RequestId.Null, - _ => throw new JsonException($"Unexpected token type for RequestId: {reader.TokenType}"), - }; - } - - public override void Write(Utf8JsonWriter writer, RequestId value, JsonSerializerOptions options) - { - if (value.Number.HasValue) - writer.WriteNumberValue(value.Number.Value); - else if (value.String is not null) - writer.WriteStringValue(value.String); - else - writer.WriteNullValue(); - } - } - [GeneratedRegex(@"listening on port ([0-9]+)", RegexOptions.IgnoreCase)] private static partial Regex ListeningOnPortRegex(); } diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index fce4b4708..f90d836c9 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -12,7 +12,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; -using StreamJsonRpc; namespace GitHub.Copilot.SDK.Rpc; @@ -4096,7 +4095,7 @@ public sealed class ClientSessionApiHandlers } /// Registers client session API handlers on a JSON-RPC connection. -public static class ClientSessionApiRegistration +internal static class ClientSessionApiRegistration { /// /// Registers handlers for server-to-client session API calls. @@ -4105,106 +4104,66 @@ public static class ClientSessionApiRegistration /// public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func getHandlers) { - var registerSessionFsReadFileMethod = (Func>)(async (request, cancellationToken) => + rpc.SetLocalRpcMethod("sessionFs.readFile", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.ReadFileAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsReadFileMethod.Method, registerSessionFsReadFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readFile") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsWriteFileMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.writeFile", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.WriteFileAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsWriteFileMethod.Method, registerSessionFsWriteFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.writeFile") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsAppendFileMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.appendFile", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.AppendFileAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsAppendFileMethod.Method, registerSessionFsAppendFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.appendFile") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsExistsMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.exists", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.ExistsAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsExistsMethod.Method, registerSessionFsExistsMethod.Target!, new JsonRpcMethodAttribute("sessionFs.exists") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsStatMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.stat", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.StatAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsStatMethod.Method, registerSessionFsStatMethod.Target!, new JsonRpcMethodAttribute("sessionFs.stat") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsMkdirMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.mkdir", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.MkdirAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsMkdirMethod.Method, registerSessionFsMkdirMethod.Target!, new JsonRpcMethodAttribute("sessionFs.mkdir") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsReaddirMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.readdir", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.ReaddirAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsReaddirMethod.Method, registerSessionFsReaddirMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readdir") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsReaddirWithTypesMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.readdirWithTypes", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.ReaddirWithTypesAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsReaddirWithTypesMethod.Method, registerSessionFsReaddirWithTypesMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readdirWithTypes") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsRmMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.rm", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.RmAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsRmMethod.Method, registerSessionFsRmMethod.Target!, new JsonRpcMethodAttribute("sessionFs.rm") - { - UseSingleObjectParameterDeserialization = true - }); - var registerSessionFsRenameMethod = (Func>)(async (request, cancellationToken) => + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("sessionFs.rename", (Func>)(async (request, cancellationToken) => { var handler = getHandlers(request.SessionId).SessionFs; if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); return await handler.RenameAsync(request, cancellationToken); - }); - rpc.AddLocalRpcMethod(registerSessionFsRenameMethod.Method, registerSessionFsRenameMethod.Target!, new JsonRpcMethodAttribute("sessionFs.rename") - { - UseSingleObjectParameterDeserialization = true - }); + }), singleObjectParam: true); } } diff --git a/dotnet/src/GitHub.Copilot.SDK.csproj b/dotnet/src/GitHub.Copilot.SDK.csproj index 38eb0cf3a..abcb8a51a 100644 --- a/dotnet/src/GitHub.Copilot.SDK.csproj +++ b/dotnet/src/GitHub.Copilot.SDK.csproj @@ -37,7 +37,6 @@ - diff --git a/dotnet/src/JsonRpc.cs b/dotnet/src/JsonRpc.cs new file mode 100644 index 000000000..2970b9991 --- /dev/null +++ b/dotnet/src/JsonRpc.cs @@ -0,0 +1,835 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Buffers; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using System.Text.Unicode; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace GitHub.Copilot.SDK; + +/// +/// A lightweight JSON-RPC 2.0 implementation covering only the features used +/// by this SDK to talk to the Copilot CLI. Messages are framed using the +/// LSP-style header convention (Content-Length: N\r\n\r\n followed by +/// N bytes of JSON body) — the same wire format used by the Language Server +/// Protocol and the Copilot CLI's other language SDKs (Go, Node, Python). +/// This is not a general-purpose JSON-RPC stack: it is narrowly scoped to the +/// methods, transports, and framing the CLI uses. +/// +internal sealed partial class JsonRpc : IDisposable +{ + private const int ErrorCodeMethodNotFound = -32601; + private const int ErrorCodeInternalError = -32603; + + private readonly Stream _sendStream; + private readonly Stream _receiveStream; + private readonly JsonSerializerOptions _serializerOptions; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _pendingRequests = new(); + private readonly ConcurrentDictionary _methods = new(); + private readonly TaskCompletionSource _completionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly CancellationTokenSource _disposeCts = new(); + private long _nextId; + private bool _disposed; + + /// + /// Initializes a new . + /// + /// The stream to write outgoing messages to. + /// The stream to read incoming messages from. + /// JSON serializer options (should include all needed source-gen contexts). + /// Optional logger for diagnostics. + public JsonRpc(Stream sendStream, Stream receiveStream, JsonSerializerOptions serializerOptions, ILogger? logger = null) + { + _sendStream = sendStream; + _receiveStream = receiveStream; + _serializerOptions = serializerOptions; + _logger = logger ?? NullLogger.Instance; + } + + /// + /// A that completes when the connection is closed or faulted. + /// + public Task Completion => _completionSource.Task; + + /// + /// Begins reading messages from the receive stream. Call once after registering all method handlers. + /// + public void StartListening() + { + _ = ReadLoopAsync(_disposeCts.Token); + } + + /// + /// Sends a JSON-RPC request and waits for the response. + /// + public async Task InvokeAsync(string method, object?[]? args, CancellationToken cancellationToken) + { + var id = Interlocked.Increment(ref _nextId); + var pending = new PendingRequest(); + _pendingRequests[id] = pending; + + CancellationTokenRegistration cancelRegistration = default; + try + { + if (cancellationToken.CanBeCanceled) + { + cancelRegistration = cancellationToken.Register(static state => + { + var (self, reqId, ct) = ((JsonRpc, long, CancellationToken))state!; + if (self._pendingRequests.TryRemove(reqId, out var p)) + { + p.TrySetCanceled(ct); + } + + // Best-effort cancel notification + _ = self.SendCancelNotificationAsync(reqId); + }, (this, id, cancellationToken)); + } + + // Send request message + await SendMessageAsync(new JsonRpcRequest + { + Id = id, + Method = method, + Params = SerializeArgs(args), + }, JsonRpcWireContext.Default.JsonRpcRequest, cancellationToken).ConfigureAwait(false); + + var responseElement = await pending.Task.ConfigureAwait(false); + + if (responseElement.ValueKind == JsonValueKind.Null || responseElement.ValueKind == JsonValueKind.Undefined) + { + return default!; + } + + return (T)responseElement.Deserialize(_serializerOptions.GetTypeInfo(typeof(T)))!; + } + finally + { + _pendingRequests.TryRemove(id, out _); + await cancelRegistration.DisposeAsync().ConfigureAwait(false); + } + } + + /// + /// Registers a method handler that receives positional parameters. + /// If singleObjectParam is false (the default), parameter names and types are inferred from the delegate's signature. + /// If singleObjectParam is true, the entire params object is deserialized as the handler's first parameter. + /// + public void SetLocalRpcMethod(string methodName, Delegate handler, bool singleObjectParam = false) + { + _methods[methodName] = new MethodRegistration(handler, singleObjectParam); + } + + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + _disposeCts.Cancel(); + + // Fail all pending requests + foreach (var kvp in _pendingRequests) + { + if (_pendingRequests.TryRemove(kvp.Key, out var pending)) + { + pending.TrySetException(new ObjectDisposedException(nameof(JsonRpc))); + } + } + + _completionSource.TrySetResult(); + _writeLock.Dispose(); + } + + private async Task SendMessageAsync(T message, JsonTypeInfo typeInfo, CancellationToken cancellationToken) + { + // "Content-Length: " (16) + max int digits (10) + "\r\n\r\n" (4) + const int MaxHeaderLength = 30; + + var json = JsonSerializer.SerializeToUtf8Bytes(message, typeInfo); + + var headerBuf = ArrayPool.Shared.Rent(MaxHeaderLength); + bool wrote = Utf8.TryWrite(headerBuf, $"Content-Length: {json.Length}\r\n\r\n", out int headerLen); + Debug.Assert(wrote && headerLen > 0); + + // Cancellation only applies to *waiting* for the write lock. Once we hold the lock + // and start writing a framed message, we must finish it — cancelling between the + // header and the body (or mid-body) would leave the peer waiting for N body bytes + // that never arrive, desynchronizing the LSP-style stream for every subsequent + // message on this connection. + await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + await _sendStream.WriteAsync(headerBuf.AsMemory(0, headerLen), CancellationToken.None).ConfigureAwait(false); + await _sendStream.WriteAsync(json, CancellationToken.None).ConfigureAwait(false); + await _sendStream.FlushAsync(CancellationToken.None).ConfigureAwait(false); + } + finally + { + _writeLock.Release(); + ArrayPool.Shared.Return(headerBuf); + } + } + + private async Task ReadLoopAsync(CancellationToken cancellationToken) + { + var buffer = new byte[256]; + int carried = 0; // bytes in buffer carried over from previous read + try + { + while (!cancellationToken.IsCancellationRequested) + { + // Read headers and body + var (contentLength, buf, newCarried) = await ReadMessageAsync(buffer, carried, cancellationToken).ConfigureAwait(false); + if (contentLength < 0) + { + break; // Stream ended + } + + // Keep the (possibly grown) buffer and carry-over count for next iteration + buffer = buf; + carried = newCarried; + + // Parse the raw JSON. Body is at buffer[0..contentLength], carried bytes + // for the next message are at buffer[contentLength..contentLength+carried]. + JsonElement? message = null; + try + { + using var doc = JsonDocument.Parse(buffer.AsMemory(0, contentLength)); + message = doc.RootElement.Clone(); + } + catch (JsonException ex) + { + _logger.LogWarning(ex, "Failed to parse incoming JSON-RPC message"); + } + + // Always move carried bytes to the front, even on parse failure — otherwise + // the next ReadMessageAsync call would scan stale body bytes as headers. + // This must happen AFTER parsing because the carried region overlaps where + // the body lived. + if (carried > 0) + { + Buffer.BlockCopy(buffer, contentLength, buffer, 0, carried); + } + + if (message is not { } parsed) + { + continue; + } + + // Route the message + if (parsed.TryGetProperty("id", out var idProp) && !parsed.TryGetProperty("method", out _)) + { + // It's a response to one of our requests + HandleResponse(parsed, idProp); + } + else if (parsed.TryGetProperty("method", out var methodProp) && methodProp.GetString() is string methodName) + { + _ = HandleIncomingMethodAsync(methodName, parsed, cancellationToken); + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + } + catch (Exception ex) + { + _logger.LogDebug(ex, "JSON-RPC read loop ended"); + } + finally + { + // Fail all pending requests + foreach (var kvp in _pendingRequests) + { + if (_pendingRequests.TryRemove(kvp.Key, out var pending)) + { + pending.TrySetException(new ConnectionLostException()); + } + } + + _completionSource.TrySetResult(); + } + } + + /// + /// Reads headers and body in one pass. + /// On return, body is at buffer[0..ContentLength], and any overflow bytes + /// from the next message are at buffer[ContentLength..ContentLength+Carried]. + /// The caller must move the carried bytes to the front before the next call. + /// + /// Shared buffer (may be grown). + /// Bytes already in buffer[0..carried] from a previous read. + /// Cancellation token. + private async ValueTask<(int ContentLength, byte[] Buffer, int Carried)> ReadMessageAsync(byte[] buffer, int carried, CancellationToken cancellationToken) + { + // Read until we find the \r\n\r\n header terminator. + // carried bytes are already at buffer[0..carried]. + int filled = carried; + int headerEnd = -1; // index of first byte after \r\n\r\n + + // Check carried bytes first for a header terminator + { + int pos = buffer.AsSpan(0, filled).IndexOf("\r\n\r\n"u8); + if (pos >= 0) + { + headerEnd = pos + 4; + } + } + + while (headerEnd < 0) + { + if (filled == buffer.Length) + { + Array.Resize(ref buffer, buffer.Length * 2); + } + + int bytesRead = await _receiveStream.ReadAsync(buffer.AsMemory(filled, buffer.Length - filled), cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + // Clean EOF only if we haven't started a frame; otherwise the peer truncated mid-header. + if (filled == 0) + { + return (-1, buffer, 0); + } + + throw new EndOfStreamException("Stream ended while reading JSON-RPC headers."); + } + + filled += bytesRead; + + // Scan for \r\n\r\n starting from where a match could begin + int scanStart = Math.Max(filled - bytesRead - 3, 0); + int pos = buffer.AsSpan(scanStart, filled - scanStart).IndexOf("\r\n\r\n"u8); + if (pos >= 0) + { + headerEnd = scanStart + pos + 4; + } + } + + // Parse Content-Length. LSP framing puts each header on its own \r\n-terminated + // line; we walk the lines and require an exact "Content-Length: " prefix at the + // start of one of them. A substring match anywhere in the header block would + // false-positive on values like "X-Trace: Content-Length: 5" and desync the stream. + // A missing or unparseable Content-Length means the framing is broken — there's + // no safe way to resync, so throw and let the read loop terminate the connection. + int contentLength = -1; + ReadOnlySpan prefix = "Content-Length: "u8; + // headerEnd points just past the \r\n\r\n terminator. Drop only the trailing + // empty line's \r\n; each remaining header line is still \r\n-terminated and + // gets split out by the IndexOf below. + var headerLines = buffer.AsSpan(0, headerEnd - 2); + while (!headerLines.IsEmpty) + { + int lineEnd = headerLines.IndexOf("\r\n"u8); + ReadOnlySpan line = lineEnd >= 0 ? headerLines.Slice(0, lineEnd) : headerLines; + + if (line.StartsWith(prefix) && + (contentLength >= 0 || + !int.TryParse(line.Slice(prefix.Length), NumberStyles.None, CultureInfo.InvariantCulture, out contentLength) || + contentLength < 0)) + { + throw new InvalidDataException("JSON-RPC frame has a missing, duplicate, or invalid Content-Length header."); + } + + headerLines = lineEnd >= 0 ? headerLines.Slice(lineEnd + 2) : default; + } + + if (contentLength < 0) + { + throw new InvalidDataException("JSON-RPC frame is missing the Content-Length header."); + } + + // Bytes after the header that we already have + int extraBytes = filled - headerEnd; + + // Ensure buffer is large enough for the body and any overflow already read. + int needed = Math.Max(contentLength, extraBytes); + if (needed > buffer.Length) + { + var newBuffer = new byte[needed]; + Buffer.BlockCopy(buffer, headerEnd, newBuffer, 0, extraBytes); + buffer = newBuffer; + } + else if (extraBytes > 0) + { + Buffer.BlockCopy(buffer, headerEnd, buffer, 0, extraBytes); + } + + // Read remaining body bytes if we don't have enough + if (extraBytes < contentLength) + { + await _receiveStream.ReadExactlyAsync(buffer.AsMemory(extraBytes, contentLength - extraBytes), cancellationToken).ConfigureAwait(false); + return (contentLength, buffer, 0); + } + + // We read more than the body — overflow belongs to the next message + int overflow = extraBytes - contentLength; + return (contentLength, buffer, overflow); + } + + private void HandleResponse(JsonElement message, JsonElement idProp) + { + if (!idProp.TryGetInt64(out long id)) + { + return; + } + + if (!_pendingRequests.TryRemove(id, out var pending)) + { + return; + } + + if (message.TryGetProperty("error", out var errorProp)) + { + var errorMessage = errorProp.TryGetProperty("message", out var msgProp) + ? msgProp.GetString() ?? "Unknown error" + : "Unknown error"; + var errorCode = errorProp.TryGetProperty("code", out var codeProp) && codeProp.ValueKind == JsonValueKind.Number + ? codeProp.GetInt32() + : 0; + pending.TrySetException(new RemoteRpcException(errorMessage, errorCode)); + } + else if (message.TryGetProperty("result", out var resultProp)) + { + pending.TrySetResult(resultProp.Clone()); + } + else + { + // Per JSON-RPC 2.0, a response must have either "result" or "error". + // Treat missing result as null result. + pending.TrySetResult(default); + } + } + + private async Task HandleIncomingMethodAsync(string methodName, JsonElement message, CancellationToken cancellationToken) + { + try + { + JsonElement? requestId = null; + if (message.TryGetProperty("id", out var idProp)) + { + requestId = idProp; + } + + if (!_methods.TryGetValue(methodName, out var registration)) + { + if (requestId.HasValue) + { + await SendErrorResponseAsync(requestId.Value, ErrorCodeMethodNotFound, $"Method not found: {methodName}", cancellationToken).ConfigureAwait(false); + } + return; + } + + message.TryGetProperty("params", out var paramsProp); + + try + { + var result = await InvokeHandlerAsync(registration, paramsProp, cancellationToken).ConfigureAwait(false); + + if (requestId.HasValue) + { + await SendResultResponseAsync(requestId.Value, result, cancellationToken).ConfigureAwait(false); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.LogDebug("Error handling JSON-RPC method {Method}: {Error}", methodName, ex.Message); + } + if (requestId.HasValue) + { + await SendErrorResponseAsync(requestId.Value, ErrorCodeInternalError, ex.Message, cancellationToken).ConfigureAwait(false); + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown — cancellation propagated from the read loop. + } + catch (Exception ex) + { + // Belt-and-braces: this method is fire-and-forget from the read loop, so any + // exception escaping here would become an unobserved task exception. The most + // likely sources are IOException/ObjectDisposedException from sending the error + // response after the underlying transport is gone. + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.LogDebug(ex, "Unobserved error in JSON-RPC method dispatch for {Method}", methodName); + } + } + } + + private async ValueTask InvokeHandlerAsync(MethodRegistration registration, JsonElement paramsProp, CancellationToken cancellationToken) + { + var parameters = registration.Parameters; + + // Build argument list + var invokeArgs = new object?[parameters.Length]; + + if (registration.SingleObjectParam) + { + // Single-object deserialization: entire `params` → first parameter. + // Every singleObjectParam handler has shape (TRequest, CancellationToken), + // so `params` must be a JSON object. + if (paramsProp.ValueKind != JsonValueKind.Object) + { + throw new InvalidOperationException( + $"Expected JSON object for `params` of single-object-param handler; got '{paramsProp.ValueKind}'."); + } + + for (int i = 0; i < parameters.Length; i++) + { + if (parameters[i].ParameterType == typeof(CancellationToken)) + { + invokeArgs[i] = cancellationToken; + } + else if (i == 0) + { + invokeArgs[i] = paramsProp.Deserialize(_serializerOptions.GetTypeInfo(parameters[i].ParameterType)); + } + } + } + else if (paramsProp.ValueKind == JsonValueKind.Array) + { + // Positional parameters. Optional params (with defaults) are filled when absent. + int jsonIndex = 0; + int arrayLength = paramsProp.GetArrayLength(); + for (int i = 0; i < parameters.Length; i++) + { + if (parameters[i].ParameterType == typeof(CancellationToken)) + { + invokeArgs[i] = cancellationToken; + } + else if (jsonIndex < arrayLength) + { + invokeArgs[i] = paramsProp[jsonIndex].Deserialize(_serializerOptions.GetTypeInfo(parameters[i].ParameterType)); + jsonIndex++; + } + else + { + invokeArgs[i] = parameters[i].HasDefaultValue ? parameters[i].DefaultValue : null; + } + } + } + else if (paramsProp.ValueKind == JsonValueKind.Object) + { + // Named parameters. The CLI sends notifications/requests as a JSON object whose + // property names match the handler's parameter names (camelCased per web defaults). + // Look up each parameter by name; missing optional parameters fall back to defaults. + for (int i = 0; i < parameters.Length; i++) + { + if (parameters[i].ParameterType == typeof(CancellationToken)) + { + invokeArgs[i] = cancellationToken; + } + else if (parameters[i].Name is { } paramName && + TryGetPropertyCaseInsensitive(paramsProp, paramName, out var valueProp)) + { + invokeArgs[i] = valueProp.Deserialize(_serializerOptions.GetTypeInfo(parameters[i].ParameterType)); + } + else + { + invokeArgs[i] = parameters[i].HasDefaultValue ? parameters[i].DefaultValue : null; + } + } + } + else + { + // Missing/null `params` for a handler with required positional parameters is a + // protocol violation. Surface it as an error rather than silently filling defaults. + throw new InvalidOperationException( + $"Unsupported JSON-RPC params shape '{paramsProp.ValueKind}' for handler with positional parameters."); + } + + // Invoke + var result = registration.Handler.DynamicInvoke(invokeArgs); + + // Handlers return one of: a synchronous value, Task (void async), or ValueTask. + if (result is Task task) + { + // Task handlers are not supported — use ValueTask for results. + Debug.Assert(!task.GetType().IsGenericType, "Task handlers are not supported; use ValueTask."); + await task.ConfigureAwait(false); + return null; + } + + if (result is not null && registration.ReturnsValueTaskOfT) + { + var resultType = result.GetType(); + var asTask = (Task)resultType.GetMethod("AsTask")!.Invoke(result, null)!; + await asTask.ConfigureAwait(false); + return asTask.GetType().GetProperty("Result")!.GetValue(asTask); + } + + return result; + } + + private static bool TryGetPropertyCaseInsensitive(JsonElement obj, string name, out JsonElement value) + { + // Fast path: exact match. The CLI uses camelCase property names that match the + // C# parameter names exactly, so this should hit in the common case. + if (obj.TryGetProperty(name, out value)) + { + return true; + } + + foreach (var prop in obj.EnumerateObject()) + { + if (string.Equals(prop.Name, name, StringComparison.OrdinalIgnoreCase)) + { + value = prop.Value; + return true; + } + } + + value = default; + return false; + } + + private JsonElement? SerializeArgs(object?[]? args) + { + if (args is null || args.Length == 0) + { + return null; + } + + // The Copilot CLI uses vscode-jsonrpc-style request handlers, which expect + // `params` to be the single request object (not wrapped in a positional array). + // The other SDKs (Node, Python, Go) all send single-object params, and every + // generated call site here passes exactly one request object. For the rare + // multi-arg case, fall back to a positional array. + if (args.Length == 1) + { + var arg = args[0]; + if (arg is null) + { + return null; + } + + var typeInfo = _serializerOptions.GetTypeInfo(arg.GetType()); + return JsonSerializer.SerializeToElement(arg, typeInfo); + } + + // Source-generated JsonSerializerOptions do not provide metadata for object[], + // so build the JSON array manually, serializing each element with a TypeInfo + // looked up by its runtime type from the merged resolver. + var buffer = new ArrayBufferWriter(); + using (var writer = new Utf8JsonWriter(buffer)) + { + writer.WriteStartArray(); + foreach (var arg in args) + { + if (arg is null) + { + writer.WriteNullValue(); + } + else + { + var typeInfo = _serializerOptions.GetTypeInfo(arg.GetType()); + JsonSerializer.Serialize(writer, arg, typeInfo); + } + } + + writer.WriteEndArray(); + } + + using var doc = JsonDocument.Parse(buffer.WrittenMemory); + return doc.RootElement.Clone(); + } + + private async Task SendResultResponseAsync(JsonElement id, object? result, CancellationToken cancellationToken) + { + try + { + // Convert the result to a JsonElement using the runtime type, looked up via + // the merged resolver. Source-gen serialization of an `object`-typed property + // would otherwise have no way to find metadata for the actual response type + // (e.g. SystemMessageTransformRpcResponse, SessionFsReadFileResult, ...). + JsonElement? resultElement = null; + if (result is not null) + { + var typeInfo = _serializerOptions.GetTypeInfo(result.GetType()); + resultElement = JsonSerializer.SerializeToElement(result, typeInfo); + } + + await SendMessageAsync(new JsonRpcResponse + { + Id = id, + Result = resultElement, + }, JsonRpcWireContext.Default.JsonRpcResponse, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (ex is IOException or ObjectDisposedException or OperationCanceledException) + { + // Connection lost during response — nothing we can do + } + } + + private async Task SendErrorResponseAsync(JsonElement id, int code, string message, CancellationToken cancellationToken) + { + try + { + await SendMessageAsync(new JsonRpcErrorResponse + { + Id = id, + Error = new JsonRpcError { Code = code, Message = message }, + }, JsonRpcWireContext.Default.JsonRpcErrorResponse, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (ex is IOException or ObjectDisposedException or OperationCanceledException) + { + // Connection lost during error response — nothing we can do + } + } + + private async Task SendCancelNotificationAsync(long requestId) + { + try + { + await SendMessageAsync(new JsonRpcNotification + { + Method = "$/cancelRequest", + Params = JsonSerializer.SerializeToElement( + new CancelRequestParams { Id = requestId }, + CancelRequestParamsContext.Default.CancelRequestParams), + }, JsonRpcWireContext.Default.JsonRpcNotification, CancellationToken.None).ConfigureAwait(false); + } + catch (Exception ex) when (ex is IOException or ObjectDisposedException or OperationCanceledException) + { + // Best effort — connection may already be gone + } + } + + private sealed class PendingRequest() : TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + private sealed class MethodRegistration + { + public MethodRegistration(Delegate handler, bool singleObjectParam) + { + Handler = handler; + SingleObjectParam = singleObjectParam; + Parameters = handler.Method.GetParameters(); + ReturnsValueTaskOfT = + handler.Method.ReturnType.IsGenericType && + handler.Method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>); + } + + public Delegate Handler { get; } + public bool SingleObjectParam { get; } + public ParameterInfo[] Parameters { get; } + public bool ReturnsValueTaskOfT { get; } + } + + [JsonSourceGenerationOptions( + JsonSerializerDefaults.Web, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] + [JsonSerializable(typeof(JsonRpcRequest))] + [JsonSerializable(typeof(JsonRpcResponse))] + [JsonSerializable(typeof(JsonRpcErrorResponse))] + [JsonSerializable(typeof(JsonRpcNotification))] + private partial class JsonRpcWireContext : JsonSerializerContext; + + private sealed class JsonRpcRequest + { + [JsonPropertyName("jsonrpc")] + public string Jsonrpc { get; } = "2.0"; + + [JsonPropertyName("id")] + public long Id { get; set; } + + [JsonPropertyName("method")] + public string Method { get; set; } = string.Empty; + + [JsonPropertyName("params")] + public JsonElement? Params { get; set; } + } + + private sealed class JsonRpcResponse + { + [JsonPropertyName("jsonrpc")] + public string Jsonrpc { get; } = "2.0"; + + [JsonPropertyName("id")] + public JsonElement Id { get; set; } + + // JSON-RPC 2.0 requires every response to carry either `result` or `error`. + // vscode-jsonrpc (used by the CLI) rejects responses that have neither with + // "The received response has neither a result nor an error property", so we + // must emit `result: null` for void-returning handlers — overriding the + // context-level WhenWritingNull policy. + [JsonPropertyName("result")] + [JsonIgnore(Condition = JsonIgnoreCondition.Never)] + public JsonElement? Result { get; set; } + } + + private sealed class JsonRpcErrorResponse + { + [JsonPropertyName("jsonrpc")] + public string Jsonrpc { get; } = "2.0"; + + [JsonPropertyName("id")] + public JsonElement Id { get; set; } + + [JsonPropertyName("error")] + public JsonRpcError? Error { get; set; } + } + + private sealed class JsonRpcError + { + [JsonPropertyName("code")] + public int Code { get; set; } + + [JsonPropertyName("message")] + public string Message { get; set; } = string.Empty; + } + + private sealed class JsonRpcNotification + { + [JsonPropertyName("jsonrpc")] + public string Jsonrpc { get; } = "2.0"; + + [JsonPropertyName("method")] + public string Method { get; set; } = string.Empty; + + [JsonPropertyName("params")] + public JsonElement? Params { get; set; } + } + + private sealed class CancelRequestParams + { + [JsonPropertyName("id")] + public long Id { get; set; } + } + + [JsonSerializable(typeof(CancelRequestParams))] + private partial class CancelRequestParamsContext : JsonSerializerContext; +} + +/// +/// Thrown when the JSON-RPC connection is lost unexpectedly. +/// +internal sealed class ConnectionLostException() : IOException("The JSON-RPC connection was lost."); + +/// +/// Thrown when the remote side returns a JSON-RPC error response. +/// +internal sealed class RemoteRpcException(string message, int errorCode, Exception? innerException = null) : Exception(message, innerException) +{ + public int ErrorCode { get; } = errorCode; +} diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index a97d54a30..2d3e803e0 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -5,7 +5,6 @@ using GitHub.Copilot.SDK.Rpc; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using StreamJsonRpc; using System.Collections.Immutable; using System.Text.Json; using System.Text.Json.Nodes; diff --git a/dotnet/test/GitHub.Copilot.SDK.Test.csproj b/dotnet/test/GitHub.Copilot.SDK.Test.csproj index 8e0dbf6b7..e42dc8e4c 100644 --- a/dotnet/test/GitHub.Copilot.SDK.Test.csproj +++ b/dotnet/test/GitHub.Copilot.SDK.Test.csproj @@ -26,7 +26,6 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all - diff --git a/dotnet/test/Harness/TestHelper.cs b/dotnet/test/Harness/TestHelper.cs index f30f24962..36c9be043 100644 --- a/dotnet/test/Harness/TestHelper.cs +++ b/dotnet/test/Harness/TestHelper.cs @@ -14,17 +14,36 @@ public static class TestHelper var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60)); + // Both `finalAssistantMessage` and `sawIdle` are set from two threads — the + // subscription callback (CLI read loop) and CheckExistingMessages (RPC reply). + // We complete only once we've observed both, regardless of which path saw which. + var stateLock = new object(); AssistantMessageEvent? finalAssistantMessage = null; + bool sawIdle = false; + + void TryComplete() + { + AssistantMessageEvent? snapshot; + bool idle; + lock (stateLock) + { + snapshot = finalAssistantMessage; + idle = sawIdle; + } + if (snapshot != null && idle) tcs.TrySetResult(snapshot); + } using var subscription = session.On(evt => { switch (evt) { case AssistantMessageEvent msg: - finalAssistantMessage = msg; + lock (stateLock) { finalAssistantMessage = msg; } + TryComplete(); break; - case SessionIdleEvent when finalAssistantMessage != null: - tcs.TrySetResult(finalAssistantMessage); + case SessionIdleEvent: + lock (stateLock) { sawIdle = true; } + TryComplete(); break; case SessionErrorEvent error: tcs.TrySetException(new Exception(error.Data.Message ?? "session error")); @@ -32,7 +51,8 @@ public static class TestHelper } }); - // Check existing messages + // Backfill from already-delivered messages so we don't lose events that arrived + // between SendAsync returning and the subscription being installed. CheckExistingMessages(); cts.Token.Register(() => tcs.TrySetException(new TimeoutException("Timeout waiting for assistant message"))); @@ -43,8 +63,17 @@ async void CheckExistingMessages() { try { - var existing = await GetExistingFinalResponseAsync(session, alreadyIdle); - if (existing != null) tcs.TrySetResult(existing); + var (existingFinal, existingIdle) = await GetExistingMessagesAsync(session, alreadyIdle); + lock (stateLock) + { + // Preserve a newer message captured by the subscription in the meantime. + if (existingFinal != null && finalAssistantMessage == null) + { + finalAssistantMessage = existingFinal; + } + if (existingIdle) sawIdle = true; + } + TryComplete(); } catch (Exception ex) { @@ -53,7 +82,7 @@ async void CheckExistingMessages() } } - private static async Task GetExistingFinalResponseAsync(CopilotSession session, bool alreadyIdle) + private static async Task<(AssistantMessageEvent? Final, bool SawIdle)> GetExistingMessagesAsync(CopilotSession session, bool alreadyIdle) { var messages = (await session.GetMessagesAsync()).ToList(); @@ -64,15 +93,17 @@ async void CheckExistingMessages() if (error != null) throw new Exception(error.Data.Message ?? "session error"); var idleIdx = alreadyIdle ? currentTurn.Count : currentTurn.FindIndex(m => m is SessionIdleEvent); - if (idleIdx == -1) return null; + var sawIdle = alreadyIdle || idleIdx >= 0; - for (var i = idleIdx - 1; i >= 0; i--) + // Find the most recent assistant message in the turn (whether idle has arrived or not). + var searchEnd = idleIdx >= 0 ? idleIdx : currentTurn.Count; + for (var i = searchEnd - 1; i >= 0; i--) { if (currentTurn[i] is AssistantMessageEvent msg) - return msg; + return (msg, sawIdle); } - return null; + return (null, sawIdle); } public static async Task GetNextEventOfTypeAsync( diff --git a/dotnet/test/SerializationTests.cs b/dotnet/test/SerializationTests.cs index 4a976d2bc..720844533 100644 --- a/dotnet/test/SerializationTests.cs +++ b/dotnet/test/SerializationTests.cs @@ -5,68 +5,14 @@ using Xunit; using System.Text.Json; using System.Text.Json.Serialization; -using StreamJsonRpc; namespace GitHub.Copilot.SDK.Test; /// -/// Tests for JSON serialization compatibility, particularly for StreamJsonRpc types -/// that are needed when CancellationTokens fire during JSON-RPC operations. -/// This test suite verifies the fix for https://github.com/PureWeen/PolyPilot/issues/319 +/// Tests for JSON serialization compatibility with the SDK's configured options. /// public class SerializationTests { - /// - /// Verifies that StreamJsonRpc.RequestId can be round-tripped using the SDK's configured - /// JsonSerializerOptions. This is critical for preventing NotSupportedException when - /// StandardCancellationStrategy fires during JSON-RPC operations. - /// - [Fact] - public void RequestId_CanBeSerializedAndDeserialized_WithSdkOptions() - { - var options = GetSerializerOptions(); - - // Long id - var jsonLong = JsonSerializer.Serialize(new RequestId(42L), options); - Assert.Equal("42", jsonLong); - Assert.Equal(new RequestId(42L), JsonSerializer.Deserialize(jsonLong, options)); - - // String id - var jsonStr = JsonSerializer.Serialize(new RequestId("req-1"), options); - Assert.Equal("\"req-1\"", jsonStr); - Assert.Equal(new RequestId("req-1"), JsonSerializer.Deserialize(jsonStr, options)); - - // Null id - var jsonNull = JsonSerializer.Serialize(RequestId.Null, options); - Assert.Equal("null", jsonNull); - Assert.Equal(RequestId.Null, JsonSerializer.Deserialize(jsonNull, options)); - } - - [Theory] - [InlineData(0L)] - [InlineData(-1L)] - [InlineData(long.MaxValue)] - public void RequestId_NumericEdgeCases_RoundTrip(long id) - { - var options = GetSerializerOptions(); - var requestId = new RequestId(id); - var json = JsonSerializer.Serialize(requestId, options); - Assert.Equal(requestId, JsonSerializer.Deserialize(json, options)); - } - - /// - /// Verifies the SDK's options can resolve type info for RequestId, - /// ensuring AOT-safe serialization without falling back to reflection. - /// - [Fact] - public void SerializerOptions_CanResolveRequestIdTypeInfo() - { - var options = GetSerializerOptions(); - var typeInfo = options.GetTypeInfo(typeof(RequestId)); - Assert.NotNull(typeInfo); - Assert.Equal(typeof(RequestId), typeInfo.Type); - } - [Fact] public void ProviderConfig_CanSerializeHeaders_WithSdkOptions() { diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index 4baf8061c..9c8332c09 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -840,7 +840,13 @@ function resolvedResultTypeName(method: RpcMethod): string { return resultTypeName(method); } -/** Returns the Task or Task string for a method's result type. */ +/** Returns the ValueTask or ValueTask string for an incoming-handler's result type. */ +function handlerTaskType(method: RpcMethod): string { + const schema = getMethodResultSchema(method); + return !isVoidSchema(schema) ? `ValueTask<${resolvedResultTypeName(method)}>` : "ValueTask"; +} + +/** Returns the Task or Task string for an outgoing-call wrapper's result type. */ function resultTaskType(method: RpcMethod): string { const schema = getMethodResultSchema(method); return !isVoidSchema(schema) ? `Task<${resolvedResultTypeName(method)}>` : "Task"; @@ -1465,7 +1471,7 @@ function emitClientSessionApiRegistration(clientSchema: Record, lines.push(""); lines.push(`/// Registers client session API handlers on a JSON-RPC connection.`); - lines.push(`public static class ClientSessionApiRegistration`); + lines.push(`internal static class ClientSessionApiRegistration`); lines.push(`{`); lines.push(` /// `); lines.push(` /// Registers handlers for server-to-client session API calls.`); @@ -1482,11 +1488,10 @@ function emitClientSessionApiRegistration(clientSchema: Record, const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; const resultSchema = getMethodResultSchema(method); const paramsClass = paramsTypeName(method); - const taskType = resultTaskType(method); - const registrationVar = `register${typeToClassName(method.rpcMethod)}Method`; + const taskType = handlerTaskType(method); if (hasParams) { - lines.push(` var ${registrationVar} = (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); lines.push(` {`); lines.push(` var handler = getHandlers(request.SessionId).${handlerProperty};`); lines.push(` if (handler is null) throw new InvalidOperationException($"No ${groupName} handler registered for session: {request.SessionId}");`); @@ -1495,13 +1500,9 @@ function emitClientSessionApiRegistration(clientSchema: Record, } else { lines.push(` await handler.${handlerMethod}(request, cancellationToken);`); } - lines.push(` });`); - lines.push(` rpc.AddLocalRpcMethod(${registrationVar}.Method, ${registrationVar}.Target!, new JsonRpcMethodAttribute("${method.rpcMethod}")`); - lines.push(` {`); - lines.push(` UseSingleObjectParameterDeserialization = true`); - lines.push(` });`); + lines.push(` }), singleObjectParam: true);`); } else { - lines.push(` rpc.AddLocalRpcMethod("${method.rpcMethod}", (Func)(_ =>`); + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func)(_ =>`); lines.push(` throw new InvalidOperationException("No params provided for ${method.rpcMethod}")));`); } } @@ -1544,7 +1545,6 @@ using System.ComponentModel.DataAnnotations; using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; -using StreamJsonRpc; namespace GitHub.Copilot.SDK.Rpc;