Skip to content

Commit

Permalink
fix tests and format
Browse files Browse the repository at this point in the history
  • Loading branch information
lafirm committed Feb 9, 2025
1 parent 4576b7d commit 91d134b
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 18 deletions.
11 changes: 4 additions & 7 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlmesh.core.console import configure_console, get_console
from sqlmesh.core.config import load_configs
from sqlmesh.core.context import Context
from sqlmesh.utils.date import TimeLike, time_like_to_str
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import MissingDependencyError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -976,12 +976,9 @@ def dlt_refresh(
help="Prints the expiry datetime of the environments.",
default=False,
)
@click.pass_context
@click.pass_obj
@error_handler
@cli_analytics
def environments(ctx: click.Context, show_expiry: bool) -> None:
def environments(obj: Context, show_expiry: bool) -> None:
"""Prints the list of SQLMesh environments with its expiry datetime."""
context = ctx.obj
environment_names = context.state_sync.get_environment_names(get_expiry_ts=show_expiry)
output = [f"{name} - {time_like_to_str(ts)}" for name, ts in environment_names] if show_expiry else [name[0] for name in environment_names]
context.console.log_status_update(f"Number of SQLMesh environments are: {len(output)}\n{"\n".join(output)}")
obj.print_environment_names(show_expiry=show_expiry)
20 changes: 19 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike, now_ds, to_timestamp, format_tz_datetime
from sqlmesh.utils.date import TimeLike, now_ds, to_timestamp, format_tz_datetime, time_like_to_str
from sqlmesh.utils.errors import (
CircuitBreakerError,
ConfigError,
Expand Down Expand Up @@ -1972,6 +1972,24 @@ def print_info(self, skip_connection: bool = False, verbose: bool = False) -> No
if state_connection:
self._try_connection("state backend", state_connection.connection_validator())

@python_api_analytics
def print_environment_names(self, show_expiry: bool) -> None:
"""Prints all environment names along with expiry datetime if show_expiry is True."""
environment_names = self._new_state_sync().get_environment_names(get_expiry_ts=show_expiry)
if not environment_names:
error_msg = "Environments were not found."
raise SQLMeshError(error_msg)
output = (
[
f"{name} - {time_like_to_str(ts)}" if ts else f"{name} - No Expiry"
for name, ts in environment_names
]
if show_expiry
else [name[0] for name in environment_names]
)
output_str = "\n".join([str(len(output)), *output])
self.console.log_status_update(f"Number of SQLMesh environments are: {output_str}")

def close(self) -> None:
"""Releases all resources allocated by this context."""
if self._snapshot_evaluator:
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def get_environments(self) -> t.List[Environment]:
"""

@abc.abstractmethod
def get_environment_names(self, get_expiry_ts: bool = True) -> t.List[t.Tuple[str]] | t.List[t.Tuple[str, int]]:
def get_environment_names(
self, get_expiry_ts: bool = True
) -> t.Optional[t.List[t.Tuple[str, ...]]]:
"""Fetches all environment names along with expiry datetime if get_expiry_ts is True.
Returns:
Expand Down
14 changes: 10 additions & 4 deletions sqlmesh/core/state_sync/engine_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,20 @@ def get_environments(self) -> t.List[Environment]:
self._environment_from_row(row) for row in self._fetchall(self._environments_query())
]

def get_environment_names(self, get_expiry_ts: bool = True) -> t.List[t.Tuple[str]] | t.List[t.Tuple[str, int]]:
def get_environment_names(
self, get_expiry_ts: bool = True
) -> t.Optional[t.List[t.Tuple[str, ...]]]:
"""Fetches all environment names along with expiry datetime if get_expiry_ts is True.
Returns:
A list of all environment names along with expiry datetime if get_expiry_ts is True.
"""
name_field = ["name"]
return self._fetchall(self._environments_query(required_fields=name_field if not get_expiry_ts else name_field + ["expiration_ts"]))
return self._fetchall(
self._environments_query(
required_fields=name_field if not get_expiry_ts else name_field + ["expiration_ts"]
),
)

