Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Added support for `FileOperation.remove` to remove files in a stage.
- Added a new function `snowflake.snowpark.functions.vectorized` that allows users to mark a function as vectorized UDF.
- Added support for parameter `use_vectorized_scanner` in function `Session.write_pandas()`.
- Added support for parameter `session_init_statement` in udtf ingestion of `DataFrameReader.jdbc`(PrPr).
- Added support for the following scalar functions in `functions.py`:
- `getdate`
- `getvariable`
Expand All @@ -27,6 +28,8 @@

#### Bug Fixes

- Fixed a bug that `query_timeout` does not work in udtf ingestion of `DataFrameReader.jdbc`(PrPr).

#### Deprecations

#### Dependency Updates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def read(self, partition: str) -> Iterator[List[Any]]:
cursor.execute(statement)
except BaseException as exc:
raise SnowparkDataframeReaderException(
f"Failed to execute session init statement: '{statement}' due to exception '{exc!r}'"
f"Failed to execute session init statement: '{statement}' due to exception '{exc}'"
)
# use server side cursor to fetch data if supported by the driver
# some drivers do not support execute twice on server side cursor (e.g. psycopg2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,21 @@ def udtf_ingestion(
fetch_size: int = 1000,
imports: Optional[List[str]] = None,
packages: Optional[List[str]] = None,
session_init_statement: Optional[List[str]] = None,
query_timeout: Optional[int] = 0,
_emit_ast: bool = True,
) -> "snowflake.snowpark.DataFrame":
from snowflake.snowpark._internal.data_source.utils import UDTF_PACKAGE_MAP

udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
with measure_time() as udtf_register_time:
session.udtf.register(
self.udtf_class_builder(fetch_size=fetch_size, schema=schema),
self.udtf_class_builder(
fetch_size=fetch_size,
schema=schema,
session_init_statement=session_init_statement,
query_timeout=query_timeout,
),
name=udtf_name,
output_schema=StructType(
[
Expand All @@ -166,14 +173,22 @@ def udtf_ingestion(
return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast)

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection
prepare_connection = self.prepare_connection

class UDTFIngestion:
def process(self, query: str):
conn = create_connection()
conn = prepare_connection(create_connection(), query_timeout)
cursor = conn.cursor()
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,21 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
return StructType(all_columns)

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection

class UDTFIngestion:
def process(self, query: str):
conn = create_connection()
cursor = conn.cursor()
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)

# First get schema information
describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,18 @@ def prepare_connection(
conn: "Connection",
query_timeout: int = 0,
) -> "Connection":
conn.call_timeout = query_timeout * 1000
if query_timeout > 0:
conn.call_timeout = query_timeout * 1000
if conn.outputtypehandler is None:
conn.outputtypehandler = output_type_handler
return conn

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection

Expand All @@ -138,9 +143,14 @@ def convert_to_hex(value):
class UDTFIngestion:
def process(self, query: str):
conn = create_connection()
if query_timeout > 0:
conn.call_timeout = query_timeout * 1000
if conn.outputtypehandler is None:
conn.outputtypehandler = oracledb_output_type_handler
cursor = conn.cursor()
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ def prepare_connection(
return conn

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection

Expand All @@ -275,10 +279,15 @@ def prepare_connection_in_udtf(

class UDTFIngestion:
def process(self, query: str):
conn = prepare_connection_in_udtf(create_connection())
conn = prepare_connection_in_udtf(create_connection(), query_timeout)
cursor = conn.cursor(
f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}"
)
if session_init_statement is not None:
session_init_cur = conn.cursor()
for statement in session_init_statement:
session_init_cur.execute(statement)
session_init_cur.fetchall()
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
return StructType(fields)

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection

Expand All @@ -194,6 +198,9 @@ def process(self, query: str):

conn = create_connection()
cursor = pymysql.cursors.SSCursor(conn)
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
Expand All @@ -203,14 +210,6 @@ def process(self, query: str):

return UDTFIngestion

def prepare_connection(
self,
conn: "Connection",
query_timeout: int = 0,
) -> "Connection":
conn.read_timeout = query_timeout if query_timeout != 0 else None
return conn

@staticmethod
def infer_type_from_data(data: List[tuple], number_of_columns: int) -> List[Type]:
# TODO: SNOW-2112938 investigate whether different types can be fit into one column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,14 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
return StructType(fields)

def udtf_class_builder(
self, fetch_size: int = 1000, schema: StructType = None
self,
fetch_size: int = 1000,
schema: StructType = None,
session_init_statement: List[str] = None,
query_timeout: int = 0,
) -> type:
create_connection = self.create_connection
prepare_connection = self.prepare_connection

def binary_converter(value):
return value.hex() if value is not None else None
Expand All @@ -89,7 +94,7 @@ class UDTFIngestion:
def process(self, query: str):
import pyodbc

conn = create_connection()
conn = prepare_connection(create_connection(), query_timeout)
if (
conn.get_output_converter(pyodbc.SQL_BINARY) is None
and conn.get_output_converter(pyodbc.SQL_VARBINARY) is None
Expand All @@ -101,6 +106,9 @@ def process(self, query: str):
pyodbc.SQL_LONGVARBINARY, binary_converter
)
cursor = conn.cursor()
if session_init_statement is not None:
for statement in session_init_statement:
cursor.execute(statement)
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,6 +1859,8 @@ def create_oracledb_connection():
fetch_size=fetch_size,
imports=udtf_configs.get("imports", None),
packages=udtf_configs.get("packages", None),
session_init_statement=session_init_statement,
query_timeout=query_timeout,
_emit_ast=_emit_ast,
)
end_time = time.perf_counter()
Expand Down
43 changes: 41 additions & 2 deletions tests/integ/datasource/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
random_name_for_temp_object,
TempObjectType,
)
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
SnowparkSQLException,
)
from snowflake.snowpark.types import (
StructType,
StructField,
Expand Down Expand Up @@ -205,7 +208,9 @@ def local_create_databricks_connection():

def test_unit_udtf_ingestion():
dbx_driver = DatabricksDriver(create_databricks_connection, DBMS_TYPE.DATABRICKS_DB)
udtf_ingestion_class = dbx_driver.udtf_class_builder()
udtf_ingestion_class = dbx_driver.udtf_class_builder(
session_init_statement=["select 1"]
)
udtf_ingestion_instance = udtf_ingestion_class()

dsp = DataSourcePartitioner(
Expand Down Expand Up @@ -258,3 +263,37 @@ def test_unsupported_type():
create_databricks_connection, DBMS_TYPE.DATABRICKS_DB
).to_snow_type([("test_col", "unsupported_type", True)])
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="syntax error command",
):
session.read.dbapi(
create_databricks_connection,
table=TEST_TABLE_NAME,
session_init_statement=["syntax error command"],
)


