Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1331032 add retrieve func defaults from source #1957

Merged
143 changes: 143 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
Variant,
VariantType,
VectorType,
_FractionalType,
_IntegralType,
_NumericType,
)

Expand Down Expand Up @@ -525,6 +527,52 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType
return a


def python_value_str_to_object(value, tp: DataType) -> Any:
if isinstance(tp, StringType):
return value

if isinstance(
tp,
(
_IntegralType,
_FractionalType,
BooleanType,
BinaryType,
TimeType,
DateType,
TimestampType,
),
):
return eval(value)

if isinstance(tp, ArrayType):
curr_list = eval(value)
if curr_list is None:
return None
element_tp = tp.element_type or StringType()
return [python_value_str_to_object(val, element_tp) for val in curr_list]

if isinstance(tp, MapType):
curr_dict: dict = eval(value)
if curr_dict is None:
return None
key_tp = tp.key_type or StringType()
val_tp = tp.value_type or StringType()
return {
python_value_str_to_object(k, key_tp): python_value_str_to_object(v, val_tp)
for k, v in curr_dict.items()
}

if isinstance(tp, (GeometryType, GeographyType, VariantType)):
if value.strip() == "None":
return None
return value

raise TypeError(
f"Unsupported data type: {tp}, value {value} by python_value_str_to_object()"
)


def python_type_str_to_object(
tp_str: str, is_return_type_for_sproc: bool = False
) -> Type:
Expand Down Expand Up @@ -708,6 +756,101 @@ def snow_type_to_dtype_str(snow_type: DataType) -> str:
raise TypeError(f"invalid DataType {snow_type}")


def retrieve_func_defaults_from_source(
file_path: str,
func_name: str,
class_name: Optional[str] = None,
_source: Optional[str] = None,
) -> Optional[List[Optional[str]]]:
"""
Retrieve default values assigned to optional arguments of a function from a
source file, or a source string (test only).
Returns list of str(default value) if the function is found, None otherwise.
"""

def parse_default_value(
value: ast.expr, enquote_string: bool = False
) -> Optional[str]:
# recursively parse the default value if it is tuple or list
if isinstance(value, (ast.Tuple, ast.List)):
return f"{[parse_default_value(e) for e in value.elts]}"
# recursively parse the default keys and values if it is dict
if isinstance(value, ast.Dict):
key_val_tuples = [
(parse_default_value(k), parse_default_value(v))
for k, v in zip(value.keys, value.values)
]
return f"{dict(key_val_tuples)}"
# recursively parse the default value.value and extract value.attr
if isinstance(value, ast.Attribute):
return f"{parse_default_value(value.value)}.{value.attr}"
# recursively parse value.value and extract value.arg
if isinstance(value, ast.keyword):
return f"{value.arg}={parse_default_value(value.value)}"
# extract constant value
if isinstance(value, ast.Constant):
if isinstance(value.value, str) and enquote_string:
return f"'{value.value}'"
if value.value is None:
return None
return f"{value.value}"
# extract value.id from Name
if isinstance(value, ast.Name):
return value.id
# recursively parse value.func and extract value.args and value.keywords
if isinstance(value, ast.Call):
parsed_args = ", ".join(
parse_default_value(arg, True) for arg in value.args
)
parsed_kwargs = ", ".join(
parse_default_value(kw, True) for kw in value.keywords
)
combined_parsed_input = (
f"{parsed_args}, {parsed_kwargs}"
if parsed_args and parsed_kwargs
else parsed_args or parsed_kwargs
)
return f"{parse_default_value(value.func)}({combined_parsed_input})"
raise TypeError(f"invalid default value: {value}")

class FuncNodeVisitor(ast.NodeVisitor):
default_values = []
func_exist = False

def visit_FunctionDef(self, node):
if node.name == func_name:
for value in node.args.defaults:
self.default_values.append(parse_default_value(value))
self.func_exist = True

if not _source:
with open(file_path) as f:
_source = f.read()

if class_name:

class ClassNodeVisitor(ast.NodeVisitor):
class_node = None

def visit_ClassDef(self, node):
if node.name == class_name:
self.class_node = node

class_visitor = ClassNodeVisitor()
class_visitor.visit(ast.parse(_source))
if class_visitor.class_node is None:
return None
to_visit_node_for_func = class_visitor.class_node
else:
to_visit_node_for_func = ast.parse(_source)

visitor = FuncNodeVisitor()
visitor.visit(to_visit_node_for_func)
if not visitor.func_exist:
return None
return visitor.default_values


def retrieve_func_type_hints_from_source(
file_path: str,
func_name: str,
Expand Down
153 changes: 153 additions & 0 deletions tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import decimal
import os
import sys
import typing
Expand Down Expand Up @@ -38,6 +39,8 @@
infer_type,
merge_type,
python_type_to_snow_type,
python_value_str_to_object,
retrieve_func_defaults_from_source,
retrieve_func_type_hints_from_source,
snow_type_to_dtype_str,
)
Expand Down Expand Up @@ -584,6 +587,156 @@ def test_decimal_regular_expression(decimal_word):
assert get_number_precision_scale(f" {decimal_word} ( 2 , 1 ) ") == (2, 1)


