Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/litserve/callbacks/defaults/metric_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import time
import typing

Expand All @@ -6,16 +7,18 @@
if typing.TYPE_CHECKING:
from litserve import LitAPI

logger = logging.getLogger(__name__)


class PredictionTimeLogger(Callback):
def on_before_predict(self, lit_api: "LitAPI"):
self._start_time = time.perf_counter()

def on_after_predict(self, lit_api: "LitAPI"):
elapsed = time.perf_counter() - self._start_time
print(f"Prediction took {elapsed:.2f} seconds", flush=True)
logger.info(f"Prediction took {elapsed:.2f} seconds")


class RequestTracker(Callback):
def on_request(self, active_requests: int, **kwargs):
print(f"Active requests: {active_requests}", flush=True)
logger.info(f"Active requests: {active_requests}")
13 changes: 7 additions & 6 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,8 @@ def run(
)

if not self._disable_openapi_url:
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
host = "127.0.0.1" if sys.platform == "win32" else "0.0.0.0"
logger.info(f"Swagger UI is available at http://{host}:{port}/docs")

if self._monitor_workers:
self._start_worker_monitoring(manager, uvicorn_workers)
Expand Down Expand Up @@ -1655,17 +1656,17 @@ def monitor():
resp.response_queue.append((None, LitAPIStatus.ERROR))

resp.event.set()
print(f"[monoriting] Marked {uid} set")
logger.info(f"[monoriting] Marked {uid} set")

print(f"[monoriting] Worker {worker_id} is dead. Restarting it")
logger.info(f"[monoriting] Worker {worker_id} is dead. Restarting it")
lit_api = self.litapi_connector.lit_apis[lit_api_id]
self.inference_workers[idx] = self.launch_single_inference_worker(lit_api, worker_id)
print(f"[monoriting] Worker {worker_id} has been started.")
logger.info(f"[monoriting] Worker {worker_id} has been started.")

time.sleep(self.monitor_internal)

except Exception as e:
print(e)
except Exception:
logger.exception("Monitoring worker crashed")

t = threading.Thread(target=monitor, daemon=True, name="litserve-monitoring")
t.start()
2 changes: 1 addition & 1 deletion src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def pre_setup(self, lit_api: "LitAPI"):

def setup(self, server: "LitServer"):
super().setup(server)
print("OpenAI spec setup complete")
logger.info("OpenAI spec setup complete")

def as_async(self) -> "_AsyncOpenAISpecWrapper":
return _AsyncOpenAISpecWrapper(self)
Expand Down
2 changes: 1 addition & 1 deletion src/litserve/specs/openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def pre_setup(self, lit_api: "LitAPI"):

def setup(self, server: "LitServer"):
super().setup(server)
print("OpenAI Embedding Spec is ready.")
logger.info("OpenAI Embedding Spec is ready.")

def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> list[str]:
return request.input
Expand Down
75 changes: 54 additions & 21 deletions tests/unit/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import io
import logging
import re
import time

Expand Down Expand Up @@ -54,36 +56,65 @@ def test_callback(capfd):
assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}"


def test_metric_logger(capfd):
cb = PredictionTimeLogger()
cb_runner = CallbackRunner()
cb_runner._add_callbacks(cb)
assert cb_runner._callbacks == [cb], "Callback not added to runner"
cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None)
cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None)

captured = capfd.readouterr()
def test_metric_logger():
# litserve logger has propagate=False so we add a temporary handler directly
log_capture = io.StringIO()
handler = logging.StreamHandler(log_capture)
metric_logger = logging.getLogger("litserve.callbacks.defaults.metric_callback")
metric_logger.addHandler(handler)
try:
cb = PredictionTimeLogger()
cb_runner = CallbackRunner()
cb_runner._add_callbacks(cb)
assert cb_runner._callbacks == [cb], "Callback not added to runner"
cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None)
cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None)
finally:
metric_logger.removeHandler(handler)

output = log_capture.getvalue()
pattern = r"Prediction took \d+\.\d{2} seconds"
assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}"
assert re.search(pattern, output), f"Expected pattern not found in log: {output}"


