Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/api/providers/__tests__/anthropic-vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ vitest.mock("@ai-sdk/google-vertex/anthropic", () => ({
}))

// Mock ai-sdk transform utilities
vitest.mock("../../transform/sanitize-messages", () => ({
sanitizeMessagesForProvider: vitest.fn().mockImplementation((msgs: any[]) => msgs),
}))

vitest.mock("../../transform/ai-sdk", () => ({
convertToAiSdkMessages: vitest.fn().mockReturnValue([{ role: "user", content: [{ type: "text", text: "Hello" }] }]),
convertToolsForAiSdk: vitest.fn().mockReturnValue(undefined),
processAiSdkStreamPart: vitest.fn().mockImplementation(function* (part: any) {
if (part.type === "text-delta") {
Expand All @@ -59,7 +62,7 @@ vitest.mock("../../transform/ai-sdk", () => ({
}))

// Import mocked modules
import { convertToAiSdkMessages, convertToolsForAiSdk, mapToolChoice } from "../../transform/ai-sdk"
import { convertToolsForAiSdk, mapToolChoice } from "../../transform/ai-sdk"
import { Anthropic } from "@anthropic-ai/sdk"

// Helper: create a mock provider function
Expand Down
56 changes: 51 additions & 5 deletions src/api/providers/__tests__/anthropic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ vitest.mock("@ai-sdk/anthropic", () => ({
}))

// Mock ai-sdk transform utilities
vitest.mock("../../transform/sanitize-messages", () => ({
sanitizeMessagesForProvider: vitest.fn().mockImplementation((msgs: any[]) => msgs),
}))

vitest.mock("../../transform/ai-sdk", () => ({
convertToAiSdkMessages: vitest.fn().mockReturnValue([{ role: "user", content: [{ type: "text", text: "Hello" }] }]),
convertToolsForAiSdk: vitest.fn().mockReturnValue(undefined),
processAiSdkStreamPart: vitest.fn().mockImplementation(function* (part: any) {
if (part.type === "text-delta") {
Expand All @@ -54,7 +57,8 @@ vitest.mock("../../transform/ai-sdk", () => ({
}))

// Import mocked modules
import { convertToAiSdkMessages, convertToolsForAiSdk, mapToolChoice } from "../../transform/ai-sdk"
import { convertToolsForAiSdk, mapToolChoice } from "../../transform/ai-sdk"
import { sanitizeMessagesForProvider } from "../../transform/sanitize-messages"
import { Anthropic } from "@anthropic-ai/sdk"

// Helper: create a mock provider function
Expand Down Expand Up @@ -82,9 +86,6 @@ describe("AnthropicHandler", () => {

// Re-set mock defaults after clearAllMocks
mockCreateAnthropic.mockReturnValue(mockProviderFn)
vitest
.mocked(convertToAiSdkMessages)
.mockReturnValue([{ role: "user", content: [{ type: "text", text: "Hello" }] }])
vitest.mocked(convertToolsForAiSdk).mockReturnValue(undefined)
vitest.mocked(mapToolChoice).mockReturnValue(undefined)
})
Expand Down Expand Up @@ -399,6 +400,51 @@ describe("AnthropicHandler", () => {
expect(endChunk).toBeDefined()
})

it("should strip reasoning_details and reasoning_content from messages before sending to API", async () => {
// Override the identity mock with the real implementation for this test
const { sanitizeMessagesForProvider: realSanitize } = await vi.importActual<
typeof import("../../transform/sanitize-messages")
>("../../transform/sanitize-messages")
vi.mocked(sanitizeMessagesForProvider).mockImplementation(realSanitize)

setupStreamTextMock([{ type: "text-delta", text: "test" }])

// Simulate messages with extra legacy fields that survive JSON deserialization
const messagesWithExtraFields = [
{
role: "user",
content: [{ type: "text" as const, text: "Hello" }],
},
{
role: "assistant",
content: [{ type: "text" as const, text: "Hi" }],
reasoning_details: [{ type: "thinking", thinking: "some reasoning" }],
reasoning_content: "some reasoning content",
},
{
role: "user",
content: [{ type: "text" as const, text: "Follow up" }],
},
] as any

const stream = handler.createMessage(systemPrompt, messagesWithExtraFields)

for await (const _chunk of stream) {
// Consume stream
}

// Verify streamText was called exactly once
expect(mockStreamText).toHaveBeenCalledTimes(1)
const callArgs = mockStreamText.mock.calls[0]![0]
for (const msg of callArgs.messages) {
expect(msg).not.toHaveProperty("reasoning_details")
expect(msg).not.toHaveProperty("reasoning_content")
}
// Verify the rest of the message is preserved
expect(callArgs.messages[1].role).toBe("assistant")
expect(callArgs.messages[1].content).toEqual([{ type: "text", text: "Hi" }])
})

it("should pass system prompt via system param when no systemProviderOptions", async () => {
setupStreamTextMock([{ type: "text-delta", text: "test" }])

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/__tests__/azure.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ describe("AzureHandler", () => {
for await (const chunk of stream) {
chunks.push(chunk)
}
}).rejects.toThrow("Azure AI Foundry")
}).rejects.toThrow("API Error")
})
})

Expand Down
4 changes: 2 additions & 2 deletions src/api/providers/__tests__/baseten.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ describe("BasetenHandler", () => {
for await (const _ of stream) {
// consume stream
}
}).rejects.toThrow("Baseten: API Error")
}).rejects.toThrow("API Error")
})

it("should preserve status codes in error handling", async () => {
Expand All @@ -439,7 +439,7 @@ describe("BasetenHandler", () => {
}
expect.fail("Should have thrown an error")
} catch (error: any) {
expect(error.message).toContain("Baseten")
expect(error.message).toContain("Rate limit exceeded")
expect(error.status).toBe(429)
}
})
Expand Down
44 changes: 21 additions & 23 deletions src/api/providers/__tests__/bedrock-error-handling.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,11 @@ describe("AwsBedrockHandler Error Handling", () => {
})

// -----------------------------------------------------------------------
// Non-throttling errors (createMessage) are wrapped by handleAiSdkError
// Non-throttling errors (createMessage) propagate unchanged
// -----------------------------------------------------------------------

describe("Non-throttling errors (createMessage)", () => {
it("should wrap non-throttling errors with provider name via handleAiSdkError", async () => {
it("should propagate non-throttling errors unchanged", async () => {
const genericError = createMockError({
message: "Something completely unexpected happened",
})
Expand All @@ -256,7 +256,7 @@ describe("AwsBedrockHandler Error Handling", () => {
for await (const _chunk of generator) {
// should throw
}
}).rejects.toThrow("Bedrock: Something completely unexpected happened")
}).rejects.toThrow("Something completely unexpected happened")
})

it("should preserve status code from non-throttling API errors", async () => {
Expand All @@ -277,8 +277,7 @@ describe("AwsBedrockHandler Error Handling", () => {
}
throw new Error("Expected error to be thrown")
} catch (error: any) {
expect(error.message).toContain("Bedrock:")
expect(error.message).toContain("Internal server error occurred")
expect(error.message).toBe("Internal server error occurred")
}
})

Expand All @@ -298,7 +297,7 @@ describe("AwsBedrockHandler Error Handling", () => {
for await (const _chunk of generator) {
// should throw
}
}).rejects.toThrow("Bedrock: Too many tokens in request")
}).rejects.toThrow("Too many tokens in request")
})
})

