diff --git a/.circleci/manage-test-db.sh b/.circleci/manage-test-db.sh index fb772ec27..21fe49055 100755 --- a/.circleci/manage-test-db.sh +++ b/.circleci/manage-test-db.sh @@ -97,12 +97,16 @@ bigquery_init() { echo "$BIGQUERY_KEYFILE_CONTENTS" > $BIGQUERY_KEYFILE } -bigquery_up() { - echo "BigQuery doesnt support creating databases" -} -bigquery_down() { - echo "BigQuery doesnt support dropping databases" +# Clickhouse cloud +clickhouse-cloud_init() { + # note: the ping endpoint doesnt seem to need any API keys + until curl https://$CLICKHOUSE_CLOUD_HOST:8443/ping + do + echo "Pinging Clickhouse Cloud service to ensure it's not in idle mode..." + sleep 5 + done + echo "Clickhouse Cloud instance $CLICKHOUSE_CLOUD_HOST is up and running" } INIT_FUNC="${ENGINE}_init" @@ -118,10 +122,10 @@ fi echo "Initializing $ENGINE" $INIT_FUNC -if [ "$DIRECTION" == "up" ]; then +if [ "$DIRECTION" == "up" ] && function_exists $UP_FUNC; then echo "Creating database $DB_NAME" $UP_FUNC $DB_NAME -elif [ "$DIRECTION" == "down" ]; then +elif [ "$DIRECTION" == "down" ] && function_exists $DOWN_FUNC; then echo "Dropping database $DB_NAME" $DOWN_FUNC $DB_NAME fi diff --git a/Makefile b/Makefile index c2bda2e0c..24d9f6e54 100644 --- a/Makefile +++ b/Makefile @@ -211,7 +211,7 @@ redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD g pytest -n auto -x -m "redshift" --retries 3 --junitxml=test-results/junit-redshift.xml clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNAME guard-CLICKHOUSE_CLOUD_PASSWORD engine-clickhouse-install - pytest -n auto -x -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml + pytest -n 1 -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install pytest -n auto -x -m "athena" --retries 3 --retry-delay 10 --junitxml=test-results/junit-athena.xml \ No newline at end of file diff --git a/docs/cloud/features/observability/prod_environment.md b/docs/cloud/features/observability/prod_environment.md index 9ced172fb..71fc97dda 100644 --- a/docs/cloud/features/observability/prod_environment.md +++ b/docs/cloud/features/observability/prod_environment.md @@ -2,11 +2,34 @@ A data transformation system's most important component is the production environment, which provides the data your business runs on. -Tobiko Cloud makes it easy to understand your production environment, embedding three observability features directly on your project's homepage: +When you first log in to Tobiko Cloud, you'll see the production environment page. This page shows you at a glance if your data systems are working properly. + +It helps data teams quickly check their work without having to dig through complicated logs - just look at the visual dashboard, and you'll know if everything is running smoothly. + +![tcloud prod env](./prod_environment/tcloud_prod_environment.png) + +## When you might use this + +**After a production update** + +The dashboard helps you check if your recent updates to production are working correctly. It uses a simple color system to show you what's happening: green means everything is good, and red shows where there might be problems. + +If you see red in your current run, plan or freshness, it means there's a problem that needs your attention. Don't worry about red marks from the past (in the historical and previous runs/plans) - these are old issues that have already been fixed. + +Best part? You can check all of this in about 5-10 seconds. + +**Quick cost check** + +The homepage also displays cost metrics for your production environment, a feature exclusive to production (not available in development environments). This allows you to quickly understand and monitor your team's model execution costs without diving into detailed reports. + +## Observing production + +Tobiko Cloud makes it easy to understand your production environment, embedding four observability features directly on your project's homepage: 1. [Model Freshness chart](./model_freshness.md) 2. Runs and plans chart 3. Recent activity table +4. Warehouse costs overview ![tcloud prod env](./prod_environment/tcloud_prod_environment_labelled.png) @@ -14,7 +37,7 @@ Tobiko Cloud makes it easy to understand your production environment, embedding Model freshness has its own feature page - learn more [here](./model_freshness.md)! -## Runs and Plans Chart +### Runs and Plans Chart SQLMesh performs two primary actions: running the project's models on a cadence and applying plans to update the project's content/behavior. @@ -31,11 +54,13 @@ Each day displays zero or more vertical bars representing `run` duration. If no The chart's `y-axis` represents `run` duration. The height of each `run`'s bar corresponds to its duration, allowing you to quickly assess execution times. For example, consider the leftmost entry in the figure above: + - The label at the top of the chart shows that it represents November 26 - The entry consists of a single green bar, which tells us that one successful `run` occurred - The bottom of the bar begins at 0 seconds on the `y-axis`, and the top of the bar ends at 20 seconds, telling us the `run` took 20 seconds to execute In contrast, consider the rightmost entry in the figure above: + - The label at the top of the chart shows that it represents December 9 - The entry contains two green bars, which tells us that two successful `run`s occurred - The lower bar begins at 0 seconds on the `y-axis` and reaches up to 13 seconds, telling us the `run` took 13 seconds to execute @@ -43,7 +68,7 @@ In contrast, consider the rightmost entry in the figure above: Learn more about a `run` or `plan` by hovering over its bar, which displays a link to its page, its start and end times, and its duration. -## Recent Activity Table +### Recent Activity Table The recent activity table provides comprehensive information about recent project activities, displaying both `run`s and `plan`s in chronological order. This provides a more granular view than the runs and plans chart. @@ -51,4 +76,11 @@ For each activity entry, you can view its completion status, estimated cost of e ![tcloud recent activity](./prod_environment/recent_activity.png) -The table provides the ability to filter which rows are displayed by typing into the text box in the top right. This helps you locate specific information within the activity log, making it easier to find and analyze particular events or patterns in your system's operational history. \ No newline at end of file +The table provides the ability to filter which rows are displayed by typing into the text box in the top right. This helps you locate specific information within the activity log, making it easier to find and analyze particular events or patterns in your system's operational history. + +### Warehouse Costs Overview +Managing data warehouse costs can be complex. Tobiko Cloud simplifies this by monitoring costs directly. For BigQuery and Snowflake projects, it tracks cost estimates per model and calculates savings from avoided model reruns. + +The costs and savings summary information and chart display the costs to run and host all the models in your production environment over the last 30 days. This provides a great way to quickly see increases and decreases in daily running costs. To learn more, [check out the cost savings docs](../costs_savings.md). + +![tcloud recent activity](./prod_environment/costs.png) \ No newline at end of file diff --git a/docs/cloud/features/observability/prod_environment/costs.png b/docs/cloud/features/observability/prod_environment/costs.png new file mode 100644 index 000000000..469d802ad Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/costs.png differ diff --git a/docs/cloud/features/observability/prod_environment/recent_activity.png b/docs/cloud/features/observability/prod_environment/recent_activity.png index 998945843..727bfe21f 100644 Binary files a/docs/cloud/features/observability/prod_environment/recent_activity.png and b/docs/cloud/features/observability/prod_environment/recent_activity.png differ diff --git a/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png new file mode 100644 index 000000000..f1f6104f8 Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png differ diff --git a/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png index 9b2a7bc26..ca52c49f0 100644 Binary files a/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png and b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png differ diff --git a/docs/cloud/tcloud_getting_started.md b/docs/cloud/tcloud_getting_started.md index 75b811afb..305cc245a 100644 --- a/docs/cloud/tcloud_getting_started.md +++ b/docs/cloud/tcloud_getting_started.md @@ -1,5 +1,7 @@ # Tobiko Cloud: Getting Started +
+ Tobiko Cloud is a data platform that extends SQLMesh to make it easy to manage data at scale without the waste. We're here to make it easy to get started and feel confident that everything is working as expected. After you've completed the steps below, you'll have achieved the following: @@ -225,10 +227,10 @@ Now we're ready to connect your data warehouse to Tobiko Cloud: skip_pr_backfill: false enable_deploy_command: true auto_categorize_changes: - external: full - python: full - sql: full - seed: full + external: full + python: full + sql: full + seed: full # preview data for forward only models plan: diff --git a/docs/concepts/macros/jinja_macros.md b/docs/concepts/macros/jinja_macros.md index 3b9df4372..7018ecc9c 100644 --- a/docs/concepts/macros/jinja_macros.md +++ b/docs/concepts/macros/jinja_macros.md @@ -50,6 +50,30 @@ JINJA_STATEMENT_BEGIN; JINJA_END; ``` +## SQLMesh predefined variables + +SQLMesh provides multiple [predefined macro variables](./macro_variables.md) you may reference in jinja code. + +Some predefined variables provide information about the SQLMesh project itself, like the [`runtime_stage`](./macro_variables.md#runtime-variables) and [`this_model`](./macro_variables.md#runtime-variables) variables. + +Other predefined variables are [temporal](./macro_variables.md#temporal-variables), like `start_ds` and `execution_date`. They are used to build incremental model queries and are only available in incremental model kinds. + +Access predefined macro variables by passing their unquoted name in curly braces. For example, this demonstrates how to access the `start_ds` and `end_ds` variables: + +```sql linenums="1" +JINJA_QUERY_BEGIN; + +SELECT * +FROM table +WHERE time_column BETWEEN '{{ start_ds }}' and '{{ end_ds }}'; + +JINJA_END; +``` + +Because the two macro variables return string values, we must surround the curly braces with single quotes `'`. Other macro variables, such as `start_epoch`, return numeric values and do not require the single quotes. + +The `gateway` variable uses a slightly different syntax than other predefined variables because it is a function call. Instead of the bare name `{{ gateway }}`, it must include parentheses: `{{ gateway() }}`. + ## User-defined variables SQLMesh supports two kinds of user-defined macro variables: global and local. diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index 3fab383b5..431ce1560 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -126,6 +126,8 @@ SQLMesh provides two other predefined variables used to modify model behavior ba * 'loading' - The project is being loaded into SQLMesh's runtime context. * 'creating' - The model tables are being created. * 'evaluating' - The model query logic is being evaluated. + * 'promoting' - The model is being promoted in the target environment (virtual layer update). + * 'auditing' - The audit is being run. * 'testing' - The model query logic is being evaluated in the context of a unit test. * @gateway - A string value containing the name of the current [gateway](../../guides/connections.md). * @this_model - A string value containing the name of the physical table the model view selects from. Typically used to create [generic audits](../audits.md#generic-audits). In the case of [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements) it contains the qualified view name instead. diff --git a/docs/faq/faq.md b/docs/faq/faq.md index ed072bb74..e67e7acc1 100644 --- a/docs/faq/faq.md +++ b/docs/faq/faq.md @@ -128,7 +128,11 @@ SQLMesh’s `plan` command is the primary tool for understanding the effects of changes you make to your project. If your project files have changed or are different from the state of an environment, you execute `sqlmesh plan [environment name]` to synchronize the environment's state with your project files. `sqlmesh plan` will generate a summary of the actions needed to implement the changes, automatically run unit tests, and prompt you to `apply` the plan and implement the changes. - If your project files have not changed, you execute `sqlmesh run` to run your project's models and audits. You can execute `sqlmesh run` yourself or with the native [Airflow integration](../integrations/airflow.md). If running it yourself, a sensible approach is to use Linux’s `cron` tool to execute `sqlmesh run` on a cadence at least as frequent as your briefest SQLMesh model `cron` parameter. For example, if your most frequent model’s `cron` is hour, your `cron` tool should execute `sqlmesh run` at least every hour. + If your project files have not changed, you execute `sqlmesh run` to run your project's models and audits. + + `sqlmesh run` does not use models, macros, or audits from your local project files. Everything it executes is based on the model, macro, and audit versions currently promoted in the target environment. Those versions are stored in the metadata SQLMesh captures about the state of your environment. + + A sensible approach to executing `sqlmesh run` is to use Linux’s `cron` tool to execute `sqlmesh run` on a cadence at least as frequent as your briefest SQLMesh model `cron` parameter. For example, if your most frequent model’s `cron` is hour, your `cron` tool should execute `sqlmesh run` at least every hour. ??? question "What are start date and end date for?" SQLMesh uses the ["intervals" approach](https://tobikodata.com/data_load_patterns_101.html) to determine the date ranges that should be included in an incremental by time model query. It divides time into disjoint intervals and tracks which intervals have ever been processed. diff --git a/docs/integrations/engines/clickhouse.md b/docs/integrations/engines/clickhouse.md index b5d5eb55c..2f6627c31 100644 --- a/docs/integrations/engines/clickhouse.md +++ b/docs/integrations/engines/clickhouse.md @@ -394,6 +394,26 @@ If a model has many records in each partition, you may see additional performanc ## Local/Built-in Scheduler **Engine Adapter Type**: `clickhouse` +## Airflow Scheduler +**Engine Name:** `clickhouse` + +In order to share a common implementation across local and Airflow, SQLMesh ClickHouse implements its own hook and operator. + +By default, the connection ID is set to `sqlmesh_clickhouse_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: +```python linenums="1" +from sqlmesh.schedulers.airflow import NO_DEFAULT_CATALOG + +sqlmesh_airflow = SQLMeshAirflow( + "clickhouse", + default_catalog=NO_DEFAULT_CATALOG, + engine_operator_args={ + "sqlmesh_clickhouse_conn_id": "" + }, +) +``` + +Note: `NO_DEFAULT_CATALOG` is required for ClickHouse since ClickHouse doesn't support catalogs. + ### Connection options | Option | Description | Type | Required | diff --git a/docs/integrations/engines/databricks.md b/docs/integrations/engines/databricks.md index c31a45261..32157bc60 100644 --- a/docs/integrations/engines/databricks.md +++ b/docs/integrations/engines/databricks.md @@ -14,9 +14,9 @@ SQLMesh connects to Databricks with the [Databricks SQL Connector](https://docs. The SQL Connector is bundled with SQLMesh and automatically installed when you include the `databricks` extra in the command `pip install "sqlmesh[databricks]"`. -The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models locally (the default SQLMesh approach). +The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models that do not return PySpark DataFrames. -The SQL Connector does not support Databricks Serverless Compute. If you require Serverless Compute then you must use the Databricks Connect library. +If you have Python models returning PySpark DataFrames, check out the [Databricks Connect](#databricks-connect-1) section. ### Databricks Connect @@ -229,7 +229,9 @@ If you want Databricks to process PySpark DataFrames in SQLMesh Python models, t SQLMesh **DOES NOT** include/bundle the Databricks Connect library. You must [install the version of Databricks Connect](https://docs.databricks.com/en/dev-tools/databricks-connect/python/install.html) that matches the Databricks Runtime used in your Databricks cluster. -SQLMesh's Databricks Connect implementation supports Databricks Runtime 13.0 or higher. If SQLMesh detects that you have Databricks Connect installed, then it will use it for all Python models (both Pandas and PySpark DataFrames). +If SQLMesh detects that you have Databricks Connect installed, then it will automatically configure the connection and use it for all Python models that return a Pandas or PySpark DataFrame. + +To have databricks-connect installed but ignored by SQLMesh, set `disable_databricks_connect` to `true` in the connection configuration. Databricks Connect can execute SQL and DataFrame operations on different clusters by setting the SQLMesh `databricks_connect_*` connection options. For example, these options could configure SQLMesh to run SQL on a [Databricks SQL Warehouse](https://docs.databricks.com/sql/admin/create-sql-warehouse.html) while still routing DataFrame operations to a normal Databricks Cluster. @@ -259,7 +261,7 @@ The only relevant SQLMesh configuration parameter is the optional `catalog` para | `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N | | `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N | | `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N | -| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect. If using serverless then SQL connector is disabled since Serverless is not supported for SQL Connector | bool | N | +| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect instead of `databricks_connect_cluster_id`. | bool | N | | `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N | | `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N | | `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N | diff --git a/docs/quickstart/cli.md b/docs/quickstart/cli.md index 0a4e220c2..403246a59 100644 --- a/docs/quickstart/cli.md +++ b/docs/quickstart/cli.md @@ -324,6 +324,7 @@ $ sqlmesh plan dev ====================================================================== Successfully Ran 1 tests against duckdb ---------------------------------------------------------------------- + New environment `dev` will be created from `prod` Differences from the `prod` environment: @@ -349,39 +350,36 @@ Models: Directly Modified: sqlmesh_example__dev.incremental_model (Non-breaking) └── Indirectly Modified Children: └── sqlmesh_example__dev.full_model (Indirect Non-breaking) -Models needing backfill (missing dates): -└── sqlmesh_example__dev.incremental_model: 2020-01-01 - 2023-05-31 -Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: +Models needing backfill: +└── sqlmesh_example__dev.incremental_model: [2020-01-01 - 2023-05-31] +Apply - Backfill Tables [y/n]: y ``` -Line 5 of the output states that a new environment `dev` will be created from the existing `prod` environment. +Line 6 of the output states that a new environment `dev` will be created from the existing `prod` environment. -Lines 7-13 summarize the differences between the modified model and the `prod` environment, detecting that we directly modified `incremental_model` and that `full_model` was indirectly modified because it selects from the incremental model. Note that the model schemas are `sqlmesh_example__dev`, indicating that they are being created in the `dev` environment. +Lines 8-14 summarize the differences between the modified model and the `prod` environment, detecting that we directly modified `incremental_model` and that `full_model` was indirectly modified because it selects from the incremental model. Note that the model schemas are `sqlmesh_example__dev`, indicating that they are being created in the `dev` environment. -On line 27, we see that SQLMesh automatically classified the change as `Non-breaking` because it understood that the change was additive (added a column not used by `full_model`) and did not invalidate any data already in `prod`. +On line 28, we see that SQLMesh automatically classified the change as `Non-breaking` because it understood that the change was additive (added a column not used by `full_model`) and did not invalidate any data already in `prod`. -Hit `Enter` at the prompt to backfill data from our start date `2020-01-01`. Another prompt will appear asking for a backfill end date; hit `Enter` to backfill until now. Finally, enter `y` and press `Enter` to apply the plan and execute the backfill: +Enter `y` at the prompt and press `Enter` to apply the plan and execute the backfill: ```bash linenums="1" -Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: -Enter the backfill end date (eg. '1 month ago', '2020-01-01') or blank to backfill up until now: Apply - Backfill Tables [y/n]: y Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 -All model versions have been created successfully +Model versions created successfully [1/1] sqlmesh_example__dev.incremental_model evaluated in 0.01s Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 - -All model batches have been executed successfully +Model batches executed successfully Virtually Updating 'dev' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 -The target environment has been updated successfully +Target environment updated successfully ``` -Line 8 of the output shows that SQLMesh applied the change and evaluated `sqlmesh_example__dev.incremental_model`. +Line 6 of the output shows that SQLMesh applied the change and evaluated `sqlmesh_example__dev.incremental_model`. SQLMesh did not need to backfill anything for the `full_model` since the change was `Non-breaking`. @@ -460,7 +458,7 @@ Directly Modified: sqlmesh_example.incremental_model (Non-breaking) Apply - Virtual Update [y/n]: y Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 -The target environment has been updated successfully +Target environment updated successfully Virtual Update executed successfully ``` diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 1c7997f25..a38d92bcf 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -61,15 +61,15 @@ Global variable values may be any of the data types in the table below or lists Configuration for the `sqlmesh plan` command. -| Option | Description | Type | Required | -| ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------: | :------: | -| `auto_categorize_changes` | Indicates whether SQLMesh should attempt to automatically [categorize](../concepts/plans.md#change-categories) model changes during plan creation per each model source type ([additional details](../guides/configuration.md#auto-categorize-changes)) | dict[string, string] | N | -| `include_unmodified` | Indicates whether to create views for all models in the target development environment or only for modified ones (Default: False) | boolean | N | -| `auto_apply` | Indicates whether to automatically apply a new plan after creation (Default: False) | boolean | N | -| `forward_only` | Indicates whether the plan should be [forward-only](../concepts/plans.md#forward-only-plans) (Default: False) | boolean | N | -| `enable_preview` | Indicates whether to enable [data preview](../concepts/plans.md#data-preview) for forward-only models when targeting a development environment (Default: False) | boolean | N | -| `no_diff` | Don't show diffs for changed models (Default: False) | boolean | N | -| `no_prompts` | Disables interactive prompts in CLI (Default: True) | boolean | N | +| Option | Description | Type | Required | +|---------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------:|:--------:| +| `auto_categorize_changes` | Indicates whether SQLMesh should attempt to automatically [categorize](../concepts/plans.md#change-categories) model changes during plan creation per each model source type ([additional details](../guides/configuration.md#auto-categorize-changes)) | dict[string, string] | N | +| `include_unmodified` | Indicates whether to create views for all models in the target development environment or only for modified ones (Default: False) | boolean | N | +| `auto_apply` | Indicates whether to automatically apply a new plan after creation (Default: False) | boolean | N | +| `forward_only` | Indicates whether the plan should be [forward-only](../concepts/plans.md#forward-only-plans) (Default: False) | boolean | N | +| `enable_preview` | Indicates whether to enable [data preview](../concepts/plans.md#data-preview) for forward-only models when targeting a development environment (Default: True, except for dbt projects where the target engine does not support cloning) | Boolean | N | +| `no_diff` | Don't show diffs for changed models (Default: False) | boolean | N | +| `no_prompts` | Disables interactive prompts in CLI (Default: True) | boolean | N | ## Run diff --git a/examples/multi/repo_1/config.yaml b/examples/multi/repo_1/config.yaml index f4e111275..6cce77d27 100644 --- a/examples/multi/repo_1/config.yaml +++ b/examples/multi/repo_1/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: diff --git a/examples/multi/repo_2/config.yaml b/examples/multi/repo_2/config.yaml index 6bd2063a8..0a127b2e7 100644 --- a/examples/multi/repo_2/config.yaml +++ b/examples/multi/repo_2/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: @@ -13,4 +13,4 @@ gateways: default_gateway: local model_defaults: - dialect: 'duckdb' \ No newline at end of file + dialect: 'duckdb' diff --git a/examples/multi/repo_2/models/e.sql b/examples/multi/repo_2/models/e.sql new file mode 100644 index 000000000..34d079332 --- /dev/null +++ b/examples/multi/repo_2/models/e.sql @@ -0,0 +1,7 @@ +MODEL ( + name silver.e +); + +SELECT + * EXCEPT(dup) +FROM bronze.a diff --git a/setup.py b/setup.py index 6f19583c7..2b07607d6 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ "rich[jupyter]", "ruamel.yaml", "setuptools; python_version>='3.12'", - "sqlglot[rs]~=26.3.9", + "sqlglot[rs]~=26.4.1", "tenacity", "time-machine", ], @@ -99,6 +99,7 @@ "trino", "types-croniter", "types-dateparser", + "types-PyMySQL", "types-python-dateutil", "types-pytz", "types-requests==2.28.8", @@ -132,7 +133,7 @@ "pymssql", ], "mysql": [ - "mysql-connector-python", + "pymysql", ], "mwaa": [ "boto3", diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index 20299ceec..4377c6d2d 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -136,12 +136,14 @@ def format(self, record: logging.LogRecord) -> str: def configure_logging( force_debug: bool = False, - ignore_warnings: bool = False, write_to_stdout: bool = False, write_to_file: bool = True, log_limit: int = c.DEFAULT_LOG_LIMIT, log_file_dir: t.Optional[t.Union[str, Path]] = None, ) -> None: + # Remove noisy grpc logs that are not useful for users + os.environ["GRPC_VERBOSITY"] = os.environ.get("GRPC_VERBOSITY", "NONE") + logger = logging.getLogger() debug = force_debug or debug_mode_enabled() @@ -149,12 +151,11 @@ def configure_logging( level = logging.DEBUG if debug else logging.INFO logger.setLevel(level) - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(CustomFormatter()) - stdout_handler.setLevel( - level if write_to_stdout else (logging.ERROR if ignore_warnings else logging.WARNING) - ) - logger.addHandler(stdout_handler) + if write_to_stdout: + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(CustomFormatter()) + stdout_handler.setLevel(level) + logger.addHandler(stdout_handler) log_file_dir = log_file_dir or c.DEFAULT_LOG_FILE_DIR log_path_prefix = Path(log_file_dir) / LOG_FILENAME_PREFIX diff --git a/sqlmesh/cli/example_project.py b/sqlmesh/cli/example_project.py index 9ad36bf9a..7080f4f49 100644 --- a/sqlmesh/cli/example_project.py +++ b/sqlmesh/cli/example_project.py @@ -48,12 +48,15 @@ def _gen_config( if isinstance(default_value, Enum): default_value = default_value.value elif not isinstance(default_value, PRIMITIVES): - default_value = None + default_value = "" required = field.is_required() or field_name == "type" - option_str = ( - f" {'# ' if not required else ''}{field_name}: {default_value or ''}\n" - ) + option_str = f" {'# ' if not required else ''}{field_name}: {default_value}\n" + + # specify the DuckDB database field so quickstart runs out of the box + if engine == "duckdb" and field_name == "database": + option_str = " database: db.db\n" + required = True if required: required_fields.append(option_str) @@ -74,22 +77,22 @@ def _gen_config( default_configs = { ProjectTemplate.DEFAULT: f"""gateways: - dev: + {dialect}: connection: {connection_settings} -default_gateway: dev +default_gateway: {dialect} model_defaults: dialect: {dialect} start: {start or yesterday_ds()} """, ProjectTemplate.AIRFLOW: f"""gateways: - dev: + {dialect}: connection: {connection_settings} -default_gateway: dev +default_gateway: {dialect} default_scheduler: type: airflow diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 01fae0900..4ea22b447 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -12,6 +12,7 @@ from sqlmesh.cli import options as opt from sqlmesh.cli.example_project import ProjectTemplate, init_example_project from sqlmesh.core.analytics import cli_analytics +from sqlmesh.core.console import configure_console, get_console from sqlmesh.core.config import load_configs from sqlmesh.core.context import Context from sqlmesh.utils.date import TimeLike, time_like_to_str @@ -91,9 +92,8 @@ def cli( configs = load_configs(config, Context.CONFIG_TYPE, paths) log_limit = list(configs.values())[0].log_limit - configure_logging( - debug, ignore_warnings, log_to_stdout, log_limit=log_limit, log_file_dir=log_file_dir - ) + configure_logging(debug, log_to_stdout, log_limit=log_limit, log_file_dir=log_file_dir) + configure_console(ignore_warnings=ignore_warnings) try: context = Context( @@ -433,7 +433,7 @@ def plan( select_models = kwargs.pop("select_model") or None allow_destructive_models = kwargs.pop("allow_destructive_model") or None backfill_models = kwargs.pop("backfill_model") or None - context.console.verbose = verbose + setattr(get_console(), "verbose", verbose) context.plan( environment, restate_models=restate_models, diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index bf22abe22..45a9a06a1 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -75,11 +75,6 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]: """kwargs that are for execution config only""" return {} - @property - def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: - """Key-value arguments that will be passed during cursor construction.""" - return None - @property def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: """A function that is called to initialize the cursor""" @@ -115,7 +110,6 @@ def create_engine_adapter(self, register_comments_override: bool = False) -> Eng return self._engine_adapter( self._connection_factory_with_kwargs, multithreaded=self.concurrent_tasks > 1, - cursor_kwargs=self._cursor_kwargs, default_catalog=self.get_catalog(), cursor_init=self._cursor_init, register_comments=register_comments_override or self.register_comments, @@ -623,6 +617,12 @@ class DatabricksConnectionConfig(ConnectionConfig): @model_validator(mode="before") def _databricks_connect_validator(cls, data: t.Any) -> t.Any: + # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. + # Disabling this allows SQLMesh to determine what should be shown to the user. + # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show + # the user since it is expected if the table doesn't exist. Without this change the user would see the error. + logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) + if not isinstance(data, dict): return data @@ -641,10 +641,6 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any: data.get("auth_type"), ) - if databricks_connect_use_serverless: - data["force_databricks_connect"] = True - data["disable_databricks_connect"] = False - if (not server_hostname or not http_path or not access_token) and ( not databricks_connect_use_serverless and not auth_type ): @@ -666,11 +662,12 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any: data["databricks_connect_access_token"] = access_token if not data.get("databricks_connect_server_hostname"): data["databricks_connect_server_hostname"] = f"https://{server_hostname}" - if not databricks_connect_use_serverless: - if not data.get("databricks_connect_cluster_id"): - if t.TYPE_CHECKING: - assert http_path is not None - data["databricks_connect_cluster_id"] = http_path.split("/")[-1] + if not databricks_connect_use_serverless and not data.get( + "databricks_connect_cluster_id" + ): + if t.TYPE_CHECKING: + assert http_path is not None + data["databricks_connect_cluster_id"] = http_path.split("/")[-1] if auth_type: from databricks.sql.auth.auth import AuthType @@ -1208,7 +1205,9 @@ class MySQLConnectionConfig(ConnectionConfig): user: str password: str port: t.Optional[int] = None + database: t.Optional[str] = None charset: t.Optional[str] = None + collation: t.Optional[str] = None ssl_disabled: t.Optional[bool] = None concurrent_tasks: int = 4 @@ -1217,24 +1216,21 @@ class MySQLConnectionConfig(ConnectionConfig): type_: t.Literal["mysql"] = Field(alias="type", default="mysql") - @property - def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: - """Key-value arguments that will be passed during cursor construction.""" - return {"buffered": True} - @property def _connection_kwargs_keys(self) -> t.Set[str]: connection_keys = { "host", "user", "password", - "port", - "database", } if self.port is not None: connection_keys.add("port") + if self.database is not None: + connection_keys.add("database") if self.charset is not None: connection_keys.add("charset") + if self.collation is not None: + connection_keys.add("collation") if self.ssl_disabled is not None: connection_keys.add("ssl_disabled") return connection_keys @@ -1245,7 +1241,7 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: @property def _connection_factory(self) -> t.Callable: - from mysql.connector import connect + from pymysql import connect return connect @@ -1776,7 +1772,7 @@ def _connection_config_validator( return parse_connection_config(v) -connection_config_validator = field_validator( +connection_config_validator: t.Callable = field_validator( "connection", "state_connection", "test_connection", diff --git a/sqlmesh/core/config/plan.py b/sqlmesh/core/config/plan.py index 0982d96d4..cac0b3fd7 100644 --- a/sqlmesh/core/config/plan.py +++ b/sqlmesh/core/config/plan.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.config.categorizer import CategorizerConfig @@ -23,7 +25,7 @@ class PlanConfig(BaseConfig): forward_only: bool = False auto_categorize_changes: CategorizerConfig = CategorizerConfig() include_unmodified: bool = False - enable_preview: bool = False + enable_preview: t.Optional[bool] = None no_diff: bool = False no_prompts: bool = True auto_apply: bool = False diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index da785c34d..7b0881df6 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -4,7 +4,6 @@ import re import typing as t import zlib -import logging from pydantic import Field from sqlglot import exp @@ -13,6 +12,7 @@ from sqlmesh.cicd.config import CICDBotConfig from sqlmesh.core import constants as c +from sqlmesh.core.console import get_console from sqlmesh.core.config import EnvironmentSuffixTarget from sqlmesh.core.config.base import BaseConfig, UpdateStrategy from sqlmesh.core.config.common import variables_validator, compile_regex_mapping @@ -45,8 +45,6 @@ if t.TYPE_CHECKING: from sqlmesh.core._typing import Self -logger = logging.getLogger(__name__) - class Config(BaseConfig): """An object used by a Context to configure your SQLMesh project. @@ -175,13 +173,13 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any: ) if "physical_schema_override" in data: - logger.warning( - "`physical_schema_override` is deprecated. Please use `physical_schema_mapping` instead" + get_console().log_warning( + "`physical_schema_override` is deprecated. Please use `physical_schema_mapping` instead." ) if "physical_schema_mapping" in data: raise ConfigError( - "Only one of `physical_schema_override` and `physical_schema_mapping` can be specified" + "Only one of `physical_schema_override` and `physical_schema_mapping` can be specified." ) physical_schema_override: t.Dict[str, str] = data.pop("physical_schema_override") diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index 7ad4f75f2..1917a466d 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -1,7 +1,6 @@ from __future__ import annotations import abc -import logging import typing as t from pydantic import Field @@ -10,7 +9,7 @@ from sqlglot.helper import subclasses from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.config.common import concurrent_tasks_validator -from sqlmesh.core.console import Console +from sqlmesh.core.console import Console, get_console from sqlmesh.core.plan import ( AirflowPlanEvaluator, BuiltInPlanEvaluator, @@ -32,8 +31,6 @@ from sqlmesh.utils.config import sensitive_fields, excluded_fields -logger = logging.getLogger(__name__) - class SchedulerConfig(abc.ABC): """Abstract base class for Scheduler configurations.""" @@ -88,10 +85,10 @@ def create_state_sync(self, context: GenericContext) -> StateSync: ): # If we are using DuckDB, ensure that multithreaded mode gets enabled if necessary if warehouse_connection.concurrent_tasks > 1: - logger.warning( + get_console().log_warning( "The duckdb state connection is configured for single threaded mode but the warehouse connection is configured for " + f"multi threaded mode with {warehouse_connection.concurrent_tasks} concurrent tasks." - + " This can cause SQLMesh to hang. Overriding the duckdb state connection config to use multi threaded mode" + + " This can cause SQLMesh to hang. Overriding the duckdb state connection config to use multi threaded mode." ) # this triggers multithreaded mode and has to happen before the engine adapter is created below state_connection.concurrent_tasks = warehouse_connection.concurrent_tasks @@ -109,7 +106,7 @@ def create_state_sync(self, context: GenericContext) -> StateSync: warehouse_connection, DuckDBConnectionConfig ): if not state_connection.is_recommended_for_state_sync: - logger.warning( + get_console().log_warning( f"The {state_connection.type_} engine is not recommended for storing SQLMesh state in production deployments. Please see" + " https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#state-connection for a list of recommended engines and more information." ) @@ -205,7 +202,7 @@ def get_default_catalog(self, context: GenericContext) -> t.Optional[str]: def _max_snapshot_ids_per_request_validator(v: t.Any) -> t.Optional[int]: - logger.warning( + get_console().log_warning( "The `max_snapshot_ids_per_request` field is deprecated and will be removed in a future release." ) return None diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index a71848cd4..64ed61626 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -5,6 +5,7 @@ import typing as t import unittest import uuid +import logging from hyperscript import h from rich.console import Console as RichConsole @@ -47,6 +48,9 @@ LayoutWidget = t.TypeVar("LayoutWidget", bound=t.Union[widgets.VBox, widgets.HBox]) +logger = logging.getLogger(__name__) + + SNAPSHOT_CHANGE_CATEGORY_STR = { None: "Unknown", SnapshotChangeCategory.BREAKING: "Breaking", @@ -245,6 +249,10 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: def log_error(self, message: str) -> None: """Display error info to the user.""" + @abc.abstractmethod + def log_warning(self, message: str) -> None: + """Display warning info to the user.""" + @abc.abstractmethod def log_success(self, message: str) -> None: """Display a general successful message to the user.""" @@ -279,6 +287,152 @@ def _limit_model_names(self, tree: Tree, verbose: bool = False) -> Tree: return tree +class NoopConsole(Console): + def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: + pass + + def stop_plan_evaluation(self) -> None: + pass + + def start_evaluation_progress( + self, + batches: t.Dict[Snapshot, int], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: + pass + + def update_snapshot_evaluation_progress( + self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + ) -> None: + pass + + def stop_evaluation_progress(self, success: bool = True) -> None: + pass + + def start_creation_progress( + self, + total_tasks: int, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: + pass + + def stop_creation_progress(self, success: bool = True) -> None: + pass + + def start_cleanup(self, ignore_ttl: bool) -> bool: + return True + + def update_cleanup_progress(self, object_name: str) -> None: + pass + + def stop_cleanup(self, success: bool = True) -> None: + pass + + def start_promotion_progress( + self, + total_tasks: int, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def update_promotion_progress(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: + pass + + def stop_promotion_progress(self, success: bool = True) -> None: + pass + + def start_snapshot_migration_progress(self, total_tasks: int) -> None: + pass + + def update_snapshot_migration_progress(self, num_tasks: int) -> None: + pass + + def log_migration_status(self, success: bool = True) -> None: + pass + + def stop_snapshot_migration_progress(self, success: bool = True) -> None: + pass + + def start_env_migration_progress(self, total_tasks: int) -> None: + pass + + def update_env_migration_progress(self, num_tasks: int) -> None: + pass + + def stop_env_migration_progress(self, success: bool = True) -> None: + pass + + def show_model_difference_summary( + self, + context_diff: ContextDiff, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + no_diff: bool = True, + ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, + ) -> None: + pass + + def plan( + self, + plan_builder: PlanBuilder, + auto_apply: bool, + default_catalog: t.Optional[str], + no_diff: bool = False, + no_prompts: bool = False, + ) -> None: + if auto_apply: + plan_builder.apply() + + def log_test_results( + self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + ) -> None: + pass + + def show_sql(self, sql: str) -> None: + pass + + def log_status_update(self, message: str) -> None: + pass + + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + pass + + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + pass + + def log_error(self, message: str) -> None: + pass + + def log_warning(self, message: str) -> None: + logger.warning(message) + + def log_success(self, message: str) -> None: + pass + + def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: + return uuid.uuid4() + + def loading_stop(self, id: uuid.UUID) -> None: + pass + + def show_schema_diff(self, schema_diff: SchemaDiff) -> None: + pass + + def show_row_diff( + self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False + ) -> None: + pass + + def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> Progress: return Progress( TextColumn(f"[bold blue]{message}", justify="right"), @@ -302,6 +456,7 @@ def __init__( console: t.Optional[RichConsole] = None, verbose: bool = False, dialect: DialectType = None, + ignore_warnings: bool = False, **kwargs: t.Any, ) -> None: self.console: RichConsole = console or srich.console @@ -333,6 +488,7 @@ def __init__( self.verbose = verbose self.dialect = dialect + self.ignore_warnings = ignore_warnings def _print(self, value: t.Any, **kwargs: t.Any) -> None: self.console.print(value, **kwargs) @@ -899,7 +1055,7 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N missing_intervals = plan.missing_intervals if not missing_intervals: return - backfill = Tree("[bold]Models needing backfill \\[missing dates]:") + backfill = Tree("[bold]Models needing backfill:[/bold]") for missing in missing_intervals: snapshot = plan.context_diff.snapshots[missing.snapshot_id] if not snapshot.is_model: @@ -1030,6 +1186,11 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: def log_error(self, message: str) -> None: self._print(f"[red]{message}[/red]") + def log_warning(self, message: str) -> None: + logger.warning(message) + if not self.ignore_warnings: + self._print(f"[yellow]{message}[/yellow]") + def log_success(self, message: str) -> None: self._print(f"\n[green]{message}[/green]\n") @@ -1136,12 +1297,12 @@ def show_row_diff( columns: dict[str, list[str]] = {} source_prefix, source_name = ( (f"{source_name}__", source_name) - if source_name != row_diff.source + if source_name.lower() != row_diff.source.lower() else ("s__", "SOURCE") ) target_prefix, target_name = ( (f"{target_name}__", target_name) - if target_name != row_diff.target + if target_name.lower() != row_diff.target.lower() else ("t__", "TARGET") ) @@ -1628,6 +1789,9 @@ class MarkdownConsole(CaptureTerminalConsole): where you want to display a plan or test results in markdown. """ + def __init__(self, **kwargs: t.Any) -> None: + super().__init__(**{**kwargs, "console": RichConsole(no_color=True)}) + def show_model_difference_summary( self, context_diff: ContextDiff, @@ -1759,7 +1923,7 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N missing_intervals = plan.missing_intervals if not missing_intervals: return - self._print("\n**Models needing backfill \\[missing dates]:**") + self._print("\n**Models needing backfill:**") snapshots = [] for missing in missing_intervals: snapshot = plan.context_diff.snapshots[missing.snapshot_id] @@ -1854,7 +2018,10 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: self._print("```\n") def log_error(self, message: str) -> None: - super().log_error(f"```\n{message}```\n\n") + super().log_error(f"```\n\\[ERROR] {message}```\n\n") + + def log_warning(self, message: str) -> None: + super().log_warning(f"```\n\\[WARNING] {message}```\n\n") class DatabricksMagicConsole(CaptureTerminalConsole): @@ -2027,11 +2194,13 @@ def __init__( console: t.Optional[RichConsole], *args: t.Any, dialect: DialectType = None, + ignore_warnings: bool = False, **kwargs: t.Any, ) -> None: self.console: RichConsole = console or srich.console self.dialect = dialect self.verbose = False + self.ignore_warnings = ignore_warnings def _write(self, msg: t.Any, *args: t.Any, **kwargs: t.Any) -> None: self.console.log(msg, *args, **kwargs) @@ -2146,6 +2315,11 @@ def log_status_update(self, message: str) -> None: def log_error(self, message: str) -> None: self._write(message, style="bold red") + def log_warning(self, message: str) -> None: + logger.warning(message) + if not self.ignore_warnings: + self._write(message, style="bold yellow") + def log_success(self, message: str) -> None: self._write(message, style="bold green") @@ -2165,9 +2339,31 @@ def show_row_diff( self._write(row_diff) -def get_console(**kwargs: t.Any) -> TerminalConsole | DatabricksMagicConsole | NotebookMagicConsole: +_CONSOLE: Console = NoopConsole() + + +def set_console(console: Console) -> None: + """Sets the console instance.""" + global _CONSOLE + _CONSOLE = console + + +def configure_console(**kwargs: t.Any) -> None: + """Configures the console instance.""" + global _CONSOLE + _CONSOLE = create_console(**kwargs) + + +def get_console() -> Console: + """Returns the console instance or creates a new one if it hasn't been created yet.""" + return _CONSOLE + + +def create_console( + **kwargs: t.Any, +) -> TerminalConsole | DatabricksMagicConsole | NotebookMagicConsole: """ - Returns the console that is appropriate for the current runtime environment. + Creates a new console instance that is appropriate for the current runtime environment. Note: Google Colab environment is untested and currently assumes is compatible with the base NotebookMagicConsole. diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 63d6cf460..9869096b4 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -62,7 +62,7 @@ load_configs, ) from sqlmesh.core.config.loader import C -from sqlmesh.core.console import Console, get_console +from sqlmesh.core.console import get_console from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.dialect import ( format_model_expressions, @@ -180,7 +180,7 @@ def default_catalog(self) -> t.Optional[str]: raise NotImplementedError def table(self, model_name: str) -> str: - logger.warning( + get_console().log_warning( "The SQLMesh context's `table` method is deprecated and will be removed " "in a future release. Please use the `resolve_table` method instead." ) @@ -329,12 +329,12 @@ def __init__( concurrent_tasks: t.Optional[int] = None, loader: t.Optional[t.Type[Loader]] = None, load: bool = True, - console: t.Optional[Console] = None, users: t.Optional[t.List[User]] = None, ): self.configs = ( config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths) ) + self._projects = {config.project for config in self.configs.values()} self.dag: DAG[str] = DAG() self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits") @@ -384,7 +384,9 @@ def __init__( self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None - self.console = console or get_console(dialect=self.engine_adapter.dialect) + self.console = get_console() + setattr(self.console, "dialect", self.engine_adapter.dialect) + self._test_connection_config = self.config.get_test_connection( self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT ) @@ -552,7 +554,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: """Load all files in the context's path.""" load_start_ts = time.perf_counter() - projects = [loader.load() for loader in self._loaders] + loaded_projects = [loader.load() for loader in self._loaders] self.dag = DAG() self._standalone_audits.clear() @@ -563,7 +565,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._requirements.clear() self._excluded_requirements.clear() - for project in projects: + for project in loaded_projects: self._jinja_macros = self._jinja_macros.merge(project.jinja_macros) self._macros.update(project.macros) self._models.update(project.models) @@ -573,15 +575,39 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self._requirements.update(project.requirements) self._excluded_requirements.update(project.excluded_requirements) + uncached = set() + + if any(self._projects): + prod = self.state_reader.get_environment(c.PROD) + + if prod: + for snapshot in self.state_reader.get_snapshots(prod.snapshots).values(): + if snapshot.node.project in self._projects: + uncached.add(snapshot.name) + else: + store = self._standalone_audits if snapshot.is_audit else self._models + store[snapshot.name] = snapshot.node # type: ignore + for model in self._models.values(): self.dag.add(model.fqn, model.depends_on) - # This topologically sorts the DAG & caches the result in-memory for later; - # we do it here to detect any cycles as early as possible and fail if needed - self.dag.sorted - if update_schemas: + for fqn in self.dag: + model = self._models.get(fqn) # type: ignore + + if not model or fqn in uncached: + continue + + # make a copy of remote models that depend on local models or in the downstream chain + # without this, a SELECT * FROM local will not propogate properly because the downstream + # model will get mutated (schema changes) but the object is the same as the remote cache + if any(dep in uncached for dep in model.depends_on): + uncached.add(fqn) + self._models.update({fqn: model.copy(update={"mapping_schema": {}})}) + continue + update_model_schemas(self.dag, models=self._models, context_path=self.path) + for model in self.models.values(): # The model definition can be validated correctly only after the schema is set. model.validate_definition() @@ -596,13 +622,8 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: self.default_dialect or "" } - project_types = { - c.DBT if loader.config.loader.__name__.lower().startswith(c.DBT) else c.NATIVE - for loader in self._loaders - } - analytics.collector.on_project_loaded( - project_type=c.HYBRID if len(project_types) > 1 else first(project_types), + project_type=self._project_type, models_count=len(self._models), audits_count=len(self._audits), standalone_audits_count=len(self._standalone_audits), @@ -1390,7 +1411,7 @@ def plan_builder( default_start=default_start, default_end=default_end, enable_preview=( - enable_preview if enable_preview is not None else self.config.plan.enable_preview + enable_preview if enable_preview is not None else self._plan_preview_enabled ), end_bounded=not run, ensure_finalized_snapshots=self.config.plan.use_finalized_state, @@ -1647,7 +1668,7 @@ def render_dag(self, path: str, select_models: t.Optional[t.Collection[str]] = N suffix = file_path.suffix if suffix != ".html": if suffix: - logger.warning( + get_console().log_warning( f"The extension {suffix} does not designate an html file. A file with a `.html` extension will be created instead." ) path = str(file_path.with_suffix(".html")) @@ -2109,63 +2130,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: def _snapshots( self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None ) -> t.Dict[str, Snapshot]: - projects = {config.project for config in self.configs.values()} - - if any(projects): - prod = self.state_reader.get_environment(c.PROD) - remote_snapshots = ( - { - snapshot.name: snapshot - for snapshot in self.state_reader.get_snapshots(prod.snapshots).values() - } - if prod - else {} - ) - else: - remote_snapshots = {} - - local_nodes = {**(models_override or self._models), **self._standalone_audits} - nodes = local_nodes.copy() - - for name, snapshot in remote_snapshots.items(): - if name not in nodes and snapshot.node.project not in projects: - nodes[name] = snapshot.node - def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]: snapshots: t.Dict[str, Snapshot] = {} fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {} for node in nodes.values(): - if node.fqn not in local_nodes and node.fqn in remote_snapshots: - ttl = remote_snapshots[node.fqn].ttl - else: - config = self.config_for_node(node) - ttl = config.snapshot_ttl + kwargs = {} + if node.project in self._projects: + kwargs["ttl"] = self.config_for_node(node).snapshot_ttl snapshot = Snapshot.from_node( node, nodes=nodes, cache=fingerprint_cache, - ttl=ttl, - config=self.config_for_node(node), + **kwargs, ) snapshots[snapshot.name] = snapshot return snapshots + nodes = {**(models_override or self._models), **self._standalone_audits} snapshots = _nodes_to_snapshots(nodes) stored_snapshots = self.state_reader.get_snapshots(snapshots.values()) unrestorable_snapshots = { snapshot for snapshot in stored_snapshots.values() - if snapshot.name in local_nodes and snapshot.unrestorable + if snapshot.name in nodes and snapshot.unrestorable } if unrestorable_snapshots: for snapshot in unrestorable_snapshots: logger.info( "Found a unrestorable snapshot %s. Restamping the model...", snapshot.name ) - node = local_nodes[snapshot.name] + node = nodes[snapshot.name] nodes[snapshot.name] = node.copy( update={"stamp": f"revert to {snapshot.identifier}"} ) @@ -2279,6 +2276,23 @@ def _select_models_for_run( result = set(dag.subdag(*result)) return result + @cached_property + def _project_type(self) -> str: + project_types = { + c.DBT if loader.__class__.__name__.lower().startswith(c.DBT) else c.NATIVE + for loader in self._loaders + } + return c.HYBRID if len(project_types) > 1 else first(project_types) + + @property + def _plan_preview_enabled(self) -> bool: + if self.config.plan.enable_preview is not None: + return self.config.plan.enable_preview + # It is dangerous to enable preview by default for dbt projects that rely on engines that don’t support cloning. + # Enabling previews in such cases can result in unintended full refreshes because dbt incremental models rely on + # the maximum timestamp value in the target table. + return self._project_type == c.NATIVE or self.engine_adapter.SUPPORTS_CLONING + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 99cee67c8..bf054600f 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -12,17 +12,16 @@ from __future__ import annotations -import logging import sys import typing as t from difflib import ndiff from functools import cached_property from sqlmesh.core import constants as c +from sqlmesh.core.console import get_console from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel - if sys.version_info >= (3, 12): from importlib import metadata else: @@ -34,8 +33,6 @@ IGNORED_PACKAGES = {"sqlmesh", "sqlglot"} -logger = logging.getLogger(__name__) - class ContextDiff(PydanticModel): """ContextDiff is an object representing the difference between two environments. @@ -116,7 +113,7 @@ def create( env = state_reader.get_environment(create_from.lower()) if not env and create_from != c.PROD: - logger.warning( + get_console().log_warning( f"The environment name '{create_from}' was passed to the `plan` command's `--create-from` argument, but '{create_from}' does not exist. Initializing new environment '{environment}' from scratch." ) @@ -396,7 +393,7 @@ def text_diff(self, name: str) -> str: try: return old.node.text_diff(new.node, rendered=self.diff_rendered) except SQLMeshError as e: - logger.warning("Failed to diff model '%s': %s", name, str(e)) + get_console().log_warning(f"Failed to diff model '{name}': {str(e)}.") return "" @@ -426,5 +423,7 @@ def _build_requirements( ): requirements[dist] = metadata.version(dist) except metadata.PackageNotFoundError: - logger.warning("Failed to find package for %s", lib) + from sqlmesh.core.console import get_console + + get_console().log_warning(f"Failed to find package for {lib}.") return requirements diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 5c80a2a2e..ef7c04fa4 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -329,9 +329,11 @@ def _parse_join( def _warn_unsupported(self: Parser) -> None: + from sqlmesh.core.console import get_console + sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context] - logger.warning( + get_console().log_warning( f"'{sql}' could not be semantically understood as it contains unsupported syntax, SQLMesh will treat the command as is. Note that any references to the model's " "underlying physical table can't be resolved in this case, consider using Jinja as explained here https://sqlmesh.readthedocs.io/en/stable/concepts/macros/macro_variables/#audit-only-variables" ) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 3f4b1b743..ad48eb80f 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -43,7 +43,11 @@ from sqlmesh.utils import columns_to_types_all_known, random_id from sqlmesh.utils.connection_pool import create_connection_pool from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column -from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError +from sqlmesh.utils.errors import ( + SQLMeshError, + UnsupportedCatalogOperationError, + MissingDefaultCatalogError, +) from sqlmesh.utils.pandas import columns_to_types_from_df if t.TYPE_CHECKING: @@ -109,7 +113,6 @@ def __init__( dialect: str = "", sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None, multithreaded: bool = False, - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, default_catalog: t.Optional[str] = None, execute_log_level: int = logging.DEBUG, @@ -120,7 +123,7 @@ def __init__( ): self.dialect = dialect.lower() or self.DIALECT self._connection_pool = create_connection_pool( - connection_factory, multithreaded, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init + connection_factory, multithreaded, cursor_init=cursor_init ) self._sql_gen_kwargs = sql_gen_kwargs or {} self._default_catalog = default_catalog @@ -186,7 +189,9 @@ def default_catalog(self) -> t.Optional[str]: return None default_catalog = self._default_catalog or self.get_current_catalog() if not default_catalog: - raise SQLMeshError("Could not determine a default catalog despite it being supported.") + raise MissingDefaultCatalogError( + "Could not determine a default catalog despite it being supported." + ) return default_catalog @property @@ -2362,7 +2367,7 @@ def _create_table_comment( self.execute(self._build_create_comment_table_exp(table, table_comment, table_kind)) except Exception: logger.warning( - f"Table comment for '{table.alias_or_name}' not registered - this may be due to limited permissions.", + f"Table comment for '{table.alias_or_name}' not registered - this may be due to limited permissions", exc_info=True, ) @@ -2389,7 +2394,7 @@ def _create_column_comments( self.execute(self._build_create_comment_column_exp(table, col, comment, table_kind)) except Exception: logger.warning( - f"Column comments for column '{col}' in table '{table.alias_or_name}' not registered - this may be due to limited permissions.", + f"Column comments for column '{col}' in table '{table.alias_or_name}' not registered - this may be due to limited permissions", exc_info=True, ) diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 184e1e319..bb38de631 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -import os import typing as t +from functools import partial import pandas as pd from sqlglot import exp @@ -17,10 +17,11 @@ from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection +from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError if t.TYPE_CHECKING: - from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core._typing import SchemaName, TableName, SessionProperties from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query logger = logging.getLogger(__name__) @@ -47,9 +48,9 @@ class DatabricksEngineAdapter(SparkEngineAdapter): }, ) - def __init__(self, *args: t.Any, **kwargs: t.Any): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - self._spark: t.Optional[PySparkSession] = None + self._set_spark_engine_adapter_if_needed() @classmethod def can_access_spark_session(cls, disable_spark_session: bool) -> bool: @@ -93,21 +94,43 @@ def _use_spark_session(self) -> bool: ) @property - def use_serverless(self) -> bool: - from sqlmesh import RuntimeEnv - from sqlmesh.utils import str_to_bool + def is_spark_session_connection(self) -> bool: + return isinstance(self.connection, SparkSessionConnection) - if not self._use_spark_session: - return False - return ( - RuntimeEnv.get().is_databricks and str_to_bool(os.environ.get("IS_SERVERLESS", "False")) - ) or bool(self._extra_config["databricks_connect_use_serverless"]) + def _set_spark_engine_adapter_if_needed(self) -> None: + self._spark_engine_adapter = None - @property - def is_spark_session_cursor(self) -> bool: - from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor + if not self._use_spark_session or self.is_spark_session_connection: + return - return isinstance(self.cursor, SparkSessionCursor) + from databricks.connect import DatabricksSession + + connect_kwargs = dict( + host=self._extra_config["databricks_connect_server_hostname"], + token=self._extra_config["databricks_connect_access_token"], + ) + if "databricks_connect_use_serverless" in self._extra_config: + connect_kwargs["serverless"] = True + else: + connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"] + + catalog = self._extra_config.get("catalog") + spark = ( + DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate() + ) + self._spark_engine_adapter = SparkEngineAdapter( + partial(connection, spark=spark, catalog=catalog), + default_catalog=catalog, + ) + + @property + def cursor(self) -> t.Any: + if ( + self._connection_pool.get_attribute("use_spark_engine_adapter") + and not self.is_spark_session_connection + ): + return self._spark_engine_adapter.cursor # type: ignore + return super().cursor @property def spark(self) -> PySparkSession: @@ -117,31 +140,22 @@ def spark(self) -> PySparkSession: "Either run from a Databricks Notebook or " "install `databricks-connect` and configure it to connect to your Databricks cluster." ) - - if self.is_spark_session_cursor: - return self._connection_pool.get().spark - - from databricks.connect import DatabricksSession - - if self._spark is None: - self._spark = ( - DatabricksSession.builder.remote( - host=self._extra_config["databricks_connect_server_hostname"], - token=self._extra_config["databricks_connect_access_token"], - cluster_id=self._extra_config["databricks_connect_cluster_id"], - ) - .userAgent("sqlmesh") - .getOrCreate() - ) - catalog = self._extra_config.get("catalog") - if catalog: - self.set_current_catalog(catalog) - return self._spark + if self.is_spark_session_connection: + return self.connection.spark + return self._spark_engine_adapter.spark # type: ignore @property def catalog_support(self) -> CatalogSupport: return CatalogSupport.FULL_SUPPORT + def _begin_session(self, properties: SessionProperties) -> t.Any: + """Begin a new session.""" + # Align the different possible connectors to a single catalog + self.set_current_catalog(self.default_catalog) # type: ignore + + def _end_session(self) -> None: + self._connection_pool.set_attribute("use_spark_engine_adapter", False) + def _df_to_source_queries( self, df: DF, @@ -157,14 +171,8 @@ def _df_to_source_queries( def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) - if self.use_serverless: - # Global temp views are not supported on Databricks Serverless - # This also means we can't mix Python SQL Connection and DB Connect since they wouldn't - # share the same temp objects. - df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore - else: - df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore - temp_table.set("db", "global_temp") + df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) + self._connection_pool.set_attribute("use_spark_engine_adapter", True) return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) if self._use_spark_session: @@ -175,16 +183,12 @@ def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" - if self.is_spark_session_cursor: + if self.is_spark_session_connection: return super()._fetch_native_df(query, quote_identifiers=quote_identifiers) - if self._use_spark_session: - sql = ( - self._to_sql(query, quote=quote_identifiers) - if isinstance(query, exp.Expression) - else query + if self._spark_engine_adapter: + return self._spark_engine_adapter._fetch_native_df( # type: ignore + query, quote_identifiers=quote_identifiers ) - self._log_sql(sql) - return self.spark.sql(sql) self.execute(query) return self.cursor.fetchall_arrow().to_pandas() @@ -200,37 +204,49 @@ def fetchdf( return df def get_current_catalog(self) -> t.Optional[str]: - # Update the Dataframe API if we have a spark session - if self._use_spark_session: + pyspark_catalog = None + sql_connector_catalog = None + if self._spark_engine_adapter: from py4j.protocol import Py4JError from pyspark.errors.exceptions.connect import SparkConnectGrpcException try: # Note: Spark 3.4+ Only API - return self.spark.catalog.currentCatalog() + pyspark_catalog = self._spark_engine_adapter.get_current_catalog() except (Py4JError, SparkConnectGrpcException): pass - result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) - if result: - return result[0] - return None + elif self.is_spark_session_connection: + pyspark_catalog = self.connection.spark.catalog.currentCatalog() + if not self.is_spark_session_connection: + result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) + sql_connector_catalog = result[0] if result else None + if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog: + logger.warning( + f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same." + ) + return pyspark_catalog or sql_connector_catalog def set_current_catalog(self, catalog_name: str) -> None: - # Since Databricks splits commands across the Dataframe API and the SQL Connector - # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both - # are set to the same catalog since they maintain their default catalog seperately - self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) - # Update the Dataframe API is we have a spark session - if self._use_spark_session: + def _set_spark_session_current_catalog(spark: PySparkSession) -> None: from py4j.protocol import Py4JError from pyspark.errors.exceptions.connect import SparkConnectGrpcException try: # Note: Spark 3.4+ Only API - self.spark.catalog.setCurrentCatalog(catalog_name) + spark.catalog.setCurrentCatalog(catalog_name) except (Py4JError, SparkConnectGrpcException): pass + # Since Databricks splits commands across the Dataframe API and the SQL Connector + # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both + # are set to the same catalog since they maintain their default catalog separately + self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) + if self.is_spark_session_connection: + _set_spark_session_current_catalog(self.connection.spark) + + if self._spark_engine_adapter: + _set_spark_session_current_catalog(self._spark_engine_adapter.spark) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: @@ -257,6 +273,15 @@ def clone_table( def wap_supported(self, table_name: TableName) -> bool: return False + @property + def default_catalog(self) -> t.Optional[str]: + try: + return super().default_catalog + except MissingDefaultCatalogError as e: + raise MissingDefaultCatalogError( + "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details" + ) from e + def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index f09e5b8a4..f80f1816a 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -221,8 +221,8 @@ def query_factory() -> Query: self._convert_df_datetime(df, columns_to_types_create) self.create_table(temp_table, columns_to_types_create) rows: t.List[t.Tuple[t.Any, ...]] = list( - df.replace({np.nan: None}).itertuples(index=False, name=None) - ) # type: ignore + df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore + ) conn = self._connection_pool.get() conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) # type: ignore diff --git a/sqlmesh/core/engine_adapter/shared.py b/sqlmesh/core/engine_adapter/shared.py index fb860754f..aec7786f1 100644 --- a/sqlmesh/core/engine_adapter/shared.py +++ b/sqlmesh/core/engine_adapter/shared.py @@ -317,6 +317,7 @@ def internal_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: catalog_name = expression.catalog if not catalog_name: return func(*list_args, **kwargs) + # If we have a catalog and this engine doesn't support catalogs then we need to error if catalog_support.is_unsupported: raise UnsupportedCatalogOperationError( diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 092df0652..3b2a203ea 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -65,8 +65,10 @@ def __init__(self, context: GenericContext, path: Path) -> None: try: gateway = self.config.get_gateway(gateway_name) except ConfigError: - logger.warning( - "Gateway '%s' not found in project '%s'", gateway_name, self.config.project + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Gateway '{gateway_name}' not found in project '{self.config.project}'." ) gateway = None self._variables = { diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 6320b509a..ce284d109 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -1,7 +1,6 @@ from __future__ import annotations import inspect -import logging import sys import types import typing as t @@ -57,13 +56,12 @@ UNION_TYPES = (t.Union,) -logger = logging.getLogger(__name__) - - class RuntimeStage(Enum): LOADING = "loading" CREATING = "creating" EVALUATING = "evaluating" + PROMOTING = "promoting" + AUDITING = "auditing" TESTING = "testing" @@ -773,7 +771,9 @@ def star( if exclude and not isinstance(exclude, (exp.Array, exp.Tuple)): raise SQLMeshError(f"Invalid exclude '{exclude}'. Expected an array.") if except_ != exp.tuple_(): - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( "The 'except_' argument in @STAR will soon be deprecated. Use 'exclude' instead." ) if not isinstance(exclude, (exp.Array, exp.Tuple)): @@ -1303,10 +1303,10 @@ def _coerce( except Exception: if strict: raise - logger.error( - "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", - expr, - typ, - path, + + from sqlmesh.core.console import get_console + + get_console().log_error( + f"Coercion of expression '{expr}' to type '{typ}' failed. Using non coerced expression at '{path}'", ) return expr diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index 339d0a094..4122d6d2f 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -300,7 +300,7 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ return v -expression_validator = field_validator( +expression_validator: t.Callable = field_validator( "query", "expressions_", "pre_statements_", @@ -312,7 +312,7 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ )(parse_expression) -bool_validator = field_validator( +bool_validator: t.Callable = field_validator( "skip", "blocking", "forward_only", @@ -327,7 +327,7 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ )(parse_bool) -properties_validator = field_validator( +properties_validator: t.Callable = field_validator( "physical_properties_", "virtual_properties_", "session_properties_", @@ -337,14 +337,14 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[ )(parse_properties) -default_catalog_validator = field_validator( +default_catalog_validator: t.Callable = field_validator( "default_catalog", mode="before", check_fields=False, )(default_catalog) -depends_on_validator = field_validator( +depends_on_validator: t.Callable = field_validator( "depends_on_", mode="before", check_fields=False, diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 1848d2435..17006f4b0 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from pathlib import Path import inspect @@ -18,6 +17,7 @@ create_python_model, create_sql_model, get_model_name, + render_meta_fields, ) from sqlmesh.core.model.kind import ModelKindName, _ModelKind from sqlmesh.utils import registry_decorator @@ -25,8 +25,6 @@ from sqlmesh.utils.metaprogramming import build_env, serialize_env -logger = logging.getLogger(__name__) - if t.TYPE_CHECKING: from sqlmesh.core.audit import ModelAudit @@ -108,7 +106,9 @@ def model( kind = self.kwargs.get("kind", None) if kind is not None: if isinstance(kind, _ModelKind): - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( f"""Python model "{self.name}"'s `kind` argument was passed a SQLMesh `{type(kind).__name__}` object. This may result in unexpected behavior - provide a dictionary instead.""" ) elif isinstance(kind, dict): @@ -119,6 +119,21 @@ def model( build_env(self.func, env=env, name=entrypoint, path=module_path) + rendered_fields = render_meta_fields( + fields={"name": self.name, **self.kwargs}, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + + rendered_name = rendered_fields["name"] + if isinstance(rendered_name, exp.Expression): + rendered_fields["name"] = rendered_name.sql(dialect=dialect) + common_kwargs = { "defaults": defaults, "path": path, @@ -134,7 +149,7 @@ def model( "macros": macros, "jinja_macros": jinja_macros, "audit_definitions": audit_definitions, - **self.kwargs, + **rendered_fields, } for key in ("pre_statements", "post_statements", "on_virtual_update"): @@ -147,5 +162,5 @@ def model( if self.is_sql: query = MacroFunc(this=exp.Anonymous(this=entrypoint)) - return create_sql_model(self.name, query, **common_kwargs) - return create_python_model(self.name, entrypoint, **common_kwargs) + return create_sql_model(query=query, **common_kwargs) + return create_python_model(entrypoint=entrypoint, **common_kwargs) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index f0925de34..d8fc276dd 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -62,6 +62,15 @@ logger = logging.getLogger(__name__) +RUNTIME_RENDERED_MODEL_FIELDS = { + "audits", + "signals", + "description", + "cron", + "physical_properties", + "merge_filter", +} + class _Model(ModelMeta, frozen=True): """Model is the core abstraction for user defined datasets. @@ -687,13 +696,23 @@ def text_diff(self, other: Node, rendered: bool = False) -> str: f"Cannot diff model '{self.name} against a non-model node '{other.name}'" ) - return d.text_diff( + text_diff = d.text_diff( self.render_definition(render_query=rendered), other.render_definition(render_query=rendered), self.dialect, other.dialect, ).strip() + if not text_diff and not rendered: + text_diff = d.text_diff( + self.render_definition(render_query=True), + other.render_definition(render_query=True), + self.dialect, + other.dialect, + ).strip() + + return text_diff + def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None: """Sets the default time format for a model. @@ -1255,8 +1274,12 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]: if query is None: return None + unknown = exp.DataType.build("unknown") + self._columns_to_types = { - select.output_name: select.type or exp.DataType.build("unknown") + # copy data type because it is used in the engine to build CTAS and other queries + # this can change the parent which will mess up the diffing algo + select.output_name: (select.type or unknown).copy() for select in query.selects } @@ -1351,9 +1374,16 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: # Can't determine if there's a breaking change if we can't render the query. return None - edits = diff( - previous_query, this_query, matchings=[(previous_query, this_query)], delta_only=True - ) + if previous_query is this_query: + edits = [] + else: + edits = diff( + previous_query, + this_query, + matchings=[(previous_query, this_query)], + delta_only=True, + copy=False, + ) inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)} for edit in edits: @@ -1802,7 +1832,7 @@ def load_sql_based_model( if kind_prop.name.lower() == "merge_filter": unrendered_merge_filter = kind_prop - meta_renderer = _meta_renderer( + rendered_meta_exprs = render_expression( expression=meta, module_path=module_path, macros=macros, @@ -1813,7 +1843,6 @@ def load_sql_based_model( default_catalog=default_catalog, ) - rendered_meta_exprs = meta_renderer.render() if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1: raise_config_error( f"Invalid MODEL statement:\n{meta.sql(dialect=dialect, pretty=True)}", @@ -2003,21 +2032,6 @@ def create_python_model( # Also remove self-references that are found dialect = kwargs.get("dialect") - renderer_kwargs = { - "module_path": module_path, - "macros": macros, - "jinja_macros": jinja_macros, - "variables": variables, - "path": path, - "dialect": dialect, - "default_catalog": kwargs.get("default_catalog"), - } - - name_renderer = _meta_renderer( - expression=d.parse_one(name, dialect=dialect), - **renderer_kwargs, # type: ignore - ) - name = t.cast(t.List[exp.Expression], name_renderer.render())[0].sql(dialect=dialect) dependencies_unspecified = depends_on is None @@ -2029,15 +2043,21 @@ def create_python_model( if dependencies_unspecified: depends_on = parsed_depends_on - {name} else: - depends_on_renderer = _meta_renderer( + depends_on_rendered = render_expression( expression=exp.Array( expressions=[d.parse_one(dep, dialect=dialect) for dep in depends_on or []] ), - **renderer_kwargs, # type: ignore + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=kwargs.get("default_catalog"), ) depends_on = { dep.sql(dialect=dialect) - for dep in t.cast(t.List[exp.Expression], depends_on_renderer.render())[0].expressions + for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions } variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} @@ -2361,7 +2381,61 @@ def _refs_to_sql(values: t.Any) -> exp.Expression: return exp.Tuple(expressions=values) -def _meta_renderer( +def render_meta_fields( + fields: t.Dict[str, t.Any], + module_path: Path, + path: Path, + jinja_macros: t.Optional[JinjaMacroRegistry], + macros: t.Optional[MacroRegistry], + dialect: DialectType, + variables: t.Optional[t.Dict[str, t.Any]], + default_catalog: t.Optional[str], +) -> t.Dict[str, t.Any]: + def render_field_value(value: t.Any) -> t.Any: + if isinstance(value, exp.Expression) or ( + isinstance(value, str) and d.SQLMESH_MACRO_PREFIX in value + ): + expression = exp.maybe_parse(value, dialect=dialect) + rendered_expr = render_expression( + expression=expression, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + if rendered_expr is None: + raise SQLMeshError( + f"Failed to render model attribute `{fields['name']}` at `{path}`\n" + f"'{expression.sql(dialect=dialect)}' must return an expression" + ) + if len(rendered_expr) != 1: + raise SQLMeshError( + f"Failed to render model attribute `{fields['name']}` at `{path}`.\n" + f"`{expression.sql(dialect=dialect)}` must return one result, but got {len(rendered_expr)}" + ) + return rendered_expr[0] + + return value + + for field_name, field_info in ModelMeta.all_field_infos().items(): + field = field_info.alias or field_name + if field not in RUNTIME_RENDERED_MODEL_FIELDS and (field_value := fields.get(field)): + if isinstance(field_value, dict): + for key in list(field_value.keys()): + if key not in RUNTIME_RENDERED_MODEL_FIELDS: + fields[field][key] = render_field_value(field_value[key]) + elif isinstance(field_value, list): + fields[field] = [render_field_value(value) for value in field_value] + else: + fields[field] = render_field_value(field_value) + + return fields + + +def render_expression( expression: exp.Expression, module_path: Path, path: Path, @@ -2370,7 +2444,7 @@ def _meta_renderer( dialect: DialectType = None, variables: t.Optional[t.Dict[str, t.Any]] = None, default_catalog: t.Optional[str] = None, -) -> ExpressionRenderer: +) -> t.Optional[t.List[exp.Expression]]: meta_python_env = make_python_env( expressions=expression, jinja_macro_references=None, @@ -2389,7 +2463,7 @@ def _meta_renderer( default_catalog=default_catalog, quote_identifiers=False, normalize_identifiers=False, - ) + ).render() META_FIELD_CONVERTER: t.Dict[str, t.Callable] = { diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index a60036f7f..1dafc3d0e 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -120,7 +120,7 @@ def is_symbolic(self) -> bool: @property def is_materialized(self) -> bool: - return not (self.is_symbolic or self.is_view) + return self.model_kind_name is not None and not (self.is_symbolic or self.is_view) @property def only_execution_time(self) -> bool: @@ -986,7 +986,7 @@ def _model_kind_validator(cls: t.Type, v: t.Any, info: t.Optional[ValidationInfo return create_model_kind(v, dialect, {}) -model_kind_validator = field_validator("kind", mode="before")(_model_kind_validator) +model_kind_validator: t.Callable = field_validator("kind", mode="before")(_model_kind_validator) def _property(name: str, value: t.Any) -> exp.Property: diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index a15ef229b..429551224 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from functools import cached_property from typing_extensions import Self @@ -46,8 +45,6 @@ FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] -logger = logging.getLogger(__name__) - class ModelMeta(_Node): """Metadata for models which can be defined in SQL.""" @@ -302,7 +299,9 @@ def _pre_root_validator(cls, data: t.Any) -> t.Any: if not isinstance(table_properties, str): # Do not warn when deserializing from the state. model_name = data["name"] - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( f"Model '{model_name}' is using the `table_properties` attribute which is deprecated. Please use `physical_properties` instead." ) physical_properties = data.get("physical_properties") @@ -337,8 +336,10 @@ def _root_validator(self) -> Self: "hudi", "delta", }: - logger.warning( - f"Model {self.name} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead" + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Model {self.name} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead." ) return self diff --git a/sqlmesh/core/model/schema.py b/sqlmesh/core/model/schema.py index 29486dfe0..86c628f61 100644 --- a/sqlmesh/core/model/schema.py +++ b/sqlmesh/core/model/schema.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from concurrent.futures import as_completed from pathlib import Path @@ -21,9 +20,6 @@ from sqlmesh.utils.dag import DAG -logger = logging.getLogger(__name__) - - def update_model_schemas( dag: DAG[str], models: UniqueKeyDict[str, Model], @@ -45,7 +41,9 @@ def _update_schema_with_model(schema: MappingSchema, model: Model) -> None: schema.add_table(model.fqn, columns_to_types, dialect=model.dialect) except SchemaError as e: if "nesting level:" in str(e): - logger.error( + from sqlmesh.core.console import get_console + + get_console().log_error( "SQLMesh requires all model names and references to have the same level of nesting." ) raise diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 589ffba89..595cb8d5e 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -493,7 +493,7 @@ def _check_destructive_changes(self, directly_modified: t.Set[SnapshotId]) -> No ) warning_msg = f"Plan results in a destructive change to forward-only model '{snapshot.name}'s schema{dropped_column_msg}." if snapshot.model.on_destructive_change.is_warn: - logger.warning(warning_msg) + get_console().log_warning(warning_msg) else: raise PlanError( f"{warning_msg} To allow this, change the model's `on_destructive_change` setting to `warn` or `allow` or include it in the plan's `--allow-destructive-model` option." diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index f2437018b..2139521b3 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -219,11 +219,23 @@ def _push( plan: The plan to source snapshots from. deployability_index: Indicates which snapshots are deployable in the context of this creation. """ - snapshots_to_create = [ - s - for s in snapshots.values() - if s.is_model and not s.is_symbolic and plan.is_selected_for_backfill(s.name) - ] + promoted_snapshot_ids = ( + set(plan.environment.promoted_snapshot_ids) + if plan.environment.promoted_snapshot_ids is not None + else None + ) + + def _should_create(s: Snapshot) -> bool: + if not s.is_model or s.is_symbolic: + return False + # Only create tables for snapshots that we're planning to promote or that were selected for backfill + return ( + plan.is_selected_for_backfill(s.name) + or promoted_snapshot_ids is None + or s.snapshot_id in promoted_snapshot_ids + ) + + snapshots_to_create = [s for s in snapshots.values() if _should_create(s)] completed = False progress_stopped = False diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 07ec5b4c0..8388a0322 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -142,7 +142,7 @@ def _render( "default_catalog": self._default_catalog, "runtime_stage": runtime_stage.value, "resolve_table": lambda table: self._resolve_table( - table, + d.normalize_model_name(table, self._default_catalog, self._dialect), snapshots=snapshots, table_mapping=table_mapping, deployability_index=deployability_index, @@ -534,6 +534,8 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: missing_deps.add(dep) if self._model_fqn and not should_optimize and any(s.is_star for s in query.selects): + from sqlmesh.core.console import get_console + deps = ", ".join(f"'{dep}'" for dep in sorted(missing_deps)) warning = ( @@ -545,7 +547,7 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: if self._validate_query: raise_config_error(warning, self._path) - logger.warning(warning) + get_console().log_warning(warning) try: if should_optimize: @@ -564,8 +566,10 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: ) ) except SqlglotError as ex: + from sqlmesh.core.console import get_console + warning = ( - f"{ex} for model '{self._model_fqn}', the column may not exist or is ambiguous" + f"{ex} for model '{self._model_fqn}', the column may not exist or is ambiguous." ) if self._validate_query: @@ -573,7 +577,7 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: query = original - logger.warning(warning) + get_console().log_warning(warning) except Exception as ex: raise_config_error( f"Failed to optimize query, please file an issue at https://github.com/TobikoData/sqlmesh/issues/new. {ex}", diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 28c70a1ac..29e29caff 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -222,7 +222,9 @@ def evaluate( if audit_result.blocking: audit_error_to_raise = error else: - logger.warning(f"{error}\nAudit is non-blocking so proceeding with execution.") + get_console().log_warning( + f"{error}\nAudit is non-blocking so proceeding with execution." + ) if audit_error_to_raise: raise audit_error_to_raise diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index 9e16bf22e..674b1df76 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -381,9 +381,12 @@ def _is_coerceable_type(self, current_type: exp.DataType, new_type: exp.DataType if current_type in self.coerceable_types: is_coerceable = new_type in self.coerceable_types[current_type] if is_coerceable: - logger.warning( - f"Coercing type {current_type} to {new_type} which means an alter will not be performed and therefore the resulting table structure will not match what is in the query.\nUpdate your model to cast the value to {current_type} type in order to remove this warning.", + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Coercing type {current_type} to {new_type} which means an alter will not be performed and therefore the resulting table structure will not match what is in the query.\nUpdate your model to cast the value to {current_type} type in order to remove this warning." ) + return is_coerceable return False diff --git a/sqlmesh/core/schema_loader.py b/sqlmesh/core/schema_loader.py index 1b00c7bc1..8df5164a8 100644 --- a/sqlmesh/core/schema_loader.py +++ b/sqlmesh/core/schema_loader.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -8,14 +7,13 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType +from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.model.definition import Model from sqlmesh.core.state_sync import StateReader from sqlmesh.utils import UniqueKeyDict, yaml from sqlmesh.utils.errors import SQLMeshError -logger = logging.getLogger(__name__) - def create_external_models_file( path: Path, @@ -51,10 +49,10 @@ def create_external_models_file( # Make sure we don't convert internal models into external ones. existing_model_fqns = state_reader.nodes_exist(external_model_fqns, exclude_external=True) if existing_model_fqns: - logger.warning( - "The following models already exist and can't be converted to external: %s." - "Perhaps these models have been removed, while downstream models that reference them weren't updated accordingly", - ", ".join(existing_model_fqns), + existing_model_fqns_str = ", ".join(existing_model_fqns) + get_console().log_warning( + f"The following models already exist and can't be converted to external: {existing_model_fqns_str}. " + "Perhaps these models have been removed, while downstream models that reference them weren't updated accordingly." ) external_model_fqns -= existing_model_fqns @@ -67,7 +65,7 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]: msg = f"Unable to get schema for '{table}': '{e}'." if strict: raise SQLMeshError(msg) from e - logger.warning(msg) + get_console().log_warning(msg) return None gateway_part = {"gateway": gateway} if gateway else {} diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 24ba8379b..7944da6a0 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -46,7 +46,6 @@ if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType from sqlmesh.core.environment import EnvironmentNamingInfo - from sqlmesh.core.config import Config Interval = t.Tuple[int, int] Intervals = t.List[Interval] @@ -596,7 +595,6 @@ def from_node( ttl: str = c.DEFAULT_SNAPSHOT_TTL, version: t.Optional[str] = None, cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None, - config: t.Optional[Config] = None, ) -> Snapshot: """Creates a new snapshot for a node. @@ -1480,7 +1478,7 @@ def table_name( table.set("this", exp.to_identifier(f"{name}__{version}{temp_suffix}")) table.set("db", exp.to_identifier(physical_schema)) if not table.catalog and catalog: - table.set("catalog", exp.parse_identifier(catalog)) + table.set("catalog", exp.to_identifier(catalog)) return exp.table_name(table) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4494f708e..b12ca4d32 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -918,6 +918,7 @@ def _promote_snapshot( snapshots=snapshots, deployability_index=deployability_index, table_mapping=table_mapping, + runtime_stage=RuntimeStage.PROMOTING, ) adapter.execute(snapshot.model.render_on_virtual_update(**render_kwargs)) @@ -1006,6 +1007,7 @@ def _audit( "snapshots": snapshots, "deployability_index": deployability_index, "engine_adapter": adapter, + "runtime_stage": RuntimeStage.AUDITING, **audit_args, **kwargs, } diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py index 19c83f2f6..a147fac0a 100644 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ b/sqlmesh/core/state_sync/engine_adapter.py @@ -332,7 +332,7 @@ def promote( if ( existing_environment and existing_environment.finalized_ts - and not existing_environment.expiration_ts + and not existing_environment.expired ): # Only promote new snapshots. added_table_infos -= set(existing_environment.promoted_snapshots) diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index ae6cb67c9..c79897783 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -1,7 +1,6 @@ from __future__ import annotations import datetime -import logging import typing as t import unittest from collections import Counter @@ -33,7 +32,6 @@ Row = t.Dict[str, t.Any] -logger = logging.getLogger(__name__) TIME_KWARG_KEYS = { "start", @@ -242,9 +240,11 @@ def assert_equal( elif type(value) is datetime.datetime: expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime() except Exception as e: - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( f"Failed to convert expected value for {col} into `datetime` " - f"for unit test '{str(self)}'. {str(e)}" + f"for unit test '{str(self)}'. {str(e)}." ) actual = actual.replace({np.nan: None}) @@ -329,7 +329,11 @@ def create_test( name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) model = models.get(name) if not model: - logger.warning(f"Model '{name}' was not found{' at ' + str(path) if path else ''}") + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Model '{name}' was not found{' at ' + str(path) if path else ''}" + ) return None if isinstance(model, SqlModel): diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index fb0857cd4..22ebce302 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -219,7 +219,9 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: try: requirements[target_package] = metadata.version(target_package) except metadata.PackageNotFoundError: - logger.warning("dbt package %s is not installed", target_package) + from sqlmesh.core.console import get_console + + get_console().log_warning(f"dbt package {target_package} is not installed.") return requirements, excluded_requirements diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 8220e244b..c8f181f86 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from sqlglot import exp @@ -9,6 +8,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.config.base import UpdateStrategy +from sqlmesh.core.console import get_console from sqlmesh.core.model import ( EmbeddedKind, FullKind, @@ -33,14 +33,12 @@ from sqlmesh.dbt.context import DbtContext -logger = logging.getLogger(__name__) - INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite"]) INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES = set(["merge"]) def collection_to_str(collection: t.Iterable) -> str: - return ", ".join(f"'{item}'" for item in collection) + return ", ".join(f"'{item}'" for item in sorted(collection)) class ModelConfig(BaseModelConfig): @@ -190,7 +188,9 @@ def _validate_materialized(cls, v: str) -> str: # dictionary materialization raise ConfigError(msg) else: - logger.warning(f"{msg} Falling back to the '{fallback[1]}' materialization.") + get_console().log_warning( + f"{msg} Falling back to the '{fallback[1]}' materialization." + ) return fallback[1] return v @@ -259,11 +259,9 @@ def model_kind(self, context: DbtContext) -> ModelKind: ) if strategy not in INCREMENTAL_BY_TIME_STRATEGIES: - logger.warning( - "SQLMesh incremental by time strategy is not compatible with '%s' incremental strategy in model '%s'. Supported strategies include %s.", - strategy, - self.canonical_name(context), - collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES), + get_console().log_warning( + f"SQLMesh incremental by time strategy is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. " + f"Supported strategies include {collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES)}." ) return IncrementalByTimeRangeKind( @@ -312,11 +310,13 @@ def model_kind(self, context: DbtContext) -> ModelKind: **incremental_by_kind_kwargs, ) - logger.warning( - "Using unmanaged incremental materialization for model '%s'. Some features might not be available. Consider adding either a time_column (%s) or a unique_key (%s) configuration to mitigate this", - self.canonical_name(context), - collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES), - collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES.union(["none"])), + incremental_by_time_str = collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES) + incremental_by_unique_key_str = collection_to_str( + INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES.union(["none"]) + ) + get_console().log_warning( + f"Using unmanaged incremental materialization for model '{self.canonical_name(context)}'. " + f"Some features might not be available. Consider adding either a time_column ({incremental_by_time_str}) or a unique_key ({incremental_by_unique_key_str}) configuration to mitigate this.", ) strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalUnmanagedKind @@ -499,17 +499,17 @@ def to_sqlmesh( self.incremental_strategy = "append" if self.incremental_strategy == "delete+insert": - logger.warning( + get_console().log_warning( f"The '{self.incremental_strategy}' incremental strategy is not supported - SQLMesh will use the temp table/partition swap strategy." ) if self.incremental_predicates: - logger.warning( + get_console().log_warning( "SQLMesh does not support 'incremental_predicates' - they will not be applied." ) if self.query_settings: - logger.warning( + get_console().log_warning( "SQLMesh does not support the 'query_settings' model configuration parameter. Specify the query settings directly in the model query." ) @@ -539,7 +539,7 @@ def to_sqlmesh( physical_properties["primary_key"] = primary_key if self.sharding_key: - logger.warning( + get_console().log_warning( "SQLMesh does not support the 'sharding_key' model configuration parameter or distributed materializations." ) diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index 8c215e295..e49173608 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -1,9 +1,10 @@ from __future__ import annotations -import logging import typing as t +import logging from pathlib import Path +from sqlmesh.core.console import get_console from sqlmesh.dbt.common import PROJECT_FILENAME, load_yaml from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.manifest import ManifestHelper @@ -79,8 +80,8 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N extra_fields = profile.target.extra if extra_fields: extra_str = ",".join(f"'{field}'" for field in extra_fields) - logger.warning( - "%s adapter does not currently support %s", profile.target.type, extra_str + get_console().log_warning( + f"{profile.target.type} adapter does not currently support {extra_str}." ) packages = {} diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 74e8fcbd6..5b3814a68 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -1,13 +1,13 @@ from __future__ import annotations import abc -import logging import typing as t from pathlib import Path from dbt.adapters.base import BaseRelation, Column from pydantic import Field +from sqlmesh.core.console import get_console from sqlmesh.core.config.connection import ( AthenaConnectionConfig, BigQueryConnectionConfig, @@ -36,8 +36,6 @@ from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator, model_validator -logger = logging.getLogger(__name__) - IncrementalKind = t.Union[ t.Type[IncrementalByUniqueKeyKind], t.Type[IncrementalByTimeRangeKind], @@ -176,7 +174,7 @@ def validate_authentication(cls, data: t.Any) -> t.Any: ) if "threads" in data and t.cast(int, data["threads"]) > 1: - logger.warning("DuckDB does not support concurrency - setting threads to 1.") + get_console().log_warning("DuckDB does not support concurrency - setting threads to 1.") return data diff --git a/sqlmesh/integrations/github/cicd/command.py b/sqlmesh/integrations/github/cicd/command.py index c238aea23..3aa46c5e9 100644 --- a/sqlmesh/integrations/github/cicd/command.py +++ b/sqlmesh/integrations/github/cicd/command.py @@ -6,6 +6,7 @@ import click from sqlmesh.core.analytics import cli_analytics +from sqlmesh.core.console import set_console, MarkdownConsole from sqlmesh.integrations.github.cicd.controller import ( GithubCheckConclusion, GithubCheckStatus, @@ -26,6 +27,7 @@ @click.pass_context def github(ctx: click.Context, token: str) -> None: """Github Action CI/CD Bot. See https://sqlmesh.readthedocs.io/en/stable/integrations/github/ for details""" + set_console(MarkdownConsole()) ctx.obj["github"] = GithubController( paths=ctx.obj["paths"], token=token, diff --git a/sqlmesh/integrations/github/cicd/controller.py b/sqlmesh/integrations/github/cicd/controller.py index b17c52642..381116d9a 100644 --- a/sqlmesh/integrations/github/cicd/controller.py +++ b/sqlmesh/integrations/github/cicd/controller.py @@ -14,11 +14,10 @@ import requests from hyperscript import Element, h -from rich.console import Console from sqlglot.helper import seq_get from sqlmesh.core import constants as c -from sqlmesh.core.console import SNAPSHOT_CHANGE_CATEGORY_STR, MarkdownConsole +from sqlmesh.core.console import SNAPSHOT_CHANGE_CATEGORY_STR, get_console, MarkdownConsole from sqlmesh.core.context import Context from sqlmesh.core.environment import Environment from sqlmesh.core.plan import Plan, PlanBuilder @@ -303,7 +302,11 @@ def __init__( self._prod_plan_builder: t.Optional[PlanBuilder] = None self._prod_plan_with_gaps_builder: t.Optional[PlanBuilder] = None self._check_run_mapping: t.Dict[str, CheckRun] = {} - self._console = MarkdownConsole(console=Console(no_color=True)) + + if not isinstance(get_console(), MarkdownConsole): + raise CICDBotError("Console must be a markdown console.") + self._console = t.cast(MarkdownConsole, get_console()) + self._client: Github = client or Github( base_url=os.environ["GITHUB_API_URL"], login_or_token=self._token, @@ -326,7 +329,6 @@ def __init__( self._context: Context = Context( paths=self._paths, config=self.config, - console=self._console, ) @property diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 6a40569ae..9b4ac86e2 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -24,7 +24,7 @@ from sqlmesh.core import analytics from sqlmesh.core import constants as c from sqlmesh.core.config import load_configs -from sqlmesh.core.console import get_console +from sqlmesh.core.console import create_console, set_console, configure_console from sqlmesh.core.context import Context from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model @@ -53,7 +53,9 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: f"Context must be defined and initialized with one of these names: {', '.join(CONTEXT_VARIABLE_NAMES)}" ) old_console = context.console - context.console = get_console(display=self.display) + new_console = create_console(display=self.display) + context.console = new_console + set_console(new_console) context.refresh() magic_name = func.__name__ @@ -81,6 +83,7 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: func(self, context, *args, **kwargs) context.console = old_console + set_console(old_console) return wrapper @@ -128,9 +131,8 @@ def context(self, line: str) -> None: args = parse_argstring(self.context, line) configs = load_configs(args.config, Context.CONFIG_TYPE, args.paths) log_limit = list(configs.values())[0].log_limit - configure_logging( - args.debug, args.ignore_warnings, log_limit=log_limit, log_file_dir=args.log_file_dir - ) + configure_logging(args.debug, log_limit=log_limit, log_file_dir=args.log_file_dir) + configure_console(ignore_warnings=args.ignore_warnings) try: context = Context(paths=args.paths, config=configs, gateway=args.gateway) self._shell.user_ns["context"] = context diff --git a/sqlmesh/schedulers/airflow/hooks/clickhouse.py b/sqlmesh/schedulers/airflow/hooks/clickhouse.py new file mode 100644 index 000000000..8057f12a6 --- /dev/null +++ b/sqlmesh/schedulers/airflow/hooks/clickhouse.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing as t + +from airflow.providers.common.sql.hooks.sql import DbApiHook + +if t.TYPE_CHECKING: + from clickhouse_connect.dbapi.connection import Connection + + +class SQLMeshClickHouseHook(DbApiHook): + """ + Uses the ClickHouse Python DB API connector. + """ + + conn_name_attr = "sqlmesh_clickhouse_conn_id" + default_conn_name = "sqlmesh_clickhouse_default" + conn_type = "sqlmesh_clickhouse" + hook_name = "SQLMesh ClickHouse" + + def get_conn(self) -> Connection: + """Returns a ClickHouse connection object""" + from clickhouse_connect.dbapi import connect + + db = self.get_connection(getattr(self, t.cast(str, self.conn_name_attr))) + + return connect( + host=db.host, + port=db.port, + username=db.login, + password=db.password, + database=db.schema, + **db.extra_dejson, + ) diff --git a/sqlmesh/schedulers/airflow/operators/clickhouse.py b/sqlmesh/schedulers/airflow/operators/clickhouse.py new file mode 100644 index 000000000..c9b3301ad --- /dev/null +++ b/sqlmesh/schedulers/airflow/operators/clickhouse.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing as t + + +from sqlmesh.schedulers.airflow.hooks.clickhouse import SQLMeshClickHouseHook +from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator +from sqlmesh.schedulers.airflow.operators.targets import BaseTarget + + +class SQLMeshClickHouseOperator(BaseDbApiOperator): + """The operator that evaluates a SQLMesh model snapshot on a ClickHouse target + + Args: + target: The target that will be executed by this operator instance. + postgres_conn_id: The Airflow connection id for the postgres target. + """ + + def __init__( + self, + *, + target: BaseTarget, + clickhouse_conn_id: str = SQLMeshClickHouseHook.default_conn_name, + **kwargs: t.Any, + ) -> None: + super().__init__( + target=target, + conn_id=clickhouse_conn_id, + dialect="clickhouse", + hook_type=SQLMeshClickHouseHook, + **kwargs, + ) diff --git a/sqlmesh/schedulers/airflow/util.py b/sqlmesh/schedulers/airflow/util.py index 1eeb9e620..321331a4b 100644 --- a/sqlmesh/schedulers/airflow/util.py +++ b/sqlmesh/schedulers/airflow/util.py @@ -122,6 +122,10 @@ def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOp name = name.lower() try: + if name == "clickhouse": + from sqlmesh.schedulers.airflow.operators.clickhouse import SQLMeshClickHouseOperator + + return SQLMeshClickHouseOperator if name == "spark": from sqlmesh.schedulers.airflow.operators.spark_submit import ( SQLMeshSparkSubmitOperator, diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py index e3eb80681..169b4d738 100644 --- a/sqlmesh/utils/connection_pool.py +++ b/sqlmesh/utils/connection_pool.py @@ -115,7 +115,6 @@ class ThreadLocalConnectionPool(_TransactionManagementMixin): def __init__( self, connection_factory: t.Callable[[], t.Any], - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ): self._connection_factory = connection_factory @@ -126,14 +125,13 @@ def __init__( self._thread_connections_lock = Lock() self._thread_cursors_lock = Lock() self._thread_transactions_lock = Lock() - self._cursor_kwargs = cursor_kwargs or {} self._cursor_init = cursor_init def get_cursor(self) -> t.Any: thread_id = get_ident() with self._thread_cursors_lock: if thread_id not in self._thread_cursors: - self._thread_cursors[thread_id] = self.get().cursor(**self._cursor_kwargs) + self._thread_cursors[thread_id] = self.get().cursor() if self._cursor_init: self._cursor_init(self._thread_cursors[thread_id]) return self._thread_cursors[thread_id] @@ -209,20 +207,18 @@ class SingletonConnectionPool(_TransactionManagementMixin): def __init__( self, connection_factory: t.Callable[[], t.Any], - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ): self._connection_factory = connection_factory self._connection: t.Optional[t.Any] = None self._cursor: t.Optional[t.Any] = None - self._cursor_kwargs = cursor_kwargs or {} self._attributes: t.Dict[str, t.Any] = {} self._is_transaction_active: bool = False self._cursor_init = cursor_init def get_cursor(self) -> t.Any: if not self._cursor: - self._cursor = self.get().cursor(**self._cursor_kwargs) + self._cursor = self.get().cursor() if self._cursor_init: self._cursor_init(self._cursor) return self._cursor @@ -273,17 +269,12 @@ def close_all(self, exclude_calling_thread: bool = False) -> None: def create_connection_pool( connection_factory: t.Callable[[], t.Any], multithreaded: bool, - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ) -> ConnectionPool: return ( - ThreadLocalConnectionPool( - connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init - ) + ThreadLocalConnectionPool(connection_factory, cursor_init=cursor_init) if multithreaded - else SingletonConnectionPool( - connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init - ) + else SingletonConnectionPool(connection_factory, cursor_init=cursor_init) ) diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index d0a3c3840..000dbd8dc 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -159,6 +159,10 @@ class PythonModelEvalError(SQLMeshError): pass +class MissingDefaultCatalogError(SQLMeshError): + pass + + def raise_config_error( msg: str, location: t.Optional[str | Path] = None, diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 6b237fe45..0561ca015 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -2,7 +2,6 @@ from contextlib import contextmanager from os import getcwd, path, remove from pathlib import Path -from unittest.mock import patch import pytest from click.testing import CliRunner import time_machine @@ -416,30 +415,28 @@ def test_plan_dev_bad_create_from(runner, tmp_path): update_incremental_model(tmp_path) # create dev2 environment from non-existent dev3 - logger = logging.getLogger("sqlmesh.core.context_diff") - with patch.object(logger, "warning") as mock_logger: - result = runner.invoke( - cli, - [ - "--log-file-dir", - tmp_path, - "--paths", - tmp_path, - "plan", - "dev2", - "--create-from", - "dev3", - "--no-prompts", - "--auto-apply", - ], - ) + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev2", + "--create-from", + "dev3", + "--no-prompts", + "--auto-apply", + ], + ) - assert result.exit_code == 0 - assert_new_env(result, "dev2", "dev") - assert ( - mock_logger.call_args[0][0] - == "The environment name 'dev3' was passed to the `plan` command's `--create-from` argument, but 'dev3' does not exist. Initializing new environment 'dev2' from scratch." - ) + assert result.exit_code == 0 + assert_new_env(result, "dev2", "dev") + assert ( + "The environment name 'dev3' was passed to the `plan` command's `--create-from` argument, but 'dev3' does not exist. Initializing new environment 'dev2' from scratch." + in result.output.replace("\n", "") + ) def test_plan_dev_no_prompts(runner, tmp_path): @@ -779,6 +776,7 @@ def test_dlt_pipeline_errors(runner, tmp_path): assert "Error: Could not attach to pipeline" in result.output +@time_machine.travel(FREEZE_TIME) def test_plan_dlt(runner, tmp_path): root_dir = path.abspath(getcwd()) pipeline_path = root_dir + "/examples/sushi_dlt/sushi_pipeline.py" @@ -793,12 +791,12 @@ def test_plan_dlt(runner, tmp_path): init_example_project(tmp_path, "duckdb", ProjectTemplate.DLT, "sushi") expected_config = f"""gateways: - dev: + duckdb: connection: type: duckdb database: {dataset_path} -default_gateway: dev +default_gateway: duckdb model_defaults: dialect: duckdb @@ -950,20 +948,21 @@ def test_plan_dlt(runner, tmp_path): remove(dataset_path) -def test_init_project_dialects(runner, tmp_path): +@time_machine.travel(FREEZE_TIME) +def test_init_project_dialects(tmp_path): dialect_to_config = { - "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: ", - "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", - "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", - "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: \n # pretty_sql: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: \n # force_databricks_connect: \n # disable_databricks_connect: \n # disable_spark_session: ", - "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: ", + "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: ", + "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", + "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", + "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", + "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: ", } for dialect, expected_config in dialect_to_config.items(): init_example_project(tmp_path, dialect=dialect) - config_start = f"gateways:\n dev:\n connection:\n # For more information on configuring the connection to your execution engine, visit:\n # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connections\n # https://sqlmesh.readthedocs.io/en/stable/integrations/engines/{dialect}/#connection-options\n type: {dialect}\n " - config_end = f"\n\n\ndefault_gateway: dev\n\nmodel_defaults:\n dialect: {dialect}\n start: {yesterday_ds()}\n" + config_start = f"gateways:\n {dialect}:\n connection:\n # For more information on configuring the connection to your execution engine, visit:\n # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connections\n # https://sqlmesh.readthedocs.io/en/stable/integrations/engines/{dialect}/#connection-options\n type: {dialect}\n " + config_end = f"\n\n\ndefault_gateway: {dialect}\n\nmodel_defaults:\n dialect: {dialect}\n start: {yesterday_ds()}\n" with open(tmp_path / "config.yaml") as file: config = file.read() diff --git a/tests/conftest.py b/tests/conftest.py index 038cf9980..b7e9a071f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,6 +235,13 @@ def rescope_lineage_cache(request): yield +@pytest.fixture(autouse=True) +def reset_console(): + from sqlmesh.core.console import set_console, NoopConsole + + set_console(NoopConsole()) + + @pytest.fixture def duck_conn() -> duckdb.DuckDBPyConnection: return duckdb.connect() @@ -433,6 +440,7 @@ def _make_function( klass: t.Type[T], dialect: t.Optional[str] = None, register_comments: bool = True, + default_catalog: t.Optional[str] = None, **kwargs: t.Any, ) -> T: connection_mock = mocker.NonCallableMock() @@ -443,6 +451,7 @@ def _make_function( lambda: connection_mock, dialect=dialect or klass.DIALECT, register_comments=register_comments, + default_catalog=default_catalog, **kwargs, ) if isinstance(adapter, SparkEngineAdapter): diff --git a/tests/core/engine_adapter/integration/config.yaml b/tests/core/engine_adapter/integration/config.yaml index 1a64aba0b..fce7ccdbb 100644 --- a/tests/core/engine_adapter/integration/config.yaml +++ b/tests/core/engine_adapter/integration/config.yaml @@ -149,6 +149,9 @@ gateways: port: 8443 username: {{ env_var("CLICKHOUSE_CLOUD_USERNAME") }} password: {{ env_var("CLICKHOUSE_CLOUD_PASSWORD") }} + connect_timeout: 30 + connection_pool_options: + retries: 5 state_connection: type: duckdb diff --git a/tests/core/engine_adapter/integration/conftest.py b/tests/core/engine_adapter/integration/conftest.py index ceaccd29f..a7655c5b8 100644 --- a/tests/core/engine_adapter/integration/conftest.py +++ b/tests/core/engine_adapter/integration/conftest.py @@ -89,7 +89,8 @@ def ctx( ctx = TestContext(test_type, engine_adapter, mark, gateway, is_remote=is_remote) ctx.init() - yield ctx + with ctx.engine_adapter.session({}): + yield ctx try: ctx.cleanup() diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index 4ef71fa90..fa495ca24 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -19,7 +19,10 @@ def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapte "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.replace_query( "test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")} ) @@ -34,7 +37,10 @@ def test_replace_query_exists(mocker: MockFixture, make_mocked_engine_adapter: t "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=True, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) assert to_sql_calls(adapter) == [ @@ -49,7 +55,10 @@ def test_replace_query_pandas_not_exists( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) adapter.replace_query( "test_table", df, {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")} @@ -65,7 +74,10 @@ def test_replace_query_pandas_exists(mocker: MockFixture, make_mocked_engine_ada "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=True, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) adapter.replace_query( "test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")} @@ -76,31 +88,40 @@ def test_replace_query_pandas_exists(mocker: MockFixture, make_mocked_engine_ada ] -def test_clone_table(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_clone_table(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.clone_table("target_table", "source_table") adapter.cursor.execute.assert_called_once_with( "CREATE TABLE `target_table` SHALLOW CLONE `source_table`" ) -def test_set_current_catalog(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) - adapter.set_current_catalog("test_catalog") +def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + adapter.set_current_catalog("test_catalog2") - assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog`"] + assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"] -def test_get_current_catalog(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_catalog",) assert adapter.get_current_catalog() == "test_catalog" assert to_sql_calls(adapter) == ["SELECT CURRENT_CATALOG()"] -def test_get_current_database(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_get_current_database(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_database",) assert adapter.get_current_database() == "test_database" @@ -110,7 +131,10 @@ def test_get_current_database(make_mocked_engine_adapter: t.Callable): def test_insert_overwrite_by_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockFixture, make_temp_table_name: t.Callable ): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") table_name = "test_schema.test_table" @@ -139,8 +163,11 @@ def test_insert_overwrite_by_partition_query( ] -def test_materialized_view_properties(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_materialized_view_properties(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.create_view( "test_table", @@ -161,8 +188,11 @@ def test_materialized_view_properties(make_mocked_engine_adapter: t.Callable): ] -def test_create_table_clustered_by(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_create_table_clustered_by(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") columns_to_types = { "cola": exp.DataType.build("INT"), diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 3d7689623..a11dac78b 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -26,6 +26,7 @@ load_configs, ) from sqlmesh.core.context import Context +from sqlmesh.core.console import create_console from sqlmesh.core.dialect import parse, schema_ from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.environment import Environment @@ -459,8 +460,7 @@ def test_override_builtin_audit_blocking_mode(): ) ) - logger = logging.getLogger("sqlmesh.core.scheduler") - with patch.object(logger, "warning") as mock_logger: + with patch.object(context.console, "log_warning") as mock_logger: plan = context.plan(auto_apply=True, no_prompts=True) new_snapshot = next(iter(plan.context_diff.new_snapshots.values())) @@ -506,6 +506,8 @@ def test_override_builtin_audit_blocking_mode(): def test_python_model_empty_df_raises(sushi_context, capsys): + sushi_context.console = create_console() + @model( "memory.sushi.test_model", columns={"col": "int"}, @@ -524,8 +526,8 @@ def entrypoint(context, **kwargs): sushi_context.plan(no_prompts=True, auto_apply=True) assert ( - "Cannot construct source query from an empty DataFrame. This error is \ncommonly related to Python models that produce no data. For such models, \nconsider yielding from an empty generator if the resulting set is empty, i.e. \nuse" - ) in capsys.readouterr().out + "Cannot construct source query from an empty DataFrame. This error is commonly related to Python models that produce no data. For such models, consider yielding from an empty generator if the resulting set is empty, i.e. use" + ) in capsys.readouterr().out.replace("\n", "") def test_env_and_default_schema_normalization(mocker: MockerFixture): @@ -1008,8 +1010,7 @@ def test_load_external_models(copy_to_temp_path): assert context.resolve_table("raw.demographics") == '"memory"."raw"."demographics"' assert context.resolve_table("raw.model2") == '"memory"."raw"."model2"' - logger = logging.getLogger("sqlmesh.core.context") - with patch.object(logger, "warning") as mock_logger: + with patch.object(context.console, "log_warning") as mock_logger: context.table("raw.model1") == '"memory"."raw"."model1"' assert mock_logger.mock_calls == [ @@ -1053,6 +1054,27 @@ def test_disabled_model(copy_to_temp_path): assert not context.get_model("sushi.disabled_py") +def test_disabled_model_python_macro(sushi_context): + @model( + "memory.sushi.disabled_model_2", + columns={"col": "int"}, + enabled="@IF(@gateway = 'dev', True, False)", + ) + def entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + test_model = model.get_registry()["memory.sushi.disabled_model_2"].model( + module_path=Path("."), path=Path("."), variables={"gateway": "prod"} + ) + assert not test_model.enabled + + with pytest.raises( + SQLMeshError, + match="The disabled model 'memory.sushi.disabled_model_2' cannot be upserted", + ): + sushi_context.upsert_model(test_model) + + def test_get_model_mixed_dialects(copy_to_temp_path): path = copy_to_temp_path("examples/sushi") @@ -1118,7 +1140,7 @@ def test_wildcard(copy_to_temp_path: t.Callable): parent_path = copy_to_temp_path("examples/multi")[0] context = Context(paths=f"{parent_path}/*") - assert len(context.models) == 4 + assert len(context.models) == 5 def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path): @@ -1265,3 +1287,23 @@ def test_rendered_diff(): ) -DROP VIEW "test" +DROP VIEW IF EXISTS "test"''' in plan.context_diff.text_diff('"test"') + + +def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context: Context): + assert sushi_context._plan_preview_enabled + assert not sushi_dbt_context._plan_preview_enabled + + sushi_dbt_context.engine_adapter.SUPPORTS_CLONING = True + assert sushi_dbt_context._plan_preview_enabled + + +def test_catalog_name_needs_to_be_quoted(): + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(catalogs={'"foo--bar"': ":memory:"}), + ) + context = Context(config=config) + parsed_model = parse("MODEL(name db.x, kind FULL); SELECT 1 AS c") + context.upsert_model(load_sql_based_model(parsed_model, default_catalog='"foo--bar"')) + context.plan(auto_apply=True, no_prompts=True) + assert context.fetchdf('select * from "foo--bar".db.x').to_dict() == {"c": {0: 1}} diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 4af9532ac..a50914052 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -232,7 +232,7 @@ def test_forward_only_model_regular_plan(init_and_plan_context: t.Callable): snapshot = context.get_snapshot(model, raise_if_missing=True) top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert len(plan.new_snapshots) == 2 assert ( plan.context_diff.snapshots[snapshot.snapshot_id].change_category @@ -253,7 +253,9 @@ def test_forward_only_model_regular_plan(init_and_plan_context: t.Callable): assert not dev_df["event_date"].tolist() # Run a restatement plan to preview changes - plan_builder = context.plan_builder("dev", skip_tests=True, restate_models=[model_name]) + plan_builder = context.plan_builder( + "dev", skip_tests=True, restate_models=[model_name], enable_preview=False + ) plan_builder.set_start("2023-01-06") assert plan_builder.build().missing_intervals == [ SnapshotIntervals( @@ -397,7 +399,7 @@ def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t assert model.kind.full_history_restatement_only context.upsert_model(model) - context.plan("prod", skip_tests=True, auto_apply=True) + context.plan("prod", skip_tests=True, auto_apply=True, enable_preview=False) model_kwargs = { **model.dict(), @@ -407,7 +409,7 @@ def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t context.upsert_model(SqlModel.parse_obj(model_kwargs)) # Apply the model change in dev - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert not plan.missing_intervals context.apply(plan) @@ -425,7 +427,7 @@ def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t assert df["cnt"][0] == 1 # Apply a restatement plan in dev - plan = context.plan("dev", restate_models=[model.name], auto_apply=True) + plan = context.plan("dev", restate_models=[model.name], auto_apply=True, enable_preview=False) assert len(plan.missing_intervals) == 1 # Check that the dummy value is not present @@ -833,7 +835,7 @@ def test_forward_only_parent_created_in_dev_child_created_in_prod( ) top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert len(plan.new_snapshots) == 2 assert ( plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category @@ -855,7 +857,7 @@ def test_forward_only_parent_created_in_dev_child_created_in_prod( top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("prod", skip_tests=True).build() + plan = context.plan_builder("prod", skip_tests=True, enable_preview=False).build() assert len(plan.new_snapshots) == 1 assert ( plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category @@ -869,7 +871,7 @@ def test_forward_only_parent_created_in_dev_child_created_in_prod( def test_new_forward_only_model(init_and_plan_context: t.Callable): context, _ = init_and_plan_context("examples/sushi") - context.plan("dev", skip_tests=True, no_prompts=True, auto_apply=True) + context.plan("dev", skip_tests=True, no_prompts=True, auto_apply=True, enable_preview=False) snapshot = context.get_snapshot("sushi.marketing") @@ -1156,7 +1158,7 @@ def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_co context.upsert_model(model) snapshot = context.get_snapshot(model, raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert ( plan.context_diff.snapshots[snapshot.snapshot_id].change_category == SnapshotChangeCategory.FORWARD_ONLY @@ -1169,7 +1171,7 @@ def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_co context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert len(plan.new_snapshots) == 1 assert ( plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category @@ -1202,7 +1204,7 @@ def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_co ) top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert len(plan.new_snapshots) == 2 assert ( plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category @@ -1233,7 +1235,7 @@ def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_co assert not context.plan_builder("dev", skip_tests=True).build().requires_backfill # Deploy everything to prod. - plan = context.plan_builder("prod", skip_tests=True).build() + plan = context.plan_builder("prod", skip_tests=True, enable_preview=False).build() assert plan.start == to_timestamp("2023-01-01") assert plan.missing_intervals == [ SnapshotIntervals( @@ -1263,7 +1265,11 @@ def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_co ] context.apply(plan) - assert not context.plan_builder("prod", skip_tests=True).build().requires_backfill + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) @time_machine.travel("2023-01-08 15:00:00 UTC") @@ -1285,7 +1291,7 @@ def test_forward_only_precedence_over_indirect_non_breaking(init_and_plan_contex non_breaking_snapshot = context.get_snapshot(non_breaking_model, raise_if_missing=True) top_waiter_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - plan = context.plan_builder("dev", skip_tests=True).build() + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() assert ( plan.context_diff.snapshots[forward_only_snapshot.snapshot_id].change_category == SnapshotChangeCategory.FORWARD_ONLY @@ -1315,7 +1321,11 @@ def test_forward_only_precedence_over_indirect_non_breaking(init_and_plan_contex ] context.apply(plan) - assert not context.plan_builder("dev", skip_tests=True).build().requires_backfill + assert ( + not context.plan_builder("dev", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) # Deploy everything to prod. plan = context.plan_builder("prod", skip_tests=True).build() @@ -1336,7 +1346,11 @@ def test_forward_only_precedence_over_indirect_non_breaking(init_and_plan_contex ] context.apply(plan) - assert not context.plan_builder("prod", skip_tests=True).build().requires_backfill + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) @time_machine.travel("2023-01-08 15:00:00 UTC") @@ -1579,12 +1593,6 @@ def test_select_models_for_backfill(init_and_plan_context: t.Callable): context, _ = init_and_plan_context("examples/sushi") expected_intervals = [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), ] @@ -1620,7 +1628,7 @@ def test_select_models_for_backfill(init_and_plan_context: t.Callable): dev_df = context.engine_adapter.fetchdf( "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" ) - assert len(dev_df) == 7 + assert len(dev_df) == 1 schema_objects = context.engine_adapter.get_data_objects("sushi__dev") assert {o.name for o in schema_objects} == { @@ -1854,7 +1862,7 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( full_downstream_model_2_snapshot_id = context.get_snapshot( view_downstream_model_name ).snapshot_id - dev_plan = context.plan("dev", auto_apply=True, no_prompts=True) + dev_plan = context.plan("dev", auto_apply=True, no_prompts=True, enable_preview=False) assert ( dev_plan.snapshots[forward_only_model_snapshot_id].change_category == SnapshotChangeCategory.FORWARD_ONLY @@ -1887,7 +1895,11 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( view_downstream_model_name ).snapshot_id dev_plan = context.plan( - "dev", categorizer_config=CategorizerConfig.all_full(), auto_apply=True, no_prompts=True + "dev", + categorizer_config=CategorizerConfig.all_full(), + auto_apply=True, + no_prompts=True, + enable_preview=False, ) assert ( dev_plan.snapshots[full_downstream_model_snapshot_id].change_category @@ -1917,7 +1929,11 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( view_downstream_model_name ).snapshot_id dev_plan = context.plan( - "dev", categorizer_config=CategorizerConfig.all_full(), auto_apply=True, no_prompts=True + "dev", + categorizer_config=CategorizerConfig.all_full(), + auto_apply=True, + no_prompts=True, + enable_preview=False, ) assert ( dev_plan.snapshots[full_downstream_model_snapshot_id].change_category @@ -2855,6 +2871,35 @@ def test_prod_restatement_plan_missing_model_in_dev( ) +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_snapshot_table_exists_for_promoted_snapshot(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan("dev", auto_apply=True, no_prompts=True, skip_tests=True) + + # Drop the views and make sure SQLMesh recreates them later + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + context.engine_adapter.drop_view(top_waiters_snapshot.table_name()) + context.engine_adapter.drop_view(top_waiters_snapshot.table_name(False)) + + # Make the environment unfinalized to force recreation of all views in the virtual layer + context.state_sync.state_sync.engine_adapter.execute( + "UPDATE sqlmesh._environments SET finalized_ts = NULL WHERE name = 'dev'" + ) + + model = context.get_model("sushi.customers") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan( + "dev", select_models=["sushi.customers"], auto_apply=True, no_prompts=True, skip_tests=True + ) + assert context.engine_adapter.table_exists(top_waiters_snapshot.table_name()) + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_plan_against_expired_environment(init_and_plan_context: t.Callable): context, plan = init_and_plan_context("examples/sushi") @@ -3821,7 +3866,7 @@ def test_multi(mocker): ) context._new_state_sync().reset(default_catalog=context.default_catalog) plan = context.plan_builder().build() - assert len(plan.new_snapshots) == 4 + assert len(plan.new_snapshots) == 5 context.apply(plan) adapter = context.engine_adapter @@ -3840,12 +3885,13 @@ def test_multi(mocker): assert set(snapshot.name for snapshot in plan.directly_modified) == { '"memory"."bronze"."a"', '"memory"."bronze"."b"', + '"memory"."silver"."e"', } assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [ '"memory"."silver"."c"', '"memory"."silver"."d"', ] - assert len(plan.missing_intervals) == 2 + assert len(plan.missing_intervals) == 3 context.apply(plan) validate_apply_basics(context, c.PROD, plan.snapshots.values()) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8503248de..8d7fd729d 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,6 +1,5 @@ # ruff: noqa: F811 import json -import logging import typing as t from datetime import date, datetime from pathlib import Path @@ -16,6 +15,7 @@ from sqlmesh.core import constants as c from sqlmesh.core import dialect as d +from sqlmesh.core.console import get_console from sqlmesh.core.audit import ModelAudit, load_audit from sqlmesh.core.config import ( Config, @@ -275,8 +275,7 @@ def test_model_validation_union_query(): def test_model_qualification(): - logger = logging.getLogger("sqlmesh.core.renderer") - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ MODEL ( @@ -292,7 +291,7 @@ def test_model_qualification(): model.render_query(needs_optimization=True) assert ( mock_logger.call_args[0][0] - == """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous""" + == """Column '"a"' could not be resolved for model '"db"."table"', the column may not exist or is ambiguous.""" ) @@ -2082,8 +2081,6 @@ def model2_entrypoint(evaluator: MacroEvaluator) -> str: def test_python_model_decorator_kind() -> None: - logger = logging.getLogger("sqlmesh.core.model.decorator") - # no kind specified -> default Full kind @model("default_kind", columns={'"COL"': "int"}) def a_model(context): @@ -2152,7 +2149,7 @@ def my_model_2(context): pass # warning if kind is ModelKind instance - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: python_model = model.get_registry()["kind_instance"].model( module_path=Path("."), path=Path("."), @@ -2164,7 +2161,7 @@ def my_model_2(context): ) # no warning with valid kind dict - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: @model("kind_valid_dict", kind=dict(name=ModelKindName.FULL), columns={'"COL"': "int"}) def my_model(context): @@ -2678,8 +2675,7 @@ def test_update_schema(): model.update_schema(schema) assert model.mapping_schema == {'"table_a"': {"a": "INT"}} - logger = logging.getLogger("sqlmesh.core.renderer") - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: model.render_query(needs_optimization=True) assert mock_logger.call_args[0][0] == missing_schema_warning_msg( '"db"."table"', ('"table_b"',) @@ -2695,8 +2691,6 @@ def test_update_schema(): def test_missing_schema_warnings(): - logger = logging.getLogger("sqlmesh.core.renderer") - full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2711,34 +2705,36 @@ def test_missing_schema_warnings(): }, ) + console = get_console() + # star, no schema, no deps - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM (SELECT 1 a) x")) model.render_query(needs_optimization=True) mock_logger.assert_not_called() # star, full schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(full_schema) model.render_query(needs_optimization=True) mock_logger.assert_not_called() # star, partial schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(partial_schema) model.render_query(needs_optimization=True) assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"b"',)) # star, no schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM b JOIN a")) model.render_query(needs_optimization=True) assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"a"', '"b"')) # no star, full schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) @@ -2747,7 +2743,7 @@ def test_missing_schema_warnings(): mock_logger.assert_not_called() # no star, partial schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) @@ -2756,7 +2752,7 @@ def test_missing_schema_warnings(): mock_logger.assert_not_called() # no star, empty schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) @@ -4906,7 +4902,7 @@ def external_model_python(context, **kwargs): def test_variables_python_sql_model(mocker: MockerFixture) -> None: @model( - "test_variables_python_model", + "test_variables_python_model_@{bar}", is_sql=True, kind="full", columns={"a": "string", "b": "string", "c": "string"}, @@ -4918,12 +4914,13 @@ def model_with_variables(evaluator, **kwargs): exp.convert(evaluator.var("test_var_c")).as_("c"), ) - python_sql_model = model.get_registry()["test_variables_python_model"].model( + python_sql_model = model.get_registry()["test_variables_python_model_@{bar}"].model( module_path=Path("."), path=Path("."), - variables={"test_var_a": "test_value", "test_var_unused": 2}, + variables={"test_var_a": "test_value", "test_var_unused": 2, "bar": "suffix"}, ) + assert python_sql_model.name == "test_variables_python_model_suffix" assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( {"test_var_a": "test_value"} ) @@ -4936,6 +4933,195 @@ def model_with_variables(evaluator, **kwargs): ) +def test_macros_python_model(mocker: MockerFixture) -> None: + @model( + "foo_macro_model_@{bar}", + columns={"a": "string"}, + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="@{time_col}"), + stamp="@{stamp}", + owner="@IF(@gateway = 'dev', @{dev_owner}, @{prod_owner})", + enabled="@IF(@gateway = 'dev', True, False)", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + partitioned_by=[ + d.parse_one("DATETIME_TRUNC(@{time_col}, MONTH)"), + ], + ) + def model_with_macros(context, **kwargs): + return pd.DataFrame( + [ + { + "a": context.var("TEST_VAR_A"), + } + ] + ) + + python_model = model.get_registry()["foo_macro_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + variables={ + "test_var_a": "test_value", + "gateway": "prod", + "bar": "suffix", + "dev_owner": "dv_1", + "prod_owner": "pr_1", + "stamp": "bump", + "time_col": "a", + }, + ) + + assert python_model.name == "foo_macro_model_suffix" + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"}) + assert not python_model.enabled + assert python_model.start == "2024-01-01" + assert python_model.owner == "pr_1" + assert python_model.stamp == "bump" + assert python_model.time_column.column == exp.column("a", quoted=True) + assert python_model.partitioned_by[0].sql() == 'DATETIME_TRUNC("a", MONTH)' + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"a": "test_value"}] + + +def test_macros_python_sql_model(mocker: MockerFixture) -> None: + @macro() + def end_date_macro(evaluator: MacroEvaluator, var: bool): + return f"@IF({var} = False, '1 day ago', '2025-01-01 12:00:00')" + + @model( + "test_macros_python_model_@{bar}", + is_sql=True, + kind="full", + cron="@daily", + columns={"a": "string"}, + enabled="@IF(@gateway = 'dev', True, False)", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + end="@end_date_macro(@{global_var})", + owner="@IF(@gateway = 'dev', @{dev_owner}, @{prod_owner})", + stamp="@{stamp}", + tags=["@{tag1}", "@{tag2}"], + description="Model desc @{test_}", + ) + def model_with_macros(evaluator, **kwargs): + return exp.select( + exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), + ) + + python_sql_model = model.get_registry()["test_macros_python_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + variables={ + "test_var_a": "test_value", + "test_var_unused": 2, + "bar": "suffix", + "gateway": "dev", + "global_var": False, + "dev_owner": "dv_1", + "prod_owner": "pr_1", + "stamp": "bump", + "time_col": "a", + "tag1": "tag__1", + "tag2": "tag__2", + }, + ) + + assert python_sql_model.name == "test_macros_python_model_suffix" + assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value"} + ) + + assert python_sql_model.enabled + assert python_sql_model.start == "1 month ago" + assert python_sql_model.end == "1 day ago" + assert python_sql_model.owner == "dv_1" + assert python_sql_model.stamp == "bump" + assert python_sql_model.description == "Model desc @{test_}" + assert python_sql_model.tags == ["tag__1", "tag__2"] + + context = ExecutionContext(mocker.Mock(), {}, None, None) + query = list(python_sql_model.render(context=context))[0] + assert query.sql() == """SELECT 'test_value' AS "a" """.strip() + + +def test_unrendered_macros_python_model(mocker: MockerFixture) -> None: + @model( + "test_unrendered_macros_python_model_@{bar}", + is_sql=True, + kind=dict( + name=ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, + unique_key="@{key}", + merge_filter="source.id > 0 and target.updated_at < @end_ds and source.updated_at > @start_ds", + ), + cron="@daily", + columns={"a": "string"}, + allow_partials="@IF(@gateway = 'dev', True, False)", + physical_properties=dict( + location1="@'s3://bucket/prefix/@{schema_name}/@{table_name}'", + location2="@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')", + ), + virtual_properties={"creatable_type": "@{create_type}"}, + session_properties={ + "spark.executor.cores": "@IF(@gateway = 'dev', 1, 2)", + "spark.executor.memory": "1G", + }, + ) + def model_with_macros(evaluator, **kwargs): + return exp.select( + exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), + ) + + python_sql_model = model.get_registry()["test_unrendered_macros_python_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + variables={ + "test_var_a": "test_value", + "bar": "suffix", + "gateway": "dev", + "key": "a", + "create_type": "'SECURE'", + }, + ) + + assert python_sql_model.name == "test_unrendered_macros_python_model_suffix" + assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value"} + ) + assert python_sql_model.enabled + + context = ExecutionContext(mocker.Mock(), {}, None, None) + query = list(python_sql_model.render(context=context))[0] + assert query.sql() == """SELECT 'test_value' AS "a" """.strip() + assert python_sql_model.allow_partials + + assert "location1" in python_sql_model.physical_properties + assert "location2" in python_sql_model.physical_properties + + assert python_sql_model.session_properties == { + "spark.executor.cores": 1, + "spark.executor.memory": "1G", + } + assert python_sql_model.virtual_properties["creatable_type"].this == "SECURE" + + # The physical_properties will stay unrendered at load time + assert ( + python_sql_model.physical_properties["location1"].text("this") + == "@'s3://bucket/prefix/@{schema_name}/@{table_name}'" + ) + assert ( + python_sql_model.physical_properties["location2"].text("this") + == "@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')" + ) + + # Merge_filter will stay unrendered as well + assert python_sql_model.unique_key[0] == exp.column("a", quoted=True) + assert ( + python_sql_model.merge_filter.sql() + == "source.id > 0 AND target.updated_at < @end_ds AND source.updated_at > @start_ds" + ) + + def test_columns_python_sql_model() -> None: @model( "test_columns_python_model", @@ -6570,6 +6756,36 @@ def resolve_parent(evaluator, name): assert len(post_statements) == 1 assert post_statements[0].sql() == f'"sqlmesh__default"."parent__{version}"' + # test with additional nesting level and default catalog + for post_statement in ( + "JINJA_STATEMENT_BEGIN; {{ resolve_table('schema.parent') }}; JINJA_END;", + "@resolve_parent('schema.parent')", + ): + expressions = d.parse( + f""" + MODEL (name schema.child); + + SELECT c FROM schema.parent; + + {post_statement} + """ + ) + child = load_sql_based_model(expressions, default_catalog="main") + parent = load_sql_based_model( + d.parse("MODEL (name schema.parent); SELECT 1 AS c"), default_catalog="main" + ) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + version = parent_snapshot.version + + post_statements = child.render_post_statements( + snapshots={'"main"."schema"."parent"': parent_snapshot} + ) + + assert len(post_statements) == 1 + assert post_statements[0].sql() == f'"main"."sqlmesh__schema"."schema__parent__{version}"' + def test_cluster_with_complex_expression(): expressions = d.parse( diff --git a/tests/core/test_schema_loader.py b/tests/core/test_schema_loader.py index e88fc358b..a74933553 100644 --- a/tests/core/test_schema_loader.py +++ b/tests/core/test_schema_loader.py @@ -1,5 +1,4 @@ import pytest -import logging import typing as t from pathlib import Path from unittest.mock import patch @@ -349,8 +348,7 @@ def test_missing_table(tmp_path: Path): model = SqlModel(name="a", query=parse_one("select * FROM tbl_source")) filename = tmp_path / c.EXTERNAL_MODELS_YAML - logger = logging.getLogger("sqlmesh.core.schema_loader") - with patch.object(logger, "warning") as mock_logger: + with patch.object(context.console, "log_warning") as mock_logger: create_external_models_file( filename, {"a": model}, # type: ignore diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py index a01f0f9b1..929245e41 100644 --- a/tests/core/test_state_sync.py +++ b/tests/core/test_state_sync.py @@ -880,7 +880,9 @@ def test_promote_environment_expired(state_sync: EngineAdapterStateSync, make_sn end_at="2022-01-01", plan_id="new_plan_id", previous_plan_id=None, # No previous plan ID since it's technically a new environment + expiration_ts=now_timestamp() + 3600, ) + assert new_environment.expiration_ts # This call shouldn't fail. promotion_result = state_sync.promote(new_environment) @@ -888,6 +890,16 @@ def test_promote_environment_expired(state_sync: EngineAdapterStateSync, make_sn assert promotion_result.removed == [] assert promotion_result.removed_environment_naming_info is None + state_sync.finalize(new_environment) + + new_environment.previous_plan_id = new_environment.plan_id + new_environment.plan_id = "another_plan_id" + promotion_result = state_sync.promote(new_environment) + # Should be empty since the environment is no longer expired and nothing has changed + assert promotion_result.added == [] + assert promotion_result.removed == [] + assert promotion_result.removed_environment_naming_info is None + def test_promote_snapshots_no_gaps(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): model = SqlModel( diff --git a/tests/core/test_test.py b/tests/core/test_test.py index d39f4fd04..5f632236e 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import datetime import typing as t from pathlib import Path @@ -21,6 +20,7 @@ ModelDefaultsConfig, ) from sqlmesh.core.context import Context +from sqlmesh.core.console import get_console from sqlmesh.core.dialect import parse from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.macros import MacroEvaluator, macro @@ -1771,8 +1771,7 @@ def test_unknown_model_warns(mocker: MockerFixture) -> None: """ ) - logger = logging.getLogger("sqlmesh.core.test.definition") - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: ModelTest.create_test( body=body, test_name="test_unknown_model", diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 72fc573f5..00c3ec573 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -15,6 +15,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.context import Context +from sqlmesh.core.console import get_console from sqlmesh.core.model import ( EmbeddedKind, FullKind, @@ -73,8 +74,7 @@ def test_materialization(): context.project_name = "Test" context.target = DuckDbConfig(name="target", schema="foo") - logger = logging.getLogger("sqlmesh.dbt.model") - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: model_config = ModelConfig( name="model", alias="model", schema="schema", materialized="materialized_view" ) @@ -1260,8 +1260,7 @@ def test_clickhouse_properties(mocker: MockerFixture): sql="""SELECT 1 AS one, ds FROM foo""", ) - logger = logging.getLogger("sqlmesh.dbt.model") - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: model_to_sqlmesh = model_config.to_sqlmesh(context) assert [call[0][0] for call in mock_logger.call_args_list] == [ @@ -1269,7 +1268,7 @@ def test_clickhouse_properties(mocker: MockerFixture): "SQLMesh does not support 'incremental_predicates' - they will not be applied.", "SQLMesh does not support the 'query_settings' model configuration parameter. Specify the query settings directly in the model query.", "SQLMesh does not support the 'sharding_key' model configuration parameter or distributed materializations.", - "Using unmanaged incremental materialization for model '%s'. Some features might not be available. Consider adding either a time_column (%s) or a unique_key (%s) configuration to mitigate this", + "Using unmanaged incremental materialization for model '`test`.`model`'. Some features might not be available. Consider adding either a time_column ('delete+insert', 'insert_overwrite') or a unique_key ('merge', 'none') configuration to mitigate this.", ] assert [e.sql("clickhouse") for e in model_to_sqlmesh.partitioned_by] == [ diff --git a/tests/integrations/github/cicd/fixtures.py b/tests/integrations/github/cicd/fixtures.py index 056bbdbac..a53c8b333 100644 --- a/tests/integrations/github/cicd/fixtures.py +++ b/tests/integrations/github/cicd/fixtures.py @@ -3,6 +3,7 @@ import pytest from pytest_mock.plugin import MockerFixture +from sqlmesh.core.console import set_console, MarkdownConsole from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig from sqlmesh.integrations.github.cicd.controller import ( GithubController, @@ -80,6 +81,7 @@ def _make_function( "sqlmesh.integrations.github.cicd.controller.GithubController.bot_config", bot_config, ) + set_console(MarkdownConsole()) return GithubController( paths=["examples/sushi"], token="abc", diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index ba12eb2d5..1157a064e 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -379,7 +379,7 @@ def test_merge_pr_has_non_breaking_change_diff_start( - `sushi.top_waiters` -**Models needing backfill [missing dates]:** +**Models needing backfill:** * `sushi.waiter_revenue_by_day`: [2022-12-25 - 2022-12-28] """ assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan @@ -1066,7 +1066,7 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( - `sushi.top_waiters` -**Models needing backfill [missing dates]:** +**Models needing backfill:** * `sushi.waiter_revenue_by_day`: [2022-12-25 - 2022-12-29] """ assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index 89916d57a..0caff8a51 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -77,9 +77,9 @@ def convert_all_html_output_to_tags(): def _convert_html_to_tags(html: str) -> t.List[str]: # BS4 automatically adds html and body tags so we remove those since they are not actually part of the output return [ - tag.name + tag.name # type: ignore for tag in BeautifulSoup(html, "html").find_all() - if tag.name not in {"html", "body"} + if tag.name not in {"html", "body"} # type: ignore ] def _convert(output: CapturedIO) -> t.List[t.List[str]]: @@ -604,7 +604,7 @@ def test_migrate( @pytest.mark.slow -def test_create_external_models(notebook, loaded_sushi_context): +def test_create_external_models(notebook, loaded_sushi_context, convert_all_html_output_to_text): external_model_file = loaded_sushi_context.path / "external_models.yaml" external_model_file.unlink() assert not external_model_file.exists() @@ -614,7 +614,11 @@ def test_create_external_models(notebook, loaded_sushi_context): assert not output.stdout assert not output.stderr - assert not output.outputs + assert len(output.outputs) == 2 + converted = sorted(convert_all_html_output_to_text(output)) + assert 'Unable to get schema for \'"memory"."raw"."model1"\'' in converted[0] + assert 'Unable to get schema for \'"memory"."raw"."model2"\'' in converted[1] + assert external_model_file.exists() assert ( external_model_file.read_text() diff --git a/tests/web/conftest.py b/tests/web/conftest.py index e7a731256..6b6fcaad2 100644 --- a/tests/web/conftest.py +++ b/tests/web/conftest.py @@ -4,6 +4,8 @@ from fastapi import FastAPI from sqlmesh.core.context import Context +from sqlmesh.core.console import set_console + from web.server.console import api_console from web.server.settings import Settings, get_loaded_context, get_settings @@ -34,7 +36,8 @@ def get_settings_override() -> Settings: @pytest.fixture def project_context(web_app: FastAPI, project_tmp_path: Path): - context = Context(paths=project_tmp_path, console=api_console) + set_console(api_console) + context = Context(paths=project_tmp_path) def get_loaded_context_override() -> Context: return context diff --git a/web/server/settings.py b/web/server/settings.py index f3e7ba548..8077abf02 100644 --- a/web/server/settings.py +++ b/web/server/settings.py @@ -63,9 +63,11 @@ def get_settings() -> Settings: @lru_cache() def _get_context(path: str | Path, config: str, gateway: str) -> Context: + from sqlmesh.core.console import set_console from web.server.main import api_console - return Context(paths=str(path), config=config, console=api_console, gateway=gateway, load=False) + set_console(api_console) + return Context(paths=str(path), config=config, gateway=gateway, load=False) @lru_cache()