diff --git a/bom/application/pom.xml b/bom/application/pom.xml index b09f00c2486..f739f4a5340 100644 --- a/bom/application/pom.xml +++ b/bom/application/pom.xml @@ -80,12 +80,46 @@ - - + + io.projectreactor + reactor-core + 3.4.41 + - io.netty - netty-all - --> + + + io.netty + netty-common + 4.1.118.Final + + + io.netty + netty-buffer + 4.1.118.Final + + + io.netty + netty-transport + 4.1.118.Final + + + io.netty + netty-resolver + 4.1.118.Final + + + io.netty + netty-codec + 4.1.118.Final + + + io.netty + netty-handler + 4.1.118.Final + + dev.langchain4j + langchain4j-azure-open-ai + jakarta.inject jakarta.inject-api diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java index 8effc93212c..c93a10f9298 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java @@ -1,5 +1,9 @@ package com.dotcms.ai.client.langchain4j; +import dev.langchain4j.model.azure.AzureOpenAiChatModel; +import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel; +import dev.langchain4j.model.azure.AzureOpenAiImageModel; +import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -20,8 +24,8 @@ * To add support for a new provider, add a case to each switch block below. * No other class needs to change. * - *

Supported providers (Phase 1): {@code openai} - *

Planned (Phase 2): {@code azure_openai}, {@code bedrock}, {@code vertex_ai} + *

Supported providers: {@code openai}, {@code azure_openai} + *

