jakarta.inject
jakarta.inject-api
diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java
index 8effc93212c..c93a10f9298 100644
--- a/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java
+++ b/dotCMS/src/main/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactory.java
@@ -1,5 +1,9 @@
package com.dotcms.ai.client.langchain4j;
+import dev.langchain4j.model.azure.AzureOpenAiChatModel;
+import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
+import dev.langchain4j.model.azure.AzureOpenAiImageModel;
+import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -20,8 +24,8 @@
* To add support for a new provider, add a case to each switch block below.
* No other class needs to change.
*
- * Supported providers (Phase 1): {@code openai}
- *
Planned (Phase 2): {@code azure_openai}, {@code bedrock}, {@code vertex_ai}
+ *
Supported providers: {@code openai}, {@code azure_openai}
+ *
Planned (Phase 2): {@code bedrock}, {@code vertex_ai}
*/
public class LangChain4jModelFactory {
@@ -35,7 +39,9 @@ private LangChain4jModelFactory() {}
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static ChatModel buildChatModel(final ProviderConfig config) {
- return build(config, "chat", LangChain4jModelFactory::buildOpenAiChatModel);
+ return build(config, "chat",
+ LangChain4jModelFactory::buildOpenAiChatModel,
+ LangChain4jModelFactory::buildAzureOpenAiChatModel);
}
/**
@@ -46,7 +52,9 @@ public static ChatModel buildChatModel(final ProviderConfig config) {
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static StreamingChatModel buildStreamingChatModel(final ProviderConfig config) {
- return build(config, "chat", LangChain4jModelFactory::buildOpenAiStreamingChatModel);
+ return build(config, "chat",
+ LangChain4jModelFactory::buildOpenAiStreamingChatModel,
+ LangChain4jModelFactory::buildAzureOpenAiStreamingChatModel);
}
/**
@@ -57,7 +65,9 @@ public static StreamingChatModel buildStreamingChatModel(final ProviderConfig co
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) {
- return build(config, "embeddings", LangChain4jModelFactory::buildOpenAiEmbeddingModel);
+ return build(config, "embeddings",
+ LangChain4jModelFactory::buildOpenAiEmbeddingModel,
+ LangChain4jModelFactory::buildAzureOpenAiEmbeddingModel);
}
/**
@@ -68,23 +78,29 @@ public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) {
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static ImageModel buildImageModel(final ProviderConfig config) {
- return build(config, "image", LangChain4jModelFactory::buildOpenAiImageModel);
+ return build(config, "image",
+ LangChain4jModelFactory::buildOpenAiImageModel,
+ LangChain4jModelFactory::buildAzureOpenAiImageModel);
}
private static T build(final ProviderConfig config,
final String modelType,
- final Function openAiFn) {
+ final Function openAiFn,
+ final Function azureOpenAiFn) {
if (config == null || config.provider() == null) {
throw new IllegalArgumentException("ProviderConfig or provider name is null for model type: " + modelType);
}
- requireNonBlank(config.model(), "model", modelType);
switch (config.provider().toLowerCase()) {
case "openai":
+ requireNonBlank(config.model(), "model", modelType);
validateOpenAi(config, modelType);
return openAiFn.apply(config);
+ case "azure_openai":
+ validateAzureOpenAi(config, modelType);
+ return azureOpenAiFn.apply(config);
default:
throw new IllegalArgumentException("Unsupported " + modelType + " provider: "
- + config.provider() + ". Supported in Phase 1: openai");
+ + config.provider() + ". Supported: openai, azure_openai");
}
}
@@ -92,6 +108,16 @@ private static void validateOpenAi(final ProviderConfig config, final String mod
requireNonBlank(config.apiKey(), "apiKey", modelType);
}
+ private static void validateAzureOpenAi(final ProviderConfig config, final String modelType) {
+ requireNonBlank(config.apiKey(), "apiKey", modelType);
+ requireNonBlank(config.endpoint(), "endpoint", modelType);
+ if ((config.model() == null || config.model().isBlank())
+ && (config.deploymentName() == null || config.deploymentName().isBlank())) {
+ throw new IllegalArgumentException(
+ "providerConfig." + modelType + ": either 'model' or 'deploymentName' is required for azure_openai");
+ }
+ }
+
private static void requireNonBlank(final String value, final String field, final String modelType) {
if (value == null || value.isBlank()) {
throw new IllegalArgumentException(
@@ -162,4 +188,55 @@ private static ImageModel buildOpenAiImageModel(final ProviderConfig config) {
return builder.build();
}
+ // ── Azure OpenAI builders ─────────────────────────────────────────────────
+
+ private static StreamingChatModel buildAzureOpenAiStreamingChatModel(final ProviderConfig config) {
+ final AzureOpenAiStreamingChatModel.Builder builder = AzureOpenAiStreamingChatModel.builder()
+ .apiKey(config.apiKey())
+ .endpoint(config.endpoint())
+ .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
+ if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
+ if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
+ if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
+ if (config.temperature() != null) builder.temperature(config.temperature());
+ if (config.maxTokens() != null) builder.maxTokens(config.maxTokens());
+ return builder.build();
+ }
+
+ private static ChatModel buildAzureOpenAiChatModel(final ProviderConfig config) {
+ final AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
+ .apiKey(config.apiKey())
+ .endpoint(config.endpoint())
+ .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
+ if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
+ if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
+ if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
+ if (config.temperature() != null) builder.temperature(config.temperature());
+ if (config.maxTokens() != null) builder.maxTokens(config.maxTokens());
+ return builder.build();
+ }
+
+ private static EmbeddingModel buildAzureOpenAiEmbeddingModel(final ProviderConfig config) {
+ final AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
+ .apiKey(config.apiKey())
+ .endpoint(config.endpoint())
+ .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
+ if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
+ if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
+ if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
+ return builder.build();
+ }
+
+ private static ImageModel buildAzureOpenAiImageModel(final ProviderConfig config) {
+ final AzureOpenAiImageModel.Builder builder = AzureOpenAiImageModel.builder()
+ .apiKey(config.apiKey())
+ .endpoint(config.endpoint())
+ .deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
+ if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
+ if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
+ if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
+ if (config.size() != null) builder.size(config.size());
+ return builder.build();
+ }
+
}
diff --git a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java
index 04fc382a586..b0e1e0a26f1 100644
--- a/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java
+++ b/dotCMS/src/test/java/com/dotcms/ai/client/langchain4j/LangChain4jModelFactoryTest.java
@@ -39,6 +39,32 @@ public void test_buildChatModel_openai_missingApiKey_throws() {
assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
}
+ @Test
+ public void test_buildChatModel_azureOpenai_returnsModel() {
+ final ChatModel model = LangChain4jModelFactory.buildChatModel(azureOpenAiConfig("gpt-4o"));
+ assertNotNull(model);
+ }
+
+ @Test
+ public void test_buildChatModel_azureOpenai_missingApiKey_throws() {
+ final ProviderConfig config = ImmutableProviderConfig.builder()
+ .provider("azure_openai")
+ .model("gpt-4o")
+ .endpoint("https://my-company.openai.azure.com/")
+ .build();
+ assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
+ }
+
+ @Test
+ public void test_buildChatModel_azureOpenai_missingEndpoint_throws() {
+ final ProviderConfig config = ImmutableProviderConfig.builder()
+ .provider("azure_openai")
+ .model("gpt-4o")
+ .apiKey("test-key")
+ .build();
+ assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
+ }
+
@Test
public void test_buildChatModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
@@ -60,6 +86,12 @@ public void test_buildEmbeddingModel_openai_returnsModel() {
assertNotNull(model);
}
+ @Test
+ public void test_buildEmbeddingModel_azureOpenai_returnsModel() {
+ final EmbeddingModel model = LangChain4jModelFactory.buildEmbeddingModel(azureOpenAiConfig("text-embedding-ada-002"));
+ assertNotNull(model);
+ }
+
@Test
public void test_buildEmbeddingModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
@@ -81,6 +113,12 @@ public void test_buildImageModel_openai_returnsModel() {
assertNotNull(model);
}
+ @Test
+ public void test_buildImageModel_azureOpenai_returnsModel() {
+ final ImageModel model = LangChain4jModelFactory.buildImageModel(azureOpenAiConfig("dall-e-3"));
+ assertNotNull(model);
+ }
+
@Test
public void test_buildImageModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
@@ -99,4 +137,15 @@ private static ProviderConfig openAiConfig(final String model) {
.build();
}
+ private static ProviderConfig azureOpenAiConfig(final String model) {
+ return ImmutableProviderConfig.builder()
+ .provider("azure_openai")
+ .model(model)
+ .apiKey("test-key")
+ .endpoint("https://my-company.openai.azure.com/")
+ .deploymentName(model)
+ .apiVersion("2024-02-01")
+ .build();
+ }
+
}