Skip to content

Submit STAC using transactions API #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
2 changes: 1 addition & 1 deletion dags/veda_data_pipeline/groups/collection_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def ingest_collection_task(ti=None, collection=None):
event=collection,
endpoint="/collections",
cognito_app_secret=cognito_app_secret,
stac_ingestor_api_url=stac_ingestor_api_url
ingest_url=stac_ingestor_api_url
)


Expand Down
53 changes: 33 additions & 20 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from datetime import timedelta
import json
import logging
from copy import deepcopy
import smart_open
from airflow.models.variable import Variable
from airflow.decorators import task
from veda_data_pipeline.utils.submit_stac import submission_handler
from veda_data_pipeline.utils.submit_stac_transactions import submit_transactions_handler

group_kwgs = {"group_id": "Process", "tooltip": "Process"}

airflow_vars = Variable.get("aws_dags_variables")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set deserialize_json=True, do you know if we still need to do:

var.transactions_endpoint_enabled==true ? "True" : null

in main.tf?

Suggested change
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars = Variable.get("aws_dags_variables", deserialize_json=True)

And maybe we refactor this file a bit so we call Variable.get in tasks rather than in the top level, since the docs seem to imply this is best practice: https://airflow.apache.org/docs/apache-airflow/2.8.4/best-practices.html#airflow-variables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love a refactor for this (we call the same set of Variable.get("aws_dags_variables") lines 16 times in our DAGs), but I think that's a different issue. .airflowignore keeps the DAG processor from parsing the groups and utils folders, so there's no performance loss by having these calls here.

airflow_vars_json = json.loads(airflow_vars)
TRANSACTIONS_ENDPOINT_ENABLED = airflow_vars_json.get("TRANSACTIONS_ENDPOINT_ENABLED", False)

def log_task(text: str):
logging.info(text)
Expand All @@ -30,32 +33,42 @@ def remove_thumbnail_asset(ti):
payload.pop("assets", True)
return payload

if TRANSACTIONS_ENDPOINT_ENABLED:
# assuming default chunk size (500), this matches the current dynamoDB configuration on the STAC ingestor
task_kwargs = {"retries": 3, "retry_delay": 10, "retry_exponential_backoff": True, "max_active_tis_per_dag": 2}
submit_kwargs = {}
ingest_url = airflow_vars_json.get("STAC_URL")
app_secret = airflow_vars_json.get("STAC_API_KEYCLOAK_CLIENT_SECRET")
else:
task_kwargs = {"retries": 2, "retry_delay": 60, "retry_exponential_backoff": True, "max_active_tis_per_dag": 5}
submit_kwargs = {"endpoint": "/ingestions"}
ingest_url = airflow_vars_json.get("STAC_INGESTOR_API_URL")
app_secret = airflow_vars_json.get("COGNITO_APP_SECRET")

# with exponential backoff enabled, retry delay is converted to seconds
@task(retries=2, retry_delay=60, retry_exponential_backoff=True, max_active_tis_per_dag=5)
@task(**task_kwargs)
def submit_to_stac_ingestor_task(built_stac: dict):
"""Submit STAC items to the STAC ingestor API."""
event = built_stac.copy()
success_file = event["payload"]["success_event_key"]
with smart_open.open(success_file, "r") as _file:
stac_items = json.loads(_file.read())

airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
cognito_app_secret = airflow_vars_json.get("COGNITO_APP_SECRET")
stac_ingestor_api_url = airflow_vars_json.get("STAC_INGESTOR_API_URL")
try:
success_file = event["payload"]["success_event_key"]
with smart_open.open(success_file, "r") as _file:
stac_items = json.loads(_file.read())
except KeyError:
log_task("No success file found - using event directly")
stac_items = [event]

