Skip to content

Commit 83bd0d3

Browse files
author
Andrzej Pijanowski
committed
feat: Implement header-based collection and geometry filtering for search and collection endpoints.
1 parent 2b8cb54 commit 83bd0d3

File tree

3 files changed

+350
-2
lines changed

3 files changed

+350
-2
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
2424
from stac_fastapi.core.base_settings import ApiBaseSettings
2525
from stac_fastapi.core.datetime_utils import format_datetime_range
26+
from stac_fastapi.core.header_filters import (
27+
parse_filter_collections,
28+
parse_filter_geometry,
29+
)
2630
from stac_fastapi.core.models.links import PagingLinks
2731
from stac_fastapi.core.queryables import (
2832
QueryablesCache,
@@ -449,6 +453,13 @@ async def all_collections(
449453
else:
450454
filtered_collections = collections
451455

456+
# Filter by header collections if present
457+
header_collections = parse_filter_collections(request)
458+
if header_collections is not None:
459+
filtered_collections = [
460+
c for c in filtered_collections if c.get("id") in header_collections
461+
]
462+
452463
links = [
453464
{"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url},
454465
{"rel": Relations.parent.value, "type": MimeTypes.json, "href": base_url},
@@ -580,6 +591,12 @@ async def get_collection(
580591
NotFoundError: If the collection with the given id cannot be found in the database.
581592
"""
582593
request = kwargs["request"]
594+
595+
# Check if collection is allowed by header filter
596+
header_collections = parse_filter_collections(request)
597+
if header_collections is not None and collection_id not in header_collections:
598+
raise HTTPException(status_code=404, detail="Collection not found")
599+
583600
collection = await self.database.find_collection(collection_id=collection_id)
584601
return self.collection_serializer.db_to_stac(
585602
collection=collection,
@@ -665,7 +682,14 @@ async def get_item(
665682
Exception: If any error occurs while getting the item from the database.
666683
NotFoundError: If the item does not exist in the specified collection.
667684
"""
668-
base_url = str(kwargs["request"].base_url)
685+
request = kwargs["request"]
686+
687+
# Check if collection is allowed by header filter
688+
header_collections = parse_filter_collections(request)
689+
if header_collections is not None and collection_id not in header_collections:
690+
raise HTTPException(status_code=404, detail="Item not found")
691+
692+
base_url = str(request.base_url)
669693
item = await self.database.get_one_item(
670694
item_id=item_id, collection_id=collection_id
671695
)
@@ -821,7 +845,14 @@ async def post_search(
821845
search=search, item_ids=search_request.ids
822846
)
823847

824-
if search_request.collections:
848+
# Apply collection filter from header or request
849+
header_collections = parse_filter_collections(request)
850+
if header_collections is not None:
851+
# Use header collections (stac-auth-proxy already did intersection)
852+
search = self.database.apply_collections_filter(
853+
search=search, collection_ids=header_collections
854+
)
855+
elif search_request.collections:
825856
search = self.database.apply_collections_filter(
826857
search=search, collection_ids=search_request.collections
827858
)
@@ -844,6 +875,19 @@ async def post_search(
844875

845876
search = self.database.apply_bbox_filter(search=search, bbox=bbox)
846877

878+
# Apply geometry filter from header
879+
header_geometry = parse_filter_geometry(request)
880+
if header_geometry is not None:
881+
from types import SimpleNamespace
882+
883+
geometry_obj = SimpleNamespace(
884+
type=header_geometry.get("type", ""),
885+
coordinates=header_geometry.get("coordinates", []),
886+
)
887+
search = self.database.apply_intersects_filter(
888+
search=search, intersects=geometry_obj
889+
)
890+
847891
if hasattr(search_request, "intersects") and getattr(
848892
search_request, "intersects"
849893
):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Header-based filtering utilities.
2+
3+
This module provides functions for parsing filter headers from stac-auth-proxy.
4+
Headers allow stac-auth-proxy to pass collection and geometry filters to sfeos.
5+
"""
6+
7+
import json
8+
import logging
9+
from typing import Any, Dict, List, Optional
10+
11+
from fastapi import Request
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# Header names
16+
FILTER_COLLECTIONS_HEADER = "X-Filter-Collections"
17+
FILTER_GEOMETRY_HEADER = "X-Filter-Geometry"
18+
19+
20+
def parse_filter_collections(request: Request) -> Optional[List[str]]:
21+
"""Parse collection filter from X-Filter-Collections header.
22+
23+
Args:
24+
request: FastAPI Request object.
25+
26+
Returns:
27+
List of collection IDs if header is present, None otherwise.
28+
Empty list if header value is empty string.
29+
30+
Example:
31+
Header "X-Filter-Collections: col-a,col-b,col-c" returns ["col-a", "col-b", "col-c"]
32+
"""
33+
header_value = request.headers.get(FILTER_COLLECTIONS_HEADER)
34+
35+
if header_value is None:
36+
return None
37+
38+
# Handle empty header value
39+
if not header_value.strip():
40+
return []
41+
42+
# Parse comma-separated list
43+
collections = [c.strip() for c in header_value.split(",") if c.strip()]
44+
logger.debug(f"Parsed filter collections from header: {collections}")
45+
46+
return collections
47+
48+
49+
def parse_filter_geometry(request: Request) -> Optional[Dict[str, Any]]:
50+
"""Parse geometry filter from X-Filter-Geometry header.
51+
52+
Args:
53+
request: FastAPI Request object.
54+
55+
Returns:
56+
GeoJSON geometry dict if header is present and valid, None otherwise.
57+
58+
Example:
59+
Header 'X-Filter-Geometry: {"type":"Polygon","coordinates":[...]}'
60+
returns the parsed GeoJSON dict.
61+
"""
62+
header_value = request.headers.get(FILTER_GEOMETRY_HEADER)
63+
64+
if header_value is None:
65+
return None
66+
67+
if not header_value.strip():
68+
return None
69+
70+
try:
71+
geometry = json.loads(header_value)
72+
logger.debug(
73+
f"Parsed filter geometry from header: {geometry.get('type', 'unknown')}"
74+
)
75+
return geometry
76+
except json.JSONDecodeError as e:
77+
logger.warning(f"Failed to parse geometry header: {e}")
78+
return None
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""Tests for header-based filtering functionality.
2+
3+
This module tests the header filtering feature that allows stac-auth-proxy
4+
to pass allowed collections and geometries via HTTP headers.
5+
"""
6+
7+
import json
8+
9+
import pytest
10+
import pytest_asyncio
11+
12+
from ..conftest import create_collection, create_item, delete_collections_and_items
13+
14+
# Header names
15+
FILTER_COLLECTIONS_HEADER = "X-Filter-Collections"
16+
FILTER_GEOMETRY_HEADER = "X-Filter-Geometry"
17+
18+
19+
@pytest_asyncio.fixture(scope="function")
20+
async def multi_collection_ctx(txn_client, load_test_data):
21+
"""Create multiple collections for testing header filtering."""
22+
await delete_collections_and_items(txn_client)
23+
24+
# Create test collections
25+
collections = []
26+
for suffix in ["a", "b", "c"]:
27+
collection = load_test_data("test_collection.json").copy()
28+
collection["id"] = f"test-collection-{suffix}"
29+
await create_collection(txn_client, collection)
30+
collections.append(collection)
31+
32+
# Create items in each collection
33+
items = []
34+
for collection in collections:
35+
item = load_test_data("test_item.json").copy()
36+
item["id"] = f"test-item-{collection['id']}"
37+
item["collection"] = collection["id"]
38+
await create_item(txn_client, item)
39+
items.append(item)
40+
41+
yield {"collections": collections, "items": items}
42+
43+
await delete_collections_and_items(txn_client)
44+
45+
46+
class TestHeaderFilteringSearch:
47+
"""Tests for search endpoints with header filtering."""
48+
49+
@pytest.mark.asyncio
50+
async def test_search_uses_header_collections(
51+
self, app_client, multi_collection_ctx
52+
):
53+
"""When X-Filter-Collections header is present, search only in those collections."""
54+
# Search with header limiting to collection-a only
55+
response = await app_client.get(
56+
"/search",
57+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a"},
58+
)
59+
assert response.status_code == 200
60+
data = response.json()
61+
62+
# Should only return items from collection-a
63+
for feature in data["features"]:
64+
assert feature["collection"] == "test-collection-a"
65+
66+
@pytest.mark.asyncio
67+
async def test_search_header_multiple_collections(
68+
self, app_client, multi_collection_ctx
69+
):
70+
"""Header with multiple collections filters to those collections."""
71+
response = await app_client.get(
72+
"/search",
73+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a,test-collection-b"},
74+
)
75+
assert response.status_code == 200
76+
data = response.json()
77+
78+
# Should only return items from collection-a and collection-b
79+
for feature in data["features"]:
80+
assert feature["collection"] in ["test-collection-a", "test-collection-b"]
81+
82+
@pytest.mark.asyncio
83+
async def test_search_no_header_returns_all(self, app_client, multi_collection_ctx):
84+
"""Without header, search returns items from all collections."""
85+
response = await app_client.get("/search")
86+
assert response.status_code == 200
87+
data = response.json()
88+
89+
# Should have items from all collections
90+
collections_in_response = {f["collection"] for f in data["features"]}
91+
assert "test-collection-a" in collections_in_response
92+
assert "test-collection-b" in collections_in_response
93+
assert "test-collection-c" in collections_in_response
94+
95+
@pytest.mark.asyncio
96+
async def test_post_search_uses_header_collections(
97+
self, app_client, multi_collection_ctx
98+
):
99+
"""POST /search also respects the header."""
100+
response = await app_client.post(
101+
"/search",
102+
json={},
103+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-b"},
104+
)
105+
assert response.status_code == 200
106+
data = response.json()
107+
108+
for feature in data["features"]:
109+
assert feature["collection"] == "test-collection-b"
110+
111+
112+
class TestHeaderFilteringCollections:
113+
"""Tests for collections endpoint with header filtering."""
114+
115+
@pytest.mark.asyncio
116+
async def test_all_collections_filtered_by_header(
117+
self, app_client, multi_collection_ctx
118+
):
119+
"""GET /collections only returns collections from header."""
120+
response = await app_client.get(
121+
"/collections",
122+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a,test-collection-c"},
123+
)
124+
assert response.status_code == 200
125+
data = response.json()
126+
127+
collection_ids = [c["id"] for c in data["collections"]]
128+
assert "test-collection-a" in collection_ids
129+
assert "test-collection-c" in collection_ids
130+
assert "test-collection-b" not in collection_ids
131+
132+
@pytest.mark.asyncio
133+
async def test_get_collection_allowed_by_header(
134+
self, app_client, multi_collection_ctx
135+
):
136+
"""GET /collections/{id} works when collection is in header."""
137+
response = await app_client.get(
138+
"/collections/test-collection-a",
139+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a,test-collection-b"},
140+
)
141+
assert response.status_code == 200
142+
assert response.json()["id"] == "test-collection-a"
143+
144+
@pytest.mark.asyncio
145+
async def test_get_collection_no_header_allowed(
146+
self, app_client, multi_collection_ctx
147+
):
148+
"""GET /collections/{id} works without header."""
149+
response = await app_client.get("/collections/test-collection-a")
150+
assert response.status_code == 200
151+
assert response.json()["id"] == "test-collection-a"
152+
153+
154+
class TestHeaderFilteringItems:
155+
"""Tests for item endpoints with header filtering."""
156+
157+
@pytest.mark.asyncio
158+
async def test_item_collection_uses_header(self, app_client, multi_collection_ctx):
159+
"""GET /collections/{id}/items respects header."""
160+
response = await app_client.get(
161+
"/collections/test-collection-a/items",
162+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a"},
163+
)
164+
assert response.status_code == 200
165+
166+
@pytest.mark.asyncio
167+
async def test_get_item_with_header(self, app_client, multi_collection_ctx):
168+
"""GET /collections/{id}/items/{item_id} works with header."""
169+
response = await app_client.get(
170+
"/collections/test-collection-a/items/test-item-test-collection-a",
171+
headers={FILTER_COLLECTIONS_HEADER: "test-collection-a"},
172+
)
173+
assert response.status_code == 200
174+
175+
176+
class TestGeometryHeaderFiltering:
177+
"""Tests for geometry header filtering."""
178+
179+
@pytest.mark.asyncio
180+
async def test_search_with_geometry_header(self, app_client, ctx):
181+
"""Search respects X-Filter-Geometry header."""
182+
# Geometry that intersects with test item
183+
geometry = {
184+
"type": "Polygon",
185+
"coordinates": [
186+
[
187+
[149.0, -34.5],
188+
[149.0, -32.0],
189+
[151.5, -32.0],
190+
[151.5, -34.5],
191+
[149.0, -34.5],
192+
]
193+
],
194+
}
195+
196+
response = await app_client.get(
197+
"/search",
198+
headers={FILTER_GEOMETRY_HEADER: json.dumps(geometry)},
199+
)
200+
assert response.status_code == 200
201+
# Items should be filtered by geometry
202+
203+
@pytest.mark.asyncio
204+
async def test_search_with_non_intersecting_geometry(self, app_client, ctx):
205+
"""Search with non-intersecting geometry returns no items."""
206+
# Geometry that doesn't intersect with test item
207+
geometry = {
208+
"type": "Polygon",
209+
"coordinates": [
210+
[
211+
[0.0, 0.0],
212+
[0.0, 1.0],
213+
[1.0, 1.0],
214+
[1.0, 0.0],
215+
[0.0, 0.0],
216+
]
217+
],
218+
}
219+
220+
response = await app_client.get(
221+
"/search",
222+
headers={FILTER_GEOMETRY_HEADER: json.dumps(geometry)},
223+
)
224+
assert response.status_code == 200
225+
data = response.json()
226+
assert len(data["features"]) == 0

0 commit comments

Comments
 (0)