def test_session_init_udtf(session):
udtf_configs = {
"external_access_integration": DATABRICKS_TEST_EXTERNAL_ACCESS_INTEGRATION
}

def create_databricks_udtf_connection():
import databricks.sql

return databricks.sql.connect(**DATABRICKS_CONNECTION_PARAMETERS)

with pytest.raises(
SnowparkSQLException,
match="syntax error command",
):
session.read.dbapi(
create_databricks_udtf_connection,
table=TEST_TABLE_NAME,
session_init_statement=["syntax error command"],
udtf_configs=udtf_configs,
).collect()
52 changes: 51 additions & 1 deletion tests/integ/datasource/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
)
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
from snowflake.snowpark.types import StructType, StructField, StringType
from snowflake.snowpark.exceptions import (
SnowparkDataframeReaderException,
SnowparkSQLException,
)
from tests.resources.test_data_source_dir.test_mysql_data import (
mysql_real_data,
MysqlType,
Expand Down Expand Up @@ -261,7 +265,9 @@ def test_pymysql_driver_udtf_class_builder():
driver = PymysqlDriver(create_connection_mysql, DBMS_TYPE.MYSQL_DB)

# Get the UDTF class with a small fetch size to test batching
UDTFClass = driver.udtf_class_builder(fetch_size=2)
UDTFClass = driver.udtf_class_builder(
fetch_size=2, session_init_statement=["select 1"]
)

# Instantiate the UDTF class
udtf_instance = UDTFClass()
Expand Down Expand Up @@ -297,3 +303,47 @@ def test_unsupported_type():
[("test_col", "unsupported_type", None, None, 0, 0, True)]
)
assert schema == StructType([StructField("TEST_COL", StringType(), nullable=True)])


def test_session_init(session):
with pytest.raises(
SnowparkDataframeReaderException,
match="Mock error to test init_statement",
):
session.read.dbapi(
create_connection_mysql,
table=TEST_TABLE_NAME,
session_init_statement=[
"SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Mock error to test init_statement'"
],
)


def test_session_init_udtf(session):
udtf_configs = {
"external_access_integration": MYSQL_TEST_EXTERNAL_ACCESS_INTEGRATION
}

def create_connection_udtf_mysql():
import pymysql # noqa: F811

conn = pymysql.connect(
user=MYSQL_CONNECTION_PARAMETERS["username"],
password=MYSQL_CONNECTION_PARAMETERS["password"],
host=MYSQL_CONNECTION_PARAMETERS["host"],
database=MYSQL_CONNECTION_PARAMETERS["database"],
)
return conn

with pytest.raises(
SnowparkSQLException,
match="Mock error to test init_statement",
):
session.read.dbapi(
create_connection_udtf_mysql,
table=TEST_TABLE_NAME,
session_init_statement=[
"SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Mock error to test init_statement'"
],
udtf_configs=udtf_configs,
).collect()
Loading
Loading