for item in stac_items:
submission_handler(
event=item,
endpoint="/ingestions",
cognito_app_secret=cognito_app_secret,
stac_ingestor_api_url=stac_ingestor_api_url,
if TRANSACTIONS_ENDPOINT_ENABLED:
submit_transactions_handler(
event=stac_items,
cognito_app_secret=app_secret,
ingest_url=ingest_url,
**submit_kwargs,
)
else:
for item in stac_items:
submission_handler(
event=item,
cognito_app_secret=app_secret,
ingest_url=ingest_url,
**submit_kwargs,
)
return event

@task(retries=2, retry_delay=60, retry_exponential_backoff=True, max_active_tis_per_dag=5)
Expand Down
10 changes: 2 additions & 8 deletions dags/veda_data_pipeline/utils/submit_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,8 @@ def submission_handler(
event: Union[S3LinkInput, StacItemInput, Dict[str, Any]],
endpoint: str = "/ingestions",
cognito_app_secret=None,
stac_ingestor_api_url=None,
context=None,
ingest_url=None,
) -> None | dict:
if context is None:
context = {}

stac_item = event

Expand All @@ -114,12 +111,9 @@ def submission_handler(
print(json.dumps(stac_item, indent=2))
return

cognito_app_secret = cognito_app_secret or os.getenv("COGNITO_APP_SECRET")
stac_ingestor_api_url = stac_ingestor_api_url or os.getenv("STAC_INGESTOR_API_URL")

ingestor = IngestionApi.from_veda_auth_secret(
secret_id=cognito_app_secret,
base_url=stac_ingestor_api_url,
base_url=ingest_url,
)
return ingestor.submit(event=stac_item, endpoint=endpoint)

Expand Down
128 changes: 128 additions & 0 deletions dags/veda_data_pipeline/utils/submit_stac_transactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import logging
import requests
from typing import List, TypedDict
from dataclasses import dataclass

import boto3

logging.basicConfig(level=logging.INFO)

class Creds(TypedDict):
access_token: str
expires_in: int
token_type: str
scope: str

class Secret(TypedDict):
userinfo_url: str
id: str
secret: str
auth_url: str
token_url: str

@dataclass
class TransactionsApi:
base_url: str
token: str

@classmethod
def from_veda_auth_secret(cls, *, secret_id: str, base_url: str) -> "TransactionsApi":
secret_details = cls._get_auth_service_details(secret_id)
credentials = cls._get_app_credentials(**secret_details)
return cls(token=credentials["access_token"], base_url=base_url)

@staticmethod
def _get_auth_service_details(secret_id: str) -> Secret:
client = boto3.client("secretsmanager")
response = client.get_secret_value(SecretId=secret_id)
return json.loads(response["SecretString"])

@staticmethod
def _get_app_credentials(
userinfo_url: str, id: str, secret: str, auth_url: str, token_url: str, **kwargs
) -> Creds:
response = requests.post(
token_url,
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json"
},
data={
"client_id": id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": "stac:item:create stac:collection:create stac:collection:update stac:item:update"
},
)
try:
response.raise_for_status()
except Exception as ex:
print(response.text)
raise RuntimeError(f"Error, {ex}")
return response.json()


def post_items(self, collection_id: str, items: List[dict]) -> dict:
"""
Perform a PUT request to update or create a STAC Item in the given collection.

:param collection_id: The target collection ID.
:param items: list of STAC items to be submitted.
:return: The JSON response (as a dict) from the STAC API.
:raises RuntimeError: If the response is not 200/201.
"""
headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
}
bulk_items = {"items": {item['id']: item for item in items}, "method": "upsert"}
response = requests.post(
f"{self.base_url.rstrip('/')}/collections/{collection_id}/bulk_items",
headers=headers,
json=bulk_items
)

if response.status_code not in (200, 201):
logging.error("Failed PUT request: %s %s", response.status_code, response.text)
raise RuntimeError(f"PUT request failed: {response.text}")

return response.json()


def submit_transactions_handler(
event,
cognito_app_secret=None, # unused, but maintains signature compatibility w/ ingest API
ingest_url=None
):
"""
Handler function that can be integrated in the same way as the existing `submission_handler`,
but uses the TransactionsApi to perform a PUT request to STAC's Transactions endpoint.

:param event: A dict containing the data needed for STAC item submission,
including collection_id, item_id, and the STAC item body itself.
:param context: (Optional) context object, for AWS Lambda or similar environments.
:return: A dict representing the API response.
"""

collection_id = event[0].get("collection")
api = TransactionsApi.from_veda_auth_secret(
secret_id=cognito_app_secret,
base_url=ingest_url,
)
try:
response = api.post_items(
collection_id=collection_id,
items=event,
)
logging.info("STAC Bulk Item POST completed successfully.")
except RuntimeError as err:
logging.error("Error while performing POST: %s", str(err))
raise
return {
"statusCode": 200,
"body": json.dumps({
"message": "POST request completed successfully",
"response": response
})
}
2 changes: 0 additions & 2 deletions dags/veda_data_pipeline/veda_promotion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def transfer_assets_to_production_bucket(ti=None, payload={}):
return payload

with DAG("veda_promotion_pipeline", params=template_dag_run_conf, **dag_args) as dag:
# ECS dependency variable

start = EmptyOperator(task_id="start", dag=dag)
end = EmptyOperator(task_id="end", dag=dag)

Expand Down
4 changes: 3 additions & 1 deletion sm2a/infrastructure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ module "sma-base" {
ASSUME_ROLE_WRITE_ARN = var.assume_role_write_arn,
SM2A_BASE_URL = module.sma-base.airflow_url,
CLOUDFRONT_TO_INVALIDATE = var.cloudfront_to_invalidate
CLOUDFRONT_PATH_TO_INVALIDATE = var.cloudfront_path_to_invalidate
CLOUDFRONT_PATH_TO_INVALIDATE = var.cloudfront_path_to_invalidate,
TRANSACTIONS_ENDPOINT_ENABLED = var.transactions_endpoint_enabled==true ? "True" : null
Copy link
Contributor

@smohiudd smohiudd Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think tf bools are serialized in airflow in python. I used this hack but open if there's a better way.

STAC_API_KEYCLOAK_CLIENT_SECRET=var.stac_api_keycloak_client_secret
}, var.snapshot_bucket_name != "" ? module.rds_backups[0].rds_backup_environment : {}
)
}
10 changes: 10 additions & 0 deletions sm2a/infrastructure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ variable "gh_user_team_id" {

variable "workflows_client_secret" {
}

variable "stac_ingestor_api_url" {
}

Expand Down Expand Up @@ -243,3 +244,12 @@ variable "cloudfront_path_to_invalidate" {
variable "lambda_dag_trigger_function_name" {
default = "trigger-sm2a-dag"
}

variable "transactions_endpoint_enabled" {
type = bool
default = false
}

variable stac_api_keycloak_client_secret{
type = string
}