Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faster loading through parquet #54

Merged
merged 21 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: Run tests
run: |
source venv/bin/activate
python -m pytest -rsx tests/ --data-path=/home/runner/work/fair-mast/fair-mast/tests/mock_data/mini
python -m pytest -rsx tests/ --data-path=/home/runner/work/fair-mast/fair-mast/tests/mock_data/index

ruff-code-check:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions dev/docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ services:
restart: always
volumes:
- ../../tests/mock_data:/code/data
- ../../data/index:/code/index
- ../../src:/code/src
ports:
- '8081:5000'
Expand Down
2 changes: 1 addition & 1 deletion docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ logo: assets/MAST_plasma_image.jpg

# Information about where the book exists on the web
repository:
url: https://github.com/samueljackson92/mast-book # Online location of your book
url: https://github.com/ukaea/fair-mast/ # Online location of your book
branch: main

exclude_patterns: [data/*, 'data']
Expand Down
7 changes: 3 additions & 4 deletions docs/config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# host: https://mastapp.site
host: http://localhost:8081
rest_api: http://localhost:8081/json
host: https://mastapp.site
rest_api: http://mastapp.site/json
graphql_api: https://mastapp.site/graphql
s3_api: https://s3.echo.stfc.ac.uk
s3_api: https://s3.echo.stfc.ac.uk
186 changes: 53 additions & 133 deletions src/api/create.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
import numpy as np
from enum import Enum
from pathlib import Path
import pandas as pd
import dask
import click
import uuid
import pyarrow.parquet as pq
from tqdm import tqdm
from sqlalchemy_utils.functions import (
drop_database,
Expand Down Expand Up @@ -107,14 +109,14 @@ def create_user(self):

def create_cpf_summary(self, data_path: Path):
"""Create the CPF summary table"""
paths = data_path.glob("*_cpf_columns.parquet")
paths = data_path.glob("cpf/*_cpf_columns.parquet")
for path in paths:
df = pd.read_parquet(path)
df.to_sql("cpf_summary", self.uri, if_exists="replace")

def create_scenarios(self, data_path: Path):
"""Create the scenarios metadata table"""
shot_file_name = data_path.parent / "shot_metadata.parquet"
shot_file_name = data_path / "shots.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
ids = shot_metadata["scenario_id"].unique()
scenarios = shot_metadata["scenario"].unique()
Expand All @@ -125,21 +127,26 @@ def create_scenarios(self, data_path: Path):

def create_shots(self, data_path: Path):
"""Create the shot metadata table"""
shot_file_name = data_path.parent / "shot_metadata.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
sources_file = data_path / "sources.parquet"
sources_metadata = pd.read_parquet(sources_file)
shot_ids = sources_metadata.shot_id.unique()

shot_file_name = data_path / "shots.parquet"
shot_metadata = pd.read_parquet(shot_file_name)
shot_metadata = shot_metadata.loc[shot_metadata["shot_id"] <= LAST_MAST_SHOT]
shot_metadata["facility"] = "MAST"
shot_metadata = shot_metadata.loc[shot_metadata.shot_id.isin(shot_ids)]
shot_metadata = shot_metadata.set_index("shot_id", drop=True)
shot_metadata = shot_metadata.sort_index()

shot_metadata["scenario"] = shot_metadata["scenario_id"]
shot_metadata["facility"] = "MAST"
shot_metadata = shot_metadata.drop(["scenario_id", "reference_id"], axis=1)
shot_metadata["uuid"] = shot_metadata.index.map(get_dataset_uuid)
shot_metadata["url"] = (
"s3://mast/level1/shots/" + shot_metadata.index.astype(str) + ".zarr"
)

paths = data_path.glob("*_cpf_data.parquet")
paths = data_path.glob("cpf/*_cpf_data.parquet")
cpfs = []
for path in paths:
cpf_metadata = read_cpf_metadata(path)
Expand All @@ -148,153 +155,69 @@ def create_shots(self, data_path: Path):
cpfs.append(cpf_metadata)

cpfs = pd.concat(cpfs, axis=0)
cpfs = cpfs = cpfs.reset_index()
cpfs = cpfs.loc[cpfs.shot_id <= LAST_MAST_SHOT]
cpfs = cpfs.drop_duplicates(subset="shot_id")
cpfs = cpfs.set_index("shot_id")

shot_metadata = pd.merge(
shot_metadata,
cpfs,
left_on="shot_id",
right_on="shot_id",
how="inner",
how="left",
)

shot_metadata.to_sql("shots", self.uri, if_exists="append")

def create_signal_datasets(self, file_name: str, url_type: URLType = URLType.S3):
"""Create the signal metadata table"""
signal_dataset_metadata = pd.read_parquet(file_name)
signal_dataset_metadata = signal_dataset_metadata.loc[
~signal_dataset_metadata.uri.str.contains("mini")
]
signal_dataset_metadata = signal_dataset_metadata.loc[
~signal_dataset_metadata["type"].isna()
]

signal_dataset_metadata["name"] = signal_dataset_metadata["name"].map(
normalize_signal_name
)
signal_dataset_metadata["quality"] = signal_dataset_metadata["status"].map(
lookup_status_code
)

signal_dataset_metadata["dimensions"] = signal_dataset_metadata[
"dimensions"
].map(list)
signal_dataset_metadata["doi"] = ""

signal_dataset_metadata["url"] = signal_dataset_metadata["name"].map(
lambda name: f"s3://mast/{name}.zarr"
)

signal_dataset_metadata["signal_type"] = signal_dataset_metadata["type"]
signal_dataset_metadata["csd3_path"] = signal_dataset_metadata["uri"]

signal_metadata = signal_dataset_metadata[
[
# "context_",
"uuid",
"name",
"description",
"signal_type",
"quality",
"dimensions",
"rank",
"units",
"doi",
"url",
"csd3_path",
]
]
signal_metadata.to_sql(
"signal_datasets", self.uri, if_exists="append", index=False
)

def create_signals(self, data_path: Path):
logging.info(f"Loading signals from {data_path}/signals")
file_names = data_path.glob("signals/**/*.parquet")
file_names = list(file_names)
logging.info(f"Loading signals from {data_path}")
file_name = data_path / "signals.parquet"

parquet_file = pq.ParquetFile(file_name)
batch_size = 10000
n = math.ceil(parquet_file.scan_contents() / batch_size)
for batch in tqdm(parquet_file.iter_batches(batch_size=batch_size), total=n):
signals_metadata = batch.to_pandas()

for file_name in tqdm(file_names):
signals_metadata = pd.read_parquet(file_name)
signals_metadata = signals_metadata.rename(
columns=dict(shot_nums="shot_id")
)

if len(signals_metadata) == 0 or "shot_id" not in signals_metadata.columns:
continue

df = signals_metadata
df = df[df.shot_id <= LAST_MAST_SHOT].copy()
df = df.rename({"dataset_item_uuid": "uuid"}, axis=1)
df["uuid"] = [
get_dataset_item_uuid(item["name"], item["shot_id"])
for key, item in df.iterrows()
]
df = df[df.shot_id <= LAST_MAST_SHOT]
df = df.drop_duplicates(subset="uuid")

df["quality"] = df["status"].map(lookup_status_code)

df["shape"] = df["shape"].map(
lambda x: x.tolist() if x is not None else None
)
df["shape"] = df["shape"].map(lambda x: x.tolist())
df["dimensions"] = df["dimensions"].map(lambda x: x.tolist())

df["url"] = (
"s3://mast/shots/M9/" + df["shot_id"].map(str) + ".zarr/" + df["group"]
"s3://mast/level1/shots/"
+ df["shot_id"].map(str)
+ ".zarr/"
+ df["name"]
)

df["version"] = 0
df["signal_type"] = df["type"]

if "IMAGE_SUBCLASS" not in df:
df["IMAGE_SUBCLASS"] = None

df["subclass"] = df["IMAGE_SUBCLASS"]

if "format" not in df:
df["format"] = None

if "units" not in df:
df["units"] = ""

uda_attributes = ["uda_name", "mds_name", "file_name", "format"]
df = df.drop(uda_attributes, axis=1)
df["shot_id"] = df.shot_id.astype(int)
columns = [
"uuid",
"shot_id",
"quality",
"shape",
"name",
"url",
"version",
"units",
"signal_type",
"description",
"subclass",
"format",
]
df = df[columns]
df = df.set_index("shot_id")
df = df.set_index("shot_id", drop=True)
df["description"] = df.description.map(lambda x: "" if x is None else x)
df.to_sql("signals", self.uri, if_exists="append")

def create_sources(self, data_path: Path):
source_metadata = pd.read_parquet(data_path.parent / "sources_metadata.parquet")
source_metadata["name"] = source_metadata["source_alias"]
source_metadata["source_type"] = source_metadata["type"]
source_metadata = source_metadata[["description", "name", "source_type"]]
source_metadata = source_metadata.drop_duplicates()
source_metadata = source_metadata.sort_values("name")
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)

