Skip to content

Commit bcba153

Browse files
SNOW-2866776: add private var in numeric type (#4022)
1 parent 2824029 commit bcba153

File tree

4 files changed

+128
-5
lines changed

4 files changed

+128
-5
lines changed

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def convert_sf_to_sp_type(
307307
if column_type_name == "REAL":
308308
return DoubleType()
309309
if (column_type_name == "FIXED" or column_type_name == "NUMBER") and scale == 0:
310-
return LongType()
310+
return LongType(_precision=precision)
311311
raise NotImplementedError(
312312
"Unsupported type: {}, precision: {}, scale: {}".format(
313313
column_type_name, precision, scale

src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import inspect
77
from abc import ABCMeta, abstractmethod
88
from dataclasses import dataclass
9-
from typing import Any, Callable, NamedTuple, Optional, Tuple, Type, Union
9+
from typing import Any, Callable, NamedTuple, Optional, Tuple, Type, Union, ClassVar
1010

1111
import numpy as np
1212
import pandas as native_pd
1313

14+
from snowflake.snowpark import context
1415
from snowflake.snowpark.column import Column
1516
from snowflake.snowpark.types import DataType, LongType
1617

@@ -121,7 +122,7 @@ class TimedeltaType(SnowparkPandasType, LongType):
121122
two times.
122123
"""
123124

124-
snowpark_type: DataType = LongType()
125+
snowpark_type: ClassVar[DataType] = LongType()
125126
pandas_type: np.dtype = np.dtype("timedelta64[ns]")
126127
types_to_convert_with_from_pandas: Tuple[Type] = ( # type: ignore[assignment]
127128
native_pd.Timedelta,
@@ -133,7 +134,15 @@ def __init__(self) -> None:
133134
super().__init__()
134135

135136
def __eq__(self, other: Any) -> bool:
136-
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
137+
def filtered(d: dict) -> dict:
138+
return {k: v for k, v in d.items() if k != "_precision"}
139+
140+
if context._is_snowpark_connect_compatible_mode:
141+
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
142+
else:
143+
return isinstance(other, self.__class__) and filtered(
144+
self.__dict__
145+
) == filtered(other.__dict__)
137146

138147
def __ne__(self, other: Any) -> bool:
139148
return not self.__eq__(other)

src/snowflake/snowpark/types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,27 @@ def _fill_ast(self, ast: proto.DataType) -> None:
370370

371371
# Numeric types
372372
class _IntegralType(_NumericType):
373-
pass
373+
def __init__(self, **kwargs) -> None:
374+
self._precision = kwargs.pop("_precision", None)
375+
376+
if kwargs != {}:
377+
raise TypeError(
378+
f"__init__() takes 0 argument but {len(kwargs.keys())} were given"
379+
)
380+
381+
def __eq__(self, other):
382+
def filtered(d: dict) -> dict:
383+
return {k: v for k, v in d.items() if k != "_precision"}
384+
385+
if context._is_snowpark_connect_compatible_mode:
386+
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
387+
else:
388+
return isinstance(other, self.__class__) and filtered(
389+
self.__dict__
390+
) == filtered(other.__dict__)
391+
392+
def __hash__(self):
393+
return hash(repr(self))
374394

375395

376396
class _FractionalType(_NumericType):

tests/integ/test_datatypes.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4+
import csv
5+
import os
6+
import tempfile
47
from decimal import Decimal
58

9+
import pytest
10+
611
from snowflake.snowpark import DataFrame, Row
712
from snowflake.snowpark.functions import lit
813
from snowflake.snowpark.types import (
@@ -408,3 +413,92 @@ def test_join_basic(session):
408413
]
409414
)
410415
)
416+
417+
418+
@pytest.mark.skipif(
419+
"config.getoption('local_testing_mode', default=False)",
420+
reason="session.sql not supported by local testing mode",
421+
)
422+
@pytest.mark.parametrize(
423+
"massive_number, precision", [("9" * 38, 38), ("5" * 20, 20), ("7" * 10, 10)]
424+
)
425+
def test_numeric_type_store_precision_and_scale(session, massive_number, precision):
426+
table_name = Utils.random_table_name()
427+
try:
428+
df = session.create_dataframe(
429+
[Decimal(massive_number)],
430+
StructType([StructField("large_value", DecimalType(precision, 0), True)]),
431+
)
432+
datatype = df.schema.fields[0].datatype
433+
assert isinstance(datatype, LongType)
434+
assert datatype._precision == precision
435+
436+
# after save as table, the precision information is lost, because it is basically save LongType(), which
437+
# does not have precision information, thus set to default 38.
438+
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
439+
result = session.sql(f"select * from {table_name}")
440+
session.sql(f"describe table {table_name}").show()
441+
datatype = result.schema.fields[0].datatype
442+
assert isinstance(datatype, LongType)
443+
assert datatype._precision == 38
444+
finally:
445+
session.sql(f"drop table if exists {table_name}").collect()
446+
447+
448+
@pytest.mark.skipif(
449+
"config.getoption('local_testing_mode', default=False)",
450+
reason="relaxed_types not supported by local testing mode",
451+
)
452+
@pytest.mark.parametrize("massive_number", ["9" * 38, "5" * 20, "7" * 10])
453+
def test_numeric_type_store_precision_and_scale_read_file(session, massive_number):
454+
stage_name = Utils.random_stage_name()
455+
header = ("BIG_NUM",)
456+
test_data = [(massive_number,)]
457+
458+
def write_csv(data):
459+
with tempfile.NamedTemporaryFile(
460+
mode="w+",
461+
delete=False,
462+
suffix=".csv",
463+
newline="",
464+
) as file:
465+
writer = csv.writer(file)
466+
writer.writerow(header)
467+
for row in data:
468+
writer.writerow(row)
469+
return file.name
470+
471+
file_path = write_csv(test_data)
472+
473+
try:
474+
Utils.create_stage(session, stage_name, is_temporary=True)
475+
result = session.file.put(
476+
file_path, f"@{stage_name}", auto_compress=False, overwrite=True
477+
)
478+
479+
# Infer schema from only the short file
480+
constrained_reader = session.read.options(
481+
{
482+
"INFER_SCHEMA": True,
483+
"INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]},
484+
"PARSE_HEADER": True,
485+
# Only load the short file
486+
"PATTERN": f".*{result[0].target}",
487+
}
488+
)
489+
490+
# df1 uses constrained types
491+
df1 = constrained_reader.csv(f"@{stage_name}/")
492+
datatype = df1.schema.fields[0].datatype
493+
assert isinstance(datatype, LongType)
494+
assert datatype._precision == 38
495+
496+
finally:
497+
Utils.drop_stage(session, stage_name)
498+
if os.path.exists(file_path):
499+
os.remove(file_path)
500+
501+
502+
def test_illegal_argument_intergraltype():
503+
with pytest.raises(TypeError, match="takes 0 argument but 1 were given"):
504+
LongType(b=10)

0 commit comments

Comments
 (0)