diff --git a/CHANGELOG.md b/CHANGELOG.md index 113617e20df..b0e488d718a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ ### Snowpark Python API Updates - Added support for 'Service' domain to `session.lineage.trace` API. -- Updated `Session` class to be thread-safe. This allows concurrent dataframe transformations, dataframe actions, UDF and store procedure registration, and concurrent file uploads. - Added support for `copy_grants` parameter when registering UDxF and stored procedures. #### New Features diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index 1b57697ad6d..ae7c635a88d 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -49,6 +49,8 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import TelemetryClient from snowflake.snowpark._internal.utils import ( + create_rlock, + create_thread_local, escape_quotes, get_application_name, get_version, @@ -155,8 +157,6 @@ def __init__( options: Dict[str, Union[int, str]], conn: Optional[SnowflakeConnection] = None, ) -> None: - self._lock = threading.RLock() - self._thread_store = threading.local() self._lower_case_parameters = {k.lower(): v for k, v in options.items()} self._add_application_parameters() self._conn = conn if conn else connect(**self._lower_case_parameters) @@ -171,6 +171,13 @@ def __init__( except TypeError: pass + # thread safe param protection + self._thread_safe_session_enabled = self._get_client_side_session_parameter( + "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False + ) + self._lock = create_rlock(self._thread_safe_session_enabled) + self._thread_store = create_thread_local(self._thread_safe_session_enabled) + if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None self._telemetry_client = TelemetryClient(self._conn) diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index 69453d48596..d5b8387a268 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -2,12 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import logging -import threading import weakref from collections import defaultdict from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable +from snowflake.snowpark._internal.utils import create_rlock if TYPE_CHECKING: from snowflake.snowpark.session import Session # pragma: no cover @@ -33,7 +33,7 @@ def __init__(self, session: "Session") -> None: # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) # Lock to protect the ref_count_map - self.lock = threading.RLock() + self.lock = create_rlock(session._conn._thread_safe_session_enabled) def add(self, table: SnowflakeTable) -> None: with self.lock: diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 8783faa39d6..6f45605b3ce 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -298,7 +298,12 @@ def normalize_path(path: str, is_local: bool) -> str: return f"'{path}'" -def warn_session_config_update_in_multithreaded_mode(config) -> None: +def warn_session_config_update_in_multithreaded_mode( + config: str, thread_safe_mode_enabled: bool +) -> None: + if not thread_safe_mode_enabled: + return + if threading.active_count() > 1: logger.warning( "You might have more than one threads sharing the Session object trying to update " @@ -675,6 +680,47 @@ def warning(self, text: str) -> None: self.count += 1 +# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout +class DummyRLock: + """This is a dummy lock that is used in place of threading.Rlock when multithreading is + disabled.""" + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def acquire(self, *args, **kwargs): + pass # pragma: no cover + + def release(self, *args, **kwargs): + pass # pragma: no cover + + +class DummyThreadLocal: + """This is a dummy thread local class that is used in place of threading.local when + multithreading is disabled.""" + + pass + + +def create_thread_local( + thread_safe_session_enabled: bool, +) -> Union[threading.local, DummyThreadLocal]: + if thread_safe_session_enabled: + return threading.local() + return DummyThreadLocal() + + +def create_rlock( + thread_safe_session_enabled: bool, +) -> Union[threading.RLock, DummyRLock]: + if thread_safe_session_enabled: + return threading.RLock() + return DummyRLock() + + warning_dict: Dict[str, WarningHelper] = {} diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index b384931cb89..37a2d77a446 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -6,7 +6,6 @@ import functools import json import logging -import threading import uuid from copy import copy from decimal import Decimal @@ -30,6 +29,7 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE from snowflake.snowpark._internal.utils import ( + create_rlock, is_in_stored_procedure, result_set_to_rows, ) @@ -281,19 +281,26 @@ def read_table_if_exists( def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self._conn = MockedSnowflakeConnection() self._cursor = Mock() - self._lock = threading.RLock() + self._options = options or {} + session_params = self._options.get("session_parameters", {}) + # thread safe param protection + self._thread_safe_session_enabled = session_params.get( + "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False + ) + self._lock = create_rlock(self._thread_safe_session_enabled) self._lower_case_parameters = {} self.remove_query_listener = Mock() self.add_query_listener = Mock() self._telemetry_client = Mock() self.entity_registry = MockServerConnection.TabularEntityRegistry(self) self.stage_registry = StageEntityRegistry(self) - self._conn._session_parameters = { - "ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False, - "_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True, - "_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True, - } - self._options = options or {} + self._conn._session_parameters = session_params.update( + { + "ENABLE_ASYNC_QUERY_IN_PYTHON_STORED_PROCS": False, + "_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING": True, + "_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING": True, + } + ) self._active_account = self._options.get( "account", snowflake.snowpark.mock._constants.CURRENT_ACCOUNT ) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 42ead058478..81ffb9bc82d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -13,7 +13,6 @@ import re import sys import tempfile -import threading import warnings from array import array from functools import reduce @@ -93,6 +92,8 @@ PythonObjJSONEncoder, TempObjectType, calculate_checksum, + create_rlock, + create_thread_local, deprecated, escape_quotes, experimental, @@ -234,6 +235,9 @@ _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = ( "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND" ) +_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = ( + "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION" +) # The complexity score lower bound is set to match COMPILATION_MEMORY_LIMIT # in Snowflake. This is the limit where we start seeing compilation errors. DEFAULT_COMPLEXITY_SCORE_LOWER_BOUND = 10_000_000 @@ -520,13 +524,6 @@ def __init__( if len(_active_sessions) >= 1 and is_in_stored_procedure(): raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP() self._conn = conn - self._thread_store = threading.local() - self._lock = threading.RLock() - - # this lock is used to protect _packages. We use introduce a new lock because add_packages - # launches a query to snowflake to get all version of packages available in snowflake. This - # query can be slow and prevent other threads from moving on waiting for _lock. - self._package_lock = threading.RLock() self._query_tag = None self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {} self._packages: Dict[str, str] = {} @@ -618,6 +615,17 @@ def __init__( DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, ), ) + + self._thread_store = create_thread_local( + self._conn._thread_safe_session_enabled + ) + self._lock = create_rlock(self._conn._thread_safe_session_enabled) + + # this lock is used to protect _packages. We use introduce a new lock because add_packages + # launches a query to snowflake to get all version of packages available in snowflake. This + # query can be slow and prevent other threads from moving on waiting for _lock. + self._package_lock = create_rlock(self._conn._thread_safe_session_enabled) + self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None @@ -778,7 +786,9 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled") + warn_session_config_update_in_multithreaded_mode( + "sql_simplifier_enabled", self._conn._thread_safe_session_enabled + ) with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( @@ -795,7 +805,9 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled") + warn_session_config_update_in_multithreaded_mode( + "cte_optimization_enabled", self._conn._thread_safe_session_enabled + ) with self._lock: if value: @@ -809,7 +821,8 @@ def cte_optimization_enabled(self, value: bool) -> None: def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" warn_session_config_update_in_multithreaded_mode( - "eliminate_numeric_sql_value_cast_enabled" + "eliminate_numeric_sql_value_cast_enabled", + self._conn._thread_safe_session_enabled, ) if value in [True, False]: @@ -828,7 +841,7 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" warn_session_config_update_in_multithreaded_mode( - "auto_clean_up_temp_table_enabled" + "auto_clean_up_temp_table_enabled", self._conn._thread_safe_session_enabled ) if value in [True, False]: @@ -851,7 +864,7 @@ def large_query_breakdown_enabled(self, value: bool) -> None: overall performance. """ warn_session_config_update_in_multithreaded_mode( - "large_query_breakdown_enabled" + "large_query_breakdown_enabled", self._conn._thread_safe_session_enabled ) if value in [True, False]: @@ -869,7 +882,8 @@ def large_query_breakdown_enabled(self, value: bool) -> None: def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" warn_session_config_update_in_multithreaded_mode( - "large_query_breakdown_complexity_bounds" + "large_query_breakdown_complexity_bounds", + self._conn._thread_safe_session_enabled, ) if len(value) != 2: diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index a6d13584bcf..b427ef8aa40 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -13,7 +13,10 @@ import pytest -from snowflake.snowpark.session import Session +from snowflake.snowpark.session import ( + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION, + Session, +) from snowflake.snowpark.types import IntegerType from tests.integ.test_temp_table_cleanup import wait_for_drop_table_sql_done @@ -32,21 +35,51 @@ from tests.utils import IS_IN_STORED_PROC, IS_LINUX, IS_WINDOWS, TestFiles, Utils -def test_concurrent_select_queries(session): +@pytest.fixture(scope="module") +def threadsafe_session( + db_parameters, sql_simplifier_enabled, cte_optimization_enabled, local_testing_mode +): + new_db_parameters = db_parameters.copy() + new_db_parameters["local_testing"] = local_testing_mode + new_db_parameters["session_parameters"] = { + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True + } + with Session.builder.configs(new_db_parameters).create() as session: + session._sql_simplifier_enabled = sql_simplifier_enabled + session._cte_optimization_enabled = cte_optimization_enabled + yield session + + +@pytest.fixture(scope="function") +def threadsafe_temp_stage(threadsafe_session, resources_path, local_testing_mode): + tmp_stage_name = Utils.random_stage_name() + test_files = TestFiles(resources_path) + + if not local_testing_mode: + Utils.create_stage(threadsafe_session, tmp_stage_name, is_temporary=True) + Utils.upload_to_stage( + threadsafe_session, tmp_stage_name, test_files.test_file_parquet, compress=False + ) + yield tmp_stage_name + if not local_testing_mode: + Utils.drop_stage(threadsafe_session, tmp_stage_name) + + +def test_concurrent_select_queries(threadsafe_session): def run_select(session_, thread_id): df = session_.sql(f"SELECT {thread_id} as A") assert df.collect()[0][0] == thread_id with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(run_select, session, i) + executor.submit(run_select, threadsafe_session, i) -def test_concurrent_dataframe_operations(session): +def test_concurrent_dataframe_operations(threadsafe_session): try: table_name = Utils.random_table_name() data = [(i, 11 * i) for i in range(10)] - df = session.create_dataframe(data, ["A", "B"]) + df = threadsafe_session.create_dataframe(data, ["A", "B"]) df.write.save_as_table(table_name, table_type="temporary") def run_dataframe_operation(session_, thread_id): @@ -59,7 +92,8 @@ def run_dataframe_operation(session_, thread_id): dfs = [] with ThreadPoolExecutor(max_workers=10) as executor: df_futures = [ - executor.submit(run_dataframe_operation, session, i) for i in range(10) + executor.submit(run_dataframe_operation, threadsafe_session, i) + for i in range(10) ] for future in as_completed(df_futures): @@ -74,7 +108,7 @@ def run_dataframe_operation(session_, thread_id): ) finally: - Utils.drop_table(session, table_name) + Utils.drop_table(threadsafe_session, table_name) @pytest.mark.xfail( @@ -82,14 +116,14 @@ def run_dataframe_operation(session_, thread_id): reason="SQL query and query listeners are not supported", run=False, ) -def test_query_listener(session): +def test_query_listener(threadsafe_session): def run_select(session_, thread_id): session_.sql(f"SELECT {thread_id} as A").collect() - with session.query_history() as history: + with threadsafe_session.query_history() as history: with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(run_select, session, i) + executor.submit(run_select, threadsafe_session, i) queries_sent = [query.sql_text for query in history.queries] assert len(queries_sent) == 10 @@ -105,16 +139,18 @@ def run_select(session_, thread_id): @pytest.mark.skipif( IS_IN_STORED_PROC, reason="show parameters is not supported in stored procedure" ) -def test_query_tagging(session): +def test_query_tagging(threadsafe_session): def set_query_tag(session_, thread_id): session_.query_tag = f"tag_{thread_id}" with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(set_query_tag, session, i) + executor.submit(set_query_tag, threadsafe_session, i) - actual_query_tag = session.sql("SHOW PARAMETERS LIKE 'QUERY_TAG'").collect()[0][1] - assert actual_query_tag == session.query_tag + actual_query_tag = threadsafe_session.sql( + "SHOW PARAMETERS LIKE 'QUERY_TAG'" + ).collect()[0][1] + assert actual_query_tag == threadsafe_session.query_tag @pytest.mark.xfail( @@ -122,21 +158,24 @@ def set_query_tag(session_, thread_id): reason="SQL query is not supported", run=False, ) -def test_session_stage_created_once(session): +def test_session_stage_created_once(threadsafe_session): with patch.object( - session._conn, "run_query", wraps=session._conn.run_query + threadsafe_session._conn, "run_query", wraps=threadsafe_session._conn.run_query ) as patched_run_query: with ThreadPoolExecutor(max_workers=10) as executor: for _ in range(10): - executor.submit(session.get_session_stage) + executor.submit(threadsafe_session.get_session_stage) assert patched_run_query.call_count == 1 -def test_action_ids_are_unique(session): +def test_action_ids_are_unique(threadsafe_session): with ThreadPoolExecutor(max_workers=10) as executor: action_ids = set() - futures = [executor.submit(session._generate_new_action_id) for _ in range(10)] + futures = [ + executor.submit(threadsafe_session._generate_new_action_id) + for _ in range(10) + ] for future in as_completed(futures): action_ids.add(future.result()) @@ -145,9 +184,9 @@ def test_action_ids_are_unique(session): @pytest.mark.parametrize("use_stream", [True, False]) -def test_file_io(session, resources_path, temp_stage, use_stream): +def test_file_io(threadsafe_session, resources_path, threadsafe_temp_stage, use_stream): stage_prefix = f"prefix_{Utils.random_alphanumeric_str(10)}" - stage_with_prefix = f"@{temp_stage}/{stage_prefix}/" + stage_with_prefix = f"@{threadsafe_temp_stage}/{stage_prefix}/" test_files = TestFiles(resources_path) resources_files = [ @@ -170,11 +209,11 @@ def get_file_hash(fd): def put_and_get_file(upload_file_path, download_dir): if use_stream: with open(upload_file_path, "rb") as fd: - results = session.file.put_stream( + results = threadsafe_session.file.put_stream( fd, stage_with_prefix, auto_compress=False, overwrite=False ) else: - results = session.file.put( + results = threadsafe_session.file.put( upload_file_path, stage_with_prefix, auto_compress=False, @@ -186,12 +225,12 @@ def put_and_get_file(upload_file_path, download_dir): stage_file_name = f"{stage_with_prefix}{os.path.basename(upload_file_path)}" if use_stream: - fd = session.file.get_stream(stage_file_name, download_dir) + fd = threadsafe_session.file.get_stream(stage_file_name, download_dir) with open(upload_file_path, "rb") as upload_fd: assert get_file_hash(upload_fd) == get_file_hash(fd) else: - results = session.file.get(stage_file_name, download_dir) + results = threadsafe_session.file.get(stage_file_name, download_dir) # assert file is downloaded successfully assert len(results) == 1 assert results[0].status == "DOWNLOADED" @@ -214,7 +253,7 @@ def put_and_get_file(upload_file_path, download_dir): } -def test_concurrent_add_packages(session): +def test_concurrent_add_packages(threadsafe_session): # this is a list of packages available in snowflake anaconda. If this # test fails due to packages not being available, please update the list package_list = { @@ -229,21 +268,21 @@ def test_concurrent_add_packages(session): try: with ThreadPoolExecutor(max_workers=10) as executor: futures = [ - executor.submit(session.add_packages, package) + executor.submit(threadsafe_session.add_packages, package) for package in package_list ] for future in as_completed(futures): future.result() - assert session.get_packages() == { + assert threadsafe_session.get_packages() == { package: package for package in package_list } finally: - session.clear_packages() + threadsafe_session.clear_packages() -def test_concurrent_remove_package(session): +def test_concurrent_remove_package(threadsafe_session): def remove_package(session_, package_name): try: session_.remove_package(package_name) @@ -254,11 +293,12 @@ def remove_package(session_, package_name): raise e try: - session.add_packages("numpy") + threadsafe_session.add_packages("numpy") with ThreadPoolExecutor(max_workers=10) as executor: futures = [ - executor.submit(remove_package, session, "numpy") for _ in range(10) + executor.submit(remove_package, threadsafe_session, "numpy") + for _ in range(10) ] success_count, failure_count = 0, 0 for future in as_completed(futures): @@ -271,11 +311,11 @@ def remove_package(session_, package_name): assert success_count == 1 assert failure_count == 9 finally: - session.clear_packages() + threadsafe_session.clear_packages() @pytest.mark.skipif(not is_dateutil_available, reason="dateutil is not available") -def test_concurrent_add_import(session, resources_path): +def test_concurrent_add_import(threadsafe_session, resources_path): test_files = TestFiles(resources_path) import_files = [ test_files.test_udf_py_file, @@ -290,18 +330,18 @@ def test_concurrent_add_import(session, resources_path): with ThreadPoolExecutor(max_workers=10) as executor: for file in import_files: executor.submit( - session.add_import, + threadsafe_session.add_import, file, ) - assert set(session.get_imports()) == { + assert set(threadsafe_session.get_imports()) == { os.path.abspath(file) for file in import_files } finally: - session.clear_imports() + threadsafe_session.clear_imports() -def test_concurrent_remove_import(session, resources_path): +def test_concurrent_remove_import(threadsafe_session, resources_path): test_files = TestFiles(resources_path) def remove_import(session_, import_file): @@ -314,10 +354,12 @@ def remove_import(session_, import_file): raise e try: - session.add_import(test_files.test_udf_py_file) + threadsafe_session.add_import(test_files.test_udf_py_file) with ThreadPoolExecutor(max_workers=10) as executor: futures = [ - executor.submit(remove_import, session, test_files.test_udf_py_file) + executor.submit( + remove_import, threadsafe_session, test_files.test_udf_py_file + ) for _ in range(10) ] @@ -332,12 +374,12 @@ def remove_import(session_, import_file): assert success_count == 1 assert failure_count == 9 finally: - session.clear_imports() + threadsafe_session.clear_imports() -def test_concurrent_sp_register(session, tmpdir): +def test_concurrent_sp_register(threadsafe_session, tmpdir): try: - session.add_packages("snowflake-snowpark-python") + threadsafe_session.add_packages("snowflake-snowpark-python") def register_and_test_sp(session_, thread_id): prefix = Utils.random_alphanumeric_str(10) @@ -373,13 +415,13 @@ def add_{thread_id}(session_: Session, x: int) -> int: with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(register_and_test_sp, session, i) + executor.submit(register_and_test_sp, threadsafe_session, i) finally: - session.clear_packages() + threadsafe_session.clear_packages() -def test_concurrent_udf_register(session, tmpdir): - df = session.range(-5, 5).to_df("a") +def test_concurrent_udf_register(threadsafe_session, tmpdir): + df = threadsafe_session.range(-5, 5).to_df("a") def register_and_test_udf(session_, thread_id): prefix = Utils.random_alphanumeric_str(10) @@ -407,7 +449,7 @@ def add_{thread_id}(x: int) -> int: with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(register_and_test_udf, session, i) + executor.submit(register_and_test_udf, threadsafe_session, i) @pytest.mark.xfail( @@ -415,7 +457,7 @@ def add_{thread_id}(x: int) -> int: reason="UDTFs is not supported in local testing mode", run=False, ) -def test_concurrent_udtf_register(session, tmpdir): +def test_concurrent_udtf_register(threadsafe_session, tmpdir): def register_and_test_udtf(session_, thread_id): udtf_body = f""" from typing import List, Tuple @@ -440,14 +482,14 @@ def process( ) echo_udtf = session_.udtf.register(d["UDTFEcho"], output_schema=["num"]) - df_local = session.table_function(echo_udtf(lit(1))) - df_from_file = session.table_function(echo_udtf_from_file(lit(1))) + df_local = threadsafe_session.table_function(echo_udtf(lit(1))) + df_from_file = threadsafe_session.table_function(echo_udtf_from_file(lit(1))) assert df_local.collect() == [(thread_id + 1,)] assert df_from_file.collect() == [(thread_id + 1,)] with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(register_and_test_udtf, session, i) + executor.submit(register_and_test_udtf, threadsafe_session, i) @pytest.mark.xfail( @@ -455,8 +497,10 @@ def process( reason="UDAFs is not supported in local testing mode", run=False, ) -def test_concurrent_udaf_register(session: Session, tmpdir): - df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") +def test_concurrent_udaf_register(threadsafe_session, tmpdir): + df = threadsafe_session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df( + "a", "b" + ) def register_and_test_udaf(session_, thread_id): udaf_body = f""" @@ -504,7 +548,7 @@ def finish(self): with ThreadPoolExecutor(max_workers=10) as executor: for i in range(10): - executor.submit(register_and_test_udaf, session, i) + executor.submit(register_and_test_udaf, threadsafe_session, i) @pytest.mark.xfail( @@ -512,13 +556,15 @@ def finish(self): reason="session.sql is not supported in local testing mode", run=False, ) -def test_auto_temp_table_cleaner(session, caplog): - session._temp_table_auto_cleaner.ref_count_map.clear() - original_auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled - session.auto_clean_up_temp_table_enabled = True +def test_auto_temp_table_cleaner(threadsafe_session, caplog): + threadsafe_session._temp_table_auto_cleaner.ref_count_map.clear() + original_auto_clean_up_temp_table_enabled = ( + threadsafe_session.auto_clean_up_temp_table_enabled + ) + threadsafe_session.auto_clean_up_temp_table_enabled = True def create_temp_table(session_, thread_id): - df = session.sql(f"select {thread_id} as A").cache_result() + df = threadsafe_session.sql(f"select {thread_id} as A").cache_result() table_name = df.table_name del df return table_name @@ -527,21 +573,24 @@ def create_temp_table(session_, thread_id): futures = [] table_names = [] for i in range(10): - futures.append(executor.submit(create_temp_table, session, i)) + futures.append(executor.submit(create_temp_table, threadsafe_session, i)) for future in as_completed(futures): table_names.append(future.result()) gc.collect() - wait_for_drop_table_sql_done(session, caplog, expect_drop=True) + wait_for_drop_table_sql_done(threadsafe_session, caplog, expect_drop=True) try: for table_name in table_names: - assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 - assert session._temp_table_auto_cleaner.num_temp_tables_created == 10 - assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 10 + assert ( + threadsafe_session._temp_table_auto_cleaner.ref_count_map[table_name] + == 0 + ) + assert threadsafe_session._temp_table_auto_cleaner.num_temp_tables_created == 10 + assert threadsafe_session._temp_table_auto_cleaner.num_temp_tables_cleaned == 10 finally: - session.auto_clean_up_temp_table_enabled = ( + threadsafe_session.auto_clean_up_temp_table_enabled = ( original_auto_clean_up_temp_table_enabled ) @@ -561,12 +610,14 @@ def create_temp_table(session_, thread_id): ("large_query_breakdown_complexity_bounds", (20, 30)), ], ) -def test_concurrent_update_on_sensitive_configs(session, config, value, caplog): +def test_concurrent_update_on_sensitive_configs( + threadsafe_session, config, value, caplog +): def change_config_value(session_): session_.conf.set(config, value) caplog.clear() - change_config_value(session) + change_config_value(threadsafe_session) assert ( f"You might have more than one threads sharing the Session object trying to update {config}" not in caplog.text @@ -575,8 +626,37 @@ def change_config_value(session_): with caplog.at_level(logging.WARNING): with ThreadPoolExecutor(max_workers=5) as executor: for _ in range(5): - executor.submit(change_config_value, session) + executor.submit(change_config_value, threadsafe_session) assert ( f"You might have more than one threads sharing the Session object trying to update {config}" in caplog.text ) + + +@pytest.mark.parametrize("is_enabled", [True, False]) +def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode): + if is_enabled and local_testing_mode: + pytest.skip("Multithreading is enabled by default in local testing mode") + + num_workers = 5 if is_enabled else 1 + new_db_parameters = db_parameters.copy() + new_db_parameters["session_parameters"] = { + _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: is_enabled + } + + with Session.builder.configs(new_db_parameters).create() as new_session: + + def run_query(session_, thread_id): + assert session_.sql(f"SELECT {thread_id} as A").collect()[0][0] == thread_id + + with patch.object( + new_session._conn._telemetry_client, "send_cursor_created_telemetry" + ) as mock_telemetry: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + for i in range(10): + executor.submit(run_query, new_session, i) + + # when multithreading is enabled, each worker will create a cursor + # otherwise, we will use the same cursor created by the main thread + # thus creating 0 new cursors. + assert mock_telemetry.call_count == (num_workers if is_enabled else 0) diff --git a/tests/mock/test_multithreading.py b/tests/mock/test_multithreading.py index 5e0078212d6..2adcad82835 100644 --- a/tests/mock/test_multithreading.py +++ b/tests/mock/test_multithreading.py @@ -28,15 +28,31 @@ from tests.utils import Utils -def test_table_update_merge_delete(session): +@pytest.fixture(scope="function", autouse=True) +def threadsafe_server_connection(): + options = { + "session_parameters": {"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": True} + } + s = MockServerConnection(options) + yield s + s.close() + + +@pytest.fixture(scope="function") +def threadsafe_session(threadsafe_server_connection): + with Session(threadsafe_server_connection) as s: + yield s + + +def test_table_update_merge_delete(threadsafe_session): table_name = Utils.random_table_name() num_threads = 10 data = [[v, 11 * v] for v in range(10)] - df = session.create_dataframe(data, schema=["a", "b"]) + df = threadsafe_session.create_dataframe(data, schema=["a", "b"]) df.write.save_as_table(table_name, table_type="temp") source_df = df - t = session.table(table_name) + t = threadsafe_session.table(table_name) def update_table(thread_id: int): t.update({"b": 0}, t.a == lit(thread_id)) @@ -76,18 +92,18 @@ def delete_table(thread_id: int): assert t.count() == 0 -def test_udf_register_and_invoke(session): - df = session.create_dataframe([[1], [2]], schema=["num"]) +def test_udf_register_and_invoke(threadsafe_session): + df = threadsafe_session.create_dataframe([[1], [2]], schema=["num"]) num_threads = 10 def register_udf(x: int): def echo(x: int) -> int: return x - return session.udf.register(echo, name="echo", replace=True) + return threadsafe_session.udf.register(echo, name="echo", replace=True) def invoke_udf(): - result = df.select(session.udf.call_udf("echo", df.num)).collect() + result = df.select(threadsafe_session.udf.call_udf("echo", df.num)).collect() assert result[0][0] == 1 assert result[1][0] == 2 @@ -105,19 +121,19 @@ def invoke_udf(): thread.join() -def test_sp_register_and_invoke(session): +def test_sp_register_and_invoke(threadsafe_session): num_threads = 10 def increment_by_one_fn(session_: Session, x: int) -> int: return x + 1 def register_sproc(): - session.sproc.register( + threadsafe_session.sproc.register( increment_by_one_fn, name="increment_by_one", replace=True ) def invoke_sproc(): - result = session.call("increment_by_one", 1) + result = threadsafe_session.call("increment_by_one", 1) assert result == 2 threads = [] @@ -152,8 +168,8 @@ def test_mocked_function_registry_created_once(): @pytest.mark.parametrize("test_table", [True, False]) -def test_tabular_entity_registry(test_table): - conn = MockServerConnection() +def test_tabular_entity_registry(test_table, threadsafe_server_connection): + conn = threadsafe_server_connection entity_registry = conn.entity_registry num_threads = 10 @@ -191,8 +207,8 @@ def write_read_and_drop_view(): future.result() -def test_stage_entity_registry_put_and_get(): - stage_registry = StageEntityRegistry(MockServerConnection()) +def test_stage_entity_registry_put_and_get(threadsafe_server_connection): + stage_registry = StageEntityRegistry(threadsafe_server_connection) num_threads = 10 def put_and_get_file(): @@ -219,8 +235,10 @@ def put_and_get_file(): thread.join() -def test_stage_entity_registry_upload_and_read(session): - stage_registry = StageEntityRegistry(MockServerConnection()) +def test_stage_entity_registry_upload_and_read( + threadsafe_session, threadsafe_server_connection +): + stage_registry = StageEntityRegistry(threadsafe_server_connection) num_threads = 10 def upload_and_read_json(thread_id: int): @@ -236,7 +254,7 @@ def upload_and_read_json(thread_id: int): f"@test_stage/test_parent_dir/test_file_{thread_id}", "json", [], - session._analyzer, + threadsafe_session._analyzer, {"INFER_SCHEMA": "True"}, ) @@ -249,8 +267,8 @@ def upload_and_read_json(thread_id: int): future.result() -def test_stage_entity_registry_create_or_replace(): - stage_registry = StageEntityRegistry(MockServerConnection()) +def test_stage_entity_registry_create_or_replace(threadsafe_server_connection): + stage_registry = StageEntityRegistry(threadsafe_server_connection) num_threads = 10 with ThreadPoolExecutor(max_workers=num_threads) as executor: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index b621f6ca6a0..f0043d85d32 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -17,8 +17,25 @@ @pytest.fixture def mock_server_connection() -> ServerConnection: fake_snowflake_connection = mock.create_autospec(SnowflakeConnection) + fake_snowflake_connection._conn = mock.MagicMock() fake_snowflake_connection._telemetry = None fake_snowflake_connection._session_parameters = {} + fake_snowflake_connection._thread_safe_session_enabled = True + fake_snowflake_connection.cursor.return_value = mock.create_autospec( + SnowflakeCursor + ) + fake_snowflake_connection.is_closed.return_value = False + return ServerConnection({}, fake_snowflake_connection) + + +@pytest.fixture +def closed_mock_server_connection() -> ServerConnection: + fake_snowflake_connection = mock.create_autospec(SnowflakeConnection) + fake_snowflake_connection._conn = mock.MagicMock() + fake_snowflake_connection._telemetry = None + fake_snowflake_connection._session_parameters = {} + fake_snowflake_connection._thread_safe_session_enabled = True + fake_snowflake_connection.is_closed = mock.MagicMock(return_value=False) fake_snowflake_connection.cursor.return_value = mock.create_autospec( SnowflakeCursor ) diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index 21eb35ba30a..e9a77564e8a 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -148,10 +148,8 @@ def test_select_bad_input(): ) -def test_join_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_join_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df( ["int", "int2", "str"] ) @@ -174,20 +172,16 @@ def test_join_bad_input(): assert "Invalid type for join. Must be Dataframe" in str(exc_info) -def test_with_column_renamed_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_with_column_renamed_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "str"]) with pytest.raises(TypeError) as exc_info: df1.with_column_renamed(123, "int4") assert "must be a column name or Column object." in str(exc_info) -def test_with_column_rename_function_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_with_column_rename_function_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "str"]) with pytest.raises(TypeError) as exc_info: df1.rename(123, "int4") @@ -200,10 +194,8 @@ def test_with_column_rename_function_bad_input(): assert "You cannot rename a column using value 123 of type int" in str(exc_info) -def test_create_or_replace_view_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_create_or_replace_view_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "str"]) with pytest.raises(TypeError) as exc_info: df1.create_or_replace_view(123) @@ -213,10 +205,8 @@ def test_create_or_replace_view_bad_input(): ) -def test_create_or_replace_dynamic_table_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_create_or_replace_dynamic_table_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "str"]) with pytest.raises(TypeError) as exc_info: df1.create_or_replace_dynamic_table(123, warehouse="warehouse", lag="1 minute") @@ -249,10 +239,8 @@ def test_create_or_replace_dynamic_table_bad_input(): ) -def test_create_or_replace_temp_view_bad_input(): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_create_or_replace_temp_view_bad_input(mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(["a", "b", "str"]) with pytest.raises(TypeError) as exc_info: df1.create_or_replace_temp_view(123) @@ -266,10 +254,8 @@ def test_create_or_replace_temp_view_bad_input(): "join_type", ["inner", "leftouter", "rightouter", "fullouter", "leftsemi", "leftanti", "cross"], ) -def test_same_joins_should_generate_same_queries(join_type): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_same_joins_should_generate_same_queries(join_type, mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) session._conn._telemetry_client = mock.MagicMock() df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df( ["a1", "b1", "str1"] @@ -284,6 +270,7 @@ def test_same_joins_should_generate_same_queries(join_type): def test_statement_params(): mock_connection = mock.create_autospec(ServerConnection) mock_connection._conn = mock.MagicMock() + mock_connection._thread_safe_session_enabled = True session = snowflake.snowpark.session.Session(mock_connection) session._conn._telemetry_client = mock.MagicMock() df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) @@ -298,10 +285,8 @@ def test_statement_params(): ) -def test_dataFrame_printSchema(capfd): - mock_connection = mock.create_autospec(ServerConnection) - mock_connection._conn = mock.MagicMock() - session = snowflake.snowpark.session.Session(mock_connection) +def test_dataFrame_printSchema(capfd, mock_server_connection): + session = snowflake.snowpark.session.Session(mock_server_connection) df = session.create_dataframe([[1, ""], [3, None]]) df._plan._attributes = [ Attribute("A", IntegerType(), False), @@ -327,6 +312,7 @@ def test_session(): def test_table_source_plan(sql_simplifier_enabled): mock_connection = mock.create_autospec(ServerConnection) mock_connection._conn = mock.MagicMock() + mock_connection._thread_safe_session_enabled = True session = snowflake.snowpark.session.Session(mock_connection) session._sql_simplifier_enabled = sql_simplifier_enabled t = session.table("table") diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 370ee455d62..02f4b3bb5d4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -68,6 +68,7 @@ def quoted(s): def test_used_scoped_temp_object(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True fake_connection._get_client_side_session_parameter = ( lambda x, y: ServerConnection._get_client_side_session_parameter( @@ -112,6 +113,7 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" @@ -124,11 +126,8 @@ def test_close_exception(): session.close() -def test_close_session_in_stored_procedure_no_op(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - fake_connection.is_closed = MagicMock(return_value=False) - session = Session(fake_connection) +def test_close_session_in_stored_procedure_no_op(closed_mock_server_connection): + session = Session(closed_mock_server_connection) with mock.patch.object( snowflake.snowpark.session, "is_in_stored_procedure" ) as mock_fn, mock.patch.object( @@ -149,13 +148,12 @@ def test_close_session_in_stored_procedure_no_op(): "warning_level, expected", [(logging.WARNING, True), (logging.INFO, True), (logging.ERROR, False)], ) -def test_close_session_in_stored_procedure_log_level(caplog, warning_level, expected): +def test_close_session_in_stored_procedure_log_level( + caplog, closed_mock_server_connection, warning_level, expected +): caplog.clear() caplog.set_level(warning_level) - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - fake_connection.is_closed = MagicMock(return_value=False) - session = Session(fake_connection) + session = Session(closed_mock_server_connection) with mock.patch.object( snowflake.snowpark.session, "is_in_stored_procedure" ) as mock_fn: @@ -165,10 +163,10 @@ def test_close_session_in_stored_procedure_log_level(caplog, warning_level, expe assert result == expected -def test_resolve_import_path_ignore_import_path(tmp_path_factory): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_resolve_import_path_ignore_import_path( + tmp_path_factory, mock_server_connection +): + session = Session(mock_server_connection) tmp_path = tmp_path_factory.mktemp("session_test") a_temp_file = tmp_path / "file.txt" @@ -203,6 +201,7 @@ def mock_get_information_schema_packages(table_name: str): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() + fake_connection._thread_safe_session_enabled = True fake_connection._get_current_parameter = mock_get_current_parameter session = Session(fake_connection) session.table = MagicMock(name="session.table") @@ -213,10 +212,8 @@ def mock_get_information_schema_packages(table_name: str): ) -def test_resolve_package_terms_not_accepted(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_resolve_package_terms_not_accepted(mock_server_connection): + session = Session(mock_server_connection) def get_information_schema_packages(table_name: str): if table_name == "information_schema.packages": @@ -246,7 +243,7 @@ def run_query(sql: str): ) -def test_resolve_packages_side_effect(): +def test_resolve_packages_side_effect(mock_server_connection): """Python stored procedure depends on this behavior to add packages to the session.""" def mock_get_information_schema_packages(table_name: str): @@ -256,9 +253,7 @@ def mock_get_information_schema_packages(table_name: str): ] return result - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) + session = Session(mock_server_connection) session.table = MagicMock(name="session.table") session.table.side_effect = mock_get_information_schema_packages @@ -280,20 +275,16 @@ def mock_get_information_schema_packages(table_name: str): @pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas") -def test_write_pandas_wrong_table_type(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_write_pandas_wrong_table_type(mock_server_connection): + session = Session(mock_server_connection) with pytest.raises(ValueError, match="Unsupported table type."): session.write_pandas( mock.create_autospec(pandas.DataFrame), table_name="t", table_type="aaa" ) -def test_create_dataframe_empty_schema(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_create_dataframe_empty_schema(mock_server_connection): + session = Session(mock_server_connection) with pytest.raises( ValueError, match="The provided schema or inferred schema cannot be None or empty", @@ -301,20 +292,16 @@ def test_create_dataframe_empty_schema(): session.create_dataframe([[1]], schema=StructType([])) -def test_create_dataframe_wrong_type(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_create_dataframe_wrong_type(mock_server_connection): + session = Session(mock_server_connection) with pytest.raises( TypeError, match=r"Cannot cast \(1\) to ." ): session.create_dataframe([[1]], schema=StructType([StructField("a", str)])) -def test_table_exists_invalid_table_name(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_table_exists_invalid_table_name(mock_server_connection): + session = Session(mock_server_connection) with pytest.raises( SnowparkInvalidObjectNameException, match="The object name 'a.b.c.d' is invalid.", @@ -322,10 +309,8 @@ def test_table_exists_invalid_table_name(): session._table_exists(["a", "b", "c", "d"]) -def test_explain_query_error(): - fake_connection = mock.create_autospec(ServerConnection) - fake_connection._conn = mock.Mock() - session = Session(fake_connection) +def test_explain_query_error(mock_server_connection): + session = Session(mock_server_connection) session._run_query = MagicMock() session._run_query.side_effect = ProgrammingError("Can't explain.") assert session._explain_query("select 1") is None @@ -451,6 +436,7 @@ def test_parse_table_name(): def test_session_id(): fake_server_connection = mock.create_autospec(ServerConnection) + fake_server_connection._thread_safe_session_enabled = True fake_server_connection.get_session_id = mock.Mock(return_value=123456) session = Session(fake_server_connection)