Skip to content

Commit

Permalink
NA+SPCS PoC
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fcampbell committed Nov 1, 2024
1 parent f4c69a8 commit 2f3309e
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from typing import List, Literal, Optional, Union

import typer
import yaml
from click import BadOptionUsage, ClickException
from pydantic import Field, field_validator
from snowflake.cli._plugins.nativeapp.artifacts import (
BundleMap,
build_bundle,
find_manifest_file,
find_setup_script_file,
find_version_info_in_manifest_file,
)
from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
Expand Down Expand Up @@ -40,6 +43,10 @@
PolicyBase,
)
from snowflake.cli._plugins.nativeapp.utils import needs_confirmation
from snowflake.cli._plugins.spcs.entities.service import (
ServiceEntity,
ServiceEntityModel,
)
from snowflake.cli._plugins.stage.diff import DiffResult
from snowflake.cli._plugins.stage.manager import StageManager
from snowflake.cli._plugins.workspace.context import ActionContext
Expand Down Expand Up @@ -109,6 +116,12 @@ class ApplicationPackageEntityModel(EntityModelBase):
title="Path to manifest.yml",
)

### SPCS PoC
services: list[str] = Field(
title="List of Snowpark Container Service entity IDs to integrate into this application package",
default=[],
)

@field_validator("identifier")
@classmethod
def append_test_resource_suffix_to_identifier(
Expand Down Expand Up @@ -191,7 +204,8 @@ def post_deploy_hooks(self) -> list[PostDeployHook] | None:
return model.meta and model.meta.post_deploy

def action_bundle(self, action_ctx: ActionContext, *args, **kwargs):
return self._bundle()
spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services]
return self._bundle(spcs_services=spcs_services)

