Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AsyncTyper class to the Python SDK #2453

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions docs/docs/infrahubctl/infrahubctl-schema.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ $ infrahubctl schema [OPTIONS] COMMAND [ARGS]...
**Commands**:

* `load`: Load a schema file into Infrahub.
* `migrate`: Migrate the schema to the latest version.

## `infrahubctl schema load`

Expand All @@ -39,17 +38,3 @@ $ infrahubctl schema load [OPTIONS] SCHEMAS...
* `--branch TEXT`: Branch on which to load the schema. [default: main]
* `--config-file TEXT`: [env var: INFRAHUBCTL_CONFIG; default: infrahubctl.toml]
* `--help`: Show this message and exit.

## `infrahubctl schema migrate`

Migrate the schema to the latest version. (Not Implemented Yet)

**Usage**:

```console
$ infrahubctl schema migrate [OPTIONS]
```

**Options**:

* `--help`: Show this message and exit.
31 changes: 31 additions & 0 deletions python_sdk/infrahub_sdk/async_typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import asyncio
import inspect
from functools import partial, wraps
from typing import Any, Callable

from typer import Typer


class AsyncTyper(Typer):
@staticmethod
def maybe_run_async(decorator: Callable, func: Callable) -> Any:
if inspect.iscoroutinefunction(func):

@wraps(func)
def runner(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(func(*args, **kwargs))

decorator(runner)
else:
decorator(func)
return func

def callback(self, *args: Any, **kwargs: Any) -> Any:
decorator = super().callback(*args, **kwargs)
return partial(self.maybe_run_async, decorator)

def command(self, *args: Any, **kwargs: Any) -> Any:
decorator = super().command(*args, **kwargs)
return partial(self.maybe_run_async, decorator)
128 changes: 44 additions & 84 deletions python_sdk/infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import sys
from asyncio import run as aiorun
from datetime import datetime
from pathlib import Path
from typing import Dict, Generator, List, Optional, Union
from typing import Dict, Generator, List, Optional

import typer
from rich.console import Console
Expand All @@ -12,6 +11,7 @@
from rich.table import Table

from infrahub_sdk import Error, GraphQLError
from infrahub_sdk.async_typer import AsyncTyper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why we don't want to have this under infrahub_sdk.ctl instead? Is the intention that the server components would also load the same class from this location?. To me it feels cleaner to have it under .ctl.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal is to have it available for other cli tools not just infrahubctl

from infrahub_sdk.ctl import config
from infrahub_sdk.ctl.client import initialize_client
from infrahub_sdk.ctl.utils import (
Expand All @@ -20,7 +20,7 @@
render_action_rich,
)

app = typer.Typer()
app = AsyncTyper()


DEFAULT_CONFIG_FILE = "infrahubctl.toml"
Expand All @@ -36,7 +36,17 @@ def callback() -> None:
"""


async def _list() -> None:
@app.command("list")
async def list_branch(
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""List all existing branches."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

client = await initialize_client()

console = Console()
Expand Down Expand Up @@ -86,21 +96,20 @@ async def _list() -> None:
console.print(table)


@app.command("list")
def list_branch(
@app.command()
async def create(
branch_name: str,
description: str = typer.Argument(""),
data_only: bool = True,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""List all existing branches."""
"""Create a new branch."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_list())


async def _create(branch_name: str, description: str, data_only: bool) -> None:
console = Console()

client = await initialize_client()
Expand All @@ -118,23 +127,17 @@ async def _create(branch_name: str, description: str, data_only: bool) -> None:


@app.command()
def create(
async def delete(
branch_name: str,
description: str = typer.Argument(""),
data_only: bool = True,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Create a new branch."""
"""Delete a branch."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_create(branch_name=branch_name, description=description, data_only=data_only))


async def _delete(branch_name: str) -> None:
console = Console()

client = await initialize_client()
Expand All @@ -152,21 +155,17 @@ async def _delete(branch_name: str) -> None:


@app.command()
def delete(
async def rebase(
branch_name: str,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Delete a branch."""
"""Rebase a Branch with main."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_delete(branch_name=branch_name))


async def _rebase(branch_name: str) -> None:
console = Console()

client = await initialize_client()
Expand All @@ -184,21 +183,17 @@ async def _rebase(branch_name: str) -> None:


@app.command()
def rebase(
async def merge(
branch_name: str,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Rebase a Branch with main."""
"""Merge a Branch with main."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_rebase(branch_name=branch_name))


async def _merge(branch_name: str) -> None:
console = Console()

client = await initialize_client()
Expand All @@ -216,21 +211,15 @@ async def _merge(branch_name: str) -> None:


@app.command()
def merge(
async def validate(
branch_name: str,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Merge a Branch with main."""

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
"""Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET)."""

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_merge(branch_name=branch_name))


async def _validate(branch_name: str) -> None:
console = Console()

client = await initialize_client()
Expand All @@ -247,19 +236,6 @@ async def _validate(branch_name: str) -> None:
console.print(f"Branch '{branch_name}' is valid.")


@app.command()
def validate(
branch_name: str,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET)."""

if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

aiorun(_validate(branch_name=branch_name))


@rich_group()
def node_panel_generator(nodes: List[Dict]) -> Generator:
for node in nodes:
Expand Down Expand Up @@ -320,12 +296,24 @@ def node_panel_generator(nodes: List[Dict]) -> Generator:
)


async def _diff(
@app.command()
async def diff(
branch_name: str,
time_from: Union[str, datetime],
time_to: Union[str, datetime],
branch_only: bool,
time_from: Optional[datetime] = typer.Option(
None,
"--from",
help="Start Time used to calculate the Diff, Default: from the start of the branch",
),
time_to: Optional[datetime] = typer.Option(None, "--to", help="End Time used to calculate the Diff, Default: now"),
branch_only: bool = True,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Show the differences between a Branch and main."""
if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

console = Console()

client = await initialize_client()
Expand All @@ -349,31 +337,3 @@ async def _diff(
title_align="left",
)
)


@app.command()
def diff(
branch_name: str,
time_from: Optional[datetime] = typer.Option(
None,
"--from",
help="Start Time used to calculate the Diff, Default: from the start of the branch",
),
time_to: Optional[datetime] = typer.Option(None, "--to", help="End Time used to calculate the Diff, Default: now"),
branch_only: bool = True,
config_file: Path = typer.Option(DEFAULT_CONFIG_FILE, envvar=ENVVAR_CONFIG_FILE),
) -> None:
"""Show the differences between a Branch and main."""
if not config.SETTINGS:
config.load_and_exit(config_file=config_file)

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

aiorun(
_diff(
branch_name=branch_name,
time_from=time_from or "",
time_to=time_to or "",
branch_only=branch_only,
)
)
Loading
Loading