Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
*/
public abstract class BaseChatModelConnection extends Resource {

/**
* Reserved {@code modelParams} key carrying the raw output schema (a POJO {@link Class} or an
* {@link org.apache.flink.agents.api.agents.OutputSchema}) down to the connection so it can
* apply the provider's native structured-output mechanism. The key is intra-language only and
* must be removed before the provider SDK call so it never reaches the request body.
*/
public static final String STRUCTURED_OUTPUT_SCHEMA_KEY = "__structured_output_schema__";

public BaseChatModelConnection(ResourceDescriptor descriptor, ResourceContext resourceContext) {
super(descriptor, resourceContext);
}
Expand All @@ -45,6 +53,33 @@ public ResourceType getResourceType() {
return ResourceType.CHAT_MODEL_CONNECTION;
}

/**
* Whether this connection applies the provider's native structured-output API when an output
* schema is supplied. Connections that translate a schema into a native provider parameter
* override this to return {@code true}; the default false keeps non-native connections on the
* prompt-engineering fallback.
*
* @return true if this connection supports native structured output
*/
protected boolean supportsNativeStructuredOutput() {
return false;
}

/**
* Removes and returns the reserved structured-output schema from {@code modelParams}. Every
* connection must call this so the reserved key never leaks into the provider SDK request;
* native connections additionally use the returned value to build the native parameter.
*
* @param modelParams the mutable model parameters map (may be null)
* @return the raw output schema if present, otherwise null
*/
protected static Object popStructuredOutputSchema(Map<String, Object> modelParams) {
if (modelParams == null) {
return null;
}
return modelParams.remove(STRUCTURED_OUTPUT_SCHEMA_KEY);
}

/**
* Process a chat request and return a chat response.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,33 @@ void testChatRefillsTemplateOnSubsequentInvocations() {
assertEquals("tool result", connection.capturedMessages.get(1).getContent());
}

@Test
@DisplayName("popStructuredOutputSchema removes the reserved key and returns its value")
void testPopStructuredOutputSchemaRemovesAndReturns() {
Object schema = new Object();
Map<String, Object> modelParams = new HashMap<>();
modelParams.put(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, schema);
modelParams.put("temperature", 0.5);

Object popped = BaseChatModelConnection.popStructuredOutputSchema(modelParams);

assertSame(schema, popped);
assertFalse(modelParams.containsKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY));
assertTrue(modelParams.containsKey("temperature"));
}

@Test
@DisplayName(
"popStructuredOutputSchema returns null when the reserved key is absent or map is null")
void testPopStructuredOutputSchemaNoKey() {
Map<String, Object> modelParams = new HashMap<>();
modelParams.put("temperature", 0.5);

assertNull(BaseChatModelConnection.popStructuredOutputSchema(modelParams));
assertEquals(1, modelParams.size());
assertNull(BaseChatModelConnection.popStructuredOutputSchema(null));
}

@Test
@DisplayName("Test chat with long input")
void testChatWithLongInput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.openai.client.OpenAIClient;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.core.JsonSchemaLocalValidation;
import com.openai.core.JsonValue;
import com.openai.models.ChatModel;
import com.openai.models.FunctionDefinition;
import com.openai.models.FunctionParameters;
import com.openai.models.ReasoningEffort;
import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.chat.completions.ChatCompletion;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.openai.models.chat.completions.ChatCompletionFunctionTool;
Expand Down Expand Up @@ -119,6 +121,11 @@ public OpenAICompletionsConnection(
this.client = builder.build();
}

@Override
protected boolean supportsNativeStructuredOutput() {
return true;
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> modelParams) {
Expand Down Expand Up @@ -150,7 +157,9 @@ public ChatMessage chat(
}
}

private ChatCompletionCreateParams buildRequest(
// Package-private so the request body (including the native response_format) can be asserted
// without issuing a live API call through the final OpenAI client.
ChatCompletionCreateParams buildRequest(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> rawModelParams) {
Map<String, Object> modelParams =
rawModelParams != null ? new HashMap<>(rawModelParams) : new HashMap<>();
Expand All @@ -161,15 +170,25 @@ private ChatCompletionCreateParams buildRequest(
modelName = this.defaultModel;
}

// Always pop the reserved schema so it never leaks into the SDK request body.
Object outputSchema = popStructuredOutputSchema(modelParams);

ChatCompletionCreateParams.Builder builder =
ChatCompletionCreateParams.builder()
.model(ChatModel.of(modelName))
.messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages));

if (tools != null && !tools.isEmpty()) {
boolean hasTools = tools != null && !tools.isEmpty();
if (hasTools) {
builder.tools(convertTools(tools, strictMode));
}

// Native structured output applies only for a POJO Class schema when no tools are bound;
// a RowTypeInfo (wrapped in OutputSchema) keeps the prompt-engineering fallback.
if (outputSchema instanceof Class && !hasTools) {
builder.responseFormat(toNativeResponseFormat((Class<?>) outputSchema));
}

Object temperature = modelParams.remove("temperature");
if (temperature instanceof Number) {
builder.temperature(((Number) temperature).doubleValue());
Expand Down Expand Up @@ -208,6 +227,26 @@ private ChatCompletionCreateParams buildRequest(
return builder.build();
}

// Derives the strict json_schema response format from a POJO class via the SDK's typed
// structured-output builder. The Kotlin-facade StructuredOutputsKt.responseFormatFromClass is
// not callable from Java, so the response format is extracted through the typed builder, which
// generates the same strict draft-2020-12 schema, and then reattached to the standard builder.
private static <T> ResponseFormatJsonSchema toNativeResponseFormat(Class<T> schemaClass) {
return ChatCompletionCreateParams.builder()
.model(ChatModel.of(""))
.addUserMessage("")
.responseFormat(schemaClass, JsonSchemaLocalValidation.NO)
.build()
.rawParams()
.responseFormat()
.orElseThrow(
() ->
new IllegalStateException(
"OpenAI SDK did not produce a response_format for schema "
+ schemaClass.getName()))
.asJsonSchema();
}

private List<ChatCompletionTool> convertTools(List<Tool> tools, boolean strictMode) {
List<ChatCompletionTool> openaiTools = new ArrayList<>(tools.size());
for (Tool tool : tools) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.agents.integrations.chatmodels.openai;

import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.agents.api.tools.ToolMetadata;
import org.apache.flink.agents.api.tools.ToolParameters;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.api.tools.ToolType;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Unit tests for {@link OpenAICompletionsConnection}'s native structured-output behavior. These
* assert the built request body without a live API call by inspecting {@code buildRequest}.
*/
class OpenAICompletionsConnectionTest {

private static final ResourceContext NOOP = ResourceContext.fromGetResource((a, b) -> null);

/** A representative POJO output schema. */
public static class Person {
public String name;
public int age;
}

private static OpenAICompletionsConnection connection() {
ResourceDescriptor desc =
ResourceDescriptor.Builder.newBuilder(OpenAICompletionsConnection.class.getName())
.addInitialArgument("api_key", "test-key")
.addInitialArgument("model", "gpt-4o")
.build();
return new OpenAICompletionsConnection(desc, NOOP);
}

private static Map<String, Object> paramsWithSchema(Object schema) {
Map<String, Object> params = new HashMap<>();
params.put("model", "gpt-4o");
if (schema != null) {
params.put(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, schema);
}
return params;
}

private static List<ChatMessage> userMessage() {
return List.of(new ChatMessage(MessageRole.USER, "hi"));
}

@Test
@DisplayName("Native response_format json_schema strict applied for a POJO and no tools")
void testNativeAppliedForPojoNoTools() {
ChatCompletionCreateParams params =
connection().buildRequest(userMessage(), List.of(), paramsWithSchema(Person.class));

assertThat(params.responseFormat()).isPresent();
ResponseFormatJsonSchema jsonSchema = params.responseFormat().get().asJsonSchema();
assertThat(jsonSchema.jsonSchema().strict()).contains(true);
}

/** Minimal tool stub; only its presence in the list matters for the empty-tools gate. */
private static class StubTool extends Tool {
StubTool() {
super(new ToolMetadata("add", "adds", "{\"type\":\"object\"}"));
}

@Override
public ToolType getToolType() {
return ToolType.FUNCTION;
}

@Override
public ToolResponse call(ToolParameters parameters) {
return ToolResponse.success(null);
}
}

@Test
@DisplayName("Native NOT applied when tools are bound (empty-tools gate)")
void testNativeNotAppliedWithTools() {
ChatCompletionCreateParams params =
connection()
.buildRequest(
userMessage(),
List.of(new StubTool()),
paramsWithSchema(Person.class));

assertThat(params.responseFormat()).isEmpty();
}

@Test
@DisplayName("Native NOT applied for a non-POJO schema form (BaseModel/POJO-only scope)")
void testNativeNotAppliedForNonClassSchema() {
// A RowTypeInfo schema arrives wrapped (not a bare POJO Class), so it must not activate
// native structured output; any non-Class schema object exercises the same gate.
Object nonClassSchema = "row<name STRING>";

ChatCompletionCreateParams params =
connection()
.buildRequest(userMessage(), List.of(), paramsWithSchema(nonClassSchema));

assertThat(params.responseFormat()).isEmpty();
}

@Test
@DisplayName(
"Reserved schema key is consumed as response_format, not passed through as a body property")
void testReservedKeyConsumedNotLeaked() {
// The reserved key is consumed by the native path into response_format rather than left
// in the modelParams to leak. The pop-helper's remove-and-return contract (which makes
// this possible) is exercised directly in BaseChatModelTest; this case pins that for the
// OpenAI connection the reserved key drives response_format and is absent from the body.
ChatCompletionCreateParams params =
connection().buildRequest(userMessage(), List.of(), paramsWithSchema(Person.class));

assertThat(params.responseFormat()).isPresent();
assertThat(params._additionalBodyProperties())
.doesNotContainKey(BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY);
}

@Test
@DisplayName("OpenAI completions declares native structured-output support")
void testDeclaresNativeCapability() {
assertThat(connection().supportsNativeStructuredOutput()).isTrue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.agents.api.agents.OutputSchema;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
import org.apache.flink.agents.api.chat.model.python.PythonChatModelSetup;
import org.apache.flink.agents.api.context.DurableCallable;
Expand Down Expand Up @@ -348,6 +349,18 @@ public static void chat(
int actualRetryCount = 0;
int totalWaitTimeSec = 0;

// Thread the output schema to the connection via a reserved modelParams key so a
// native-capable connection can apply the provider's structured-output API. The
// connection pops the key before its SDK call (see BaseChatModelConnection). Only
// thread it for a same-language (Java) setup: native structured output cannot work
// across the Pemja bridge because a Java schema object is not consumable by a Python
// connection, so a Python-backed setup keeps the prior empty-map behavior.
final Map<String, Object> modelParams =
outputSchema != null && !(chatModel instanceof PythonChatModelSetup)
? Collections.singletonMap(
BaseChatModelConnection.STRUCTURED_OUTPUT_SCHEMA_KEY, outputSchema)
: Map.of();

DurableCallable<ChatMessage> callable =
new DurableCallable<>() {
@Override
Expand All @@ -362,7 +375,7 @@ public Class<ChatMessage> getResultClass() {

@Override
public ChatMessage call() throws Exception {
return chatModel.chat(messages, promptArgs, Map.of());
return chatModel.chat(messages, promptArgs, modelParams);
}
};

Expand Down
Loading
Loading