Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import de.tudarmstadt.ukp.inception.documents.api.DocumentService;
import de.tudarmstadt.ukp.inception.documents.api.RepositoryProperties;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ToolLibraryExtensionPoint;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.LlmChatClientExtensionPoint;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.scheduling.SchedulingService;

Expand Down Expand Up @@ -87,9 +88,9 @@ public EncodingRegistry encodingRegistry()

@Bean
public EmbeddingService EmbeddingService(AssistantProperties aProperties,
OllamaClient aOllamaClient)
LlmChatClientExtensionPoint aChatClientExtensionPoint)
{
return new EmbeddingServiceImpl(aProperties, aOllamaClient);
return new EmbeddingServiceImpl(aProperties, aChatClientExtensionPoint);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

Expand All @@ -33,22 +35,27 @@
import org.springframework.context.event.EventListener;

import de.tudarmstadt.ukp.inception.assistant.config.AssistantProperties;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaEmbedRequest;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaOptions;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.LlmChatClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.LlmChatClientExtensionPoint;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.LlmEndpoint;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.ollama.client.OllamaLlmChatClient;

public class EmbeddingServiceImpl
implements EmbeddingService
{
private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private static final String OPT_NUM_CTX = "num_ctx";
private static final String OPT_SEED = "seed";

private final AssistantProperties properties;
private final OllamaClient ollamaClient;
private final LlmChatClientExtensionPoint chatClientExtensionPoint;

public EmbeddingServiceImpl(AssistantProperties aProperties, OllamaClient aOllamaClient)
public EmbeddingServiceImpl(AssistantProperties aProperties,
LlmChatClientExtensionPoint aChatClientExtensionPoint)
{
properties = aProperties;
ollamaClient = aOllamaClient;
chatClientExtensionPoint = aChatClientExtensionPoint;
}

@EventListener
Expand Down Expand Up @@ -96,23 +103,11 @@ public <T> List<Pair<T, float[]>> embed(Function<T, String> aExtractor, Iterable
objects.add(o);
}

var request = OllamaEmbedRequest.builder() //
.withModel(properties.getEmbedding().getModel()) //
.withInput(strings.toArray(String[]::new)) //
.withOption(OllamaOptions.NUM_CTX, properties.getEmbedding().getContextLength()) //
.withOption(OllamaOptions.SEED, properties.getEmbedding().getSeed()) //
// The following options should not be relevant for embeddings
// .withOption(OllamaOptions.TEMPERATURE, 0.0) //
// .withOption(OllamaOptions.TOP_P, 0.0) //
// .withOption(OllamaOptions.TOP_K, 0) //
// .withOption(OllamaOptions.REPEAT_PENALTY, 1.0) //
.build();

var response = ollamaClient.embed(properties.getUrl(), request);
var vectors = client().embed(endpoint(), strings, embeddingOptions());

var result = new ArrayList<Pair<T, float[]>>();
for (var i = 0; i < response.size(); i++) {
result.add(Pair.of(objects.get(i), response.get(i).getValue()));
for (var i = 0; i < vectors.size(); i++) {
result.add(Pair.of(objects.get(i), vectors.get(i)));
}
return result;
}
Expand All @@ -124,7 +119,6 @@ public List<Pair<String, float[]>> embed(String... aStrings) throws IOException

var strings = new ArrayList<String>();
for (var s : aStrings) {

s = removeEmptyLinesAndTrim(s);

if (s.isEmpty() || hasHighProportionOfShortSequences(s)
Expand All @@ -135,18 +129,37 @@ public List<Pair<String, float[]>> embed(String... aStrings) throws IOException
strings.add(s);
}

var request = OllamaEmbedRequest.builder() //
.withModel(properties.getEmbedding().getModel()) //
.withInput(strings.toArray(String[]::new)) //
.withOption(OllamaOptions.NUM_CTX, properties.getEmbedding().getContextLength()) //
.withOption(OllamaOptions.SEED, properties.getEmbedding().getSeed()) //
// The following options should not be relevant for embeddings
// .withOption(OllamaOptions.TEMPERATURE, 0.0) //
// .withOption(OllamaOptions.TOP_P, 0.0) //
// .withOption(OllamaOptions.TOP_K, 0) //
// .withOption(OllamaOptions.REPEAT_PENALTY, 1.0) //
.build();
return ollamaClient.embed(properties.getUrl(), request);
var vectors = client().embed(endpoint(), strings, embeddingOptions());

var result = new ArrayList<Pair<String, float[]>>();
for (var i = 0; i < vectors.size(); i++) {
result.add(Pair.of(strings.get(i), vectors.get(i)));
}
return result;
}

private LlmChatClient client()
{
// Provider is hardcoded to Ollama for now; once assistant config moves to UI-driven
// traits, this becomes traits.getProviderId().
return chatClientExtensionPoint.getExtension(OllamaLlmChatClient.ID) //
.orElseThrow(() -> new IllegalStateException(
"Ollama LLM client not registered — is the inception-imls-ollama module on "
+ "the classpath?"));
}

private LlmEndpoint endpoint()
{
return new LlmEndpoint(OllamaLlmChatClient.ID, properties.getUrl(),
properties.getEmbedding().getModel(), null);
}

private Map<String, Object> embeddingOptions()
{
var options = new LinkedHashMap<String, Object>();
options.put(OPT_NUM_CTX, properties.getEmbedding().getContextLength());
options.put(OPT_SEED, properties.getEmbedding().getSeed());
return options;
}

private void autoDetectEmbeddingDimension()
Expand All @@ -157,15 +170,14 @@ private void autoDetectEmbeddingDimension()
try {
LOG.info("Contacting [{}] to auto-detect dimension of model [{}]...",
properties.getUrl(), embeddingProperties.getModel());
var embedding = ollamaClient.embed(properties.getUrl(), OllamaEmbedRequest
.builder() //
.withModel(embeddingProperties.getModel()) //
.withInput(
"We just need to know the dimension of the generated embedding. Thanks!") //
.build()).get(0).getValue();
embeddingProperties.setDimension(embedding.length);
var vectors = client().embed(endpoint(),
List.of("We just need to know the dimension of the generated "
+ "embedding. Thanks!"),
null);
var dim = vectors.get(0).length;
embeddingProperties.setDimension(dim);
LOG.info("Auto-detected embedding dimension of model [{}]: {}",
embeddingProperties.getModel(), embeddingProperties.getDimension());
embeddingProperties.getModel(), dim);
}
catch (Exception e) {
if (LOG.isDebugEnabled()) {
Expand Down Expand Up @@ -229,5 +241,4 @@ static boolean hasHighProportionOfWhitespaceOrLineBreaks(String aString)
double proportion = (double) whitespaceOrLineBreakCount / totalChars;
return proportion > 0.5;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ public boolean isEvaluable()
return false;
}

@Override
public boolean isDeprecated()
{
// Hide from the recommender-tool dropdown: the assistant recommender is not
// user-configurable
// - its instance is created/managed by the assistant subsystem.
return true;
}

@Override
public RecommendationEngine build(Recommender aRecommender)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.assistant.tools;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.assistant.CommandDispatcher;

/**
* Snapshot of the assistant's per-chat-turn runtime state, captured into each
* {@link AssistantToolInvoker} when the tool registry is built. Any field may be {@code null} when
* the surrounding chat session does not have that context (e.g. no document open).
*/
public record AssistantRuntimeContext( //
User sessionOwner, //
Project project, //
SourceDocument document, //
String dataOwner, //
CommandDispatcher commandDispatcher)
{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.assistant.tools;

import static de.tudarmstadt.ukp.inception.recommendation.imls.llm.ToolUtils.getParameterName;
import static de.tudarmstadt.ukp.inception.recommendation.imls.llm.ToolUtils.isParameter;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.model.SourceDocument;
import de.tudarmstadt.ukp.inception.assistant.CommandDispatcher;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.AnnotationEditorContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.ToolInvoker;
import de.tudarmstadt.ukp.inception.recommendation.imls.llm.client.ToolDescriptor;
import de.tudarmstadt.ukp.inception.support.json.JSONUtil;
import tools.jackson.databind.JsonNode;

/**
* {@link ToolInvoker} backed by a {@code @Tool}-annotated Java {@link Method}, with binding for the
* assistant's runtime context. Captures the per-turn {@link AssistantRuntimeContext} at
* construction so {@link #invoke} is a pure (arguments) → result call.
* <p>
* Parameter binding mirrors what {@code MToolCall.invoke} did pre-abstraction:
* <ul>
* <li>{@code @ToolParam} parameters → Jackson-converted from the LLM-supplied JSON arguments.
* <li>{@link AnnotationEditorContext} → built per-call from the captured runtime context.
* <li>{@link Project} / {@link SourceDocument} / {@link CommandDispatcher} → captured runtime
* context.
* <li>Anything else → {@link IllegalStateException} at invocation time.
* </ul>
* Exceptions thrown by the target method are unwrapped from {@link InvocationTargetException} so
* callers see the original cause.
*/
public class AssistantToolInvoker
implements ToolInvoker
{
private final Object instance;
private final Method method;
private final ToolDescriptor descriptor;
private final AssistantRuntimeContext context;

public AssistantToolInvoker(Object aInstance, Method aMethod, AssistantRuntimeContext aContext)
{
instance = aInstance;
method = aMethod;
descriptor = ToolDescriptor.fromMethod(aMethod);
context = aContext;
}

@Override
public ToolDescriptor descriptor()
{
return descriptor;
}

@Override
public Object invoke(JsonNode aArguments) throws Exception
{
var mapper = JSONUtil.getObjectMapper();
var typeFactory = mapper.getTypeFactory();
var args = new ArrayList<>();

for (var param : method.getParameters()) {
if (isParameter(param)) {
var paramName = getParameterName(param);
var raw = aArguments != null ? aArguments.get(paramName) : null;
var type = typeFactory.constructType(param.getParameterizedType());
args.add(mapper.convertValue(raw, type));
continue;
}

args.add(resolveContextParameter(param.getType(), param.getName()));
}

try {
return method.invoke(instance, args.toArray());
}
catch (InvocationTargetException e) {
if (e.getCause() instanceof Exception cause) {
throw cause;
}
throw e;
}
}

private Object resolveContextParameter(Class<?> aType, String aParamName)
{
// Strict direction: the parameter type IS-A injectable type. Using the loose direction
// (param.getType().isAssignableFrom(KnownType.class)) — as the pre-abstraction
// MToolCall.invoke did — silently matched parameters declared as Object or other
// supertypes against whichever known type appeared first in the chain.
if (AnnotationEditorContext.class.isAssignableFrom(aType)) {
return AnnotationEditorContext.builder() //
.withSessionOwner(context.sessionOwner()) //
.withProject(context.project()) //
.withDocument(context.document()) //
.withDataOwner(context.dataOwner()) //
.build();
}
if (CommandDispatcher.class.isAssignableFrom(aType)) {
return context.commandDispatcher();
}
if (Project.class.isAssignableFrom(aType)) {
return context.project();
}
if (SourceDocument.class.isAssignableFrom(aType)) {
return context.document();
}
throw new IllegalStateException("Tool [" + descriptor.name() + "] declares parameter ["
+ aParamName + "] of unsupported type [" + aType.getName()
+ "]. Supported context types are: AnnotationEditorContext, "
+ "CommandDispatcher, Project, SourceDocument.");
}

@Override
public String toString()
{
return "AssistantToolInvoker[" + descriptor.name() + " -> " + method.toGenericString()
+ "]";
}
}
Loading
Loading