diff --git a/gateway/pyproject.toml b/gateway/pyproject.toml index 42663a3e9..a885bf39a 100644 --- a/gateway/pyproject.toml +++ b/gateway/pyproject.toml @@ -630,6 +630,11 @@ # Controls PLR0913 max-args = 9 + [tool.ruff.lint.per-file-ignores] + "sds_gateway/api_methods/tests/test_composite_capture_serialization.py" = [ + "PLR2004", + ] + [tool.ruff.format] indent-style = "space" line-ending = "auto" diff --git a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py index 19b37be08..4cf055ba2 100644 --- a/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py +++ b/gateway/sds_gateway/api_methods/helpers/temporal_filtering.py @@ -22,36 +22,17 @@ def drf_rf_filename_from_ms(ms: int) -> str: return f"rf@{ms // 1000}.{ms % 1000:03d}.h5" -def _catch_capture_type_error(capture_type: CaptureType) -> None: +def _catch_value_errors(capture_type: CaptureType, capture: Capture) -> None: + if capture_type != CaptureType.DigitalRF: msg = "Only DigitalRF captures are supported for temporal filtering." log.error(msg) raise ValueError(msg) - -def _filter_capture_data_files_selection_bounds( - capture: Capture, - start_time: int, # relative ms from start of capture (from UI) - end_time: int, # relative ms from start of capture (from UI) -) -> QuerySet[File]: - """Filter the capture file selection bounds to the given start and end times.""" if capture.start_time is None: msg = f"Capture {capture.uuid} has no indexed start_time for temporal filtering" raise ValueError(msg) - epoch_start_ms = capture.start_time * 1000 - start_ms = epoch_start_ms + start_time - end_ms = epoch_start_ms + end_time - - start_file_name = drf_rf_filename_from_ms(start_ms) - end_file_name = drf_rf_filename_from_ms(end_ms) - - data_files = capture.get_drf_data_files_queryset() - return data_files.filter( - name__gte=start_file_name, - name__lte=end_file_name, - ).order_by("name") - def get_capture_files_with_temporal_filter( capture_type: CaptureType, @@ -60,24 +41,47 @@ def get_capture_files_with_temporal_filter( end_time: int | None = None, ) -> QuerySet[File]: """Get the capture files with temporal filtering.""" - _catch_capture_type_error(capture_type) + _catch_value_errors(capture_type, capture) + + capture_files = get_capture_files(capture) if start_time is None or end_time is None: log.warning( "Start or end time is None; returning all capture files without " "temporal filtering" ) - return get_capture_files(capture) + return capture_files - # get non-data files - non_data_files = get_capture_files(capture).exclude( - name__regex=DRF_RF_FILENAME_REGEX_STR - ) + epoch_start_ms = capture.start_time * 1000 + start_ms = epoch_start_ms + start_time + end_ms = epoch_start_ms + end_time - # get data files with temporal filtering - data_files = _filter_capture_data_files_selection_bounds( - capture, start_time, end_time + return filter_files_by_temporal_bounds( + capture_files, + start_ms, + end_ms, ) + +def filter_files_by_temporal_bounds( + files: QuerySet[File], + start_time: int, + end_time: int, +) -> QuerySet[File]: + """Filter files by temporal bounds.""" + + # get non-data files + non_data_files = files.exclude(name__regex=DRF_RF_FILENAME_REGEX_STR) + + unfiltered_data_files = files.filter(name__regex=DRF_RF_FILENAME_REGEX_STR) + + start_file_name = drf_rf_filename_from_ms(start_time) + end_file_name = drf_rf_filename_from_ms(end_time) + + filtered_data_files = unfiltered_data_files.filter( + name__gte=start_file_name, + name__lte=end_file_name, + ).order_by("name") + # return all files - return non_data_files.union(data_files) + return non_data_files.union(filtered_data_files) diff --git a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py index f75a48c56..46266d88d 100644 --- a/gateway/sds_gateway/api_methods/serializers/capture_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/capture_serializers.py @@ -1,10 +1,13 @@ """Capture serializers for the SDS Gateway API methods.""" import logging +from datetime import UTC +from datetime import datetime from typing import Any from typing import cast from django.db.models import Sum +from django.utils import timezone as django_timezone from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from rest_framework.utils.serializer_helpers import ReturnList @@ -17,14 +20,57 @@ from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import UserSharePermission from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer +from sds_gateway.api_methods.serializers.dataset_serializers import ( + DatasetSummarySerializer, +) from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer from sds_gateway.api_methods.serializers.user_serializer import ( UserSharePermissionSerializer, ) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared +from sds_gateway.api_methods.utils.relationship_utils import get_capture_datasets from sds_gateway.api_methods.utils.relationship_utils import get_capture_files +def _epoch_sec_to_iso_utc_z(epoch_sec: int) -> str: + """Format OpenSearch epoch seconds as an ISO 8601 UTC string with ``Z`` suffix.""" + dt = datetime.fromtimestamp(epoch_sec, tz=UTC) + return dt.isoformat().replace("+00:00", "Z") + + +def _epoch_sec_to_local_display(epoch_sec: int) -> str: + """Human-readable local time (same pattern as ``formatted_created_at``).""" + dt = datetime.fromtimestamp(epoch_sec, tz=UTC) + return django_timezone.localtime(dt).strftime("%m/%d/%Y %I:%M:%S %p") + + +def _channel_row_bounds_from_os_meta(meta: dict[str, Any]) -> dict[str, Any]: + """Map OpenSearch metadata to composite channel serializer fields.""" + entry: dict[str, Any] = {} + start_sec = meta.get("start_time") + end_sec = meta.get("end_time") + entry["capture_start_epoch_sec"] = start_sec + entry["capture_end_epoch_sec"] = end_sec + entry["capture_start_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(start_sec) if start_sec is not None else None + ) + entry["capture_end_iso_utc"] = ( + _epoch_sec_to_iso_utc_z(end_sec) if end_sec is not None else None + ) + entry["capture_start_display"] = ( + _epoch_sec_to_local_display(start_sec) if start_sec is not None else None + ) + entry["capture_end_display"] = ( + _epoch_sec_to_local_display(end_sec) if end_sec is not None else None + ) + if start_sec is None or end_sec is None: + entry["length_of_capture_ms"] = None + else: + entry["length_of_capture_ms"] = (end_sec - start_sec) * 1000 + entry["file_cadence_ms"] = meta.get("file_cadence") + return entry + + class FileCaptureListSerializer(serializers.ModelSerializer[File]): class Meta: model = File @@ -76,20 +122,32 @@ class CaptureGetSerializer(serializers.ModelSerializer[Capture]): owner = UserGetSerializer() share_permissions = serializers.SerializerMethodField() is_shared = serializers.SerializerMethodField() + is_shared_with_me = serializers.SerializerMethodField() capture_props = serializers.SerializerMethodField() files = serializers.SerializerMethodField() total_file_size = serializers.SerializerMethodField() data_files_info = serializers.SerializerMethodField() - datasets = DatasetGetSerializer(many=True) + datasets = serializers.SerializerMethodField() center_frequency_ghz = serializers.SerializerMethodField() sample_rate_mhz = serializers.SerializerMethodField() length_of_capture_ms = serializers.SerializerMethodField() file_cadence_ms = serializers.SerializerMethodField() capture_start_epoch_sec = serializers.SerializerMethodField() + capture_start_iso_utc = serializers.SerializerMethodField() + capture_end_iso_utc = serializers.SerializerMethodField() + capture_start_display = serializers.SerializerMethodField() + capture_end_display = serializers.SerializerMethodField() formatted_created_at = serializers.SerializerMethodField() capture_type_display = serializers.SerializerMethodField() post_processed_data = serializers.SerializerMethodField() + def get_datasets(self, capture: Capture) -> list[dict[str, Any]]: + """Datasets linked to this capture; shallow under dataset detail.""" + qs = get_capture_datasets(capture, include_deleted=False) + if self.context.get("omit_nested_dataset_graph"): + return DatasetSummarySerializer(qs, many=True, context=self.context).data + return DatasetGetSerializer(qs, many=True, context=self.context).data + def get_share_permissions(self, capture: Capture) -> list[UserSharePermission]: """Get the share permissions for the capture.""" user_share_permissions = UserSharePermission.objects.filter( @@ -100,6 +158,19 @@ def get_share_permissions(self, capture: Capture) -> list[UserSharePermission]: ) return UserSharePermissionSerializer(user_share_permissions, many=True).data + def get_is_shared_with_me(self, capture: Capture) -> bool: + """Get whether the capture is shared with the current user.""" + request = self.context.get("request") + if request and hasattr(request, "user"): + return UserSharePermission.objects.filter( + shared_with=request.user, + item_type=ItemType.CAPTURE, + item_uuid=capture.uuid, + is_enabled=True, + is_deleted=False, + ).exists() + return False + def get_is_shared(self, capture: Capture) -> bool: """Get whether the capture is shared. @@ -190,6 +261,38 @@ def get_capture_start_epoch_sec(self, capture: Capture) -> int | None: """Capture start as Unix epoch seconds. None if not in OpenSearch.""" return capture.start_time + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_iso_utc(self, capture: Capture) -> str | None: + """Indexed capture start as ISO 8601 UTC (``Z``). None if unavailable.""" + if capture.start_time is None: + return None + + return _epoch_sec_to_iso_utc_z(capture.start_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_iso_utc(self, capture: Capture) -> str | None: + """Indexed capture end as ISO 8601 UTC (``Z``). None if unavailable.""" + if capture.end_time is None: + return None + + return _epoch_sec_to_iso_utc_z(capture.end_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_display(self, capture: Capture) -> str | None: + """Indexed capture start in the active timezone for display.""" + if capture.start_time is None: + return None + + return _epoch_sec_to_local_display(capture.start_time) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_display(self, capture: Capture) -> str | None: + """Indexed capture end in the active timezone for display.""" + if capture.end_time is None: + return None + + return _epoch_sec_to_local_display(capture.end_time) + @extend_schema_field(serializers.DictField) def get_capture_props(self, capture: Capture) -> dict[str, Any]: """Retrieve the indexed metadata for the capture.""" @@ -356,6 +459,22 @@ class ChannelMetadataSerializer(serializers.Serializer): channel_metadata = serializers.DictField() +class CompositeChannelEntrySerializer(serializers.Serializer): + """One channel in a composite capture, including per-channel index bounds.""" + + channel = serializers.CharField() + uuid = serializers.UUIDField() + channel_metadata = serializers.DictField() + capture_start_epoch_sec = serializers.IntegerField(allow_null=True, required=False) + capture_end_epoch_sec = serializers.IntegerField(allow_null=True, required=False) + capture_start_iso_utc = serializers.CharField(allow_null=True, required=False) + capture_end_iso_utc = serializers.CharField(allow_null=True, required=False) + capture_start_display = serializers.CharField(allow_null=True, required=False) + capture_end_display = serializers.CharField(allow_null=True, required=False) + length_of_capture_ms = serializers.IntegerField(allow_null=True, required=False) + file_cadence_ms = serializers.IntegerField(allow_null=True, required=False) + + class CompositeCaptureSerializer(serializers.Serializer): """Serializer for composite captures that contain multiple channels.""" @@ -375,8 +494,8 @@ class CompositeCaptureSerializer(serializers.Serializer): is_shared = serializers.SerializerMethodField() owner = UserGetSerializer() - # Channel-specific fields - channels = serializers.ListField(child=ChannelMetadataSerializer()) + # Channel-specific fields (enriched with OpenSearch bounds per channel) + channels = serializers.SerializerMethodField() # Computed fields share_permissions = serializers.SerializerMethodField() @@ -387,6 +506,104 @@ class CompositeCaptureSerializer(serializers.Serializer): length_of_capture_ms = serializers.SerializerMethodField() file_cadence_ms = serializers.SerializerMethodField() capture_start_epoch_sec = serializers.SerializerMethodField() + capture_start_iso_utc = serializers.SerializerMethodField() + capture_end_iso_utc = serializers.SerializerMethodField() + capture_start_display = serializers.SerializerMethodField() + capture_end_display = serializers.SerializerMethodField() + + def _captures_bulk_by_uuid(self, obj: dict[str, Any]) -> dict[str, Capture]: + """One DB query per composite when ``include_serializer_aux`` was False.""" + key = str(obj.get("uuid", "")) + if not hasattr(self, "_captures_bulk_cache"): + self._captures_bulk_cache: dict[str, dict[str, Capture]] = {} + if key not in self._captures_bulk_cache: + uuids = [ch["uuid"] for ch in obj.get("channels") or []] + self._captures_bulk_cache[key] = ( + {str(c.uuid): c for c in Capture.objects.filter(uuid__in=uuids)} + if uuids + else {} + ) + return self._captures_bulk_cache[key] + + def _capture_for_channel( + self, obj: dict[str, Any], channel_entry: dict[str, Any] + ) -> Capture | None: + """Resolve Capture; prefer auxiliary map from build, otherwise bulk queryset.""" + by_uuid = obj.get("_captures_by_uuid") + uuid_key = str(channel_entry["uuid"]) + if isinstance(by_uuid, dict): + hit = cast("Capture | None", by_uuid.get(uuid_key)) + if hit is not None: + return hit + hit = self._captures_bulk_by_uuid(obj).get(uuid_key) + if hit is not None: + return hit + try: + return Capture.objects.get(uuid=channel_entry["uuid"]) + except Capture.DoesNotExist: + return None + + def _enriched_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: + """Per-channel rows with OpenSearch bounds (each channel may differ).""" + key = str(obj.get("uuid", "")) + if not hasattr(self, "_enriched_channels_cache"): + self._enriched_channels_cache: dict[str, list[dict[str, Any]]] = {} + if key not in self._enriched_channels_cache: + out: list[dict[str, Any]] = [] + for ch in obj.get("channels") or []: + entry: dict[str, Any] = { + "channel": ch["channel"], + "uuid": ch["uuid"], + "channel_metadata": ch.get("channel_metadata", {}), + } + pre_meta = cast( + "dict[str, Any] | None", + ch.get("_per_channel_os_meta"), + ) + if pre_meta is not None: + entry.update(_channel_row_bounds_from_os_meta(pre_meta)) + out.append(entry) + continue + capture = self._capture_for_channel(obj, ch) + if capture is None: + entry.update( + { + "capture_start_epoch_sec": None, + "capture_end_epoch_sec": None, + "capture_start_iso_utc": None, + "capture_end_iso_utc": None, + "capture_start_display": None, + "capture_end_display": None, + "length_of_capture_ms": None, + "file_cadence_ms": None, + } + ) + else: + # One OS round-trip per Capture instance via instance cache. + meta = capture.get_opensearch_metadata() + entry.update(_channel_row_bounds_from_os_meta(meta)) + out.append(entry) + self._enriched_channels_cache[key] = out + return self._enriched_channels_cache[key] + + def _composite_envelope_bounds( + self, + obj: dict[str, Any], + ) -> tuple[int, int] | None: + """Earliest channel start and latest channel end (seconds).""" + pairs = [ + (row["capture_start_epoch_sec"], row["capture_end_epoch_sec"]) + for row in self._enriched_channels(obj) + if row.get("capture_start_epoch_sec") is not None + and row.get("capture_end_epoch_sec") is not None + ] + if not pairs: + return None + return min(s for s, _ in pairs), max(e for _, e in pairs) + + @extend_schema_field(CompositeChannelEntrySerializer(many=True)) + def get_channels(self, obj: dict[str, Any]) -> list[dict[str, Any]]: + return self._enriched_channels(obj) def get_share_permissions(self, obj: dict[str, Any]) -> list[UserSharePermission]: """Get the share permissions for the composite capture.""" @@ -416,9 +633,10 @@ def get_is_shared(self, obj: dict[str, Any]) -> bool: def get_files(self, obj: dict[str, Any]) -> ReturnList[File]: """Get all files from all channels in the composite capture.""" all_files = [] - for channel_data in obj["channels"]: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + for channel_data in obj.get("channels") or []: + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue non_deleted_files = get_capture_files(capture, include_deleted=False) serializer = FileCaptureListSerializer( non_deleted_files, @@ -434,9 +652,10 @@ def get_total_file_size(self, obj: dict[str, Any]) -> int | None: return None total_size = 0 - for channel_data in obj["channels"]: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + for channel_data in obj.get("channels") or []: + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue all_files = get_capture_files(capture, include_deleted=False) result = all_files.aggregate(total_size=Sum("size")) total_size += result["total_size"] or 0 @@ -462,9 +681,10 @@ def get_data_files_info(self, obj: dict[str, Any]) -> dict[str, Any]: total_count = 0 total_size = 0 - for channel_data in obj["channels"]: - capture_uuid = channel_data["uuid"] - capture = Capture.objects.get(uuid=capture_uuid) + for channel_data in obj.get("channels") or []: + capture = self._capture_for_channel(obj, channel_data) + if capture is None: + continue stats = capture.get_drf_data_files_stats() total_count += stats["total_count"] total_size += stats["total_size"] @@ -487,42 +707,81 @@ def get_formatted_created_at(self, obj: dict[str, Any]) -> str: @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_length_of_capture_ms(self, obj: dict[str, Any]) -> int | None: - """Use first channel's bounds for composite capture duration.""" - channels = obj.get("channels") or [] - if not channels: + """Span from earliest channel start to latest channel end (milliseconds).""" + bounds = self._composite_envelope_bounds(obj) + if bounds is None: return None - - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - if capture.end_time is None or capture.start_time is None: - return None - return (capture.end_time - capture.start_time) * 1000 + start_time, end_time = bounds + return (end_time - start_time) * 1000 @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_file_cadence_ms(self, obj: dict[str, Any]) -> int | None: - """Use first channel's file cadence for composite capture.""" - channels = obj.get("channels") or [] - if not channels: + """Mean file cadence across channels (each channel may differ).""" + cadences = [ + row["file_cadence_ms"] + for row in self._enriched_channels(obj) + if row.get("file_cadence_ms") is not None + ] + if not cadences: return None - - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - return capture.file_cadence + return round(sum(cadences) / len(cadences)) @extend_schema_field(serializers.IntegerField(allow_null=True)) def get_capture_start_epoch_sec(self, obj: dict[str, Any]) -> int | None: - """Use first channel's start time for composite capture.""" - channels = obj.get("channels") or [] - if not channels: + """Earliest indexed start among channels (epoch seconds).""" + bounds = self._composite_envelope_bounds(obj) + if bounds is None: return None + start_time, _ = bounds + return start_time - capture = Capture.objects.get(uuid=channels[0]["uuid"]) - return capture.start_time + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_iso_utc(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + start_sec, _ = bounds + return _epoch_sec_to_iso_utc_z(start_sec) + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_iso_utc(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + _, end_sec = bounds + return _epoch_sec_to_iso_utc_z(end_sec) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_start_display(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + start_sec, _ = bounds + return _epoch_sec_to_local_display(start_sec) -def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_capture_end_display(self, obj: dict[str, Any]) -> str | None: + bounds = self._composite_envelope_bounds(obj) + if bounds is None: + return None + _, end_sec = bounds + return _epoch_sec_to_local_display(end_sec) + + +def build_composite_capture_data( + captures: list[Capture], + *, + include_serializer_aux: bool = False, +) -> dict[str, Any]: """Build composite capture data from a list of captures with the same top_level_dir. Args: captures: List of Capture objects to combine into composite + include_serializer_aux: When True, attach non-public fields used only by + :class:`CompositeCaptureSerializer`: per-channel cached OpenSearch + metadata (one search per capture) and a Capture map to avoid duplicate + ORM lookups. Keep False for raw API payloads (capture list/search, + nested dataset captures). Returns: dict: Composite capture data structure @@ -538,20 +797,25 @@ def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: base_capture = captures[0] # Build channel data with metadata - channels = [] + captures_by_uuid: dict[str, Capture] | None = {} if include_serializer_aux else None + channels: list[dict[str, Any]] = [] for capture in captures: - channel_data = { + if captures_by_uuid is not None: + captures_by_uuid[str(capture.uuid)] = capture + channel_data: dict[str, Any] = { "channel": capture.channel, "uuid": capture.uuid, "channel_metadata": retrieve_indexed_metadata(capture), } + if include_serializer_aux: + channel_data["_per_channel_os_meta"] = capture.get_opensearch_metadata() channels.append(channel_data) # Serialize the owner field owner_serializer = UserGetSerializer(base_capture.owner) # Build composite data - return { + composite: dict[str, Any] = { "uuid": base_capture.uuid, # Use first capture's UUID as composite UUID "capture_type": base_capture.capture_type, "capture_type_display": base_capture.get_capture_type_display(), @@ -567,6 +831,9 @@ def build_composite_capture_data(captures: list[Capture]) -> dict[str, Any]: "owner": owner_serializer.data, "channels": channels, } + if captures_by_uuid is not None: + composite["_captures_by_uuid"] = captures_by_uuid + return composite def serialize_capture_or_composite( @@ -585,7 +852,10 @@ def serialize_capture_or_composite( if capture_data["is_composite"]: # Serialize as composite - composite_data = build_composite_capture_data(capture_data["captures"]) + composite_data = build_composite_capture_data( + capture_data["captures"], + include_serializer_aux=True, + ) serializer = CompositeCaptureSerializer(composite_data, context=context) return serializer.data # Serialize as single capture diff --git a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py index 231aea286..744dbc935 100644 --- a/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/dataset_serializers.py @@ -2,19 +2,43 @@ from rest_framework import serializers +from sds_gateway.api_methods.helpers.search_captures import ( + group_captures_by_top_level_dir, +) from sds_gateway.api_methods.models import Dataset from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import PermissionLevel from sds_gateway.api_methods.models import UserSharePermission +from sds_gateway.api_methods.serializers.capture_serializers import ( + build_composite_capture_data, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + serialize_capture_or_composite, +) +from sds_gateway.api_methods.serializers.file_serializers import ( + FileArtifactSummarySerializer, +) +from sds_gateway.api_methods.serializers.user_serializer import UserGetSerializer from sds_gateway.api_methods.serializers.user_serializer import ( UserSharePermissionSerializer, ) from sds_gateway.api_methods.utils.asset_access_control import check_if_shared +from sds_gateway.api_methods.utils.relationship_utils import get_dataset_artifact_files +from sds_gateway.api_methods.utils.relationship_utils import get_dataset_captures READABLE_ISO_DATE_TIME: str = "%Y-%m-%d %H:%M:%S%z" +class DatasetSummarySerializer(serializers.ModelSerializer[Dataset]): + """Minimal dataset shape for capture ``datasets`` (breaks serializer cycles).""" + + class Meta: + model = Dataset + fields = ["uuid", "name", "version", "status", "is_public"] + + class DatasetGetSerializer(serializers.ModelSerializer[Dataset]): + owner = UserGetSerializer(read_only=True) authors = serializers.SerializerMethodField() keywords = serializers.SerializerMethodField() created_at = serializers.DateTimeField( @@ -26,6 +50,8 @@ class DatasetGetSerializer(serializers.ModelSerializer[Dataset]): status_display = serializers.CharField(source="get_status_display", read_only=True) shared_users = serializers.SerializerMethodField() share_permissions = serializers.SerializerMethodField() + captures = serializers.SerializerMethodField() + files = serializers.SerializerMethodField() owner_name = serializers.SerializerMethodField() owner_email = serializers.SerializerMethodField() permission_level = serializers.SerializerMethodField() @@ -131,6 +157,51 @@ def get_share_permissions(self, obj): ) return UserSharePermissionSerializer(user_share_permissions, many=True).data + def get_files(self, obj: Dataset) -> list[dict]: + """Get the files for the dataset. + + Returns: + A list of serialized file objects + """ + non_deleted_files = get_dataset_artifact_files( + obj, + include_deleted=False, + ) + serializer = FileArtifactSummarySerializer( + non_deleted_files, + many=True, + context=self.context, + ) + return serializer.data + + def get_captures(self, obj: Dataset) -> list[dict]: + """Captures for the dataset (one row per logical capture, list API semantics). + + Multi-channel uploads share ``top_level_dir``; those rows are merged into a + single composite payload like :func:`get_composite_captures`. + """ + non_deleted_captures = get_dataset_captures( + obj, + include_deleted=False, + ) + grouped = group_captures_by_top_level_dir(list(non_deleted_captures)) + composite_captures: list[dict] = [] + capture_context = { + **(self.context or {}), + "omit_nested_dataset_graph": True, + } + for capture_list in grouped.values(): + if len(capture_list) > 1: + composite_captures.append(build_composite_capture_data(capture_list)) + else: + composite_captures.append( + serialize_capture_or_composite( + capture_list[0], + context=capture_context, + ) + ) + return composite_captures + def get_is_shared(self, obj): """Check if the dataset is shared.""" return check_if_shared(obj.uuid, ItemType.DATASET) diff --git a/gateway/sds_gateway/api_methods/serializers/file_serializers.py b/gateway/sds_gateway/api_methods/serializers/file_serializers.py index d8e7b9350..343fb7012 100644 --- a/gateway/sds_gateway/api_methods/serializers/file_serializers.py +++ b/gateway/sds_gateway/api_methods/serializers/file_serializers.py @@ -7,6 +7,7 @@ from loguru import logger as log from rest_framework import serializers +from sds_gateway.api_methods.models import Capture from sds_gateway.api_methods.models import File from sds_gateway.api_methods.serializers.capture_serializers import CaptureGetSerializer from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer @@ -18,6 +19,20 @@ CONFLICT = 409 +class FileArtifactSummarySerializer(serializers.ModelSerializer[File]): + """Subset of file fields for nested dataset payloads (avoids recursive graphs).""" + + class Meta: + model = File + fields = ( + "uuid", + "name", + "directory", + "media_type", + "size", + ) + + class FileGetSerializer(serializers.ModelSerializer[File]): owner = UserGetSerializer() datasets = DatasetGetSerializer(many=True) @@ -37,6 +52,31 @@ class Meta: model = File fields = "__all__" + def to_representation(self, instance: File) -> dict[str, Any]: + """Build ``captures`` from M2M plus legacy ``capture`` FK when needed.""" + data = super().to_representation(instance) + merged: list[Capture] = [] + seen_pks: set[uuid.UUID] = set() + for cap in instance.captures.all().order_by("uuid"): + pk = cap.pk + if pk not in seen_pks: + merged.append(cap) + seen_pks.add(pk) + legacy = instance.capture + if legacy is not None and legacy.pk not in seen_pks: + merged.append(legacy) + seen_pks.add(legacy.pk) + context = self.context + data["captures"] = CaptureGetSerializer( + merged, + many=True, + context=context, + ).data + data["capture"] = ( + CaptureGetSerializer(merged[0], context=context).data if merged else None + ) + return data + class FilePostSerializer(serializers.ModelSerializer[File]): class Meta: diff --git a/gateway/sds_gateway/api_methods/tasks.py b/gateway/sds_gateway/api_methods/tasks.py index 440b7dfb7..1d77fee9b 100644 --- a/gateway/sds_gateway/api_methods/tasks.py +++ b/gateway/sds_gateway/api_methods/tasks.py @@ -1316,7 +1316,7 @@ def _get_item_files( end_time=end_time, ) else: - if start_time is not None or end_time is not None: + if start_time or end_time: log.warning( "Temporal filtering is only supported for DigitalRF captures, " "ignoring start_time and end_time" diff --git a/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py b/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py index a3a10c026..f922a8aaf 100644 --- a/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py +++ b/gateway/sds_gateway/api_methods/tests/test_celery_tasks.py @@ -1238,9 +1238,9 @@ def test_large_file_download_redirects_to_sdk(self): def test_get_item_files_with_temporal_bounds_returns_expected_rf_subset(self): """ Task-level test: start_time/end_time flow into _get_item_files. - For DigitalRF captures, ``get_capture_files_with_temporal_filter`` returns - non-DRF capture files (metadata) plus DRF files in the selected time range - (temporal_filtering details are unit-tested in test_temporal_filtering.py). + For DigitalRF captures, ``get_capture_files_with_temporal_filter`` (and + ``filter_files_by_temporal_bounds``) returns non-DRF capture files (metadata) + plus DRF files in the selected time range (see test_temporal_filtering.py). """ # Create DRF-named files for self.capture (epoch 1s..5s) epoch_start_sec = 1 diff --git a/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py new file mode 100644 index 000000000..4482c0a6c --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_composite_capture_serialization.py @@ -0,0 +1,197 @@ +"""Tests for composite (multi-channel) capture serialization. + +Focused on per-channel metadata: indexed ``channel_metadata`` and OpenSearch-derived +bounds/cadence must stay distinct after ``CompositeCaptureSerializer`` runs, and +top-level summary fields must reflect the envelope across channels. +""" + +from __future__ import annotations + +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.test import TestCase + +from sds_gateway.api_methods.models import Capture +from sds_gateway.api_methods.models import CaptureType +from sds_gateway.api_methods.serializers.capture_serializers import ( + CompositeCaptureSerializer, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + _epoch_sec_to_iso_utc_z, +) +from sds_gateway.api_methods.serializers.capture_serializers import ( + build_composite_capture_data, +) +from sds_gateway.api_methods.views.capture_endpoints import _normalize_top_level_dir + +User = get_user_model() + + +def _two_drf_captures_same_group() -> tuple[Capture, Capture]: + """Two DRF captures sharing ``top_level_dir`` (multi-channel group).""" + user = User.objects.create( + email="composite-ser@example.com", + password="testpassword", # noqa: S106 + is_approved=True, + ) + top = _normalize_top_level_dir("test-composite-serialization-group") + cap0 = Capture.objects.create( + capture_type=CaptureType.DigitalRF, + channel="ch0", + index_name="captures-test-drf", + owner=user, + top_level_dir=top, + ) + cap1 = Capture.objects.create( + capture_type=CaptureType.DigitalRF, + channel="ch1", + index_name="captures-test-drf", + owner=user, + top_level_dir=top, + ) + return cap0, cap1 + + +class CompositeCaptureSerializationTests(TestCase): + """Serializer-level tests with OpenSearch and index helpers mocked.""" + + def test_distinct_channel_metadata_preserved(self) -> None: + """Per-channel indexed metadata payloads stay on each channel row.""" + cap0, cap1 = _two_drf_captures_same_group() + + def fake_retrieve(capture: Capture) -> dict: + return { + "channel_key": capture.channel, + "capture_pk": str(capture.uuid), + } + + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + side_effect=fake_retrieve, + ), + patch.object(Capture, "get_opensearch_metadata", return_value={}), + ): + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) + + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["channel_metadata"] == { + "channel_key": "ch0", + "capture_pk": str(cap0.uuid), + } + assert ch_rows["ch1"]["channel_metadata"] == { + "channel_key": "ch1", + "capture_pk": str(cap1.uuid), + } + # Ensure we did not collapse to a single shared dict object + assert ( + ch_rows["ch0"]["channel_metadata"] is not ch_rows["ch1"]["channel_metadata"] + ) + + def test_distinct_opensearch_times_cadence_and_envelope(self) -> None: + """Per-channel bounds/cadence; top-level uses min start and max end.""" + cap0, cap1 = _two_drf_captures_same_group() + + meta_by_uuid = { + str(cap0.uuid): { + "start_time": 1_700_000_000, + "end_time": 1_700_000_100, + "file_cadence": 400, + }, + str(cap1.uuid): { + "start_time": 1_700_000_050, + "end_time": 1_700_000_200, + "file_cadence": 800, + }, + } + + def opensearch_by_instance(self: Capture) -> dict: + return dict(meta_by_uuid[str(self.uuid)]) + + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ), + patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ), + ): + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) + + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["capture_start_epoch_sec"] == 1_700_000_000 + assert ch_rows["ch0"]["capture_end_epoch_sec"] == 1_700_000_100 + assert ch_rows["ch0"]["length_of_capture_ms"] == 100_000 + assert ch_rows["ch0"]["file_cadence_ms"] == 400 + + assert ch_rows["ch1"]["capture_start_epoch_sec"] == 1_700_000_050 + assert ch_rows["ch1"]["capture_end_epoch_sec"] == 1_700_000_200 + assert ch_rows["ch1"]["length_of_capture_ms"] == 150_000 + assert ch_rows["ch1"]["file_cadence_ms"] == 800 + + assert out["capture_start_epoch_sec"] == 1_700_000_000 + # Composite serializer exposes end time via ISO/display, not epoch field + assert out["capture_end_iso_utc"] == _epoch_sec_to_iso_utc_z(1_700_000_200) + assert out["length_of_capture_ms"] == 200_000 + assert out["file_cadence_ms"] == 600 + + def test_channel_with_incomplete_bounds_excluded_from_envelope(self) -> None: + """Incomplete channel bounds are excluded from the composite envelope.""" + cap0, cap1 = _two_drf_captures_same_group() + + def opensearch_by_instance(self: Capture) -> dict: + if self.uuid == cap0.uuid: + return { + "start_time": 1_800_000_000, + "end_time": 1_800_000_030, + "file_cadence": 100, + } + return { + "start_time": 1_800_000_010, + "end_time": None, + "file_cadence": 200, + } + + with ( + patch( + "sds_gateway.api_methods.serializers.capture_serializers" + ".retrieve_indexed_metadata", + return_value={}, + ), + patch.object( + Capture, + "get_opensearch_metadata", + opensearch_by_instance, + ), + ): + composite = build_composite_capture_data( + [cap0, cap1], + include_serializer_aux=True, + ) + + out = CompositeCaptureSerializer(composite, context={}).data + + ch_rows = {row["channel"]: row for row in out["channels"]} + assert ch_rows["ch0"]["length_of_capture_ms"] == 30_000 + assert ch_rows["ch1"]["capture_end_epoch_sec"] is None + assert ch_rows["ch1"]["length_of_capture_ms"] is None + # Envelope from only the complete channel + assert out["capture_start_epoch_sec"] == 1_800_000_000 + assert out["capture_end_iso_utc"] == _epoch_sec_to_iso_utc_z(1_800_000_030) + assert out["length_of_capture_ms"] == 30_000 diff --git a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py index 500be2613..2c1f8201a 100644 --- a/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_dataset_endpoints.py @@ -117,6 +117,18 @@ def _cleanup_dataset_connections(self): if permission.pk: permission.delete() + def test_retrieve_dataset_success(self): + """GET dataset detail returns metadata, captures, and artifact files.""" + url = reverse("api:datasets-detail", kwargs={"pk": self.dataset.uuid}) + response = self.client.get(url) + assert response.status_code == status.HTTP_200_OK + body = response.json() + assert body["uuid"] == str(self.dataset.uuid) + assert "captures" in body + assert "files" in body + assert isinstance(body["captures"], list) + assert isinstance(body["files"], list) + def test_get_dataset_files_success(self): """Test successful dataset files manifest retrieval.""" # Create test files associated with the dataset with MinIO mocking @@ -158,6 +170,7 @@ def test_get_dataset_files_success(self): assert "size" in file_info assert "media_type" in file_info assert file_info["capture"] is None + assert file_info["captures"] == [] def test_get_dataset_files_with_owned_captures(self): """Test dataset files manifest including files from owned captures.""" @@ -216,6 +229,12 @@ def test_get_dataset_files_with_owned_captures(self): for file_info in capture_files: assert file_info["capture"]["uuid"] == str(capture.uuid) assert file_info["capture"]["name"] == capture.name + assert file_info["captures"] + assert any(c["uuid"] == str(capture.uuid) for c in file_info["captures"]) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] def test_get_dataset_files_with_shared_captures(self): """Test dataset files manifest including files from shared captures.""" @@ -290,6 +309,12 @@ def test_get_dataset_files_with_shared_captures(self): for file_info in capture_files: assert file_info["capture"]["uuid"] == str(capture.uuid) assert file_info["capture"]["name"] == capture.name + assert file_info["captures"] + assert any(c["uuid"] == str(capture.uuid) for c in file_info["captures"]) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] def test_get_dataset_files_with_both_owned_and_shared_captures(self): """Test dataset files manifest with both owned and shared captures.""" @@ -390,6 +415,10 @@ def test_get_dataset_files_with_both_owned_and_shared_captures(self): for file_info in owned_capture_files: assert file_info["capture"]["uuid"] == str(owned_capture.uuid) assert file_info["capture"]["name"] == owned_capture.name + assert file_info["captures"] + assert any( + c["uuid"] == str(owned_capture.uuid) for c in file_info["captures"] + ) # Verify shared capture file info structure shared_capture_files = [ @@ -403,6 +432,92 @@ def test_get_dataset_files_with_both_owned_and_shared_captures(self): for file_info in shared_capture_files: assert file_info["capture"]["uuid"] == str(shared_capture.uuid) assert file_info["capture"]["name"] == shared_capture.name + assert file_info["captures"] + assert any( + c["uuid"] == str(shared_capture.uuid) for c in file_info["captures"] + ) + + artifact_only = [f for f in results if f["capture"] is None] + assert len(artifact_only) == 1 + assert artifact_only[0]["captures"] == [] + + def test_get_dataset_files_filter_by_capture_query(self): + """``capture`` query param restricts the manifest server-side.""" + capture = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch", + index_name="ix", + name="cap", + ) + self.created_captures.append(capture) + with MockMinIOContext(b"c"): + f_cap_a = create_file_with_minio_mock( + file_content=b"c", owner=self.user, capture=capture + ) + f_cap_b = create_file_with_minio_mock( + file_content=b"c", owner=self.user, capture=capture + ) + f_art = create_file_with_minio_mock( + file_content=b"c", owner=self.user, dataset=self.dataset + ) + self.created_files.extend([f_cap_a, f_cap_b, f_art]) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get( + base_url, + {"capture": str(capture.uuid)}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["count"] == self.EXPECTED_CAPTURE_FILES + for row in data["results"]: + assert any(c["uuid"] == str(capture.uuid) for c in row["captures"]) + + def test_get_dataset_files_filter_by_capture_invalid_uuid(self): + """Invalid ``capture`` UUID returns 400.""" + with MockMinIOContext(b"x"): + f = create_file_with_minio_mock( + file_content=b"x", owner=self.user, dataset=self.dataset + ) + self.created_files.append(f) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get(base_url, {"capture": "not-a-uuid"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_get_dataset_files_filter_by_top_level_dir(self): + """``top_level_dir`` filters by ``File.directory`` prefix.""" + tld = "/pytest/capture/root" + capture = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch", + index_name="ix", + name="cap", + top_level_dir=tld, + ) + self.created_captures.append(capture) + with MockMinIOContext(b"c"): + f_match = create_file_with_minio_mock( + file_content=b"c", + owner=self.user, + capture=capture, + directory=f"{tld}/channel0/", + ) + f_other = create_file_with_minio_mock( + file_content=b"c", owner=self.user, dataset=self.dataset + ) + self.created_files.extend([f_match, f_other]) + base_url = reverse("api:datasets-files", kwargs={"pk": self.dataset.uuid}) + response = self.client.get( + base_url, + {"top_level_dir": tld}, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["count"] == 1 + assert data["results"][0]["uuid"] == str(f_match.uuid) def test_get_dataset_files_not_found(self): """Test dataset files manifest with non-existent UUID.""" diff --git a/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py b/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py new file mode 100644 index 000000000..18cb8a821 --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_dataset_manifest_filters.py @@ -0,0 +1,37 @@ +"""Tests for dataset manifest query parsing and path normalization.""" + +import uuid + +import pytest +from rest_framework.request import Request +from rest_framework.test import APIRequestFactory + +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + normalize_top_level_dir_prefix, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_capture_uuid_query, +) + + +def test_normalize_top_level_dir_prefix() -> None: + assert normalize_top_level_dir_prefix("foo/bar") == "/foo/bar" + assert normalize_top_level_dir_prefix("/foo/bar/") == "/foo/bar" + assert normalize_top_level_dir_prefix("/") == "/" + + +def test_parse_capture_uuid_query_accepts_comma_separated() -> None: + u1, u2 = uuid.uuid4(), uuid.uuid4() + factory = APIRequestFactory() + req = factory.get("/x", {"capture": f"{u1},{u2}"}) + drf_request = Request(req) + parsed = parse_capture_uuid_query(drf_request) + assert parsed == [u1, u2] + + +def test_parse_capture_uuid_query_invalid_raises() -> None: + factory = APIRequestFactory() + req = factory.get("/x", {"capture": "not-a-uuid"}) + drf_request = Request(req) + with pytest.raises(ValueError, match="Invalid capture UUID"): + parse_capture_uuid_query(drf_request) diff --git a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py index 002ff7b7b..7ee5add9f 100644 --- a/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py +++ b/gateway/sds_gateway/api_methods/tests/test_file_endpoints.py @@ -3,6 +3,8 @@ import time import uuid from collections.abc import Mapping +from datetime import UTC +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -201,6 +203,7 @@ def test_retrieve_latest_file_not_accessible_404(self) -> None: self.list_url, data={"path": self.file.directory}, ) + assert response.data.get("warnings") == [] results = response.data.get("results") assert len(results) == 0, f"Expected no files to be returned, got {results}" @@ -265,6 +268,7 @@ def test_retrieve_latest_file_accessible_200(self) -> None: f"Expected 200, got {response.status_code}" ) response = response.data + assert response.get("warnings") == [] assert "results" in response, ( f"Expected a paginated response with 'results', got: {response}" ) @@ -441,6 +445,82 @@ def test_download_file_no_access_403(self): # Verify 403 response assert response.status_code == status.HTTP_403_FORBIDDEN + def test_list_files_with_temporal_params(self) -> None: + """Temporal params keep non-RF files; RF listings respect time bounds.""" + base_sec = 1_000_000 + for offset in (0, 1, 2, 5): + create_db_file( + owner=self.user, + extras={ + "directory": self.sds_path, + "name": f"rf@{base_sec + offset}.000.h5", + "media_type": "application/x-hdf5", + }, + ) + + path = str(self.file.directory) + # Absolute epoch ms from ISO datetimes; bounds include rf@(base+1)..rf@(base+2). + start_iso = datetime.fromtimestamp(base_sec + 1, tz=UTC).isoformat() + end_iso = datetime.fromtimestamp(base_sec + 2, tz=UTC).isoformat() + response = self.client.get( + self.list_url, + { + "path": path, + "start_time": start_iso, + "end_time": end_iso, + }, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["warnings"] == [] + names = {row["name"] for row in data["results"]} + assert self.file.name in names + assert {f"rf@{base_sec + 1}.000.h5", f"rf@{base_sec + 2}.000.h5"} <= names + assert f"rf@{base_sec}.000.h5" not in names + assert f"rf@{base_sec + 5}.000.h5" not in names + + def test_list_files_invalid_temporal_param_returns_400(self) -> None: + """Malformed start_time yields 400 instead of 500.""" + response = self.client.get( + self.list_url, + { + "path": "/", + "start_time": "not-a-datetime", + "end_time": "2020-01-02T00:00:00", + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "start_time" in response.json()["detail"].lower() + + def test_list_files_includes_warnings_key(self) -> None: + """Paginated list responses always include ``warnings`` (possibly empty).""" + response = self.client.get(self.list_url) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "warnings" in data + assert data["warnings"] == [] + + def test_list_files_directory_temporal_params_without_rf_data_includes_warning( + self, + ) -> None: + """start_time/end_time on dirs with no RF data: full listing + warning.""" + path = str(self.file.directory) + response = self.client.get( + self.list_url, + { + "path": path, + "start_time": "2020-01-01T00:00:00", + "end_time": "2020-01-02T00:00:00", + }, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert "warnings" in data + assert len(data["warnings"]) == 1 + assert "RF data" in data["warnings"][0] + assert data["count"] >= 1 + assert len(data["results"]) >= 1 + def test_list_files_includes_shared_files(self) -> None: """Test that list files includes files from shared captures/datasets.""" @@ -483,6 +563,7 @@ def test_list_files_includes_shared_files(self) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() + assert data.get("warnings") == [] # Should have the original file plus the shared one expected_count = 2 # Original file + shared file assert data["count"] == expected_count, ( @@ -592,6 +673,7 @@ def test_disabled_share_permission_blocks_file_access(self) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() + assert data.get("warnings") == [] # Should only have the original file, not the shared one assert data["count"] == 1, ( f"Expected 1 file (excluding disabled shared), got {data['count']}" diff --git a/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py b/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py new file mode 100644 index 000000000..5f3b7e028 --- /dev/null +++ b/gateway/sds_gateway/api_methods/tests/test_file_get_serializer_captures.py @@ -0,0 +1,95 @@ +"""Tests for FileGetSerializer ``capture`` / ``captures`` merge (SDK-facing shape).""" + +from unittest.mock import patch + +from django.test import TestCase + +from sds_gateway.api_methods.models import Capture +from sds_gateway.api_methods.models import CaptureType +from sds_gateway.api_methods.models import File +from sds_gateway.api_methods.serializers.file_serializers import FileGetSerializer +from sds_gateway.api_methods.tests.factories import DatasetFactory +from sds_gateway.api_methods.tests.factories import UserFactory +from sds_gateway.api_methods.tests.test_file_endpoints import create_db_file + + +class FileGetSerializerCapturesTestCase(TestCase): + """``to_representation`` merges M2M ``captures`` with legacy ``capture`` FK.""" + + def setUp(self) -> None: + self.user = UserFactory() + self.dataset = DatasetFactory(owner=self.user) + self.capture_a = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch-a", + index_name="ix-a", + name="cap-a", + ) + self.capture_b = Capture.objects.create( + owner=self.user, + dataset=self.dataset, + capture_type=CaptureType.DigitalRF, + channel="ch-b", + index_name="ix-b", + name="cap-b", + ) + self.opensearch_patcher = patch( + "sds_gateway.api_methods.helpers.index_handling.retrieve_indexed_metadata", + return_value={}, + ) + self.opensearch_patcher.start() + + def tearDown(self) -> None: + self.opensearch_patcher.stop() + File.objects.filter(owner=self.user).delete() + Capture.objects.filter(owner=self.user).delete() + self.dataset.delete() + self.user.delete() + + def _serialize(self, file_obj: File) -> dict: + return FileGetSerializer(file_obj).data + + def test_legacy_fk_only_populates_both_capture_and_captures(self) -> None: + f = create_db_file(owner=self.user) + f.capture = self.capture_a + f.save(update_fields=["capture"]) + data = self._serialize(f) + assert data["capture"] is not None + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + + def test_m2m_only_populates_both(self) -> None: + f = create_db_file(owner=self.user) + f.captures.add(self.capture_a) + data = self._serialize(f) + assert data["capture"] is not None + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + + def test_fk_and_m2m_same_capture_deduplicated(self) -> None: + f = create_db_file(owner=self.user) + f.capture = self.capture_a + f.save(update_fields=["capture"]) + f.captures.add(self.capture_a) + data = self._serialize(f) + assert len(data["captures"]) == 1 + assert data["captures"][0]["uuid"] == str(self.capture_a.uuid) + assert data["capture"]["uuid"] == str(self.capture_a.uuid) + + def test_m2m_two_captures_fk_none_first_in_uuid_order(self) -> None: + f = create_db_file(owner=self.user) + f.captures.add(self.capture_b, self.capture_a) + data = self._serialize(f) + ordered = sorted([str(self.capture_a.uuid), str(self.capture_b.uuid)]) + assert [c["uuid"] for c in data["captures"]] == ordered + assert data["capture"]["uuid"] == ordered[0] + + def test_no_capture_links_both_null_or_empty(self) -> None: + f = create_db_file(owner=self.user) + data = self._serialize(f) + assert data["capture"] is None + assert data["captures"] == [] diff --git a/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py b/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py new file mode 100644 index 000000000..21ccc34fa --- /dev/null +++ b/gateway/sds_gateway/api_methods/utils/dataset_manifest_filters.py @@ -0,0 +1,68 @@ +"""Database filters for dataset file manifest (download) listings.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from uuid import UUID + +from django.db.models import Q +from django.db.models import QuerySet + +if TYPE_CHECKING: + from rest_framework.request import Request + + from sds_gateway.api_methods.models import File + + +def normalize_top_level_dir_prefix(raw: str) -> str: + """Normalize a capture ``top_level_dir`` / path prefix for URL and DB matching.""" + s = raw.strip().replace("\\", "/") + if not s.startswith("/"): + s = f"/{s}" + return s.rstrip("/") or "/" + + +def parse_capture_uuid_query(request: Request) -> list[UUID]: + """Parse ``capture`` query params (repeat or comma-separated) into UUIDs.""" + vals: list[str] = [] + for item in request.query_params.getlist("capture"): + vals.extend(p.strip() for p in item.split(",") if p.strip()) + out: list[UUID] = [] + for s in vals: + try: + out.append(UUID(s)) + except ValueError as err: + msg = f"Invalid capture UUID: {s!r}" + raise ValueError(msg) from err + return out + + +def parse_top_level_dir_query(request: Request) -> list[str]: + """Parse ``top_level_dir`` query params into normalized path prefixes.""" + vals: list[str] = [] + for item in request.query_params.getlist("top_level_dir"): + vals.extend(p.strip() for p in item.split(",") if p.strip()) + return [normalize_top_level_dir_prefix(v) for v in vals] + + +def filter_dataset_files_queryset( + qs: QuerySet[File], + *, + capture_uuids: list[UUID], + top_level_dir_prefixes: list[str], +) -> QuerySet[File]: + """Restrict manifest files; capture UUID and path filters are OR-ed (SDK parity).""" + if not capture_uuids and not top_level_dir_prefixes: + return qs + parts: list[Q] = [] + if capture_uuids: + parts.append( + Q(captures__uuid__in=capture_uuids) | Q(capture__uuid__in=capture_uuids) + ) + if top_level_dir_prefixes: + qd = Q() + for pre in top_level_dir_prefixes: + qd |= Q(directory=pre) | Q(directory__startswith=f"{pre}/") + parts.append(qd) + q = parts[0] if len(parts) == 1 else parts[0] | parts[1] + return qs.filter(q).distinct() diff --git a/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py b/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py index 71cb295cc..5abb30df7 100644 --- a/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py +++ b/gateway/sds_gateway/api_methods/utils/swagger_example_schema.py @@ -201,6 +201,7 @@ "count": 105, "next": "http://localhost:8000/api/latest/assets/files/?page=2&page_size=3", "previous": None, + "warnings": [], "results": [ { "bucket_name": "spectrumx", diff --git a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py index 21a558645..806c18b88 100644 --- a/gateway/sds_gateway/api_methods/views/dataset_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/dataset_endpoints.py @@ -17,11 +17,21 @@ from sds_gateway.api_methods.models import File from sds_gateway.api_methods.models import ItemType from sds_gateway.api_methods.models import user_has_access_to_item +from sds_gateway.api_methods.serializers.dataset_serializers import DatasetGetSerializer from sds_gateway.api_methods.serializers.file_serializers import FileGetSerializer from sds_gateway.api_methods.utils.asset_access_control import check_if_shared from sds_gateway.api_methods.utils.asset_access_control import ( revoke_share_permissions as revoke_item_share_permissions, ) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + filter_dataset_files_queryset, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_capture_uuid_query, +) +from sds_gateway.api_methods.utils.dataset_manifest_filters import ( + parse_top_level_dir_query, +) from sds_gateway.api_methods.utils.relationship_utils import ( get_dataset_files_including_captures, ) @@ -37,6 +47,60 @@ def _get_file_objects(self, dataset: Dataset) -> QuerySet[File]: """Get all files associated with a dataset.""" return get_dataset_files_including_captures(dataset, include_deleted=False) + @extend_schema( + parameters=[ + OpenApiParameter( + name="id", + description="Dataset UUID", + required=True, + type=str, + location=OpenApiParameter.PATH, + ), + ], + responses={ + 200: OpenApiResponse( + description=("Dataset metadata, captures, and direct artifact files"), + ), + 403: OpenApiResponse(description="Forbidden"), + 404: OpenApiResponse(description="Not Found"), + }, + description=( + "Return dataset metadata with captures (one row per logical capture, " + "including composite multi-channel) and artifact files on the dataset." + ), + summary="Retrieve Dataset", + ) + def retrieve(self, request: Request, pk: str | None = None) -> Response: + """Return serialized dataset including captures and direct (artifact) files.""" + if pk is None: + return Response( + {"detail": "Dataset UUID is required."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + target_dataset = get_object_or_404( + Dataset, + pk=pk, + is_deleted=False, + ) + + assert isinstance(request.user, User), ( + "Expected request.user to be an instance of the custom User model" + ) + if not user_has_access_to_item( + request.user, target_dataset.uuid, ItemType.DATASET + ): + return Response( + {"detail": "You do not have permission to access this dataset."}, + status=status.HTTP_403_FORBIDDEN, + ) + + serializer = DatasetGetSerializer( + target_dataset, + context={"request": request}, + ) + return Response(serializer.data) + @extend_schema( parameters=[ OpenApiParameter( @@ -62,6 +126,27 @@ def _get_file_objects(self, dataset: Dataset) -> QuerySet[File]: location=OpenApiParameter.QUERY, default=FilePagination.page_size, ), + OpenApiParameter( + name="capture", + description=( + "Only include files linked to this capture UUID " + "(repeat param or comma-separated). OR with top_level_dir." + ), + required=False, + type=str, + location=OpenApiParameter.QUERY, + ), + OpenApiParameter( + name="top_level_dir", + description=( + "Only include files whose directory is this path or under it " + "(repeat param or comma-separated). Normalized like capture " + "top_level_dir." + ), + required=False, + type=str, + location=OpenApiParameter.QUERY, + ), ], responses={ 200: OpenApiResponse(description="Dataset file listing"), @@ -113,8 +198,30 @@ def get_dataset_files(self, request: Request, pk: str | None = None) -> Response status=status.HTTP_404_NOT_FOUND, ) - # Order and deduplicate files by path and created_at - ordered_files = dataset_files.order_by("-created_at") + try: + capture_uuids = parse_capture_uuid_query(request) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + top_level_dir_prefixes = parse_top_level_dir_query(request) + if capture_uuids or top_level_dir_prefixes: + dataset_files = filter_dataset_files_queryset( + dataset_files, + capture_uuids=capture_uuids, + top_level_dir_prefixes=top_level_dir_prefixes, + ) + + # Order and deduplicate files by path and created_at; avoid N+1 on captures + ordered_files = ( + dataset_files.order_by("-created_at") + .select_related( + "capture", + "owner", + ) + .prefetch_related("captures", "datasets") + ) paginator = FilePagination() paginated_files = paginator.paginate_queryset(ordered_files, request=request) diff --git a/gateway/sds_gateway/api_methods/views/file_endpoints.py b/gateway/sds_gateway/api_methods/views/file_endpoints.py index 625d92fc7..fa554ab5d 100644 --- a/gateway/sds_gateway/api_methods/views/file_endpoints.py +++ b/gateway/sds_gateway/api_methods/views/file_endpoints.py @@ -1,12 +1,16 @@ """File operations endpoints for the SDS Gateway API.""" +from datetime import UTC +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING +from typing import Any from typing import cast from django.db.models import CharField from django.db.models import F as FExpression from django.db.models import ProtectedError +from django.db.models import QuerySet from django.db.models import Value as WrappedValue from django.db.models.functions import Concat from django.http import HttpResponse @@ -29,6 +33,10 @@ import sds_gateway.api_methods.utils.swagger_example_schema as example_schema from sds_gateway.api_methods.authentication import APIKeyAuthentication from sds_gateway.api_methods.helpers.download_file import download_file +from sds_gateway.api_methods.helpers.temporal_filtering import ( + filter_files_by_temporal_bounds, +) +from sds_gateway.api_methods.models import DRF_RF_FILENAME_REGEX_STR from sds_gateway.api_methods.models import File from sds_gateway.api_methods.serializers.file_serializers import ( FileCheckResponseSerializer, @@ -59,6 +67,17 @@ class FileViewSet(ViewSet): authentication_classes = [APIKeyAuthentication] permission_classes = [IsAuthenticated] + @staticmethod + def _paginated_list_response( + paginator: FilePagination, + serializer_data: Any, + warnings: list[str], + ) -> Response: + """Build paginated file list JSON with a always-present ``warnings`` key.""" + response = paginator.get_paginated_response(serializer_data) + response.data["warnings"] = warnings + return response + @extend_schema( request=FilePostSerializer, responses={ @@ -195,6 +214,66 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: serializer = FileGetSerializer(target_file, many=False) return Response(serializer.data) + def _iso8601_query_string_to_epoch_ms( + self, param_name: str, datetime_string: str + ) -> int: + """Parse ``start_time`` / ``end_time`` query strings to Unix epoch milliseconds. + + Naive datetimes are treated as UTC (same convention as the Python SDK). + ``Z`` is accepted as UTC. Raises ``ValueError`` with a client-safe message if + the string is not a parseable ISO 8601 datetime. + """ + s = datetime_string.strip() + if len(s) >= 1 and s[-1] in "Zz": + s = s[:-1] + "+00:00" + try: + parsed = datetime.fromisoformat(s) + except ValueError as exc: + msg = ( + f"Invalid {param_name}: expected ISO 8601 datetime " + "(e.g. 2024-01-15T12:30:45+00:00 or naive as UTC)." + ) + raise ValueError(msg) from exc + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=UTC) + else: + parsed = parsed.astimezone(UTC) + return int(parsed.timestamp() * 1000) + + def _parsed_temporal_bounds_for_file_list( + self, request: Request + ) -> Response | tuple[int | None, int | None]: + """Parse ``start_time`` / ``end_time`` for :meth:`list`, or a 400 response.""" + start_raw = request.GET.get("start_time") or None + end_raw = request.GET.get("end_time") or None + start_ms: int | None = None + end_ms: int | None = None + if start_raw: + try: + start_ms = self._iso8601_query_string_to_epoch_ms( + "start_time", cast("str", start_raw) + ) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + if end_raw: + try: + end_ms = self._iso8601_query_string_to_epoch_ms( + "end_time", cast("str", end_raw) + ) + except ValueError as err: + return Response( + {"detail": str(err)}, + status=status.HTTP_400_BAD_REQUEST, + ) + return start_ms, end_ms + + def _check_files_includes_rf_data(self, files: QuerySet[File]) -> bool: + """Checks if the files include RF data.""" + return files.filter(name__regex=DRF_RF_FILENAME_REGEX_STR).exists() + @extend_schema( responses={ 200: FileGetSerializer, @@ -220,6 +299,28 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: "or an exact file match (directory + name)." ), ), + OpenApiParameter( + name="start_time", + type=OpenApiTypes.DATETIME, + location=OpenApiParameter.QUERY, + required=False, + description=( + "ISO 8601 datetime; parsed to UTC epoch milliseconds for temporal " + "filtering of RF data files when the listing includes Digital RF " + "``.h5`` data. Timezone-aware inputs are converted to UTC; naive " + "values are interpreted as UTC. Invalid values return 400." + ), + ), + OpenApiParameter( + name="end_time", + type=OpenApiTypes.DATETIME, + location=OpenApiParameter.QUERY, + required=False, + description=( + "ISO 8601 datetime; paired with start_time for RF temporal bounds " + "(same UTC / naive-as-UTC rules as start_time)." + ), + ), OpenApiParameter( name="page", type=OpenApiTypes.INT, @@ -238,7 +339,7 @@ def retrieve(self, request: Request, pk: str | None = None) -> Response: ), ], ) - def list(self, request: Request) -> Response: + def list(self, request: Request) -> Response: # noqa: C901 """ Lists all files accessible to the user (owned + shared via captures/datasets). When `path` is passed, it filters all files matching that subdirectory. @@ -246,6 +347,14 @@ def list(self, request: Request) -> Response: it will retrieve the most recent one that matches that path. Wildcards are not yet supported. """ + # warnings to be returned in the response + warnings = [] + + temporal = self._parsed_temporal_bounds_for_file_list(request) + if isinstance(temporal, Response): + return temporal + start_time, end_time = temporal + unsafe_path = request.GET.get("path", "/").strip() basename = Path(unsafe_path).name @@ -259,7 +368,7 @@ def list(self, request: Request) -> Response: all_valid_user_files, request=request ) serializer = FileGetSerializer(paginated_files, many=True) - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response(paginator, serializer.data, warnings) # For specific paths, use the existing path-based filtering logic user_rel_path = sanitize_path_rel_to_user( @@ -296,7 +405,9 @@ def list(self, request: Request) -> Response: serializer = FileGetSerializer(paginated_files, many=True) # despite being a single result, we return it paginated for consistency - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response( + paginator, serializer.data, warnings + ) log.debug( "No exact match found for " f"{inferred_user_rel_path!s} and name {basename}", @@ -307,6 +418,28 @@ def list(self, request: Request) -> Response: directory__startswith=str(user_rel_path), ) + if self._check_files_includes_rf_data(files_matching_dir): + if start_time and end_time: + files_matching_dir = filter_files_by_temporal_bounds( + files_matching_dir, + start_time, + end_time, + ) + elif start_time or end_time: + msg = ( + "Both start_time and end_time are required for temporal filtering " + "when listing Digital RF data." + ) + log.warning(msg) + warnings.append(msg) + elif start_time or end_time: + msg = ( + "Temporal filtering is only supported " + "for file directories that include RF data" + ) + log.warning(msg) + warnings.append(msg) + files_matching_dir = files_matching_dir.annotate( path=Concat( FExpression("directory"), @@ -334,7 +467,7 @@ def list(self, request: Request) -> Response: f"user files for path {user_rel_path!s} - returning {len(serializer.data)}", ) - return paginator.get_paginated_response(serializer.data) + return self._paginated_list_response(paginator, serializer.data, warnings) @extend_schema( parameters=[ diff --git a/gateway/sds_gateway/users/views/downloads.py b/gateway/sds_gateway/users/views/downloads.py index 8c15584bf..c989c75b3 100644 --- a/gateway/sds_gateway/users/views/downloads.py +++ b/gateway/sds_gateway/users/views/downloads.py @@ -163,7 +163,7 @@ def _validate_time_range( start_time: int | None, end_time: int | None ) -> JsonResponse | None: """Return 400 JsonResponse if both provided and start >= end; else None.""" - if start_time is not None and end_time is not None and start_time >= end_time: + if start_time and end_time and start_time >= end_time: return JsonResponse( { "success": False, diff --git a/sdk/docs/mkdocs/changelog.md b/sdk/docs/mkdocs/changelog.md index a20a42adb..86bdd7969 100644 --- a/sdk/docs/mkdocs/changelog.md +++ b/sdk/docs/mkdocs/changelog.md @@ -1,5 +1,14 @@ # SpectrumX SDK Changelog +## `0.1.19` - YYYY-MM-DD + ++ Features: + + [**Added `start_time` and `end_time` parameters to `list_files` and `download`**](https://github.com/spectrumx/sds-code/pull/278): this gives users the ability to filter file directory downloads associated with DigitalRF captures based on a time span within the capture bounds (similar to time filtering in the web UI on SDS) + + [**Added `capture_uuids` and `top_level_dirs` parameters to `download_dataset`**](https://github.com/spectrumx/sds-code/pull/278): this allows users to input specific capture UUIDs or directories associated with the dataset to download. Users may run `list_dataset_captures` to see the UUID and `top_level_dir` of the captures on the dataset they wish to download. + ++ Observability: + + [**Added `captures` and `files` attributes to `Dataset` model**](https://github.com/spectrumx/sds-code/pull/278): This allows visibility from the dataset side into attached captures and files. + ## `0.1.18` - YYYY-MM-DD + Features: @@ -12,7 +21,6 @@ + New file attributes: `owner`, `captures`, `datasets` + New capture attributes: `owner`, `is_shared`, `is_shared_with_me`, `share_permissions` + New dataset attributes: `owner`, `is_shared`, `share_permissions`, `datasets` - + [**Added new models for User and UserSharePermission**](https://github.com/spectrumx/sds-code/pull/275): This allows for visibility into the users and share permissions connected to assets. ## `0.1.17` - 2025-12-20 diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index a1c60c8bd..cc0b9db95 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -285,9 +285,13 @@ max-args = 9 [tool.ruff.lint.per-file-ignores] - "*.ipynb" = ["E501", "ERA001", "PLR2004", "T201"] + "*.ipynb" = ["E501", "ERA001", "PLR2004", "T201"] + # Monorepo: pre-commit passes `gateway/...` from repo root; local runs may use `../gateway/...` from sdk/. + "**/test_composite_capture_serialization.py" = ["PLR2004"] + "../gateway/sds_gateway/api_methods/utils/metadata_schemas.py" = ["PLC0415"] "gateway/sds_gateway/api_methods/utils/metadata_schemas.py" = ["PLC0415"] - "gateway/sds_gateway/users/views.py" = ["PLC0415"] + "../gateway/sds_gateway/users/views.py" = ["PLC0415"] + "gateway/sds_gateway/users/views.py" = ["PLC0415"] [tool.ruff.format] indent-style = "space" diff --git a/sdk/src/spectrumx/api/datasets.py b/sdk/src/spectrumx/api/datasets.py index 519cdd48c..60cd1266c 100644 --- a/sdk/src/spectrumx/api/datasets.py +++ b/sdk/src/spectrumx/api/datasets.py @@ -2,15 +2,21 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING +from typing import Any from loguru import logger as log +from spectrumx.models.datasets import Dataset from spectrumx.models.files import File from spectrumx.ops.pagination import Paginator from spectrumx.utils import log_user if TYPE_CHECKING: + from collections.abc import Collection + from pathlib import Path + from pathlib import PurePosixPath from uuid import UUID from spectrumx.gateway import GatewayClient @@ -32,26 +38,93 @@ def __init__( self.gateway = gateway self.verbose = verbose + def get(self, dataset_uuid: UUID) -> Dataset: + """Load dataset metadata, captures, and artifact files from SDS. + + Captures are returned in the same grouped shape as the capture list API + (one entry per logical multi-channel capture where applicable). For every + file in the dataset (including capture-linked files), use :meth:`get_files` + instead, which calls the paginated dataset files manifest endpoint. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty Dataset shell") + return Dataset(uuid=dataset_uuid) + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + return Dataset.model_validate_json(raw) + + def list_captures(self, dataset_uuid: UUID) -> list[dict[str, Any]]: + """Return capture payloads linked to the dataset (raw JSON objects). + + Use this when you need composite capture fields (for example ``channels``) + without coercing through :class:`~spectrumx.models.datasets.DatasetCapture`. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty capture list") + return [] + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + data = json.loads(raw) + captures = data.get("captures") + return list(captures) if isinstance(captures, list) else [] + + def list_artifact_files(self, dataset_uuid: UUID) -> list[dict[str, Any]]: + """Return file rows linked directly to the dataset (artifacts), as JSON dicts. + + These are the same objects embedded on :meth:`get` under the ``files`` key. + For the full downloadable manifest (captures plus artifacts), use + :meth:`get_files`. + """ + if self.dry_run: + log_user("Dry run enabled: returning an empty artifact file list") + return [] + + raw = self.gateway.get_dataset( + dataset_uuid=dataset_uuid, + verbose=self.verbose, + ) + data = json.loads(raw) + files = data.get("files") + return list(files) if isinstance(files, list) else [] + def get_files( self, dataset_uuid: UUID, + *, + capture_uuids: Collection[UUID] | None = None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None = None, ) -> Paginator[File]: """Get files in the dataset as a paginator. Args: dataset_uuid: The UUID of the dataset to get files for. + capture_uuids: If set, passed to the gateway to restrict by capture UUID (OR + with ``top_level_dirs``). + top_level_dirs: If set, passed to the gateway as path prefixes under + ``File.directory`` (OR with ``capture_uuids``). Returns: A paginator for the files in the dataset. """ if self.dry_run: log_user("Dry run enabled: files will be simulated") - # Create a paginator that fetches from the dataset files endpoint + list_kwargs: dict[str, Any] = {"dataset_uuid": dataset_uuid} + if capture_uuids is not None: + list_kwargs["capture_uuids"] = tuple(capture_uuids) + if top_level_dirs is not None: + list_kwargs["top_level_dirs"] = tuple(str(p) for p in top_level_dirs) + pagination: Paginator[File] = Paginator( Entry=File, gateway=self.gateway, list_method=self.gateway.get_dataset_files, - list_kwargs={"dataset_uuid": dataset_uuid}, + list_kwargs=list_kwargs, dry_run=self.dry_run, verbose=self.verbose, ) diff --git a/sdk/src/spectrumx/api/sds_files.py b/sdk/src/spectrumx/api/sds_files.py index 9839cdb8b..5cb541bed 100644 --- a/sdk/src/spectrumx/api/sds_files.py +++ b/sdk/src/spectrumx/api/sds_files.py @@ -6,6 +6,8 @@ import os import tempfile import uuid +from datetime import UTC +from datetime import datetime from enum import Enum from enum import auto from multiprocessing.synchronize import RLock @@ -27,6 +29,12 @@ log.trace("Placeholder log to avoid reimporting or resolving unused import warnings.") +def _file_list_time_query_param(value: datetime) -> str: + """Format a datetime for Gateway file list temporal query params (ISO 8601, UTC).""" + value = value.replace(tzinfo=UTC) if value.tzinfo is None else value.astimezone(UTC) + return value.isoformat() + + class FileUploadMode(Enum): """Modes for uploading files to SDS.""" @@ -109,23 +117,41 @@ def list_files( *, client: Client, sds_path: PurePosixPath | Path | str, + start_time: datetime | None = None, + end_time: datetime | None = None, verbose: bool = False, ) -> Paginator[File]: """Lists files in a given SDS path. Args: sds_path: The virtual directory on SDS to list files from. + start_time: When set, lower bound for Digital RF data files (UTC-aligned ISO + sent to the API). Must be used together with ``end_time``. + end_time: Upper bound for Digital RF data files, paired with ``start_time``. Returns: A paginator for the files in the given SDS path. """ + if (start_time is None) ^ (end_time is None): + msg = "start_time and end_time must both be set or both omitted." + raise ValueError(msg) sds_path = PurePosixPath(sds_path) + start_q: str | None = ( + _file_list_time_query_param(start_time) if start_time else None + ) + end_q: str | None = ( + _file_list_time_query_param(end_time) if end_time is not None else None + ) if client.dry_run: log_user("Dry run enabled: files will be simulated") pagination: Paginator[File] = Paginator( gateway=client._gateway, Entry=File, list_method=client._gateway.list_files, - list_kwargs={"sds_path": sds_path}, + list_kwargs={ + "sds_path": sds_path, + "start_time": start_q, + "end_time": end_q, + }, dry_run=client.dry_run, verbose=verbose, ) diff --git a/sdk/src/spectrumx/client.py b/sdk/src/spectrumx/client.py index dbee54d52..b6717f586 100644 --- a/sdk/src/spectrumx/client.py +++ b/sdk/src/spectrumx/client.py @@ -1,6 +1,8 @@ """Client for the SpectrumX Data System.""" +from collections.abc import Collection from collections.abc import Mapping +from datetime import datetime from pathlib import Path from pathlib import PurePosixPath from typing import Any @@ -17,6 +19,7 @@ from spectrumx.errors import process_upload_results from spectrumx.models.captures import Capture from spectrumx.models.captures import CaptureType +from spectrumx.models.datasets import Dataset from spectrumx.ops.pagination import Paginator from . import __version__ @@ -31,6 +34,47 @@ from .utils import log_user_warning +def _normalize_top_level_dir_prefix( + top_level_dir: PurePosixPath | Path | str, +) -> str: + s = str(top_level_dir).strip().replace("\\", "/") + if not s.startswith("/"): + s = f"/{s}" + return s.rstrip("/") or "/" + + +def _resolve_dataset_capture_filter_params( + *, + capture_uuids: Collection[UUID4 | str] | None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None, + dry_run: bool, +) -> tuple[bool, set[UUID] | None, list[str] | None]: + """Return (filter_active, uuid_set, dir_prefixes_norm) for manifest filtering.""" + filter_by_capture = capture_uuids is not None + filter_by_dir = top_level_dirs is not None + filter_active = filter_by_capture or filter_by_dir + if not filter_active: + return False, None, None + if dry_run: + log_user( + "Dry run: capture_uuids / top_level_dirs filters are ignored " + "(simulated manifest files are not capture-tagged)." + ) + return False, None, None + uuid_set: set[UUID] | None = None + dir_prefixes_norm: list[str] | None = None + if filter_by_capture: + uuid_set = {UUID(str(u)) for u in capture_uuids or ()} + if filter_by_dir: + dir_prefixes_norm = [ + _normalize_top_level_dir_prefix(d) for d in top_level_dirs or () + ] + return True, uuid_set, dir_prefixes_norm + + +DownloadFileSource = list[File] | Paginator[File] + + class Client: """Instantiates an SDS client.""" @@ -206,8 +250,10 @@ def download( self, *, from_sds_path: PurePosixPath | Path | str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, to_local_path: Path | str, - files_to_download: list[File] | Paginator[File] | None = None, + files_to_download: DownloadFileSource | None = None, skip_contents: bool = False, overwrite: bool = False, verbose: bool = True, @@ -216,6 +262,10 @@ def download( Args: from_sds_path: The virtual directory on SDS to download files from. + start_time: The start time to filter DRF capture files by + (optional, only applicable to DRF capture files). + end_time: The end time to filter DRF capture files by + (optional, only applicable to DRF capture files). to_local_path: The local path to save the downloaded files to. files_to_download: A paginator or list (in dry run mode) of files to download. If not provided, all files in the directory will be @@ -241,6 +291,8 @@ def download( self._prepare_download_directory(to_local_path) files_to_download = self._get_files_to_download( from_sds_path=from_sds_path, + start_time=start_time, + end_time=end_time, files_to_download=files_to_download, verbose=verbose, ) @@ -266,9 +318,11 @@ def _get_files_to_download( self, *, from_sds_path: PurePosixPath | None, - files_to_download: list[File] | Paginator[File] | None, + start_time: datetime | None = None, + end_time: datetime | None = None, + files_to_download: DownloadFileSource | None, verbose: bool, - ) -> list[File] | Paginator[File]: + ) -> DownloadFileSource: """Get the list of files to download.""" if self.dry_run: log_user( @@ -279,7 +333,11 @@ def _get_files_to_download( log_user(f"Dry run: discovered {len(files_to_download)} files (samples)") else: if from_sds_path is not None: - files_to_download = self.list_files(sds_path=from_sds_path) + files_to_download = self.list_files( + sds_path=from_sds_path, + start_time=start_time, + end_time=end_time, + ) elif files_to_download is None: error_msg = ( "Either a path in the SDS or a paginator/list of files " @@ -294,7 +352,7 @@ def _get_files_to_download( def _download_files( self, *, - files_to_download: list[File] | Paginator[File], + files_to_download: DownloadFileSource, to_local_path: Path, skip_contents: bool, overwrite: bool, @@ -396,18 +454,30 @@ def _is_path_relative_to_target( return local_file_path.is_relative_to(to_local_path) def list_files( - self, sds_path: PurePosixPath | Path | str, *, verbose: bool = False + self, + sds_path: PurePosixPath | Path | str, + start_time: datetime | None = None, + end_time: datetime | None = None, + *, + verbose: bool = False, ) -> Paginator[File]: """Lists files in a given SDS path. Args: sds_path: The virtual directory on SDS to list files from. + start_time: Optional inclusive lower time bound for ``rf@*.h5`` files + (naive datetimes are treated as UTC). Requires ``end_time``. + end_time: Optional inclusive upper time bound; requires ``start_time``. verbose: Show network requests and other info. Returns: A paginator for the files in the given SDS path. """ return self._sds_files.list_files( - client=self, sds_path=sds_path, verbose=verbose + client=self, + sds_path=sds_path, + start_time=start_time, + end_time=end_time, + verbose=verbose, ) def download_file( @@ -451,11 +521,13 @@ def download_dataset( *, dataset_uuid: UUID4 | str, to_local_path: Path | str, + capture_uuids: Collection[UUID4 | str] | None = None, + top_level_dirs: Collection[PurePosixPath | Path | str] | None = None, skip_contents: bool = False, overwrite: bool = False, verbose: bool = True, ) -> list[Result[File]]: - """Downloads all files in a dataset using the existing download infrastructure. + """Downloads files in a dataset using the existing download infrastructure. This approach uses the get_dataset_files endpoint to get a paginated list of File objects and then uses the existing download() method with @@ -464,19 +536,64 @@ def download_dataset( Args: dataset_uuid: The UUID of the dataset to download. to_local_path: The local path to save the downloaded files to. + capture_uuids: If set, only files linked to at least one of these capture + UUIDs are downloaded (dataset artifacts with no capture are excluded). + top_level_dirs: If set, only files whose ``directory`` lies under one of + these capture ``top_level_dir`` paths (as from + :meth:`list_dataset_captures`) are included. Leading/trailing slashes + are normalized. skip_contents: When True, only the metadata is downloaded. overwrite: Whether to overwrite existing local files. verbose: Show progress bars and detailed output. + + If both ``capture_uuids`` and ``top_level_dirs`` are set, a file is included + when it matches **either** criterion (enforced on the gateway). In dry run + mode, capture/path filters are ignored because simulated files are not + capture-tagged. + Returns: A list of results for each file downloaded. """ if isinstance(dataset_uuid, str): dataset_uuid = UUID(dataset_uuid) - # Get all files in the dataset as a paginator - files_to_download = self.datasets.get_files(dataset_uuid=dataset_uuid) + ( + filter_active, + uuid_set, + dir_prefixes_norm, + ) = _resolve_dataset_capture_filter_params( + capture_uuids=capture_uuids, + top_level_dirs=top_level_dirs, + dry_run=self.dry_run, + ) - if verbose: + if filter_active: + uuid_nonempty = bool(uuid_set) + dir_nonempty = bool(dir_prefixes_norm) + if not uuid_nonempty and not dir_nonempty: + files_to_download: DownloadFileSource = [] + elif uuid_nonempty and dir_nonempty: + files_to_download = self.datasets.get_files( + dataset_uuid, + capture_uuids=tuple(uuid_set), + top_level_dirs=tuple(dir_prefixes_norm), + ) + elif uuid_nonempty: + files_to_download = self.datasets.get_files( + dataset_uuid, + capture_uuids=tuple(uuid_set), + ) + else: + files_to_download = self.datasets.get_files( + dataset_uuid, + top_level_dirs=tuple(dir_prefixes_norm), + ) + if verbose and (uuid_nonempty or dir_nonempty): + log_user("Dataset capture filter active (applied on the gateway).") + else: + files_to_download = self.datasets.get_files(dataset_uuid=dataset_uuid) + + if verbose and not filter_active: log_user( f"Downloading files from dataset " f"(total: {len(files_to_download)} files)" @@ -499,6 +616,26 @@ def download_dataset( verbose=verbose, ) + def get_dataset(self, dataset_uuid: UUID4 | str) -> Dataset: + """Fetch dataset metadata, captures, and artifact files from SDS.""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.get(dataset_uuid) + + def list_dataset_captures(self, dataset_uuid: UUID4 | str) -> list[dict[str, Any]]: + """List captures for a dataset (raw dicts; composite payloads supported).""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.list_captures(dataset_uuid) + + def list_dataset_artifact_files( + self, dataset_uuid: UUID4 | str + ) -> list[dict[str, Any]]: + """List dataset artifact files (not the full download manifest).""" + if isinstance(dataset_uuid, str): + dataset_uuid = UUID(dataset_uuid) + return self.datasets.list_artifact_files(dataset_uuid) + def _upload_deprecated( self, *, diff --git a/sdk/src/spectrumx/gateway.py b/sdk/src/spectrumx/gateway.py index c275bd21d..5a27816fd 100644 --- a/sdk/src/spectrumx/gateway.py +++ b/sdk/src/spectrumx/gateway.py @@ -2,6 +2,7 @@ import json import uuid +from collections.abc import Collection from collections.abc import Iterator from enum import StrEnum from http import HTTPStatus @@ -263,21 +264,33 @@ def list_files( sds_path: PurePosixPath | Path | str, page: int = 1, page_size: int = 30, + start_time: str | None = None, + end_time: str | None = None, verbose: bool = False, ) -> bytes: """Lists files from the SDS API. + Args: + start_time: Optional ISO 8601 instant (UTC recommended) for RF temporal + lower bound; must be paired with ``end_time``. + end_time: Optional ISO 8601 instant for RF temporal upper bound. + Returns: The response content from SDS Gateway. """ + params: dict[str, str | int] = { + "page": page, + "page_size": page_size, + "path": str(sds_path), + } + if start_time: + params["start_time"] = start_time + if end_time: + params["end_time"] = end_time response = self._request( method=HTTPMethods.GET, endpoint=Endpoints.FILES, - params={ - "page": page, - "page_size": page_size, - "path": str(sds_path), - }, + params=params, verbose=verbose, ) network.success_or_raise(response, ContextException=FileError) @@ -734,27 +747,65 @@ def revoke_dataset_share_permissions( network.success_or_raise(response, ContextException=DatasetError) return response.content + def get_dataset( + self, + *, + dataset_uuid: uuid.UUID, + verbose: bool = False, + ) -> bytes: + """Fetch dataset metadata including captures and direct (artifact) files. + + Args: + dataset_uuid: UUID of the dataset. + verbose: Whether to log the request. + Returns: + JSON body from the gateway. + Raises: + DatasetError: If the request fails. + """ + response = self._request( + method=HTTPMethods.GET, + endpoint=Endpoints.DATASETS, + asset_id=dataset_uuid.hex, + verbose=verbose, + ) + network.success_or_raise(response, ContextException=DatasetError) + return response.content + def get_dataset_files( self, *, dataset_uuid: uuid.UUID, page: int = 1, page_size: int = 30, + capture_uuids: Collection[uuid.UUID] | None = None, + top_level_dirs: Collection[str | PurePosixPath | Path] | None = None, verbose: bool = False, ) -> bytes: """Get a manifest of files in the dataset for efficient downloading. Args: dataset_uuid: The UUID of the dataset to get files for. + capture_uuids: Optional capture UUIDs to filter server-side (repeat query). + top_level_dirs: Optional directory prefixes to filter server-side. verbose: Show network requests and other info. Returns: The response content containing the dataset file manifest. """ + params_list: list[tuple[str, str | int]] = [ + ("page", page), + ("page_size", page_size), + ] + if capture_uuids: + params_list.extend(("capture", str(cap_id)) for cap_id in capture_uuids) + if top_level_dirs: + params_list.extend(("top_level_dir", str(path)) for path in top_level_dirs) + response = self._request( method=HTTPMethods.GET, endpoint=Endpoints.DATASET_FILES, endpoint_args={"uuid": dataset_uuid.hex}, - params={"page": page, "page_size": page_size}, + params=params_list, verbose=verbose, ) network.success_or_raise(response, ContextException=DatasetError) diff --git a/sdk/src/spectrumx/models/capture_enums.py b/sdk/src/spectrumx/models/capture_enums.py new file mode 100644 index 000000000..6f9a7ce1f --- /dev/null +++ b/sdk/src/spectrumx/models/capture_enums.py @@ -0,0 +1,21 @@ +"""Capture type/origin enums (shared by captures and datasets models).""" + +from enum import StrEnum + + +class CaptureType(StrEnum): + """Capture types in SDS.""" + + DigitalRF = "drf" + RadioHound = "rh" + SigMF = "sigmf" + + +class CaptureOrigin(StrEnum): + """Capture origins in SDS.""" + + System = "system" + User = "user" + + +__all__ = ["CaptureOrigin", "CaptureType"] diff --git a/sdk/src/spectrumx/models/captures.py b/sdk/src/spectrumx/models/captures.py index 5bf81604d..e7efcea41 100644 --- a/sdk/src/spectrumx/models/captures.py +++ b/sdk/src/spectrumx/models/captures.py @@ -1,7 +1,6 @@ """Capture model for SpectrumX.""" from datetime import datetime -from enum import StrEnum from pathlib import Path from pathlib import PurePosixPath from typing import Annotated @@ -14,26 +13,12 @@ from pydantic import field_validator from spectrumx.models.base import SDSModel +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.datasets import Dataset from spectrumx.models.user import User from spectrumx.models.user import UserSharePermission - -class CaptureType(StrEnum): - """Capture types in SDS.""" - - DigitalRF = "drf" - RadioHound = "rh" - SigMF = "sigmf" - - -class CaptureOrigin(StrEnum): - """Capture origins in SDS.""" - - System = "system" - User = "user" - - _d_capture_created_at = "The time the capture was created" _d_capture_props = "The indexed metadata for the capture" _d_capture_type = f"The type of capture {', '.join([x.value for x in CaptureType])}" @@ -48,7 +33,18 @@ class CaptureOrigin(StrEnum): _d_channel = "The channel associated with the capture. Only for RadioHound type." _d_scan_group = "The scan group associated with the capture. Only for Digital-RF type." _d_is_shared = "Whether the capture is shared" +_d_is_shared_with_me = "Whether the capture is shared with the current user" _d_datasets = "Datasets this capture is associated with" +_d_capture_start_iso_utc = ( + "Indexed capture start from OpenSearch as ISO 8601 UTC (when available)" +) +_d_capture_end_iso_utc = "Indexed capture end from OpenSearch as ISO 8601 UTC" +_d_capture_start_display = ( + "Indexed capture start formatted for display (server/local timezone)" +) +_d_capture_end_display = ( + "Indexed capture end formatted for display (server/local timezone)" +) class CaptureFile(BaseModel): @@ -90,7 +86,10 @@ class Capture(SDSModel): list[UserSharePermission], Field(description=_d_share_permissions, default_factory=list), ] - is_shared: Annotated[bool, Field(description=_d_is_shared)] + is_shared: Annotated[bool, Field(description=_d_is_shared, default=False)] + is_shared_with_me: Annotated[ + bool, Field(description=_d_is_shared_with_me, default=False) + ] # optional fields created_at: Annotated[ @@ -103,6 +102,22 @@ class Capture(SDSModel): str | None, Field(max_length=255, description=_d_channel, default=None) ] scan_group: Annotated[UUID4 | None, Field(description=_d_scan_group, default=None)] + capture_start_iso_utc: Annotated[ + str | None, + Field(description=_d_capture_start_iso_utc, default=None), + ] + capture_end_iso_utc: Annotated[ + str | None, + Field(description=_d_capture_end_iso_utc, default=None), + ] + capture_start_display: Annotated[ + str | None, + Field(description=_d_capture_start_display, default=None), + ] + capture_end_display: Annotated[ + str | None, + Field(description=_d_capture_end_display, default=None), + ] @field_validator("capture_props", mode="before") @classmethod @@ -150,5 +165,6 @@ def __repr__(self) -> str: __all__ = [ "Capture", + "CaptureOrigin", "CaptureType", ] diff --git a/sdk/src/spectrumx/models/datasets.py b/sdk/src/spectrumx/models/datasets.py index f02bcc717..61e913ca9 100644 --- a/sdk/src/spectrumx/models/datasets.py +++ b/sdk/src/spectrumx/models/datasets.py @@ -7,17 +7,38 @@ from pydantic import BaseModel from pydantic import ConfigDict +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.user import User from spectrumx.models.user import UserSharePermission +class DatasetFile(BaseModel): + model_config = ConfigDict(extra="ignore") + + uuid: UUID4 | None = None + name: str | None = None + directory: str | None = None + media_type: str | None = None + + +class DatasetCapture(BaseModel): + model_config = ConfigDict(extra="ignore") + + uuid: UUID4 | None = None + name: str | None = None + capture_type: CaptureType | None = None + index_name: str | None = None + origin: CaptureOrigin | None = None + top_level_dir: str | None = None + owner: User | None = None + + class Dataset(BaseModel): """A dataset in SDS.""" model_config = ConfigDict(extra="ignore") - # TODO ownership: include ownership and access level information - uuid: UUID4 | None = None owner: User | None = None name: str | None = None @@ -30,7 +51,7 @@ class Dataset(BaseModel): institutions: list[str] | None = None release_date: datetime | None = None repository: str | None = None - version: str | None = None + version: int | None = None website: str | None = None provenance: dict[str, Any] | None = None citation: dict[str, Any] | None = None @@ -42,8 +63,12 @@ class Dataset(BaseModel): is_shared: bool = False is_shared_with_me: bool = False share_permissions: list[UserSharePermission] | None = None + captures: list[DatasetCapture] | None = None + files: list[DatasetFile] | None = None __all__ = [ "Dataset", + "DatasetCapture", + "DatasetFile", ] diff --git a/sdk/src/spectrumx/ops/pagination.py b/sdk/src/spectrumx/ops/pagination.py index fdfdee53e..c96eab586 100644 --- a/sdk/src/spectrumx/ops/pagination.py +++ b/sdk/src/spectrumx/ops/pagination.py @@ -18,6 +18,7 @@ from spectrumx.gateway import GatewayClient from spectrumx.models import SDSModel from spectrumx.ops import files +from spectrumx.utils import log_user_warning if TYPE_CHECKING: from collections.abc import Generator @@ -275,6 +276,12 @@ def _ingest_new_page(self, raw_page: bytes) -> None: if not isinstance(self._current_page_data, dict): # pragma: no cover msg = "Failed to load page data: expected a dictionary from JSON." raise TypeError(msg) + if not self._has_fetched: + raw_warnings = self._current_page_data.get("warnings") + if isinstance(raw_warnings, list): + for w in raw_warnings: + if isinstance(w, str) and w: + log_user_warning(w) if "count" in self._current_page_data: self._total_matches = self._current_page_data["count"] self._current_page_entries = ( @@ -302,7 +309,11 @@ def main() -> None: # pragma: no cover api_key="does-not-matter-in-dry-run", ), list_method=lambda **kwargs: b'{"count": 25, "results": []}', # Mock response - list_kwargs={"sds_path": "/path/to/files"}, + list_kwargs={ + "sds_path": "/path/to/files", + "start_time": None, + "end_time": None, + }, page_size=10, dry_run=True, # in dry-run this should always generate 2.5 pages verbose=True, diff --git a/sdk/tests/integration/test_file_ops.py b/sdk/tests/integration/test_file_ops.py index bf09516cd..692840241 100644 --- a/sdk/tests/integration/test_file_ops.py +++ b/sdk/tests/integration/test_file_ops.py @@ -1,7 +1,10 @@ """Integration tests for file operations on SDS.""" +import logging import time import uuid +from datetime import UTC +from datetime import datetime from pathlib import Path from pathlib import PurePosixPath from unittest.mock import patch @@ -17,9 +20,12 @@ from spectrumx.utils import get_random_line from tests.integration.conftest import PassthruEndpoints +from tests.integration.test_captures import _upload_drf_capture_test_assets +from tests.integration.test_captures import drf_channel from tests.test_utils import disable_ssl_warnings BLAKE3_HEX_LEN: int = 64 +DRF_SAMPLE_RF_CHUNK_COUNT: int = 16 def test_is_valid_file_allowed(temp_file_with_text_contents) -> None: @@ -710,6 +716,190 @@ def test_file_listing(integration_client: Client, temp_file_tree: Path) -> None: ) +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_rf_narrows_digital_rf_chunks( + integration_client: Client, + drf_sample_top_level_dir: Path, +) -> None: + """Temporal bounds narrow ``rf@*.h5``; other filenames in the dir still list.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip("Digital RF sample tree missing; cannot test temporal RF listing.") + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499740, tz=UTC) + all_rf = { + f.name + for f in integration_client.list_files(sds_path=rf_dir) + if f.name.startswith("rf@") + } + narrow_rf = { + f.name + for f in integration_client.list_files( + sds_path=rf_dir, + start_time=start, + end_time=end, + ) + if f.name.startswith("rf@") + } + assert len(all_rf) == DRF_SAMPLE_RF_CHUNK_COUNT, ( + f"Expected {DRF_SAMPLE_RF_CHUNK_COUNT} RF chunks under {rf_dir}, got {all_rf!r}" + ) + assert narrow_rf == {"rf@1719499740.000.h5"} + + +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_rf_inclusive_range_multiple_chunks( + integration_client: Client, + drf_sample_top_level_dir: Path, +) -> None: + """A multi-second temporal window lists each ``rf@*.h5`` in that span.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip("Digital RF sample tree missing; cannot test temporal RF listing.") + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499742, tz=UTC) + expected = { + "rf@1719499740.000.h5", + "rf@1719499741.000.h5", + "rf@1719499742.000.h5", + } + narrow_rf = { + f.name + for f in integration_client.list_files( + sds_path=rf_dir, + start_time=start, + end_time=end, + ) + if f.name.startswith("rf@") + } + assert narrow_rf == expected + + +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + *PassthruEndpoints.file_content_download(), + ] + ], + indirect=True, +) +def test_download_respects_temporal_rf_window( + integration_client: Client, + drf_sample_top_level_dir: Path, + tmp_path: Path, +) -> None: + """Bulk download with start/end only fetches ``rf@*.h5`` in the UTC window.""" + if not drf_sample_top_level_dir.is_dir(): + pytest.skip("Digital RF sample tree missing; cannot test temporal RF download.") + cap_data = _upload_drf_capture_test_assets( + integration_client=integration_client, + drf_sample_top_level_dir=drf_sample_top_level_dir, + ) + rf_dir = cap_data.capture_top_level / drf_channel / "2024-06-27T14-00-00" + start = datetime.fromtimestamp(1719499740, tz=UTC) + end = datetime.fromtimestamp(1719499741, tz=UTC) + expected_names = {"rf@1719499740.000.h5", "rf@1719499741.000.h5"} + + download_dir = tmp_path / "drf_partial" + results = integration_client.download( + from_sds_path=rf_dir, + start_time=start, + end_time=end, + to_local_path=download_dir, + verbose=False, + ) + failures = [r.error_info for r in results if not r] + assert not failures, f"download failures: {failures}" + downloaded_rf = { + p.name + for p in download_dir.rglob("*") + if p.is_file() and p.name.startswith("rf@") and p.name.endswith(".h5") + } + assert downloaded_rf == expected_names + + +@pytest.mark.integration +@pytest.mark.usefixtures("_integration_setup_teardown") +@pytest.mark.usefixtures("_without_responses") +@pytest.mark.parametrize( + "_without_responses", + argvalues=[ + [ + *PassthruEndpoints.file_content_checks(), + *PassthruEndpoints.file_uploads(), + ] + ], + indirect=True, +) +def test_list_files_temporal_non_rf_directory_warns( + caplog: pytest.LogCaptureFixture, + integration_client: Client, + temp_file_tree: Path, +) -> None: + """API warns when temporal params are used on a tree with no RF data files.""" + caplog.set_level(logging.WARNING) + random_subdir_name = get_random_line(10, include_punctuation=False) + sds_path = PurePosixPath("/test-tree-temporal-warn") / random_subdir_name + results = integration_client.upload( + local_path=temp_file_tree, + sds_path=sds_path, + verbose=False, + ) + failures = [result for result in results if not result] + assert not failures, f"Upload failed: {failures}" + start = datetime(2020, 1, 1, tzinfo=UTC) + end = datetime(2020, 1, 2, tzinfo=UTC) + listed = list( + integration_client.list_files( + sds_path=sds_path, + start_time=start, + end_time=end, + ) + ) + assert len(listed) > 0 + assert any("RF data" in r.getMessage() for r in caplog.records), ( + f"expected server warning in logs; got {caplog.records!r}" + ) + + @pytest.mark.integration @pytest.mark.usefixtures("_integration_setup_teardown") @pytest.mark.usefixtures("_without_responses") diff --git a/sdk/tests/models/test_files.py b/sdk/tests/models/test_files.py index 573733046..b1509b300 100644 --- a/sdk/tests/models/test_files.py +++ b/sdk/tests/models/test_files.py @@ -2,6 +2,7 @@ # pylint: disable=redefined-outer-name +import uuid from datetime import datetime from datetime import timedelta from pathlib import PurePosixPath @@ -11,6 +12,8 @@ from pydantic import BaseModel from pytz import UTC from pytz import timezone +from spectrumx.models.capture_enums import CaptureOrigin +from spectrumx.models.capture_enums import CaptureType from spectrumx.models.files import File from spectrumx.models.files import PermissionRepresentation from spectrumx.models.files import UnixPermissionStr @@ -100,3 +103,64 @@ class Model(BaseModel): assert b.model_dump(context={"mode": PermissionRepresentation.OCTAL}) == { "permission": "0o755" } + + +def test_file_parses_captures_list_from_api( + file_properties: dict[str, Any], +) -> None: + """Gateway file payloads use a ``captures`` list (canonical).""" + cap_uid = uuid.uuid4() + nested_capture = { + "owner": {"id": 1, "email": "test@example.com", "name": "Test User"}, + "is_shared": False, + "share_permissions": [], + "datasets": [], + "capture_props": {}, + "capture_type": CaptureType.DigitalRF.value, + "created_at": datetime.now(UTC).isoformat(), + "index_name": "captures-drf", + "origin": CaptureOrigin.User.value, + "top_level_dir": "/c/tdir", + "uuid": str(cap_uid), + "files": [], + } + raw = { + "uuid": str(uuid.uuid4()), + **file_properties, + "captures": [nested_capture], + } + f = File.model_validate(raw) + assert f.captures is not None + assert len(f.captures) == 1 + assert f.captures[0].uuid == cap_uid + + +def test_file_payload_may_include_redundant_singular_capture( + file_properties: dict[str, Any], +) -> None: + """Gateway may send both ``captures`` and ``capture``; SDK uses ``captures``.""" + cap_uid = uuid.uuid4() + nested_capture = { + "owner": {"id": 1, "email": "test@example.com", "name": "Test User"}, + "is_shared": False, + "share_permissions": [], + "datasets": [], + "capture_props": {}, + "capture_type": CaptureType.DigitalRF.value, + "created_at": datetime.now(UTC).isoformat(), + "index_name": "captures-drf", + "origin": CaptureOrigin.User.value, + "top_level_dir": "/c/tdir", + "uuid": str(cap_uid), + "files": [], + } + raw = { + "uuid": str(uuid.uuid4()), + **file_properties, + "captures": [nested_capture], + "capture": nested_capture, + } + f = File.model_validate(raw) + assert f.captures is not None + assert len(f.captures) == 1 + assert f.captures[0].uuid == cap_uid diff --git a/sdk/tests/ops/test_files.py b/sdk/tests/ops/test_files.py index 9928b1fd7..4e786288d 100644 --- a/sdk/tests/ops/test_files.py +++ b/sdk/tests/ops/test_files.py @@ -5,15 +5,21 @@ import sys import tempfile import uuid as uuidlib +from datetime import UTC from datetime import datetime +from datetime import timedelta +from datetime import timezone from pathlib import Path from pathlib import PurePosixPath +from unittest.mock import patch import pytest import responses from loguru import logger as log from spectrumx import Client +from spectrumx.api.sds_files import _file_list_time_query_param from spectrumx.api.sds_files import delete_file +from spectrumx.api.sds_files import list_files from spectrumx.errors import FileError from spectrumx.gateway import API_TARGET_VERSION from spectrumx.ops.files import ( @@ -103,6 +109,86 @@ def test_get_file_by_id(client: Client, responses: responses.RequestsMock) -> No assert file_sample.uuid == uuid +def test_list_files_start_end_must_be_paired(client: Client) -> None: + """Omitting one of ``start_time`` / ``end_time`` raises before any request.""" + t = datetime(2024, 6, 27, 14, 0, tzinfo=UTC) + with pytest.raises(ValueError, match="both be set or both omitted"): + list_files(client=client, sds_path="/", start_time=t, end_time=None) + with pytest.raises(ValueError, match="both be set or both omitted"): + list_files(client=client, sds_path="/", start_time=None, end_time=t) + + +def test_file_list_time_query_param_naive_treated_as_utc() -> None: + """Naive datetimes are interpreted as UTC for query strings.""" + naive = datetime(2024, 6, 27, 14, 9, 0) # noqa: DTZ001 + out = _file_list_time_query_param(naive) + assert out.endswith("+00:00") + assert "2024-06-27T14:09:00" in out + + +def test_file_list_time_query_param_converts_non_utc_to_utc() -> None: + """Aware datetimes are formatted in UTC (same instant).""" + eastern = timezone(timedelta(hours=-5)) + local = datetime(2024, 6, 27, 9, 9, 0, tzinfo=eastern) + out = _file_list_time_query_param(local) + assert "2024-06-27T14:09:00" in out + assert out.endswith("+00:00") + + +def test_list_files_passes_iso_temporal_params_to_gateway(client: Client) -> None: + """Temporal bounds are forwarded to ``gateway.list_files`` as ISO strings.""" + client.dry_run = False + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=UTC) + end = datetime(2024, 6, 27, 14, 10, 0, tzinfo=UTC) + expected_start = _file_list_time_query_param(start) + expected_end = _file_list_time_query_param(end) + empty_page = b'{"count": 0, "results": []}' + + with patch.object( + client._gateway, # noqa: SLF001 + "list_files", + return_value=empty_page, + ) as m: + paginator = list_files( + client=client, + sds_path=PurePosixPath("/drf/dir"), + start_time=start, + end_time=end, + ) + list(paginator) + + m.assert_called() + for call in m.call_args_list: + kw = call.kwargs + assert kw["start_time"] == expected_start + assert kw["end_time"] == expected_end + assert kw["sds_path"] == PurePosixPath("/drf/dir") + + +def test_client_download_forwards_temporal_bounds_to_list_files( + client: Client, tmp_path: Path +) -> None: + """``Client.download(..., start_time=, end_time=)`` lists with the same bounds.""" + client.dry_run = False + start = datetime(2024, 6, 27, 14, 9, 0, tzinfo=UTC) + end = datetime(2024, 6, 27, 14, 11, 0, tzinfo=UTC) + sds = PurePosixPath("/capture/rf") + + with patch.object(client, "list_files", return_value=[]) as m_list: + client.download( + from_sds_path=sds, + start_time=start, + end_time=end, + to_local_path=tmp_path, + verbose=False, + ) + + m_list.assert_called_once() + assert m_list.call_args.kwargs["sds_path"] == sds + assert m_list.call_args.kwargs["start_time"] == start + assert m_list.call_args.kwargs["end_time"] == end + + def test_file_get_returns_valid( client: Client, ) -> None: diff --git a/sdk/tests/ops/test_paginator.py b/sdk/tests/ops/test_paginator.py index 9760b7184..56865f6fa 100644 --- a/sdk/tests/ops/test_paginator.py +++ b/sdk/tests/ops/test_paginator.py @@ -3,8 +3,11 @@ # ruff: noqa: SLF001 # pyright: reportPrivateUsage=false +import json +import logging import uuid from collections.abc import Generator +from pathlib import PurePosixPath from unittest.mock import MagicMock from unittest.mock import patch @@ -12,6 +15,7 @@ from loguru import logger as log from spectrumx.gateway import GatewayClient from spectrumx.models.files import File +from spectrumx.ops import files as sx_files from spectrumx.ops.pagination import Paginator log.trace("Placeholder log avoid reimporting or resolving unused import warnings.") @@ -96,6 +100,59 @@ def test_paginator_dry_run_ingest_list_files(gateway: GatewayClient) -> None: ) +def test_paginator_preserves_temporal_kwargs_across_pages( + gateway: GatewayClient, +) -> None: + """ + ``start_time`` / ``end_time`` in ``list_kwargs`` + are sent on every ``list_files`` call. + """ + one = sx_files.generate_sample_file(uuid.uuid4()) + two = sx_files.generate_sample_file(uuid.uuid4()) + three = sx_files.generate_sample_file(uuid.uuid4()) + start_iso = "2024-06-27T14:09:00+00:00" + end_iso = "2024-06-27T14:11:00+00:00" + recorded: list[dict[str, object]] = [] + + def side_effect(**kwargs: object) -> bytes: + recorded.append(dict(kwargs)) + page = kwargs["page"] + if page == 1: + results = [ + json.loads(one.model_dump_json()), + json.loads(two.model_dump_json()), + ] + else: + results = [json.loads(three.model_dump_json())] + body = {"count": 3, "results": results} + return json.dumps(body).encode() + + gateway.list_files.side_effect = side_effect + + paginator = Paginator[File]( + Entry=File, + gateway=gateway, + list_method=gateway.list_files, + list_kwargs={ + "sds_path": PurePosixPath("/drf"), + "start_time": start_iso, + "end_time": end_iso, + }, + page_size=2, + dry_run=False, + ) + + consumed = list(paginator) + expected_items = 3 + expected_pages = 2 + assert len(consumed) == expected_items + assert len(recorded) == expected_pages + for call_kw in recorded: + assert call_kw["start_time"] == start_iso + assert call_kw["end_time"] == end_iso + assert call_kw["sds_path"] == PurePosixPath("/drf") + + def test_paginator_dry_run_ingest_get_dataset_files(gateway: GatewayClient) -> None: """Tests the dry-run mode of the paginator for get_dataset_files.""" page_size = 3 @@ -318,3 +375,28 @@ def _get_raw_page( raw_page_suffix = "]}" raw_page_str = raw_page_prefix + ",".join(first_page_results) + raw_page_suffix return raw_page_str.encode() + + +def test_paginator_logs_api_warnings_on_first_page( + caplog: pytest.LogCaptureFixture, + gateway: GatewayClient, +) -> None: + """First-page ``warnings`` in JSON are forwarded via ``log_user_warning``.""" + caplog.set_level(logging.WARNING) + warn_msg = "Temporal filtering is only supported for RF dirs" + one = sx_files.generate_sample_file(uuid.uuid4()) + body: dict[str, object] = { + "count": 1, + "results": [json.loads(one.model_dump_json())], + "warnings": [warn_msg], + } + gateway.list_files.return_value = json.dumps(body).encode() + paginator = Paginator[File]( + Entry=File, + gateway=gateway, + list_method=gateway.list_files, + list_kwargs={"sds_path": "/only-text"}, + dry_run=False, + ) + assert len(paginator) == 1 + assert any(warn_msg in r.getMessage() for r in caplog.records) diff --git a/sdk/tests/test_client.py b/sdk/tests/test_client.py index 28bcaddb0..8b393ef12 100644 --- a/sdk/tests/test_client.py +++ b/sdk/tests/test_client.py @@ -11,6 +11,7 @@ import pytest from loguru import logger as log from spectrumx.client import Client +from spectrumx.client import _resolve_dataset_capture_filter_params from spectrumx.config import SDSConfig from spectrumx.config import _cfg_name_lookup from spectrumx.models.files import File @@ -329,3 +330,25 @@ def test_existing_local_file_identical_checksum_not_redownloaded( assert result, "Result should not be None" assert result() is file_info, "Returned file should be the original file_info" + + +def test_resolve_dataset_capture_filter_dry_run_disables() -> None: + active, uuids, dirs = _resolve_dataset_capture_filter_params( + capture_uuids=[uuid.uuid4()], + top_level_dirs=None, + dry_run=True, + ) + assert not active + assert uuids is None + assert dirs is None + + +def test_resolve_dataset_capture_filter_normalizes_top_level_dirs() -> None: + active, uuids, dirs = _resolve_dataset_capture_filter_params( + capture_uuids=None, + top_level_dirs=["foo/bar", "/baz/"], + dry_run=False, + ) + assert active + assert uuids is None + assert dirs == ["/foo/bar", "/baz"]