Skip to content

Commit

Permalink
Allow users to disable schema check and creation on load_file (#1922)
Browse files Browse the repository at this point in the history
Support running `load_file` without checking if the table schema exists
or trying to create it.

Recently a user reported that the cost of checking if the schema exists
is very high for Snowflake:
"I have a (`load_file`) task that took 1:36 minutes to run, and it was
1:30 running the information schema query."
This is likely happening for other databases as well.

Introduce two ways of disabling schema checks:

1. On a per-task basis, by exposing the argument `schema_exists` in
`aql.load_file`
When this argument is `True`, the SDK will not check if the schema
exists or try to create it.
It is `False` by default, and the Python SDK will behave as of 1.6
(running schema check and, if needed, trying to create the schema)

2. Globally, by exposing the Airflow configuration
`load_table_schema_exists` in the `[astro-sdk]` section. This can also
be set using the environment variable
`AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS`. The global configuration
can be overridden per task, using [1].

Closes: #1921
  • Loading branch information
tatiana authored May 5, 2023
1 parent af36feb commit 74a6894
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 14 deletions.
2 changes: 2 additions & 0 deletions python-sdk/docs/astro/sql/operators/load_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Parameters to use when loading a file to a database table

Note that if you use ``if_exists='replace'``, the existing table will be dropped and the schema of the new data will be used.

#. **schema_exists** (default is False) - By default, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This query can be costly. This argument makes the SDK skip this check, since the user is informing the schema already exists.

#. **output_table** - This parameter defines the output table to load data to, which should be an instance of ``astro.sql.table.Table``. You can specify the schema of the table by providing a list of the instance of ``sqlalchemy.Column <https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column>`` to the ``columns`` parameter. If you don't specify a schema, it will be inferred using Pandas.

.. literalinclude:: ../../../../example_dags/example_load_file.py
Expand Down
23 changes: 23 additions & 0 deletions python-sdk/docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@ or by updating Airflow's configuration
redshift_default_schema = "redshift_tmp"
mssql_default_schema = "mssql_tmp"
Configuring if schemas existence should be checked and if the SDK should create them
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By default, during ``aql.load_file``, the SDK checks if the schema of the target table exists, and if not, it tries to create it. This type of check can be costly.

The configuration ``AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS`` allows users to inform the SDK that the schema already exists, skipping this check for all ``load_file`` tasks.

The user can also have a more granular control, by defining the ``load_file`` argument ``schema_exists`` on a per-task basis :ref:load_file.

Example of how to disable schema existence check using environment variables:

.. code:: ini
AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS = True
Or using Airflow's configuration file:

.. code:: ini
[astro_sdk]
load_table_schema_exists = True
Configuring the unsafe dataframe storage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The dataframes (generated by ``dataframe`` or ``transform`` operators) are stored in XCom table using pickling in the Airflow metadata database. Since this dataframe is defined by the user and if it is huge, it might potentially break Airflow's metadata DB by using all the available resources. Hence, unsafe dataframe storage should be set to ``True`` once you are aware of this risk and are OK with it. Alternatively, you could use a Custom XCom backend to store the XCom data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True, # Skip queries that check if the table schema exist
)

homes_data2 = load_file(
Expand All @@ -96,6 +97,7 @@ def example_snowflake_partial_table_with_append():
schema=os.getenv("SNOWFLAKE_SCHEMA"),
),
),
schema_exists=True,
)

# Define task dependencies
Expand Down
17 changes: 13 additions & 4 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from astro.files.types.base import FileType as FileTypeConstants
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK, LOAD_TABLE_AUTODETECT_ROWS_COUNT, SCHEMA
from astro.settings import (
LOAD_FILE_ENABLE_NATIVE_FALLBACK,
LOAD_TABLE_AUTODETECT_ROWS_COUNT,
LOAD_TABLE_SCHEMA_EXISTS,
SCHEMA,
)
from astro.table import BaseTable, Metadata
from astro.utils.compat.functools import cached_property

Expand Down Expand Up @@ -359,7 +364,7 @@ def drop_table(self, table: BaseTable) -> None:
# Table load methods
# ---------------------------------------------------------

