diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 4ea22b447..03ebfa49c 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -15,7 +15,7 @@ from sqlmesh.core.console import configure_console, get_console from sqlmesh.core.config import load_configs from sqlmesh.core.context import Context -from sqlmesh.utils.date import TimeLike, time_like_to_str +from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import MissingDependencyError logger = logging.getLogger(__name__) @@ -976,12 +976,9 @@ def dlt_refresh( help="Prints the expiry datetime of the environments.", default=False, ) -@click.pass_context +@click.pass_obj @error_handler @cli_analytics -def environments(ctx: click.Context, show_expiry: bool) -> None: +def environments(obj: Context, show_expiry: bool) -> None: """Prints the list of SQLMesh environments with its expiry datetime.""" - context = ctx.obj - environment_names = context.state_sync.get_environment_names(get_expiry_ts=show_expiry) - output = [f"{name} - {time_like_to_str(ts)}" for name, ts in environment_names] if show_expiry else [name[0] for name in environment_names] - context.console.log_status_update(f"Number of SQLMesh environments are: {len(output)}\n{"\n".join(output)}") + obj.print_environment_names(show_expiry=show_expiry) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 9869096b4..f755a0b6c 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -113,7 +113,7 @@ from sqlmesh.core.user import User from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now_ds, to_timestamp, format_tz_datetime +from sqlmesh.utils.date import TimeLike, now_ds, to_timestamp, format_tz_datetime, time_like_to_str from sqlmesh.utils.errors import ( CircuitBreakerError, ConfigError, @@ -1972,6 +1972,24 @@ def print_info(self, skip_connection: bool = False, verbose: bool = False) -> No if state_connection: self._try_connection("state backend", state_connection.connection_validator()) + @python_api_analytics + def print_environment_names(self, show_expiry: bool) -> None: + """Prints all environment names along with expiry datetime if show_expiry is True.""" + environment_names = self._new_state_sync().get_environment_names(get_expiry_ts=show_expiry) + if not environment_names: + error_msg = "Environments were not found." + raise SQLMeshError(error_msg) + output = ( + [ + f"{name} - {time_like_to_str(ts)}" if ts else f"{name} - No Expiry" + for name, ts in environment_names + ] + if show_expiry + else [name[0] for name in environment_names] + ) + output_str = "\n".join([str(len(output)), *output]) + self.console.log_status_update(f"Number of SQLMesh environments are: {output_str}") + def close(self) -> None: """Releases all resources allocated by this context.""" if self._snapshot_evaluator: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index e2065e1e1..2a61f5aab 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -136,7 +136,9 @@ def get_environments(self) -> t.List[Environment]: """ @abc.abstractmethod - def get_environment_names(self, get_expiry_ts: bool = True) -> t.List[t.Tuple[str]] | t.List[t.Tuple[str, int]]: + def get_environment_names( + self, get_expiry_ts: bool = True + ) -> t.Optional[t.List[t.Tuple[str, ...]]]: """Fetches all environment names along with expiry datetime if get_expiry_ts is True. Returns: diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py index a147fac0a..0ad555242 100644 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ b/sqlmesh/core/state_sync/engine_adapter.py @@ -733,14 +733,20 @@ def get_environments(self) -> t.List[Environment]: self._environment_from_row(row) for row in self._fetchall(self._environments_query()) ] - def get_environment_names(self, get_expiry_ts: bool = True) -> t.List[t.Tuple[str]] | t.List[t.Tuple[str, int]]: + def get_environment_names( + self, get_expiry_ts: bool = True + ) -> t.Optional[t.List[t.Tuple[str, ...]]]: """Fetches all environment names along with expiry datetime if get_expiry_ts is True. Returns: A list of all environment names along with expiry datetime if get_expiry_ts is True. """ name_field = ["name"] - return self._fetchall(self._environments_query(required_fields=name_field if not get_expiry_ts else name_field + ["expiration_ts"])) + return self._fetchall( + self._environments_query( + required_fields=name_field if not get_expiry_ts else name_field + ["expiration_ts"] + ), + ) def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) @@ -751,9 +757,9 @@ def _environments_query( lock_for_update: bool = False, required_fields: t.Optional[t.List[str]] = None, ) -> exp.Select: - required_fields = required_fields if required_fields else Environment.all_fields() + query_fields = required_fields if required_fields else Environment.all_fields() query = ( - exp.select(*(exp.to_identifier(field) for field in required_fields)) + exp.select(*(exp.to_identifier(field) for field in query_fields)) .from_(self.environments_table) .where(where) ) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 9b4ac86e2..18fe79464 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -29,7 +29,7 @@ from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests -from sqlmesh.utils import date, sqlglot_dialects, yaml +from sqlmesh.utils import sqlglot_dialects, yaml from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError logger = logging.getLogger(__name__) @@ -1011,7 +1011,7 @@ def clean(self, context: Context, line: str) -> None: @magic_arguments() @argument( - "--expiry-ds", + "--show-expiry", "-e", action="store_true", help="Prints the expiration datetime of the environments.", @@ -1021,9 +1021,7 @@ def clean(self, context: Context, line: str) -> None: def environments(self, context: Context, line: str) -> None: """Prints the list of SQLMesh environments with its expiry datetime.""" args = parse_argstring(self.environments, line) - environment_names = context.state_sync.get_environment_names(get_expiry_ts=args.show_expiry) - output = [f"{name} - {date.time_like_to_str(ts)}" for name, ts in environment_names] if args.show_expiry else [name[0] for name in environment_names] - context.console.log_status_update(f"Number of SQLMesh environments are: {len(output)}\n{"\n".join(output)}") + context.print_environment_names(show_expiry=args.show_expiry) def register_magics() -> None: diff --git a/sqlmesh/schedulers/airflow/state_sync.py b/sqlmesh/schedulers/airflow/state_sync.py index 2b58395ad..5b9af9986 100644 --- a/sqlmesh/schedulers/airflow/state_sync.py +++ b/sqlmesh/schedulers/airflow/state_sync.py @@ -68,6 +68,14 @@ def get_environments(self) -> t.List[Environment]: """ return self._client.get_environments() + def get_environment_names( + self, get_expiry_ts: bool = True + ) -> t.Optional[t.List[t.Tuple[str, ...]]]: + """Fetches all environment names along with expiry datetime if get_expiry_ts is True.""" + raise NotImplementedError( + "get_environment_names method is not implemented for the Airflow state sync." + ) + def max_interval_end_per_model( self, environment: str, diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 0561ca015..4e5945d2a 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -990,6 +990,19 @@ def test_environments(runner, tmp_path): ], ) + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + ], + ) + assert result.exit_code == 0 + assert result.output == "Number of SQLMesh environments are: 1\ndev\n" + # # create dev2 environment from dev environment # # Input: `y` to apply and virtual update runner.invoke( @@ -1035,3 +1048,37 @@ def test_environments(runner, tmp_path): assert result.exit_code == 0 ttl = time_like_to_str(to_datetime(now_ds()) + timedelta(days=7)) assert result.output == f"Number of SQLMesh environments are: 2\ndev - {ttl}\ndev2 - {ttl}\n" + + # Example project models have start dates, so there are no date prompts + # for the `prod` environment. + # Input: `y` to apply and backfill + runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan"], input="y\n") + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + ], + ) + assert result.exit_code == 0 + assert result.output == "Number of SQLMesh environments are: 3\ndev\ndev2\nprod\n" + + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + "--show-expiry", + ], + ) + assert result.exit_code == 0 + assert ( + result.output + == f"Number of SQLMesh environments are: 3\ndev - {ttl}\ndev2 - {ttl}\nprod - No Expiry\n" + )