def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment:
return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())})
Expand All @@ -751,9 +757,9 @@ def _environments_query(
lock_for_update: bool = False,
required_fields: t.Optional[t.List[str]] = None,
) -> exp.Select:
required_fields = required_fields if required_fields else Environment.all_fields()
query_fields = required_fields if required_fields else Environment.all_fields()
query = (
exp.select(*(exp.to_identifier(field) for field in required_fields))
exp.select(*(exp.to_identifier(field) for field in query_fields))
.from_(self.environments_table)
.where(where)
)
Expand Down
8 changes: 3 additions & 5 deletions sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sqlmesh.core.dialect import format_model_expressions, parse
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests
from sqlmesh.utils import date, sqlglot_dialects, yaml
from sqlmesh.utils import sqlglot_dialects, yaml
from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def clean(self, context: Context, line: str) -> None:

@magic_arguments()
@argument(
"--expiry-ds",
"--show-expiry",
"-e",
action="store_true",
help="Prints the expiration datetime of the environments.",
Expand All @@ -1021,9 +1021,7 @@ def clean(self, context: Context, line: str) -> None:
def environments(self, context: Context, line: str) -> None:
"""Prints the list of SQLMesh environments with its expiry datetime."""
args = parse_argstring(self.environments, line)
environment_names = context.state_sync.get_environment_names(get_expiry_ts=args.show_expiry)
output = [f"{name} - {date.time_like_to_str(ts)}" for name, ts in environment_names] if args.show_expiry else [name[0] for name in environment_names]
context.console.log_status_update(f"Number of SQLMesh environments are: {len(output)}\n{"\n".join(output)}")
context.print_environment_names(show_expiry=args.show_expiry)


def register_magics() -> None:
Expand Down
8 changes: 8 additions & 0 deletions sqlmesh/schedulers/airflow/state_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_environments(self) -> t.List[Environment]:
"""
return self._client.get_environments()

def get_environment_names(
self, get_expiry_ts: bool = True
) -> t.Optional[t.List[t.Tuple[str, ...]]]:
"""Fetches all environment names along with expiry datetime if get_expiry_ts is True."""
raise NotImplementedError(
"get_environment_names method is not implemented for the Airflow state sync."
)

def max_interval_end_per_model(
self,
environment: str,
Expand Down
47 changes: 47 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,19 @@ def test_environments(runner, tmp_path):
],
)

result = runner.invoke(
cli,
[
"--log-file-dir",
tmp_path,
"--paths",
tmp_path,
"environments",
],
)
assert result.exit_code == 0
assert result.output == "Number of SQLMesh environments are: 1\ndev\n"

# # create dev2 environment from dev environment
# # Input: `y` to apply and virtual update
runner.invoke(
Expand Down Expand Up @@ -1035,3 +1048,37 @@ def test_environments(runner, tmp_path):
assert result.exit_code == 0
ttl = time_like_to_str(to_datetime(now_ds()) + timedelta(days=7))
assert result.output == f"Number of SQLMesh environments are: 2\ndev - {ttl}\ndev2 - {ttl}\n"

# Example project models have start dates, so there are no date prompts
# for the `prod` environment.
# Input: `y` to apply and backfill
runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan"], input="y\n")
result = runner.invoke(
cli,
[
"--log-file-dir",
tmp_path,
"--paths",
tmp_path,
"environments",
],
)
assert result.exit_code == 0
assert result.output == "Number of SQLMesh environments are: 3\ndev\ndev2\nprod\n"

result = runner.invoke(
cli,
[
"--log-file-dir",
tmp_path,
"--paths",
tmp_path,
"environments",
"--show-expiry",
],
)
assert result.exit_code == 0
assert (
result.output
== f"Number of SQLMesh environments are: 3\ndev - {ttl}\ndev2 - {ttl}\nprod - No Expiry\n"
)

0 comments on commit 91d134b

Please sign in to comment.