Expand Down Expand Up @@ -334,7 +333,7 @@ describe("AwsBedrockHandler Error Handling", () => {
}).rejects.toThrow("Bedrock is unable to process your request")
})

it("should wrap non-throttling errors that occur mid-stream via handleAiSdkError", async () => {
it("should propagate non-throttling errors that occur mid-stream unchanged", async () => {
const genericError = createMockError({
message: "Some other error",
status: 500,
Expand All @@ -357,32 +356,32 @@ describe("AwsBedrockHandler Error Handling", () => {
for await (const _chunk of generator) {
// should throw
}
}).rejects.toThrow("Bedrock: Some other error")
}).rejects.toThrow("Some other error")
})
})

// -----------------------------------------------------------------------
// completePrompt errors — all go through handleAiSdkError (no throttle check)
// completePrompt errors — propagate unchanged (no throttle check)
// -----------------------------------------------------------------------

describe("completePrompt error handling", () => {
it("should wrap errors with provider name for completePrompt", async () => {
it("should propagate errors unchanged for completePrompt", async () => {
mockGenerateText.mockRejectedValueOnce(new Error("Bedrock API failure"))

await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock: Bedrock API failure")
await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock API failure")
})

it("should wrap throttling-pattern errors with provider name for completePrompt", async () => {
it("should propagate throttling-pattern errors unchanged for completePrompt", async () => {
const throttleError = createMockError({
message: "Bedrock is unable to process your request",
status: 429,
})

mockGenerateText.mockRejectedValueOnce(throttleError)

// completePrompt does NOT have the throttle-rethrow path; it always uses handleAiSdkError
// completePrompt does NOT have the throttle-rethrow path; errors propagate unchanged
await expect(handler.completePrompt("test")).rejects.toThrow(
"Bedrock: Bedrock is unable to process your request",
"Bedrock is unable to process your request",
)
})

Expand All @@ -396,7 +395,7 @@ describe("AwsBedrockHandler Error Handling", () => {
results.forEach((result) => {
expect(result.status).toBe("rejected")
if (result.status === "rejected") {
expect(result.reason.message).toContain("Bedrock:")
expect(result.reason.message).toBe("API failure")
}
})
})
Expand All @@ -413,8 +412,7 @@ describe("AwsBedrockHandler Error Handling", () => {
await handler.completePrompt("test")
throw new Error("Expected error to be thrown")
} catch (error: any) {
expect(error.message).toContain("Bedrock:")
expect(error.message).toContain("Service unavailable")
expect(error.message).toBe("Service unavailable")
}
})
})
Expand Down Expand Up @@ -479,7 +477,8 @@ describe("AwsBedrockHandler Error Handling", () => {
it("should handle non-Error objects thrown by generateText", async () => {
mockGenerateText.mockRejectedValueOnce("string error")

await expect(handler.completePrompt("test")).rejects.toThrow("Bedrock: string error")
// Non-Error values propagate as-is
await expect(handler.completePrompt("test")).rejects.toBe("string error")
})

it("should handle non-Error objects thrown by streamText", async () => {
Expand All @@ -489,12 +488,12 @@ describe("AwsBedrockHandler Error Handling", () => {

const generator = handler.createMessage("system", [{ role: "user", content: "test" }])

// Non-Error values are not detected as throttling → handleAiSdkError path
// Non-Error values are not detected as throttling → propagate as-is
await expect(async () => {
for await (const _chunk of generator) {
// should throw
}
}).rejects.toThrow("Bedrock: string error")
}).rejects.toBe("string error")
})

it("should handle errors with unusual structure gracefully", async () => {
Expand All @@ -505,9 +504,8 @@ describe("AwsBedrockHandler Error Handling", () => {
await handler.completePrompt("test")
throw new Error("Expected error to be thrown")
} catch (error: any) {
// handleAiSdkError wraps with "Bedrock: ..."
expect(error.message).toContain("Bedrock:")
expect(error.message).not.toContain("undefined")
// Errors propagate unchanged — the object's message property is preserved
expect(error.message).toBe("Error with unusual structure")
}
})

Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/__tests__/lmstudio.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ describe("LmStudioHandler", () => {

it("should handle API errors with handleAiSdkError", async () => {
mockGenerateText.mockRejectedValueOnce(new Error("Connection refused"))
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("LM Studio")
await expect(handler.completePrompt("Test prompt")).rejects.toThrow("Connection refused")
})
})

Expand Down
19 changes: 1 addition & 18 deletions src/api/providers/__tests__/minimax.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ const {
mockCreateAnthropic,
mockModel,
mockMergeEnvironmentDetailsForMiniMax,
mockHandleAiSdkError,
} = vi.hoisted(() => {
const mockModel = vi.fn().mockReturnValue("mock-model-instance")
return {
Expand All @@ -24,10 +23,6 @@ const {
mockCreateAnthropic: vi.fn().mockReturnValue(mockModel),
mockModel,
mockMergeEnvironmentDetailsForMiniMax: vi.fn((messages: RooMessage[]) => messages),
mockHandleAiSdkError: vi.fn((error: unknown, providerName: string) => {
const message = error instanceof Error ? error.message : String(error)
return new Error(`${providerName}: ${message}`)
}),
}
})

Expand All @@ -44,13 +39,6 @@ vi.mock("../../transform/minimax-format", () => ({
mergeEnvironmentDetailsForMiniMax: mockMergeEnvironmentDetailsForMiniMax,
}))

vi.mock("../../transform/ai-sdk", async (importOriginal) => {
const actual = await importOriginal<typeof import("../../transform/ai-sdk")>()
return {
...actual,
handleAiSdkError: mockHandleAiSdkError,
}
})

type HandlerOptions = Omit<Partial<ApiHandlerOptions>, "minimaxBaseUrl"> & {
minimaxBaseUrl?: string
Expand Down Expand Up @@ -108,10 +96,6 @@ describe("MiniMaxHandler", () => {
vi.clearAllMocks()
mockCreateAnthropic.mockReturnValue(mockModel)
mockMergeEnvironmentDetailsForMiniMax.mockImplementation((inputMessages: RooMessage[]) => inputMessages)
mockHandleAiSdkError.mockImplementation((error: unknown, providerName: string) => {
const message = error instanceof Error ? error.message : String(error)
return new Error(`${providerName}: ${message}`)
})
})

describe("constructor", () => {
Expand Down Expand Up @@ -359,8 +343,7 @@ describe("MiniMaxHandler", () => {

await expect(async () => {
await collectChunks(stream)
}).rejects.toThrow("MiniMax: API Error")
expect(mockHandleAiSdkError).toHaveBeenCalledWith(expect.any(Error), "MiniMax")
}).rejects.toThrow("API Error")
})
})

Expand Down
27 changes: 27 additions & 0 deletions src/api/providers/__tests__/native-ollama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,33 @@ describe("NativeOllamaHandler", () => {
}).rejects.toThrow("Ollama service is not running")
})

it("propagates stream error when usage resolution fails after stream error", async () => {
async function* mockFullStream() {
yield { type: "error", error: new Error("upstream provider returned 500") }
}

mockStreamText.mockReturnValue({
fullStream: mockFullStream(),
usage: Promise.reject(new Error("No output generated")),
})

const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }])
const results: any[] = []

await expect(async () => {
for await (const chunk of stream) {
results.push(chunk)
}
}).rejects.toThrow("upstream provider returned 500")

// The stream error should have been yielded before the throw
expect(results).toContainEqual({
type: "error",
error: "StreamError",
message: "upstream provider returned 500",
})
})

it("should handle model not found errors", async () => {
const error = new Error("Not found") as any
error.status = 404
Expand Down
2 changes: 1 addition & 1 deletion src/api/providers/__tests__/openai-codex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ describe("OpenAiCodexHandler.completePrompt", () => {

mockGenerateText.mockRejectedValue(new Error("API Error"))

await expect(handler.completePrompt("Say hello")).rejects.toThrow("OpenAI Codex")
await expect(handler.completePrompt("Say hello")).rejects.toThrow("API Error")
})

it("should throw when not authenticated", async () => {
Expand Down
Loading
Loading