Skip to content

Commit

Permalink
SNOW-1720835 param protect thread-safe client side changes (#2401)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Oct 17, 2024
1 parent 0b47a67 commit 367dd23
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 190 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
### Snowpark Python API Updates

- Added support for 'Service' domain to `session.lineage.trace` API.
- Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads.
- Added support for `copy_grants` parameter when registering UDxF and stored procedures.

#### New Features
Expand Down
11 changes: 9 additions & 2 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
escape_quotes,
get_application_name,
get_version,
Expand Down Expand Up @@ -155,8 +157,6 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_store = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -171,6 +171,13 @@ def __init__(
except TypeError:
pass

# thread safe param protection
self._thread_safe_session_enabled = self._get_client_side_session_parameter(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._thread_store = create_thread_local(self._thread_safe_session_enabled)

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._telemetry_client = TelemetryClient(self._conn)
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
import threading
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
from snowflake.snowpark._internal.utils import create_rlock

if TYPE_CHECKING:
from snowflake.snowpark.session import Session # pragma: no cover
Expand All @@ -33,7 +33,7 @@ def __init__(self, session: "Session") -> None:
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = threading.RLock()
self.lock = create_rlock(session._conn._thread_safe_session_enabled)

def add(self, table: SnowflakeTable) -> None:
with self.lock:
Expand Down
48 changes: 47 additions & 1 deletion src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ def normalize_path(path: str, is_local: bool) -> str:
return f"'{path}'"


def warn_session_config_update_in_multithreaded_mode(config) -> None:
def warn_session_config_update_in_multithreaded_mode(
config: str, thread_safe_mode_enabled: bool
) -> None:
if not thread_safe_mode_enabled:
return

if threading.active_count() > 1:
logger.warning(
"You might have more than one threads sharing the Session object trying to update "
Expand Down Expand Up @@ -675,6 +680,47 @@ def warning(self, text: str) -> None:
self.count += 1


# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout
class DummyRLock:
"""This is a dummy lock that is used in place of threading.Rlock when multithreading is
disabled."""

def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def acquire(self, *args, **kwargs):
pass # pragma: no cover

def release(self, *args, **kwargs):
pass # pragma: no cover


class DummyThreadLocal:
"""This is a dummy thread local class that is used in place of threading.local when
multithreading is disabled."""

pass


def create_thread_local(
thread_safe_session_enabled: bool,
) -> Union[threading.local, DummyThreadLocal]:
if thread_safe_session_enabled:
return threading.local()
return DummyThreadLocal()


def create_rlock(
thread_safe_session_enabled: bool,
) -> Union[threading.RLock, DummyRLock]:
if thread_safe_session_enabled:
return threading.RLock()
return DummyRLock()


warning_dict: Dict[str, WarningHelper] = {}


Expand Down
23 changes: 15 additions & 8 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import functools
import json
import logging
import threading
import uuid
from copy import copy
from decimal import Decimal
Expand All @@ -30,6 +29,7 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE
from snowflake.snowpark._internal.utils import (
create_rlock,
is_in_stored_procedure,
result_set_to_rows,
)
Expand Down Expand Up @@ -281,19 +281,26 @@ def read_table_if_exists(
def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
self._conn = MockedSnowflakeConnection()
self._cursor = Mock()
self._lock = threading.RLock()
self._options = options or {}
session_params = self._options.get("session_parameters", {})
# thread safe param protection
self._thread_safe_session_enabled = session_params.get(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._lower_case_parameters = {}
self.remove_query_listener = Mock()
self.add_query_listener = Mock()
self._telemetry_client = Mock()
self.entity_registry = MockServerConnection.TabularEntityRegistry(self)
self.stage_registry = StageEntityRegistry(self)
self._conn._session_parameters = {
"ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False,
"_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True,
"_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True,
}
self._options = options or {}
self._conn._session_parameters = session_params.update(
{
"ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False,
"_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True,
"_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True,
}
)
self._active_account = self._options.get(
"account", snowflake.snowpark.mock._constants.CURRENT_ACCOUNT
)
Expand Down
42 changes: 28 additions & 14 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import re
import sys
import tempfile
import threading
import warnings
from array import array
from functools import reduce
Expand Down Expand Up @@ -93,6 +92,8 @@
PythonObjJSONEncoder,
TempObjectType,
calculate_checksum,
create_rlock,
create_thread_local,
deprecated,
escape_quotes,
experimental,
Expand Down Expand Up @@ -234,6 +235,9 @@
_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = (
"PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND"
)
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = (
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION"
)
# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
# in Snowflake. This is the limit where we start seeing compilation errors.
DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000
Expand Down Expand Up @@ -520,13 +524,6 @@ def __init__(
if len(_active_sessions) >= 1 and is_in_stored_procedure():
raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP()
self._conn = conn
self._thread_store = threading.local()
self._lock = threading.RLock()

# this lock is used to protect _packages. We use introduce a new lock because add_packages
# launches a query to snowflake to get all version of packages available in snowflake. This
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = threading.RLock()
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
self._packages: Dict[str, str] = {}
Expand Down Expand Up @@ -618,6 +615,17 @@ def __init__(
DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND,
),
)

self._thread_store = create_thread_local(
self._conn._thread_safe_session_enabled
)
self._lock = create_rlock(self._conn._thread_safe_session_enabled)

# this lock is used to protect _packages. We use introduce a new lock because add_packages
# launches a query to snowflake to get all version of packages available in snowflake. This
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = create_rlock(self._conn._thread_safe_session_enabled)

self._custom_package_usage_config: Dict = {}
self._conf = self.RuntimeConfig(self, options or {})
self._runtime_version_from_requirement: str = None
Expand Down Expand Up @@ -778,7 +786,9 @@ def custom_package_usage_config(self) -> Dict:

@sql_simplifier_enabled.setter
def sql_simplifier_enabled(self, value: bool) -> None:
warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled")
warn_session_config_update_in_multithreaded_mode(
"sql_simplifier_enabled", self._conn._thread_safe_session_enabled
)

with self._lock:
self._conn._telemetry_client.send_sql_simplifier_telemetry(
Expand All @@ -795,7 +805,9 @@ def sql_simplifier_enabled(self, value: bool) -> None:
@cte_optimization_enabled.setter
@experimental_parameter(version="1.15.0")
def cte_optimization_enabled(self, value: bool) -> None:
warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled")
warn_session_config_update_in_multithreaded_mode(
"cte_optimization_enabled", self._conn._thread_safe_session_enabled
)

with self._lock:
if value:
Expand All @@ -809,7 +821,8 @@ def cte_optimization_enabled(self, value: bool) -> None:
def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None:
"""Set the value for eliminate_numeric_sql_value_cast_enabled"""
warn_session_config_update_in_multithreaded_mode(
"eliminate_numeric_sql_value_cast_enabled"
"eliminate_numeric_sql_value_cast_enabled",
self._conn._thread_safe_session_enabled,
)

if value in [True, False]:
Expand All @@ -828,7 +841,7 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None:
def auto_clean_up_temp_table_enabled(self, value: bool) -> None:
"""Set the value for auto_clean_up_temp_table_enabled"""
warn_session_config_update_in_multithreaded_mode(
"auto_clean_up_temp_table_enabled"
"auto_clean_up_temp_table_enabled", self._conn._thread_safe_session_enabled
)

if value in [True, False]:
Expand All @@ -851,7 +864,7 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
overall performance.
"""
warn_session_config_update_in_multithreaded_mode(
"large_query_breakdown_enabled"
"large_query_breakdown_enabled", self._conn._thread_safe_session_enabled
)

if value in [True, False]:
Expand All @@ -869,7 +882,8 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None:
"""Set the lower and upper bounds for the complexity score used in large query breakdown optimization."""
warn_session_config_update_in_multithreaded_mode(
"large_query_breakdown_complexity_bounds"
"large_query_breakdown_complexity_bounds",
self._conn._thread_safe_session_enabled,
)

if len(value) != 2:
Expand Down
Loading

0 comments on commit 367dd23

Please sign in to comment.