def action_deploy(
self,
Expand All @@ -206,6 +220,7 @@ def action_deploy(
*args,
**kwargs,
):
spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services]
return self._deploy(
bundle_map=None,
prune=prune,
Expand All @@ -216,6 +231,7 @@ def action_deploy(
stage_fqn=stage_fqn or self.stage_fqn,
interactive=interactive,
force=force,
spcs_services=spcs_services,
)

def action_drop(self, action_ctx: ActionContext, force_drop: bool, *args, **kwargs):
Expand Down Expand Up @@ -357,6 +373,8 @@ def action_version_create(
else:
git_policy = AllowAlwaysPolicy()

spcs_services = [action_ctx.get_entity(s) for s in self._entity_model.services]

# Make sure version is not None before proceeding any further.
# This will raise an exception if version information is not found. Patch can be None.
bundle_map = None
Expand All @@ -369,7 +387,7 @@ def action_version_create(
"""
)
)
bundle_map = self._bundle()
bundle_map = self._bundle(spcs_services=spcs_services)
version, patch = find_version_info_in_manifest_file(self.deploy_root)
if not version:
raise ClickException(
Expand Down Expand Up @@ -403,6 +421,7 @@ def action_version_create(
stage_fqn=self.stage_fqn,
interactive=interactive,
force=force,
spcs_services=spcs_services,
)

# Warn if the version exists in a release directive(s)
Expand Down Expand Up @@ -489,7 +508,7 @@ def action_version_drop(
"""
)
)
self._bundle()
self._bundle(spcs_services=[])
version, _ = find_version_info_in_manifest_file(self.deploy_root)
if not version:
raise ClickException(
Expand Down Expand Up @@ -533,7 +552,7 @@ def action_version_drop(
f"Version {version} in application package {self.name} dropped successfully."
)

def _bundle(self):
def _bundle(self, spcs_services: list[ServiceEntity]):
model = self._entity_model
bundle_map = build_bundle(self.project_root, self.deploy_root, model.artifacts)
bundle_context = BundleContext(
Expand All @@ -546,8 +565,76 @@ def _bundle(self):
)
compiler = NativeAppCompiler(bundle_context)
compiler.compile_artifacts()

# TODO should this merged into NativeAppCompiler?
self._inject_spcs(spcs_services)

# TODO should we create a post-deploy script that automatically
# grants CREATE COMPUTE POOL and BIND SERVICE ENDPOINT to the app?

return bundle_map

def _inject_spcs(self, spcs_services: list[ServiceEntity]):
manifest_path = find_manifest_file(self.deploy_root)
manifest = yaml.safe_load(manifest_path.read_text())
if "configuration" not in manifest:
manifest["configuration"] = {}
existing_grant_callback = manifest["configuration"].get("grant_callback")
wrapper_grant_callback = "_spcs_generation.grant_callback"
manifest["configuration"]["grant_callback"] = wrapper_grant_callback
# TODO set default_web_endpoint in manifest?
if manifest_path.is_symlink():
manifest_path.unlink()
manifest_path.write_text(yaml.safe_dump(manifest, sort_keys=False))

generated_setup_script = self._spcs_grant_callback(
name=wrapper_grant_callback,
service=spcs_services[0]._entity_model, # noqa SLF001
existing_grant_callback=existing_grant_callback,
)

setup_script_path = find_setup_script_file(self.deploy_root)
setup_script = setup_script_path.read_text()
if setup_script_path.is_symlink():
setup_script_path.unlink()
setup_script_path.write_text(setup_script + generated_setup_script)

@staticmethod
def _spcs_grant_callback(
name: str, service: ServiceEntityModel, existing_grant_callback: str
):
return dedent(
f"""\
-- Begin generated SPCS services, this section is managed by the Snowflake CLI
create schema if not exists _spcs_generation;
create or replace procedure {name}(privileges array)
returns string
as $$
begin
{f'call {existing_grant_callback}(privileges);' if existing_grant_callback else ''}
if (array_contains('CREATE COMPUTE POOL'::variant, privileges)) then
create compute pool if not exists {service.compute_pool}
min_nodes = {service.min_nodes}
max_nodes = {service.max_nodes}
instance_family = {service.instance_family};
end if;
if (array_contains('BIND SERVICE ENDPOINT'::variant, privileges)) then
create service if not exists _spcs_generation.{service.fqn.name}
in compute pool {service.compute_pool}
from specification_file = '{service.specification_file}';
end if;
return 'done';
end;
$$;
create application role if not exists _spcs_generation_role;
grant usage on procedure {name}(array) to application role _spcs_generation_role;
-- End generated SPCS services
"""
)

def _deploy(
self,
bundle_map: BundleMap | None,
Expand All @@ -559,6 +646,7 @@ def _deploy(
stage_fqn: str,
interactive: bool,
force: bool,
spcs_services: list[ServiceEntity],
run_post_deploy_hooks: bool = True,
) -> DiffResult:
model = self._entity_model
Expand All @@ -574,7 +662,7 @@ def _deploy(
stage_fqn = stage_fqn or self.stage_fqn

# 1. Create a bundle if one wasn't passed in
bundle_map = bundle_map or self._bundle()
bundle_map = bundle_map or self._bundle(spcs_services=spcs_services)

# 2. Create an empty application package, if none exists
try:
Expand Down Expand Up @@ -932,6 +1020,7 @@ def get_validation_result(
stage_fqn=self.scratch_stage_fqn,
interactive=interactive,
force=force,
spcs_services=[], # TODO this affects the setup script, but it's under our control
run_post_deploy_hooks=False,
)
prefixed_stage_fqn = StageManager.get_standard_stage_prefix(stage_fqn)
Expand Down
16 changes: 12 additions & 4 deletions src/snowflake/cli/_plugins/nativeapp/v2_conversions/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,21 @@ def wrapper(*args, **kwargs):
app_definition, app_package_definition = _find_app_and_package_entities(
original_pdf, package_entity_id, app_entity_id, app_required
)
entities_to_keep = {app_package_definition.entity_id}
native_app_entities_to_keep = {app_package_definition.entity_id}
kwargs["package_entity_id"] = app_package_definition.entity_id
if app_definition:
entities_to_keep.add(app_definition.entity_id)
native_app_entities_to_keep.add(app_definition.entity_id)
kwargs["app_entity_id"] = app_definition.entity_id
for entity_id in list(original_pdf.entities):
if entity_id not in entities_to_keep:

native_app_entity_classes = (
ApplicationEntityModel,
ApplicationPackageEntityModel,
)
for entity_id, entity_model in list(original_pdf.entities.items()):
if (
isinstance(entity_model, native_app_entity_classes)
and entity_id not in native_app_entities_to_keep
):
# This happens after templates are rendered,
# so we can safely remove the entity
del original_pdf.entities[entity_id]
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions src/snowflake/cli/_plugins/spcs/entities/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Literal

from pydantic import Field
from snowflake.cli.api.entities.common import EntityBase
from snowflake.cli.api.project.schemas.entities.common import EntityModelBase
from snowflake.cli.api.project.schemas.updatable_model import DiscriminatorField


class ServiceEntityModel(EntityModelBase):
type: Literal["snowpark container service"] = DiscriminatorField() # noqa: A003
specification_file: str = Field(
title="Path to the specification file for the SPCS service, relative to the deploy root",
)
# TODO is a compute pool a separate entity?
compute_pool: str = Field(
title="Name of the compute pool to use for the SPCS service",
)
min_nodes: int = Field(
title="Minimum number of nodes in the compute pool",
default=1,
)
max_nodes: int = Field(
title="Maximum number of nodes in the compute pool",
default=1,
)
instance_family: str = Field(
title="Instance family to use for the compute pool",
default="CPU_X64_XS",
)


class ServiceEntity(EntityBase[ServiceEntityModel]):
# Local deploy of SPSC service not yet implemented
# We only use the model to deploy SPCS services in native apps
pass
6 changes: 6 additions & 0 deletions src/snowflake/cli/api/project/schemas/entities/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
FunctionEntityModel,
ProcedureEntityModel,
)
from snowflake.cli._plugins.spcs.entities.service import (
ServiceEntity,
ServiceEntityModel,
)
from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity
from snowflake.cli._plugins.streamlit.streamlit_entity_model import (
StreamlitEntityModel,
Expand All @@ -43,13 +47,15 @@
StreamlitEntity,
ProcedureEntity,
FunctionEntity,
ServiceEntity,
]
EntityModel = Union[
ApplicationEntityModel,
ApplicationPackageEntityModel,
StreamlitEntityModel,
FunctionEntityModel,
ProcedureEntityModel,
ServiceEntityModel,
]

ALL_ENTITIES: List[Entity] = [*get_args(Entity)]
Expand Down

0 comments on commit 2f3309e

Please sign in to comment.