Skip to content
Merged
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def convert_sf_to_sp_type(
if column_type_name == "REAL":
return DoubleType()
if (column_type_name == "FIXED" or column_type_name == "NUMBER") and scale == 0:
return LongType()
return LongType(_precision=precision)
raise NotImplementedError(
"Unsupported type: {}, precision: {}, scale: {}".format(
column_type_name, precision, scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import datetime
import inspect
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, NamedTuple, Optional, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -121,7 +121,7 @@ class TimedeltaType(SnowparkPandasType, LongType):
two times.
"""

snowpark_type: DataType = LongType()
snowpark_type: DataType = field(default_factory=LongType)
pandas_type: np.dtype = np.dtype("timedelta64[ns]")
types_to_convert_with_from_pandas: Tuple[Type] = ( # type: ignore[assignment]
native_pd.Timedelta,
Expand All @@ -132,9 +132,6 @@ class TimedeltaType(SnowparkPandasType, LongType):
def __init__(self) -> None:
super().__init__()

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

Expand Down
22 changes: 21 additions & 1 deletion src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,27 @@ def _fill_ast(self, ast: proto.DataType) -> None:

# Numeric types
class _IntegralType(_NumericType):
pass
def __init__(self, **kwargs) -> None:
self._precision = kwargs.pop("_precision", None)

if kwargs != {}:
raise TypeError(
f"__init__() takes 0 argument but {len(kwargs.keys())} were given"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammar Error in Exception Message

The error message uses singular "argument" but will be grammatically incorrect when multiple arguments are passed.

# Current: "takes 0 argument but 2 were given" (incorrect grammar)

Fix: Use proper singular/plural form:

arg_word = "argument" if len(kwargs) == 1 else "arguments"
raise TypeError(
    f"__init__() takes 0 arguments but {len(kwargs)} {arg_word if len(kwargs) != 1 else 'was'} given"
)

Or simply use plural consistently:

raise TypeError(
    f"__init__() takes 0 arguments but {len(kwargs)} were given"
)

Spotted by Graphite Agent

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

)

def __eq__(self, other):
def filtered(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "_precision"}

if context._is_snowpark_connect_compatible_mode:
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
else:
return isinstance(other, self.__class__) and filtered(
self.__dict__
) == filtered(other.__dict__)

def __hash__(self):
return hash(repr(self))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a custom hash implementation for this class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL:DR: this is to fix some test failure
when you defined eq, you need to redefine hash otherwise it is considered None



class _FractionalType(_NumericType):
Expand Down
94 changes: 94 additions & 0 deletions tests/integ/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import csv
import os
import tempfile
from decimal import Decimal

import pytest

from snowflake.snowpark import DataFrame, Row
from snowflake.snowpark.functions import lit
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -408,3 +413,92 @@ def test_join_basic(session):
]
)
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="session.sql not supported by local testing mode",
)
@pytest.mark.parametrize(
"massive_number, precision", [("9" * 38, 38), ("5" * 20, 20), ("7" * 10, 10)]
)
def test_numeric_type_store_precision_and_scale(session, massive_number, precision):
table_name = Utils.random_table_name()
try:
df = session.create_dataframe(
[Decimal(massive_number)],
StructType([StructField("large_value", DecimalType(precision, 0), True)]),
)
datatype = df.schema.fields[0].datatype
assert isinstance(datatype, LongType)
assert datatype._precision == precision

# after save as table, the precision information is lost, because it is basically save LongType(), which
# does not have precision information, thus set to default 38.
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
result = session.sql(f"select * from {table_name}")
session.sql(f"describe table {table_name}").show()
datatype = result.schema.fields[0].datatype
assert isinstance(datatype, LongType)
assert datatype._precision == 38
finally:
session.sql(f"drop table if exists {table_name}").collect()


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="relaxed_types not supported by local testing mode",
)
@pytest.mark.parametrize("massive_number", ["9" * 38, "5" * 20, "7" * 10])
def test_numeric_type_store_precision_and_scale_read_file(session, massive_number):
stage_name = Utils.random_stage_name()
header = ("BIG_NUM",)
test_data = [(massive_number,)]

def write_csv(data):
with tempfile.NamedTemporaryFile(
mode="w+",
delete=False,
suffix=".csv",
newline="",
) as file:
writer = csv.writer(file)
writer.writerow(header)
for row in data:
writer.writerow(row)
return file.name

file_path = write_csv(test_data)

try:
Utils.create_stage(session, stage_name, is_temporary=True)
result = session.file.put(
file_path, f"@{stage_name}", auto_compress=False, overwrite=True
)

# Infer schema from only the short file
constrained_reader = session.read.options(
{
"INFER_SCHEMA": True,
"INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]},
"PARSE_HEADER": True,
# Only load the short file
"PATTERN": f".*{result[0].target}",
}
)

# df1 uses constrained types
df1 = constrained_reader.csv(f"@{stage_name}/")
datatype = df1.schema.fields[0].datatype
assert isinstance(datatype, LongType)
assert datatype._precision == 38

finally:
Utils.drop_stage(session, stage_name)
if os.path.exists(file_path):
os.remove(file_path)


def test_illegal_argument_intergraltype():
with pytest.raises(TypeError, match="takes 0 argument but 1 were given"):
LongType(b=10)
Loading