Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/reference/task_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ called directly from within a task by passing `context.task_manager`:

```python
from fluid.scheduler import TaskRun, task
from fluid.scheduler.db import get_db_plugin, HistoryQuery
from fluid.scheduler.db import get_db_plugin, TaskHistoryQuery


@task()
async def report(context: TaskRun) -> None:
db_plugin = get_db_plugin(context.task_manager)
page = await db_plugin.get_history(HistoryQuery(task="my-task", limit=10))
page = await db_plugin.get_history(TaskHistoryQuery(task="my-task", limit=10))
for run in page.data:
print(run.id, run.state)
```
Expand All @@ -53,10 +53,10 @@ or the HTTP endpoints added by [with_task_history_router][fluid.scheduler.db.wit
They can be imported from `fluid.scheduler.db`:

```python
from fluid.scheduler.db import HistoryQuery, TaskRunHistory, TaskRunHistoryPage
from fluid.scheduler.db import TaskHistoryQuery, TaskRunHistory, TaskRunHistoryPage
```

::: fluid.scheduler.db.HistoryQuery
::: fluid.scheduler.db.TaskHistoryQuery

::: fluid.scheduler.db.TaskRunHistory

Expand Down
6 changes: 6 additions & 0 deletions fluid/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dateutil.parser import parse as parse_date
from sqlalchemy import Column, Table, func, insert, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection
Expand Down Expand Up @@ -300,6 +301,11 @@ def default_filter_column(
],
) -> Any:
"""Build a SQLAlchemy WHERE clause expression for a single column filter"""
if isinstance(column.type, JSONB) and isinstance(value, dict):
if op == "eq":
return column.contains(value)
return None

if multiple := isinstance(value, (list, tuple)):
value = tuple(column_value_to_python(column, v) for v in value)
else:
Expand Down
51 changes: 44 additions & 7 deletions fluid/scheduler/db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import json
from datetime import datetime
from typing import Any, ClassVar

import sqlalchemy as sa
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query
from pydantic import BaseModel, Field
from pydantic import BaseModel, BeforeValidator, Field, model_validator
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.exc import NoResultFound
from typing_extensions import Annotated, Doc

Expand All @@ -20,8 +22,18 @@
from .plugin import TaskManagerPlugin


def _parse_json_str(v: Any) -> Any:
"""Parse a JSON string into a dict for JSONB query parameters."""
if isinstance(v, str):
return json.loads(v)
return v


JsonDict = Annotated[dict[str, Any], BeforeValidator(_parse_json_str)]


class TaskDbPlugin(TaskManagerPlugin):
"""A plugin to store task runs in a database.
"""A plugin to store [TaskRun][fluid.scheduler.TaskRun] in a postgresql database.

This plugin listens to task state changes and updates the database accordingly.
It requires a CrudDB instance to perform database operations and allows
Expand Down Expand Up @@ -95,7 +107,7 @@ def register(self, task_manager: TaskManager) -> None:
async def get_history(
self,
q: Annotated[
HistoryQuery, Doc("Query parameters for fetching task run history")
TaskHistoryQuery, Doc("Query parameters for fetching task run history")
],
) -> TaskRunHistoryPage:
"""Get task run history based on the provided query parameters."""
Expand Down Expand Up @@ -169,10 +181,15 @@ def task_meta(meta: sa.MetaData, table_name: str = "tasks") -> None:
nullable=False,
index=True,
),
sa.Column("queued", sa.DateTime(timezone=True), nullable=False),
sa.Column("queued", sa.DateTime(timezone=True), nullable=False, index=True),
sa.Column("start", sa.DateTime(timezone=True)),
sa.Column("end", sa.DateTime(timezone=True)),
sa.Column("params", sa.JSON),
sa.Column("params", JSONB),
sa.Index(
f"ix_{table_name}_params",
"params",
postgresql_using="gin",
),
)


Expand Down Expand Up @@ -228,32 +245,52 @@ class TaskRunHistoryPage(BaseModel):
cursor: str = Field(..., description="Pagination cursor to fetch the next page")


class HistoryQuery(BaseModel):
class TaskHistoryQuery(BaseModel):
"""Query parameters for fetching task run history."""

