-
Notifications
You must be signed in to change notification settings - Fork 333
[Feat] Enable registered entity summary output and quiet flag in pyflyte register #3028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 6 commits
f51be71
104783e
9bba92f
f1be205
db9f744
a3734e4
42510ba
8a82a78
05ab275
386eea0
4b2b3b9
cf16cb4
43704b8
4068ad0
dc490a5
5fcf464
aad07e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
import asyncio | ||
import functools | ||
import json | ||
import os | ||
import tarfile | ||
import tempfile | ||
import typing | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
|
||
import click | ||
import yaml | ||
from rich import print as rprint | ||
|
||
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings | ||
|
@@ -22,6 +25,8 @@ | |
from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities | ||
from flytekit.tools.translator import FlyteControlPlaneEntity, Options | ||
|
||
original_secho = click.secho | ||
|
||
|
||
class NoSerializableEntitiesError(Exception): | ||
pass | ||
|
@@ -237,6 +242,20 @@ def print_registration_status( | |
rprint(f"[{color}]{state_ind} {name}: {i.name} (Failed)") | ||
|
||
|
||
@contextmanager | ||
def temporary_secho(): | ||
""" | ||
Temporarily restores the original click.secho function. | ||
Useful when you need to temporarily disable quiet mode. | ||
""" | ||
current_secho = click.secho | ||
try: | ||
click.secho = original_secho | ||
yield | ||
finally: | ||
click.secho = current_secho | ||
|
||
|
||
def register( | ||
project: str, | ||
domain: str, | ||
|
@@ -251,6 +270,8 @@ def register( | |
remote: FlyteRemote, | ||
copy_style: CopyFileDetection, | ||
env: typing.Optional[typing.Dict[str, str]], | ||
summary_format: typing.Optional[str], | ||
quiet: bool = False, | ||
dry_run: bool = False, | ||
activate_launchplans: bool = False, | ||
skip_errors: bool = False, | ||
|
@@ -261,6 +282,10 @@ def register( | |
Temporarily, for fast register, specify both the fast arg as well as copy_style. | ||
fast == True with copy_style == None means use the old fast register tar'ring method. | ||
""" | ||
|
||
if quiet: | ||
click.secho = lambda *args, **kw: None | ||
|
||
detected_root = find_common_root(package_or_module) | ||
click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") | ||
|
||
|
@@ -316,6 +341,7 @@ def register( | |
registrable_entities = serialize_get_control_plane_entities( | ||
serialization_settings, str(detected_root), options, is_registration=True | ||
) | ||
|
||
click.secho( | ||
f"Serializing and registering {len(registrable_entities)} flyte entities", | ||
fg="green", | ||
|
@@ -333,6 +359,14 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): | |
is_lp = True | ||
else: | ||
og_id = cp_entity.template.id | ||
|
||
result = { | ||
"id": og_id.name, | ||
"type": og_id.resource_type_name(), | ||
"version": og_id.version, | ||
"status": "skipped", # default status | ||
} | ||
|
||
try: | ||
if not dry_run: | ||
try: | ||
|
@@ -347,33 +381,56 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): | |
print_activation_message = True | ||
if cp_entity.should_auto_activate: | ||
print_activation_message = True | ||
print_registration_status( | ||
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message | ||
) | ||
if not quiet: | ||
print_registration_status( | ||
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message | ||
) | ||
result["status"] = "success" | ||
|
||
except Exception as e: | ||
if not skip_errors: | ||
raise e | ||
print_registration_status(og_id, success=False) | ||
if not quiet: | ||
print_registration_status(og_id, success=False) | ||
result["status"] = "failed" | ||
|
||
else: | ||
print_registration_status(og_id, dry_run=True) | ||
if not quiet: | ||
print_registration_status(og_id, dry_run=True) | ||
except RegistrationSkipped: | ||
print_registration_status(og_id, success=False) | ||
if not quiet: | ||
print_registration_status(og_id, success=False) | ||
result["status"] = "skipped" | ||
|
||
return result | ||
|
||
async def _register(entities: typing.List[task.TaskSpec]): | ||
loop = asyncio.get_running_loop() | ||
tasks = [] | ||
for entity in entities: | ||
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) | ||
await asyncio.gather(*tasks) | ||
return | ||
results = await asyncio.gather(*tasks) | ||
return results | ||
|
||
# concurrent register | ||
cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) | ||
asyncio.run(_register(cp_task_entities)) | ||
task_results = asyncio.run(_register(cp_task_entities)) | ||
# serial register | ||
cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) | ||
other_results = [] | ||
for entity in cp_other_entities: | ||
_raw_register(entity) | ||
other_results.append(_raw_register(entity)) | ||
|
||
all_results = task_results + other_results | ||
|
||
click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") | ||
|
||
if summary_format is not None: | ||
supported_format = {"json", "yaml"} | ||
if summary_format not in supported_format: | ||
raise ValueError(f"Unsupported file format: {summary_format}") | ||
|
||
with temporary_secho(): | ||
if summary_format == "json": | ||
click.secho(json.dumps(all_results, indent=2)) | ||
elif summary_format == "yaml": | ||
click.secho(yaml.dump(all_results)) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,6 +1,8 @@ | ||||||||||||||||
import os | ||||||||||||||||
import shutil | ||||||||||||||||
import subprocess | ||||||||||||||||
import json | ||||||||||||||||
import yaml | ||||||||||||||||
|
||||||||||||||||
import mock | ||||||||||||||||
import pytest | ||||||||||||||||
|
@@ -163,3 +165,99 @@ def test_non_fast_register_require_version(mock_client, mock_remote): | |||||||||||||||
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"]) | ||||||||||||||||
assert result.exit_code == 1 | ||||||||||||||||
shutil.rmtree("core3") | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) | ||||||||||||||||
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) | ||||||||||||||||
def test_register_registrated_summary_json(mock_client, mock_remote): | ||||||||||||||||
ctx = FlyteContextManager.current_context() | ||||||||||||||||
mock_remote._client = mock_client | ||||||||||||||||
mock_remote.return_value.context = ctx | ||||||||||||||||
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" | ||||||||||||||||
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" | ||||||||||||||||
runner = CliRunner() | ||||||||||||||||
context_manager.FlyteEntities.entities.clear() | ||||||||||||||||
|
||||||||||||||||
with runner.isolated_filesystem(): | ||||||||||||||||
out = subprocess.run(["git", "init"], capture_output=True) | ||||||||||||||||
assert out.returncode == 0 | ||||||||||||||||
os.makedirs("core5", exist_ok=True) | ||||||||||||||||
with open(os.path.join("core5", "sample.py"), "w") as f: | ||||||||||||||||
f.write(sample_file_contents) | ||||||||||||||||
f.close() | ||||||||||||||||
|
||||||||||||||||
result = runner.invoke( | ||||||||||||||||
pyflyte.main, | ||||||||||||||||
["register", "--summary-format", "json", "core5"] | ||||||||||||||||
) | ||||||||||||||||
assert result.exit_code == 0 | ||||||||||||||||
summary_data = json.loads(result.output) | ||||||||||||||||
|
summary_data = json.loads(result.output) | |
try: | |
summary_data = json.loads(result.output) | |
except json.JSONDecodeError as e: | |
pytest.fail(f"Failed to parse registration summary JSON: {e}") | |
except Exception as e: | |
pytest.fail(f"Unexpected error while parsing registration summary: {e}") |
Code Review Run #5721dc
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding error handling when parsing YAML from result.output
. The yaml.safe_load()
could raise yaml.YAMLError
if the output is not valid YAML.
Code suggestion
Check the AI-generated fix before applying
summary_data = yaml.safe_load(result.output) | |
try: | |
summary_data = yaml.safe_load(result.output) | |
except yaml.YAMLError as e: | |
pytest.fail(f"Failed to parse YAML output: {e}") | |
Code Review Run #5721dc
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding more detailed status information in the
result
dictionary. The current status field only captures high-level states ('skipped', 'success', 'failed'). Additional fields likeerror_message
andtimestamp
could provide more context for debugging and monitoring.Code suggestion
Code Review Run #9a3edb
Is this a valid issue, or was it incorrectly flagged by the Agent?