|
1 | 1 | # |
2 | 2 | # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. |
3 | 3 | # |
| 4 | +import csv |
| 5 | +import os |
| 6 | +import tempfile |
4 | 7 | from decimal import Decimal |
5 | 8 |
|
| 9 | +import pytest |
| 10 | + |
6 | 11 | from snowflake.snowpark import DataFrame, Row |
7 | 12 | from snowflake.snowpark.functions import lit |
8 | 13 | from snowflake.snowpark.types import ( |
@@ -408,3 +413,92 @@ def test_join_basic(session): |
408 | 413 | ] |
409 | 414 | ) |
410 | 415 | ) |
| 416 | + |
| 417 | + |
| 418 | +@pytest.mark.skipif( |
| 419 | + "config.getoption('local_testing_mode', default=False)", |
| 420 | + reason="session.sql not supported by local testing mode", |
| 421 | +) |
| 422 | +@pytest.mark.parametrize( |
| 423 | + "massive_number, precision", [("9" * 38, 38), ("5" * 20, 20), ("7" * 10, 10)] |
| 424 | +) |
| 425 | +def test_numeric_type_store_precision_and_scale(session, massive_number, precision): |
| 426 | + table_name = Utils.random_table_name() |
| 427 | + try: |
| 428 | + df = session.create_dataframe( |
| 429 | + [Decimal(massive_number)], |
| 430 | + StructType([StructField("large_value", DecimalType(precision, 0), True)]), |
| 431 | + ) |
| 432 | + datatype = df.schema.fields[0].datatype |
| 433 | + assert isinstance(datatype, LongType) |
| 434 | + assert datatype._precision == precision |
| 435 | + |
| 436 | + # after save as table, the precision information is lost, because it is basically save LongType(), which |
| 437 | + # does not have precision information, thus set to default 38. |
| 438 | + df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
| 439 | + result = session.sql(f"select * from {table_name}") |
| 440 | + session.sql(f"describe table {table_name}").show() |
| 441 | + datatype = result.schema.fields[0].datatype |
| 442 | + assert isinstance(datatype, LongType) |
| 443 | + assert datatype._precision == 38 |
| 444 | + finally: |
| 445 | + session.sql(f"drop table if exists {table_name}").collect() |
| 446 | + |
| 447 | + |
| 448 | +@pytest.mark.skipif( |
| 449 | + "config.getoption('local_testing_mode', default=False)", |
| 450 | + reason="relaxed_types not supported by local testing mode", |
| 451 | +) |
| 452 | +@pytest.mark.parametrize("massive_number", ["9" * 38, "5" * 20, "7" * 10]) |
| 453 | +def test_numeric_type_store_precision_and_scale_read_file(session, massive_number): |
| 454 | + stage_name = Utils.random_stage_name() |
| 455 | + header = ("BIG_NUM",) |
| 456 | + test_data = [(massive_number,)] |
| 457 | + |
| 458 | + def write_csv(data): |
| 459 | + with tempfile.NamedTemporaryFile( |
| 460 | + mode="w+", |
| 461 | + delete=False, |
| 462 | + suffix=".csv", |
| 463 | + newline="", |
| 464 | + ) as file: |
| 465 | + writer = csv.writer(file) |
| 466 | + writer.writerow(header) |
| 467 | + for row in data: |
| 468 | + writer.writerow(row) |
| 469 | + return file.name |
| 470 | + |
| 471 | + file_path = write_csv(test_data) |
| 472 | + |
| 473 | + try: |
| 474 | + Utils.create_stage(session, stage_name, is_temporary=True) |
| 475 | + result = session.file.put( |
| 476 | + file_path, f"@{stage_name}", auto_compress=False, overwrite=True |
| 477 | + ) |
| 478 | + |
| 479 | + # Infer schema from only the short file |
| 480 | + constrained_reader = session.read.options( |
| 481 | + { |
| 482 | + "INFER_SCHEMA": True, |
| 483 | + "INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]}, |
| 484 | + "PARSE_HEADER": True, |
| 485 | + # Only load the short file |
| 486 | + "PATTERN": f".*{result[0].target}", |
| 487 | + } |
| 488 | + ) |
| 489 | + |
| 490 | + # df1 uses constrained types |
| 491 | + df1 = constrained_reader.csv(f"@{stage_name}/") |
| 492 | + datatype = df1.schema.fields[0].datatype |
| 493 | + assert isinstance(datatype, LongType) |
| 494 | + assert datatype._precision == 38 |
| 495 | + |
| 496 | + finally: |
| 497 | + Utils.drop_stage(session, stage_name) |
| 498 | + if os.path.exists(file_path): |
| 499 | + os.remove(file_path) |
| 500 | + |
| 501 | + |
| 502 | +def test_illegal_argument_intergraltype(): |
| 503 | + with pytest.raises(TypeError, match="takes 0 argument but 1 were given"): |
| 504 | + LongType(b=10) |
0 commit comments