task: Annotated[
str | None,
Query(description="Filter by task name"),
Doc("Filter by task name when provided"),
] = None
start: Annotated[
datetime | None,
Query(description="Filter runs queued at or after this time"),
Doc("Filter runs queued at or after this time when provided"),
] = None
end: Annotated[
datetime | None,
Query(description="Filter runs queued at or before this time"),
Doc("Filter runs queued at or before this time when provided"),
] = None
state: Annotated[
TaskState | None,
Query(description="Filter by task state"),
Doc("Filter by task state when provided"),
] = None
params: Annotated[
dict[str, Any] | str | None,
Query(description="Filter by params using JSON containment"),
Doc("Filter by params using JSON containment when provided"),
] = None

@model_validator(mode="before")
@classmethod
def _parse_params_str(cls, data: Any) -> Any:
if isinstance(data, dict) and "params" in data:
data = {**data}
data["params"] = _parse_json_str(data["params"])
return data

limit: Annotated[
int | None,
Query(description="Maximum number of results to return", ge=1),
Doc("Maximum number of results to return when provided"),
] = None
cursor: Annotated[
str,
Query(description="Pagination cursor from a previous response"),
Doc("Pagination cursor from a previous response when provided"),
] = ""

_filter_map: ClassVar[dict[str, str]] = {
Expand All @@ -278,7 +315,7 @@ def filters(self) -> dict:
)
async def get_history(
db_plugin: TaskDbPluginDep,
q: Annotated[HistoryQuery, Depends()],
q: Annotated[TaskHistoryQuery, Query()],
) -> TaskRunHistoryPage:
return await db_plugin.get_history(q)

Expand Down
49 changes: 48 additions & 1 deletion tests/scheduler/test_db_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from datetime import datetime, timezone
from typing import Any, AsyncIterator, cast

Expand All @@ -10,7 +11,12 @@
from examples import tasks
from fluid.scheduler import TaskState
from fluid.scheduler.consumer import TaskConsumer
from fluid.scheduler.db import TaskDbPlugin, get_db_plugin, with_task_history_router
from fluid.scheduler.db import (
TaskDbPlugin,
TaskHistoryQuery,
get_db_plugin,
with_task_history_router,
)
from fluid.scheduler.endpoints import get_task_manager, task_manager_fastapi
from fluid.utils.http_client import HttpResponseError
from tests.scheduler.tasks import TaskClient, redis_broker, start_fastapi
Expand Down Expand Up @@ -268,3 +274,44 @@ async def test_get_history_filter_by_end(
assert any(item["id"] == task_run.id for item in data)
data_empty = await get_history(cli_db, end="2000-01-01T00:00:00Z")
assert data_empty == []


async def test_get_history_filter_by_params_programmatic(
task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin
) -> None:
task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=7.0, b=8.0)
assert task_run.state == TaskState.success

await wait_for_task_run(db_plugin, task_run.id)
page = await db_plugin.get_history(TaskHistoryQuery(params={"a": 7.0}))
assert len(page.data) >= 1
assert any(r.id == task_run.id for r in page.data)
assert all(7.0 == r.params.get("a") for r in page.data)

# Negative: filter that shouldn't match
page_empty = await db_plugin.get_history(TaskHistoryQuery(params={"a": 999.0}))
assert not any(r.id == task_run.id for r in page_empty.data)


async def test_get_history_filter_by_params_http(
cli_db: TaskClient, task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin
) -> None:
task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=9.0, b=10.0)
assert task_run.state == TaskState.success

await wait_for_task_run(db_plugin, task_run.id)

data = await get_history(cli_db, params=json.dumps({"a": 9.0}))
assert any(item["id"] == task_run.id for item in data)
assert all(9.0 == item["params"].get("a") for item in data)


async def test_get_history_filter_by_params_http_negative(
cli_db: TaskClient, task_manager_db: TaskConsumer, db_plugin: TaskDbPlugin
) -> None:
task_run = await task_manager_db.queue_and_wait("add", timeout=5, a=11.0, b=12.0)
assert task_run.state == TaskState.success

await wait_for_task_run(db_plugin, task_run.id)
data_empty = await get_history(cli_db, params=json.dumps({"a": 999.0}))
assert not any(item["id"] == task_run.id for item in data_empty)
Loading