@pytest.fixture
def metric_log_capture():
"""Add a StringIO handler to the litserve metric logger and yield the stream.

litserve's logger has propagate=False so capfd/caplog won't intercept it. For callbacks that fire in the main
process (e.g. RequestTracker), this is the only reliable way to capture log output on all platforms.

"""
stream = io.StringIO()
handler = logging.StreamHandler(stream)
metric_logger = logging.getLogger("litserve.callbacks.defaults.metric_callback")
metric_logger.addHandler(handler)
yield stream
metric_logger.removeHandler(handler)


@pytest.mark.asyncio
async def test_request_tracker(capfd):
async def test_request_tracker(metric_log_capture):
lit_api = SlowAPI()

server = ls.LitServer(lit_api, track_requests=False, callbacks=[RequestTracker()])
await run_simple_request(server, 1)
captured = capfd.readouterr()
assert "Active requests: None" in captured.out, f"Expected pattern not found in output: {captured.out}"
assert "Active requests: None" in metric_log_capture.getvalue(), (
f"Expected pattern not found in log: {metric_log_capture.getvalue()}"
)

metric_log_capture.truncate(0)
metric_log_capture.seek(0)

server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()])
await run_simple_request(server, 4)
captured = capfd.readouterr()
assert "Active requests: 4" in captured.out, f"Expected pattern not found in output: {captured.out}"
assert "Active requests: 4" in metric_log_capture.getvalue(), (
f"Expected pattern not found in log: {metric_log_capture.getvalue()}"
)


@pytest.mark.asyncio
async def test_request_tracker_with_spec(capfd):
async def test_request_tracker_with_spec(metric_log_capture):
from litserve.specs.openai_embedding import OpenAIEmbeddingSpec
from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI

Expand All @@ -98,12 +129,13 @@ async def test_request_tracker_with_spec(capfd):
resp = await ac.post("/v1/embeddings", json={"input": "test", "model": "test"})
assert resp.status_code == 200

captured = capfd.readouterr()
assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}"
assert "Active requests: 1" in metric_log_capture.getvalue(), (
f"Expected pattern not found in log: {metric_log_capture.getvalue()}"
)


@pytest.mark.asyncio
async def test_request_tracker_with_openai_spec(capfd):
async def test_request_tracker_with_openai_spec(metric_log_capture):
from litserve.specs.openai import OpenAISpec
from litserve.test_examples.openai_spec_example import TestAPI

Expand All @@ -120,5 +152,6 @@ async def test_request_tracker_with_openai_spec(capfd):
)
assert resp.status_code == 200

captured = capfd.readouterr()
assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}"
assert "Active requests: 1" in metric_log_capture.getvalue(), (
f"Expected pattern not found in log: {metric_log_capture.getvalue()}"
)
11 changes: 6 additions & 5 deletions tests/unit/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ def test_server_terminate():


@pytest.mark.parametrize(("disable_openapi_url", "should_print"), [(False, True), (True, False)])
@patch("builtins.print")
@patch("litserve.server.logger")
@patch("litserve.server.uvicorn")
def test_disable_openapi_url_print_message(mock_uvicorn, mock_print, mock_manager, disable_openapi_url, should_print):
"""Test that the Swagger UI message is only printed when disable_openapi_url=False."""
def test_disable_openapi_url_print_message(mock_uvicorn, mock_logger, mock_manager, disable_openapi_url, should_print):
"""Test that the Swagger UI message is only logged when disable_openapi_url=False."""
server = LitServer(SimpleLitAPI(), disable_openapi_url=disable_openapi_url, restart_workers=False)
server.verify_worker_status = MagicMock()

Expand All @@ -355,10 +355,11 @@ def test_disable_openapi_url_print_message(mock_uvicorn, mock_print, mock_manage
server._monitor_workers = False
server.run(port=8000)

swagger_calls = [c for c in mock_logger.info.call_args_list if c.args and "Swagger UI" in c.args[0]]
if should_print:
mock_print.assert_called_with("Swagger UI is available at http://0.0.0.0:8000/docs")
assert len(swagger_calls) == 1
else:
mock_print.assert_not_called()
assert len(swagger_calls) == 0


class IdentityAPI(ls.test_examples.SimpleLitAPI):
Expand Down