def create_shot_source_links(self, data_path: Path):
sources_metadata = pd.read_parquet(
data_path.parent / "sources_metadata.parquet"
)
sources_metadata["source"] = sources_metadata["source_alias"]
sources_metadata["quality"] = sources_metadata["status"].map(lookup_status_code)
sources_metadata["shot_id"] = sources_metadata["shot"].astype(int)
sources_metadata = sources_metadata[
["source", "shot_id", "quality", "pass", "format"]
]
sources_metadata = sources_metadata.sort_values("source")
sources_metadata.to_sql(
"shot_source_link", self.uri, if_exists="append", index=False
source_metadata = pd.read_parquet(data_path / "sources.parquet")
source_metadata = source_metadata.drop_duplicates("uuid")
source_metadata = source_metadata.loc[source_metadata.shot_id <= LAST_MAST_SHOT]
source_metadata["url"] = (
"s3://mast/level1/shots/"
+ source_metadata["shot_id"].map(str)
+ ".zarr/"
+ source_metadata["name"]
)
column_names = ["uuid", "shot_id", "name", "description", "quality", "url"]
source_metadata = source_metadata[column_names]
source_metadata.to_sql("sources", self.uri, if_exists="append", index=False)


def read_cpf_metadata(cpf_file_name: Path) -> pd.DataFrame:
Expand All @@ -320,22 +243,19 @@ def create_db_and_tables(data_path):

# populate the database tables
logging.info("Create CPF summary")
client.create_cpf_summary(data_path)
client.create_cpf_summary(data_path / "cpf")

logging.info("Create Scenarios")
client.create_scenarios(data_path)

logging.info("Create Shots")
client.create_shots(data_path)

logging.info("Create Signals")
client.create_signals(data_path)

logging.info("Create Sources")
client.create_sources(data_path)

logging.info("Create Shot Source Links")
client.create_shot_source_links(data_path)
logging.info("Create Signals")
client.create_signals(data_path)

client.create_user()

Expand Down
18 changes: 8 additions & 10 deletions src/api/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,21 +260,19 @@ def on_request_end(self):
)
class Shot:
@strawberry.field
def signal_datasets(
self,
limit: Optional[int] = None,
where: Optional[ShotWhereFilter] = None,
) -> List[strawberry.LazyType["Shot", __module__]]: # noqa: F821
results = do_where_child_member(self.signal_datasets, where)
def signals(
self, limit: Optional[int] = None, where: Optional[SignalWhereFilter] = None
) -> List[strawberry.LazyType["Signal", __module__]]: # noqa: F821
results = do_where_child_member(self.signals, where)
if limit is not None:
results = results[:limit]
return results

@strawberry.field
def signals(
self, limit: Optional[int] = None, where: Optional[SignalWhereFilter] = None
) -> List[strawberry.LazyType["Signal", __module__]]: # noqa: F821
results = do_where_child_member(self.signals, where)
def sources(
self, limit: Optional[int] = None, where: Optional[SourceWhereFilter] = None
) -> List[strawberry.LazyType["Source", __module__]]: # noqa: F821
results = do_where_child_member(self.sources, where)
if limit is not None:
results = results[:limit]
return results
Expand Down
Loading
Loading