diff --git a/src/litserve/callbacks/defaults/metric_callback.py b/src/litserve/callbacks/defaults/metric_callback.py index 8782b6156..6c90d61db 100644 --- a/src/litserve/callbacks/defaults/metric_callback.py +++ b/src/litserve/callbacks/defaults/metric_callback.py @@ -1,3 +1,4 @@ +import logging import time import typing @@ -6,6 +7,8 @@ if typing.TYPE_CHECKING: from litserve import LitAPI +logger = logging.getLogger(__name__) + class PredictionTimeLogger(Callback): def on_before_predict(self, lit_api: "LitAPI"): @@ -13,9 +16,9 @@ def on_before_predict(self, lit_api: "LitAPI"): def on_after_predict(self, lit_api: "LitAPI"): elapsed = time.perf_counter() - self._start_time - print(f"Prediction took {elapsed:.2f} seconds", flush=True) + logger.info(f"Prediction took {elapsed:.2f} seconds") class RequestTracker(Callback): def on_request(self, active_requests: int, **kwargs): - print(f"Active requests: {active_requests}", flush=True) + logger.info(f"Active requests: {active_requests}") diff --git a/src/litserve/server.py b/src/litserve/server.py index ccefb5c23..1aa098267 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -1510,7 +1510,8 @@ def run( ) if not self._disable_openapi_url: - print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") + host = "127.0.0.1" if sys.platform == "win32" else "0.0.0.0" + logger.info(f"Swagger UI is available at http://{host}:{port}/docs") if self._monitor_workers: self._start_worker_monitoring(manager, uvicorn_workers) @@ -1655,17 +1656,17 @@ def monitor(): resp.response_queue.append((None, LitAPIStatus.ERROR)) resp.event.set() - print(f"[monoriting] Marked {uid} set") + logger.info(f"[monoriting] Marked {uid} set") - print(f"[monoriting] Worker {worker_id} is dead. Restarting it") + logger.info(f"[monoriting] Worker {worker_id} is dead. Restarting it") lit_api = self.litapi_connector.lit_apis[lit_api_id] self.inference_workers[idx] = self.launch_single_inference_worker(lit_api, worker_id) - print(f"[monoriting] Worker {worker_id} has been started.") + logger.info(f"[monoriting] Worker {worker_id} has been started.") time.sleep(self.monitor_internal) - except Exception as e: - print(e) + except Exception: + logger.exception("Monitoring worker crashed") t = threading.Thread(target=monitor, daemon=True, name="litserve-monitoring") t.start() diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 74652f24b..f26088ffd 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -423,7 +423,7 @@ def pre_setup(self, lit_api: "LitAPI"): def setup(self, server: "LitServer"): super().setup(server) - print("OpenAI spec setup complete") + logger.info("OpenAI spec setup complete") def as_async(self) -> "_AsyncOpenAISpecWrapper": return _AsyncOpenAISpecWrapper(self) diff --git a/src/litserve/specs/openai_embedding.py b/src/litserve/specs/openai_embedding.py index 7ab40d1dc..a6640ab4d 100644 --- a/src/litserve/specs/openai_embedding.py +++ b/src/litserve/specs/openai_embedding.py @@ -168,7 +168,7 @@ def pre_setup(self, lit_api: "LitAPI"): def setup(self, server: "LitServer"): super().setup(server) - print("OpenAI Embedding Spec is ready.") + logger.info("OpenAI Embedding Spec is ready.") def decode_request(self, request: EmbeddingRequest, context_kwargs: Optional[dict] = None) -> list[str]: return request.input diff --git a/tests/unit/test_callbacks.py b/tests/unit/test_callbacks.py index 838b22643..32b77ab1f 100644 --- a/tests/unit/test_callbacks.py +++ b/tests/unit/test_callbacks.py @@ -1,4 +1,6 @@ import asyncio +import io +import logging import re import time @@ -54,36 +56,65 @@ def test_callback(capfd): assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" -def test_metric_logger(capfd): - cb = PredictionTimeLogger() - cb_runner = CallbackRunner() - cb_runner._add_callbacks(cb) - assert cb_runner._callbacks == [cb], "Callback not added to runner" - cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None) - cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None) - - captured = capfd.readouterr() +def test_metric_logger(): + # litserve logger has propagate=False so we add a temporary handler directly + log_capture = io.StringIO() + handler = logging.StreamHandler(log_capture) + metric_logger = logging.getLogger("litserve.callbacks.defaults.metric_callback") + metric_logger.addHandler(handler) + try: + cb = PredictionTimeLogger() + cb_runner = CallbackRunner() + cb_runner._add_callbacks(cb) + assert cb_runner._callbacks == [cb], "Callback not added to runner" + cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None) + cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None) + finally: + metric_logger.removeHandler(handler) + + output = log_capture.getvalue() pattern = r"Prediction took \d+\.\d{2} seconds" - assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" + assert re.search(pattern, output), f"Expected pattern not found in log: {output}" + + +@pytest.fixture +def metric_log_capture(): + """Add a StringIO handler to the litserve metric logger and yield the stream. + + litserve's logger has propagate=False so capfd/caplog won't intercept it. For callbacks that fire in the main + process (e.g. RequestTracker), this is the only reliable way to capture log output on all platforms. + + """ + stream = io.StringIO() + handler = logging.StreamHandler(stream) + metric_logger = logging.getLogger("litserve.callbacks.defaults.metric_callback") + metric_logger.addHandler(handler) + yield stream + metric_logger.removeHandler(handler) @pytest.mark.asyncio -async def test_request_tracker(capfd): +async def test_request_tracker(metric_log_capture): lit_api = SlowAPI() server = ls.LitServer(lit_api, track_requests=False, callbacks=[RequestTracker()]) await run_simple_request(server, 1) - captured = capfd.readouterr() - assert "Active requests: None" in captured.out, f"Expected pattern not found in output: {captured.out}" + assert "Active requests: None" in metric_log_capture.getvalue(), ( + f"Expected pattern not found in log: {metric_log_capture.getvalue()}" + ) + + metric_log_capture.truncate(0) + metric_log_capture.seek(0) server = ls.LitServer(lit_api, track_requests=True, callbacks=[RequestTracker()]) await run_simple_request(server, 4) - captured = capfd.readouterr() - assert "Active requests: 4" in captured.out, f"Expected pattern not found in output: {captured.out}" + assert "Active requests: 4" in metric_log_capture.getvalue(), ( + f"Expected pattern not found in log: {metric_log_capture.getvalue()}" + ) @pytest.mark.asyncio -async def test_request_tracker_with_spec(capfd): +async def test_request_tracker_with_spec(metric_log_capture): from litserve.specs.openai_embedding import OpenAIEmbeddingSpec from litserve.test_examples.openai_embedding_spec_example import TestEmbedAPI @@ -98,12 +129,13 @@ async def test_request_tracker_with_spec(capfd): resp = await ac.post("/v1/embeddings", json={"input": "test", "model": "test"}) assert resp.status_code == 200 - captured = capfd.readouterr() - assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}" + assert "Active requests: 1" in metric_log_capture.getvalue(), ( + f"Expected pattern not found in log: {metric_log_capture.getvalue()}" + ) @pytest.mark.asyncio -async def test_request_tracker_with_openai_spec(capfd): +async def test_request_tracker_with_openai_spec(metric_log_capture): from litserve.specs.openai import OpenAISpec from litserve.test_examples.openai_spec_example import TestAPI @@ -120,5 +152,6 @@ async def test_request_tracker_with_openai_spec(capfd): ) assert resp.status_code == 200 - captured = capfd.readouterr() - assert "Active requests: 1" in captured.out, f"Expected pattern not found in output: {captured.out}" + assert "Active requests: 1" in metric_log_capture.getvalue(), ( + f"Expected pattern not found in log: {metric_log_capture.getvalue()}" + ) diff --git a/tests/unit/test_lit_server.py b/tests/unit/test_lit_server.py index d319ae192..5075765f2 100644 --- a/tests/unit/test_lit_server.py +++ b/tests/unit/test_lit_server.py @@ -341,10 +341,10 @@ def test_server_terminate(): @pytest.mark.parametrize(("disable_openapi_url", "should_print"), [(False, True), (True, False)]) -@patch("builtins.print") +@patch("litserve.server.logger") @patch("litserve.server.uvicorn") -def test_disable_openapi_url_print_message(mock_uvicorn, mock_print, mock_manager, disable_openapi_url, should_print): - """Test that the Swagger UI message is only printed when disable_openapi_url=False.""" +def test_disable_openapi_url_print_message(mock_uvicorn, mock_logger, mock_manager, disable_openapi_url, should_print): + """Test that the Swagger UI message is only logged when disable_openapi_url=False.""" server = LitServer(SimpleLitAPI(), disable_openapi_url=disable_openapi_url, restart_workers=False) server.verify_worker_status = MagicMock() @@ -355,10 +355,11 @@ def test_disable_openapi_url_print_message(mock_uvicorn, mock_print, mock_manage server._monitor_workers = False server.run(port=8000) + swagger_calls = [c for c in mock_logger.info.call_args_list if c.args and "Swagger UI" in c.args[0]] if should_print: - mock_print.assert_called_with("Swagger UI is available at http://0.0.0.0:8000/docs") + assert len(swagger_calls) == 1 else: - mock_print.assert_not_called() + assert len(swagger_calls) == 0 class IdentityAPI(ls.test_examples.SimpleLitAPI):