Planned (Phase 2): {@code bedrock}, {@code vertex_ai} */ public class LangChain4jModelFactory { @@ -35,7 +39,9 @@ private LangChain4jModelFactory() {} * @throws IllegalArgumentException if config or provider is null, or the provider is unsupported */ public static ChatModel buildChatModel(final ProviderConfig config) { - return build(config, "chat", LangChain4jModelFactory::buildOpenAiChatModel); + return build(config, "chat", + LangChain4jModelFactory::buildOpenAiChatModel, + LangChain4jModelFactory::buildAzureOpenAiChatModel); } /** @@ -46,7 +52,9 @@ public static ChatModel buildChatModel(final ProviderConfig config) { * @throws IllegalArgumentException if config or provider is null, or the provider is unsupported */ public static StreamingChatModel buildStreamingChatModel(final ProviderConfig config) { - return build(config, "chat", LangChain4jModelFactory::buildOpenAiStreamingChatModel); + return build(config, "chat", + LangChain4jModelFactory::buildOpenAiStreamingChatModel, + LangChain4jModelFactory::buildAzureOpenAiStreamingChatModel); } /** @@ -57,7 +65,9 @@ public static StreamingChatModel buildStreamingChatModel(final ProviderConfig co * @throws IllegalArgumentException if config or provider is null, or the provider is unsupported */ public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) { - return build(config, "embeddings", LangChain4jModelFactory::buildOpenAiEmbeddingModel); + return build(config, "embeddings", + LangChain4jModelFactory::buildOpenAiEmbeddingModel, + LangChain4jModelFactory::buildAzureOpenAiEmbeddingModel); } /** @@ -68,23 +78,29 @@ public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) { * @throws IllegalArgumentException if config or provider is null, or the provider is unsupported */ public static ImageModel buildImageModel(final ProviderConfig config) { - return build(config, "image", LangChain4jModelFactory::buildOpenAiImageModel); + return build(config, "image", + LangChain4jModelFactory::buildOpenAiImageModel, + LangChain4jModelFactory::buildAzureOpenAiImageModel); } private static T build(final ProviderConfig config, final String modelType, - final Function openAiFn) { + final Function openAiFn, + final Function azureOpenAiFn) { if (config == null || config.provider() == null) { throw new IllegalArgumentException("ProviderConfig or provider name is null for model type: " + modelType); } - requireNonBlank(config.model(), "model", modelType); switch (config.provider().toLowerCase()) { case "openai": + requireNonBlank(config.model(), "model", modelType); validateOpenAi(config, modelType); return openAiFn.apply(config); + case "azure_openai": + validateAzureOpenAi(config, modelType); + return azureOpenAiFn.apply(config); default: throw new IllegalArgumentException("Unsupported " + modelType + " provider: " - + config.provider() + ". Supported in Phase 1: openai"); + + config.provider() + ". Supported: openai, azure_openai"); } } @@ -92,6 +108,16 @@ private static void validateOpenAi(final ProviderConfig config, final String mod requireNonBlank(config.apiKey(), "apiKey", modelType); } + private static void validateAzureOpenAi(final ProviderConfig config, final String modelType) { + requireNonBlank(config.apiKey(), "apiKey", modelType); + requireNonBlank(config.endpoint(), "endpoint", modelType); + if ((config.model() == null || config.model().isBlank()) + && (config.deploymentName() == null || config.deploymentName().isBlank())) { + throw new IllegalArgumentException( + "providerConfig." + modelType + ": either 'model' or 'deploymentName' is required for azure_openai"); + } + } + private static void requireNonBlank(final String value, final String field, final String modelType) { if (value == null || value.isBlank()) { throw new IllegalArgumentException( @@ -162,4 +188,55 @@ private static ImageModel buildOpenAiImageModel(final ProviderConfig config) { return builder.build(); } + // ── Azure OpenAI builders ───────────────────────────────────────────────── + + private static StreamingChatModel buildAzureOpenAiStreamingChatModel(final ProviderConfig config) { + final AzureOpenAiStreamingChatModel.Builder builder = AzureOpenAiStreamingChatModel.builder() + .apiKey(config.apiKey()) + .endpoint(config.endpoint()) + .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model()); + if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion()); + if (config.maxRetries() != null) builder.maxRetries(config.maxRetries()); + if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout())); + if (config.temperature() != null) builder.temperature(config.temperature()); + if (config.maxTokens() != null) builder.maxTokens(config.maxTokens()); + return builder.build(); + } + + private static ChatModel buildAzureOpenAiChatModel(final ProviderConfig config) { + final AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() + .apiKey(config.apiKey()) + .endpoint(config.endpoint()) + .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model()); + if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion()); + if (config.maxRetries() != null) builder.maxRetries(config.maxRetries()); + if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout())); + if (config.temperature() != null) builder.temperature(config.temperature()); + if (config.maxTokens() != null) builder.maxTokens(config.maxTokens()); + return builder.build(); + } + + private static EmbeddingModel buildAzureOpenAiEmbeddingModel(final ProviderConfig config) { + final AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder() + .apiKey(config.apiKey()) + .endpoint(config.endpoint()) + .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model()); + if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion()); + if (config.maxRetries() != null) builder.maxRetries(config.maxRetries()); + if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout())); + return builder.build(); + } + + private static ImageModel buildAzureOpenAiImageModel(final ProviderConfig config) { + final AzureOpenAiImageModel.Builder builder = AzureOpenAiImageModel.builder() + .apiKey(config.apiKey()) + .endpoint(config.endpoint()) + .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model()); + if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion()); + if (config.maxRetries() != null) builder.maxRetries(config.maxRetries()); + if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout())); + if (config.size() != null) builder.size(config.size()); + return builder.build(); + } + } diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java index 04fc382a586..b0e1e0a26f1 100644 --- a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java +++ b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java @@ -39,6 +39,32 @@ public void test_buildChatModel_openai_missingApiKey_throws() { assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); } + @Test + public void test_buildChatModel_azureOpenai_returnsModel() { + final ChatModel model = LangChain4jModelFactory.buildChatModel(azureOpenAiConfig("gpt-4o")); + assertNotNull(model); + } + + @Test + public void test_buildChatModel_azureOpenai_missingApiKey_throws() { + final ProviderConfig config = ImmutableProviderConfig.builder() + .provider("azure_openai") + .model("gpt-4o") + .endpoint("https://my-company.openai.azure.com/") + .build(); + assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); + } + + @Test + public void test_buildChatModel_azureOpenai_missingEndpoint_throws() { + final ProviderConfig config = ImmutableProviderConfig.builder() + .provider("azure_openai") + .model("gpt-4o") + .apiKey("test-key") + .build(); + assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config)); + } + @Test public void test_buildChatModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -60,6 +86,12 @@ public void test_buildEmbeddingModel_openai_returnsModel() { assertNotNull(model); } + @Test + public void test_buildEmbeddingModel_azureOpenai_returnsModel() { + final EmbeddingModel model = LangChain4jModelFactory.buildEmbeddingModel(azureOpenAiConfig("text-embedding-ada-002")); + assertNotNull(model); + } + @Test public void test_buildEmbeddingModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -81,6 +113,12 @@ public void test_buildImageModel_openai_returnsModel() { assertNotNull(model); } + @Test + public void test_buildImageModel_azureOpenai_returnsModel() { + final ImageModel model = LangChain4jModelFactory.buildImageModel(azureOpenAiConfig("dall-e-3")); + assertNotNull(model); + } + @Test public void test_buildImageModel_unknownProvider_throws() { final ProviderConfig config = ImmutableProviderConfig.builder() @@ -99,4 +137,15 @@ private static ProviderConfig openAiConfig(final String model) { .build(); } + private static ProviderConfig azureOpenAiConfig(final String model) { + return ImmutableProviderConfig.builder() + .provider("azure_openai") + .model(model) + .apiKey("test-key") + .endpoint("https://my-company.openai.azure.com/") + .deploymentName(model) + .apiVersion("2024-02-01") + .build(); + } + }