Skip to content
111 changes: 99 additions & 12 deletions unstract/sdk1/src/unstract/sdk1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from collections.abc import Callable, Generator, Mapping
from typing import cast
from typing import NoReturn, cast

import litellm

Expand Down Expand Up @@ -229,6 +229,7 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]:
)

response_text = response["choices"][0]["message"]["content"]
finish_reason = response["choices"][0].get("finish_reason")

self._record_usage(
self._cost_model or self.kwargs["model"],
Expand All @@ -237,6 +238,10 @@ def complete(self, prompt: str, **kwargs: object) -> dict[str, object]:
"complete",
)

# Handle refusal or empty content from the LLM provider
if response_text is None:
self._raise_for_empty_response(finish_reason)

# NOTE:
# The typecasting was required to stop the type checker from complaining.
# Improvements in readability are definitely welcome.
Expand Down Expand Up @@ -306,6 +311,7 @@ def stream_complete(
completion_kwargs = self.adapter.validate({**self.kwargs, **kwargs})
completion_kwargs.pop("cost_model", None)

has_yielded_content = False
for chunk in litellm.completion(
messages=messages,
stream=True,
Expand All @@ -322,17 +328,12 @@ def stream_complete(
"stream_complete",
)

text = chunk["choices"][0]["delta"].get("content", "")

if text:
if callback_manager and hasattr(callback_manager, "on_stream"):
callback_manager.on_stream(text)

# Yield LLMResponseCompat for backward compatibility
# with code expecting .delta
stream_response = LLMResponseCompat(text)
stream_response.delta = text
yield stream_response
response = self._process_stream_chunk(
chunk, callback_manager, has_yielded_content
)
if response is not None:
has_yielded_content = True
yield response

except LLMError:
# Already wrapped LLMError, re-raise as is
Expand Down Expand Up @@ -379,6 +380,7 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]:
**completion_kwargs,
)
response_text = response["choices"][0]["message"]["content"]
finish_reason = response["choices"][0].get("finish_reason")

self._record_usage(
self._cost_model or self.kwargs["model"],
Expand All @@ -387,6 +389,10 @@ async def acomplete(self, prompt: str, **kwargs: object) -> dict[str, object]:
"acomplete",
)

# Handle refusal or empty content from the LLM provider
if response_text is None:
self._raise_for_empty_response(finish_reason)

response_object = LLMResponseCompat(response_text)
response_object.raw = (
response # Attach raw litellm response for metadata access
Expand Down Expand Up @@ -490,6 +496,87 @@ def _record_usage(
kwargs={"provider": self.adapter.get_provider(), **self.platform_kwargs},
)

# Finish reasons indicating a safety/policy refusal across providers:
# - "refusal": Anthropic
# - "content_filter": OpenAI / Azure OpenAI
REFUSAL_FINISH_REASONS = {"refusal", "content_filter"}

def _raise_for_empty_response(self, finish_reason: str | None) -> NoReturn:
"""Raise an appropriate error when the LLM response content is None.

This typically happens when the LLM provider refuses to generate a
response (e.g. Anthropic's safety filters, OpenAI's content filter)
or returns an empty response.

Args:
finish_reason: The finish_reason from the LLM response.

Raises:
LLMError: With a descriptive message based on the finish_reason.
"""
if finish_reason in self.REFUSAL_FINISH_REASONS:
raise LLMError(
message=(
"The LLM refused to generate a response due to safety "
f"restrictions (finish_reason: {finish_reason!r}). "
"Please review your prompt and try again."
),
status_code=400,
)
raise LLMError(
message=(
f"The LLM returned an empty response "
f"(finish_reason: {finish_reason}). This may indicate "
f"the model could not generate content for the given prompt."
),
status_code=500,
)

def _process_stream_chunk(
self,
chunk: dict[str, object],
callback_manager: object | None,
has_yielded_content: bool = False,
) -> LLMResponseCompat | None:
"""Process a single streaming chunk and return a response if content.

Args:
chunk: A streaming chunk from litellm.
callback_manager: Optional callback manager for stream events.
has_yielded_content: Whether any content has already been yielded.

Returns:
LLMResponseCompat with the text chunk, or None if no content.

Raises:
LLMError: If the chunk indicates a refusal and no content has
been yielded yet. If content was already streamed, logs a
warning instead to avoid confusing late errors.
"""
if not chunk.get("choices"):
return None

finish_reason = chunk["choices"][0].get("finish_reason")
if finish_reason in self.REFUSAL_FINISH_REASONS:
if has_yielded_content:
logger.warning(
"[sdk1][LLM] Provider sent refusal after content was "
"already streamed. Partial content may have been returned."
)
return None
self._raise_for_empty_response(finish_reason)

text = chunk["choices"][0].get("delta", {}).get("content", "")
if not text:
return None

if callback_manager and hasattr(callback_manager, "on_stream"):
callback_manager.on_stream(text)

stream_response = LLMResponseCompat(text)
stream_response.delta = text
return stream_response

def _post_process_response(
self,
response_text: str,
Expand Down
4 changes: 3 additions & 1 deletion unstract/sdk1/src/unstract/sdk1/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,12 @@ def __init__(self, text: str) -> None:

def __str__(self) -> str:
"""Return text for string operations like join()."""
return self.text
return self.text or ""

def __repr__(self) -> str:
"""Return detailed representation with text preview."""
if self.text is None:
return "LLMResponseCompat(text=None)"
text_preview = self.text[:50] + "..." if len(self.text) > 50 else self.text
return f"LLMResponseCompat(text={text_preview!r})"

Expand Down
Loading