@pytest.mark.parametrize("test_from_file", [True, False])
@pytest.mark.parametrize("add_type_hint", [True, False])
@pytest.mark.parametrize(
"datatype,annotated_value,extracted_value",
[
("int", "None", None),
("int", "1", "1"),
("bool", "True", "True"),
("float", "1.0", "1.0"),
("decimal.Decimal", "decimal.Decimal('3.14')", "decimal.Decimal('3.14')"),
("decimal.Decimal", "decimal.Decimal(1.0)", "decimal.Decimal(1.0)"),
("str", "one", "one"),
("str", "None", None),
("bytes", "b'one'", "b'one'"),
("bytearray", "bytearray('one', 'utf-8')", "bytearray('one', 'utf-8')"),
("datetime.date", "datetime.date(2024, 4, 1)", "datetime.date(2024, 4, 1)"),
(
"datetime.time",
"datetime.time(12, 0, second=20, tzinfo=datetime.timezone.utc)",
"datetime.time(12, 0, second=20, tzinfo=datetime.timezone.utc)",
),
(
"datetime.datetime",
"datetime.datetime(2024, 4, 1, 12, 0, 20)",
"datetime.datetime(2024, 4, 1, 12, 0, 20)",
),
("List[int]", "[1, 2, 3]", "['1', '2', '3']"),
("List[str]", "['a', 'b', 'c']", "['a', 'b', 'c']"),
(
"List[List[int]]",
"[[1, 2, 3], [4, 5, 6]]",
"[\"['1', '2', '3']\", \"['4', '5', '6']\"]",
),
("Map[int, str]", "{1: 'a'}", "{'1': 'a'}"),
("Map[int, List[str]]", "{1: ['a', 'b']}", "{'1': \"['a', 'b']\"}"),
("Variant", "{'key': 'val'}", "{'key': 'val'}"),
("Geography", "'POINT(-122.35 37.55)'", "POINT(-122.35 37.55)"),
("Geometry", "'POINT(-122.35 37.55)'", "POINT(-122.35 37.55)"),
],
)
def test_retrieve_func_defaults_from_source(
datatype, annotated_value, extracted_value, add_type_hint, test_from_file, tmpdir
):
func_name = "foo"

source = f"""
def {func_name}() -> None:
return None
"""
if test_from_file:
file = tmpdir.join("test_udf.py")
file.write(source)
assert retrieve_func_defaults_from_source(file, func_name) == []
else:
assert retrieve_func_defaults_from_source("", func_name, _source=source) == []

datatype_str = f": {datatype}" if add_type_hint else ""
source = f"""
def {func_name}(x, y {datatype_str} = {annotated_value}) -> None:
return None
"""
if test_from_file:
file = tmpdir.join("test_udf.py")
file.write(source)
assert retrieve_func_defaults_from_source(file, func_name) == [extracted_value]
else:
assert retrieve_func_defaults_from_source("", func_name, _source=source) == [
extracted_value
]


@pytest.mark.parametrize(
"value_str,datatype,expected_value",
[
("1", IntegerType(), 1),
("True", BooleanType(), True),
("1.0", FloatType(), 1.0),
("decimal.Decimal('3.14')", DecimalType(), decimal.Decimal("3.14")),
("decimal.Decimal(1.0)", DecimalType(), decimal.Decimal(1.0)),
("one", StringType(), "one"),
(None, StringType(), None),
("None", StringType(), "None"),
("POINT(-122.35 37.55)", GeographyType(), "POINT(-122.35 37.55)"),
("POINT(-122.35 37.55)", GeometryType(), "POINT(-122.35 37.55)"),
('{"key": "val"}', VariantType(), '{"key": "val"}'),
("b'one'", BinaryType(), b"one"),
("bytearray('one', 'utf-8')", BinaryType(), bytearray("one", "utf-8")),
("datetime.date(2024, 4, 1)", DateType(), date(2024, 4, 1)),
(
"datetime.time(12, 0, second=20, tzinfo=datetime.timezone.utc)",
TimeType(),
time(12, 0, second=20, tzinfo=timezone.utc),
),
(
"datetime.datetime(2024, 4, 1, 12, 0, 20)",
TimestampType(),
datetime(2024, 4, 1, 12, 0, 20),
),
("['1', '2', '3']", ArrayType(IntegerType()), [1, 2, 3]),
("['a', 'b', 'c']", ArrayType(StringType()), ["a", "b", "c"]),
("['a', 'b', 'c']", ArrayType(), ["a", "b", "c"]),
(
"[\"['1', '2', '3']\", \"['4', '5', '6']\"]",
ArrayType(ArrayType(IntegerType())),
[[1, 2, 3], [4, 5, 6]],
),
("{'1': 'a'}", MapType(), {"1": "a"}),
("{'1': 'a'}", MapType(IntegerType(), StringType()), {1: "a"}),
(
"{'1': \"['a', 'b']\"}",
MapType(IntegerType(), ArrayType(StringType())),
{1: ["a", "b"]},
),
],
)
def test_python_value_str_to_object(value_str, datatype, expected_value):
assert python_value_str_to_object(value_str, datatype) == expected_value


@pytest.mark.parametrize(
"datatype",
[
IntegerType(),
BooleanType(),
FloatType(),
DecimalType(),
BinaryType(),
DateType(),
TimeType(),
TimestampType(),
ArrayType(),
MapType(),
VariantType(),
GeographyType(),
GeometryType(),
],
)
def test_python_value_str_to_object_for_none(datatype):
"StringType() is excluded here and tested in test_python_value_str_to_object"
assert python_value_str_to_object("None", datatype) is None


def test_python_value_str_to_object_negative():
with pytest.raises(
TypeError,
match="Unsupported data type: invalid type, value thanksgiving by python_value_str_to_object()",
):
python_value_str_to_object("thanksgiving", "invalid type")


def test_retrieve_func_type_hints_from_source():
func_name = "foo"

Expand Down
Loading