Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 211 additions & 26 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import base64
import copy
import inspect
import json
import random
import re
from collections.abc import AsyncGenerator
from typing import Any
from io import BytesIO
from pathlib import Path
from typing import Any, Literal
from urllib.parse import unquote, urlparse

import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI
Expand All @@ -14,6 +18,8 @@
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion_usage import CompletionUsage
from PIL import Image as PILImage
from PIL import UnidentifiedImageError

import astrbot.core.message.components as Comp
from astrbot import logger
Expand Down Expand Up @@ -133,6 +139,186 @@ def _context_contains_image(contexts: list[dict]) -> bool:
return True
return False

def _is_invalid_attachment_error(self, error: Exception) -> bool:
body = getattr(error, "body", None)
code: str | None = None
message: str | None = None
if isinstance(body, dict):
err_obj = body.get("error")
if isinstance(err_obj, dict):
raw_code = err_obj.get("code")
raw_message = err_obj.get("message")
code = raw_code.lower() if isinstance(raw_code, str) else None
message = raw_message.lower() if isinstance(raw_message, str) else None

if code == "invalid_attachment":
return True

text_sources: list[str] = []
if message:
text_sources.append(message)
if code:
text_sources.append(code)
text_sources.extend(map(str, self._extract_error_text_candidates(error)))

error_text = " ".join(text.lower() for text in text_sources if text)
if "invalid_attachment" in error_text:
return True
if "download attachment" in error_text and "404" in error_text:
return True
return False

@classmethod
def _encode_image_file_to_data_url(
cls,
image_path: str,
*,
mode: Literal["safe", "strict"],
) -> str | None:
try:
image_bytes = Path(image_path).read_bytes()
except OSError:
if mode == "strict":
raise
return None

try:
with PILImage.open(BytesIO(image_bytes)) as image:
image.verify()
image_format = str(image.format or "").upper()
except (OSError, UnidentifiedImageError):
if mode == "strict":
raise ValueError(f"Invalid image file: {image_path}")
return None

mime_type = {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"WEBP": "image/webp",
"BMP": "image/bmp",
}.get(image_format, "image/jpeg")
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}"

@staticmethod
def _file_uri_to_path(file_uri: str) -> str:
"""Normalize file URIs to paths.

`file://localhost/...` and drive-letter forms are treated as local paths.
Other non-empty hosts are preserved as UNC-style paths.
"""
parsed = urlparse(file_uri)
if parsed.scheme != "file":
return file_uri

netloc = unquote(parsed.netloc or "")
path = unquote(parsed.path or "")
if re.fullmatch(r"[A-Za-z]:", netloc):
return str(Path(f"{netloc}{path}"))
if re.match(r"^/[A-Za-z]:/", path):
path = path[1:]
if netloc and netloc != "localhost":
path = f"//{netloc}{path}"
return str(Path(path))

async def _image_ref_to_data_url(
self,
image_ref: str,
*,
mode: Literal["safe", "strict"] = "safe",
) -> str | None:
if image_ref.startswith("base64://"):
return image_ref.replace("base64://", "data:image/jpeg;base64,")

if image_ref.startswith("http"):
image_path = await download_image_by_url(image_ref)
elif image_ref.startswith("file://"):
image_path = self._file_uri_to_path(image_ref)
else:
image_path = image_ref

return self._encode_image_file_to_data_url(
image_path,
mode=mode,
)

async def _resolve_image_part(
self,
image_url: str,
*,
image_detail: str | None = None,
) -> dict | None:
if image_url.startswith("data:"):
image_payload = {"url": image_url}
else:
image_data = await self._image_ref_to_data_url(image_url, mode="safe")
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
image_payload = {"url": image_data}

if image_detail:
image_payload["detail"] = image_detail
return {
"type": "image_url",
"image_url": image_payload,
}

def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]:
if not isinstance(part, dict) or part.get("type") != "image_url":
return None, None

image_url_data = part.get("image_url")
if not isinstance(image_url_data, dict):
logger.warning("图片内容块格式无效,将保留原始内容。")
return None, None

url = image_url_data.get("url")
if not isinstance(url, str) or not url:
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
return None, None

image_detail = image_url_data.get("detail")
if not isinstance(image_detail, str):
image_detail = None
return url, image_detail

async def _transform_content_part(self, part: dict) -> dict:
url, image_detail = self._extract_image_part_info(part)
if not url:
return part

try:
resolved_part = await self._resolve_image_part(
url, image_detail=image_detail
)
except Exception as exc:
logger.warning(
"图片 %s 预处理失败,将保留原始内容。错误: %s",
url,
exc,
)
return part

return resolved_part or part

async def _materialize_message_image_parts(self, message: dict) -> dict:
content = message.get("content")
if not isinstance(content, list):
return {**message}

new_content = [await self._transform_content_part(part) for part in content]
return {**message, "content": new_content}

async def _materialize_context_image_parts(
self, context_query: list[dict]
) -> list[dict]:
return [
await self._materialize_message_image_parts(message)
for message in context_query
]

async def _fallback_to_text_only_and_retry(
self,
payloads: dict,
Expand Down Expand Up @@ -594,7 +780,7 @@ async def _prepare_chat_payload(
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts))
if new_record:
context_query.append(new_record)
if system_prompt:
Expand All @@ -612,6 +798,9 @@ async def _prepare_chat_payload(
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())

if self._context_contains_image(context_query):
context_query = await self._materialize_context_image_parts(context_query)

model = model or self.get_model()

payloads = {"messages": context_query, "model": model}
Expand Down Expand Up @@ -712,6 +901,18 @@ async def _handle_api_error(
"image_content_moderated",
image_fallback_used=True,
)
if self._is_invalid_attachment_error(e):
if image_fallback_used or not self._context_contains_image(context_query):
raise e
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
chosen_key,
available_api_keys,
func_tool,
"invalid_attachment",
image_fallback_used=True,
)

if (
"Function calling is not enabled" in str(e)
Expand Down Expand Up @@ -913,23 +1114,6 @@ async def assemble_context(
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""

async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}

# 构建内容块列表
content_blocks = []

Expand All @@ -949,7 +1133,9 @@ async def resolve_image_part(image_url: str) -> dict | None:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
image_part = await self._resolve_image_part(
part.image_url.url,
)
if image_part:
content_blocks.append(image_part)
else:
Expand All @@ -958,7 +1144,7 @@ async def resolve_image_part(image_url: str) -> dict | None:
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
image_part = await self._resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)

Expand All @@ -977,11 +1163,10 @@ async def resolve_image_part(image_url: str) -> dict | None:

async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
image_data = await self._image_ref_to_data_url(image_url, mode="strict")
if image_data is None:
raise RuntimeError(f"Failed to encode image data: {image_url}")
return image_data

async def terminate(self):
if self.client:
Expand Down
Loading
Loading