diff --git a/pylock.toml b/pylock.toml index 8c6438fa2d..8efe8fbd0e 100644 --- a/pylock.toml +++ b/pylock.toml @@ -877,6 +877,12 @@ version = "0.2.13" sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", upload-time = 2024-01-06T02:10:57Z, size = 101301, hashes = { sha256 = "72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5" } } wheels = [{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", upload-time = 2024-01-06T02:10:55Z, size = 34166, hashes = { sha256 = "3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859" } }] +[[packages]] +name = "websocket-client" +version = "1.9.0" +sdist = { url = "https://files.pythonhosted.org/packages/2c/41/aa4bf9664e4cda14c3b39865b12251e8e7d239f4cd0e3cc1b6c2ccde25c1/websocket_client-1.9.0.tar.gz", upload-time = 2025-10-07T21:16:36Z, size = 70576, hashes = { sha256 = "9e813624b6eb619999a97dc7958469217c3176312b3a16a4bd1bc7e08a46ec98" } } +wheels = [{ url = "https://files.pythonhosted.org/packages/34/db/b10e48aa8fff7407e67470363eac595018441cf32d5e1001567a7aeba5d2/websocket_client-1.9.0-py3-none-any.whl", upload-time = 2025-10-07T21:16:34Z, size = 82616, hashes = { sha256 = "af248a825037ef591efbf6ed20cc5faa03d3b47b9e5a2230a529eeee1c1fc3ef" } }] + [[packages]] name = "wheel" version = "0.45.1" diff --git a/pyproject.toml b/pyproject.toml index 1f4ee48882..4b12cf6d5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "pip==25.3", "pluggy==1.6.0", "prompt-toolkit==3.0.51", + "protobuf>=3.20,<7", "pydantic==2.12.5", "requests==2.32.4", "requirements-parser==0.13.0", @@ -46,6 +47,7 @@ dependencies = [ "tomlkit==0.13.3", "typer==0.17.3", "urllib3>=2.6.3,<3", + "websocket-client>=1.6.0,<2", ] classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 144536850a..ccd92971c6 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -66,5 +66,6 @@ tzdata==2025.2 ; sys_platform == 'win32' tzlocal==5.3.1 urllib3==2.6.3 wcwidth==0.2.13 +websocket-client==1.9.0 wheel==0.45.1 zipp==3.23.0 ; python_full_version < '3.12' diff --git a/src/snowflake/cli/_plugins/streamlit/commands.py b/src/snowflake/cli/_plugins/streamlit/commands.py index 1c8ef10ccb..731a8d2749 100644 --- a/src/snowflake/cli/_plugins/streamlit/commands.py +++ b/src/snowflake/cli/_plugins/streamlit/commands.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Snowflake Inc. +# Copyright (c) 2026 Snowflake Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,10 @@ add_object_command_aliases, scope_option, ) +from snowflake.cli._plugins.streamlit.log_streaming import ( + stream_logs, + validate_spcs_v2_runtime, +) from snowflake.cli._plugins.streamlit.manager import StreamlitManager from snowflake.cli._plugins.streamlit.streamlit_entity import StreamlitEntity from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext @@ -33,6 +37,7 @@ with_project_definition, ) from snowflake.cli.api.commands.flags import ( + IdentifierType, PruneOption, ReplaceOption, entity_argument, @@ -215,6 +220,84 @@ def get_url( return MessageResult(url) +@app.command("logs", requires_connection=True) +@with_project_definition(is_optional=True) +def streamlit_logs( + entity_id: str = entity_argument("streamlit"), + name: FQN = typer.Option( + None, + "--name", + help="Fully qualified name of the Streamlit app (e.g. my_app, schema.my_app, or db.schema.my_app). " + "Overrides the project definition when provided.", + click_type=IdentifierType(), + ), + tail: int = typer.Option( + 100, + "--tail", + "-n", + min=0, + max=1000, # server-side buffer size limit (see logs_service.proto) + help="Number of historical log lines to fetch. Use 0 for live logs only.", + ), + **options, +) -> CommandResult: + """ + Streams live logs from a deployed Streamlit app to your terminal. + + Reads the Streamlit app name from the project definition file (snowflake.yml) + or from the --name option. Connects to the app's developer log service via + WebSocket and prints log entries in real time. Press Ctrl+C to stop streaming. + + Log streaming requires SPCSv2 runtime. + """ + cli_context = get_cli_context() + conn = cli_context.connection + + if name is not None: + if entity_id is not None: + raise ClickException( + "Cannot specify both --name and an entity ID. " + "Use --name to identify the app directly, or use an " + "entity ID to reference a snowflake.yml definition." + ) + # --name flag provided: resolve FQN and validate via server-side DESCRIBE + fqn = name.using_connection(conn) + validate_spcs_v2_runtime(conn, str(fqn)) + else: + # No --name: require project definition + pd = cli_context.project_definition + if pd is None: + raise ClickException( + "No Streamlit app specified. Provide --name or run from a " + "directory with a snowflake.yml project definition." + ) + if not pd.meets_version_requirement("2"): + if not pd.streamlit: + raise NoProjectDefinitionError( + project_type="streamlit", project_root=cli_context.project_root + ) + pd = convert_project_definition_to_v2(cli_context.project_root, pd) + + entity_model = get_entity_for_operation( + cli_context=cli_context, + entity_id=entity_id, + project_definition=pd, + entity_type=ObjectType.STREAMLIT.value.cli_name, + ) + + fqn = entity_model.fqn.using_connection(conn) + # Validate SPCSv2 runtime via server-side DESCRIBE (same path as --name) + validate_spcs_v2_runtime(conn, str(fqn)) + + stream_logs( + conn=conn, + fqn=str(fqn), + tail_lines=tail, + json_output=cli_context.output_format.is_json, + ) + return MessageResult("Log streaming ended.") + + def _get_current_workspace_context(): ctx = get_cli_context() diff --git a/src/snowflake/cli/_plugins/streamlit/log_streaming.py b/src/snowflake/cli/_plugins/streamlit/log_streaming.py new file mode 100644 index 0000000000..c9b290c294 --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/log_streaming.py @@ -0,0 +1,228 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +WebSocket log streaming client for Streamlit developer logs. + +Connects to the Streamlit container runtime's developer log service +via WebSocket and streams log entries in real time. +""" + +from __future__ import annotations + +import json +import logging +import sys +from dataclasses import dataclass + +import websocket +from click import ClickException +from google.protobuf.message import DecodeError +from snowflake.cli._plugins.streamlit.proto_codec import ( + decode_log_entry, + encode_stream_logs_request, +) +from snowflake.cli._plugins.streamlit.streamlit_entity_model import ( + SPCS_RUNTIME_V2_NAME, +) +from snowflake.cli.api.console import cli_console +from snowflake.connector import SnowflakeConnection + +log = logging.getLogger(__name__) + +# Timeout for each ws.recv_data() call — mirrors the Go client's 90-second +# read deadline. When no log entry arrives within this window, we re-issue +# recv_data() so the loop stays responsive to KeyboardInterrupt. +_WS_RECV_TIMEOUT_SECONDS = 90 + +_HANDSHAKE_TIMEOUT_SECONDS = 10 + + +@dataclass +class DeveloperApiToken: + token: str + resource_uri: str + + +def get_developer_api_token(conn: SnowflakeConnection, fqn: str) -> DeveloperApiToken: + """ + Calls SYSTEM$GET_STREAMLIT_DEVELOPER_API_TOKEN and returns a + DeveloperApiToken with the token and resource URI. + """ + if "'" in fqn: + raise ClickException( + f"Invalid Streamlit app name: {fqn}. Name must not contain single quotes." + ) + + query = f"CALL SYSTEM$GET_STREAMLIT_DEVELOPER_API_TOKEN('{fqn}', false);" + log.debug("Fetching developer API token for %s", fqn) + + cursor = conn.cursor() + try: + cursor.execute(query) + row = cursor.fetchone() + if not row: + raise ClickException( + "Empty response from SYSTEM$GET_STREAMLIT_DEVELOPER_API_TOKEN" + ) + raw = row[0] + finally: + cursor.close() + + try: + resp = json.loads(raw) + except (json.JSONDecodeError, TypeError) as e: + raise ClickException(f"Failed to parse token response: {e}") from e + + token = resp.get("token", "") + resource_uri = resp.get("resourceUri", "") + + if not token: + raise ClickException("Empty token in developer API response") + if not resource_uri: + raise ClickException("Empty resourceUri in developer API response") + + log.debug("Resource URI: %s", resource_uri) + return DeveloperApiToken(token=token, resource_uri=resource_uri) + + +def build_ws_url(resource_uri: str) -> str: + """Convert resource URI to WebSocket URL and append /logs path.""" + ws_url = resource_uri.replace("https://", "wss://", 1).replace( + "http://", "ws://", 1 + ) + return ws_url.rstrip("/") + "/logs" + + +def validate_spcs_v2_runtime(conn: SnowflakeConnection, fqn: str) -> None: + """ + Run DESCRIBE STREAMLIT and verify the app uses SPCSv2 runtime. + + Raises ClickException if the app does not use the SPCS Runtime V2 + (required for log streaming). + """ + cursor = conn.cursor() + try: + # fqn is already validated by IdentifierType / FQN.using_connection — + # DESCRIBE uses identifier syntax, not string literals, so no + # single-quote injection risk. + cursor.execute(f"DESCRIBE STREAMLIT {fqn}") + row = cursor.fetchone() + description = cursor.description + finally: + cursor.close() + + if not row or not description: + raise ClickException( + f"Could not describe Streamlit app {fqn}. " + "Verify the app exists and you have access." + ) + + # Build column-name -> value mapping from cursor.description + columns = {desc[0].lower(): val for desc, val in zip(description, row)} + runtime_name = columns.get("runtime_name") + + if runtime_name != SPCS_RUNTIME_V2_NAME: + raise ClickException( + f"Log streaming is only supported for Streamlit apps running on " + f"SPCSv2 runtime ({SPCS_RUNTIME_V2_NAME}). " + f"App '{fqn}' has runtime_name='{runtime_name}'." + ) + + +def stream_logs( + conn: SnowflakeConnection, + fqn: str, + tail_lines: int = 100, + json_output: bool = False, +) -> None: + """ + Connect to the Streamlit developer log streaming WebSocket and print + log entries to stdout until interrupted. + + When *json_output* is True each log entry is emitted as a single-line + JSON object (JSONL), suitable for piping to ``jq`` or other tools. + """ + # 1. Get token + cli_console.step("Fetching developer API token...") + token_info = get_developer_api_token(conn, fqn) + + # 2. Build WebSocket URL + ws_url = build_ws_url(token_info.resource_uri) + cli_console.step(f"Connecting to log stream: {ws_url}") + + # 3. Connect + # NOTE: Do not log `header` — it contains the auth token. Also be aware + # that websocket.enableTrace(True) will dump headers to stderr. + header = [f'Authorization: Snowflake Token="{token_info.token}"'] + ws = websocket.WebSocket() + ws.timeout = _WS_RECV_TIMEOUT_SECONDS + streaming = False + + try: + try: + ws.connect(ws_url, header=header, timeout=_HANDSHAKE_TIMEOUT_SECONDS) + except Exception as e: + raise ClickException(f"Failed to connect to log stream: {e}") from e + + # 4. Send StreamLogsRequest + ws.send_binary(encode_stream_logs_request(tail_lines)) + log.debug("Sent StreamLogsRequest with tail_lines=%d", tail_lines) + + cli_console.step(f"Streaming logs (tail={tail_lines}). Press Ctrl+C to stop.") + sys.stdout.write("---\n") + sys.stdout.flush() + streaming = True + + # 5. Read loop + while True: + try: + opcode, data = ws.recv_data() + except websocket.WebSocketTimeoutException: + # No message within the timeout window — loop back so we + # stay responsive to KeyboardInterrupt. + continue + except websocket.WebSocketConnectionClosedException: + log.debug("WebSocket connection closed by server") + break + except (websocket.WebSocketException, OSError) as e: + log.debug("WebSocket recv error: %s", e) + break + + if opcode == websocket.ABNF.OPCODE_BINARY: + try: + entry = decode_log_entry(data) + except (DecodeError, ValueError) as e: + log.warning("Failed to decode log entry: %s", e) + continue + if json_output: + sys.stdout.write(json.dumps(entry.to_dict()) + "\n") + else: + sys.stdout.write(entry.format_line() + "\n") + sys.stdout.flush() + elif opcode == websocket.ABNF.OPCODE_CLOSE: + break + elif opcode == websocket.ABNF.OPCODE_PING: + ws.pong(data) + + except KeyboardInterrupt: + pass + finally: + try: + ws.close(status=websocket.STATUS_NORMAL) + except Exception as e: + log.debug("Error closing WebSocket: %s", e) + if streaming: + sys.stdout.write("\n--- Log streaming stopped.\n") + sys.stdout.flush() diff --git a/src/snowflake/cli/_plugins/streamlit/proto/__init__.py b/src/snowflake/cli/_plugins/streamlit/proto/__init__.py new file mode 100644 index 0000000000..74bcb8a780 --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/snowflake/cli/_plugins/streamlit/proto/developer/v1/logs_service.proto b/src/snowflake/cli/_plugins/streamlit/proto/developer/v1/logs_service.proto new file mode 100644 index 0000000000..de25490e5f --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/developer/v1/logs_service.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package developer.v1; + +import "google/protobuf/timestamp.proto"; + +option go_package = "github.com/snowflakedb/streamlit-container-runtime/gen/developer/v1;developerv1"; + +// LogSource identifies the origin of log entries +enum LogSource { + LOG_SOURCE_UNSPECIFIED = 0; + LOG_SOURCE_APP = 1; + LOG_SOURCE_MANAGER = 2; +} + +// LogLevel represents the severity of a log entry +enum LogLevel { + LOG_LEVEL_UNSPECIFIED = 0; + LOG_LEVEL_DEBUG = 1; + LOG_LEVEL_INFO = 2; + LOG_LEVEL_WARN = 3; + LOG_LEVEL_ERROR = 4; +} + +// StreamLogsRequest configures the log stream +message StreamLogsRequest { + // Number of historical lines to send before streaming live logs. + // If 0, only stream live logs. Max: 1000 (buffer size). + int32 tail_lines = 1; +} + +// LogEntry represents a single log line +message LogEntry { + LogSource log_source = 1; + string content = 2; + // Timestamp when the log was captured (UTC) + google.protobuf.Timestamp timestamp = 3; + int64 sequence = 4; + LogLevel level = 5; +} diff --git a/src/snowflake/cli/_plugins/streamlit/proto/generated/__init__.py b/src/snowflake/cli/_plugins/streamlit/proto/generated/__init__.py new file mode 100644 index 0000000000..74bcb8a780 --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/generated/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/__init__.py b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/__init__.py new file mode 100644 index 0000000000..74bcb8a780 --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/__init__.py b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/__init__.py new file mode 100644 index 0000000000..74bcb8a780 --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/logs_service_pb2.py b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/logs_service_pb2.py new file mode 100644 index 0000000000..faa82664bd --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto/generated/developer/v1/logs_service_pb2.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: developer/v1/logs_service.proto +# Regenerate with: +# python -m grpc_tools.protoc \ +# --proto_path=src/snowflake/cli/_plugins/streamlit/proto \ +# --python_out=src/snowflake/cli/_plugins/streamlit/proto/generated \ +# developer/v1/logs_service.proto +# ruff: noqa: SLF001 +# NOTE: The runtime version check below is wrapped in a try/except for +# compatibility with protobuf 5.x (pulled by snowflake-connector-python) and 6.x. +# IMPORTANT: After regenerating, you must re-apply the try/except wrapper around +# the ValidateProtobufRuntimeVersion call. +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import timestamp_pb2 as _timestamp_pb2 # noqa: F401 +from google.protobuf.internal import builder as _builder + +try: + from google.protobuf import runtime_version as _runtime_version + + _runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 32, + 1, + "", + "developer/v1/logs_service.proto", + ) +except Exception: + pass # protobuf 5.x compat: may lack runtime_version or fail version check + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1f\x64\x65veloper/v1/logs_service.proto\x12\x0c\x64\x65veloper.v1\x1a\x1fgoogle/protobuf/timestamp.proto"\'\n\x11StreamLogsRequest\x12\x12\n\ntail_lines\x18\x01 \x01(\x05"\xb0\x01\n\x08LogEntry\x12+\n\nlog_source\x18\x01 \x01(\x0e\x32\x17.developer.v1.LogSource\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\t\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x10\n\x08sequence\x18\x04 \x01(\x03\x12%\n\x05level\x18\x05 \x01(\x0e\x32\x16.developer.v1.LogLevel*S\n\tLogSource\x12\x1a\n\x16LOG_SOURCE_UNSPECIFIED\x10\x00\x12\x12\n\x0eLOG_SOURCE_APP\x10\x01\x12\x16\n\x12LOG_SOURCE_MANAGER\x10\x02*w\n\x08LogLevel\x12\x19\n\x15LOG_LEVEL_UNSPECIFIED\x10\x00\x12\x13\n\x0fLOG_LEVEL_DEBUG\x10\x01\x12\x12\n\x0eLOG_LEVEL_INFO\x10\x02\x12\x12\n\x0eLOG_LEVEL_WARN\x10\x03\x12\x13\n\x0fLOG_LEVEL_ERROR\x10\x04\x42QZOgithub.com/snowflakedb/streamlit-container-runtime/gen/developer/v1;developerv1b\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "developer.v1.logs_service_pb2", _globals +) +if not _descriptor._USE_C_DESCRIPTORS: + _globals["DESCRIPTOR"]._loaded_options = None + _globals[ + "DESCRIPTOR" + ]._serialized_options = b"ZOgithub.com/snowflakedb/streamlit-container-runtime/gen/developer/v1;developerv1" + _globals["_LOGSOURCE"]._serialized_start = 302 + _globals["_LOGSOURCE"]._serialized_end = 385 + _globals["_LOGLEVEL"]._serialized_start = 387 + _globals["_LOGLEVEL"]._serialized_end = 506 + _globals["_STREAMLOGSREQUEST"]._serialized_start = 82 + _globals["_STREAMLOGSREQUEST"]._serialized_end = 121 + _globals["_LOGENTRY"]._serialized_start = 124 + _globals["_LOGENTRY"]._serialized_end = 300 +# @@protoc_insertion_point(module_scope) diff --git a/src/snowflake/cli/_plugins/streamlit/proto_codec.py b/src/snowflake/cli/_plugins/streamlit/proto_codec.py new file mode 100644 index 0000000000..4d5f5482bf --- /dev/null +++ b/src/snowflake/cli/_plugins/streamlit/proto_codec.py @@ -0,0 +1,109 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Protobuf codec for the Streamlit developer log streaming protocol. + +Uses generated protobuf classes from logs_service.proto and provides +a Python-friendly dataclass wrapper for log entries. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone + +from snowflake.cli._plugins.streamlit.proto.generated.developer.v1 import ( + logs_service_pb2 as pb2, +) + +# Re-export enum values for convenience +LOG_SOURCE_APP = pb2.LOG_SOURCE_APP +LOG_SOURCE_MANAGER = pb2.LOG_SOURCE_MANAGER +LOG_SOURCE_UNSPECIFIED = pb2.LOG_SOURCE_UNSPECIFIED + +LOG_LEVEL_DEBUG = pb2.LOG_LEVEL_DEBUG +LOG_LEVEL_INFO = pb2.LOG_LEVEL_INFO +LOG_LEVEL_WARN = pb2.LOG_LEVEL_WARN +LOG_LEVEL_ERROR = pb2.LOG_LEVEL_ERROR +LOG_LEVEL_UNSPECIFIED = pb2.LOG_LEVEL_UNSPECIFIED + +LOG_SOURCE_LABELS = { + LOG_SOURCE_APP: "APP", + LOG_SOURCE_MANAGER: "MGR", + LOG_SOURCE_UNSPECIFIED: "UNKNOWN", +} + +LOG_LEVEL_LABELS = { + LOG_LEVEL_UNSPECIFIED: "UNKNOWN", + LOG_LEVEL_DEBUG: "DEBUG", + LOG_LEVEL_INFO: "INFO", + LOG_LEVEL_WARN: "WARN", + LOG_LEVEL_ERROR: "ERROR", +} + + +@dataclass +class LogEntry: + log_source: int + content: str + timestamp: datetime + sequence: int + level: int + + @property + def source_label(self) -> str: + return LOG_SOURCE_LABELS.get(self.log_source, "UNKNOWN") + + @property + def level_label(self) -> str: + return LOG_LEVEL_LABELS.get(self.level, "UNKNOWN") + + def format_line(self) -> str: + ts = self.timestamp.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + return f"[{ts}] [{self.level_label}] [{self.source_label}] [seq:{self.sequence}] {self.content}" + + def to_dict(self) -> dict[str, str | int]: + return { + "timestamp": self.timestamp.isoformat(), + "level": self.level_label, + "source": self.source_label, + "sequence": self.sequence, + "content": self.content, + } + + +def encode_stream_logs_request(tail_lines: int) -> bytes: + """Encode a StreamLogsRequest protobuf message to binary.""" + request = pb2.StreamLogsRequest(tail_lines=tail_lines) + return request.SerializeToString() + + +def decode_log_entry(data: bytes) -> LogEntry: + """Decode a binary protobuf LogEntry message into a Python dataclass.""" + entry = pb2.LogEntry() + entry.ParseFromString(data) + + if entry.HasField("timestamp"): + ts = entry.timestamp.ToDatetime(tzinfo=timezone.utc) + else: + ts = datetime.fromtimestamp(0, tz=timezone.utc) + + return LogEntry( + log_source=entry.log_source, + content=entry.content, + timestamp=ts, + sequence=entry.sequence, + level=entry.level, + ) diff --git a/tests/__snapshots__/test_help_messages.ambr b/tests/__snapshots__/test_help_messages.ambr index 328b75547a..0fc00e732d 100644 --- a/tests/__snapshots__/test_help_messages.ambr +++ b/tests/__snapshots__/test_help_messages.ambr @@ -22164,6 +22164,168 @@ +------------------------------------------------------------------------------+ + ''' +# --- +# name: test_help_messages[streamlit.logs] + ''' + + Usage: root streamlit logs [OPTIONS] [ENTITY_ID] + + Streams live logs from a deployed Streamlit app to your terminal. + + Reads the Streamlit app name from the project definition file (snowflake.yml) + or from the --name option. Connects to the app's developer log service via + WebSocket and prints log entries in real time. Press Ctrl+C to stop streaming. + + Log streaming requires SPCSv2 runtime. + + +- Arguments ------------------------------------------------------------------+ + | entity_id [ENTITY_ID] ID of streamlit entity. | + +------------------------------------------------------------------------------+ + +- Options --------------------------------------------------------------------+ + | --name TEXT Fully qualified name of the | + | Streamlit app (e.g. my_app, | + | schema.my_app, or | + | db.schema.my_app). Overrides | + | the project definition when | + | provided. | + | --tail -n INTEGER RANGE [0<=x<=1000] Number of historical log | + | lines to fetch. Use 0 for | + | live logs only. | + | [default: 100] | + | --project -p TEXT Path where the Snowflake | + | project is stored. Defaults | + | to the current working | + | directory. | + | --env TEXT String in the format | + | key=value. Overrides | + | variables from the env | + | section used for templates. | + | --help -h Show this message and exit. | + +------------------------------------------------------------------------------+ + +- Connection configuration ---------------------------------------------------+ + | --connection,--environment -c TEXT Name of the connection, as | + | defined in your config.toml | + | file. Default: default. | + | --host TEXT Host address for the | + | connection. Overrides the | + | value specified for the | + | connection. | + | --port INTEGER Port for the connection. | + | Overrides the value specified | + | for the connection. | + | --account,--accountname TEXT Name assigned to your | + | Snowflake account. Overrides | + | the value specified for the | + | connection. | + | --user,--username TEXT Username to connect to | + | Snowflake. Overrides the | + | value specified for the | + | connection. | + | --password TEXT Snowflake password. Overrides | + | the value specified for the | + | connection. | + | --authenticator TEXT Snowflake authenticator. | + | Overrides the value specified | + | for the connection. | + | --workload-identity-provider TEXT Workload identity provider | + | (AWS, AZURE, GCP, OIDC). | + | Overrides the value specified | + | for the connection | + | --private-key-file,--privat… TEXT Snowflake private key file | + | path. Overrides the value | + | specified for the connection. | + | --token TEXT OAuth token to use when | + | connecting to Snowflake. | + | --token-file-path TEXT Path to file with an OAuth | + | token to use when connecting | + | to Snowflake. | + | --database,--dbname TEXT Database to use. Overrides | + | the value specified for the | + | connection. | + | --schema,--schemaname TEXT Database schema to use. | + | Overrides the value specified | + | for the connection. | + | --role,--rolename TEXT Role to use. Overrides the | + | value specified for the | + | connection. | + | --warehouse TEXT Warehouse to use. Overrides | + | the value specified for the | + | connection. | + | --temporary-connection -x Uses a connection defined | + | with command-line parameters, | + | instead of one defined in | + | config | + | --mfa-passcode TEXT Token to use for multi-factor | + | authentication (MFA) | + | --enable-diag Whether to generate a | + | connection diagnostic report. | + | --diag-log-path TEXT Path for the generated | + | report. Defaults to system | + | temporary directory. | + | --diag-allowlist-path TEXT Path to a JSON file that | + | contains allowlist | + | parameters. | + | --oauth-client-id TEXT Value of client id provided | + | by the Identity Provider for | + | Snowflake integration. | + | --oauth-client-secret TEXT Value of the client secret | + | provided by the Identity | + | Provider for Snowflake | + | integration. | + | --oauth-authorization-url TEXT Identity Provider endpoint | + | supplying the authorization | + | code to the driver. | + | --oauth-token-request-url TEXT Identity Provider endpoint | + | supplying the access tokens | + | to the driver. | + | --oauth-redirect-uri TEXT URI to use for authorization | + | code redirection. | + | --oauth-scope TEXT Scope requested in the | + | Identity Provider | + | authorization request. | + | --oauth-disable-pkce Disables Proof Key for Code | + | Exchange (PKCE). Default: | + | False. | + | --oauth-enable-refresh-toke… Enables a silent | + | re-authentication when the | + | actual access token becomes | + | outdated. Default: False. | + | --oauth-enable-single-use-r… Whether to opt-in to | + | single-use refresh token | + | semantics. Default: False. | + | --client-store-temporary-cr… Store the temporary | + | credential. | + +------------------------------------------------------------------------------+ + +- Global configuration -------------------------------------------------------+ + | --format [TABLE|JSON|JSON_EXT| Specifies the output | + | CSV] format. | + | [default: TABLE] | + | --verbose -v Displays log entries | + | for log levels info | + | and higher. | + | --debug Displays log entries | + | for log levels debug | + | and higher; debug logs | + | contain additional | + | information. | + | --silent Turns off intermediate | + | output to console. | + | --enhanced-exit-codes Differentiate exit | + | error codes based on | + | failure type. | + | [env var: | + | SNOWFLAKE_ENHANCED_EX… | + | --decimal-precision INTEGER Number of decimal | + | places to display for | + | decimal values. Uses | + | Python's default | + | precision if not | + | specified. [env var: | + | SNOWFLAKE_DECIMAL_PRE… | + +------------------------------------------------------------------------------+ + + ''' # --- # name: test_help_messages[streamlit.share] @@ -22331,6 +22493,7 @@ | execute Executes a streamlit in a headless mode. | | get-url Returns a URL to the specified Streamlit app | | list Lists all available streamlits. | + | logs Streams live logs from a deployed Streamlit app to your terminal. | | share Shares a Streamlit app with another role. | +------------------------------------------------------------------------------+ @@ -23111,6 +23274,7 @@ | execute Executes a streamlit in a headless mode. | | get-url Returns a URL to the specified Streamlit app | | list Lists all available streamlits. | + | logs Streams live logs from a deployed Streamlit app to your terminal. | | share Shares a Streamlit app with another role. | +------------------------------------------------------------------------------+ diff --git a/tests/streamlit/__init__.py b/tests/streamlit/__init__.py new file mode 100644 index 0000000000..74bcb8a780 --- /dev/null +++ b/tests/streamlit/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/streamlit/test_streamlit_logs.py b/tests/streamlit/test_streamlit_logs.py new file mode 100644 index 0000000000..2b77c6066e --- /dev/null +++ b/tests/streamlit/test_streamlit_logs.py @@ -0,0 +1,598 @@ +# Copyright (c) 2026 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from datetime import datetime, timezone +from unittest import mock + +import pytest +import websocket as ws_lib +from click import ClickException +from snowflake.cli._plugins.streamlit.commands import streamlit_logs +from snowflake.cli._plugins.streamlit.log_streaming import ( + DeveloperApiToken, + build_ws_url, + get_developer_api_token, + stream_logs, + validate_spcs_v2_runtime, +) +from snowflake.cli._plugins.streamlit.proto.generated.developer.v1 import ( + logs_service_pb2 as pb2, +) +from snowflake.cli._plugins.streamlit.proto_codec import ( + LOG_LEVEL_INFO, + LOG_LEVEL_WARN, + LOG_SOURCE_APP, + LOG_SOURCE_MANAGER, + LogEntry, + decode_log_entry, + encode_stream_logs_request, +) +from snowflake.cli.api.identifiers import FQN + + +class TestBuildWsUrl: + def test_https_to_wss(self): + url = build_ws_url("https://my-app.snowflakecomputing.com/api/v1") + assert url == "wss://my-app.snowflakecomputing.com/api/v1/logs" + + def test_http_to_ws(self): + url = build_ws_url("http://localhost:8702") + assert url == "ws://localhost:8702/logs" + + def test_preserves_path(self): + url = build_ws_url("https://host.example.com/some/deep/path") + assert url == "wss://host.example.com/some/deep/path/logs" + + def test_strips_trailing_slash(self): + url = build_ws_url("https://host.example.com/api/v1/") + assert url == "wss://host.example.com/api/v1/logs" + + def test_replaces_only_first_occurrence(self): + url = build_ws_url("https://proxy.example.com/redirect/https://target.com") + assert url == "wss://proxy.example.com/redirect/https://target.com/logs" + + +class TestGetDeveloperApiToken: + def test_success(self): + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = ( + '{"token": "abc123", "resourceUri": "https://my-app.snowflakecomputing.com"}', + ) + + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + result = get_developer_api_token(mock_conn, "DB.SCHEMA.APP") + + assert isinstance(result, DeveloperApiToken) + assert result.token == "abc123" + assert result.resource_uri == "https://my-app.snowflakecomputing.com" + mock_cursor.execute.assert_called_once_with( + "CALL SYSTEM$GET_STREAMLIT_DEVELOPER_API_TOKEN('DB.SCHEMA.APP', false);" + ) + mock_cursor.close.assert_called_once() + + def test_empty_response_raises(self): + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = None + + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="Empty response"): + get_developer_api_token(mock_conn, "DB.SCHEMA.APP") + + def test_empty_token_raises(self): + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = ( + '{"token": "", "resourceUri": "https://example.com"}', + ) + + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="Empty token"): + get_developer_api_token(mock_conn, "DB.SCHEMA.APP") + + def test_empty_resource_uri_raises(self): + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = ('{"token": "abc", "resourceUri": ""}',) + + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="Empty resourceUri"): + get_developer_api_token(mock_conn, "DB.SCHEMA.APP") + + def test_single_quote_in_fqn_raises(self): + mock_conn = mock.Mock() + + with pytest.raises(ClickException, match="single quotes"): + get_developer_api_token(mock_conn, "DB.SCHEMA.APP'; DROP TABLE --") + + def test_cursor_closed_on_error(self): + mock_cursor = mock.Mock() + mock_cursor.execute.side_effect = Exception("SQL error") + + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(Exception, match="SQL error"): + get_developer_api_token(mock_conn, "DB.SCHEMA.APP") + + mock_cursor.close.assert_called_once() + + +class TestEncodeStreamLogsRequest: + def test_encodes_tail_lines(self): + data = encode_stream_logs_request(100) + assert isinstance(data, bytes) + assert len(data) > 0 + + def test_zero_tail_lines(self): + data = encode_stream_logs_request(0) + assert isinstance(data, bytes) + + def test_roundtrip_via_pb2(self): + """Verify encoding matches what the protobuf library produces.""" + for tail_lines in [0, 1, 50, 100, 1000, 10000]: + encoded = encode_stream_logs_request(tail_lines) + decoded = pb2.StreamLogsRequest() + decoded.ParseFromString(encoded) + assert decoded.tail_lines == tail_lines + + +class TestDecodeLogEntry: + def _make_pb2_log_entry(self, log_source, content, seconds, nanos, sequence, level): + entry = pb2.LogEntry() + entry.log_source = log_source + entry.content = content + entry.timestamp.seconds = seconds + entry.timestamp.nanos = nanos + entry.sequence = sequence + entry.level = level + return entry.SerializeToString() + + def test_decode_app_log(self): + data = self._make_pb2_log_entry( + log_source=1, # LOG_SOURCE_APP + content="Hello from app", + seconds=1700000000, + nanos=500000000, + sequence=42, + level=2, # LOG_LEVEL_INFO + ) + entry = decode_log_entry(data) + + assert entry.log_source == LOG_SOURCE_APP + assert entry.content == "Hello from app" + assert entry.sequence == 42 + assert entry.level == LOG_LEVEL_INFO + assert entry.source_label == "APP" + assert entry.level_label == "INFO" + + def test_decode_manager_log(self): + data = self._make_pb2_log_entry( + log_source=2, # LOG_SOURCE_MANAGER + content="Manager message", + seconds=1700000000, + nanos=0, + sequence=1, + level=3, # LOG_LEVEL_WARN + ) + entry = decode_log_entry(data) + + assert entry.log_source == LOG_SOURCE_MANAGER + assert entry.content == "Manager message" + assert entry.source_label == "MGR" + assert entry.level_label == "WARN" + + def test_format_line_includes_level(self): + entry = LogEntry( + log_source=LOG_SOURCE_APP, + content="test message", + timestamp=datetime(2024, 1, 15, 10, 30, 45, 123000, tzinfo=timezone.utc), + sequence=7, + level=LOG_LEVEL_INFO, + ) + line = entry.format_line() + assert line == "[2024-01-15 10:30:45.123] [INFO] [APP] [seq:7] test message" + + def test_format_line_warn_level(self): + entry = LogEntry( + log_source=LOG_SOURCE_MANAGER, + content="warning msg", + timestamp=datetime(2024, 6, 1, 12, 0, 0, 0, tzinfo=timezone.utc), + sequence=99, + level=LOG_LEVEL_WARN, + ) + line = entry.format_line() + assert line == "[2024-06-01 12:00:00.000] [WARN] [MGR] [seq:99] warning msg" + + def test_to_dict(self): + entry = LogEntry( + log_source=LOG_SOURCE_APP, + content="some content", + timestamp=datetime(2024, 3, 10, 8, 0, 0, tzinfo=timezone.utc), + sequence=5, + level=LOG_LEVEL_INFO, + ) + d = entry.to_dict() + assert d == { + "timestamp": "2024-03-10T08:00:00+00:00", + "level": "INFO", + "source": "APP", + "sequence": 5, + "content": "some content", + } + + +def _make_entry_bytes(log_source, content, seconds, sequence, level): + """Serialize a protobuf LogEntry for use in tests.""" + entry = pb2.LogEntry( + log_source=log_source, content=content, sequence=sequence, level=level + ) + entry.timestamp.seconds = seconds + return entry.SerializeToString() + + +def _mock_conn_with_token(): + """Return a mock connection that returns a valid token response.""" + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = ( + '{"token": "test-token", "resourceUri": "https://test.snowflakecomputing.com/api"}', + ) + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + return mock_conn + + +@pytest.fixture +def mock_ws(): + """Patch the websocket module in log_streaming and wire up real constants.""" + with mock.patch( + "snowflake.cli._plugins.streamlit.log_streaming.websocket" + ) as mock_ws_module: + ws = mock.Mock() + mock_ws_module.WebSocket.return_value = ws + mock_ws_module.ABNF = ws_lib.ABNF + mock_ws_module.WebSocketTimeoutException = ws_lib.WebSocketTimeoutException + mock_ws_module.WebSocketConnectionClosedException = ( + ws_lib.WebSocketConnectionClosedException + ) + mock_ws_module.WebSocketException = ws_lib.WebSocketException + mock_ws_module.STATUS_NORMAL = ws_lib.STATUS_NORMAL + yield ws + + +@pytest.fixture +def mock_console(): + with mock.patch( + "snowflake.cli._plugins.streamlit.log_streaming.cli_console" + ) as console: + yield console + + +class TestStreamLogs: + def test_streams_log_entries_to_stdout(self, mock_ws, mock_console, capsys): + entry1 = _make_entry_bytes(1, "line one", 1700000000, 1, 2) + entry2 = _make_entry_bytes(2, "line two", 1700000001, 2, 3) + + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_BINARY, entry1), + (ws_lib.ABNF.OPCODE_BINARY, entry2), + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + captured = capsys.readouterr() + assert "line one" in captured.out + assert "line two" in captured.out + assert "[APP]" in captured.out + assert "[MGR]" in captured.out + + def test_json_output(self, mock_ws, mock_console, capsys): + entry_bytes = _make_entry_bytes(1, "json test", 1700000000, 1, 2) + + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_BINARY, entry_bytes), + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=50, json_output=True) + + captured = capsys.readouterr() + # Skip the "---" header line and the trailing "--- Log streaming stopped." + json_lines = [ + line for line in captured.out.strip().split("\n") if line.startswith("{") + ] + assert len(json_lines) == 1 + parsed = json.loads(json_lines[0]) + assert parsed["content"] == "json test" + assert parsed["source"] == "APP" + assert parsed["level"] == "INFO" + + def test_handles_connection_closed(self, mock_ws, mock_console, capsys): + mock_ws.recv_data.side_effect = ws_lib.WebSocketConnectionClosedException() + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + captured = capsys.readouterr() + assert "Log streaming stopped" in captured.out + + def test_timeout_continues_loop(self, mock_ws, mock_console, capsys): + entry_bytes = _make_entry_bytes(1, "after timeout", 1700000000, 1, 2) + + # Timeout once, then get a message, then close + mock_ws.recv_data.side_effect = [ + ws_lib.WebSocketTimeoutException(), + (ws_lib.ABNF.OPCODE_BINARY, entry_bytes), + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + captured = capsys.readouterr() + assert "after timeout" in captured.out + + def test_graceful_close(self, mock_ws, mock_console): + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + mock_ws.close.assert_called_once_with(status=ws_lib.STATUS_NORMAL) + + def test_skips_malformed_protobuf(self, mock_ws, mock_console, capsys): + good_entry = _make_entry_bytes(1, "good line", 1700000000, 1, 2) + + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_BINARY, b"\xff\xff\xff"), # invalid protobuf + (ws_lib.ABNF.OPCODE_BINARY, good_entry), + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + captured = capsys.readouterr() + # The malformed entry is skipped but the good entry still shows + assert "good line" in captured.out + + def test_responds_to_ping(self, mock_ws, mock_console): + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_PING, b"ping-data"), + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + mock_ws.pong.assert_called_once_with(b"ping-data") + + def test_keyboard_interrupt_prints_stopped(self, mock_ws, mock_console, capsys): + mock_ws.recv_data.side_effect = KeyboardInterrupt() + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + captured = capsys.readouterr() + assert "Log streaming stopped" in captured.out + mock_ws.close.assert_called_once_with(status=ws_lib.STATUS_NORMAL) + + def test_connect_failure_raises(self, mock_ws, mock_console): + mock_ws.connect.side_effect = ConnectionRefusedError("Connection refused") + + conn = _mock_conn_with_token() + with pytest.raises(ClickException, match="Failed to connect"): + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=100) + + # WebSocket should still be closed in the finally block + mock_ws.close.assert_called_once_with(status=ws_lib.STATUS_NORMAL) + + def test_sends_stream_logs_request(self, mock_ws, mock_console): + mock_ws.recv_data.side_effect = [ + (ws_lib.ABNF.OPCODE_CLOSE, b""), + ] + + conn = _mock_conn_with_token() + stream_logs(conn=conn, fqn="DB.SCHEMA.APP", tail_lines=42) + + mock_ws.send_binary.assert_called_once() + sent_bytes = mock_ws.send_binary.call_args[0][0] + + # Verify the sent bytes decode to a StreamLogsRequest with tail_lines=42 + request = pb2.StreamLogsRequest() + request.ParseFromString(sent_bytes) + assert request.tail_lines == 42 + + +class TestValidateSpcsV2Runtime: + SPCS_V2 = "SYSTEM$ST_CONTAINER_RUNTIME_PY3_11" + + def _mock_describe_cursor(self, runtime_name): + """Return a mock cursor whose DESCRIBE STREAMLIT result has the given runtime_name.""" + mock_cursor = mock.Mock() + # Simulate DESCRIBE STREAMLIT columns (subset relevant to our code) + mock_cursor.description = [ + ("title",), + ("main_file",), + ("query_warehouse",), + ("compute_pool",), + ("runtime_name",), + ("name",), + ] + mock_cursor.fetchone.return_value = ( + "My App", + "streamlit_app.py", + "WH", + "my_pool", + runtime_name, + "MY_APP", + ) + return mock_cursor + + def test_passes_for_spcs_v2_runtime(self): + mock_cursor = self._mock_describe_cursor(self.SPCS_V2) + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + # Should not raise + validate_spcs_v2_runtime(mock_conn, "DB.SCHEMA.MY_APP") + + mock_cursor.execute.assert_called_once_with( + "DESCRIBE STREAMLIT DB.SCHEMA.MY_APP" + ) + mock_cursor.close.assert_called_once() + + def test_raises_for_non_spcs_v2_runtime(self): + mock_cursor = self._mock_describe_cursor(None) + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="only supported for Streamlit apps"): + validate_spcs_v2_runtime(mock_conn, "DB.SCHEMA.MY_APP") + + mock_cursor.close.assert_called_once() + + def test_raises_for_wrong_runtime_name(self): + mock_cursor = self._mock_describe_cursor("SOME_OTHER_RUNTIME") + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="SOME_OTHER_RUNTIME"): + validate_spcs_v2_runtime(mock_conn, "DB.SCHEMA.MY_APP") + + def test_raises_for_empty_describe_result(self): + mock_cursor = mock.Mock() + mock_cursor.fetchone.return_value = None + mock_cursor.description = None + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(ClickException, match="Could not describe"): + validate_spcs_v2_runtime(mock_conn, "DB.SCHEMA.MY_APP") + + mock_cursor.close.assert_called_once() + + def test_cursor_closed_on_sql_error(self): + mock_cursor = mock.Mock() + mock_cursor.execute.side_effect = Exception("SQL error") + mock_conn = mock.Mock() + mock_conn.cursor.return_value = mock_cursor + + with pytest.raises(Exception, match="SQL error"): + validate_spcs_v2_runtime(mock_conn, "DB.SCHEMA.MY_APP") + + mock_cursor.close.assert_called_once() + + +SPCS_V2_NAME = "SYSTEM$ST_CONTAINER_RUNTIME_PY3_11" + + +class TestStreamlitLogsCommand: + """Tests for the streamlit_logs command handler in commands.py.""" + + @mock.patch("snowflake.cli._plugins.streamlit.commands.get_cli_context") + @mock.patch("snowflake.cli._plugins.streamlit.commands.validate_spcs_v2_runtime") + @mock.patch("snowflake.cli._plugins.streamlit.commands.stream_logs") + def test_name_flag_resolves_fqn_and_validates( + self, mock_stream_logs, mock_validate, mock_get_ctx + ): + """When --name is provided, resolve FQN and validate via DESCRIBE.""" + mock_conn = mock.Mock() + mock_conn.database = "DB" + mock_conn.schema = "SCHEMA" + + mock_ctx = mock.Mock() + mock_ctx.connection = mock_conn + mock_ctx.output_format.is_json = False + mock_get_ctx.return_value = mock_ctx + + fqn = FQN.from_string("MY_APP") + resolved = fqn.using_connection(mock_conn) + + result = streamlit_logs(entity_id=None, name=fqn, tail=100) + + mock_validate.assert_called_once_with(mock_conn, str(resolved)) + mock_stream_logs.assert_called_once_with( + conn=mock_conn, + fqn=str(resolved), + tail_lines=100, + json_output=False, + ) + assert result.message == "Log streaming ended." + + @mock.patch("snowflake.cli._plugins.streamlit.commands.get_cli_context") + def test_name_and_entity_id_raises(self, mock_get_ctx): + """When both --name and entity_id are provided, raise an error.""" + mock_ctx = mock.Mock() + mock_ctx.connection = mock.Mock() + mock_get_ctx.return_value = mock_ctx + + with pytest.raises(ClickException, match="Cannot specify both"): + streamlit_logs( + entity_id="my_entity", name=FQN.from_string("MY_APP"), tail=100 + ) + + @mock.patch("snowflake.cli._plugins.streamlit.commands.get_cli_context") + def test_no_name_no_project_definition_raises(self, mock_get_ctx): + """When neither --name nor project definition is available, raise an error.""" + mock_ctx = mock.Mock() + mock_ctx.connection = mock.Mock() + mock_ctx.project_definition = None + mock_get_ctx.return_value = mock_ctx + + with pytest.raises(ClickException, match="No Streamlit app specified"): + streamlit_logs(entity_id=None, name=None, tail=100) + + @mock.patch("snowflake.cli._plugins.streamlit.commands.get_cli_context") + @mock.patch("snowflake.cli._plugins.streamlit.commands.get_entity_for_operation") + @mock.patch("snowflake.cli._plugins.streamlit.commands.validate_spcs_v2_runtime") + @mock.patch("snowflake.cli._plugins.streamlit.commands.stream_logs") + def test_project_definition_path( + self, mock_stream_logs, mock_validate, mock_get_entity, mock_get_ctx + ): + """When using project definition, resolve entity and validate via DESCRIBE.""" + mock_conn = mock.Mock() + mock_conn.database = "DB" + mock_conn.schema = "PUBLIC" + + mock_pd = mock.Mock() + mock_pd.meets_version_requirement.return_value = True + + mock_ctx = mock.Mock() + mock_ctx.connection = mock_conn + mock_ctx.project_definition = mock_pd + mock_ctx.output_format.is_json = False + mock_get_ctx.return_value = mock_ctx + + mock_entity = mock.Mock() + mock_entity.fqn = FQN.from_string("DB.PUBLIC.MY_STREAMLIT") + mock_get_entity.return_value = mock_entity + + result = streamlit_logs(entity_id=None, name=None, tail=50) + + mock_validate.assert_called_once() + mock_stream_logs.assert_called_once() + assert mock_stream_logs.call_args.kwargs["tail_lines"] == 50 + assert result.message == "Log streaming ended."