diff --git a/mcp/client.go b/mcp/client.go index 7349ba9c..d3b0955e 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -63,6 +63,13 @@ type ClientOptions struct { // Setting CreateMessageHandler to a non-nil value causes the client to // advertise the sampling capability. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // SamplingSupportsTools indicates that the client's CreateMessageHandler + // supports tool use. If true and CreateMessageHandler is set, the + // sampling.tools capability is advertised. + SamplingSupportsTools bool + // SamplingSupportsContext indicates that the client supports + // includeContext values other than "none". + SamplingSupportsContext bool // ElicitationHandler handles incoming requests for elicitation/create. // // Setting ElicitationHandler to a non-nil value causes the client to @@ -131,6 +138,12 @@ func (c *Client) capabilities() *ClientCapabilities { caps.Roots.ListChanged = true if c.opts.CreateMessageHandler != nil { caps.Sampling = &SamplingCapabilities{} + if c.opts.SamplingSupportsTools { + caps.Sampling.Tools = &SamplingToolsCapabilities{} + } + if c.opts.SamplingSupportsContext { + caps.Sampling.Context = &SamplingContextCapabilities{} + } } if c.opts.ElicitationHandler != nil { caps.Elicitation = &ElicitationCapabilities{} diff --git a/mcp/content.go b/mcp/content.go index fb1a0d1e..b00e0a42 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -14,7 +14,10 @@ import ( ) // A Content is a [TextContent], [ImageContent], [AudioContent], -// [ResourceLink], or [EmbeddedResource]. +// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent]. +// +// Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling +// message contexts (CreateMessageParams/CreateMessageResult). type Content interface { MarshalJSON() ([]byte, error) fromWire(*wireContent) @@ -183,6 +186,104 @@ func (c *EmbeddedResource) fromWire(wire *wireContent) { c.Annotations = wire.Annotations } +// ToolUseContent represents a request from the assistant to invoke a tool. +// This content type is only valid in sampling messages. +type ToolUseContent struct { + // ID is a unique identifier for this tool use, used to match with ToolResultContent. + ID string + // Name is the name of the tool to invoke. + Name string + // Input contains the tool arguments as a JSON object. + Input map[string]any + Meta Meta +} + +func (c *ToolUseContent) MarshalJSON() ([]byte, error) { + input := c.Input + if input == nil { + input = map[string]any{} + } + wire := struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]any `json:"input"` + Meta Meta `json:"_meta,omitempty"` + }{ + Type: "tool_use", + ID: c.ID, + Name: c.Name, + Input: input, + Meta: c.Meta, + } + return json.Marshal(wire) +} + +func (c *ToolUseContent) fromWire(wire *wireContent) { + c.ID = wire.ID + c.Name = wire.Name + c.Input = wire.Input + c.Meta = wire.Meta +} + +// ToolResultContent represents the result of a tool invocation. +// This content type is only valid in sampling messages with role "user". +type ToolResultContent struct { + // ToolUseID references the ID from the corresponding ToolUseContent. + ToolUseID string + // Content holds the unstructured result of the tool call. + Content []Content + // StructuredContent holds an optional structured result as a JSON object. + StructuredContent any + // IsError indicates whether the tool call ended in an error. + IsError bool + Meta Meta +} + +func (c *ToolResultContent) MarshalJSON() ([]byte, error) { + // Marshal nested content + var contentWire []*wireContent + for _, content := range c.Content { + data, err := content.MarshalJSON() + if err != nil { + return nil, err + } + var w wireContent + if err := json.Unmarshal(data, &w); err != nil { + return nil, err + } + contentWire = append(contentWire, &w) + } + if contentWire == nil { + contentWire = []*wireContent{} // avoid JSON null + } + + wire := struct { + Type string `json:"type"` + ToolUseID string `json:"toolUseId"` + Content []*wireContent `json:"content"` + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` + Meta Meta `json:"_meta,omitempty"` + }{ + Type: "tool_result", + ToolUseID: c.ToolUseID, + Content: contentWire, + StructuredContent: c.StructuredContent, + IsError: c.IsError, + Meta: c.Meta, + } + return json.Marshal(wire) +} + +func (c *ToolResultContent) fromWire(wire *wireContent) { + c.ToolUseID = wire.ToolUseID + c.StructuredContent = wire.StructuredContent + c.IsError = wire.IsError + c.Meta = wire.Meta + // Content is handled separately in contentFromWire due to nested content +} + // ResourceContents contains the contents of a specific resource or // sub-resource. type ResourceContents struct { @@ -224,10 +325,9 @@ func (r *ResourceContents) MarshalJSON() ([]byte, error) { // wireContent is the wire format for content. // It represents the protocol types TextContent, ImageContent, AudioContent, -// ResourceLink, and EmbeddedResource. +// ResourceLink, EmbeddedResource, ToolUseContent, and ToolResultContent. // The Type field distinguishes them. In the protocol, each type has a constant // value for the field. -// At most one of Text, Data, Resource, and URI is non-zero. type wireContent struct { Type string `json:"type"` Text string `json:"text,omitempty"` @@ -242,6 +342,14 @@ type wireContent struct { Meta Meta `json:"_meta,omitempty"` Annotations *Annotations `json:"annotations,omitempty"` Icons []Icon `json:"icons,omitempty"` + // Fields for ToolUseContent (type: "tool_use") + ID string `json:"id,omitempty"` + Input map[string]any `json:"input,omitempty"` + // Fields for ToolResultContent (type: "tool_result") + ToolUseID string `json:"toolUseId,omitempty"` + ToolResultContent []*wireContent `json:"content,omitempty"` // nested content for tool_result + StructuredContent any `json:"structuredContent,omitempty"` + IsError bool `json:"isError,omitempty"` } func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) { @@ -284,6 +392,27 @@ func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) v := new(EmbeddedResource) v.fromWire(wire) return v, nil + case "tool_use": + v := new(ToolUseContent) + v.fromWire(wire) + return v, nil + case "tool_result": + v := new(ToolResultContent) + v.fromWire(wire) + // Handle nested content - tool_result content can contain text, image, audio, + // resource_link, and resource (same as CallToolResult.content) + if wire.ToolResultContent != nil { + toolResultContentAllow := map[string]bool{ + "text": true, "image": true, "audio": true, + "resource_link": true, "resource": true, + } + nestedContent, err := contentsFromWire(wire.ToolResultContent, toolResultContentAllow) + if err != nil { + return nil, fmt.Errorf("tool_result nested content: %w", err) + } + v.Content = nestedContent + } + return v, nil } - return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type) + return nil, fmt.Errorf("unrecognized content type %q", wire.Type) } diff --git a/mcp/protocol.go b/mcp/protocol.go index 8a88f8e2..b395dc26 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -291,6 +291,11 @@ type CreateMessageParams struct { Meta `json:"_meta,omitempty"` // A request to include context from one or more MCP servers (including the // caller), to be attached to the prompt. The client may ignore this request. + // + // The default behavior is Default is "none". Values "thisServer" and + // "allServers" are soft-deprecated. Servers SHOULD only use these values if + // the client declares ClientCapabilities.sampling.context. These values may + // be removed in future spec releases. IncludeContext string `json:"includeContext,omitempty"` // The maximum number of tokens to sample, as requested by the server. The // client may choose to sample fewer tokens than requested. @@ -307,6 +312,12 @@ type CreateMessageParams struct { // may modify or omit this prompt. SystemPrompt string `json:"systemPrompt,omitempty"` Temperature float64 `json:"temperature,omitempty"` + // Tools is an optional list of tools available for the model to use. + // Requires the client's sampling.tools capability. + Tools []*Tool `json:"tools,omitempty"` + // ToolChoice controls how the model should use tools. + // Requires the client's sampling.tools capability. + ToolChoice *ToolChoice `json:"toolChoice,omitempty"` } func (x *CreateMessageParams) isParams() {} @@ -326,6 +337,12 @@ type CreateMessageResult struct { Model string `json:"model"` Role Role `json:"role"` // The reason why sampling stopped, if known. + // + // Standard values: + // - "endTurn": natural end of the assistant's turn + // - "stopSequence": a stop sequence was encountered + // - "maxTokens": reached the maximyum token limit + // - "toolUse": the model wants to use one or more tools StopReason string `json:"stopReason,omitempty"` } @@ -339,8 +356,9 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &wire); err != nil { return err } + // Allow text, image, audio, and tool_use in results var err error - if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true}); err != nil { return err } *r = CreateMessageResult(wire.result) @@ -876,7 +894,27 @@ func (x *RootsListChangedParams) GetProgressToken() any { return getProgressTok func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } // SamplingCapabilities describes the capabilities for sampling. -type SamplingCapabilities struct{} +type SamplingCapabilities struct { + // Context indicates the client supports includeContext values other than "none". + Context *SamplingContextCapabilities `json:"context,omitempty"` + // Tools indicates the client supports tools and toolChoice in sampling requests. + Tools *SamplingToolsCapabilities `json:"tools,omitempty"` +} + +// SamplingContextCapabilities indicates the client supports context inclusion. +type SamplingContextCapabilities struct{} + +// SamplingToolsCapabilities indicates the client supports tool use in sampling. +type SamplingToolsCapabilities struct{} + +// ToolChoice controls how the model uses tools during sampling. +type ToolChoice struct { + // Mode controls tool invocation behavior: + // - "auto": Model decides whether to use tools (default) + // - "required": Model must use at least one tool + // - "none": Model must not use any tools + Mode string `json:"mode,omitempty"` +} // ElicitationCapabilities describes the capabilities for elicitation. // @@ -895,6 +933,9 @@ type URLElicitationCapabilities struct { } // Describes a message issued to or received from an LLM API. +// +// For assistant messages, Content may be text, image, audio, or tool_use. +// For user messages, Content may be text, image, audio, or tool_result. type SamplingMessage struct { Content Content `json:"content"` Role Role `json:"role"` @@ -911,8 +952,9 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &wire); err != nil { return err } + // Allow text, image, audio, tool_use, and tool_result in sampling messages var err error - if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true, "tool_result": true}); err != nil { return err } *m = SamplingMessage(wire.msg) diff --git a/mcp/sampling_tools_test.go b/mcp/sampling_tools_test.go new file mode 100644 index 00000000..d7cd1dd4 --- /dev/null +++ b/mcp/sampling_tools_test.go @@ -0,0 +1,762 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "encoding/json" + "reflect" + "testing" +) + +func TestToolUseContent_MarshalJSON(t *testing.T) { + tests := []struct { + name string + content *ToolUseContent + want map[string]any + }{ + { + name: "basic tool use", + content: &ToolUseContent{ + ID: "tool_123", + Name: "calculator", + Input: map[string]any{ + "operation": "add", + "x": 1.0, + "y": 2.0, + }, + }, + want: map[string]any{ + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": map[string]any{ + "operation": "add", + "x": 1.0, + "y": 2.0, + }, + }, + }, + { + name: "tool use with nil input", + content: &ToolUseContent{ + ID: "tool_456", + Name: "no_args_tool", + Input: nil, + }, + want: map[string]any{ + "type": "tool_use", + "id": "tool_456", + "name": "no_args_tool", + "input": map[string]any{}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.content.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToolResultContent_MarshalJSON(t *testing.T) { + tests := []struct { + name string + content *ToolResultContent + want map[string]any + }{ + { + name: "basic tool result", + content: &ToolResultContent{ + ToolUseID: "tool_123", + Content: []Content{&TextContent{Text: "42"}}, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_123", + "content": []any{ + map[string]any{ + "type": "text", + "text": "42", + }, + }, + }, + }, + { + name: "tool result with error", + content: &ToolResultContent{ + ToolUseID: "tool_456", + Content: []Content{&TextContent{Text: "division by zero"}}, + IsError: true, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_456", + "content": []any{ + map[string]any{ + "type": "text", + "text": "division by zero", + }, + }, + "isError": true, + }, + }, + { + name: "tool result with structured content", + content: &ToolResultContent{ + ToolUseID: "tool_789", + Content: []Content{&TextContent{Text: `{"result": 42}`}}, + StructuredContent: map[string]any{"result": 42.0}, + }, + want: map[string]any{ + "type": "tool_result", + "toolUseId": "tool_789", + "structuredContent": map[string]any{"result": 42.0}, + "content": []any{ + map[string]any{ + "type": "text", + "text": `{"result": 42}`, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.content.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToolUseContent_UnmarshalJSON(t *testing.T) { + jsonData := `{ + "type": "tool_use", + "id": "tool_123", + "name": "calculator", + "input": {"x": 1, "y": 2} + }` + + wire := &wireContent{} + if err := json.Unmarshal([]byte(jsonData), wire); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_use": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + toolUse, ok := content.(*ToolUseContent) + if !ok { + t.Fatalf("expected *ToolUseContent, got %T", content) + } + + if toolUse.ID != "tool_123" { + t.Errorf("ID = %v, want %v", toolUse.ID, "tool_123") + } + if toolUse.Name != "calculator" { + t.Errorf("Name = %v, want %v", toolUse.Name, "calculator") + } + if toolUse.Input["x"] != 1.0 || toolUse.Input["y"] != 2.0 { + t.Errorf("Input = %v, want map with x=1, y=2", toolUse.Input) + } +} + +func TestToolResultContent_UnmarshalJSON(t *testing.T) { + jsonData := `{ + "type": "tool_result", + "toolUseId": "tool_123", + "content": [{"type": "text", "text": "42"}], + "isError": false + }` + + wire := &wireContent{} + if err := json.Unmarshal([]byte(jsonData), wire); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + content, err := contentFromWire(wire, map[string]bool{"tool_result": true}) + if err != nil { + t.Fatalf("contentFromWire() error = %v", err) + } + + toolResult, ok := content.(*ToolResultContent) + if !ok { + t.Fatalf("expected *ToolResultContent, got %T", content) + } + + if toolResult.ToolUseID != "tool_123" { + t.Errorf("ToolUseID = %v, want %v", toolResult.ToolUseID, "tool_123") + } + if toolResult.IsError { + t.Errorf("IsError = %v, want false", toolResult.IsError) + } + if len(toolResult.Content) != 1 { + t.Fatalf("len(Content) = %v, want 1", len(toolResult.Content)) + } + textContent, ok := toolResult.Content[0].(*TextContent) + if !ok { + t.Fatalf("expected *TextContent, got %T", toolResult.Content[0]) + } + if textContent.Text != "42" { + t.Errorf("Text = %v, want %v", textContent.Text, "42") + } +} + +func TestCreateMessageResult_ToolUseContent(t *testing.T) { + // Test that CreateMessageResult can unmarshal tool_use content + jsonData := `{ + "content": {"type": "tool_use", "id": "tool_1", "name": "calculator", "input": {"x": 1}}, + "model": "test-model", + "role": "assistant", + "stopReason": "toolUse" + }` + + var result CreateMessageResult + if err := json.Unmarshal([]byte(jsonData), &result); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if result.Model != "test-model" { + t.Errorf("Model = %v, want %v", result.Model, "test-model") + } + if result.StopReason != "toolUse" { + t.Errorf("StopReason = %v, want %v", result.StopReason, "toolUse") + } + + toolUse, ok := result.Content.(*ToolUseContent) + if !ok { + t.Fatalf("Content expected *ToolUseContent, got %T", result.Content) + } + if toolUse.ID != "tool_1" { + t.Errorf("Content.ID = %v, want %v", toolUse.ID, "tool_1") + } + if toolUse.Name != "calculator" { + t.Errorf("Content.Name = %v, want %v", toolUse.Name, "calculator") + } +} + +func TestSamplingMessage_ToolUseContent(t *testing.T) { + // Test that SamplingMessage can unmarshal tool_use content (assistant role) + jsonData := `{ + "content": {"type": "tool_use", "id": "tool_1", "name": "calc", "input": {}}, + "role": "assistant" + }` + + var msg SamplingMessage + if err := json.Unmarshal([]byte(jsonData), &msg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if msg.Role != "assistant" { + t.Errorf("Role = %v, want %v", msg.Role, "assistant") + } + + toolUse, ok := msg.Content.(*ToolUseContent) + if !ok { + t.Fatalf("Content expected *ToolUseContent, got %T", msg.Content) + } + if toolUse.ID != "tool_1" { + t.Errorf("Content.ID = %v, want %v", toolUse.ID, "tool_1") + } +} + +func TestSamplingMessage_ToolResultContent(t *testing.T) { + // Test that SamplingMessage can unmarshal tool_result content (user role) + jsonData := `{ + "content": {"type": "tool_result", "toolUseId": "tool_1", "content": [{"type": "text", "text": "42"}]}, + "role": "user" + }` + + var msg SamplingMessage + if err := json.Unmarshal([]byte(jsonData), &msg); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if msg.Role != "user" { + t.Errorf("Role = %v, want %v", msg.Role, "user") + } + + toolResult, ok := msg.Content.(*ToolResultContent) + if !ok { + t.Fatalf("Content expected *ToolResultContent, got %T", msg.Content) + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("Content.ToolUseID = %v, want %v", toolResult.ToolUseID, "tool_1") + } + if len(toolResult.Content) != 1 { + t.Fatalf("len(Content.Content) = %v, want 1", len(toolResult.Content)) + } +} + +func TestSamplingCapabilities_WithTools(t *testing.T) { + caps := &SamplingCapabilities{ + Tools: &SamplingToolsCapabilities{}, + Context: &SamplingContextCapabilities{}, + } + + data, err := json.Marshal(caps) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var caps2 SamplingCapabilities + if err := json.Unmarshal(data, &caps2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if caps2.Tools == nil { + t.Error("Tools capability should not be nil") + } + if caps2.Context == nil { + t.Error("Context capability should not be nil") + } +} + +func TestSamplingCapabilities_Empty(t *testing.T) { + // Test backward compatibility - empty struct should marshal/unmarshal correctly + caps := &SamplingCapabilities{} + + data, err := json.Marshal(caps) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var caps2 SamplingCapabilities + if err := json.Unmarshal(data, &caps2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if caps2.Tools != nil { + t.Error("Tools capability should be nil for empty capabilities") + } + if caps2.Context != nil { + t.Error("Context capability should be nil for empty capabilities") + } +} + +func TestCreateMessageParams_WithTools(t *testing.T) { + params := &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + { + Role: "user", + Content: &TextContent{Text: "Calculate 1+1"}, + }, + }, + Tools: []*Tool{ + { + Name: "calculator", + Description: "A calculator tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{"type": "number"}, + "y": map[string]any{"type": "number"}, + }, + }, + }, + }, + ToolChoice: &ToolChoice{Mode: "auto"}, + } + + data, err := json.Marshal(params) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var params2 CreateMessageParams + if err := json.Unmarshal(data, ¶ms2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if len(params2.Tools) != 1 { + t.Fatalf("len(Tools) = %v, want 1", len(params2.Tools)) + } + if params2.Tools[0].Name != "calculator" { + t.Errorf("Tools[0].Name = %v, want %v", params2.Tools[0].Name, "calculator") + } + if params2.ToolChoice == nil || params2.ToolChoice.Mode != "auto" { + t.Errorf("ToolChoice.Mode = %v, want %v", params2.ToolChoice, &ToolChoice{Mode: "auto"}) + } +} + +func TestToolChoice_Modes(t *testing.T) { + tests := []struct { + name string + mode string + }{ + {"auto", "auto"}, + {"required", "required"}, + {"none", "none"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &ToolChoice{Mode: tt.mode} + data, err := json.Marshal(tc) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var tc2 ToolChoice + if err := json.Unmarshal(data, &tc2); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if tc2.Mode != tt.mode { + t.Errorf("Mode = %v, want %v", tc2.Mode, tt.mode) + } + }) + } +} + +// Integration tests + +func TestSamplingWithTools_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Track what the client received + var receivedParams *CreateMessageParams + + // Client with tools capability + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + receivedParams = req.Params + // Return a tool use response + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &ToolUseContent{ + ID: "tool_call_1", + Name: "calculator", + Input: map[string]any{"x": 1.0, "y": 2.0}, + }, + StopReason: "toolUse", + }, nil + }, + SamplingSupportsTools: true, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessage with tools + result, err := ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &TextContent{Text: "Calculate 1+2"}}, + }, + Tools: []*Tool{ + { + Name: "calculator", + Description: "A calculator", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{"type": "number"}, + "y": map[string]any{"type": "number"}, + }, + }, + }, + }, + ToolChoice: &ToolChoice{Mode: "auto"}, + }) + if err != nil { + t.Fatalf("CreateMessage() error = %v", err) + } + + // Verify client received the tools + if receivedParams == nil { + t.Fatal("client did not receive params") + } + if len(receivedParams.Tools) != 1 { + t.Errorf("client received %d tools, want 1", len(receivedParams.Tools)) + } + if receivedParams.Tools[0].Name != "calculator" { + t.Errorf("tool name = %v, want calculator", receivedParams.Tools[0].Name) + } + if receivedParams.ToolChoice == nil || receivedParams.ToolChoice.Mode != "auto" { + t.Errorf("tool choice mode = %v, want auto", receivedParams.ToolChoice) + } + + // Verify server received the tool use response + if result.StopReason != "toolUse" { + t.Errorf("StopReason = %v, want toolUse", result.StopReason) + } + toolUse, ok := result.Content.(*ToolUseContent) + if !ok { + t.Fatalf("Content type = %T, want *ToolUseContent", result.Content) + } + if toolUse.ID != "tool_call_1" { + t.Errorf("ToolUse.ID = %v, want tool_call_1", toolUse.ID) + } + if toolUse.Name != "calculator" { + t.Errorf("ToolUse.Name = %v, want calculator", toolUse.Name) + } +} + +func TestSamplingWithToolResult_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Track messages received by client + var receivedMessages []*SamplingMessage + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + receivedMessages = req.Params.Messages + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "The result is 3"}, + }, nil + }, + SamplingSupportsTools: true, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessage with tool result in messages + _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &TextContent{Text: "Calculate 1+2"}}, + {Role: "assistant", Content: &ToolUseContent{ + ID: "tool_1", + Name: "calculator", + Input: map[string]any{"x": 1.0, "y": 2.0}, + }}, + {Role: "user", Content: &ToolResultContent{ + ToolUseID: "tool_1", + Content: []Content{&TextContent{Text: "3"}}, + }}, + }, + }) + if err != nil { + t.Fatalf("CreateMessage() error = %v", err) + } + + // Verify client received all messages including tool content + if len(receivedMessages) != 3 { + t.Fatalf("received %d messages, want 3", len(receivedMessages)) + } + + // Check first message is text + if _, ok := receivedMessages[0].Content.(*TextContent); !ok { + t.Errorf("message[0] content type = %T, want *TextContent", receivedMessages[0].Content) + } + + // Check second message is tool use + toolUse, ok := receivedMessages[1].Content.(*ToolUseContent) + if !ok { + t.Fatalf("message[1] content type = %T, want *ToolUseContent", receivedMessages[1].Content) + } + if toolUse.ID != "tool_1" { + t.Errorf("toolUse.ID = %v, want tool_1", toolUse.ID) + } + + // Check third message is tool result + toolResult, ok := receivedMessages[2].Content.(*ToolResultContent) + if !ok { + t.Fatalf("message[2] content type = %T, want *ToolResultContent", receivedMessages[2].Content) + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("toolResult.ToolUseID = %v, want tool_1", toolResult.ToolUseID) + } + if len(toolResult.Content) != 1 { + t.Fatalf("toolResult.Content len = %d, want 1", len(toolResult.Content)) + } + if tc, ok := toolResult.Content[0].(*TextContent); !ok || tc.Text != "3" { + t.Errorf("toolResult.Content[0] = %v, want TextContent with '3'", toolResult.Content[0]) + } +} + +func TestSamplingToolsCapability_Integration(t *testing.T) { + ctx := context.Background() + + t.Run("client advertises tools capability", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, _ *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "m", Content: &TextContent{}}, nil + }, + SamplingSupportsTools: true, + SamplingSupportsContext: true, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Check server sees client capabilities + caps := ss.InitializeParams().Capabilities + if caps.Sampling == nil { + t.Fatal("client should advertise sampling capability") + } + if caps.Sampling.Tools == nil { + t.Error("client should advertise sampling.tools capability") + } + if caps.Sampling.Context == nil { + t.Error("client should advertise sampling.context capability") + } + }) + + t.Run("client without tools capability", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, _ *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "m", Content: &TextContent{}}, nil + }, + // SamplingSupportsTools not set + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Check server sees client capabilities + caps := ss.InitializeParams().Capabilities + if caps.Sampling == nil { + t.Fatal("client should advertise sampling capability") + } + if caps.Sampling.Tools != nil { + t.Error("client should NOT advertise sampling.tools capability") + } + if caps.Sampling.Context != nil { + t.Error("client should NOT advertise sampling.context capability") + } + }) +} + +func TestSamplingToolResultWithError_Integration(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + var receivedMessages []*SamplingMessage + + client := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(_ context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { + receivedMessages = req.Params.Messages + return &CreateMessageResult{ + Model: "test-model", + Role: "assistant", + Content: &TextContent{Text: "I see the tool failed"}, + }, nil + }, + SamplingSupportsTools: true, + }) + + server := NewServer(testImpl, nil) + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Server sends CreateMessage with error tool result + _, err = ss.CreateMessage(ctx, &CreateMessageParams{ + MaxTokens: 1000, + Messages: []*SamplingMessage{ + {Role: "user", Content: &ToolResultContent{ + ToolUseID: "tool_1", + Content: []Content{&TextContent{Text: "division by zero"}}, + IsError: true, + }}, + }, + }) + if err != nil { + t.Fatalf("CreateMessage() error = %v", err) + } + + if len(receivedMessages) != 1 { + t.Fatalf("received %d messages, want 1", len(receivedMessages)) + } + + toolResult, ok := receivedMessages[0].Content.(*ToolResultContent) + if !ok { + t.Fatalf("content type = %T, want *ToolResultContent", receivedMessages[0].Content) + } + if !toolResult.IsError { + t.Error("IsError should be true") + } + if toolResult.ToolUseID != "tool_1" { + t.Errorf("ToolUseID = %v, want tool_1", toolResult.ToolUseID) + } +}