diff --git a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts index 3b470ce461e..3a01d032e75 100644 --- a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts +++ b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts @@ -1,5 +1,11 @@ // npx vitest run api/providers/__tests__/qwen-code-native-tools.spec.ts +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + // Mock filesystem - must come before other imports vi.mock("node:fs", () => ({ promises: { @@ -8,25 +14,27 @@ vi.mock("node:fs", () => ({ }, })) -const mockCreate = vi.fn() -vi.mock("openai", () => { +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() return { - __esModule: true, - default: vi.fn().mockImplementation(() => ({ - apiKey: "test-key", - baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", - chat: { - completions: { - create: mockCreate, - }, - }, - })), + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, } }) +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "qwen3-coder-plus", + provider: "qwen-code", + })) + }), +})) + import { promises as fs } from "node:fs" import { QwenCodeHandler } from "../qwen-code" -import { NativeToolCallParser } from "../../../core/assistant-message/NativeToolCallParser" import type { ApiHandlerOptions } from "../../../shared/api" describe("QwenCodeHandler Native Tools", () => { @@ -68,20 +76,18 @@ describe("QwenCodeHandler Native Tools", () => { apiModelId: "qwen3-coder-plus", } handler = new QwenCodeHandler(mockOptions) - - // Clear NativeToolCallParser state before each test - NativeToolCallParser.clearRawChunkState() }) describe("Native Tool Calling Support", () => { it("should include tools in request when model supports native tools and tools are provided", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -89,29 +95,24 @@ describe("QwenCodeHandler Native Tools", () => { }) await stream.next() - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - parallel_tool_calls: true, + tools: expect.objectContaining({ + test_tool: expect.any(Object), + }), }), ) }) it("should include tool_choice when provided", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -120,21 +121,22 @@ describe("QwenCodeHandler Native Tools", () => { }) await stream.next() - expect(mockCreate).toHaveBeenCalledWith( + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - tool_choice: "auto", + toolChoice: "auto", }), ) }) - it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + it("should always include tools and toolChoice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -142,51 +144,38 @@ describe("QwenCodeHandler Native Tools", () => { await stream.next() // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] + const callArgs = mockStreamText.mock.calls[mockStreamText.mock.calls.length - 1][0] expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - expect(callArgs).toHaveProperty("parallel_tool_calls", true) + expect(callArgs).toHaveProperty("toolChoice") }) - it("should yield tool_call_partial chunks during streaming", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_qwen_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - } - }, - })) + it("should yield tool call chunks during streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "call_qwen_123", + toolName: "test_tool", + } + yield { + type: "tool-input-delta", + id: "call_qwen_123", + delta: '{"arg1":', + } + yield { + type: "tool-input-delta", + id: "call_qwen_123", + delta: '"value"}', + } + yield { + type: "tool-input-end", + id: "call_qwen_123", + } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -198,78 +187,37 @@ describe("QwenCodeHandler Native Tools", () => { chunks.push(chunk) } - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_qwen_123", - name: "test_tool", - arguments: '{"arg1":', - }) - - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', - }) + // Check for tool_call_start, tool_call_delta, and tool_call_end chunks + const startChunks = chunks.filter((chunk) => chunk.type === "tool_call_start") + const deltaChunks = chunks.filter((chunk) => chunk.type === "tool_call_delta") + const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(startChunks.length).toBeGreaterThan(0) + expect(deltaChunks.length).toBeGreaterThan(0) + expect(endChunks.length).toBeGreaterThan(0) }) - it("should set parallel_tool_calls based on metadata", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [{ delta: { content: "Test response" } }], - } - }, - })) + it("should yield tool_call_end events when tool call is complete", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "call_qwen_test", + toolName: "test_tool", + } + yield { + type: "tool-input-delta", + id: "call_qwen_test", + delta: '{"arg1":"value"}', + } + yield { + type: "tool-input-end", + id: "call_qwen_test", + } + } - const stream = handler.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - parallelToolCalls: true, + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), }) - await stream.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - parallel_tool_calls: true, - }), - ) - }) - - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_qwen_test", - function: { - name: "test_tool", - arguments: '{"arg1":"value"}', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - } - }, - })) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -278,68 +226,41 @@ describe("QwenCodeHandler Native Tools", () => { const chunks = [] for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } chunks.push(chunk) } - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") + // Should have tool_call_end from the tool-input-end event const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") - - expect(partialChunks).toHaveLength(1) expect(endChunks).toHaveLength(1) expect(endChunks[0].id).toBe("call_qwen_test") }) - it("should preserve thinking block handling alongside tool calls", async () => { - mockCreate.mockImplementationOnce(() => ({ - [Symbol.asyncIterator]: async function* () { - yield { - choices: [ - { - delta: { - reasoning_content: "Thinking about this...", - }, - }, - ], - } - yield { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_after_think", - function: { - name: "test_tool", - arguments: '{"arg1":"result"}', - }, - }, - ], - }, - }, - ], - } - yield { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - } - }, - })) + it("should preserve reasoning handling alongside tool calls", async () => { + async function* mockFullStream() { + yield { + type: "reasoning", + text: "Thinking about this...", + } + yield { + type: "tool-input-start", + id: "call_after_think", + toolName: "test_tool", + } + yield { + type: "tool-input-delta", + id: "call_after_think", + delta: '{"arg1":"result"}', + } + yield { + type: "tool-input-end", + id: "call_after_think", + } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) const stream = handler.createMessage("test prompt", [], { taskId: "test-task-id", @@ -348,26 +269,156 @@ describe("QwenCodeHandler Native Tools", () => { const chunks = [] for await (const chunk of stream) { - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } chunks.push(chunk) } - // Should have reasoning, tool_call_partial, and tool_call_end + // Should have reasoning and tool_call_end const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning") - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") expect(reasoningChunks).toHaveLength(1) expect(reasoningChunks[0].text).toBe("Thinking about this...") - expect(partialChunks).toHaveLength(1) expect(endChunks).toHaveLength(1) }) }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) + }) + }) + + describe("OAuth credential handling", () => { + it("should load credentials from file", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) + + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + }) + await stream.next() + + expect(fs.readFile).toHaveBeenCalled() + }) + + it("should refresh token when expired", async () => { + // Mock expired credentials + const expiredCredentials = { + access_token: "expired-token", + refresh_token: "test-refresh-token", + token_type: "Bearer", + expiry_date: Date.now() - 1000, // Expired 1 second ago + resource_url: "https://dashscope.aliyuncs.com/compatible-mode/v1", + } + ;(fs.readFile as any).mockResolvedValue(JSON.stringify(expiredCredentials)) + + // Mock the token refresh endpoint + const mockFetch = vi.fn().mockResolvedValue({ + ok: true, + json: async () => ({ + access_token: "new-access-token", + refresh_token: "new-refresh-token", + token_type: "Bearer", + expires_in: 3600, + }), + }) + global.fetch = mockFetch + + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + }) + + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + }) + await stream.next() + + // Should have called fetch to refresh the token + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining("oauth2/token"), + expect.objectContaining({ + method: "POST", + }), + ) + + // Should have saved the new credentials + expect(fs.writeFile).toHaveBeenCalled() + }) + }) + + describe("getModel", () => { + it("should return model info for valid model ID", () => { + const model = handler.getModel() + expect(model.id).toBe("qwen3-coder-plus") + expect(model.info).toBeDefined() + expect(model.info.maxTokens).toBe(65536) + expect(model.info.contextWindow).toBe(1000000) + }) + + it("should return default model if no model ID is provided", () => { + const handlerWithoutModel = new QwenCodeHandler({}) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe("qwen3-coder-plus") + expect(model.info).toBeDefined() + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) + }) + + describe("usage metrics", () => { + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + details: {}, + }), + }) + + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(5) + }) + }) }) diff --git a/src/api/providers/qwen-code.ts b/src/api/providers/qwen-code.ts index 18d09a59f3b..886452d645b 100644 --- a/src/api/providers/qwen-code.ts +++ b/src/api/providers/qwen-code.ts @@ -1,18 +1,26 @@ import { promises as fs } from "node:fs" -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" import * as os from "os" import * as path from "path" +import { Anthropic } from "@anthropic-ai/sdk" +import { createOpenAICompatible } from "@ai-sdk/openai-compatible" +import { streamText, generateText, ToolSet } from "ai" + import { type ModelInfo, type QwenCodeModelId, qwenCodeModels, qwenCodeDefaultModelId } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" - -import { convertToOpenAiMessages } from "../transform/openai-format" -import { ApiStream } from "../transform/stream" - +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" @@ -51,10 +59,14 @@ function objectToUrlEncoded(data: Record): string { .join("&") } +/** + * Qwen Code provider using @ai-sdk/openai-compatible. + * Uses OAuth credentials for authentication with automatic token refresh. + */ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHandler { protected options: QwenCodeHandlerOptions private credentials: QwenOAuthCredentials | null = null - private client: OpenAI | undefined + private provider: ReturnType | null = null private refreshPromise: Promise | null = null constructor(options: QwenCodeHandlerOptions) { @@ -62,18 +74,6 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan this.options = options } - private ensureClient(): OpenAI { - if (!this.client) { - // Create the client instance with dummy key initially - // The API key will be updated dynamically via ensureAuthenticated - this.client = new OpenAI({ - apiKey: "dummy-key-will-be-replaced", - baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", - }) - } - return this.client - } - private async loadCachedQwenCredentials(): Promise { try { const keyFile = getQwenCachedCredentialPath(this.options.qwenCodeOauthPath) @@ -170,12 +170,9 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan if (!this.isTokenValid(this.credentials)) { this.credentials = await this.refreshAccessToken(this.credentials) + // Invalidate provider when credentials change + this.provider = null } - - // After authentication, update the apiKey and baseURL on the existing client - const client = this.ensureClient() - client.apiKey = this.credentials.access_token - client.baseURL = this.getBaseUrl(this.credentials) } private getBaseUrl(creds: QwenOAuthCredentials): string { @@ -186,154 +183,182 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan return baseUrl.endsWith("/v1") ? baseUrl : `${baseUrl}/v1` } - private async callApiWithRetry(apiCall: () => Promise): Promise { - try { - return await apiCall() - } catch (error: any) { - if (error.status === 401) { - // Token expired, refresh and retry - this.credentials = await this.refreshAccessToken(this.credentials!) - const client = this.ensureClient() - client.apiKey = this.credentials.access_token - client.baseURL = this.getBaseUrl(this.credentials) - return await apiCall() - } else { - throw error - } + /** + * Get or create the AI SDK provider. Creates a new provider if credentials have changed. + */ + private getProvider(): ReturnType { + if (!this.credentials) { + throw new Error("Not authenticated. Call ensureAuthenticated first.") } + + if (!this.provider) { + this.provider = createOpenAICompatible({ + name: "qwen-code", + baseURL: this.getBaseUrl(this.credentials), + apiKey: this.credentials.access_token, + headers: DEFAULT_HEADERS, + }) + } + + return this.provider + } + + /** + * Get the language model for the configured model ID. + */ + private getLanguageModel() { + const { id } = this.getModel() + const provider = this.getProvider() + return provider(id) + } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? qwenCodeDefaultModelId + const info = qwenCodeModels[id as keyof typeof qwenCodeModels] || qwenCodeModels[qwenCodeDefaultModelId] + const params = getModelParams({ format: "openai", modelId: id, model: info, settings: this.options }) + return { id, info, ...params } + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics(usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens: usage.details?.cachedInputTokens, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max output tokens for requests. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined } + /** + * Create a message stream using the AI SDK. + */ override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { await this.ensureAuthenticated() - const client = this.ensureClient() - const model = this.getModel() - - const systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { - role: "system", - content: systemPrompt, - } - const convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] - - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: model.id, - temperature: 0, - messages: convertedMessages, - stream: true, - stream_options: { include_usage: true }, - max_completion_tokens: model.info.maxTokens, - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const model = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: model.temperature ?? 0, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), } - const stream = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) - - let fullContent = "" - - for await (const apiChunk of stream) { - const delta = apiChunk.choices[0]?.delta ?? {} - const finishReason = apiChunk.choices[0]?.finish_reason + try { + // Use streamText for streaming responses + const result = streamText(requestOptions) - if (delta.content) { - let newText = delta.content - if (newText.startsWith(fullContent)) { - newText = newText.substring(fullContent.length) - } - fullContent = delta.content - - if (newText) { - // Check for thinking blocks - if (newText.includes("") || newText.includes("")) { - // Simple parsing for thinking blocks - const parts = newText.split(/<\/?think>/g) - for (let i = 0; i < parts.length; i++) { - if (parts[i]) { - if (i % 2 === 0) { - // Outside thinking block - yield { - type: "text", - text: parts[i], - } - } else { - // Inside thinking block - yield { - type: "reasoning", - text: parts[i], - } - } - } - } - } else { - yield { - type: "text", - text: newText, - } - } + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - if ("reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: (delta.reasoning_content as string | undefined) || "", - } + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) } + } catch (error: any) { + // Handle 401 errors by refreshing token and retrying once + if (error?.statusCode === 401 || error?.response?.status === 401) { + this.credentials = await this.refreshAccessToken(this.credentials!) + this.provider = null // Force provider recreation - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, + // Retry with new credentials + const retryResult = streamText({ + ...requestOptions, + model: this.getLanguageModel(), + }) + + for await (const part of retryResult.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - } - // Process finish_reason to emit tool_call_end events - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event + const retryUsage = await retryResult.usage + if (retryUsage) { + yield this.processUsageMetrics(retryUsage) } + return } - if (apiChunk.usage) { - yield { - type: "usage", - inputTokens: apiChunk.usage.prompt_tokens || 0, - outputTokens: apiChunk.usage.completion_tokens || 0, - } - } + // Handle other AI SDK errors + throw handleAiSdkError(error, "Qwen Code") } } - override getModel(): { id: string; info: ModelInfo } { - const id = this.options.apiModelId ?? qwenCodeDefaultModelId - const info = qwenCodeModels[id as keyof typeof qwenCodeModels] || qwenCodeModels[qwenCodeDefaultModelId] - return { id, info } - } - + /** + * Complete a prompt using the AI SDK generateText. + */ async completePrompt(prompt: string): Promise { await this.ensureAuthenticated() - const client = this.ensureClient() - const model = this.getModel() - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: model.id, - messages: [{ role: "user", content: prompt }], - max_completion_tokens: model.info.maxTokens, - } + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() - const response = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) + try { + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: temperature ?? 0, + }) - return response.choices[0]?.message.content || "" + return text + } catch (error: any) { + // Handle 401 errors by refreshing token and retrying once + if (error?.statusCode === 401 || error?.response?.status === 401) { + this.credentials = await this.refreshAccessToken(this.credentials!) + this.provider = null + + const { text } = await generateText({ + model: this.getLanguageModel(), + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: temperature ?? 0, + }) + + return text + } + + throw handleAiSdkError(error, "Qwen Code") + } } }