diff --git a/backend/uv.lock b/backend/uv.lock index 28ef59acb..66ef99df8 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -3997,6 +3997,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" }, diff --git a/platform-service/uv.lock b/platform-service/uv.lock index 380399f9e..66bd331ad 100644 --- a/platform-service/uv.lock +++ b/platform-service/uv.lock @@ -2744,6 +2744,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" }, diff --git a/prompt-service/uv.lock b/prompt-service/uv.lock index 03d3765d3..dba297a76 100644 --- a/prompt-service/uv.lock +++ b/prompt-service/uv.lock @@ -367,9 +367,9 @@ wheels = [ name = "chardet" version = "5.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/f7b6ab21ec75897ed80c17d79b15951a719226b9fababf1e40ea74d69079/chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7", size = 2069618 } +sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/f7b6ab21ec75897ed80c17d79b15951a719226b9fababf1e40ea74d69079/chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7", size = 2069618, upload-time = "2023-08-01T19:23:02.662Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/38/6f/f5fbc992a329ee4e0f288c1fe0e2ad9485ed064cac731ed2fe47dcc38cbf/chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970", size = 199385 }, + { url = "https://files.pythonhosted.org/packages/38/6f/f5fbc992a329ee4e0f288c1fe0e2ad9485ed064cac731ed2fe47dcc38cbf/chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970", size = 199385, upload-time = "2023-08-01T19:23:00.661Z" }, ] [[package]] @@ -475,9 +475,9 @@ dependencies = [ { name = "mbstrdecoder" }, { name = "typepy", extra = ["datetime"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/81/8c8b64ae873cb9014815214c07b63b12e3b18835780fb342223cfe3fe7d8/dataproperty-1.1.0.tar.gz", hash = "sha256:b038437a4097d1a1c497695c3586ea34bea67fdd35372b9a50f30bf044d77d04", size = 42574 } +sdist = { url = "https://files.pythonhosted.org/packages/0b/81/8c8b64ae873cb9014815214c07b63b12e3b18835780fb342223cfe3fe7d8/dataproperty-1.1.0.tar.gz", hash = "sha256:b038437a4097d1a1c497695c3586ea34bea67fdd35372b9a50f30bf044d77d04", size = 42574, upload-time = "2024-12-31T14:37:26.033Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/c2/e12e95e289e6081a40454199ab213139ef16a528c7c86432de545b05a23a/DataProperty-1.1.0-py3-none-any.whl", hash = "sha256:c61fcb2e2deca35e6d1eb1f251a7f22f0dcde63e80e61f0cc18c19f42abfd25b", size = 27581 }, + { url = "https://files.pythonhosted.org/packages/21/c2/e12e95e289e6081a40454199ab213139ef16a528c7c86432de545b05a23a/DataProperty-1.1.0-py3-none-any.whl", hash = "sha256:c61fcb2e2deca35e6d1eb1f251a7f22f0dcde63e80e61f0cc18c19f42abfd25b", size = 27581, upload-time = "2024-12-31T14:37:22.657Z" }, ] [[package]] @@ -1480,9 +1480,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "chardet" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/31/ab/05ae008357c8bdb6245ebf8a101d99f26c096e0ea20800b318153da23796/mbstrdecoder-1.1.4.tar.gz", hash = "sha256:8105ef9cf6b7d7d69fe7fd6b68a2d8f281ca9b365d7a9b670be376b2e6c81b21", size = 14527 } +sdist = { url = "https://files.pythonhosted.org/packages/31/ab/05ae008357c8bdb6245ebf8a101d99f26c096e0ea20800b318153da23796/mbstrdecoder-1.1.4.tar.gz", hash = "sha256:8105ef9cf6b7d7d69fe7fd6b68a2d8f281ca9b365d7a9b670be376b2e6c81b21", size = 14527, upload-time = "2025-01-18T10:07:31.089Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/ac/5ce64a1d4cce00390beab88622a290420401f1cabf05caf2fc0995157c21/mbstrdecoder-1.1.4-py3-none-any.whl", hash = "sha256:03dae4ec50ec0d2ff4743e63fdbd5e0022815857494d35224b60775d3d934a8c", size = 7933 }, + { url = "https://files.pythonhosted.org/packages/30/ac/5ce64a1d4cce00390beab88622a290420401f1cabf05caf2fc0995157c21/mbstrdecoder-1.1.4-py3-none-any.whl", hash = "sha256:03dae4ec50ec0d2ff4743e63fdbd5e0022815857494d35224b60775d3d934a8c", size = 7933, upload-time = "2025-01-18T10:07:29.562Z" }, ] [[package]] @@ -1812,9 +1812,9 @@ wheels = [ name = "pathvalidate" version = "3.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fa/2a/52a8da6fe965dea6192eb716b357558e103aea0a1e9a8352ad575a8406ca/pathvalidate-3.3.1.tar.gz", hash = "sha256:b18c07212bfead624345bb8e1d6141cdcf15a39736994ea0b94035ad2b1ba177", size = 63262 } +sdist = { url = "https://files.pythonhosted.org/packages/fa/2a/52a8da6fe965dea6192eb716b357558e103aea0a1e9a8352ad575a8406ca/pathvalidate-3.3.1.tar.gz", hash = "sha256:b18c07212bfead624345bb8e1d6141cdcf15a39736994ea0b94035ad2b1ba177", size = 63262, upload-time = "2025-06-15T09:07:20.736Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/70/875f4a23bfc4731703a5835487d0d2fb999031bd415e7d17c0ae615c18b7/pathvalidate-3.3.1-py3-none-any.whl", hash = "sha256:5263baab691f8e1af96092fa5137ee17df5bdfbd6cff1fcac4d6ef4bc2e1735f", size = 24305 }, + { url = "https://files.pythonhosted.org/packages/9a/70/875f4a23bfc4731703a5835487d0d2fb999031bd415e7d17c0ae615c18b7/pathvalidate-3.3.1-py3-none-any.whl", hash = "sha256:5263baab691f8e1af96092fa5137ee17df5bdfbd6cff1fcac4d6ef4bc2e1735f", size = 24305, upload-time = "2025-06-15T09:07:19.117Z" }, ] [[package]] @@ -2164,9 +2164,9 @@ dependencies = [ { name = "tcolorpy" }, { name = "typepy", extra = ["datetime"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f6/a1/617730f290f04d347103ab40bf67d317df6691b14746f6e1ea039fb57062/pytablewriter-1.2.1.tar.gz", hash = "sha256:7bd0f4f397e070e3b8a34edcf1b9257ccbb18305493d8350a5dbc9957fced959", size = 619241 } +sdist = { url = "https://files.pythonhosted.org/packages/f6/a1/617730f290f04d347103ab40bf67d317df6691b14746f6e1ea039fb57062/pytablewriter-1.2.1.tar.gz", hash = "sha256:7bd0f4f397e070e3b8a34edcf1b9257ccbb18305493d8350a5dbc9957fced959", size = 619241, upload-time = "2025-01-01T15:37:00.04Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/4c/c199512f01c845dfe5a7840ab3aae6c60463b5dc2a775be72502dfd9170a/pytablewriter-1.2.1-py3-none-any.whl", hash = "sha256:e906ff7ff5151d70a5f66e0f7b75642a7f2dce8d893c265b79cc9cf6bc04ddb4", size = 91083 }, + { url = "https://files.pythonhosted.org/packages/21/4c/c199512f01c845dfe5a7840ab3aae6c60463b5dc2a775be72502dfd9170a/pytablewriter-1.2.1-py3-none-any.whl", hash = "sha256:e906ff7ff5151d70a5f66e0f7b75642a7f2dce8d893c265b79cc9cf6bc04ddb4", size = 91083, upload-time = "2025-01-01T15:36:55.63Z" }, ] [[package]] @@ -2191,9 +2191,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/de/b4/0b378b7bf26a8ae161c3890c0b48a91a04106c5713ce81b4b080ea2f4f18/pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3", size = 46920 } +sdist = { url = "https://files.pythonhosted.org/packages/de/b4/0b378b7bf26a8ae161c3890c0b48a91a04106c5713ce81b4b080ea2f4f18/pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3", size = 46920, upload-time = "2024-07-17T17:39:34.617Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/82/62e2d63639ecb0fbe8a7ee59ef0bc69a4669ec50f6d3459f74ad4e4189a2/pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2", size = 17663 }, + { url = "https://files.pythonhosted.org/packages/ee/82/62e2d63639ecb0fbe8a7ee59ef0bc69a4669ec50f6d3459f74ad4e4189a2/pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2", size = 17663, upload-time = "2024-07-17T17:39:32.478Z" }, ] [[package]] @@ -2219,9 +2219,9 @@ dependencies = [ { name = "tcolorpy" }, { name = "typepy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f6/63/80d92406f952eee7856114c18ad269192ce179d576343fbcf0679c9a2bdd/pytest_md_report-0.7.0.tar.gz", hash = "sha256:3b832eaf660b470b5742e58d9c9a5e312b0712a7012d251cc04e908a81ce3c96", size = 284275 } +sdist = { url = "https://files.pythonhosted.org/packages/f6/63/80d92406f952eee7856114c18ad269192ce179d576343fbcf0679c9a2bdd/pytest_md_report-0.7.0.tar.gz", hash = "sha256:3b832eaf660b470b5742e58d9c9a5e312b0712a7012d251cc04e908a81ce3c96", size = 284275, upload-time = "2025-05-02T03:07:09.835Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/96/6125a1a963b3864b4a3981c9fce27df60370a889c2e79071e33d325f14d9/pytest_md_report-0.7.0-py3-none-any.whl", hash = "sha256:90ccb3b5b9587d064ec83974db34619addd11d3de6ec7056fab4feb3af69ae94", size = 14250 }, + { url = "https://files.pythonhosted.org/packages/de/96/6125a1a963b3864b4a3981c9fce27df60370a889c2e79071e33d325f14d9/pytest_md_report-0.7.0-py3-none-any.whl", hash = "sha256:90ccb3b5b9587d064ec83974db34619addd11d3de6ec7056fab4feb3af69ae94", size = 14250, upload-time = "2025-05-02T03:07:06.539Z" }, ] [[package]] @@ -2545,18 +2545,18 @@ dependencies = [ { name = "dataproperty" }, { name = "typepy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b2/35/171c8977162f1163368406deddde4c59673b62bd0cb2f34948a02effb075/tabledata-1.3.4.tar.gz", hash = "sha256:e9649cab129d718f3bff4150083b77f8a78c30f6634a30caf692b10fdc60cb97", size = 25074 } +sdist = { url = "https://files.pythonhosted.org/packages/b2/35/171c8977162f1163368406deddde4c59673b62bd0cb2f34948a02effb075/tabledata-1.3.4.tar.gz", hash = "sha256:e9649cab129d718f3bff4150083b77f8a78c30f6634a30caf692b10fdc60cb97", size = 25074, upload-time = "2024-12-31T14:12:31.198Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/64/fa4160151976ee4b2cf0c1217a99443ffaeb991956feddfeac9eee9952f8/tabledata-1.3.4-py3-none-any.whl", hash = "sha256:1f56e433bfdeb89f4487abfa48c4603a3b07c5d3a3c7e05ff73dd018c24bd0d4", size = 11820 }, + { url = "https://files.pythonhosted.org/packages/08/64/fa4160151976ee4b2cf0c1217a99443ffaeb991956feddfeac9eee9952f8/tabledata-1.3.4-py3-none-any.whl", hash = "sha256:1f56e433bfdeb89f4487abfa48c4603a3b07c5d3a3c7e05ff73dd018c24bd0d4", size = 11820, upload-time = "2024-12-31T14:12:28.584Z" }, ] [[package]] name = "tcolorpy" version = "0.1.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/80/cc/44f2d81d8f9093aad81c3467a5bf5718d2b5f786e887b6e4adcfc17ec6b9/tcolorpy-0.1.7.tar.gz", hash = "sha256:0fbf6bf238890bbc2e32662aa25736769a29bf6d880328f310c910a327632614", size = 299437 } +sdist = { url = "https://files.pythonhosted.org/packages/80/cc/44f2d81d8f9093aad81c3467a5bf5718d2b5f786e887b6e4adcfc17ec6b9/tcolorpy-0.1.7.tar.gz", hash = "sha256:0fbf6bf238890bbc2e32662aa25736769a29bf6d880328f310c910a327632614", size = 299437, upload-time = "2024-12-29T15:24:23.847Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/a2/ed023f2edd1e011b4d99b6727bce8253842d66c3fbf9ed0a26fc09a92571/tcolorpy-0.1.7-py3-none-any.whl", hash = "sha256:26a59d52027e175a37e0aba72efc99dda43f074db71f55b316d3de37d3251378", size = 8096 }, + { url = "https://files.pythonhosted.org/packages/05/a2/ed023f2edd1e011b4d99b6727bce8253842d66c3fbf9ed0a26fc09a92571/tcolorpy-0.1.7-py3-none-any.whl", hash = "sha256:26a59d52027e175a37e0aba72efc99dda43f074db71f55b316d3de37d3251378", size = 8096, upload-time = "2024-12-29T15:24:21.33Z" }, ] [[package]] @@ -2628,9 +2628,9 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mbstrdecoder" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/79/59/4c39942077d7de285f762a91024dbda731be693591732977358f77d120fb/typepy-1.3.4.tar.gz", hash = "sha256:89c1f66de6c6133209c43a94d23431d320ba03ef5db18f241091ea594035d9de", size = 39558 } +sdist = { url = "https://files.pythonhosted.org/packages/79/59/4c39942077d7de285f762a91024dbda731be693591732977358f77d120fb/typepy-1.3.4.tar.gz", hash = "sha256:89c1f66de6c6133209c43a94d23431d320ba03ef5db18f241091ea594035d9de", size = 39558, upload-time = "2024-12-29T09:18:15.774Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/31/e393c3830bdedd01735bd195c85ac3034b6bcaf6c18142bab60a4047ca36/typepy-1.3.4-py3-none-any.whl", hash = "sha256:d5ed3e0c7f49521bff0603dd08cf8d453371cf68d65a29d3d0038552ccc46e2e", size = 31449 }, + { url = "https://files.pythonhosted.org/packages/ee/31/e393c3830bdedd01735bd195c85ac3034b6bcaf6c18142bab60a4047ca36/typepy-1.3.4-py3-none-any.whl", hash = "sha256:d5ed3e0c7f49521bff0603dd08cf8d453371cf68d65a29d3d0038552ccc46e2e", size = 31449, upload-time = "2024-12-29T09:18:13.135Z" }, ] [package.optional-dependencies] @@ -2895,6 +2895,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" }, diff --git a/unstract/filesystem/uv.lock b/unstract/filesystem/uv.lock index 969dffa4f..4f53174af 100644 --- a/unstract/filesystem/uv.lock +++ b/unstract/filesystem/uv.lock @@ -1880,6 +1880,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" }, diff --git a/unstract/sdk1/pyproject.toml b/unstract/sdk1/pyproject.toml index 72ee06f6c..1fb8e10dc 100644 --- a/unstract/sdk1/pyproject.toml +++ b/unstract/sdk1/pyproject.toml @@ -73,6 +73,7 @@ test = [ "pytest==8.3.3", "pytest-asyncio>=0.23.0", "pytest-mock==3.14.0", + "pytest-asyncio>=0.24.0", "pytest-cov>=6.0.0", "pytest-md-report>=0.6.2", ] diff --git a/unstract/sdk1/src/unstract/sdk1/embedding.py b/unstract/sdk1/src/unstract/sdk1/embedding.py index 3c718d12f..e54a09339 100644 --- a/unstract/sdk1/src/unstract/sdk1/embedding.py +++ b/unstract/sdk1/src/unstract/sdk1/embedding.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import litellm +import unstract.sdk1.patches.litellm_cohere_timeout # noqa: F401 from llama_index.core.embeddings import BaseEmbedding from pydantic import ValidationError from unstract.sdk1.adapters.constants import Common diff --git a/unstract/sdk1/src/unstract/sdk1/patches/__init__.py b/unstract/sdk1/src/unstract/sdk1/patches/__init__.py new file mode 100644 index 000000000..cdf7246b7 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/patches/__init__.py @@ -0,0 +1,7 @@ +"""Monkey-patches for third-party library bugs. + +Patches in this package are applied via side-effect imports. +Currently activated from unstract.sdk1.embedding — any code path +that reaches Bedrock Cohere embeddings without going through that +module will NOT have patches active. +""" diff --git a/unstract/sdk1/src/unstract/sdk1/patches/litellm_cohere_timeout.py b/unstract/sdk1/src/unstract/sdk1/patches/litellm_cohere_timeout.py new file mode 100644 index 000000000..4090881c3 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/patches/litellm_cohere_timeout.py @@ -0,0 +1,218 @@ +"""Monkey-patch for litellm's cohere embed handler timeout bug. + +Bug: litellm.llms.cohere.embed.handler.embedding() and async_embedding() +receive a `timeout` parameter but don't forward it to client.post(), +causing "Connection timed out after None seconds" errors. + +Affected litellm version: 1.82.3 (also present on latest main as of +2026-03-10). + +Activation: This patch is imported as a side-effect from +unstract.sdk1.embedding. Any code path that invokes Bedrock Cohere +embeddings without going through unstract.sdk1.embedding will NOT +have this patch active. + +#TODO Remove this patch when litellm ships a fix upstream. +Issue link: https://github.com/BerriAI/litellm/issues/14635 +""" + +import importlib.metadata +import logging + +from packaging.version import Version + +logger = logging.getLogger(__name__) + +# --- Version guard --- +# Only apply the patch on the exact litellm version it was written for. +# Any other version (newer or older) skips the patch with a visible +# warning so engineers know to verify compatibility. +_PATCHED_LITELLM_VERSION = "1.82.3" +_litellm_version = importlib.metadata.version("litellm") +_SKIP_PATCH = Version(_litellm_version) != Version(_PATCHED_LITELLM_VERSION) +if _SKIP_PATCH: + logger.warning( + "litellm_cohere_timeout patch was SKIPPED — not applied. " + "Current litellm version: %s. " + "Patch was written for: %s. " + "Please verify the upstream fix and remove this module.", + _litellm_version, + _PATCHED_LITELLM_VERSION, + ) +else: + # Private litellm imports are deferred to here so they are only + # loaded when the patch will actually be applied. + import json + from collections.abc import Callable + + import httpx + import litellm + import litellm.llms.bedrock.embed.embedding as _bedrock_embed + import litellm.llms.cohere.embed.handler as _cohere_handler + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + ) + from litellm.llms.cohere.embed.handler import ( + validate_environment, + ) + from litellm.llms.cohere.embed.v1_transformation import ( + CohereEmbeddingConfig, + ) + from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, + ) + from litellm.types.llms.bedrock import CohereEmbeddingRequest + from litellm.types.utils import EmbeddingResponse + + _DEFAULT_TIMEOUT = httpx.Timeout(None) + + # Copied from litellm 1.82.3 cohere/embed/handler.py async_embedding(). + # ONLY CHANGE: Added timeout=timeout to the client.post() call. + # Source: litellm/llms/cohere/embed/handler.py::async_embedding + async def _patched_async_embedding( # type: ignore[return] # noqa: ANN202 + model: str, + data: dict | CohereEmbeddingRequest, + input: list, + model_response: litellm.utils.EmbeddingResponse, + timeout: float | httpx.Timeout | None, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + api_base: str, + api_key: str | None, + headers: dict, + encoding: Callable, + client: AsyncHTTPHandler | None = None, + ): + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": api_base, + }, + ) + + if client is None: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.COHERE, + params={"timeout": timeout}, + ) + + try: + response = await client.post( + api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, # ONLY CHANGE: forward timeout to client + ) + except httpx.HTTPStatusError as e: + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=e.response.text, + ) + raise e + except Exception as e: + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + + return CohereEmbeddingConfig()._transform_response( + response=response, + api_key=api_key, + logging_obj=logging_obj, + data=data, + model_response=model_response, + model=model, + encoding=encoding, + input=input, + ) + + # Copied from litellm 1.82.3 cohere/embed/handler.py embedding(). + # ONLY CHANGE: Added timeout=timeout to the client.post() call. + # Source: litellm/llms/cohere/embed/handler.py::embedding + def _patched_embedding( # type: ignore[return] # noqa: ANN202 + model: str, + input: list, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + headers: dict, + encoding: object, + data: dict | CohereEmbeddingRequest | None = None, + complete_api_base: str | None = None, + api_key: str | None = None, + aembedding: bool | None = None, + timeout: float | httpx.Timeout | None = _DEFAULT_TIMEOUT, + client: HTTPHandler | AsyncHTTPHandler | None = None, + ): + headers = validate_environment(api_key, headers=headers) + embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" + + data = data or CohereEmbeddingConfig()._transform_request( + model=model, input=input, inference_params=optional_params + ) + + if aembedding is True: + return _patched_async_embedding( + model=model, + data=data, + input=input, + model_response=model_response, + timeout=timeout, + logging_obj=logging_obj, + optional_params=optional_params, + api_base=embed_url, + api_key=api_key, + headers=headers, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), + ) + + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(concurrent_limit=1) + + response = client.post( + embed_url, + headers=headers, + data=json.dumps(data), + timeout=timeout, # ONLY CHANGE: forward timeout to client + ) + + return CohereEmbeddingConfig()._transform_response( + response=response, + api_key=api_key, + logging_obj=logging_obj, + data=data, + model_response=model_response, + model=model, + encoding=encoding, + input=input, + ) + + # Apply the monkey-patch to both the source module and any existing + # direct bindings (e.g. bedrock's `from ... import embedding as + # cohere_embedding`), since direct imports capture a reference at + # import time and won't see module-level replacements. + _cohere_handler.async_embedding = _patched_async_embedding + _cohere_handler.embedding = _patched_embedding + _bedrock_embed.cohere_embedding = _patched_embedding + logger.info("Applied litellm cohere embed timeout patch") diff --git a/unstract/sdk1/tests/patches/__init__.py b/unstract/sdk1/tests/patches/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unstract/sdk1/tests/patches/test_litellm_cohere_timeout.py b/unstract/sdk1/tests/patches/test_litellm_cohere_timeout.py new file mode 100644 index 000000000..dc9f7c751 --- /dev/null +++ b/unstract/sdk1/tests/patches/test_litellm_cohere_timeout.py @@ -0,0 +1,180 @@ +"""Tests for the litellm cohere embed timeout monkey-patch.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from unstract.sdk1.patches.litellm_cohere_timeout import ( + _patched_async_embedding, + _patched_embedding, +) + + +@pytest.fixture +def mock_logging_obj() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def mock_http_handler() -> MagicMock: + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + mock = MagicMock(spec=HTTPHandler) + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": [[0.1, 0.2]], + "id": "test", + "response_type": "embedding_floats", + "texts": ["hello"], + } + mock.post.return_value = mock_response + return mock + + +class TestPatchedEmbeddingSyncTimeoutForwarding: + """Verify sync embedding forwards timeout to client.post().""" + + def test_timeout_passed_to_client_post( + self, + mock_logging_obj: MagicMock, + mock_http_handler: MagicMock, + ) -> None: + timeout_value = 600 + + with ( + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.validate_environment", + side_effect=lambda api_key, headers: headers, + ), + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.CohereEmbeddingConfig" + ) as mock_config, + ): + mock_config.return_value._transform_response.return_value = MagicMock() + mock_config.return_value._transform_request.return_value = { + "texts": ["hello"], + "input_type": "search_document", + } + + _patched_embedding( + model="cohere.embed-multilingual-v3", + input=["hello"], + model_response=MagicMock(), + logging_obj=mock_logging_obj, + optional_params={}, + headers={}, + encoding=MagicMock(), + timeout=timeout_value, + client=mock_http_handler, + ) + + mock_http_handler.post.assert_called_once() + call_kwargs = mock_http_handler.post.call_args + assert call_kwargs.kwargs.get("timeout") is timeout_value + + def test_none_timeout_passed_to_client_post( + self, + mock_logging_obj: MagicMock, + mock_http_handler: MagicMock, + ) -> None: + with ( + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.validate_environment", + side_effect=lambda api_key, headers: headers, + ), + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.CohereEmbeddingConfig" + ) as mock_config, + ): + mock_config.return_value._transform_response.return_value = MagicMock() + mock_config.return_value._transform_request.return_value = { + "texts": ["hello"], + "input_type": "search_document", + } + + _patched_embedding( + model="cohere.embed-multilingual-v3", + input=["hello"], + model_response=MagicMock(), + logging_obj=mock_logging_obj, + optional_params={}, + headers={}, + encoding=MagicMock(), + timeout=None, + client=mock_http_handler, + ) + + call_kwargs = mock_http_handler.post.call_args + assert ( + "timeout" in call_kwargs.kwargs + ), "timeout kwarg must always be passed to client.post()" + assert ( + call_kwargs.kwargs["timeout"] is None + ), f"Expected timeout=None, got timeout={call_kwargs.kwargs['timeout']}" + + def test_httpx_timeout_object_forwarded( + self, + mock_logging_obj: MagicMock, + mock_http_handler: MagicMock, + ) -> None: + timeout_obj = httpx.Timeout(30.0) + + with ( + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.validate_environment", + side_effect=lambda api_key, headers: headers, + ), + patch( + "unstract.sdk1.patches.litellm_cohere_timeout.CohereEmbeddingConfig" + ) as mock_config, + ): + mock_config.return_value._transform_response.return_value = MagicMock() + mock_config.return_value._transform_request.return_value = { + "texts": ["hello"], + "input_type": "search_document", + } + + _patched_embedding( + model="cohere.embed-multilingual-v3", + input=["hello"], + model_response=MagicMock(), + logging_obj=mock_logging_obj, + optional_params={}, + headers={}, + encoding=MagicMock(), + timeout=timeout_obj, + client=mock_http_handler, + ) + + call_kwargs = mock_http_handler.post.call_args + assert call_kwargs.kwargs.get("timeout") is timeout_obj + + +class TestMonkeyPatchApplied: + """Verify the monkey-patch is correctly wired.""" + + def test_cohere_handler_patched(self) -> None: + import litellm.llms.cohere.embed.handler as handler + + assert handler.embedding is _patched_embedding + assert handler.async_embedding is _patched_async_embedding + + def test_bedrock_handler_patched(self) -> None: + import litellm.llms.bedrock.embed.embedding as bedrock + + assert bedrock.cohere_embedding is _patched_embedding + + def test_patch_module_loaded_via_embedding_import(self) -> None: + """Verify unstract.sdk1.embedding causes the patch module to load. + + The binding assertions (handler.embedding is _patched_embedding) + are covered by the other tests in this class. This test only + verifies that the side-effect import line in embedding.py + exists and results in the patch module being present in + sys.modules. + """ + import sys + + import unstract.sdk1.embedding # noqa: F401 + + assert "unstract.sdk1.patches.litellm_cohere_timeout" in sys.modules diff --git a/unstract/sdk1/uv.lock b/unstract/sdk1/uv.lock index 8936864de..1d9baae01 100644 --- a/unstract/sdk1/uv.lock +++ b/unstract/sdk1/uv.lock @@ -2810,6 +2810,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" }, diff --git a/workers/uv.lock b/workers/uv.lock index 6847b1a01..6f223ae55 100644 --- a/workers/uv.lock +++ b/workers/uv.lock @@ -4759,6 +4759,7 @@ test = [ { name = "parameterized", specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, { name = "pytest-md-report", specifier = ">=0.6.2" }, { name = "pytest-mock", specifier = "==3.14.0" },