Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1720835 param protect thread-safe client side changes #2401

Merged
merged 48 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fb4ecf6
add locks
sfc-gh-aalam Sep 11, 2024
f720701
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 12, 2024
eca13dc
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 17, 2024
96949be
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Sep 18, 2024
5f140ab
SNOW-1418523 make analyzer server connection thread safe (#2282)
sfc-gh-aalam Sep 25, 2024
0624824
SNOW-1418523: concurrent file operations (#2288)
sfc-gh-aalam Sep 25, 2024
42d6e19
SNOW-1418523: make udf and sproc registration thread safe (#2289)
sfc-gh-aalam Sep 25, 2024
801ad6e
merge with main
sfc-gh-aalam Oct 2, 2024
c7fa3ae
Merge branch 'main' into aalam-SNOW-1418523-make-internal-session-var…
sfc-gh-aalam Oct 3, 2024
5672a1d
SNOW-1663726 make session config updates thread safe (#2302)
sfc-gh-aalam Oct 4, 2024
bd0528d
SNOW-1663726 make temp table cleaner thread safe (#2309)
sfc-gh-aalam Oct 4, 2024
39a07d4
SNOW-1642189: collect telemetry about concurrency usage (#2316)
sfc-gh-aalam Oct 4, 2024
4d4e257
SNOW-1546090 add merge gate for future thread safe updates (#2323)
sfc-gh-aalam Oct 4, 2024
5ecb0b4
param protect lock changes
sfc-gh-aalam Oct 4, 2024
66374ee
add plan-builder that was accidentally removed
sfc-gh-aalam Oct 4, 2024
3bec695
Add dummythreadlocal and protect server_connection
sfc-gh-aalam Oct 4, 2024
d41138a
add todo
sfc-gh-aalam Oct 4, 2024
42ca571
undo fixture
sfc-gh-aalam Oct 4, 2024
ee3ce32
fix init
sfc-gh-aalam Oct 4, 2024
816b1d9
fix param read
sfc-gh-aalam Oct 4, 2024
f1ab835
Merge branch 'aalam-SNOW-1418523-make-internal-session-variables-thre…
sfc-gh-aalam Oct 4, 2024
75bb86c
fix test
sfc-gh-aalam Oct 5, 2024
5e447fd
fix test
sfc-gh-aalam Oct 5, 2024
7e7b47a
enable thread-safe session for tests
sfc-gh-aalam Oct 5, 2024
a0b259d
fix tests
sfc-gh-aalam Oct 6, 2024
97de868
fix option name
sfc-gh-aalam Oct 7, 2024
5f9200c
merge with main
sfc-gh-aalam Oct 14, 2024
312b5ad
add cursor created test
sfc-gh-aalam Oct 15, 2024
e0bcfb2
remove commented lines
sfc-gh-aalam Oct 15, 2024
4b83442
fix tests
sfc-gh-aalam Oct 15, 2024
d1592ba
address comments
sfc-gh-aalam Oct 16, 2024
1d63131
address comments
sfc-gh-aalam Oct 16, 2024
1bde81e
minor updates
sfc-gh-aalam Oct 16, 2024
f494010
fix tests
sfc-gh-aalam Oct 16, 2024
1de5526
fix tests
sfc-gh-aalam Oct 16, 2024
8ed4de4
fix tests
sfc-gh-aalam Oct 16, 2024
a155c03
Merge branch 'main' into SNOW-1720835-param-protect-client-side-changes
sfc-gh-aalam Oct 16, 2024
adbc395
disable multithreading mode by default
sfc-gh-aalam Oct 16, 2024
56dd2ce
comment upate
sfc-gh-aalam Oct 16, 2024
27a667d
fixture
sfc-gh-aalam Oct 16, 2024
eef3568
fix coverage
sfc-gh-aalam Oct 16, 2024
d22575c
address comments
sfc-gh-aalam Oct 16, 2024
46b3ca8
address comments
sfc-gh-aalam Oct 16, 2024
55844ea
address comments
sfc-gh-aalam Oct 16, 2024
4e481b2
Merge branch 'main' into SNOW-1720835-param-protect-client-side-changes
sfc-gh-aalam Oct 16, 2024
56169c5
fix integ
sfc-gh-aalam Oct 17, 2024
d491367
Merge branch 'SNOW-1720835-param-protect-client-side-changes' of gith…
sfc-gh-aalam Oct 17, 2024
7c6ab40
fix mock
sfc-gh-aalam Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
### Snowpark Python API Updates

- Added support for 'Service' domain to `session.lineage.trace` API.
- Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads.

#### New Features

Expand Down
11 changes: 9 additions & 2 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
escape_quotes,
get_application_name,
get_version,
Expand Down Expand Up @@ -155,8 +157,6 @@ def __init__(
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
) -> None:
self._lock = threading.RLock()
self._thread_store = threading.local()
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand All @@ -171,6 +171,13 @@ def __init__(
except TypeError:
pass

# thread safe param protection
self._thread_safe_session_enabled = self._get_client_side_session_parameter(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come we have two parameter control here? what are those two ways ?

"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._thread_store = create_thread_local(self._thread_safe_session_enabled)

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
self._telemetry_client = TelemetryClient(self._conn)
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
import threading
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
from snowflake.snowpark._internal.utils import create_rlock

if TYPE_CHECKING:
from snowflake.snowpark.session import Session # pragma: no cover
Expand All @@ -33,7 +33,7 @@ def __init__(self, session: "Session") -> None:
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = threading.RLock()
self.lock = create_rlock(session._conn._thread_safe_session_enabled)

def add(self, table: SnowflakeTable) -> None:
with self.lock:
Expand Down
48 changes: 47 additions & 1 deletion src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ def normalize_path(path: str, is_local: bool) -> str:
return f"'{path}'"


def warn_session_config_update_in_multithreaded_mode(config) -> None:
def warn_session_config_update_in_multithreaded_mode(
config: str, thread_safe_mode_enabled: bool
) -> None:
if not thread_safe_mode_enabled:
return

if threading.active_count() > 1:
logger.warning(
"You might have more than one threads sharing the Session object trying to update "
Expand Down Expand Up @@ -675,6 +680,47 @@ def warning(self, text: str) -> None:
self.count += 1


# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout
class DummyRLock:
"""This is a dummy lock that is used in place of threading.Rlock when multithreading is
disabled."""

def __enter__(self):
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def acquire(self, *args, **kwargs):
pass # pragma: no cover

def release(self, *args, **kwargs):
pass # pragma: no cover


class DummyThreadLocal:
"""This is a dummy thread local class that is used in place of threading.local when
multithreading is disabled."""

pass


def create_thread_local(
thread_safe_session_enabled: bool,
) -> Union[threading.local, DummyThreadLocal]:
if thread_safe_session_enabled:
return threading.local()
return DummyThreadLocal()


def create_rlock(
thread_safe_session_enabled: bool,
) -> Union[threading.RLock, DummyRLock]:
if thread_safe_session_enabled:
return threading.RLock()
return DummyRLock()


warning_dict: Dict[str, WarningHelper] = {}


Expand Down
23 changes: 15 additions & 8 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import functools
import json
import logging
import threading
import uuid
from copy import copy
from decimal import Decimal
Expand All @@ -30,6 +29,7 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE
from snowflake.snowpark._internal.utils import (
create_rlock,
is_in_stored_procedure,
result_set_to_rows,
)
Expand Down Expand Up @@ -281,19 +281,26 @@ def read_table_if_exists(
def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
self._conn = MockedSnowflakeConnection()
self._cursor = Mock()
self._lock = threading.RLock()
self._options = options or {}
session_params = self._options.get("session_parameters", {})
# thread safe param protection
self._thread_safe_session_enabled = session_params.get(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._lower_case_parameters = {}
self.remove_query_listener = Mock()
self.add_query_listener = Mock()
self._telemetry_client = Mock()
self.entity_registry = MockServerConnection.TabularEntityRegistry(self)
self.stage_registry = StageEntityRegistry(self)
self._conn._session_parameters = {
"ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False,
"_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True,
"_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True,
}
self._options = options or {}
self._conn._session_parameters = session_params.update(
{
"ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False,
"_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True,
"_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True,
}
)
self._active_account = self._options.get(
"account", snowflake.snowpark.mock._constants.CURRENT_ACCOUNT
)
Expand Down
42 changes: 28 additions & 14 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import re
import sys
import tempfile
import threading
import warnings
from array import array
from functools import reduce
Expand Down Expand Up @@ -93,6 +92,8 @@
PythonObjJSONEncoder,
TempObjectType,
calculate_checksum,
create_rlock,
create_thread_local,
deprecated,
escape_quotes,
experimental,
Expand Down Expand Up @@ -234,6 +235,9 @@
_PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = (
"PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND"
)
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = (
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION"
)
# The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT
# in Snowflake. This is the limit where we start seeing compilation errors.
DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000
Expand Down Expand Up @@ -520,13 +524,6 @@ def __init__(
if len(_active_sessions) >= 1 and is_in_stored_procedure():
raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP()
self._conn = conn
self._thread_store = threading.local()
self._lock = threading.RLock()

# this lock is used to protect _packages. We use introduce a new lock because add_packages
# launches a query to snowflake to get all version of packages available in snowflake. This
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = threading.RLock()
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
self._packages: Dict[str, str] = {}
Expand Down Expand Up @@ -618,6 +615,17 @@ def __init__(
DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND,
),
)

self._thread_store = create_thread_local(
self._conn._thread_safe_session_enabled
)
self._lock = create_rlock(self._conn._thread_safe_session_enabled)

# this lock is used to protect _packages. We use introduce a new lock because add_packages
# launches a query to snowflake to get all version of packages available in snowflake. This
# query can be slow and prevent other threads from moving on waiting for _lock.
self._package_lock = create_rlock(self._conn._thread_safe_session_enabled)

self._custom_package_usage_config: Dict = {}
self._conf = self.RuntimeConfig(self, options or {})
self._runtime_version_from_requirement: str = None
Expand Down Expand Up @@ -778,7 +786,9 @@ def custom_package_usage_config(self) -> Dict:

@sql_simplifier_enabled.setter
def sql_simplifier_enabled(self, value: bool) -> None:
warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled")
warn_session_config_update_in_multithreaded_mode(
"sql_simplifier_enabled", self._conn._thread_safe_session_enabled
)

with self._lock:
self._conn._telemetry_client.send_sql_simplifier_telemetry(
Expand All @@ -795,7 +805,9 @@ def sql_simplifier_enabled(self, value: bool) -> None:
@cte_optimization_enabled.setter
@experimental_parameter(version="1.15.0")
def cte_optimization_enabled(self, value: bool) -> None:
warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled")
warn_session_config_update_in_multithreaded_mode(
"cte_optimization_enabled", self._conn._thread_safe_session_enabled
)

with self._lock:
if value:
Expand All @@ -809,7 +821,8 @@ def cte_optimization_enabled(self, value: bool) -> None:
def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None:
"""Set the value for eliminate_numeric_sql_value_cast_enabled"""
warn_session_config_update_in_multithreaded_mode(
"eliminate_numeric_sql_value_cast_enabled"
"eliminate_numeric_sql_value_cast_enabled",
self._conn._thread_safe_session_enabled,
)

if value in [True, False]:
Expand All @@ -828,7 +841,7 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None:
def auto_clean_up_temp_table_enabled(self, value: bool) -> None:
"""Set the value for auto_clean_up_temp_table_enabled"""
warn_session_config_update_in_multithreaded_mode(
"auto_clean_up_temp_table_enabled"
"auto_clean_up_temp_table_enabled", self._conn._thread_safe_session_enabled
)

if value in [True, False]:
Expand All @@ -851,7 +864,7 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
overall performance.
"""
warn_session_config_update_in_multithreaded_mode(
"large_query_breakdown_enabled"
"large_query_breakdown_enabled", self._conn._thread_safe_session_enabled
)

if value in [True, False]:
Expand All @@ -869,7 +882,8 @@ def large_query_breakdown_enabled(self, value: bool) -> None:
def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None:
"""Set the lower and upper bounds for the complexity score used in large query breakdown optimization."""
warn_session_config_update_in_multithreaded_mode(
"large_query_breakdown_complexity_bounds"
"large_query_breakdown_complexity_bounds",
self._conn._thread_safe_session_enabled,
)

if len(value) != 2:
Expand Down
3 changes: 3 additions & 0 deletions tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def session(
session = (
Session.builder.configs(db_parameters)
.config("local_testing", local_testing_mode)
.config(
"session_parameters", {"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": False}
)
.create()
)
session.sql_simplifier_enabled = sql_simplifier_enabled
Expand Down
49 changes: 48 additions & 1 deletion tests/integ/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

import pytest

from snowflake.snowpark.session import Session
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION,
Session,
)
from snowflake.snowpark.types import IntegerType
from tests.integ.test_temp_table_cleanup import wait_for_drop_table_sql_done

Expand All @@ -32,6 +35,21 @@
from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils


@pytest.fixture(scope="module")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use a function scope here to avoid messing with other tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to threadsafe_session. this would not mess with others, right?

def session(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's call this mutlithread_session instead of session

db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode
):
new_db_parameters = db_parameters.copy()
new_db_parameters["local_testing"] = local_testing_mode
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True
}
with Session.builder.configs(new_db_parameters).create() as session:
session._sql_simplifier_enabled = sql_simplifier_enabled
session._cte_optimization_enabled = cte_optimization_enabled
yield session


def test_concurrent_select_queries(session):
def run_select(session_, thread_id):
df = session_.sql(f"SELECT {thread_id} as A")
Expand Down Expand Up @@ -580,3 +598,32 @@ def change_config_value(session_):
f"You might have more than one threads sharing the Session object trying to update {config}"
in caplog.text
)


@pytest.mark.parametrize("is_enabled", [True, False])
def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode):
if is_enabled and local_testing_mode:
pytest.skip("Multithreading is enabled by default in local testing mode")

num_workers = 5 if is_enabled else 1
new_db_parameters = db_parameters.copy()
new_db_parameters["session_parameters"] = {
_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: is_enabled
}

with Session.builder.configs(new_db_parameters).create() as new_session:

def run_query(session_, thread_id):
assert session_.sql(f"SELECT {thread_id} as A").collect()[0][0] == thread_id

with patch.object(
new_session._conn._telemetry_client, "send_cursor_created_telemetry"
) as mock_telemetry:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
for i in range(10):
executor.submit(run_query, new_session, i)

# when multithreading is enabled, each worker will create a cursor
# otherwise, we will use the same cursor created by the main thread
# thus creating 0 new cursors.
assert mock_telemetry.call_count == (num_workers if is_enabled else 0)
5 changes: 4 additions & 1 deletion tests/mock/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

@pytest.fixture(scope="function")
def session():
with Session(MockServerConnection()) as s:
options = {
"server_parameters": {"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": True}
}
with Session(MockServerConnection(options)) as s:
yield s


Expand Down
Loading