From 3120874b8ef2113c0ca2a3c47410be6b048d2228 Mon Sep 17 00:00:00 2001 From: klpoland Date: Fri, 17 Apr 2026 15:09:44 -0400 Subject: [PATCH 01/10] add captures/files to datasets for use/listing on sdk --- .../serializers/capture_serializers.py | 13 +++- .../serializers/dataset_serializers.py | 68 ++++++++++++++++++- .../serializers/file_serializers.py | 14 ++++ .../tests/test_dataset_endpoints.py | 12 ++++ .../api_methods/views/dataset_endpoints.py | 53 +++++++++++++++ sdk/src/spectrumx/api/datasets.py | 61 ++++++++++++++++- sdk/src/spectrumx/client.py | 21 ++++++ sdk/src/spectrumx/gateway.py | 25 +++++++ sdk/src/spectrumx/models/datasets.py | 30 +++++++- 9 files changed, 289 insertions(+), 8 deletions(-) diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index f75a48c5..487c834f 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -17,11 +17,15 @@ from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import UserSharePermission from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer +from sds_gateway.api_methods.serializers.dataset_serializers import DatasetSummarySerializer +from sds_gateway.api_methods.serializers.user_serializer import UserSharePermissionSerializer from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer from sds_gateway.api_methods.serializers.user_serializer import ( UserSharePermissionSerializer, ) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared +from sds_gateway.api_methods.utils.asset_access_control import get_connected_asset_ids +from sds_gateway.api_methods.utils.relationship_utils import get_capture_datasets from sds_gateway.api_methods.utils.relationship_utils import get_capture_files @@ -80,7 +84,7 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): files = serializers.SerializerMethodField() total_file_size = serializers.SerializerMethodField() data_files_info = serializers.SerializerMethodField() - datasets = DatasetGetSerializer(many=True) + datasets = serializers.SerializerMethodField() center_frequency_ghz = serializers.SerializerMethodField() sample_rate_mhz = serializers.SerializerMethodField() length_of_capture_ms = serializers.SerializerMethodField() @@ -89,6 +93,13 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): formatted_created_at = serializers.SerializerMethodField() capture_type_display = serializers.SerializerMethodField() post_processed_data = serializers.SerializerMethodField() + + def get_datasets(self, capture: Capture) -> list[dict[str, Any]]: + """Datasets linked to this capture; shallow when serializing under dataset detail.""" + qs = get_capture_datasets(capture, include_deleted=False) + if self.context.get("omit_nested_dataset_graph"): + return DatasetSummarySerializer(qs, many=True, context=self.context).data + return DatasetGetSerializer(qs, many=True, context=self.context).data def get_share_permissions(self, capture: Capture) -> list[UserSharePermission]: """Get the share permissions for the capture.""" diff --git a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py index 231aea28..4b22f4fa 100644 --- a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py @@ -6,15 +6,30 @@ from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import PermissionLevel from sds_gateway.api_methods.models import UserSharePermission -from sds_gateway.api_methods.serializers.user_serializer import ( - UserSharePermissionSerializer, -) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared +from sds_gateway.api_methods.helpers.search_captures import group_captures_by_top_level_dir +from sds_gateway.api_methods.serializers.capture_serializers import build_composite_capture_data +from sds_gateway.api_methods.serializers.capture_serializers import serialize_capture_or_composite +from sds_gateway.api_methods.serializers.file_serializers import FileArtifactSummarySerializer +from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer +from sds_gateway.api_methods.serializers.user_serializer import UserSharePermissionSerializer +from sds_gateway.api_methods.utils.asset_access_control import check_if_shared +from sds_gateway.api_methods.utils.relationship_utils import get_dataset_artifact_files +from sds_gateway.api_methods.utils.relationship_utils import get_dataset_captures READABLE_ISO_DATE_TIME: str = "%Y-%m-%d %H:%M:%S%z" +class DatasetSummarySerializer(serializers.ModelSerializer[Dataset]): + """Minimal dataset shape for capture ``datasets`` when breaking serializer cycles.""" + + class Meta: + model = Dataset + fields = ["uuid", "name", "version", "status", "is_public"] + + class DatasetGetSerializer(serializers.ModelSerializer[Dataset]): + owner = UserGetSerializer(read_only=True) authors = serializers.SerializerMethodField() keywords = serializers.SerializerMethodField() created_at = serializers.DateTimeField( @@ -26,6 +41,8 @@ class DatasetGetSerializer(serializers.ModelSerializer[Dataset]): status_display = serializers.CharField(source="get_status_display", read_only=True) shared_users = serializers.SerializerMethodField() share_permissions = serializers.SerializerMethodField() + captures = serializers.SerializerMethodField() + files = serializers.SerializerMethodField() owner_name = serializers.SerializerMethodField() owner_email = serializers.SerializerMethodField() permission_level = serializers.SerializerMethodField() @@ -130,6 +147,51 @@ def get_share_permissions(self, obj): is_enabled=True, ) return UserSharePermissionSerializer(user_share_permissions, many=True).data + + def get_files(self, obj: Dataset) -> list[dict]: + """Get the files for the dataset. + + Returns: + A list of serialized file objects + """ + non_deleted_files = get_dataset_artifact_files( + obj, + include_deleted=False, + ) + serializer = FileArtifactSummarySerializer( + non_deleted_files, + many=True, + context=self.context, + ) + return serializer.data + + def get_captures(self, obj: Dataset) -> list[dict]: + """Get captures for the dataset, one entry per logical capture (list API semantics). + + Multi-channel uploads share ``top_level_dir``; those rows are merged into a + single composite payload like :func:`get_composite_captures`. + """ + non_deleted_captures = get_dataset_captures( + obj, + include_deleted=False, + ) + grouped = group_captures_by_top_level_dir(list(non_deleted_captures)) + composite_captures: list[dict] = [] + capture_context = { + **(self.context or {}), + "omit_nested_dataset_graph": True, + } + for capture_list in grouped.values(): + if len(capture_list) > 1: + composite_captures.append(build_composite_capture_data(capture_list)) + else: + composite_captures.append( + serialize_capture_or_composite( + capture_list[0], + context=capture_context, + ) + ) + return composite_captures def get_is_shared(self, obj): """Check if the dataset is shared.""" diff --git a/gateway/sds_gateway/api_methods/serializers/file_serializers.py b/gateway/sds_gateway/api_methods/serializers/file_serializers.py index d8e7b935..51457961 100644 --- a/gateway/sds_gateway/api_methods/serializers/file_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/file_serializers.py @@ -18,6 +18,20 @@ CONFLICT = 409 +class FileArtifactSummarySerializer(serializers.ModelSerializer[File]): + """Subset of file fields for nested dataset payloads (avoids recursive graphs).""" + + class Meta: + model = File + fields = ( + "uuid", + "name", + "directory", + "media_type", + "size", + ) + + class FileGetSerializer(serializers.ModelSerializer[File]): owner = UserGetSerializer() datasets = DatasetGetSerializer(many=True) diff --git a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py index 500be261..cc83b1cc 100644 --- a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py @@ -117,6 +117,18 @@ def _cleanup_dataset_connections(self): if permission.pk: permission.delete() + def test_retrieve_dataset_success(self): + """GET dataset detail returns metadata, captures, and artifact files.""" + url = reverse("api:datasets-detail", kwargs={"pk": self.dataset.uuid}) + response = self.client.get(url) + assert response.status_code == status.HTTP_200_OK + body = response.json() + assert body["uuid"] == str(self.dataset.uuid) + assert "captures" in body + assert "files" in body + assert isinstance(body["captures"], list) + assert isinstance(body["files"], list) + def test_get_dataset_files_success(self): """Test successful dataset files manifest retrieval.""" # Create test files associated with the dataset with MinIO mocking diff --git a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py index 21a55864..312f95d9 100644 --- a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py @@ -17,6 +17,7 @@ from sds_gateway.api_methods.models import File from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import user_has_access_to_item +from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer from sds_gateway.api_methods.serializers.file_serializers import FileGetSerializer from sds_gateway.api_methods.utils.asset_access_control import check_if_shared from sds_gateway.api_methods.utils.asset_access_control import ( @@ -37,6 +38,58 @@ def _get_file_objects(self, dataset: Dataset) -> QuerySet[File]: """Get all files associated with a dataset.""" return get_dataset_files_including_captures(dataset, include_deleted=False) + @extend_schema( + parameters=[ + OpenApiParameter( + name="id", + description="Dataset UUID", + required=True, + type=str, + location=OpenApiParameter.PATH, + ), + ], + responses={ + 200: OpenApiResponse(description="Dataset metadata, captures, and artifact files"), + 403: OpenApiResponse(description="Forbidden"), + 404: OpenApiResponse(description="Not Found"), + }, + description=( + "Return dataset metadata with captures (one row per logical capture, " + "including composite multi-channel) and artifact files linked directly to the dataset." + ), + summary="Retrieve Dataset", + ) + def retrieve(self, request: Request, pk: str | None = None) -> Response: + """Return serialized dataset including captures and direct (artifact) files.""" + if pk is None: + return Response( + {"detail": "Dataset UUID is required."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + target_dataset = get_object_or_404( + Dataset, + pk=pk, + is_deleted=False, + ) + + assert isinstance(request.user, User), ( + "Expected request.user to be an instance of the custom User model" + ) + if not user_has_access_to_item( + request.user, target_dataset.uuid, ItemType.DATASET + ): + return Response( + {"detail": "You do not have permission to access this dataset."}, + status=status.HTTP_403_FORBIDDEN, + ) + + serializer = DatasetGetSerializer( + target_dataset, + context={"request": request}, + ) + return Response(serializer.data) + @extend_schema( parameters=[ OpenApiParameter( diff --git a/sdk/src/spectrumx/api/datasets.py b/sdk/src/spectrumx/api/datasets.py index 519cdd48..a25cd331 100644 --- a/sdk/src/spectrumx/api/datasets.py +++ b/sdk/src/spectrumx/api/datasets.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import json +import uuid +from typing import Any from loguru import logger as log +from spectrumx.gateway import GatewayClient +from spectrumx.models.datasets import Dataset from spectrumx.models.files import File from spectrumx.ops.pagination import Paginator from spectrumx.utils import log_user @@ -32,6 +36,61 @@ def __init__( self.gateway = gateway self.verbose = verbose + def get(self, dataset_uuid: uuid.UUID) -> Dataset: + """Load dataset metadata, captures, and artifact files from SDS. + + Captures are returned in the same grouped shape as the capture list API + (one entry per logical multi-channel capture where applicable). For every + file in the dataset (including capture-linked files), use :meth:`get_files` + instead, which calls the paginated dataset files manifest endpoint. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty Dataset shell") + return Dataset(uuid=dataset_uuid) + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + return Dataset.model_validate_json(raw) + + def list_captures(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: + """Return capture payloads linked to the dataset (raw JSON objects). + + Use this when you need composite capture fields (for example ``channels``) + without coercing through :class:`~spectrumx.models.datasets.DatasetCapture`. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty capture list") + return [] + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + data = json.loads(raw) + captures = data.get("captures") + return list(captures) if isinstance(captures, list) else [] + + def list_artifact_files(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: + """Return file rows linked directly to the dataset (artifacts), as JSON dicts. + + These are the same objects embedded on :meth:`get` under the ``files`` key. + For the full downloadable manifest (captures plus artifacts), use + :meth:`get_files`. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty artifact file list") + return [] + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + data = json.loads(raw) + files = data.get("files") + return list(files) if isinstance(files, list) else [] + def get_files( self, dataset_uuid: UUID, diff --git a/sdk/src/spectrumx/client.py b/sdk/src/spectrumx/client.py index dbee54d5..10623283 100644 --- a/sdk/src/spectrumx/client.py +++ b/sdk/src/spectrumx/client.py @@ -17,6 +17,7 @@ from spectrumx.errors import process_upload_results from spectrumx.models.captures import Capture from spectrumx.models.captures import CaptureType +from spectrumx.models.datasets import Dataset from spectrumx.ops.pagination import Paginator from . import __version__ @@ -499,6 +500,26 @@ def download_dataset( verbose=verbose, ) + def get_dataset(self, dataset_uuid: UUID4 | str) -> Dataset: + """Fetch dataset metadata, captures, and artifact files from SDS.""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.get(dataset_uuid) + + def list_dataset_captures(self, dataset_uuid: UUID4 | str) -> list[dict[str, Any]]: + """List captures linked to a dataset (raw dicts; supports composite payloads).""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.list_captures(dataset_uuid) + + def list_dataset_artifact_files( + self, dataset_uuid: UUID4 | str + ) -> list[dict[str, Any]]: + """List files linked directly to the dataset (not the full download manifest).""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.list_artifact_files(dataset_uuid) + def _upload_deprecated( self, *, diff --git a/sdk/src/spectrumx/gateway.py b/sdk/src/spectrumx/gateway.py index c275bd21..ace0d3ea 100644 --- a/sdk/src/spectrumx/gateway.py +++ b/sdk/src/spectrumx/gateway.py @@ -734,6 +734,31 @@ def revoke_dataset_share_permissions( network.success_or_raise(response, ContextException=DatasetError) return response.content + def get_dataset( + self, + *, + dataset_uuid: uuid.UUID, + verbose: bool = False, + ) -> bytes: + """Fetch dataset metadata including captures and direct (artifact) files. + + Args: + dataset_uuid: UUID of the dataset. + verbose: Whether to log the request. + Returns: + JSON body from the gateway. + Raises: + DatasetError: If the request fails. + """ + response = self._request( + method=HTTPMethods.GET, + endpoint=Endpoints.DATASETS, + asset_id=dataset_uuid.hex, + verbose=verbose, + ) + network.success_or_raise(response, ContextException=DatasetError) + return response.content + def get_dataset_files( self, *, diff --git a/sdk/src/spectrumx/models/datasets.py b/sdk/src/spectrumx/models/datasets.py index f02bcc71..a57c8ea8 100644 --- a/sdk/src/spectrumx/models/datasets.py +++ b/sdk/src/spectrumx/models/datasets.py @@ -7,17 +7,37 @@ from pydantic import BaseModel from pydantic import ConfigDict +from spectrumx.models.captures import CaptureType +from spectrumx.models.captures import CaptureOrigin +from spectrumx.models.user import UserSharePermission from spectrumx.models.user import User from spectrumx.models.user import UserSharePermission +class DatasetFile(BaseModel): + model_config = ConfigDict(extra="ignore") + + uuid: UUID4 | None = None + name: str | None = None + directory: str | None = None + media_type: str | None = None + +class DatasetCapture(BaseModel): + model_config = ConfigDict(extra="ignore") + + uuid: UUID4 | None = None + name: str | None = None + capture_type: CaptureType | None = None + index_name: str | None = None + origin: CaptureOrigin | None = None + top_level_dir: str | None = None + owner: User | None = None + class Dataset(BaseModel): """A dataset in SDS.""" model_config = ConfigDict(extra="ignore") - # TODO ownership: include ownership and access level information - uuid: UUID4 | None = None owner: User | None = None name: str | None = None @@ -30,7 +50,7 @@ class Dataset(BaseModel): institutions: list[str] | None = None release_date: datetime | None = None repository: str | None = None - version: str | None = None + version: int | None = None website: str | None = None provenance: dict[str, Any] | None = None citation: dict[str, Any] | None = None @@ -42,8 +62,12 @@ class Dataset(BaseModel): is_shared: bool = False is_shared_with_me: bool = False share_permissions: list[UserSharePermission] | None = None + captures: list[DatasetCapture] | None = None + files: list[DatasetFile] | None = None __all__ = [ "Dataset", + "DatasetCapture", + "DatasetFile", ] From 543ea8079afa57d73ac78254ad9226ecd5d1abca Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 23 Apr 2026 09:37:03 -0400 Subject: [PATCH 02/10] add is_shared_with_me to capture --- .../serializers/capture_serializers.py | 22 +++++++++++++++---- sdk/src/spectrumx/models/captures.py | 2 ++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index 487c834f..f5f8deeb 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -17,14 +17,14 @@ from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import UserSharePermission from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer -from sds_gateway.api_methods.serializers.dataset_serializers import DatasetSummarySerializer -from sds_gateway.api_methods.serializers.user_serializer import UserSharePermissionSerializer +from sds_gateway.api_methods.serializers.dataset_serializers import ( + DatasetSummarySerializer, +) from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer from sds_gateway.api_methods.serializers.user_serializer import ( UserSharePermissionSerializer, ) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared -from sds_gateway.api_methods.utils.asset_access_control import get_connected_asset_ids from sds_gateway.api_methods.utils.relationship_utils import get_capture_datasets from sds_gateway.api_methods.utils.relationship_utils import get_capture_files @@ -80,6 +80,7 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): owner = UserGetSerializer() share_permissions = serializers.SerializerMethodField() is_shared = serializers.SerializerMethodField() + is_shared_with_me = serializers.SerializerMethodField() capture_props = serializers.SerializerMethodField() files = serializers.SerializerMethodField() total_file_size = serializers.SerializerMethodField() @@ -93,7 +94,7 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): formatted_created_at = serializers.SerializerMethodField() capture_type_display = serializers.SerializerMethodField() post_processed_data = serializers.SerializerMethodField() - + def get_datasets(self, capture: Capture) -> list[dict[str, Any]]: """Datasets linked to this capture; shallow when serializing under dataset detail.""" qs = get_capture_datasets(capture, include_deleted=False) @@ -111,6 +112,19 @@ def get_share_permissions(self, capture: Capture) -> list[UserSharePermission]: ) return UserSharePermissionSerializer(user_share_permissions, many=True).data + def get_is_shared_with_me(self, capture: Capture) -> bool: + """Get whether the capture is shared with the current user.""" + request = self.context.get("request") + if request and hasattr(request, "user"): + return UserSharePermission.objects.filter( + shared_with=request.user, + item_type=ItemType.CAPTURE, + item_uuid=capture.uuid, + is_enabled=True, + is_deleted=False, + ).exists() + return False + def get_is_shared(self, capture: Capture) -> bool: """Get whether the capture is shared. diff --git a/sdk/src/spectrumx/models/captures.py b/sdk/src/spectrumx/models/captures.py index 5bf81604..1d7d7cc7 100644 --- a/sdk/src/spectrumx/models/captures.py +++ b/sdk/src/spectrumx/models/captures.py @@ -48,6 +48,7 @@ class CaptureOrigin(StrEnum): _d_channel = "The channel associated with the capture. Only for RadioHound type." _d_scan_group = "The scan group associated with the capture. Only for Digital-RF type." _d_is_shared = "Whether the capture is shared" +_d_is_shared_with_me = "Whether the capture is shared with the current user" _d_datasets = "Datasets this capture is associated with" @@ -91,6 +92,7 @@ class Capture(SDSModel): Field(description=_d_share_permissions, default_factory=list), ] is_shared: Annotated[bool, Field(description=_d_is_shared)] + is_shared_with_me: Annotated[bool, Field(description=_d_is_shared_with_me)] # optional fields created_at: Annotated[ From ad6179ed1263ffc91c29adb267278ccf27f1c2d8 Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 23 Apr 2026 12:14:34 -0400 Subject: [PATCH 03/10] add time filtering to file listing methods/endpoint, update gateway tests --- .../api_methods/helpers/temporal_filtering.py | 66 +++--- .../serializers/capture_serializers.py | 222 ++++++++++++++++-- .../api_methods/tests/test_celery_tasks.py | 6 +- .../api_methods/tests/test_file_endpoints.py | 73 ++++++ .../utils/swagger_example_schema.py | 1 + .../api_methods/views/file_endpoints.py | 91 ++++++- sdk/src/spectrumx/api/sds_files.py | 32 ++- sdk/src/spectrumx/client.py | 32 ++- sdk/src/spectrumx/gateway.py | 22 +- sdk/src/spectrumx/models/captures.py | 26 ++ sdk/src/spectrumx/ops/pagination.py | 6 +- 11 files changed, 506 insertions(+), 71 deletions(-) diff --git a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py index 19b37be0..50753fbc 100644 --- a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py +++ b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py @@ -22,36 +22,17 @@ def drf_rf_filename_from_ms(ms: int) -> str: return f"rf@{ms // 1000}.{ms % 1000:03d}.h5" -def _catch_capture_type_error(capture_type: CaptureType) -> None: +def _catch_value_errors(capture_type: CaptureType, capture: Capture) -> None: + if capture_type != CaptureType.DigitalRF: msg = "Only DigitalRF captures are supported for temporal filtering." log.error(msg) raise ValueError(msg) - - -def _filter_capture_data_files_selection_bounds( - capture: Capture, - start_time: int, # relative ms from start of capture (from UI) - end_time: int, # relative ms from start of capture (from UI) -) -> QuerySet[File]: - """Filter the capture file selection bounds to the given start and end times.""" + if capture.start_time is None: msg = f"Capture {capture.uuid} has no indexed start_time for temporal filtering" raise ValueError(msg) - epoch_start_ms = capture.start_time * 1000 - start_ms = epoch_start_ms + start_time - end_ms = epoch_start_ms + end_time - - start_file_name = drf_rf_filename_from_ms(start_ms) - end_file_name = drf_rf_filename_from_ms(end_ms) - - data_files = capture.get_drf_data_files_queryset() - return data_files.filter( - name__gte=start_file_name, - name__lte=end_file_name, - ).order_by("name") - def get_capture_files_with_temporal_filter( capture_type: CaptureType, @@ -60,24 +41,51 @@ def get_capture_files_with_temporal_filter( end_time: int | None = None, ) -> QuerySet[File]: """Get the capture files with temporal filtering.""" - _catch_capture_type_error(capture_type) + _catch_value_errors(capture_type, capture) + + capture_files = get_capture_files(capture) if start_time is None or end_time is None: log.warning( "Start or end time is None; returning all capture files without " "temporal filtering" ) - return get_capture_files(capture) + return capture_files + epoch_start_ms = capture.start_time * 1000 + start_ms = epoch_start_ms + start_time + end_ms = epoch_start_ms + end_time + + return filter_files_by_temporal_bounds( + capture_files, + start_ms, + end_ms, + ) + + +def filter_files_by_temporal_bounds( + files: QuerySet[File], + start_time: int, + end_time: int, +) -> QuerySet[File]: + """Filter files by temporal bounds.""" + # get non-data files - non_data_files = get_capture_files(capture).exclude( + non_data_files = files.exclude( name__regex=DRF_RF_FILENAME_REGEX_STR ) - # get data files with temporal filtering - data_files = _filter_capture_data_files_selection_bounds( - capture, start_time, end_time + unfiltered_data_files = files.filter( + name__regex=DRF_RF_FILENAME_REGEX_STR ) + start_file_name = drf_rf_filename_from_ms(start_time) + end_file_name = drf_rf_filename_from_ms(end_time) + + filtered_data_files = unfiltered_data_files.filter( + name__gte=start_file_name, + name__lte=end_file_name, + ).order_by("name") + # return all files - return non_data_files.union(data_files) + return non_data_files.union(filtered_data_files) \ No newline at end of file diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index f5f8deeb..47188e23 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -1,10 +1,13 @@ """Capture serializers for the SDS Gateway API methods.""" import logging +from datetime import datetime +from datetime import timezone from typing import Any from typing import cast from django.db.models import Sum +from django.utils import timezone as django_timezone from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from rest_framework.utils.serializer_helpers import ReturnList @@ -29,6 +32,18 @@ from sds_gateway.api_methods.utils.relationship_utils import get_capture_files +def _epoch_sec_to_iso_utc_z(epoch_sec: int) -> str: + """Format OpenSearch epoch seconds as an ISO 8601 UTC string with ``Z`` suffix.""" + dt = datetime.fromtimestamp(epoch_sec, tz=timezone.utc) + return dt.isoformat().replace("+00:00", "Z") + + +def _epoch_sec_to_local_display(epoch_sec: int) -> str: + """Human-readable local time (same pattern as ``formatted_created_at``).""" + dt = datetime.fromtimestamp(epoch_sec, tz=timezone.utc) + return django_timezone.localtime(dt).strftime("%m/%d/%Y %I:%M:%S %p") + + class FileCaptureListSerializer(serializers.ModelSerializer[File]): class Meta: model = File @@ -91,6 +106,10 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): length_of_capture_ms = serializers.SerializerMethodField() file_cadence_ms = serializers.SerializerMethodField() capture_start_epoch_sec = serializers.SerializerMethodField() + capture_start_iso_utc = serializers.SerializerMethodField() + capture_end_iso_utc = serializers.SerializerMethodField() + capture_start_display = serializers.SerializerMethodField() + capture_end_display = serializers.SerializerMethodField() formatted_created_at = serializers.SerializerMethodField() capture_type_display = serializers.SerializerMethodField() post_processed_data = serializers.SerializerMethodField() @@ -202,7 +221,7 @@ def get_length_of_capture_ms(self, capture: Capture) -> int | None: """Capture length in milliseconds (OpenSearch bounds are seconds).""" if capture.end_time is None or capture.start_time is None: return None - + return (capture.end_time - capture.start_time) * 1000 @extend_schema_field(serializers.IntegerField(allow_null=True)) @@ -215,6 +234,38 @@ def get_capture_start_epoch_sec(self, capture: Capture) -> int | None: """Capture start as Unix epoch seconds. None if not in OpenSearch.""" return capture.start_time + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_iso_utc(self, capture: Capture) -> str | None: + """Indexed capture start as ISO 8601 UTC (``Z``). None if unavailable.""" + if capture.start_time is None: + return None + + return _epoch_sec_to_iso_utc_z(capture.start_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_iso_utc(self, capture: Capture) -> str | None: + """Indexed capture end as ISO 8601 UTC (``Z``). None if unavailable.""" + if capture.end_time is None: + return None + + return _epoch_sec_to_iso_utc_z(capture.end_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_display(self, capture: Capture) -> str | None: + """Indexed capture start in the active timezone for display.""" + if capture.start_time is None: + return None + + return _epoch_sec_to_local_display(capture.start_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_display(self, capture: Capture) -> str | None: + """Indexed capture end in the active timezone for display.""" + if capture.end_time is None: + return None + + return _epoch_sec_to_local_display(capture.end_time) + @extend_schema_field(serializers.DictField) def get_capture_props(self, capture: Capture) -> dict[str, Any]: """Retrieve the indexed metadata for the capture.""" @@ -381,6 +432,22 @@ class ChannelMetadataSerializer(serializers.Serializer): channel_metadata = serializers.DictField() +class CompositeChannelEntrySerializer(serializers.Serializer): + """One channel in a composite capture, including per-channel index bounds.""" + + channel = serializers.CharField() + uuid = serializers.UUIDField() + channel_metadata = serializers.DictField() + capture_start_epoch_sec = serializers.IntegerField(allow_null=True, required=False) + capture_end_epoch_sec = serializers.IntegerField(allow_null=True, required=False) + capture_start_iso_utc = serializers.CharField(allow_null=True, required=False) + capture_end_iso_utc = serializers.CharField(allow_null=True, required=False) + capture_start_display = serializers.CharField(allow_null=True, required=False) + capture_end_display = serializers.CharField(allow_null=True, required=False) + length_of_capture_ms = serializers.IntegerField(allow_null=True, required=False) + file_cadence_ms = serializers.IntegerField(allow_null=True, required=False) + + class CompositeCaptureSerializer(serializers.Serializer): """Serializer for composite captures that contain multiple channels.""" @@ -400,8 +467,8 @@ class CompositeCaptureSerializer(serializers.Serializer): is_shared = serializers.SerializerMethodField() owner = UserGetSerializer() - # Channel-specific fields - channels = serializers.ListField(child=ChannelMetadataSerializer()) + # Channel-specific fields (enriched with OpenSearch bounds per channel) + channels = serializers.SerializerMethodField() # Computed fields share_permissions = serializers.SerializerMethodField() @@ -412,6 +479,81 @@ class CompositeCaptureSerializer(serializers.Serializer): length_of_capture_ms = serializers.SerializerMethodField() file_cadence_ms = serializers.SerializerMethodField() capture_start_epoch_sec = serializers.SerializerMethodField() + capture_start_iso_utc = serializers.SerializerMethodField() + capture_end_iso_utc = serializers.SerializerMethodField() + capture_start_display = serializers.SerializerMethodField() + capture_end_display = serializers.SerializerMethodField() + + def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: + """Per-channel rows with OpenSearch bounds (each channel may differ).""" + key = str(obj.get("uuid", "")) + if not hasattr(self, "_enriched_channels_cache"): + self._enriched_channels_cache: dict[str, list[dict[str, Any]]] = {} + if key not in self._enriched_channels_cache: + out: list[dict[str, Any]] = [] + for ch in obj.get("channels") or []: + entry: dict[str, Any] = { + "channel": ch["channel"], + "uuid": ch["uuid"], + "channel_metadata": ch.get("channel_metadata", {}), + } + try: + capture = Capture.objects.get(uuid=ch["uuid"]) + start_sec, end_sec = get_capture_bounds( + capture.capture_type, str(capture.uuid) + ) + except ( + ValueError, + IndexError, + KeyError, + Capture.DoesNotExist, + ): + entry["capture_start_epoch_sec"] = None + entry["capture_end_epoch_sec"] = None + entry["capture_start_iso_utc"] = None + entry["capture_end_iso_utc"] = None + entry["capture_start_display"] = None + entry["capture_end_display"] = None + entry["length_of_capture_ms"] = None + entry["file_cadence_ms"] = None + else: + entry["capture_start_epoch_sec"] = start_sec + entry["capture_end_epoch_sec"] = end_sec + entry["capture_start_iso_utc"] = _epoch_sec_to_iso_utc_z(start_sec) + entry["capture_end_iso_utc"] = _epoch_sec_to_iso_utc_z(end_sec) + entry["capture_start_display"] = _epoch_sec_to_local_display( + start_sec + ) + entry["capture_end_display"] = _epoch_sec_to_local_display(end_sec) + entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 + try: + entry["file_cadence_ms"] = get_file_cadence( + capture.capture_type, capture + ) + except (ValueError, IndexError, KeyError): + entry["file_cadence_ms"] = None + out.append(entry) + self._enriched_channels_cache[key] = out + return self._enriched_channels_cache[key] + + def _composite_envelope_bounds( + self, + obj: dict[str, Any], + ) -> tuple[int, int] | None: + """Earliest channel start and latest channel end (seconds), for composite summary.""" + pairs = [ + (row["capture_start_epoch_sec"], row["capture_end_epoch_sec"]) + for row in self._enriched_channels(obj) + if row.get("capture_start_epoch_sec") is not None + and row.get("capture_end_epoch_sec") is not None + ] + if not pairs: + return None + return min(s for s, _ in pairs), max(e for _, e in pairs) + + @extend_schema_field(CompositeChannelEntrySerializer(many=True)) + def get_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: + return self._enriched_channels(obj) def get_share_permissions(self, obj: dict[str, Any]) -> list[UserSharePermission]: """Get the share permissions for the composite capture.""" @@ -441,7 +583,7 @@ def get_is_shared(self, obj: dict[str, Any]) -> bool: def get_files(self, obj: dict[str, Any]) -> ReturnList[File]: """Get all files from all channels in the composite capture.""" all_files = [] - for channel_data in obj["channels"]: + for channel_data in obj.get("channels") or []: capture_uuid = channel_data["uuid"] capture = Capture.objects.get(uuid=capture_uuid) non_deleted_files = get_capture_files(capture, include_deleted=False) @@ -459,7 +601,7 @@ def get_total_file_size(self, obj: dict[str, Any]) -> int | None: return None total_size = 0 - for channel_data in obj["channels"]: + for channel_data in obj.get("channels") or []: capture_uuid = channel_data["uuid"] capture = Capture.objects.get(uuid=capture_uuid) all_files = get_capture_files(capture, include_deleted=False) @@ -487,7 +629,7 @@ def get_data_files_info(self, obj: dict[str, Any]) -> dict[str, Any]: total_count = 0 total_size = 0 - for channel_data in obj["channels"]: + for channel_data in obj.get("channels") or []: capture_uuid = channel_data["uuid"] capture = Capture.objects.get(uuid=capture_uuid) stats = capture.get_drf_data_files_stats() @@ -512,35 +654,65 @@ def get_formatted_created_at(self, obj: dict[str, Any]) -> str: @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_length_of_capture_ms(self, obj: dict[str, Any]) -> int | None: - """Use first channel's bounds for composite capture duration.""" - channels = obj.get("channels") or [] - if not channels: + """Span from earliest channel start to latest channel end (milliseconds).""" + bounds = self._composite_envelope_bounds(obj) + if bounds is None: return None - - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - if capture.end_time is None or capture.start_time is None: - return None - return (capture.end_time - capture.start_time) * 1000 + start_time, end_time = bounds + return (end_time - start_time) * 1000 @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_file_cadence_ms(self, obj: dict[str, Any]) -> int | None: - """Use first channel's file cadence for composite capture.""" - channels = obj.get("channels") or [] - if not channels: + """Mean file cadence across channels (each channel may differ).""" + cadences = [ + row["file_cadence_ms"] + for row in self._enriched_channels(obj) + if row.get("file_cadence_ms") is not None + ] + if not cadences: return None - - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - return capture.file_cadence + return int(round(sum(cadences) / len(cadences))) @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_capture_start_epoch_sec(self, obj: dict[str, Any]) -> int | None: - """Use first channel's start time for composite capture.""" - channels = obj.get("channels") or [] - if not channels: + """Earliest indexed start among channels (epoch seconds).""" + bounds = self._composite_envelope_bounds(obj) + if bounds is None: return None + start_time, _ = bounds + return start_time - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - return capture.start_time + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_iso_utc(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + start_sec, _ = bounds + return _epoch_sec_to_iso_utc_z(start_sec) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_iso_utc(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + _, end_sec = bounds + return _epoch_sec_to_iso_utc_z(end_sec) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_display(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + start_sec, _ = bounds + return _epoch_sec_to_local_display(start_sec) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_display(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + _, end_sec = bounds + return _epoch_sec_to_local_display(end_sec) def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: diff --git a/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py b/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py index a3a10c02..f922a8aa 100644 --- a/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py +++ b/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py @@ -1238,9 +1238,9 @@ def test_large_file_download_redirects_to_sdk(self): def test_get_item_files_with_temporal_bounds_returns_expected_rf_subset(self): """ Task-level test: start_time/end_time flow into _get_item_files. - For DigitalRF captures, ``get_capture_files_with_temporal_filter`` returns - non-DRF capture files (metadata) plus DRF files in the selected time range - (temporal_filtering details are unit-tested in test_temporal_filtering.py). + For DigitalRF captures, ``get_capture_files_with_temporal_filter`` (and + ``filter_files_by_temporal_bounds``) returns non-DRF capture files (metadata) + plus DRF files in the selected time range (see test_temporal_filtering.py). """ # Create DRF-named files for self.capture (epoch 1s..5s) epoch_start_sec = 1 diff --git a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py index 002ff7b7..4f31c031 100644 --- a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py @@ -3,6 +3,8 @@ import time import uuid from collections.abc import Mapping +from datetime import datetime +from datetime import timezone from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -201,6 +203,7 @@ def test_retrieve_latest_file_not_accessible_404(self) -> None: self.list_url, data={"path": self.file.directory}, ) + assert response.data.get("warnings") == [] results = response.data.get("results") assert len(results) == 0, f"Expected no files to be returned, got {results}" @@ -265,6 +268,7 @@ def test_retrieve_latest_file_accessible_200(self) -> None: f"Expected 200, got {response.status_code}" ) response = response.data + assert response.get("warnings") == [] assert "results" in response, ( f"Expected a paginated response with 'results', got: {response}" ) @@ -441,6 +445,73 @@ def test_download_file_no_access_403(self): # Verify 403 response assert response.status_code == status.HTTP_403_FORBIDDEN + def test_list_files_with_temporal_params(self) -> None: + """Temporal query params keep non-RF files and only RF data files in time bounds.""" + base_sec = 1_000_000 + for offset in (0, 1, 2, 5): + create_db_file( + owner=self.user, + extras={ + "directory": self.sds_path, + "name": f"rf@{base_sec + offset}.000.h5", + "media_type": "application/x-hdf5", + }, + ) + + path = str(self.file.directory) + # Absolute epoch ms from ISO datetimes; bounds include rf@(base+1)..rf@(base+2). + start_iso = datetime.fromtimestamp( + base_sec + 1, tz=timezone.utc + ).isoformat() + end_iso = datetime.fromtimestamp( + base_sec + 2, tz=timezone.utc + ).isoformat() + response = self.client.get( + self.list_url, + { + "path": path, + "start_time": start_iso, + "end_time": end_iso, + }, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["warnings"] == [] + names = {row["name"] for row in data["results"]} + assert self.file.name in names + assert {f"rf@{base_sec + 1}.000.h5", f"rf@{base_sec + 2}.000.h5"} <= names + assert f"rf@{base_sec}.000.h5" not in names + assert f"rf@{base_sec + 5}.000.h5" not in names + + def test_list_files_includes_warnings_key(self) -> None: + """Paginated list responses always include ``warnings`` (possibly empty).""" + response = self.client.get(self.list_url) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "warnings" in data + assert data["warnings"] == [] + + def test_list_files_directory_temporal_params_without_rf_data_includes_warning( + self, + ) -> None: + """start_time/end_time on dirs with no RF data: full listing + warning.""" + path = str(self.file.directory) + response = self.client.get( + self.list_url, + { + "path": path, + "start_time": "2020-01-01T00:00:00", + "end_time": "2020-01-02T00:00:00", + }, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "warnings" in data + assert len(data["warnings"]) == 1 + assert "RF data" in data["warnings"][0] + assert data["count"] >= 1 + assert len(data["results"]) >= 1 + def test_list_files_includes_shared_files(self) -> None: """Test that list files includes files from shared captures/datasets.""" @@ -483,6 +554,7 @@ def test_list_files_includes_shared_files(self) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() + assert data.get("warnings") == [] # Should have the original file plus the shared one expected_count = 2 # Original file + shared file assert data["count"] == expected_count, ( @@ -592,6 +664,7 @@ def test_disabled_share_permission_blocks_file_access(self) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() + assert data.get("warnings") == [] # Should only have the original file, not the shared one assert data["count"] == 1, ( f"Expected 1 file (excluding disabled shared), got {data['count']}" diff --git a/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py b/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py index 71cb295c..5abb30df 100644 --- a/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py +++ b/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py @@ -201,6 +201,7 @@ "count": 105, "next": "http://localhost:8000/api/latest/assets/files/?page=2&page_size=3", "previous": None, + "warnings": [], "results": [ { "bucket_name": "spectrumx", diff --git a/gateway/sds_gateway/api_methods/views/file_endpoints.py b/gateway/sds_gateway/api_methods/views/file_endpoints.py index 625d92fc..958d5a38 100644 --- a/gateway/sds_gateway/api_methods/views/file_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/file_endpoints.py @@ -1,10 +1,13 @@ """File operations endpoints for the SDS Gateway API.""" +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING +from typing import Any from typing import cast from django.db.models import CharField +from django.db.models import QuerySet from django.db.models import F as FExpression from django.db.models import ProtectedError from django.db.models import Value as WrappedValue @@ -29,7 +32,9 @@ import sds_gateway.api_methods.utils.swagger_example_schema as example_schema from sds_gateway.api_methods.authentication import APIKeyAuthentication from sds_gateway.api_methods.helpers.download_file import download_file +from sds_gateway.api_methods.helpers.temporal_filtering import filter_files_by_temporal_bounds from sds_gateway.api_methods.models import File +from sds_gateway.api_methods.models import DRF_RF_FILENAME_REGEX_STR from sds_gateway.api_methods.serializers.file_serializers import ( FileCheckResponseSerializer, ) @@ -59,6 +64,17 @@ class FileViewSet(ViewSet): authentication_classes = [APIKeyAuthentication] permission_classes = [IsAuthenticated] + @staticmethod + def _paginated_list_response( + paginator: FilePagination, + serializer_data: Any, + warnings: list[str], + ) -> Response: + """Build paginated file list JSON with a always-present ``warnings`` key.""" + response = paginator.get_paginated_response(serializer_data) + response.data["warnings"] = warnings + return response + @extend_schema( request=FilePostSerializer, responses={ @@ -195,6 +211,18 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: serializer = FileGetSerializer(target_file, many=False) return Response(serializer.data) + + def _datetime_string_to_milliseconds(self, datetime_string: str) -> int: + """Converts a datetime string to milliseconds since start of capture.""" + parsed = datetime.fromisoformat(datetime_string) + return int(parsed.timestamp() * 1000) + + def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: + """Checks if the files include RF data.""" + return files.filter( + name__regex=DRF_RF_FILENAME_REGEX_STR + ).exists() + @extend_schema( responses={ 200: FileGetSerializer, @@ -220,6 +248,25 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: "or an exact file match (directory + name)." ), ), + OpenApiParameter( + name="start_time", + type=OpenApiTypes.DATETIME, + location=OpenApiParameter.QUERY, + required=False, + description=( + "ISO 8601 datetime; converted to ms for temporal filtering of " + "RF data files when the listing includes Digital RF ``.h5`` data." + ), + ), + OpenApiParameter( + name="end_time", + type=OpenApiTypes.DATETIME, + location=OpenApiParameter.QUERY, + required=False, + description=( + "ISO 8601 datetime; paired with start_time for RF temporal bounds." + ), + ), OpenApiParameter( name="page", type=OpenApiTypes.INT, @@ -246,6 +293,20 @@ def list(self, request: Request) -> Response: it will retrieve the most recent one that matches that path. Wildcards are not yet supported. """ + # warnings to be returned in the response + warnings = [] + + # Get optional temporal filtering parameters + # time passed as datetime string, need to convert + # to milliseconds since start of capture + start_time = request.GET.get("start_time", None) + end_time = request.GET.get("end_time", None) + if start_time: + start_time = self._datetime_string_to_milliseconds(start_time) + + if end_time: + end_time = self._datetime_string_to_milliseconds(end_time) + unsafe_path = request.GET.get("path", "/").strip() basename = Path(unsafe_path).name @@ -259,7 +320,7 @@ def list(self, request: Request) -> Response: all_valid_user_files, request=request ) serializer = FileGetSerializer(paginated_files, many=True) - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response(paginator, serializer.data, warnings) # For specific paths, use the existing path-based filtering logic user_rel_path = sanitize_path_rel_to_user( @@ -296,7 +357,9 @@ def list(self, request: Request) -> Response: serializer = FileGetSerializer(paginated_files, many=True) # despite being a single result, we return it paginated for consistency - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response( + paginator, serializer.data, warnings + ) log.debug( "No exact match found for " f"{inferred_user_rel_path!s} and name {basename}", @@ -307,6 +370,28 @@ def list(self, request: Request) -> Response: directory__startswith=str(user_rel_path), ) + if self._check_files_includes_rf_data(files_matching_dir): + if start_time is not None and end_time is not None: + files_matching_dir = filter_files_by_temporal_bounds( + files_matching_dir, + start_time, + end_time, + ) + elif start_time is not None or end_time is not None: + msg = ( + "Both start_time and end_time are required for temporal filtering " + "when listing Digital RF data." + ) + log.warning(msg) + warnings.append(msg) + elif start_time or end_time: + msg = ( + "Temporal filtering is only supported " + "for file directories that include RF data" + ) + log.warning(msg) + warnings.append(msg) + files_matching_dir = files_matching_dir.annotate( path=Concat( FExpression("directory"), @@ -334,7 +419,7 @@ def list(self, request: Request) -> Response: f"user files for path {user_rel_path!s} - returning {len(serializer.data)}", ) - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response(paginator, serializer.data, warnings) @extend_schema( parameters=[ diff --git a/sdk/src/spectrumx/api/sds_files.py b/sdk/src/spectrumx/api/sds_files.py index 9839cdb8..ef6ce2ba 100644 --- a/sdk/src/spectrumx/api/sds_files.py +++ b/sdk/src/spectrumx/api/sds_files.py @@ -6,6 +6,8 @@ import os import tempfile import uuid +from datetime import datetime +from datetime import timezone from enum import Enum from enum import auto from multiprocessing.synchronize import RLock @@ -27,6 +29,15 @@ log.trace("Placeholder log to avoid reimporting or resolving unused import warnings.") +def _file_list_time_query_param(value: datetime) -> str: + """Format a datetime for Gateway file list temporal query params (ISO 8601, UTC).""" + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + else: + value = value.astimezone(timezone.utc) + return value.isoformat() + + class FileUploadMode(Enum): """Modes for uploading files to SDS.""" @@ -109,30 +120,47 @@ def list_files( *, client: Client, sds_path: PurePosixPath | Path | str, + start_time: datetime | None = None, + end_time: datetime | None = None, verbose: bool = False, ) -> Paginator[File]: """Lists files in a given SDS path. Args: sds_path: The virtual directory on SDS to list files from. + start_time: When set, lower bound for Digital RF data files (UTC-aligned ISO + sent to the API). Must be used together with ``end_time``. + end_time: Upper bound for Digital RF data files, paired with ``start_time``. Returns: A paginator for the files in the given SDS path. """ + if (start_time is None) ^ (end_time is None): + msg = "start_time and end_time must both be set or both omitted." + raise ValueError(msg) sds_path = PurePosixPath(sds_path) + start_q: str | None = ( + _file_list_time_query_param(start_time) if start_time is not None else None + ) + end_q: str | None = ( + _file_list_time_query_param(end_time) if end_time is not None else None + ) if client.dry_run: log_user("Dry run enabled: files will be simulated") pagination: Paginator[File] = Paginator( gateway=client._gateway, Entry=File, list_method=client._gateway.list_files, - list_kwargs={"sds_path": sds_path}, + list_kwargs={ + "sds_path": sds_path, + "start_time": start_q, + "end_time": end_q, + }, dry_run=client.dry_run, verbose=verbose, ) return pagination - def upload_file( *, client: Client, diff --git a/sdk/src/spectrumx/client.py b/sdk/src/spectrumx/client.py index 10623283..76428b4c 100644 --- a/sdk/src/spectrumx/client.py +++ b/sdk/src/spectrumx/client.py @@ -1,6 +1,7 @@ """Client for the SpectrumX Data System.""" from collections.abc import Mapping +from datetime import datetime from pathlib import Path from pathlib import PurePosixPath from typing import Any @@ -207,6 +208,8 @@ def download( self, *, from_sds_path: PurePosixPath | Path | str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, to_local_path: Path | str, files_to_download: list[File] | Paginator[File] | None = None, skip_contents: bool = False, @@ -217,6 +220,10 @@ def download( Args: from_sds_path: The virtual directory on SDS to download files from. + start_time: The start time to filter DRF capture files by + (optional, only applicable to DRF capture files). + end_time: The end time to filter DRF capture files by + (optional, only applicable to DRF capture files). to_local_path: The local path to save the downloaded files to. files_to_download: A paginator or list (in dry run mode) of files to download. If not provided, all files in the directory will be @@ -242,6 +249,8 @@ def download( self._prepare_download_directory(to_local_path) files_to_download = self._get_files_to_download( from_sds_path=from_sds_path, + start_time=start_time, + end_time=end_time, files_to_download=files_to_download, verbose=verbose, ) @@ -267,6 +276,8 @@ def _get_files_to_download( self, *, from_sds_path: PurePosixPath | None, + start_time: datetime | None = None, + end_time: datetime | None = None, files_to_download: list[File] | Paginator[File] | None, verbose: bool, ) -> list[File] | Paginator[File]: @@ -280,7 +291,11 @@ def _get_files_to_download( log_user(f"Dry run: discovered {len(files_to_download)} files (samples)") else: if from_sds_path is not None: - files_to_download = self.list_files(sds_path=from_sds_path) + files_to_download = self.list_files( + sds_path=from_sds_path, + start_time=start_time, + end_time=end_time, + ) elif files_to_download is None: error_msg = ( "Either a path in the SDS or a paginator/list of files " @@ -397,18 +412,29 @@ def _is_path_relative_to_target( return local_file_path.is_relative_to(to_local_path) def list_files( - self, sds_path: PurePosixPath | Path | str, *, verbose: bool = False + self, + sds_path: PurePosixPath | Path | str, + start_time: datetime | None = None, + end_time: datetime | None = None, + verbose: bool = False, ) -> Paginator[File]: """Lists files in a given SDS path. Args: sds_path: The virtual directory on SDS to list files from. + start_time: Optional inclusive lower time bound for ``rf@*.h5`` files + (naive datetimes are treated as UTC). Requires ``end_time``. + end_time: Optional inclusive upper time bound; requires ``start_time``. verbose: Show network requests and other info. Returns: A paginator for the files in the given SDS path. """ return self._sds_files.list_files( - client=self, sds_path=sds_path, verbose=verbose + client=self, + sds_path=sds_path, + start_time=start_time, + end_time=end_time, + verbose=verbose, ) def download_file( diff --git a/sdk/src/spectrumx/gateway.py b/sdk/src/spectrumx/gateway.py index ace0d3ea..34cdd879 100644 --- a/sdk/src/spectrumx/gateway.py +++ b/sdk/src/spectrumx/gateway.py @@ -263,21 +263,33 @@ def list_files( sds_path: PurePosixPath | Path | str, page: int = 1, page_size: int = 30, + start_time: str | None = None, + end_time: str | None = None, verbose: bool = False, ) -> bytes: """Lists files from the SDS API. + Args: + start_time: Optional ISO 8601 instant (UTC recommended) for RF temporal + lower bound; must be paired with ``end_time``. + end_time: Optional ISO 8601 instant for RF temporal upper bound. + Returns: The response content from SDS Gateway. """ + params: dict[str, str | int] = { + "page": page, + "page_size": page_size, + "path": str(sds_path), + } + if start_time is not None: + params["start_time"] = start_time + if end_time is not None: + params["end_time"] = end_time response = self._request( method=HTTPMethods.GET, endpoint=Endpoints.FILES, - params={ - "page": page, - "page_size": page_size, - "path": str(sds_path), - }, + params=params, verbose=verbose, ) network.success_or_raise(response, ContextException=FileError) diff --git a/sdk/src/spectrumx/models/captures.py b/sdk/src/spectrumx/models/captures.py index 1d7d7cc7..044afc01 100644 --- a/sdk/src/spectrumx/models/captures.py +++ b/sdk/src/spectrumx/models/captures.py @@ -50,6 +50,16 @@ class CaptureOrigin(StrEnum): _d_is_shared = "Whether the capture is shared" _d_is_shared_with_me = "Whether the capture is shared with the current user" _d_datasets = "Datasets this capture is associated with" +_d_capture_start_iso_utc = ( + "Indexed capture start from OpenSearch as ISO 8601 UTC (when available)" +) +_d_capture_end_iso_utc = "Indexed capture end from OpenSearch as ISO 8601 UTC" +_d_capture_start_display = ( + "Indexed capture start formatted for display (server/local timezone)" +) +_d_capture_end_display = ( + "Indexed capture end formatted for display (server/local timezone)" +) class CaptureFile(BaseModel): @@ -105,6 +115,22 @@ class Capture(SDSModel): str | None, Field(max_length=255, description=_d_channel, default=None) ] scan_group: Annotated[UUID4 | None, Field(description=_d_scan_group, default=None)] + capture_start_iso_utc: Annotated[ + str | None, + Field(description=_d_capture_start_iso_utc, default=None), + ] + capture_end_iso_utc: Annotated[ + str | None, + Field(description=_d_capture_end_iso_utc, default=None), + ] + capture_start_display: Annotated[ + str | None, + Field(description=_d_capture_start_display, default=None), + ] + capture_end_display: Annotated[ + str | None, + Field(description=_d_capture_end_display, default=None), + ] @field_validator("capture_props", mode="before") @classmethod diff --git a/sdk/src/spectrumx/ops/pagination.py b/sdk/src/spectrumx/ops/pagination.py index fdfdee53..909d643f 100644 --- a/sdk/src/spectrumx/ops/pagination.py +++ b/sdk/src/spectrumx/ops/pagination.py @@ -302,7 +302,11 @@ def main() -> None: # pragma: no cover api_key="does-not-matter-in-dry-run", ), list_method=lambda **kwargs: b'{"count": 25, "results": []}', # Mock response - list_kwargs={"sds_path": "/path/to/files"}, + list_kwargs={ + "sds_path": "/path/to/files", + "start_time": None, + "end_time": None, + }, page_size=10, dry_run=True, # in dry-run this should always generate 2.5 pages verbose=True, From 4b59861fd6073ae4d1beec95bab3eddf70b6da0d Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 23 Apr 2026 14:18:29 -0400 Subject: [PATCH 04/10] add sdk tests --- sdk/src/spectrumx/ops/pagination.py | 7 ++ sdk/tests/integration/test_file_ops.py | 96 ++++++++++++++++++++++++++ sdk/tests/ops/test_files.py | 13 +++- sdk/tests/ops/test_paginator.py | 28 ++++++++ 4 files changed, 142 insertions(+), 2 deletions(-) diff --git a/sdk/src/spectrumx/ops/pagination.py b/sdk/src/spectrumx/ops/pagination.py index 909d643f..c96eab58 100644 --- a/sdk/src/spectrumx/ops/pagination.py +++ b/sdk/src/spectrumx/ops/pagination.py @@ -18,6 +18,7 @@ from spectrumx.gateway import GatewayClient from spectrumx.models import SDSModel from spectrumx.ops import files +from spectrumx.utils import log_user_warning if TYPE_CHECKING: from collections.abc import Generator @@ -275,6 +276,12 @@ def _ingest_new_page(self, raw_page: bytes) -> None: if not isinstance(self._current_page_data, dict): # pragma: no cover msg = "Failed to load page data: expected a dictionary from JSON." raise TypeError(msg) + if not self._has_fetched: + raw_warnings = self._current_page_data.get("warnings") + if isinstance(raw_warnings, list): + for w in raw_warnings: + if isinstance(w, str) and w: + log_user_warning(w) if "count" in self._current_page_data: self._total_matches = self._current_page_data["count"] self._current_page_entries = ( diff --git a/sdk/tests/integration/test_file_ops.py b/sdk/tests/integration/test_file_ops.py index bf09516c..d81809f9 100644 --- a/sdk/tests/integration/test_file_ops.py +++ b/sdk/tests/integration/test_file_ops.py @@ -1,7 +1,10 @@ """Integration tests for file operations on SDS.""" +import logging import time import uuid +from datetime import datetime +from datetime import timezone from pathlib import Path from pathlib import PurePosixPath from unittest.mock import patch @@ -17,6 +20,8 @@ from spectrumx.utils import get_random_line from tests.integration.conftest import PassthruEndpoints +from tests.integration.test_captures import _upload_drf_capture_test_assets +from tests.integration.test_captures import drf_channel from tests.test_utils import disable_ssl_warnings BLAKE3_HEX_LEN: int = 64 @@ -710,6 +715,97 @@ def test_file_listing(integration_client: Client, temp_file_tree: Path) -> None: ) +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_rf_narrows_digital_rf_chunks( + integration_client: Client, + drf_sample_top_level_dir: Path, +) -> None: + """Temporal bounds narrow ``rf@*.h5`` listings; non-RF names in the dir stay listed.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip( + "Digital RF sample tree missing; cannot test temporal RF listing." + ) + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=timezone.utc) + end = datetime.fromtimestamp(1719499740, tz=timezone.utc) + all_rf = { + f.name + for f in integration_client.list_files(sds_path=rf_dir) + if f.name.startswith("rf@") + } + narrow_rf = { + f.name + for f in integration_client.list_files( + sds_path=rf_dir, + start_time=start, + end_time=end, + ) + if f.name.startswith("rf@") + } + assert len(all_rf) == 16, f"Expected 16 RF chunks under {rf_dir}, got {all_rf!r}" + assert narrow_rf == {"rf@1719499740.000.h5"} + + +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_non_rf_directory_warns( + caplog: pytest.LogCaptureFixture, + integration_client: Client, + temp_file_tree: Path, +) -> None: + """API warns when temporal params are used on a tree with no RF data files.""" + caplog.set_level(logging.WARNING) + random_subdir_name = get_random_line(10, include_punctuation=False) + sds_path = PurePosixPath("/test-tree-temporal-warn") / random_subdir_name + results = integration_client.upload( + local_path=temp_file_tree, + sds_path=sds_path, + verbose=False, + ) + failures = [result for result in results if not result] + assert not failures, f"Upload failed: {failures}" + start = datetime(2020, 1, 1, tzinfo=timezone.utc) + end = datetime(2020, 1, 2, tzinfo=timezone.utc) + listed = list( + integration_client.list_files( + sds_path=sds_path, + start_time=start, + end_time=end, + ) + ) + assert len(listed) > 0 + assert any("RF data" in r.getMessage() for r in caplog.records), ( + f"expected server warning in logs; got {caplog.records!r}" + ) + + @pytest.mark.integration @pytest.mark.usefixtures("_integration_setup_teardown") @pytest.mark.usefixtures("_without_responses") diff --git a/sdk/tests/ops/test_files.py b/sdk/tests/ops/test_files.py index 9928b1fd..3100c4e3 100644 --- a/sdk/tests/ops/test_files.py +++ b/sdk/tests/ops/test_files.py @@ -5,7 +5,7 @@ import sys import tempfile import uuid as uuidlib -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from pathlib import PurePosixPath @@ -13,7 +13,7 @@ import responses from loguru import logger as log from spectrumx import Client -from spectrumx.api.sds_files import delete_file +from spectrumx.api.sds_files import delete_file, list_files from spectrumx.errors import FileError from spectrumx.gateway import API_TARGET_VERSION from spectrumx.ops.files import ( @@ -103,6 +103,15 @@ def test_get_file_by_id(client: Client, responses: responses.RequestsMock) -> No assert file_sample.uuid == uuid +def test_list_files_start_end_must_be_paired(client: Client) -> None: + """Omitting one of ``start_time`` / ``end_time`` raises before any request.""" + t = datetime(2024, 6, 27, 14, 0, tzinfo=timezone.utc) + with pytest.raises(ValueError, match="both set or both omitted"): + list_files(client=client, sds_path="/", start_time=t, end_time=None) + with pytest.raises(ValueError, match="both set or both omitted"): + list_files(client=client, sds_path="/", start_time=None, end_time=t) + + def test_file_get_returns_valid( client: Client, ) -> None: diff --git a/sdk/tests/ops/test_paginator.py b/sdk/tests/ops/test_paginator.py index 9760b718..fc8cc09c 100644 --- a/sdk/tests/ops/test_paginator.py +++ b/sdk/tests/ops/test_paginator.py @@ -3,6 +3,8 @@ # ruff: noqa: SLF001 # pyright: reportPrivateUsage=false +import json +import logging import uuid from collections.abc import Generator from unittest.mock import MagicMock @@ -12,6 +14,7 @@ from loguru import logger as log from spectrumx.gateway import GatewayClient from spectrumx.models.files import File +from spectrumx.ops import files as sx_files from spectrumx.ops.pagination import Paginator log.trace("Placeholder log avoid reimporting or resolving unused import warnings.") @@ -318,3 +321,28 @@ def _get_raw_page( raw_page_suffix = "]}" raw_page_str = raw_page_prefix + ",".join(first_page_results) + raw_page_suffix return raw_page_str.encode() + + +def test_paginator_logs_api_warnings_on_first_page( + caplog: pytest.LogCaptureFixture, + gateway: GatewayClient, +) -> None: + """Server ``warnings`` on the first JSON page are forwarded via ``log_user_warning``.""" + caplog.set_level(logging.WARNING) + warn_msg = "Temporal filtering is only supported for RF dirs" + one = sx_files.generate_sample_file(uuid.uuid4()) + body: dict[str, object] = { + "count": 1, + "results": [json.loads(one.model_dump_json())], + "warnings": [warn_msg], + } + gateway.list_files.return_value = json.dumps(body).encode() + paginator = Paginator[File]( + Entry=File, + gateway=gateway, + list_method=gateway.list_files, + list_kwargs={"sds_path": "/only-text"}, + dry_run=False, + ) + assert len(paginator) == 1 + assert any(warn_msg in r.getMessage() for r in caplog.records) From 6f0a4dbbcb1913c594140068042e394628d29aee Mon Sep 17 00:00:00 2001 From: klpoland Date: Fri, 24 Apr 2026 12:25:17 -0400 Subject: [PATCH 05/10] refactor composite capture serialized metadata --- .../serializers/capture_serializers.py | 57 +++--- .../test_composite_capture_serialization.py | 182 ++++++++++++++++++ 2 files changed, 214 insertions(+), 25 deletions(-) create mode 100644 gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index 47188e23..cdceb358 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -1,8 +1,8 @@ """Capture serializers for the SDS Gateway API methods.""" import logging +from datetime import UTC from datetime import datetime -from datetime import timezone from typing import Any from typing import cast @@ -34,13 +34,13 @@ def _epoch_sec_to_iso_utc_z(epoch_sec: int) -> str: """Format OpenSearch epoch seconds as an ISO 8601 UTC string with ``Z`` suffix.""" - dt = datetime.fromtimestamp(epoch_sec, tz=timezone.utc) + dt = datetime.fromtimestamp(epoch_sec, tz=UTC) return dt.isoformat().replace("+00:00", "Z") def _epoch_sec_to_local_display(epoch_sec: int) -> str: """Human-readable local time (same pattern as ``formatted_created_at``).""" - dt = datetime.fromtimestamp(epoch_sec, tz=timezone.utc) + dt = datetime.fromtimestamp(epoch_sec, tz=UTC) return django_timezone.localtime(dt).strftime("%m/%d/%Y %I:%M:%S %p") @@ -221,7 +221,7 @@ def get_length_of_capture_ms(self, capture: Capture) -> int | None: """Capture length in milliseconds (OpenSearch bounds are seconds).""" if capture.end_time is None or capture.start_time is None: return None - + return (capture.end_time - capture.start_time) * 1000 @extend_schema_field(serializers.IntegerField(allow_null=True)) @@ -499,15 +499,7 @@ def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: } try: capture = Capture.objects.get(uuid=ch["uuid"]) - start_sec, end_sec = get_capture_bounds( - capture.capture_type, str(capture.uuid) - ) - except ( - ValueError, - IndexError, - KeyError, - Capture.DoesNotExist, - ): + except Capture.DoesNotExist: entry["capture_start_epoch_sec"] = None entry["capture_end_epoch_sec"] = None entry["capture_start_iso_utc"] = None @@ -517,21 +509,36 @@ def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: entry["length_of_capture_ms"] = None entry["file_cadence_ms"] = None else: + # Per-channel bounds/cadence from ``Capture`` (``get_opensearch_metadata``). + start_sec = capture.start_time + end_sec = capture.end_time entry["capture_start_epoch_sec"] = start_sec entry["capture_end_epoch_sec"] = end_sec - entry["capture_start_iso_utc"] = _epoch_sec_to_iso_utc_z(start_sec) - entry["capture_end_iso_utc"] = _epoch_sec_to_iso_utc_z(end_sec) - entry["capture_start_display"] = _epoch_sec_to_local_display( - start_sec + entry["capture_start_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(start_sec) + if start_sec is not None + else None + ) + entry["capture_end_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(end_sec) + if end_sec is not None + else None + ) + entry["capture_start_display"] = ( + _epoch_sec_to_local_display(start_sec) + if start_sec is not None + else None + ) + entry["capture_end_display"] = ( + _epoch_sec_to_local_display(end_sec) + if end_sec is not None + else None ) - entry["capture_end_display"] = _epoch_sec_to_local_display(end_sec) - entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 - try: - entry["file_cadence_ms"] = get_file_cadence( - capture.capture_type, capture - ) - except (ValueError, IndexError, KeyError): - entry["file_cadence_ms"] = None + if start_sec is None or end_sec is None: + entry["length_of_capture_ms"] = None + else: + entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 + entry["file_cadence_ms"] = capture.file_cadence out.append(entry) self._enriched_channels_cache[key] = out return self._enriched_channels_cache[key] diff --git a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py new file mode 100644 index 00000000..b11920c6 --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py @@ -0,0 +1,182 @@ +"""Tests for composite (multi-channel) capture serialization. + +Focused on per-channel metadata: indexed ``channel_metadata`` and OpenSearch-derived +bounds/cadence must stay distinct after ``CompositeCaptureSerializer`` runs, and +top-level summary fields must reflect the envelope across channels. +""" + +from __future__ import annotations + +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from sds_gateway.api_methods.models import Capture +from sds_gateway.api_methods.models import CaptureType +from sds_gateway.api_methods.serializers.capture_serializers import ( + CompositeCaptureSerializer, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + _epoch_sec_to_iso_utc_z, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + build_composite_capture_data, +) +from sds_gateway.api_methods.views.capture_endpoints import _normalize_top_level_dir + +User = get_user_model() + + +def _two_drf_captures_same_group() -> tuple[Capture, Capture]: + """Two DRF captures sharing ``top_level_dir`` (multi-channel group).""" + user = User.objects.create( + email="composite-ser@example.com", + password="testpassword", # noqa: S106 + is_approved=True, + ) + top = _normalize_top_level_dir("test-composite-serialization-group") + cap0 = Capture.objects.create( + capture_type=CaptureType.DigitalRF, + channel="ch0", + index_name="captures-test-drf", + owner=user, + top_level_dir=top, + ) + cap1 = Capture.objects.create( + capture_type=CaptureType.DigitalRF, + channel="ch1", + index_name="captures-test-drf", + owner=user, + top_level_dir=top, + ) + return cap0, cap1 + + +class CompositeCaptureSerializationTests(TestCase): + """Serializer-level tests with OpenSearch and index helpers mocked.""" + + def test_distinct_channel_metadata_preserved(self) -> None: + """Per-channel ``retrieve_indexed_metadata`` payloads stay on each channel row.""" + cap0, cap1 = _two_drf_captures_same_group() + + def fake_retrieve(capture: Capture) -> dict: + return { + "channel_key": capture.channel, + "capture_pk": str(capture.uuid), + } + + with patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + side_effect=fake_retrieve, + ): + composite = build_composite_capture_data([cap0, cap1]) + + with patch.object(Capture, "get_opensearch_metadata", return_value={}): + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["channel_metadata"] == { + "channel_key": "ch0", + "capture_pk": str(cap0.uuid), + } + assert ch_rows["ch1"]["channel_metadata"] == { + "channel_key": "ch1", + "capture_pk": str(cap1.uuid), + } + # Ensure we did not collapse to a single shared dict object + assert ( + ch_rows["ch0"]["channel_metadata"] is not ch_rows["ch1"]["channel_metadata"] + ) + + def test_distinct_opensearch_times_cadence_and_envelope(self) -> None: + """Each channel keeps its own bounds/cadence; top-level fields use min start / max end.""" + cap0, cap1 = _two_drf_captures_same_group() + + meta_by_uuid = { + str(cap0.uuid): { + "start_time": 1_700_000_000, + "end_time": 1_700_000_100, + "file_cadence": 400, + }, + str(cap1.uuid): { + "start_time": 1_700_000_050, + "end_time": 1_700_000_200, + "file_cadence": 800, + }, + } + + def opensearch_by_instance(self: Capture) -> dict: + return dict(meta_by_uuid[str(self.uuid)]) + + with patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ): + composite = build_composite_capture_data([cap0, cap1]) + + with patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ): + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["capture_start_epoch_sec"] == 1_700_000_000 + assert ch_rows["ch0"]["capture_end_epoch_sec"] == 1_700_000_100 + assert ch_rows["ch0"]["length_of_capture_ms"] == 100_000 + assert ch_rows["ch0"]["file_cadence_ms"] == 400 + + assert ch_rows["ch1"]["capture_start_epoch_sec"] == 1_700_000_050 + assert ch_rows["ch1"]["capture_end_epoch_sec"] == 1_700_000_200 + assert ch_rows["ch1"]["length_of_capture_ms"] == 150_000 + assert ch_rows["ch1"]["file_cadence_ms"] == 800 + + assert out["capture_start_epoch_sec"] == 1_700_000_000 + # Composite serializer exposes end time via ISO/display, not epoch field + assert out["capture_end_iso_utc"] == _epoch_sec_to_iso_utc_z(1_700_000_200) + assert out["length_of_capture_ms"] == 200_000 + assert out["file_cadence_ms"] == 600 + + def test_channel_with_incomplete_bounds_excluded_from_envelope(self) -> None: + """Channels missing start or end do not contribute to composite envelope; length null.""" + cap0, cap1 = _two_drf_captures_same_group() + + def opensearch_by_instance(self: Capture) -> dict: + if self.uuid == cap0.uuid: + return { + "start_time": 1_800_000_000, + "end_time": 1_800_000_030, + "file_cadence": 100, + } + return { + "start_time": 1_800_000_010, + "end_time": None, + "file_cadence": 200, + } + + with patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ): + composite = build_composite_capture_data([cap0, cap1]) + + with patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ): + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["length_of_capture_ms"] == 30_000 + assert ch_rows["ch1"]["capture_end_epoch_sec"] is None + assert ch_rows["ch1"]["length_of_capture_ms"] is None + # Envelope from only the complete channel + assert out["capture_start_epoch_sec"] == 1_800_000_000 + assert out["capture_end_iso_utc"] == _epoch_sec_to_iso_utc_z(1_800_000_030) + assert out["length_of_capture_ms"] == 30_000 From c54bf3f9f39da2c4adbd0c6816260ab3e1b1f030 Mon Sep 17 00:00:00 2001 From: klpoland Date: Fri, 24 Apr 2026 16:26:29 -0400 Subject: [PATCH 06/10] add/enhance tests --- sdk/docs/mkdocs/changelog.md | 3 +- sdk/src/spectrumx/api/datasets.py | 11 +-- sdk/src/spectrumx/models/capture_enums.py | 21 +++++ sdk/src/spectrumx/models/captures.py | 26 ++---- sdk/src/spectrumx/models/datasets.py | 5 +- sdk/tests/integration/test_file_ops.py | 97 +++++++++++++++++++++++ sdk/tests/ops/test_files.py | 82 ++++++++++++++++++- sdk/tests/ops/test_paginator.py | 52 ++++++++++++ 8 files changed, 262 insertions(+), 35 deletions(-) create mode 100644 sdk/src/spectrumx/models/capture_enums.py diff --git a/sdk/docs/mkdocs/changelog.md b/sdk/docs/mkdocs/changelog.md index a20a42ad..01f61341 100644 --- a/sdk/docs/mkdocs/changelog.md +++ b/sdk/docs/mkdocs/changelog.md @@ -6,14 +6,15 @@ + [**Added `delete` and `revoke_share_permissions` methods for datasets**](https://github.com/spectrumx/sds-code/pull/275): this allows users to (soft) delete datasets in the SDS through the SDK and revoke ALL share permissions from datasets if needed before deletion or in general. + [**Added `revoke_share_permissions` and `detach_from_datasets` methods to captures**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to revoke share permissions or detach captures from connected datasets when they need to delete a capture. + [**Added `detach_from_datasets` methods to files**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to detach files from connected datasets when they need to delete them. Note: Files CANNOT be detached from captures. Delete a the parent capture FIRST to delete the file. + + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + Observability: + [**Added additional fields displaying ownership, share permission, and asset connection information to SDK models**](https://github.com/spectrumx/sds-code/pull/275): this allows users to see more relevant information when retrieving or listing assets like whether they are shared, who they are shared with, who owns them, and what other assets the target is attached to (like files to captures and datasets). + New file attributes: `owner`, `captures`, `datasets` + New capture attributes: `owner`, `is_shared`, `is_shared_with_me`, `share_permissions` + New dataset attributes: `owner`, `is_shared`, `share_permissions`, `datasets` - + [**Added new models for User and UserSharePermission**](https://github.com/spectrumx/sds-code/pull/275): This allows for visibility into the users and share permissions connected to assets. + + [**Added `captures` and `files` attributes to `Dataset` model**](): This allows visibility from the dataset side into attached captures and files. ## `0.1.17` - 2025-12-20 diff --git a/sdk/src/spectrumx/api/datasets.py b/sdk/src/spectrumx/api/datasets.py index a25cd331..ee8ef8ae 100644 --- a/sdk/src/spectrumx/api/datasets.py +++ b/sdk/src/spectrumx/api/datasets.py @@ -14,11 +14,6 @@ from spectrumx.ops.pagination import Paginator from spectrumx.utils import log_user -if TYPE_CHECKING: - from uuid import UUID - - from spectrumx.gateway import GatewayClient - class DatasetAPI: gateway: GatewayClient @@ -93,7 +88,7 @@ def list_artifact_files(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: def get_files( self, - dataset_uuid: UUID, + dataset_uuid: uuid.UUID, ) -> Paginator[File]: """Get files in the dataset as a paginator. @@ -119,7 +114,7 @@ def get_files( def delete( self, - dataset_uuid: UUID, + dataset_uuid: uuid.UUID, ) -> bool: """Deletes a dataset from SDS by its UUID. @@ -142,7 +137,7 @@ def delete( log.debug(f"Dataset deleted with UUID {dataset_uuid}") return True - def revoke_share_permissions(self, dataset_uuid: UUID) -> bool: + def revoke_share_permissions(self, dataset_uuid: uuid.UUID) -> bool: """Revoke all direct share permissions on this dataset (owner-only). Use this (or the web portal) before :meth:`delete` when the dataset is shared. diff --git a/sdk/src/spectrumx/models/capture_enums.py b/sdk/src/spectrumx/models/capture_enums.py new file mode 100644 index 00000000..6f9a7ce1 --- /dev/null +++ b/sdk/src/spectrumx/models/capture_enums.py @@ -0,0 +1,21 @@ +"""Capture type/origin enums (shared by captures and datasets models).""" + +from enum import StrEnum + + +class CaptureType(StrEnum): + """Capture types in SDS.""" + + DigitalRF = "drf" + RadioHound = "rh" + SigMF = "sigmf" + + +class CaptureOrigin(StrEnum): + """Capture origins in SDS.""" + + System = "system" + User = "user" + + +__all__ = ["CaptureOrigin", "CaptureType"] diff --git a/sdk/src/spectrumx/models/captures.py b/sdk/src/spectrumx/models/captures.py index 044afc01..e7efcea4 100644 --- a/sdk/src/spectrumx/models/captures.py +++ b/sdk/src/spectrumx/models/captures.py @@ -1,7 +1,6 @@ """Capture model for SpectrumX.""" from datetime import datetime -from enum import StrEnum from pathlib import Path from pathlib import PurePosixPath from typing import Annotated @@ -14,26 +13,12 @@ from pydantic import field_validator from spectrumx.models.base import SDSModel +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.datasets import Dataset from spectrumx.models.user import User from spectrumx.models.user import UserSharePermission - -class CaptureType(StrEnum): - """Capture types in SDS.""" - - DigitalRF = "drf" - RadioHound = "rh" - SigMF = "sigmf" - - -class CaptureOrigin(StrEnum): - """Capture origins in SDS.""" - - System = "system" - User = "user" - - _d_capture_created_at = "The time the capture was created" _d_capture_props = "The indexed metadata for the capture" _d_capture_type = f"The type of capture {', '.join([x.value for x in CaptureType])}" @@ -101,8 +86,10 @@ class Capture(SDSModel): list[UserSharePermission], Field(description=_d_share_permissions, default_factory=list), ] - is_shared: Annotated[bool, Field(description=_d_is_shared)] - is_shared_with_me: Annotated[bool, Field(description=_d_is_shared_with_me)] + is_shared: Annotated[bool, Field(description=_d_is_shared, default=False)] + is_shared_with_me: Annotated[ + bool, Field(description=_d_is_shared_with_me, default=False) + ] # optional fields created_at: Annotated[ @@ -178,5 +165,6 @@ def __repr__(self) -> str: __all__ = [ "Capture", + "CaptureOrigin", "CaptureType", ] diff --git a/sdk/src/spectrumx/models/datasets.py b/sdk/src/spectrumx/models/datasets.py index a57c8ea8..5b286131 100644 --- a/sdk/src/spectrumx/models/datasets.py +++ b/sdk/src/spectrumx/models/datasets.py @@ -7,9 +7,8 @@ from pydantic import BaseModel from pydantic import ConfigDict -from spectrumx.models.captures import CaptureType -from spectrumx.models.captures import CaptureOrigin -from spectrumx.models.user import UserSharePermission +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.user import User from spectrumx.models.user import UserSharePermission diff --git a/sdk/tests/integration/test_file_ops.py b/sdk/tests/integration/test_file_ops.py index d81809f9..802218b8 100644 --- a/sdk/tests/integration/test_file_ops.py +++ b/sdk/tests/integration/test_file_ops.py @@ -762,6 +762,103 @@ def test_list_files_temporal_rf_narrows_digital_rf_chunks( assert narrow_rf == {"rf@1719499740.000.h5"} +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_rf_inclusive_range_multiple_chunks( + integration_client: Client, + drf_sample_top_level_dir: Path, +) -> None: + """Temporal window spanning several seconds includes each ``rf@*.h5`` in that span.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip( + "Digital RF sample tree missing; cannot test temporal RF listing." + ) + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=timezone.utc) + end = datetime.fromtimestamp(1719499742, tz=timezone.utc) + expected = { + "rf@1719499740.000.h5", + "rf@1719499741.000.h5", + "rf@1719499742.000.h5", + } + narrow_rf = { + f.name + for f in integration_client.list_files( + sds_path=rf_dir, + start_time=start, + end_time=end, + ) + if f.name.startswith("rf@") + } + assert narrow_rf == expected + + +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + *PassthruEndpoints.file_content_download(), + ] + ], + indirect=True, +) +def test_download_respects_temporal_rf_window( + integration_client: Client, + drf_sample_top_level_dir: Path, + tmp_path: Path, +) -> None: + """Bulk download with start/end only fetches ``rf@*.h5`` in the UTC window.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip( + "Digital RF sample tree missing; cannot test temporal RF download." + ) + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=timezone.utc) + end = datetime.fromtimestamp(1719499741, tz=timezone.utc) + expected_names = {"rf@1719499740.000.h5", "rf@1719499741.000.h5"} + + download_dir = tmp_path / "drf_partial" + results = integration_client.download( + from_sds_path=rf_dir, + start_time=start, + end_time=end, + to_local_path=download_dir, + verbose=False, + ) + failures = [r.error_info for r in results if not r] + assert not failures, f"download failures: {failures}" + downloaded_rf = { + p.name + for p in download_dir.rglob("*") + if p.is_file() and p.name.startswith("rf@") and p.name.endswith(".h5") + } + assert downloaded_rf == expected_names + + @pytest.mark.integration @pytest.mark.usefixtures("_integration_setup_teardown") @pytest.mark.usefixtures("_without_responses") diff --git a/sdk/tests/ops/test_files.py b/sdk/tests/ops/test_files.py index 3100c4e3..97b704eb 100644 --- a/sdk/tests/ops/test_files.py +++ b/sdk/tests/ops/test_files.py @@ -5,15 +5,22 @@ import sys import tempfile import uuid as uuidlib -from datetime import datetime, timezone +from datetime import datetime +from datetime import timedelta +from datetime import timezone from pathlib import Path from pathlib import PurePosixPath +from unittest.mock import patch import pytest import responses from loguru import logger as log from spectrumx import Client -from spectrumx.api.sds_files import delete_file, list_files +from spectrumx.api.sds_files import delete_file +from spectrumx.api.sds_files import list_files +from spectrumx.api.sds_files import ( + _file_list_time_query_param, # noqa: SLF001 +) from spectrumx.errors import FileError from spectrumx.gateway import API_TARGET_VERSION from spectrumx.ops.files import ( @@ -106,12 +113,79 @@ def test_get_file_by_id(client: Client, responses: responses.RequestsMock) -> No def test_list_files_start_end_must_be_paired(client: Client) -> None: """Omitting one of ``start_time`` / ``end_time`` raises before any request.""" t = datetime(2024, 6, 27, 14, 0, tzinfo=timezone.utc) - with pytest.raises(ValueError, match="both set or both omitted"): + with pytest.raises(ValueError, match="both be set or both omitted"): list_files(client=client, sds_path="/", start_time=t, end_time=None) - with pytest.raises(ValueError, match="both set or both omitted"): + with pytest.raises(ValueError, match="both be set or both omitted"): list_files(client=client, sds_path="/", start_time=None, end_time=t) +def test_file_list_time_query_param_naive_treated_as_utc() -> None: + """Naive datetimes are interpreted as UTC for query strings.""" + naive = datetime(2024, 6, 27, 14, 9, 0) + out = _file_list_time_query_param(naive) + assert out.endswith("+00:00") + assert "2024-06-27T14:09:00" in out + + +def test_file_list_time_query_param_converts_non_utc_to_utc() -> None: + """Aware datetimes are formatted in UTC (same instant).""" + eastern = timezone(timedelta(hours=-5)) + local = datetime(2024, 6, 27, 9, 9, 0, tzinfo=eastern) + out = _file_list_time_query_param(local) + assert "2024-06-27T14:09:00" in out + assert out.endswith("+00:00") + + +def test_list_files_passes_iso_temporal_params_to_gateway(client: Client) -> None: + """Temporal bounds are forwarded to ``gateway.list_files`` as ISO strings.""" + client.dry_run = False + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=timezone.utc) + end = datetime(2024, 6, 27, 14, 10, 0, tzinfo=timezone.utc) + expected_start = _file_list_time_query_param(start) + expected_end = _file_list_time_query_param(end) + empty_page = b'{"count": 0, "results": []}' + + with patch.object(client._gateway, "list_files", return_value=empty_page) as m: + paginator = list_files( + client=client, + sds_path=PurePosixPath("/drf/dir"), + start_time=start, + end_time=end, + ) + list(paginator) + + m.assert_called() + for call in m.call_args_list: + kw = call.kwargs + assert kw["start_time"] == expected_start + assert kw["end_time"] == expected_end + assert kw["sds_path"] == PurePosixPath("/drf/dir") + + +def test_client_download_forwards_temporal_bounds_to_list_files( + client: Client, tmp_path: Path +) -> None: + """``Client.download(..., start_time=, end_time=)`` lists with the same bounds.""" + client.dry_run = False + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=timezone.utc) + end = datetime(2024, 6, 27, 14, 11, 0, tzinfo=timezone.utc) + sds = PurePosixPath("/capture/rf") + + with patch.object(client, "list_files", return_value=[]) as m_list: + client.download( + from_sds_path=sds, + start_time=start, + end_time=end, + to_local_path=tmp_path, + verbose=False, + ) + + m_list.assert_called_once() + assert m_list.call_args.kwargs["sds_path"] == sds + assert m_list.call_args.kwargs["start_time"] == start + assert m_list.call_args.kwargs["end_time"] == end + + def test_file_get_returns_valid( client: Client, ) -> None: diff --git a/sdk/tests/ops/test_paginator.py b/sdk/tests/ops/test_paginator.py index fc8cc09c..1248b08a 100644 --- a/sdk/tests/ops/test_paginator.py +++ b/sdk/tests/ops/test_paginator.py @@ -7,6 +7,7 @@ import logging import uuid from collections.abc import Generator +from pathlib import PurePosixPath from unittest.mock import MagicMock from unittest.mock import patch @@ -99,6 +100,57 @@ def test_paginator_dry_run_ingest_list_files(gateway: GatewayClient) -> None: ) +def test_paginator_preserves_temporal_kwargs_across_pages( + gateway: GatewayClient +) -> None: + """ + ``start_time`` / ``end_time`` in ``list_kwargs`` + are sent on every ``list_files`` call. + """ + one = sx_files.generate_sample_file(uuid.uuid4()) + two = sx_files.generate_sample_file(uuid.uuid4()) + three = sx_files.generate_sample_file(uuid.uuid4()) + start_iso = "2024-06-27T14:09:00+00:00" + end_iso = "2024-06-27T14:11:00+00:00" + recorded: list[dict[str, object]] = [] + + def side_effect(**kwargs: object) -> bytes: + recorded.append(dict(kwargs)) + page = kwargs["page"] + if page == 1: + results = [ + json.loads(one.model_dump_json()), + json.loads(two.model_dump_json()), + ] + else: + results = [json.loads(three.model_dump_json())] + body = {"count": 3, "results": results} + return json.dumps(body).encode() + + gateway.list_files.side_effect = side_effect + + paginator = Paginator[File]( + Entry=File, + gateway=gateway, + list_method=gateway.list_files, + list_kwargs={ + "sds_path": PurePosixPath("/drf"), + "start_time": start_iso, + "end_time": end_iso, + }, + page_size=2, + dry_run=False, + ) + + consumed = list(paginator) + assert len(consumed) == 3 + assert len(recorded) == 2 + for call_kw in recorded: + assert call_kw["start_time"] == start_iso + assert call_kw["end_time"] == end_iso + assert call_kw["sds_path"] == PurePosixPath("/drf") + + def test_paginator_dry_run_ingest_get_dataset_files(gateway: GatewayClient) -> None: """Tests the dry-run mode of the paginator for get_dataset_files.""" page_size = 3 From 8e21bca0c43bb6646e352ef5fb7d0a4c19618f66 Mon Sep 17 00:00:00 2001 From: klpoland Date: Fri, 24 Apr 2026 16:51:40 -0400 Subject: [PATCH 07/10] linting --- gateway/pyproject.toml | 5 +++ .../api_methods/helpers/temporal_filtering.py | 16 +++----- .../serializers/capture_serializers.py | 8 ++-- .../serializers/dataset_serializers.py | 29 +++++++++----- .../test_composite_capture_serialization.py | 6 +-- .../api_methods/tests/test_file_endpoints.py | 12 ++---- .../api_methods/views/dataset_endpoints.py | 6 ++- .../api_methods/views/file_endpoints.py | 15 ++++--- sdk/docs/mkdocs/changelog.md | 4 +- sdk/pyproject.toml | 8 +++- sdk/src/spectrumx/api/datasets.py | 20 ++++++---- sdk/src/spectrumx/api/sds_files.py | 8 ++-- sdk/src/spectrumx/client.py | 5 ++- sdk/src/spectrumx/models/datasets.py | 2 + sdk/tests/integration/test_file_ops.py | 39 +++++++++---------- sdk/tests/ops/test_files.py | 23 ++++++----- sdk/tests/ops/test_paginator.py | 12 +++--- 17 files changed, 118 insertions(+), 100 deletions(-) diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 42663a3e..a885bf39 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -630,6 +630,11 @@ # Controls PLR0913 max-args = 9 + [tool.ruff.lint.per-file-ignores] + "sds_gateway/api_methods/tests/test_composite_capture_serialization.py" = [ + "PLR2004", + ] + [tool.ruff.format] indent-style = "space" line-ending = "auto" diff --git a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py index 50753fbc..4cf055ba 100644 --- a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py +++ b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py @@ -28,7 +28,7 @@ def _catch_value_errors(capture_type: CaptureType, capture: Capture) -> None: msg = "Only DigitalRF captures are supported for temporal filtering." log.error(msg) raise ValueError(msg) - + if capture.start_time is None: msg = f"Capture {capture.uuid} has no indexed start_time for temporal filtering" raise ValueError(msg) @@ -42,7 +42,7 @@ def get_capture_files_with_temporal_filter( ) -> QuerySet[File]: """Get the capture files with temporal filtering.""" _catch_value_errors(capture_type, capture) - + capture_files = get_capture_files(capture) if start_time is None or end_time is None: @@ -69,15 +69,11 @@ def filter_files_by_temporal_bounds( end_time: int, ) -> QuerySet[File]: """Filter files by temporal bounds.""" - + # get non-data files - non_data_files = files.exclude( - name__regex=DRF_RF_FILENAME_REGEX_STR - ) + non_data_files = files.exclude(name__regex=DRF_RF_FILENAME_REGEX_STR) - unfiltered_data_files = files.filter( - name__regex=DRF_RF_FILENAME_REGEX_STR - ) + unfiltered_data_files = files.filter(name__regex=DRF_RF_FILENAME_REGEX_STR) start_file_name = drf_rf_filename_from_ms(start_time) end_file_name = drf_rf_filename_from_ms(end_time) @@ -88,4 +84,4 @@ def filter_files_by_temporal_bounds( ).order_by("name") # return all files - return non_data_files.union(filtered_data_files) \ No newline at end of file + return non_data_files.union(filtered_data_files) diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index cdceb358..feac4e2f 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -115,7 +115,7 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): post_processed_data = serializers.SerializerMethodField() def get_datasets(self, capture: Capture) -> list[dict[str, Any]]: - """Datasets linked to this capture; shallow when serializing under dataset detail.""" + """Datasets linked to this capture; shallow under dataset detail.""" qs = get_capture_datasets(capture, include_deleted=False) if self.context.get("omit_nested_dataset_graph"): return DatasetSummarySerializer(qs, many=True, context=self.context).data @@ -509,7 +509,7 @@ def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: entry["length_of_capture_ms"] = None entry["file_cadence_ms"] = None else: - # Per-channel bounds/cadence from ``Capture`` (``get_opensearch_metadata``). + # Per-channel bounds/cadence (Capture.get_opensearch_metadata). start_sec = capture.start_time end_sec = capture.end_time entry["capture_start_epoch_sec"] = start_sec @@ -547,7 +547,7 @@ def _composite_envelope_bounds( self, obj: dict[str, Any], ) -> tuple[int, int] | None: - """Earliest channel start and latest channel end (seconds), for composite summary.""" + """Earliest channel start and latest channel end (seconds).""" pairs = [ (row["capture_start_epoch_sec"], row["capture_end_epoch_sec"]) for row in self._enriched_channels(obj) @@ -678,7 +678,7 @@ def get_file_cadence_ms(self, obj: dict[str, Any]) -> int | None: ] if not cadences: return None - return int(round(sum(cadences) / len(cadences))) + return round(sum(cadences) / len(cadences)) @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_capture_start_epoch_sec(self, obj: dict[str, Any]) -> int | None: diff --git a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py index 4b22f4fa..744dbc93 100644 --- a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py @@ -2,17 +2,26 @@ from rest_framework import serializers +from sds_gateway.api_methods.helpers.search_captures import ( + group_captures_by_top_level_dir, +) from sds_gateway.api_methods.models import Dataset from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import PermissionLevel from sds_gateway.api_methods.models import UserSharePermission -from sds_gateway.api_methods.utils.asset_access_control import check_if_shared -from sds_gateway.api_methods.helpers.search_captures import group_captures_by_top_level_dir -from sds_gateway.api_methods.serializers.capture_serializers import build_composite_capture_data -from sds_gateway.api_methods.serializers.capture_serializers import serialize_capture_or_composite -from sds_gateway.api_methods.serializers.file_serializers import FileArtifactSummarySerializer +from sds_gateway.api_methods.serializers.capture_serializers import ( + build_composite_capture_data, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + serialize_capture_or_composite, +) +from sds_gateway.api_methods.serializers.file_serializers import ( + FileArtifactSummarySerializer, +) from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer -from sds_gateway.api_methods.serializers.user_serializer import UserSharePermissionSerializer +from sds_gateway.api_methods.serializers.user_serializer import ( + UserSharePermissionSerializer, +) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared from sds_gateway.api_methods.utils.relationship_utils import get_dataset_artifact_files from sds_gateway.api_methods.utils.relationship_utils import get_dataset_captures @@ -21,7 +30,7 @@ class DatasetSummarySerializer(serializers.ModelSerializer[Dataset]): - """Minimal dataset shape for capture ``datasets`` when breaking serializer cycles.""" + """Minimal dataset shape for capture ``datasets`` (breaks serializer cycles).""" class Meta: model = Dataset @@ -147,7 +156,7 @@ def get_share_permissions(self, obj): is_enabled=True, ) return UserSharePermissionSerializer(user_share_permissions, many=True).data - + def get_files(self, obj: Dataset) -> list[dict]: """Get the files for the dataset. @@ -164,9 +173,9 @@ def get_files(self, obj: Dataset) -> list[dict]: context=self.context, ) return serializer.data - + def get_captures(self, obj: Dataset) -> list[dict]: - """Get captures for the dataset, one entry per logical capture (list API semantics). + """Captures for the dataset (one row per logical capture, list API semantics). Multi-channel uploads share ``top_level_dir``; those rows are merged into a single composite payload like :func:`get_composite_captures`. diff --git a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py index b11920c6..adacaa80 100644 --- a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py +++ b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py @@ -57,7 +57,7 @@ class CompositeCaptureSerializationTests(TestCase): """Serializer-level tests with OpenSearch and index helpers mocked.""" def test_distinct_channel_metadata_preserved(self) -> None: - """Per-channel ``retrieve_indexed_metadata`` payloads stay on each channel row.""" + """Per-channel indexed metadata payloads stay on each channel row.""" cap0, cap1 = _two_drf_captures_same_group() def fake_retrieve(capture: Capture) -> dict: @@ -91,7 +91,7 @@ def fake_retrieve(capture: Capture) -> dict: ) def test_distinct_opensearch_times_cadence_and_envelope(self) -> None: - """Each channel keeps its own bounds/cadence; top-level fields use min start / max end.""" + """Per-channel bounds/cadence; top-level uses min start and max end.""" cap0, cap1 = _two_drf_captures_same_group() meta_by_uuid = { @@ -142,7 +142,7 @@ def opensearch_by_instance(self: Capture) -> dict: assert out["file_cadence_ms"] == 600 def test_channel_with_incomplete_bounds_excluded_from_envelope(self) -> None: - """Channels missing start or end do not contribute to composite envelope; length null.""" + """Incomplete channel bounds are excluded from the composite envelope.""" cap0, cap1 = _two_drf_captures_same_group() def opensearch_by_instance(self: Capture) -> dict: diff --git a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py index 4f31c031..97b0f55f 100644 --- a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py @@ -3,8 +3,8 @@ import time import uuid from collections.abc import Mapping +from datetime import UTC from datetime import datetime -from datetime import timezone from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -446,7 +446,7 @@ def test_download_file_no_access_403(self): assert response.status_code == status.HTTP_403_FORBIDDEN def test_list_files_with_temporal_params(self) -> None: - """Temporal query params keep non-RF files and only RF data files in time bounds.""" + """Temporal params keep non-RF files; RF listings respect time bounds.""" base_sec = 1_000_000 for offset in (0, 1, 2, 5): create_db_file( @@ -460,12 +460,8 @@ def test_list_files_with_temporal_params(self) -> None: path = str(self.file.directory) # Absolute epoch ms from ISO datetimes; bounds include rf@(base+1)..rf@(base+2). - start_iso = datetime.fromtimestamp( - base_sec + 1, tz=timezone.utc - ).isoformat() - end_iso = datetime.fromtimestamp( - base_sec + 2, tz=timezone.utc - ).isoformat() + start_iso = datetime.fromtimestamp(base_sec + 1, tz=UTC).isoformat() + end_iso = datetime.fromtimestamp(base_sec + 2, tz=UTC).isoformat() response = self.client.get( self.list_url, { diff --git a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py index 312f95d9..e01b8873 100644 --- a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py @@ -49,13 +49,15 @@ def _get_file_objects(self, dataset: Dataset) -> QuerySet[File]: ), ], responses={ - 200: OpenApiResponse(description="Dataset metadata, captures, and artifact files"), + 200: OpenApiResponse( + description=("Dataset metadata, captures, and direct artifact files"), + ), 403: OpenApiResponse(description="Forbidden"), 404: OpenApiResponse(description="Not Found"), }, description=( "Return dataset metadata with captures (one row per logical capture, " - "including composite multi-channel) and artifact files linked directly to the dataset." + "including composite multi-channel) and artifact files on the dataset." ), summary="Retrieve Dataset", ) diff --git a/gateway/sds_gateway/api_methods/views/file_endpoints.py b/gateway/sds_gateway/api_methods/views/file_endpoints.py index 958d5a38..b0e752d8 100644 --- a/gateway/sds_gateway/api_methods/views/file_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/file_endpoints.py @@ -7,9 +7,9 @@ from typing import cast from django.db.models import CharField -from django.db.models import QuerySet from django.db.models import F as FExpression from django.db.models import ProtectedError +from django.db.models import QuerySet from django.db.models import Value as WrappedValue from django.db.models.functions import Concat from django.http import HttpResponse @@ -32,9 +32,11 @@ import sds_gateway.api_methods.utils.swagger_example_schema as example_schema from sds_gateway.api_methods.authentication import APIKeyAuthentication from sds_gateway.api_methods.helpers.download_file import download_file -from sds_gateway.api_methods.helpers.temporal_filtering import filter_files_by_temporal_bounds -from sds_gateway.api_methods.models import File +from sds_gateway.api_methods.helpers.temporal_filtering import ( + filter_files_by_temporal_bounds, +) from sds_gateway.api_methods.models import DRF_RF_FILENAME_REGEX_STR +from sds_gateway.api_methods.models import File from sds_gateway.api_methods.serializers.file_serializers import ( FileCheckResponseSerializer, ) @@ -211,7 +213,6 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: serializer = FileGetSerializer(target_file, many=False) return Response(serializer.data) - def _datetime_string_to_milliseconds(self, datetime_string: str) -> int: """Converts a datetime string to milliseconds since start of capture.""" parsed = datetime.fromisoformat(datetime_string) @@ -219,9 +220,7 @@ def _datetime_string_to_milliseconds(self, datetime_string: str) -> int: def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: """Checks if the files include RF data.""" - return files.filter( - name__regex=DRF_RF_FILENAME_REGEX_STR - ).exists() + return files.filter(name__regex=DRF_RF_FILENAME_REGEX_STR).exists() @extend_schema( responses={ @@ -285,7 +284,7 @@ def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: ), ], ) - def list(self, request: Request) -> Response: + def list(self, request: Request) -> Response: # noqa: C901 """ Lists all files accessible to the user (owned + shared via captures/datasets). When `path` is passed, it filters all files matching that subdirectory. diff --git a/sdk/docs/mkdocs/changelog.md b/sdk/docs/mkdocs/changelog.md index 01f61341..2e502e4c 100644 --- a/sdk/docs/mkdocs/changelog.md +++ b/sdk/docs/mkdocs/changelog.md @@ -6,7 +6,7 @@ + [**Added `delete` and `revoke_share_permissions` methods for datasets**](https://github.com/spectrumx/sds-code/pull/275): this allows users to (soft) delete datasets in the SDS through the SDK and revoke ALL share permissions from datasets if needed before deletion or in general. + [**Added `revoke_share_permissions` and `detach_from_datasets` methods to captures**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to revoke share permissions or detach captures from connected datasets when they need to delete a capture. + [**Added `detach_from_datasets` methods to files**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to detach files from connected datasets when they need to delete them. Note: Files CANNOT be detached from captures. Delete a the parent capture FIRST to delete the file. - + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + Observability: + [**Added additional fields displaying ownership, share permission, and asset connection information to SDK models**](https://github.com/spectrumx/sds-code/pull/275): this allows users to see more relevant information when retrieving or listing assets like whether they are shared, who they are shared with, who owns them, and what other assets the target is attached to (like files to captures and datasets). @@ -14,7 +14,7 @@ + New capture attributes: `owner`, `is_shared`, `is_shared_with_me`, `share_permissions` + New dataset attributes: `owner`, `is_shared`, `share_permissions`, `datasets` + [**Added new models for User and UserSharePermission**](https://github.com/spectrumx/sds-code/pull/275): This allows for visibility into the users and share permissions connected to assets. - + [**Added `captures` and `files` attributes to `Dataset` model**](): This allows visibility from the dataset side into attached captures and files. + + [**Added `captures` and `files` attributes to `Dataset` model**](https://github.com/spectrumx/sds-code/pull/275): This allows visibility from the dataset side into attached captures and files. ## `0.1.17` - 2025-12-20 diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index a1c60c8b..cc0b9db9 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -285,9 +285,13 @@ max-args = 9 [tool.ruff.lint.per-file-ignores] - "*.ipynb" = ["E501", "ERA001", "PLR2004", "T201"] + "*.ipynb" = ["E501", "ERA001", "PLR2004", "T201"] + # Monorepo: pre-commit passes `gateway/...` from repo root; local runs may use `../gateway/...` from sdk/. + "**/test_composite_capture_serialization.py" = ["PLR2004"] + "../gateway/sds_gateway/api_methods/utils/metadata_schemas.py" = ["PLC0415"] "gateway/sds_gateway/api_methods/utils/metadata_schemas.py" = ["PLC0415"] - "gateway/sds_gateway/users/views.py" = ["PLC0415"] + "../gateway/sds_gateway/users/views.py" = ["PLC0415"] + "gateway/sds_gateway/users/views.py" = ["PLC0415"] [tool.ruff.format] indent-style = "space" diff --git a/sdk/src/spectrumx/api/datasets.py b/sdk/src/spectrumx/api/datasets.py index ee8ef8ae..46cb9ed8 100644 --- a/sdk/src/spectrumx/api/datasets.py +++ b/sdk/src/spectrumx/api/datasets.py @@ -3,17 +3,21 @@ from __future__ import annotations import json -import uuid +from typing import TYPE_CHECKING from typing import Any from loguru import logger as log -from spectrumx.gateway import GatewayClient from spectrumx.models.datasets import Dataset from spectrumx.models.files import File from spectrumx.ops.pagination import Paginator from spectrumx.utils import log_user +if TYPE_CHECKING: + from uuid import UUID + + from spectrumx.gateway import GatewayClient + class DatasetAPI: gateway: GatewayClient @@ -31,7 +35,7 @@ def __init__( self.gateway = gateway self.verbose = verbose - def get(self, dataset_uuid: uuid.UUID) -> Dataset: + def get(self, dataset_uuid: UUID) -> Dataset: """Load dataset metadata, captures, and artifact files from SDS. Captures are returned in the same grouped shape as the capture list API @@ -49,7 +53,7 @@ def get(self, dataset_uuid: uuid.UUID) -> Dataset: ) return Dataset.model_validate_json(raw) - def list_captures(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: + def list_captures(self, dataset_uuid: UUID) -> list[dict[str, Any]]: """Return capture payloads linked to the dataset (raw JSON objects). Use this when you need composite capture fields (for example ``channels``) @@ -67,7 +71,7 @@ def list_captures(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: captures = data.get("captures") return list(captures) if isinstance(captures, list) else [] - def list_artifact_files(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: + def list_artifact_files(self, dataset_uuid: UUID) -> list[dict[str, Any]]: """Return file rows linked directly to the dataset (artifacts), as JSON dicts. These are the same objects embedded on :meth:`get` under the ``files`` key. @@ -88,7 +92,7 @@ def list_artifact_files(self, dataset_uuid: uuid.UUID) -> list[dict[str, Any]]: def get_files( self, - dataset_uuid: uuid.UUID, + dataset_uuid: UUID, ) -> Paginator[File]: """Get files in the dataset as a paginator. @@ -114,7 +118,7 @@ def get_files( def delete( self, - dataset_uuid: uuid.UUID, + dataset_uuid: UUID, ) -> bool: """Deletes a dataset from SDS by its UUID. @@ -137,7 +141,7 @@ def delete( log.debug(f"Dataset deleted with UUID {dataset_uuid}") return True - def revoke_share_permissions(self, dataset_uuid: uuid.UUID) -> bool: + def revoke_share_permissions(self, dataset_uuid: UUID) -> bool: """Revoke all direct share permissions on this dataset (owner-only). Use this (or the web portal) before :meth:`delete` when the dataset is shared. diff --git a/sdk/src/spectrumx/api/sds_files.py b/sdk/src/spectrumx/api/sds_files.py index ef6ce2ba..c4afef21 100644 --- a/sdk/src/spectrumx/api/sds_files.py +++ b/sdk/src/spectrumx/api/sds_files.py @@ -6,8 +6,8 @@ import os import tempfile import uuid +from datetime import UTC from datetime import datetime -from datetime import timezone from enum import Enum from enum import auto from multiprocessing.synchronize import RLock @@ -31,10 +31,7 @@ def _file_list_time_query_param(value: datetime) -> str: """Format a datetime for Gateway file list temporal query params (ISO 8601, UTC).""" - if value.tzinfo is None: - value = value.replace(tzinfo=timezone.utc) - else: - value = value.astimezone(timezone.utc) + value = value.replace(tzinfo=UTC) if value.tzinfo is None else value.astimezone(UTC) return value.isoformat() @@ -161,6 +158,7 @@ def list_files( return pagination + def upload_file( *, client: Client, diff --git a/sdk/src/spectrumx/client.py b/sdk/src/spectrumx/client.py index 76428b4c..d8c410a9 100644 --- a/sdk/src/spectrumx/client.py +++ b/sdk/src/spectrumx/client.py @@ -416,6 +416,7 @@ def list_files( sds_path: PurePosixPath | Path | str, start_time: datetime | None = None, end_time: datetime | None = None, + *, verbose: bool = False, ) -> Paginator[File]: """Lists files in a given SDS path. @@ -533,7 +534,7 @@ def get_dataset(self, dataset_uuid: UUID4 | str) -> Dataset: return self.datasets.get(dataset_uuid) def list_dataset_captures(self, dataset_uuid: UUID4 | str) -> list[dict[str, Any]]: - """List captures linked to a dataset (raw dicts; supports composite payloads).""" + """List captures for a dataset (raw dicts; composite payloads supported).""" if isinstance(dataset_uuid, str): dataset_uuid = UUID(dataset_uuid) return self.datasets.list_captures(dataset_uuid) @@ -541,7 +542,7 @@ def list_dataset_captures(self, dataset_uuid: UUID4 | str) -> list[dict[str, Any def list_dataset_artifact_files( self, dataset_uuid: UUID4 | str ) -> list[dict[str, Any]]: - """List files linked directly to the dataset (not the full download manifest).""" + """List dataset artifact files (not the full download manifest).""" if isinstance(dataset_uuid, str): dataset_uuid = UUID(dataset_uuid) return self.datasets.list_artifact_files(dataset_uuid) diff --git a/sdk/src/spectrumx/models/datasets.py b/sdk/src/spectrumx/models/datasets.py index 5b286131..61e913ca 100644 --- a/sdk/src/spectrumx/models/datasets.py +++ b/sdk/src/spectrumx/models/datasets.py @@ -21,6 +21,7 @@ class DatasetFile(BaseModel): directory: str | None = None media_type: str | None = None + class DatasetCapture(BaseModel): model_config = ConfigDict(extra="ignore") @@ -32,6 +33,7 @@ class DatasetCapture(BaseModel): top_level_dir: str | None = None owner: User | None = None + class Dataset(BaseModel): """A dataset in SDS.""" diff --git a/sdk/tests/integration/test_file_ops.py b/sdk/tests/integration/test_file_ops.py index 802218b8..69284024 100644 --- a/sdk/tests/integration/test_file_ops.py +++ b/sdk/tests/integration/test_file_ops.py @@ -3,8 +3,8 @@ import logging import time import uuid +from datetime import UTC from datetime import datetime -from datetime import timezone from pathlib import Path from pathlib import PurePosixPath from unittest.mock import patch @@ -25,6 +25,7 @@ from tests.test_utils import disable_ssl_warnings BLAKE3_HEX_LEN: int = 64 +DRF_SAMPLE_RF_CHUNK_COUNT: int = 16 def test_is_valid_file_allowed(temp_file_with_text_contents) -> None: @@ -732,18 +733,16 @@ def test_list_files_temporal_rf_narrows_digital_rf_chunks( integration_client: Client, drf_sample_top_level_dir: Path, ) -> None: - """Temporal bounds narrow ``rf@*.h5`` listings; non-RF names in the dir stay listed.""" + """Temporal bounds narrow ``rf@*.h5``; other filenames in the dir still list.""" if not drf_sample_top_level_dir.is_dir(): - pytest.skip( - "Digital RF sample tree missing; cannot test temporal RF listing." - ) + pytest.skip("Digital RF sample tree missing; cannot test temporal RF listing.") cap_data = _upload_drf_capture_test_assets( integration_client=integration_client, drf_sample_top_level_dir=drf_sample_top_level_dir, ) rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" - start = datetime.fromtimestamp(1719499740, tz=timezone.utc) - end = datetime.fromtimestamp(1719499740, tz=timezone.utc) + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499740, tz=UTC) all_rf = { f.name for f in integration_client.list_files(sds_path=rf_dir) @@ -758,7 +757,9 @@ def test_list_files_temporal_rf_narrows_digital_rf_chunks( ) if f.name.startswith("rf@") } - assert len(all_rf) == 16, f"Expected 16 RF chunks under {rf_dir}, got {all_rf!r}" + assert len(all_rf) == DRF_SAMPLE_RF_CHUNK_COUNT, ( + f"Expected {DRF_SAMPLE_RF_CHUNK_COUNT} RF chunks under {rf_dir}, got {all_rf!r}" + ) assert narrow_rf == {"rf@1719499740.000.h5"} @@ -779,18 +780,16 @@ def test_list_files_temporal_rf_inclusive_range_multiple_chunks( integration_client: Client, drf_sample_top_level_dir: Path, ) -> None: - """Temporal window spanning several seconds includes each ``rf@*.h5`` in that span.""" + """A multi-second temporal window lists each ``rf@*.h5`` in that span.""" if not drf_sample_top_level_dir.is_dir(): - pytest.skip( - "Digital RF sample tree missing; cannot test temporal RF listing." - ) + pytest.skip("Digital RF sample tree missing; cannot test temporal RF listing.") cap_data = _upload_drf_capture_test_assets( integration_client=integration_client, drf_sample_top_level_dir=drf_sample_top_level_dir, ) rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" - start = datetime.fromtimestamp(1719499740, tz=timezone.utc) - end = datetime.fromtimestamp(1719499742, tz=timezone.utc) + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499742, tz=UTC) expected = { "rf@1719499740.000.h5", "rf@1719499741.000.h5", @@ -829,16 +828,14 @@ def test_download_respects_temporal_rf_window( ) -> None: """Bulk download with start/end only fetches ``rf@*.h5`` in the UTC window.""" if not drf_sample_top_level_dir.is_dir(): - pytest.skip( - "Digital RF sample tree missing; cannot test temporal RF download." - ) + pytest.skip("Digital RF sample tree missing; cannot test temporal RF download.") cap_data = _upload_drf_capture_test_assets( integration_client=integration_client, drf_sample_top_level_dir=drf_sample_top_level_dir, ) rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" - start = datetime.fromtimestamp(1719499740, tz=timezone.utc) - end = datetime.fromtimestamp(1719499741, tz=timezone.utc) + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499741, tz=UTC) expected_names = {"rf@1719499740.000.h5", "rf@1719499741.000.h5"} download_dir = tmp_path / "drf_partial" @@ -888,8 +885,8 @@ def test_list_files_temporal_non_rf_directory_warns( ) failures = [result for result in results if not result] assert not failures, f"Upload failed: {failures}" - start = datetime(2020, 1, 1, tzinfo=timezone.utc) - end = datetime(2020, 1, 2, tzinfo=timezone.utc) + start = datetime(2020, 1, 1, tzinfo=UTC) + end = datetime(2020, 1, 2, tzinfo=UTC) listed = list( integration_client.list_files( sds_path=sds_path, diff --git a/sdk/tests/ops/test_files.py b/sdk/tests/ops/test_files.py index 97b704eb..4e786288 100644 --- a/sdk/tests/ops/test_files.py +++ b/sdk/tests/ops/test_files.py @@ -5,6 +5,7 @@ import sys import tempfile import uuid as uuidlib +from datetime import UTC from datetime import datetime from datetime import timedelta from datetime import timezone @@ -16,11 +17,9 @@ import responses from loguru import logger as log from spectrumx import Client +from spectrumx.api.sds_files import _file_list_time_query_param from spectrumx.api.sds_files import delete_file from spectrumx.api.sds_files import list_files -from spectrumx.api.sds_files import ( - _file_list_time_query_param, # noqa: SLF001 -) from spectrumx.errors import FileError from spectrumx.gateway import API_TARGET_VERSION from spectrumx.ops.files import ( @@ -112,7 +111,7 @@ def test_get_file_by_id(client: Client, responses: responses.RequestsMock) -> No def test_list_files_start_end_must_be_paired(client: Client) -> None: """Omitting one of ``start_time`` / ``end_time`` raises before any request.""" - t = datetime(2024, 6, 27, 14, 0, tzinfo=timezone.utc) + t = datetime(2024, 6, 27, 14, 0, tzinfo=UTC) with pytest.raises(ValueError, match="both be set or both omitted"): list_files(client=client, sds_path="/", start_time=t, end_time=None) with pytest.raises(ValueError, match="both be set or both omitted"): @@ -121,7 +120,7 @@ def test_list_files_start_end_must_be_paired(client: Client) -> None: def test_file_list_time_query_param_naive_treated_as_utc() -> None: """Naive datetimes are interpreted as UTC for query strings.""" - naive = datetime(2024, 6, 27, 14, 9, 0) + naive = datetime(2024, 6, 27, 14, 9, 0) # noqa: DTZ001 out = _file_list_time_query_param(naive) assert out.endswith("+00:00") assert "2024-06-27T14:09:00" in out @@ -139,13 +138,17 @@ def test_file_list_time_query_param_converts_non_utc_to_utc() -> None: def test_list_files_passes_iso_temporal_params_to_gateway(client: Client) -> None: """Temporal bounds are forwarded to ``gateway.list_files`` as ISO strings.""" client.dry_run = False - start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=timezone.utc) - end = datetime(2024, 6, 27, 14, 10, 0, tzinfo=timezone.utc) + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=UTC) + end = datetime(2024, 6, 27, 14, 10, 0, tzinfo=UTC) expected_start = _file_list_time_query_param(start) expected_end = _file_list_time_query_param(end) empty_page = b'{"count": 0, "results": []}' - with patch.object(client._gateway, "list_files", return_value=empty_page) as m: + with patch.object( + client._gateway, # noqa: SLF001 + "list_files", + return_value=empty_page, + ) as m: paginator = list_files( client=client, sds_path=PurePosixPath("/drf/dir"), @@ -167,8 +170,8 @@ def test_client_download_forwards_temporal_bounds_to_list_files( ) -> None: """``Client.download(..., start_time=, end_time=)`` lists with the same bounds.""" client.dry_run = False - start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=timezone.utc) - end = datetime(2024, 6, 27, 14, 11, 0, tzinfo=timezone.utc) + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=UTC) + end = datetime(2024, 6, 27, 14, 11, 0, tzinfo=UTC) sds = PurePosixPath("/capture/rf") with patch.object(client, "list_files", return_value=[]) as m_list: diff --git a/sdk/tests/ops/test_paginator.py b/sdk/tests/ops/test_paginator.py index 1248b08a..56865f6f 100644 --- a/sdk/tests/ops/test_paginator.py +++ b/sdk/tests/ops/test_paginator.py @@ -101,10 +101,10 @@ def test_paginator_dry_run_ingest_list_files(gateway: GatewayClient) -> None: def test_paginator_preserves_temporal_kwargs_across_pages( - gateway: GatewayClient + gateway: GatewayClient, ) -> None: """ - ``start_time`` / ``end_time`` in ``list_kwargs`` + ``start_time`` / ``end_time`` in ``list_kwargs`` are sent on every ``list_files`` call. """ one = sx_files.generate_sample_file(uuid.uuid4()) @@ -143,8 +143,10 @@ def side_effect(**kwargs: object) -> bytes: ) consumed = list(paginator) - assert len(consumed) == 3 - assert len(recorded) == 2 + expected_items = 3 + expected_pages = 2 + assert len(consumed) == expected_items + assert len(recorded) == expected_pages for call_kw in recorded: assert call_kw["start_time"] == start_iso assert call_kw["end_time"] == end_iso @@ -379,7 +381,7 @@ def test_paginator_logs_api_warnings_on_first_page( caplog: pytest.LogCaptureFixture, gateway: GatewayClient, ) -> None: - """Server ``warnings`` on the first JSON page are forwarded via ``log_user_warning``.""" + """First-page ``warnings`` in JSON are forwarded via ``log_user_warning``.""" caplog.set_level(logging.WARNING) warn_msg = "Temporal filtering is only supported for RF dirs" one = sx_files.generate_sample_file(uuid.uuid4()) From f5626f4bd0d354927e75628bf7f5c2aa6c863f1f Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 30 Apr 2026 13:59:26 -0400 Subject: [PATCH 08/10] Add new parameters for selecting dataset files by capture uuid or top_level_dir lists --- .../serializers/file_serializers.py | 28 +++++ gateway/sds_gateway/api_methods/tasks.py | 2 +- .../tests/test_dataset_endpoints.py | 103 +++++++++++++++++ .../tests/test_dataset_manifest_filters.py | 37 ++++++ .../test_file_get_serializer_captures.py | 95 ++++++++++++++++ .../utils/dataset_manifest_filters.py | 68 ++++++++++++ .../api_methods/views/dataset_endpoints.py | 52 ++++++++- .../api_methods/views/file_endpoints.py | 4 +- gateway/sds_gateway/users/views/downloads.py | 2 +- sdk/src/spectrumx/api/datasets.py | 19 +++- sdk/src/spectrumx/api/sds_files.py | 2 +- sdk/src/spectrumx/client.py | 105 ++++++++++++++++-- sdk/src/spectrumx/gateway.py | 20 +++- sdk/tests/models/test_files.py | 64 +++++++++++ sdk/tests/test_client.py | 23 ++++ 15 files changed, 604 insertions(+), 20 deletions(-) create mode 100644 gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py create mode 100644 gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py create mode 100644 gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py diff --git a/gateway/sds_gateway/api_methods/serializers/file_serializers.py b/gateway/sds_gateway/api_methods/serializers/file_serializers.py index 51457961..ccb2acf0 100644 --- a/gateway/sds_gateway/api_methods/serializers/file_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/file_serializers.py @@ -7,6 +7,7 @@ from loguru import logger as log from rest_framework import serializers +from sds_gateway.api_methods.models import Capture from sds_gateway.api_methods.models import File from sds_gateway.api_methods.serializers.capture_serializers import CaptureGetSerializer from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer @@ -51,6 +52,33 @@ class Meta: model = File fields = "__all__" + def to_representation(self, instance: File) -> dict[str, Any]: + """Build ``captures`` from M2M plus legacy ``capture`` FK when needed.""" + data = super().to_representation(instance) + merged: list[Capture] = [] + seen_pks: set[uuid.UUID] = set() + for cap in instance.captures.all().order_by("uuid"): + pk = cap.pk + if pk not in seen_pks: + merged.append(cap) + seen_pks.add(pk) + legacy = instance.capture + if legacy is not None and legacy.pk not in seen_pks: + merged.append(legacy) + seen_pks.add(legacy.pk) + context = self.context + data["captures"] = CaptureGetSerializer( + merged, + many=True, + context=context, + ).data + data["capture"] = ( + CaptureGetSerializer(merged[0], context=context).data + if merged + else None + ) + return data + class FilePostSerializer(serializers.ModelSerializer[File]): class Meta: diff --git a/gateway/sds_gateway/api_methods/tasks.py b/gateway/sds_gateway/api_methods/tasks.py index 440b7dfb..1d77fee9 100644 --- a/gateway/sds_gateway/api_methods/tasks.py +++ b/gateway/sds_gateway/api_methods/tasks.py @@ -1316,7 +1316,7 @@ def _get_item_files( end_time=end_time, ) else: - if start_time is not None or end_time is not None: + if start_time or end_time: log.warning( "Temporal filtering is only supported for DigitalRF captures, " "ignoring start_time and end_time" diff --git a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py index cc83b1cc..2c1f8201 100644 --- a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py @@ -170,6 +170,7 @@ def test_get_dataset_files_success(self): assert "size" in file_info assert "media_type" in file_info assert file_info["capture"] is None + assert file_info["captures"] == [] def test_get_dataset_files_with_owned_captures(self): """Test dataset files manifest including files from owned captures.""" @@ -228,6 +229,12 @@ def test_get_dataset_files_with_owned_captures(self): for file_info in capture_files: assert file_info["capture"]["uuid"] == str(capture.uuid) assert file_info["capture"]["name"] == capture.name + assert file_info["captures"] + assert any(c["uuid"] == str(capture.uuid) for c in file_info["captures"]) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] def test_get_dataset_files_with_shared_captures(self): """Test dataset files manifest including files from shared captures.""" @@ -302,6 +309,12 @@ def test_get_dataset_files_with_shared_captures(self): for file_info in capture_files: assert file_info["capture"]["uuid"] == str(capture.uuid) assert file_info["capture"]["name"] == capture.name + assert file_info["captures"] + assert any(c["uuid"] == str(capture.uuid) for c in file_info["captures"]) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] def test_get_dataset_files_with_both_owned_and_shared_captures(self): """Test dataset files manifest with both owned and shared captures.""" @@ -402,6 +415,10 @@ def test_get_dataset_files_with_both_owned_and_shared_captures(self): for file_info in owned_capture_files: assert file_info["capture"]["uuid"] == str(owned_capture.uuid) assert file_info["capture"]["name"] == owned_capture.name + assert file_info["captures"] + assert any( + c["uuid"] == str(owned_capture.uuid) for c in file_info["captures"] + ) # Verify shared capture file info structure shared_capture_files = [ @@ -415,6 +432,92 @@ def test_get_dataset_files_with_both_owned_and_shared_captures(self): for file_info in shared_capture_files: assert file_info["capture"]["uuid"] == str(shared_capture.uuid) assert file_info["capture"]["name"] == shared_capture.name + assert file_info["captures"] + assert any( + c["uuid"] == str(shared_capture.uuid) for c in file_info["captures"] + ) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] + + def test_get_dataset_files_filter_by_capture_query(self): + """``capture`` query param restricts the manifest server-side.""" + capture = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch", + index_name="ix", + name="cap", + ) + self.created_captures.append(capture) + with MockMinIOContext(b"c"): + f_cap_a = create_file_with_minio_mock( + file_content=b"c", owner=self.user, capture=capture + ) + f_cap_b = create_file_with_minio_mock( + file_content=b"c", owner=self.user, capture=capture + ) + f_art = create_file_with_minio_mock( + file_content=b"c", owner=self.user, dataset=self.dataset + ) + self.created_files.extend([f_cap_a, f_cap_b, f_art]) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get( + base_url, + {"capture": str(capture.uuid)}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["count"] == self.EXPECTED_CAPTURE_FILES + for row in data["results"]: + assert any(c["uuid"] == str(capture.uuid) for c in row["captures"]) + + def test_get_dataset_files_filter_by_capture_invalid_uuid(self): + """Invalid ``capture`` UUID returns 400.""" + with MockMinIOContext(b"x"): + f = create_file_with_minio_mock( + file_content=b"x", owner=self.user, dataset=self.dataset + ) + self.created_files.append(f) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get(base_url, {"capture": "not-a-uuid"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_get_dataset_files_filter_by_top_level_dir(self): + """``top_level_dir`` filters by ``File.directory`` prefix.""" + tld = "/pytest/capture/root" + capture = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch", + index_name="ix", + name="cap", + top_level_dir=tld, + ) + self.created_captures.append(capture) + with MockMinIOContext(b"c"): + f_match = create_file_with_minio_mock( + file_content=b"c", + owner=self.user, + capture=capture, + directory=f"{tld}/channel0/", + ) + f_other = create_file_with_minio_mock( + file_content=b"c", owner=self.user, dataset=self.dataset + ) + self.created_files.extend([f_match, f_other]) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get( + base_url, + {"top_level_dir": tld}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["count"] == 1 + assert data["results"][0]["uuid"] == str(f_match.uuid) def test_get_dataset_files_not_found(self): """Test dataset files manifest with non-existent UUID.""" diff --git a/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py b/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py new file mode 100644 index 00000000..18cb8a82 --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py @@ -0,0 +1,37 @@ +"""Tests for dataset manifest query parsing and path normalization.""" + +import uuid + +import pytest +from rest_framework.request import Request +from rest_framework.test import APIRequestFactory + +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + normalize_top_level_dir_prefix, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_capture_uuid_query, +) + + +def test_normalize_top_level_dir_prefix() -> None: + assert normalize_top_level_dir_prefix("foo/bar") == "/foo/bar" + assert normalize_top_level_dir_prefix("/foo/bar/") == "/foo/bar" + assert normalize_top_level_dir_prefix("/") == "/" + + +def test_parse_capture_uuid_query_accepts_comma_separated() -> None: + u1, u2 = uuid.uuid4(), uuid.uuid4() + factory = APIRequestFactory() + req = factory.get("/x", {"capture": f"{u1},{u2}"}) + drf_request = Request(req) + parsed = parse_capture_uuid_query(drf_request) + assert parsed == [u1, u2] + + +def test_parse_capture_uuid_query_invalid_raises() -> None: + factory = APIRequestFactory() + req = factory.get("/x", {"capture": "not-a-uuid"}) + drf_request = Request(req) + with pytest.raises(ValueError, match="Invalid capture UUID"): + parse_capture_uuid_query(drf_request) diff --git a/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py b/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py new file mode 100644 index 00000000..5f3b7e02 --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py @@ -0,0 +1,95 @@ +"""Tests for FileGetSerializer ``capture`` / ``captures`` merge (SDK-facing shape).""" + +from unittest.mock import patch + +from django.test import TestCase + +from sds_gateway.api_methods.models import Capture +from sds_gateway.api_methods.models import CaptureType +from sds_gateway.api_methods.models import File +from sds_gateway.api_methods.serializers.file_serializers import FileGetSerializer +from sds_gateway.api_methods.tests.factories import DatasetFactory +from sds_gateway.api_methods.tests.factories import UserFactory +from sds_gateway.api_methods.tests.test_file_endpoints import create_db_file + + +class FileGetSerializerCapturesTestCase(TestCase): + """``to_representation`` merges M2M ``captures`` with legacy ``capture`` FK.""" + + def setUp(self) -> None: + self.user = UserFactory() + self.dataset = DatasetFactory(owner=self.user) + self.capture_a = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch-a", + index_name="ix-a", + name="cap-a", + ) + self.capture_b = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch-b", + index_name="ix-b", + name="cap-b", + ) + self.opensearch_patcher = patch( + "sds_gateway.api_methods.helpers.index_handling.retrieve_indexed_metadata", + return_value={}, + ) + self.opensearch_patcher.start() + + def tearDown(self) -> None: + self.opensearch_patcher.stop() + File.objects.filter(owner=self.user).delete() + Capture.objects.filter(owner=self.user).delete() + self.dataset.delete() + self.user.delete() + + def _serialize(self, file_obj: File) -> dict: + return FileGetSerializer(file_obj).data + + def test_legacy_fk_only_populates_both_capture_and_captures(self) -> None: + f = create_db_file(owner=self.user) + f.capture = self.capture_a + f.save(update_fields=["capture"]) + data = self._serialize(f) + assert data["capture"] is not None + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + + def test_m2m_only_populates_both(self) -> None: + f = create_db_file(owner=self.user) + f.captures.add(self.capture_a) + data = self._serialize(f) + assert data["capture"] is not None + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + + def test_fk_and_m2m_same_capture_deduplicated(self) -> None: + f = create_db_file(owner=self.user) + f.capture = self.capture_a + f.save(update_fields=["capture"]) + f.captures.add(self.capture_a) + data = self._serialize(f) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + + def test_m2m_two_captures_fk_none_first_in_uuid_order(self) -> None: + f = create_db_file(owner=self.user) + f.captures.add(self.capture_b, self.capture_a) + data = self._serialize(f) + ordered = sorted([str(self.capture_a.uuid), str(self.capture_b.uuid)]) + assert [c["uuid"] for c in data["captures"]] == ordered + assert data["capture"]["uuid"] == ordered[0] + + def test_no_capture_links_both_null_or_empty(self) -> None: + f = create_db_file(owner=self.user) + data = self._serialize(f) + assert data["capture"] is None + assert data["captures"] == [] diff --git a/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py b/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py new file mode 100644 index 00000000..21ccc34f --- /dev/null +++ b/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py @@ -0,0 +1,68 @@ +"""Database filters for dataset file manifest (download) listings.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import UUID + +from django.db.models import Q +from django.db.models import QuerySet + +if TYPE_CHECKING: + from rest_framework.request import Request + + from sds_gateway.api_methods.models import File + + +def normalize_top_level_dir_prefix(raw: str) -> str: + """Normalize a capture ``top_level_dir`` / path prefix for URL and DB matching.""" + s = raw.strip().replace("\\", "/") + if not s.startswith("/"): + s = f"/{s}" + return s.rstrip("/") or "/" + + +def parse_capture_uuid_query(request: Request) -> list[UUID]: + """Parse ``capture`` query params (repeat or comma-separated) into UUIDs.""" + vals: list[str] = [] + for item in request.query_params.getlist("capture"): + vals.extend(p.strip() for p in item.split(",") if p.strip()) + out: list[UUID] = [] + for s in vals: + try: + out.append(UUID(s)) + except ValueError as err: + msg = f"Invalid capture UUID: {s!r}" + raise ValueError(msg) from err + return out + + +def parse_top_level_dir_query(request: Request) -> list[str]: + """Parse ``top_level_dir`` query params into normalized path prefixes.""" + vals: list[str] = [] + for item in request.query_params.getlist("top_level_dir"): + vals.extend(p.strip() for p in item.split(",") if p.strip()) + return [normalize_top_level_dir_prefix(v) for v in vals] + + +def filter_dataset_files_queryset( + qs: QuerySet[File], + *, + capture_uuids: list[UUID], + top_level_dir_prefixes: list[str], +) -> QuerySet[File]: + """Restrict manifest files; capture UUID and path filters are OR-ed (SDK parity).""" + if not capture_uuids and not top_level_dir_prefixes: + return qs + parts: list[Q] = [] + if capture_uuids: + parts.append( + Q(captures__uuid__in=capture_uuids) | Q(capture__uuid__in=capture_uuids) + ) + if top_level_dir_prefixes: + qd = Q() + for pre in top_level_dir_prefixes: + qd |= Q(directory=pre) | Q(directory__startswith=f"{pre}/") + parts.append(qd) + q = parts[0] if len(parts) == 1 else parts[0] | parts[1] + return qs.filter(q).distinct() diff --git a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py index e01b8873..9aa19930 100644 --- a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py @@ -23,6 +23,15 @@ from sds_gateway.api_methods.utils.asset_access_control import ( revoke_share_permissions as revoke_item_share_permissions, ) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + filter_dataset_files_queryset, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_capture_uuid_query, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_top_level_dir_query, +) from sds_gateway.api_methods.utils.relationship_utils import ( get_dataset_files_including_captures, ) @@ -117,6 +126,27 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: location=OpenApiParameter.QUERY, default=FilePagination.page_size, ), + OpenApiParameter( + name="capture", + description=( + "Only include files linked to this capture UUID " + "(repeat param or comma-separated). OR with top_level_dir." + ), + required=False, + type=str, + location=OpenApiParameter.QUERY, + ), + OpenApiParameter( + name="top_level_dir", + description=( + "Only include files whose directory is this path or under it " + "(repeat param or comma-separated). Normalized like capture " + "top_level_dir." + ), + required=False, + type=str, + location=OpenApiParameter.QUERY, + ), ], responses={ 200: OpenApiResponse(description="Dataset file listing"), @@ -168,8 +198,26 @@ def get_dataset_files(self, request: Request, pk: str | None = None) -> Response status=status.HTTP_404_NOT_FOUND, ) - # Order and deduplicate files by path and created_at - ordered_files = dataset_files.order_by("-created_at") + try: + capture_uuids = parse_capture_uuid_query(request) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + top_level_dir_prefixes = parse_top_level_dir_query(request) + if capture_uuids or top_level_dir_prefixes: + dataset_files = filter_dataset_files_queryset( + dataset_files, + capture_uuids=capture_uuids, + top_level_dir_prefixes=top_level_dir_prefixes, + ) + + # Order and deduplicate files by path and created_at; avoid N+1 on captures + ordered_files = dataset_files.order_by("-created_at").select_related( + "capture", + "owner", + ).prefetch_related("captures", "datasets") paginator = FilePagination() paginated_files = paginator.paginate_queryset(ordered_files, request=request) diff --git a/gateway/sds_gateway/api_methods/views/file_endpoints.py b/gateway/sds_gateway/api_methods/views/file_endpoints.py index b0e752d8..13025989 100644 --- a/gateway/sds_gateway/api_methods/views/file_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/file_endpoints.py @@ -370,13 +370,13 @@ def list(self, request: Request) -> Response: # noqa: C901 ) if self._check_files_includes_rf_data(files_matching_dir): - if start_time is not None and end_time is not None: + if start_time and end_time: files_matching_dir = filter_files_by_temporal_bounds( files_matching_dir, start_time, end_time, ) - elif start_time is not None or end_time is not None: + elif start_time or end_time: msg = ( "Both start_time and end_time are required for temporal filtering " "when listing Digital RF data." diff --git a/gateway/sds_gateway/users/views/downloads.py b/gateway/sds_gateway/users/views/downloads.py index 8c15584b..c989c75b 100644 --- a/gateway/sds_gateway/users/views/downloads.py +++ b/gateway/sds_gateway/users/views/downloads.py @@ -163,7 +163,7 @@ def _validate_time_range( start_time: int | None, end_time: int | None ) -> JsonResponse | None: """Return 400 JsonResponse if both provided and start >= end; else None.""" - if start_time is not None and end_time is not None and start_time >= end_time: + if start_time and end_time and start_time >= end_time: return JsonResponse( { "success": False, diff --git a/sdk/src/spectrumx/api/datasets.py b/sdk/src/spectrumx/api/datasets.py index 46cb9ed8..60cd1266 100644 --- a/sdk/src/spectrumx/api/datasets.py +++ b/sdk/src/spectrumx/api/datasets.py @@ -14,6 +14,9 @@ from spectrumx.utils import log_user if TYPE_CHECKING: + from collections.abc import Collection + from pathlib import Path + from pathlib import PurePosixPath from uuid import UUID from spectrumx.gateway import GatewayClient @@ -93,23 +96,35 @@ def list_artifact_files(self, dataset_uuid: UUID) -> list[dict[str, Any]]: def get_files( self, dataset_uuid: UUID, + *, + capture_uuids: Collection[UUID] | None = None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None = None, ) -> Paginator[File]: """Get files in the dataset as a paginator. Args: dataset_uuid: The UUID of the dataset to get files for. + capture_uuids: If set, passed to the gateway to restrict by capture UUID (OR + with ``top_level_dirs``). + top_level_dirs: If set, passed to the gateway as path prefixes under + ``File.directory`` (OR with ``capture_uuids``). Returns: A paginator for the files in the dataset. """ if self.dry_run: log_user("Dry run enabled: files will be simulated") - # Create a paginator that fetches from the dataset files endpoint + list_kwargs: dict[str, Any] = {"dataset_uuid": dataset_uuid} + if capture_uuids is not None: + list_kwargs["capture_uuids"] = tuple(capture_uuids) + if top_level_dirs is not None: + list_kwargs["top_level_dirs"] = tuple(str(p) for p in top_level_dirs) + pagination: Paginator[File] = Paginator( Entry=File, gateway=self.gateway, list_method=self.gateway.get_dataset_files, - list_kwargs={"dataset_uuid": dataset_uuid}, + list_kwargs=list_kwargs, dry_run=self.dry_run, verbose=self.verbose, ) diff --git a/sdk/src/spectrumx/api/sds_files.py b/sdk/src/spectrumx/api/sds_files.py index c4afef21..5cb541be 100644 --- a/sdk/src/spectrumx/api/sds_files.py +++ b/sdk/src/spectrumx/api/sds_files.py @@ -136,7 +136,7 @@ def list_files( raise ValueError(msg) sds_path = PurePosixPath(sds_path) start_q: str | None = ( - _file_list_time_query_param(start_time) if start_time is not None else None + _file_list_time_query_param(start_time) if start_time else None ) end_q: str | None = ( _file_list_time_query_param(end_time) if end_time is not None else None diff --git a/sdk/src/spectrumx/client.py b/sdk/src/spectrumx/client.py index d8c410a9..b6717f58 100644 --- a/sdk/src/spectrumx/client.py +++ b/sdk/src/spectrumx/client.py @@ -1,5 +1,6 @@ """Client for the SpectrumX Data System.""" +from collections.abc import Collection from collections.abc import Mapping from datetime import datetime from pathlib import Path @@ -33,6 +34,47 @@ from .utils import log_user_warning +def _normalize_top_level_dir_prefix( + top_level_dir: PurePosixPath | Path | str, +) -> str: + s = str(top_level_dir).strip().replace("\\", "/") + if not s.startswith("/"): + s = f"/{s}" + return s.rstrip("/") or "/" + + +def _resolve_dataset_capture_filter_params( + *, + capture_uuids: Collection[UUID4 | str] | None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None, + dry_run: bool, +) -> tuple[bool, set[UUID] | None, list[str] | None]: + """Return (filter_active, uuid_set, dir_prefixes_norm) for manifest filtering.""" + filter_by_capture = capture_uuids is not None + filter_by_dir = top_level_dirs is not None + filter_active = filter_by_capture or filter_by_dir + if not filter_active: + return False, None, None + if dry_run: + log_user( + "Dry run: capture_uuids / top_level_dirs filters are ignored " + "(simulated manifest files are not capture-tagged)." + ) + return False, None, None + uuid_set: set[UUID] | None = None + dir_prefixes_norm: list[str] | None = None + if filter_by_capture: + uuid_set = {UUID(str(u)) for u in capture_uuids or ()} + if filter_by_dir: + dir_prefixes_norm = [ + _normalize_top_level_dir_prefix(d) for d in top_level_dirs or () + ] + return True, uuid_set, dir_prefixes_norm + + +DownloadFileSource = list[File] | Paginator[File] + + class Client: """Instantiates an SDS client.""" @@ -211,7 +253,7 @@ def download( start_time: datetime | None = None, end_time: datetime | None = None, to_local_path: Path | str, - files_to_download: list[File] | Paginator[File] | None = None, + files_to_download: DownloadFileSource | None = None, skip_contents: bool = False, overwrite: bool = False, verbose: bool = True, @@ -278,9 +320,9 @@ def _get_files_to_download( from_sds_path: PurePosixPath | None, start_time: datetime | None = None, end_time: datetime | None = None, - files_to_download: list[File] | Paginator[File] | None, + files_to_download: DownloadFileSource | None, verbose: bool, - ) -> list[File] | Paginator[File]: + ) -> DownloadFileSource: """Get the list of files to download.""" if self.dry_run: log_user( @@ -310,7 +352,7 @@ def _get_files_to_download( def _download_files( self, *, - files_to_download: list[File] | Paginator[File], + files_to_download: DownloadFileSource, to_local_path: Path, skip_contents: bool, overwrite: bool, @@ -479,11 +521,13 @@ def download_dataset( *, dataset_uuid: UUID4 | str, to_local_path: Path | str, + capture_uuids: Collection[UUID4 | str] | None = None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None = None, skip_contents: bool = False, overwrite: bool = False, verbose: bool = True, ) -> list[Result[File]]: - """Downloads all files in a dataset using the existing download infrastructure. + """Downloads files in a dataset using the existing download infrastructure. This approach uses the get_dataset_files endpoint to get a paginated list of File objects and then uses the existing download() method with @@ -492,19 +536,64 @@ def download_dataset( Args: dataset_uuid: The UUID of the dataset to download. to_local_path: The local path to save the downloaded files to. + capture_uuids: If set, only files linked to at least one of these capture + UUIDs are downloaded (dataset artifacts with no capture are excluded). + top_level_dirs: If set, only files whose ``directory`` lies under one of + these capture ``top_level_dir`` paths (as from + :meth:`list_dataset_captures`) are included. Leading/trailing slashes + are normalized. skip_contents: When True, only the metadata is downloaded. overwrite: Whether to overwrite existing local files. verbose: Show progress bars and detailed output. + + If both ``capture_uuids`` and ``top_level_dirs`` are set, a file is included + when it matches **either** criterion (enforced on the gateway). In dry run + mode, capture/path filters are ignored because simulated files are not + capture-tagged. + Returns: A list of results for each file downloaded. """ if isinstance(dataset_uuid, str): dataset_uuid = UUID(dataset_uuid) - # Get all files in the dataset as a paginator - files_to_download = self.datasets.get_files(dataset_uuid=dataset_uuid) + ( + filter_active, + uuid_set, + dir_prefixes_norm, + ) = _resolve_dataset_capture_filter_params( + capture_uuids=capture_uuids, + top_level_dirs=top_level_dirs, + dry_run=self.dry_run, + ) + + if filter_active: + uuid_nonempty = bool(uuid_set) + dir_nonempty = bool(dir_prefixes_norm) + if not uuid_nonempty and not dir_nonempty: + files_to_download: DownloadFileSource = [] + elif uuid_nonempty and dir_nonempty: + files_to_download = self.datasets.get_files( + dataset_uuid, + capture_uuids=tuple(uuid_set), + top_level_dirs=tuple(dir_prefixes_norm), + ) + elif uuid_nonempty: + files_to_download = self.datasets.get_files( + dataset_uuid, + capture_uuids=tuple(uuid_set), + ) + else: + files_to_download = self.datasets.get_files( + dataset_uuid, + top_level_dirs=tuple(dir_prefixes_norm), + ) + if verbose and (uuid_nonempty or dir_nonempty): + log_user("Dataset capture filter active (applied on the gateway).") + else: + files_to_download = self.datasets.get_files(dataset_uuid=dataset_uuid) - if verbose: + if verbose and not filter_active: log_user( f"Downloading files from dataset " f"(total: {len(files_to_download)} files)" diff --git a/sdk/src/spectrumx/gateway.py b/sdk/src/spectrumx/gateway.py index 34cdd879..5a27816f 100644 --- a/sdk/src/spectrumx/gateway.py +++ b/sdk/src/spectrumx/gateway.py @@ -2,6 +2,7 @@ import json import uuid +from collections.abc import Collection from collections.abc import Iterator from enum import StrEnum from http import HTTPStatus @@ -282,9 +283,9 @@ def list_files( "page_size": page_size, "path": str(sds_path), } - if start_time is not None: + if start_time: params["start_time"] = start_time - if end_time is not None: + if end_time: params["end_time"] = end_time response = self._request( method=HTTPMethods.GET, @@ -777,21 +778,34 @@ def get_dataset_files( dataset_uuid: uuid.UUID, page: int = 1, page_size: int = 30, + capture_uuids: Collection[uuid.UUID] | None = None, + top_level_dirs: Collection[str | PurePosixPath | Path] | None = None, verbose: bool = False, ) -> bytes: """Get a manifest of files in the dataset for efficient downloading. Args: dataset_uuid: The UUID of the dataset to get files for. + capture_uuids: Optional capture UUIDs to filter server-side (repeat query). + top_level_dirs: Optional directory prefixes to filter server-side. verbose: Show network requests and other info. Returns: The response content containing the dataset file manifest. """ + params_list: list[tuple[str, str | int]] = [ + ("page", page), + ("page_size", page_size), + ] + if capture_uuids: + params_list.extend(("capture", str(cap_id)) for cap_id in capture_uuids) + if top_level_dirs: + params_list.extend(("top_level_dir", str(path)) for path in top_level_dirs) + response = self._request( method=HTTPMethods.GET, endpoint=Endpoints.DATASET_FILES, endpoint_args={"uuid": dataset_uuid.hex}, - params={"page": page, "page_size": page_size}, + params=params_list, verbose=verbose, ) network.success_or_raise(response, ContextException=DatasetError) diff --git a/sdk/tests/models/test_files.py b/sdk/tests/models/test_files.py index 57373304..b1509b30 100644 --- a/sdk/tests/models/test_files.py +++ b/sdk/tests/models/test_files.py @@ -2,6 +2,7 @@ # pylint: disable=redefined-outer-name +import uuid from datetime import datetime from datetime import timedelta from pathlib import PurePosixPath @@ -11,6 +12,8 @@ from pydantic import BaseModel from pytz import UTC from pytz import timezone +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.files import File from spectrumx.models.files import PermissionRepresentation from spectrumx.models.files import UnixPermissionStr @@ -100,3 +103,64 @@ class Model(BaseModel): assert b.model_dump(context={"mode": PermissionRepresentation.OCTAL}) == { "permission": "0o755" } + + +def test_file_parses_captures_list_from_api( + file_properties: dict[str, Any], +) -> None: + """Gateway file payloads use a ``captures`` list (canonical).""" + cap_uid = uuid.uuid4() + nested_capture = { + "owner": {"id": 1, "email": "test@example.com", "name": "Test User"}, + "is_shared": False, + "share_permissions": [], + "datasets": [], + "capture_props": {}, + "capture_type": CaptureType.DigitalRF.value, + "created_at": datetime.now(UTC).isoformat(), + "index_name": "captures-drf", + "origin": CaptureOrigin.User.value, + "top_level_dir": "/c/tdir", + "uuid": str(cap_uid), + "files": [], + } + raw = { + "uuid": str(uuid.uuid4()), + **file_properties, + "captures": [nested_capture], + } + f = File.model_validate(raw) + assert f.captures is not None + assert len(f.captures) == 1 + assert f.captures[0].uuid == cap_uid + + +def test_file_payload_may_include_redundant_singular_capture( + file_properties: dict[str, Any], +) -> None: + """Gateway may send both ``captures`` and ``capture``; SDK uses ``captures``.""" + cap_uid = uuid.uuid4() + nested_capture = { + "owner": {"id": 1, "email": "test@example.com", "name": "Test User"}, + "is_shared": False, + "share_permissions": [], + "datasets": [], + "capture_props": {}, + "capture_type": CaptureType.DigitalRF.value, + "created_at": datetime.now(UTC).isoformat(), + "index_name": "captures-drf", + "origin": CaptureOrigin.User.value, + "top_level_dir": "/c/tdir", + "uuid": str(cap_uid), + "files": [], + } + raw = { + "uuid": str(uuid.uuid4()), + **file_properties, + "captures": [nested_capture], + "capture": nested_capture, + } + f = File.model_validate(raw) + assert f.captures is not None + assert len(f.captures) == 1 + assert f.captures[0].uuid == cap_uid diff --git a/sdk/tests/test_client.py b/sdk/tests/test_client.py index 28bcaddb..8b393ef1 100644 --- a/sdk/tests/test_client.py +++ b/sdk/tests/test_client.py @@ -11,6 +11,7 @@ import pytest from loguru import logger as log from spectrumx.client import Client +from spectrumx.client import _resolve_dataset_capture_filter_params from spectrumx.config import SDSConfig from spectrumx.config import _cfg_name_lookup from spectrumx.models.files import File @@ -329,3 +330,25 @@ def test_existing_local_file_identical_checksum_not_redownloaded( assert result, "Result should not be None" assert result() is file_info, "Returned file should be the original file_info" + + +def test_resolve_dataset_capture_filter_dry_run_disables() -> None: + active, uuids, dirs = _resolve_dataset_capture_filter_params( + capture_uuids=[uuid.uuid4()], + top_level_dirs=None, + dry_run=True, + ) + assert not active + assert uuids is None + assert dirs is None + + +def test_resolve_dataset_capture_filter_normalizes_top_level_dirs() -> None: + active, uuids, dirs = _resolve_dataset_capture_filter_params( + capture_uuids=None, + top_level_dirs=["foo/bar", "/baz/"], + dry_run=False, + ) + assert active + assert uuids is None + assert dirs == ["/foo/bar", "/baz"] From 5af90ae572ee341c93f3d8e638017e6cd62698df Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 30 Apr 2026 14:22:26 -0400 Subject: [PATCH 09/10] linting --- .../api_methods/serializers/file_serializers.py | 4 +--- .../api_methods/views/dataset_endpoints.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/gateway/sds_gateway/api_methods/serializers/file_serializers.py b/gateway/sds_gateway/api_methods/serializers/file_serializers.py index ccb2acf0..343fb701 100644 --- a/gateway/sds_gateway/api_methods/serializers/file_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/file_serializers.py @@ -73,9 +73,7 @@ def to_representation(self, instance: File) -> dict[str, Any]: context=context, ).data data["capture"] = ( - CaptureGetSerializer(merged[0], context=context).data - if merged - else None + CaptureGetSerializer(merged[0], context=context).data if merged else None ) return data diff --git a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py index 9aa19930..806c18b8 100644 --- a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py @@ -214,10 +214,14 @@ def get_dataset_files(self, request: Request, pk: str | None = None) -> Response ) # Order and deduplicate files by path and created_at; avoid N+1 on captures - ordered_files = dataset_files.order_by("-created_at").select_related( - "capture", - "owner", - ).prefetch_related("captures", "datasets") + ordered_files = ( + dataset_files.order_by("-created_at") + .select_related( + "capture", + "owner", + ) + .prefetch_related("captures", "datasets") + ) paginator = FilePagination() paginated_files = paginator.paginate_queryset(ordered_files, request=request) From ca1cd4e29a9094baaeeb8cf6b927859db73cc117 Mon Sep 17 00:00:00 2001 From: klpoland Date: Thu, 30 Apr 2026 15:46:14 -0400 Subject: [PATCH 10/10] copilot comments --- .../serializers/capture_serializers.py | 170 ++++++++++++------ .../test_composite_capture_serialization.py | 73 +++++--- .../api_methods/tests/test_file_endpoints.py | 13 ++ .../api_methods/views/file_endpoints.py | 81 +++++++-- sdk/docs/mkdocs/changelog.md | 11 +- 5 files changed, 249 insertions(+), 99 deletions(-) diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index feac4e2f..46266d88 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -44,6 +44,33 @@ def _epoch_sec_to_local_display(epoch_sec: int) -> str: return django_timezone.localtime(dt).strftime("%m/%d/%Y %I:%M:%S %p") +def _channel_row_bounds_from_os_meta(meta: dict[str, Any]) -> dict[str, Any]: + """Map OpenSearch metadata to composite channel serializer fields.""" + entry: dict[str, Any] = {} + start_sec = meta.get("start_time") + end_sec = meta.get("end_time") + entry["capture_start_epoch_sec"] = start_sec + entry["capture_end_epoch_sec"] = end_sec + entry["capture_start_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(start_sec) if start_sec is not None else None + ) + entry["capture_end_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(end_sec) if end_sec is not None else None + ) + entry["capture_start_display"] = ( + _epoch_sec_to_local_display(start_sec) if start_sec is not None else None + ) + entry["capture_end_display"] = ( + _epoch_sec_to_local_display(end_sec) if end_sec is not None else None + ) + if start_sec is None or end_sec is None: + entry["length_of_capture_ms"] = None + else: + entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 + entry["file_cadence_ms"] = meta.get("file_cadence") + return entry + + class FileCaptureListSerializer(serializers.ModelSerializer[File]): class Meta: model = File @@ -484,6 +511,38 @@ class CompositeCaptureSerializer(serializers.Serializer): capture_start_display = serializers.SerializerMethodField() capture_end_display = serializers.SerializerMethodField() + def _captures_bulk_by_uuid(self, obj: dict[str, Any]) -> dict[str, Capture]: + """One DB query per composite when ``include_serializer_aux`` was False.""" + key = str(obj.get("uuid", "")) + if not hasattr(self, "_captures_bulk_cache"): + self._captures_bulk_cache: dict[str, dict[str, Capture]] = {} + if key not in self._captures_bulk_cache: + uuids = [ch["uuid"] for ch in obj.get("channels") or []] + self._captures_bulk_cache[key] = ( + {str(c.uuid): c for c in Capture.objects.filter(uuid__in=uuids)} + if uuids + else {} + ) + return self._captures_bulk_cache[key] + + def _capture_for_channel( + self, obj: dict[str, Any], channel_entry: dict[str, Any] + ) -> Capture | None: + """Resolve Capture; prefer auxiliary map from build, otherwise bulk queryset.""" + by_uuid = obj.get("_captures_by_uuid") + uuid_key = str(channel_entry["uuid"]) + if isinstance(by_uuid, dict): + hit = cast("Capture | None", by_uuid.get(uuid_key)) + if hit is not None: + return hit + hit = self._captures_bulk_by_uuid(obj).get(uuid_key) + if hit is not None: + return hit + try: + return Capture.objects.get(uuid=channel_entry["uuid"]) + except Capture.DoesNotExist: + return None + def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: """Per-channel rows with OpenSearch bounds (each channel may differ).""" key = str(obj.get("uuid", "")) @@ -497,48 +556,32 @@ def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: "uuid": ch["uuid"], "channel_metadata": ch.get("channel_metadata", {}), } - try: - capture = Capture.objects.get(uuid=ch["uuid"]) - except Capture.DoesNotExist: - entry["capture_start_epoch_sec"] = None - entry["capture_end_epoch_sec"] = None - entry["capture_start_iso_utc"] = None - entry["capture_end_iso_utc"] = None - entry["capture_start_display"] = None - entry["capture_end_display"] = None - entry["length_of_capture_ms"] = None - entry["file_cadence_ms"] = None - else: - # Per-channel bounds/cadence (Capture.get_opensearch_metadata). - start_sec = capture.start_time - end_sec = capture.end_time - entry["capture_start_epoch_sec"] = start_sec - entry["capture_end_epoch_sec"] = end_sec - entry["capture_start_iso_utc"] = ( - _epoch_sec_to_iso_utc_z(start_sec) - if start_sec is not None - else None - ) - entry["capture_end_iso_utc"] = ( - _epoch_sec_to_iso_utc_z(end_sec) - if end_sec is not None - else None - ) - entry["capture_start_display"] = ( - _epoch_sec_to_local_display(start_sec) - if start_sec is not None - else None - ) - entry["capture_end_display"] = ( - _epoch_sec_to_local_display(end_sec) - if end_sec is not None - else None + pre_meta = cast( + "dict[str, Any] | None", + ch.get("_per_channel_os_meta"), + ) + if pre_meta is not None: + entry.update(_channel_row_bounds_from_os_meta(pre_meta)) + out.append(entry) + continue + capture = self._capture_for_channel(obj, ch) + if capture is None: + entry.update( + { + "capture_start_epoch_sec": None, + "capture_end_epoch_sec": None, + "capture_start_iso_utc": None, + "capture_end_iso_utc": None, + "capture_start_display": None, + "capture_end_display": None, + "length_of_capture_ms": None, + "file_cadence_ms": None, + } ) - if start_sec is None or end_sec is None: - entry["length_of_capture_ms"] = None - else: - entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 - entry["file_cadence_ms"] = capture.file_cadence + else: + # One OS round-trip per Capture instance via instance cache. + meta = capture.get_opensearch_metadata() + entry.update(_channel_row_bounds_from_os_meta(meta)) out.append(entry) self._enriched_channels_cache[key] = out return self._enriched_channels_cache[key] @@ -591,8 +634,9 @@ def get_files(self, obj: dict[str, Any]) -> ReturnList[File]: """Get all files from all channels in the composite capture.""" all_files = [] for channel_data in obj.get("channels") or []: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue non_deleted_files = get_capture_files(capture, include_deleted=False) serializer = FileCaptureListSerializer( non_deleted_files, @@ -609,8 +653,9 @@ def get_total_file_size(self, obj: dict[str, Any]) -> int | None: total_size = 0 for channel_data in obj.get("channels") or []: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue all_files = get_capture_files(capture, include_deleted=False) result = all_files.aggregate(total_size=Sum("size")) total_size += result["total_size"] or 0 @@ -637,8 +682,9 @@ def get_data_files_info(self, obj: dict[str, Any]) -> dict[str, Any]: total_count = 0 total_size = 0 for channel_data in obj.get("channels") or []: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue stats = capture.get_drf_data_files_stats() total_count += stats["total_count"] total_size += stats["total_size"] @@ -722,11 +768,20 @@ def get_capture_end_display(self, obj: dict[str, Any]) -> str | None: return _epoch_sec_to_local_display(end_sec) -def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: +def build_composite_capture_data( + captures: list[Capture], + *, + include_serializer_aux: bool = False, +) -> dict[str, Any]: """Build composite capture data from a list of captures with the same top_level_dir. Args: captures: List of Capture objects to combine into composite + include_serializer_aux: When True, attach non-public fields used only by + :class:`CompositeCaptureSerializer`: per-channel cached OpenSearch + metadata (one search per capture) and a Capture map to avoid duplicate + ORM lookups. Keep False for raw API payloads (capture list/search, + nested dataset captures). Returns: dict: Composite capture data structure @@ -742,20 +797,25 @@ def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: base_capture = captures[0] # Build channel data with metadata - channels = [] + captures_by_uuid: dict[str, Capture] | None = {} if include_serializer_aux else None + channels: list[dict[str, Any]] = [] for capture in captures: - channel_data = { + if captures_by_uuid is not None: + captures_by_uuid[str(capture.uuid)] = capture + channel_data: dict[str, Any] = { "channel": capture.channel, "uuid": capture.uuid, "channel_metadata": retrieve_indexed_metadata(capture), } + if include_serializer_aux: + channel_data["_per_channel_os_meta"] = capture.get_opensearch_metadata() channels.append(channel_data) # Serialize the owner field owner_serializer = UserGetSerializer(base_capture.owner) # Build composite data - return { + composite: dict[str, Any] = { "uuid": base_capture.uuid, # Use first capture's UUID as composite UUID "capture_type": base_capture.capture_type, "capture_type_display": base_capture.get_capture_type_display(), @@ -771,6 +831,9 @@ def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: "owner": owner_serializer.data, "channels": channels, } + if captures_by_uuid is not None: + composite["_captures_by_uuid"] = captures_by_uuid + return composite def serialize_capture_or_composite( @@ -789,7 +852,10 @@ def serialize_capture_or_composite( if capture_data["is_composite"]: # Serialize as composite - composite_data = build_composite_capture_data(capture_data["captures"]) + composite_data = build_composite_capture_data( + capture_data["captures"], + include_serializer_aux=True, + ) serializer = CompositeCaptureSerializer(composite_data, context=context) return serializer.data # Serialize as single capture diff --git a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py index adacaa80..4482c0a6 100644 --- a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py +++ b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py @@ -66,15 +66,20 @@ def fake_retrieve(capture: Capture) -> dict: "capture_pk": str(capture.uuid), } - with patch( - "sds_gateway.api_methods.serializers.capture_serializers" - ".retrieve_indexed_metadata", - side_effect=fake_retrieve, + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + side_effect=fake_retrieve, + ), + patch.object(Capture, "get_opensearch_metadata", return_value={}), ): - composite = build_composite_capture_data([cap0, cap1]) + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) - with patch.object(Capture, "get_opensearch_metadata", return_value={}): - out = CompositeCaptureSerializer(composite, context={}).data + out = CompositeCaptureSerializer(composite, context={}).data ch_rows = {row["channel"]: row for row in out["channels"]} assert ch_rows["ch0"]["channel_metadata"] == { @@ -110,19 +115,24 @@ def test_distinct_opensearch_times_cadence_and_envelope(self) -> None: def opensearch_by_instance(self: Capture) -> dict: return dict(meta_by_uuid[str(self.uuid)]) - with patch( - "sds_gateway.api_methods.serializers.capture_serializers" - ".retrieve_indexed_metadata", - return_value={}, + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ), + patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ), ): - composite = build_composite_capture_data([cap0, cap1]) + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) - with patch.object( - Capture, - "get_opensearch_metadata", - opensearch_by_instance, - ): - out = CompositeCaptureSerializer(composite, context={}).data + out = CompositeCaptureSerializer(composite, context={}).data ch_rows = {row["channel"]: row for row in out["channels"]} assert ch_rows["ch0"]["capture_start_epoch_sec"] == 1_700_000_000 @@ -158,19 +168,24 @@ def opensearch_by_instance(self: Capture) -> dict: "file_cadence": 200, } - with patch( - "sds_gateway.api_methods.serializers.capture_serializers" - ".retrieve_indexed_metadata", - return_value={}, + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ), + patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ), ): - composite = build_composite_capture_data([cap0, cap1]) + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) - with patch.object( - Capture, - "get_opensearch_metadata", - opensearch_by_instance, - ): - out = CompositeCaptureSerializer(composite, context={}).data + out = CompositeCaptureSerializer(composite, context={}).data ch_rows = {row["channel"]: row for row in out["channels"]} assert ch_rows["ch0"]["length_of_capture_ms"] == 30_000 diff --git a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py index 97b0f55f..7ee5add9 100644 --- a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py @@ -479,6 +479,19 @@ def test_list_files_with_temporal_params(self) -> None: assert f"rf@{base_sec}.000.h5" not in names assert f"rf@{base_sec + 5}.000.h5" not in names + def test_list_files_invalid_temporal_param_returns_400(self) -> None: + """Malformed start_time yields 400 instead of 500.""" + response = self.client.get( + self.list_url, + { + "path": "/", + "start_time": "not-a-datetime", + "end_time": "2020-01-02T00:00:00", + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "start_time" in response.json()["detail"].lower() + def test_list_files_includes_warnings_key(self) -> None: """Paginated list responses always include ``warnings`` (possibly empty).""" response = self.client.get(self.list_url) diff --git a/gateway/sds_gateway/api_methods/views/file_endpoints.py b/gateway/sds_gateway/api_methods/views/file_endpoints.py index 13025989..fa554ab5 100644 --- a/gateway/sds_gateway/api_methods/views/file_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/file_endpoints.py @@ -1,5 +1,6 @@ """File operations endpoints for the SDS Gateway API.""" +from datetime import UTC from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING @@ -213,11 +214,62 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: serializer = FileGetSerializer(target_file, many=False) return Response(serializer.data) - def _datetime_string_to_milliseconds(self, datetime_string: str) -> int: - """Converts a datetime string to milliseconds since start of capture.""" - parsed = datetime.fromisoformat(datetime_string) + def _iso8601_query_string_to_epoch_ms( + self, param_name: str, datetime_string: str + ) -> int: + """Parse ``start_time`` / ``end_time`` query strings to Unix epoch milliseconds. + + Naive datetimes are treated as UTC (same convention as the Python SDK). + ``Z`` is accepted as UTC. Raises ``ValueError`` with a client-safe message if + the string is not a parseable ISO 8601 datetime. + """ + s = datetime_string.strip() + if len(s) >= 1 and s[-1] in "Zz": + s = s[:-1] + "+00:00" + try: + parsed = datetime.fromisoformat(s) + except ValueError as exc: + msg = ( + f"Invalid {param_name}: expected ISO 8601 datetime " + "(e.g. 2024-01-15T12:30:45+00:00 or naive as UTC)." + ) + raise ValueError(msg) from exc + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=UTC) + else: + parsed = parsed.astimezone(UTC) return int(parsed.timestamp() * 1000) + def _parsed_temporal_bounds_for_file_list( + self, request: Request + ) -> Response | tuple[int | None, int | None]: + """Parse ``start_time`` / ``end_time`` for :meth:`list`, or a 400 response.""" + start_raw = request.GET.get("start_time") or None + end_raw = request.GET.get("end_time") or None + start_ms: int | None = None + end_ms: int | None = None + if start_raw: + try: + start_ms = self._iso8601_query_string_to_epoch_ms( + "start_time", cast("str", start_raw) + ) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + if end_raw: + try: + end_ms = self._iso8601_query_string_to_epoch_ms( + "end_time", cast("str", end_raw) + ) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + return start_ms, end_ms + def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: """Checks if the files include RF data.""" return files.filter(name__regex=DRF_RF_FILENAME_REGEX_STR).exists() @@ -253,8 +305,10 @@ def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: location=OpenApiParameter.QUERY, required=False, description=( - "ISO 8601 datetime; converted to ms for temporal filtering of " - "RF data files when the listing includes Digital RF ``.h5`` data." + "ISO 8601 datetime; parsed to UTC epoch milliseconds for temporal " + "filtering of RF data files when the listing includes Digital RF " + "``.h5`` data. Timezone-aware inputs are converted to UTC; naive " + "values are interpreted as UTC. Invalid values return 400." ), ), OpenApiParameter( @@ -263,7 +317,8 @@ def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: location=OpenApiParameter.QUERY, required=False, description=( - "ISO 8601 datetime; paired with start_time for RF temporal bounds." + "ISO 8601 datetime; paired with start_time for RF temporal bounds " + "(same UTC / naive-as-UTC rules as start_time)." ), ), OpenApiParameter( @@ -295,16 +350,10 @@ def list(self, request: Request) -> Response: # noqa: C901 # warnings to be returned in the response warnings = [] - # Get optional temporal filtering parameters - # time passed as datetime string, need to convert - # to milliseconds since start of capture - start_time = request.GET.get("start_time", None) - end_time = request.GET.get("end_time", None) - if start_time: - start_time = self._datetime_string_to_milliseconds(start_time) - - if end_time: - end_time = self._datetime_string_to_milliseconds(end_time) + temporal = self._parsed_temporal_bounds_for_file_list(request) + if isinstance(temporal, Response): + return temporal + start_time, end_time = temporal unsafe_path = request.GET.get("path", "/").strip() basename = Path(unsafe_path).name diff --git a/sdk/docs/mkdocs/changelog.md b/sdk/docs/mkdocs/changelog.md index 2e502e4c..86bdd796 100644 --- a/sdk/docs/mkdocs/changelog.md +++ b/sdk/docs/mkdocs/changelog.md @@ -1,12 +1,20 @@ # SpectrumX SDK Changelog +## `0.1.19` - YYYY-MM-DD + ++ Features: + + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](https://github.com/spectrumx/sds-code/pull/278): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + + [**Added `capture_uuids` and `top_level_dirs` parameters to `download_dataset`**](https://github.com/spectrumx/sds-code/pull/278): this allows users to input specific capture UUIDs or directories associated with the dataset to download. Users may run `list_dataset_captures` to see the UUID and `top_level_dir` of the captures on the dataset they wish to download. + ++ Observability: + + [**Added `captures` and `files` attributes to `Dataset` model**](https://github.com/spectrumx/sds-code/pull/278): This allows visibility from the dataset side into attached captures and files. + ## `0.1.18` - YYYY-MM-DD + Features: + [**Added `delete` and `revoke_share_permissions` methods for datasets**](https://github.com/spectrumx/sds-code/pull/275): this allows users to (soft) delete datasets in the SDS through the SDK and revoke ALL share permissions from datasets if needed before deletion or in general. + [**Added `revoke_share_permissions` and `detach_from_datasets` methods to captures**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to revoke share permissions or detach captures from connected datasets when they need to delete a capture. + [**Added `detach_from_datasets` methods to files**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to detach files from connected datasets when they need to delete them. Note: Files CANNOT be detached from captures. Delete a the parent capture FIRST to delete the file. - + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](https://github.com/spectrumx/sds-code/pull/275): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + Observability: + [**Added additional fields displaying ownership, share permission, and asset connection information to SDK models**](https://github.com/spectrumx/sds-code/pull/275): this allows users to see more relevant information when retrieving or listing assets like whether they are shared, who they are shared with, who owns them, and what other assets the target is attached to (like files to captures and datasets). @@ -14,7 +22,6 @@ + New capture attributes: `owner`, `is_shared`, `is_shared_with_me`, `share_permissions` + New dataset attributes: `owner`, `is_shared`, `share_permissions`, `datasets` + [**Added new models for User and UserSharePermission**](https://github.com/spectrumx/sds-code/pull/275): This allows for visibility into the users and share permissions connected to assets. - + [**Added `captures` and `files` attributes to `Dataset` model**](https://github.com/spectrumx/sds-code/pull/275): This allows visibility from the dataset side into attached captures and files. ## `0.1.17` - 2025-12-20