def create_schema_and_table_if_needed(
def create_table_if_needed(
self,
table: BaseTable,
file: File,
Expand Down Expand Up @@ -393,7 +398,6 @@ def create_schema_and_table_if_needed(
):
return

self.create_schema_if_needed(table.metadata.schema)
if if_exists == "replace" or not self.table_exists(table):
files = resolve_file_path_pattern(
file.path,
Expand Down Expand Up @@ -449,6 +453,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
**kwargs,
):
"""
Expand All @@ -465,6 +470,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
"""
normalize_config = normalize_config or {}
if self.check_for_minio_connection(input_file=input_file):
Expand All @@ -474,7 +480,10 @@ def load_file_to_table(
)
use_native_support = False

self.create_schema_and_table_if_needed(
if not schema_exists:
self.create_schema_if_needed(output_table.metadata.schema)

self.create_table_if_needed(
file=input_file,
table=output_table,
columns_names_capitalization=columns_names_capitalization,
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from astro.files import File
from astro.options import LoadOptions
from astro.query_modifier import QueryModifier
from astro.settings import LOAD_TABLE_SCHEMA_EXISTS
from astro.table import BaseTable, Metadata


Expand Down Expand Up @@ -123,6 +124,7 @@ def load_file_to_table(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = None,
schema_exists: bool = LOAD_TABLE_SCHEMA_EXISTS,
databricks_job_name: str = "",
**kwargs,
):
Expand All @@ -142,7 +144,7 @@ def load_file_to_table(
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
"""
load_file_to_delta(
input_file=input_file,
Expand Down
14 changes: 13 additions & 1 deletion python-sdk/src/astro/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,22 @@
section=SECTION_KEY, key="load_table_autodetect_rows_count", fallback=1000
)


#: Reduce responses sizes returned by aql.run_raw_sql to avoid trashing the Airflow DB if the BaseXCom is used.
RAW_SQL_MAX_RESPONSE_SIZE = conf.getint(section=SECTION_KEY, key="run_raw_sql_response_size", fallback=-1)

# Temp changes
# Should Astro SDK automatically add inlets/outlets to take advantage of Airflow 2.4 Data-aware scheduling
AUTO_ADD_INLETS_OUTLETS = conf.getboolean(SECTION_KEY, "auto_add_inlets_outlets", fallback=True)

LOAD_TABLE_SCHEMA_EXISTS = False


def reload():
"""
Reload settings from environment variable during runtime.
"""
global LOAD_TABLE_SCHEMA_EXISTS # skipcq: PYL-W0603
LOAD_TABLE_SCHEMA_EXISTS = conf.getboolean(SECTION_KEY, "load_table_schema_exists", fallback=False)


reload()
8 changes: 6 additions & 2 deletions python-sdk/src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airflow.hooks.base import BaseHook
from airflow.models.xcom_arg import XComArg

from astro import settings
from astro.airflow.datasets import kwargs_with_datasets
from astro.constants import DEFAULT_CHUNK_SIZE, ColumnCapitalization, LoadExistStrategy
from astro.databases import create_database
Expand All @@ -21,7 +22,6 @@
from astro.dataframes.pandas import PandasDataframe
from astro.files import File, resolve_file_path_pattern
from astro.options import LoadOptions, LoadOptionsList
from astro.settings import LOAD_FILE_ENABLE_NATIVE_FALLBACK
from astro.sql.operators.base_operator import AstroSQLBaseOperator
from astro.table import BaseTable
from astro.utils.compat.typing import Context
Expand All @@ -47,6 +47,7 @@ class LoadFileOperator(AstroSQLBaseOperator):
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param schema_exists: Declare the table schema already exists and that load_file should not check if it exists
:return: If ``output_table`` is passed this operator returns a Table object. If not
passed, returns a dataframe.
Expand All @@ -65,7 +66,8 @@ def __init__(
native_support_kwargs: dict | None = None,
load_options: LoadOptions | list[LoadOptions] | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
enable_native_fallback: bool | None = settings.LOAD_FILE_ENABLE_NATIVE_FALLBACK,
schema_exists: bool = settings.LOAD_TABLE_SCHEMA_EXISTS,
**kwargs,
) -> None:
kwargs.setdefault("task_id", get_unique_task_id("load_file"))
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(
self.native_support_kwargs: dict[str, Any] = native_support_kwargs or {}
self.columns_names_capitalization = columns_names_capitalization
self.enable_native_fallback = enable_native_fallback
self.schema_exists = schema_exists
self.load_options_list = LoadOptionsList(load_options)

def execute(self, context: Context) -> BaseTable | File: # skipcq: PYL-W0613
Expand Down Expand Up @@ -159,6 +162,7 @@ def load_data_to_table(self, input_file: File, context: Context) -> BaseTable:
native_support_kwargs=self.native_support_kwargs,
columns_names_capitalization=self.columns_names_capitalization,
enable_native_fallback=self.enable_native_fallback,
schema_exists=self.schema_exists,
databricks_job_name=f"Load data {self.dag_id}_{self.task_id}",
)
self.log.info("Completed loading the data into %s.", self.output_table)
Expand Down
39 changes: 33 additions & 6 deletions python-sdk/tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
SNOWFLAKE_STORAGE_INTEGRATION_AMAZON = SNOWFLAKE_STORAGE_INTEGRATION_AMAZON or "aws_int_python_sdk"
SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE = SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE or "gcs_int_python_sdk"

LOCAL_CSV_FILE = str(CWD.parent / "data/homes_main.csv")


def test_stage_set_name_after():
stage = SnowflakeStage()
Expand Down Expand Up @@ -111,11 +113,10 @@ def test_load_file_to_table_natively_for_fallback_raises_exception_if_not_enable


def test_snowflake_load_options():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(
conn_id="fake-conn", load_options=SnowflakeLoadOptions(file_options={"foo": "bar"})
)
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
Expand All @@ -132,9 +133,8 @@ def test_snowflake_load_options():


def test_snowflake_load_options_default():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn", load_options=SnowflakeLoadOptions())
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
Expand All @@ -151,8 +151,7 @@ def test_snowflake_load_options_default():


def test_snowflake_load_options_wrong_options():
path = str(CWD) + "/../../data/homes_main.csv"
file = File(path)
file = File(path=LOCAL_CSV_FILE)
with pytest.raises(ValueError, match="Error: Requires a SnowflakeLoadOptions"):
database = SnowflakeDatabase(conn_id="fake-conn", load_options=LoadOptions())
database.load_file_to_table_natively(source_file=file, target_table=Table())
Expand Down Expand Up @@ -211,3 +210,31 @@ def test_storage_integrations_params_in_load_options():
database.load_file_to_table_natively(source_file=file, target_table=table)

assert create_stage.call_args.kwargs["storage_integration"] == "some_integrations"


def test_load_file_to_table_by_default_checks_schema():
database = SnowflakeDatabase(conn_id="fake-conn")
database.run_sql = MagicMock()
database.hook = MagicMock()
database.create_table_using_schema_autodetection = MagicMock()

file_ = File(path=LOCAL_CSV_FILE)
table = Table(conn_id="fake-conn", metadata=Metadata(schema="abc"))
database.load_file_to_table(input_file=file_, output_table=table)
expected = (
"SELECT SCHEMA_NAME from information_schema.schemata WHERE LOWER(SCHEMA_NAME) = %(schema_name)s;"
)
assert database.hook.run.call_args_list[0].args[0] == expected
assert database.hook.run.call_args_list[0].kwargs["parameters"]["schema_name"] == "abc"


def test_load_file_to_table_skips_schema_check():
database = SnowflakeDatabase(conn_id="fake-conn")
database.run_sql = MagicMock()
database.hook = MagicMock()
database.create_table_using_schema_autodetection = MagicMock()

file_ = File(path=LOCAL_CSV_FILE)
table = Table(conn_id="fake-conn", metadata=Metadata(schema="abc"))
database.load_file_to_table(input_file=file_, output_table=table, schema_exists=True)
assert not database.hook.run.call_count
22 changes: 22 additions & 0 deletions python-sdk/tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import importlib
import os
from unittest.mock import patch

import astro
from astro import settings
from astro.files import File


def test_settings_load_table_schema_exists_default():
from astro.sql import LoadFileOperator

load_file = LoadFileOperator(input_file=File("dummy.csv"))
assert not load_file.schema_exists


@patch.dict(os.environ, {"AIRFLOW__ASTRO_SDK__LOAD_TABLE_SCHEMA_EXISTS": "True"})
def test_settings_load_table_schema_exists_override():
settings.reload()
importlib.reload(astro.sql.operators.load_file)
load_file = astro.sql.operators.load_file.LoadFileOperator(input_file=File("dummy.csv"))
assert load_file.schema_exists

0 comments on commit 74a6894

Please sign in to comment.