Skip to content

Commit

Permalink
add from_extensions class method to create CollectionSearch extension…
Browse files Browse the repository at this point in the history
…s classes (#745)

* add from_extensions class method to create CollectionSearch extensions classes

* Apply suggestions from code review

* Apply suggestions from code review

* fix

* update makefile
  • Loading branch information
vincentsarago authored Sep 3, 2024
1 parent cf55d66 commit 3ae8d86
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

* Add `from_extensions()` method to `CollectionSearchExtension` and `CollectionSearchPostExtension` extensions to build the class based on a list of available extensions.

## [3.0.1] - 2024-08-27

### Changed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Collection-Search extension."""

import warnings
from enum import Enum
from typing import List, Optional, Union

Expand All @@ -8,7 +9,7 @@
from stac_pydantic.api.collections import Collections
from stac_pydantic.shared import MimeTypes

from stac_fastapi.api.models import GeoJSONResponse
from stac_fastapi.api.models import GeoJSONResponse, create_request_model
from stac_fastapi.api.routes import create_async_endpoint
from stac_fastapi.types.config import ApiSettings
from stac_fastapi.types.extension import ApiExtension
Expand Down Expand Up @@ -71,6 +72,48 @@ def register(self, app: FastAPI) -> None:
"""
pass

@classmethod
def from_extensions(
cls,
extensions: List[ApiExtension],
schema_href: Optional[str] = None,
) -> "CollectionSearchExtension":
"""Create CollectionSearchExtension object from extensions."""
supported_extensions = {
"FreeTextExtension": ConformanceClasses.FREETEXT,
"FreeTextAdvancedExtension": ConformanceClasses.FREETEXT,
"QueryExtension": ConformanceClasses.QUERY,
"SortExtension": ConformanceClasses.SORT,
"FieldsExtension": ConformanceClasses.FIELDS,
"FilterExtension": ConformanceClasses.FILTER,
}
conformance_classes = [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
]
for ext in extensions:
conf = supported_extensions.get(ext.__class__.__name__, None)
if not conf:
warnings.warn(
f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501
UserWarning,
)
else:
conformance_classes.append(conf)

get_request_model = create_request_model(
model_name="CollectionsGetRequest",
base_model=BaseCollectionSearchGetRequest,
extensions=extensions,
request_type="GET",
)

return cls(
GET=get_request_model,
conformance_classes=conformance_classes,
schema_href=schema_href,
)


@attr.s
class CollectionSearchPostExtension(CollectionSearchExtension):
Expand Down Expand Up @@ -132,3 +175,60 @@ def register(self, app: FastAPI) -> None:
endpoint=create_async_endpoint(self.client.post_all_collections, self.POST),
)
app.include_router(self.router)

@classmethod
def from_extensions(
cls,
extensions: List[ApiExtension],
*,
client: Union[AsyncBaseCollectionSearchClient, BaseCollectionSearchClient],
settings: ApiSettings,
schema_href: Optional[str] = None,
router: Optional[APIRouter] = None,
) -> "CollectionSearchPostExtension":
"""Create CollectionSearchPostExtension object from extensions."""
supported_extensions = {
"FreeTextExtension": ConformanceClasses.FREETEXT,
"FreeTextAdvancedExtension": ConformanceClasses.FREETEXT,
"QueryExtension": ConformanceClasses.QUERY,
"SortExtension": ConformanceClasses.SORT,
"FieldsExtension": ConformanceClasses.FIELDS,
"FilterExtension": ConformanceClasses.FILTER,
}
conformance_classes = [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
]
for ext in extensions:
conf = supported_extensions.get(ext.__class__.__name__, None)
if not conf:
warnings.warn(
f"Conformance class for `{ext.__class__.__name__}` extension not found.", # noqa: E501
UserWarning,
)
else:
conformance_classes.append(conf)

get_request_model = create_request_model(
model_name="CollectionsGetRequest",
base_model=BaseCollectionSearchGetRequest,
extensions=extensions,
request_type="GET",
)

post_request_model = create_request_model(
model_name="CollectionsPostRequest",
base_model=BaseCollectionSearchPostRequest,
extensions=extensions,
request_type="POST",
)

return cls(
client=client,
settings=settings,
GET=get_request_model,
POST=post_request_model,
conformance_classes=conformance_classes,
router=router or APIRouter(),
schema_href=schema_href,
)
119 changes: 118 additions & 1 deletion stac_fastapi/extensions/tests/test_collection_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@
from urllib.parse import quote_plus

import attr
import pytest
from starlette.testclient import TestClient

from stac_fastapi.api.app import StacApi
from stac_fastapi.api.models import create_request_model
from stac_fastapi.extensions.core import (
AggregationExtension,
CollectionSearchExtension,
CollectionSearchPostExtension,
FieldsExtension,
FilterExtension,
FreeTextAdvancedExtension,
FreeTextExtension,
QueryExtension,
SortExtension,
)
from stac_fastapi.extensions.core.collection_search import ConformanceClasses
from stac_fastapi.extensions.core.collection_search.client import (
Expand Down Expand Up @@ -302,8 +310,8 @@ def test_collection_search_extension_post_models():
client=DummyCoreClient(),
extensions=[
CollectionSearchPostExtension(
settings=settings,
client=DummyPostClient(),
settings=settings,
GET=get_request_model,
POST=post_request_model,
conformance_classes=[
Expand Down Expand Up @@ -392,3 +400,112 @@ def test_collection_search_extension_post_models():
assert response_dict["query"]
assert response_dict["sortby"]
assert response_dict["fields"]


@pytest.mark.parametrize(
"extensions",
[
# with FreeTextExtension
[
FieldsExtension(),
FilterExtension(),
FreeTextExtension(),
QueryExtension(),
SortExtension(),
],
# with FreeTextAdvancedExtension
[
FieldsExtension(),
FilterExtension(),
FreeTextAdvancedExtension(),
QueryExtension(),
SortExtension(),
],
],
)
def test_from_extensions_methods(extensions):
"""
Make sure `from_extensions` create the correct
models and adds desired conformances classes.
"""
ext = CollectionSearchExtension.from_extensions(
extensions,
)
collection_search = ext.GET()
assert collection_search.__class__.__name__ == "CollectionsGetRequest"
assert hasattr(collection_search, "bbox")
assert hasattr(collection_search, "datetime")
assert hasattr(collection_search, "limit")
assert hasattr(collection_search, "fields")
assert hasattr(collection_search, "q")
assert hasattr(collection_search, "sortby")
assert hasattr(collection_search, "filter")
assert ext.conformance_classes == [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
ConformanceClasses.FIELDS,
ConformanceClasses.FILTER,
ConformanceClasses.FREETEXT,
ConformanceClasses.QUERY,
ConformanceClasses.SORT,
]

ext = CollectionSearchPostExtension.from_extensions(
extensions,
client=DummyPostClient(),
settings=ApiSettings(),
)
collection_search = ext.POST()
assert collection_search.__class__.__name__ == "CollectionsPostRequest"
assert hasattr(collection_search, "bbox")
assert hasattr(collection_search, "datetime")
assert hasattr(collection_search, "limit")
assert hasattr(collection_search, "fields")
assert hasattr(collection_search, "q")
assert hasattr(collection_search, "sortby")
assert hasattr(collection_search, "filter")
assert ext.conformance_classes == [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
ConformanceClasses.FIELDS,
ConformanceClasses.FILTER,
ConformanceClasses.FREETEXT,
ConformanceClasses.QUERY,
ConformanceClasses.SORT,
]


def test_from_extensions_methods_invalid():
"""Should raise warnings for invalid extensions."""
extensions = [
AggregationExtension(),
]
with pytest.warns((UserWarning)):
ext = CollectionSearchExtension.from_extensions(
extensions,
)
collection_search = ext.GET()
assert collection_search.__class__.__name__ == "CollectionsGetRequest"
assert hasattr(collection_search, "bbox")
assert hasattr(collection_search, "datetime")
assert hasattr(collection_search, "limit")
assert ext.conformance_classes == [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
]

with pytest.warns((UserWarning)):
ext = CollectionSearchPostExtension.from_extensions(
extensions,
client=DummyPostClient(),
settings=ApiSettings(),
)
collection_search = ext.POST()
assert collection_search.__class__.__name__ == "CollectionsPostRequest"
assert hasattr(collection_search, "bbox")
assert hasattr(collection_search, "datetime")
assert hasattr(collection_search, "limit")
assert ext.conformance_classes == [
ConformanceClasses.COLLECTIONSEARCH,
ConformanceClasses.BASIS,
]

0 comments on commit 3ae8d86

Please sign in to comment.