diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a380f682ab..78c6bc1c7d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,7 @@ - Added support for `Index.all` and `Index.any`. - Added support for `Series.dt.is_year_start` and `Series.dt.is_year_end`. - Added support for `Series.dt.is_quarter_start` and `Series.dt.is_quarter_end`. +- Added support for lazy `DatetimeIndex`. - Added support for `Series.argmax` and `Series.argmin`. - Added support for `Series.dt.is_leap_year`. diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index cea2016e54c..dcf5db871a2 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -39,7 +39,6 @@ CategoricalDtype, CategoricalIndex, DateOffset, - DatetimeIndex, DatetimeTZDtype, ExcelWriter, Flags, @@ -156,6 +155,7 @@ import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401 from snowflake.snowpark.modin.plugin.extensions.pd_overrides import ( # isort: skip # noqa: E402,F401 Index, + DatetimeIndex, ) import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401 diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py new file mode 100644 index 00000000000..0ce6664e0f2 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -0,0 +1,152 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Module houses ``DatetimeIndex`` class, that is distributed version of +``pandas.DatetimeIndex``. +""" + +from __future__ import annotations + +import numpy as np +import pandas as native_pd +from pandas._libs import lib +from pandas._typing import ArrayLike, Dtype, Frequency, Hashable, TimeAmbiguous + +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, +) +from snowflake.snowpark.modin.plugin.extensions.index import Index + +_CONSTRUCTOR_DEFAULTS = { + "freq": lib.no_default, + "tz": lib.no_default, + "normalize": lib.no_default, + "closed": lib.no_default, + "ambiguous": "raise", + "dayfirst": False, + "yearfirst": False, + "dtype": None, + "copy": False, + "name": None, +} + + +class DatetimeIndex(Index): + + # Equivalent index type in native pandas + _NATIVE_INDEX_TYPE = native_pd.DatetimeIndex + + def __new__(cls, *args, **kwargs): + """ + Create new instance of DatetimeIndex. This overrides behavior of Index.__new__. + Args: + *args: arguments. + **kwargs: keyword arguments. + + Returns: + New instance of DatetimeIndex. + """ + return object.__new__(cls) + + def __init__( + self, + data: ArrayLike | SnowflakeQueryCompiler | None = None, + freq: Frequency | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["freq"], + tz=_CONSTRUCTOR_DEFAULTS["tz"], + normalize: bool | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["normalize"], + closed=_CONSTRUCTOR_DEFAULTS["closed"], + ambiguous: TimeAmbiguous = _CONSTRUCTOR_DEFAULTS["ambiguous"], + dayfirst: bool = _CONSTRUCTOR_DEFAULTS["dayfirst"], + yearfirst: bool = _CONSTRUCTOR_DEFAULTS["yearfirst"], + dtype: Dtype | None = _CONSTRUCTOR_DEFAULTS["dtype"], + copy: bool = _CONSTRUCTOR_DEFAULTS["copy"], + name: Hashable | None = _CONSTRUCTOR_DEFAULTS["name"], + ) -> None: + """ + Immutable ndarray-like of datetime64 data. + + Parameters + ---------- + data : array-like (1-dimensional) or snowflake query compiler + Datetime-like data to construct index with. + freq : str or pandas offset object, optional + One of pandas date offset strings or corresponding objects. The string + 'infer' can be passed in order to set the frequency of the index as the + inferred frequency upon creation. + tz : pytz.timezone or dateutil.tz.tzfile or datetime.tzinfo or str + Set the Timezone of the data. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + closed : {'left', 'right'}, optional + Set whether to include `start` and `end` that are on the + boundary. The default includes boundary points on either end. + ambiguous : 'infer', bool-ndarray, 'NaT', default 'raise' + When clocks moved backward due to DST, ambiguous times may arise. + For example in Central European Time (UTC+01), when going from 03:00 + DST to 02:00 non-DST, 02:30:00 local time occurs both at 00:30:00 UTC + and at 01:30:00 UTC. In such a situation, the `ambiguous` parameter + dictates how ambiguous times should be handled. + + - 'infer' will attempt to infer fall dst-transition hours based on + order + - bool-ndarray where True signifies a DST time, False signifies a + non-DST time (note that this flag is only applicable for ambiguous + times) + - 'NaT' will return NaT where there are ambiguous times + - 'raise' will raise an AmbiguousTimeError if there are ambiguous times. + dayfirst : bool, default False + If True, parse dates in `data` with the day first order. + yearfirst : bool, default False + If True parse dates in `data` with the year first order. + dtype : numpy.dtype or DatetimeTZDtype or str, default None + Note that the only NumPy dtype allowed is `datetime64[ns]`. + copy : bool, default False + Make a copy of input ndarray. + name : label, default None + Name to be stored in the index. + + Examples + -------- + >>> idx = pd.DatetimeIndex(["1/1/2020 10:00:00+00:00", "2/1/2020 11:00:00+00:00"], tz="America/Los_Angeles") + >>> idx + DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None) + """ + if isinstance(data, SnowflakeQueryCompiler): + # Raise error if underlying type is not a TimestampType. + current_dtype = data.index_dtypes[0] + if not current_dtype == np.dtype("datetime64[ns]"): + raise ValueError( + "DatetimeIndex can only be created from a query compiler with TimestampType." + ) + kwargs = { + "freq": freq, + "tz": tz, + "normalize": normalize, + "closed": closed, + "ambiguous": ambiguous, + "dayfirst": dayfirst, + "yearfirst": yearfirst, + "dtype": dtype, + "copy": copy, + "name": name, + } + self._init_index(data, _CONSTRUCTOR_DEFAULTS, **kwargs) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index c0f793554dd..c817a787af8 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -23,6 +23,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, Callable, Hashable, Iterator, Literal import modin @@ -32,7 +33,7 @@ from pandas._typing import ArrayLike, DtypeObj, NaPosition from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.dtypes.common import pandas_dtype +from pandas.core.dtypes.common import is_datetime64_any_dtype, pandas_dtype from snowflake.snowpark.modin.pandas import DataFrame, Series from snowflake.snowpark.modin.pandas.base import BasePandasDataset @@ -47,8 +48,68 @@ ) from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +_CONSTRUCTOR_DEFAULTS = { + "dtype": None, + "copy": False, + "name": None, + "tupleize_cols": True, +} + class Index(metaclass=TelemetryMeta): + + # Equivalent index type in native pandas + _NATIVE_INDEX_TYPE = native_pd.Index + + def __new__( + cls, + data: ArrayLike | SnowflakeQueryCompiler | None = None, + dtype: str | np.dtype | ExtensionDtype | None = _CONSTRUCTOR_DEFAULTS["dtype"], + copy: bool = _CONSTRUCTOR_DEFAULTS["copy"], + name: object = _CONSTRUCTOR_DEFAULTS["name"], + tupleize_cols: bool = _CONSTRUCTOR_DEFAULTS["tupleize_cols"], + ) -> Index: + """ + Override __new__ method to control new instance creation of Index. + Depending on data type, it will create a Index or DatetimeIndex instance. + + Parameters + ---------- + data : array-like (1-dimensional) + dtype : str, numpy.dtype, or ExtensionDtype, optional + Data type for the output Index. If not specified, this will be + inferred from `data`. + See the :ref:`user guide ` for more usages. + copy : bool, default False + Copy input data. + name : object + Name to be stored in the index. + tupleize_cols : bool (default: True) + When True, attempt to create a MultiIndex if possible. + + Returns + ------- + New instance of Index or DatetimeIndex. + DatetimeIndex object will be returned if the column/data have datetime type. + """ + from snowflake.snowpark.modin.plugin.extensions.datetime_index import ( + DatetimeIndex, + ) + + orig_data = data + data = data._query_compiler if isinstance(data, BasePandasDataset) else data + + if isinstance(data, SnowflakeQueryCompiler): + dtype = data.index_dtypes[0] + if dtype == np.dtype("datetime64[ns]"): + return DatetimeIndex(orig_data) + return object.__new__(cls) + else: + index = native_pd.Index(data, dtype, copy, name, tupleize_cols) + if isinstance(index, native_pd.DatetimeIndex): + return DatetimeIndex(orig_data) + return object.__new__(cls) + def __init__( self, data: ArrayLike @@ -56,10 +117,10 @@ def __init__( | Series | SnowflakeQueryCompiler | None = None, - dtype: str | np.dtype | ExtensionDtype | None = None, - copy: bool = False, - name: object = None, - tupleize_cols: bool = True, + dtype: str | np.dtype | ExtensionDtype | None = _CONSTRUCTOR_DEFAULTS["dtype"], + copy: bool = _CONSTRUCTOR_DEFAULTS["copy"], + name: object = _CONSTRUCTOR_DEFAULTS["name"], + tupleize_cols: bool = _CONSTRUCTOR_DEFAULTS["tupleize_cols"], ) -> None: """ Immutable sequence used for indexing and alignment. @@ -97,19 +158,33 @@ def __init__( >>> pd.Index([1, 2, 3], dtype="uint8") Index([1, 2, 3], dtype='int64') """ + kwargs = { + "dtype": dtype, + "copy": copy, + "name": name, + "tupleize_cols": tupleize_cols, + } + self._init_index(data, _CONSTRUCTOR_DEFAULTS, **kwargs) + + def _init_index( + self, + data: ArrayLike | SnowflakeQueryCompiler | None, + ctor_defaults: dict, + **kwargs: Any, + ): self._parent = data if isinstance(data, BasePandasDataset) else None data = data._query_compiler if isinstance(data, BasePandasDataset) else data + if isinstance(data, SnowflakeQueryCompiler): + # Raise warning if `data` is query compiler with non-default arguments. + for arg_name, arg_value in kwargs.items(): + assert ( + arg_value == ctor_defaults[arg_name] + ), f"Non-default argument '{arg_name}={arg_value}' when constructing Index with query compiler" if isinstance(data, SnowflakeQueryCompiler): qc = data else: qc = DataFrame( - index=native_pd.Index( - data=data, - dtype=dtype, - copy=copy, - name=name, - tupleize_cols=tupleize_cols, - ) + index=self._NATIVE_INDEX_TYPE(data=data, **kwargs) )._query_compiler self._query_compiler = qc.drop(columns=qc.columns) @@ -135,7 +210,7 @@ def __getattr__(self, key: str) -> Any: return object.__getattribute__(self, key) except AttributeError as err: if not key.startswith("_"): - native_index = native_pd.Index([]) + native_index = self._NATIVE_INDEX_TYPE([]) if hasattr(native_index, key): # Any methods that not supported by the current Index.py but exist in a # native pandas index object should raise a not implemented error for now. @@ -164,6 +239,13 @@ def to_pandas( statement_params=statement_params, **kwargs ) + @cached_property + def __constructor__(self): + """ + Returns: Type of the instance. + """ + return type(self) + @property def values(self) -> ArrayLike: """ @@ -327,7 +409,7 @@ def unique(self, level: Hashable | None = None) -> Index: raise IndexError( f"Too many levels: Index has only 1 level, {level} is not a valid level number." ) - return Index( + return self.__constructor__( data=self._query_compiler.groupby_agg( by=self._query_compiler.get_index_names(axis=0), agg_func={}, @@ -421,6 +503,15 @@ def astype(self, dtype: str | type | ExtensionDtype, copy: bool = True) -> Index column: dtype for column in self._query_compiler.get_index_names() } new_query_compiler = self._query_compiler.astype_index(col_dtypes) + + if is_datetime64_any_dtype(dtype): + # local import to avoid circular dependency. + from snowflake.snowpark.modin.plugin.extensions.datetime_index import ( + DatetimeIndex, + ) + + return DatetimeIndex(data=new_query_compiler) + return Index(data=new_query_compiler) @property @@ -507,7 +598,7 @@ def set_names( # TODO: SNOW-1458122 implement set_names WarningMessage.index_to_pandas_warning("set_names") if not inplace: - return Index( + return self.__constructor__( self.to_pandas().set_names(names, level=level, inplace=inplace) ) return self.to_pandas().set_names(names, level=level, inplace=inplace) @@ -770,7 +861,7 @@ def copy( False """ WarningMessage.ignored_argument(operation="copy", argument="deep", message="") - return Index(self._query_compiler.copy(), name=name) + return self.__constructor__(self._query_compiler.copy(), name=name) @index_not_implemented() def delete(self) -> None: @@ -826,7 +917,7 @@ def drop( """ # TODO: SNOW-1458146 implement drop WarningMessage.index_to_pandas_warning("drop") - return Index(self.to_pandas().drop(labels=labels, errors=errors)) + return self.__constructor__(self.to_pandas().drop(labels=labels, errors=errors)) @index_not_implemented() def drop_duplicates(self) -> None: @@ -974,13 +1065,13 @@ def equals(self, other: Any) -> bool: if self is other: return True - if not isinstance(other, (Index, native_pd.Index)): + if not isinstance(other, (type(self), self._NATIVE_INDEX_TYPE)): return False - if isinstance(other, native_pd.Index): + if isinstance(other, self._NATIVE_INDEX_TYPE): # Same as DataFrame/Series equals. Convert native Index to Snowpark pandas # Index for comparison. - other = Index(other) + other = self.__constructor__(other) return self._query_compiler.index_equals(other._query_compiler) @@ -1786,7 +1877,7 @@ def sort_values( key=key, include_indexer=return_indexer, ) - index = Index(res) + index = self.__constructor__(res) if return_indexer: # When `return_indexer` is True, `res` is a query compiler with one index column # and one data column. @@ -1866,7 +1957,7 @@ def intersection(self, other: Any, sort: bool = False) -> Index: """ # TODO: SNOW-1458151 implement intersection WarningMessage.index_to_pandas_warning("intersection") - return Index( + return self.__constructor__( self.to_pandas().intersection( other=try_convert_index_to_native(other), sort=sort ) @@ -1919,7 +2010,7 @@ def union(self, other: Any, sort: bool = False) -> Index: # TODO: SNOW-1458149 implement union w/o sort # TODO: SNOW-1468240 implement union w/ sort WarningMessage.index_to_pandas_warning("union") - return Index( + return self.__constructor__( self.to_pandas().union(other=try_convert_index_to_native(other), sort=sort) ) @@ -1958,7 +2049,7 @@ def difference(self, other: Any, sort: Any = None) -> Index: """ # TODO: SNOW-1458152 implement difference WarningMessage.index_to_pandas_warning("difference") - return Index( + return self.__constructor__( self.to_pandas().difference(try_convert_index_to_native(other), sort=sort) ) @@ -1990,7 +2081,7 @@ def _get_indexer_strict(self, key: Any, axis_name: str) -> tuple[Index, np.ndarr """ WarningMessage.index_to_pandas_warning("_get_indexer_strict") tup = self.to_pandas()._get_indexer_strict(key=key, axis_name=axis_name) - return Index(tup[0]), tup[1] + return self.__constructor__(tup[0]), tup[1] def get_level_values(self, level: int | str) -> Index: """ @@ -2025,7 +2116,7 @@ def get_level_values(self, level: int | str) -> Index: Index(['a', 'b', 'c'], dtype='object') """ WarningMessage.index_to_pandas_warning("get_level_values") - return Index(self.to_pandas().get_level_values(level=level)) + return self.__constructor__(self.to_pandas().get_level_values(level=level)) @index_not_implemented() def isin(self) -> None: diff --git a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py index d9cf43f6cfd..3515baaee3a 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py @@ -40,9 +40,13 @@ if TYPE_CHECKING: # pragma: no cover import csv +from snowflake.snowpark.modin.plugin.extensions.datetime_index import ( # noqa: F401 + DatetimeIndex, +) from snowflake.snowpark.modin.plugin.extensions.index import Index # noqa: F401 register_pd_accessor("Index")(Index) +register_pd_accessor("DatetimeIndex")(DatetimeIndex) @_inherit_docstrings(native_pd.read_csv, apilink="pandas.read_csv") diff --git a/tests/integ/modin/frame/test_loc.py b/tests/integ/modin/frame/test_loc.py index 381a3034c12..f258f261b51 100644 --- a/tests/integ/modin/frame/test_loc.py +++ b/tests/integ/modin/frame/test_loc.py @@ -3860,7 +3860,7 @@ def test_df_partial_string_indexing(ops): @sql_count_checker(query_count=1) def test_df_partial_string_indexing_with_timezone(): native_df = native_pd.DataFrame( - [0], index=pd.DatetimeIndex(["2019-01-01"], tz="America/Los_Angeles") + [0], index=native_pd.DatetimeIndex(["2019-01-01"], tz="America/Los_Angeles") ) snowpark_df = pd.DataFrame(native_df) diff --git a/tests/integ/modin/index/conftest.py b/tests/integ/modin/index/conftest.py index b7b93b7a03c..3029bd79f55 100644 --- a/tests/integ/modin/index/conftest.py +++ b/tests/integ/modin/index/conftest.py @@ -12,6 +12,10 @@ index=native_pd.Index([[1, 2], [2, 3], [3, 4]]), ), native_pd.DataFrame([1]), + native_pd.DataFrame( + data={"col1": [1, 2, 3], "col2": [3, 4, 5]}, + index=native_pd.DatetimeIndex(["2024-01-01", "2024-02-01", "2024-03-01"]), + ), ] NATIVE_INDEX_TEST_DATA = [ @@ -23,6 +27,15 @@ native_pd.Index([1]), native_pd.Index(["a", "b", 1, 2]), native_pd.Index(["a", "b", "c", "d"]), + native_pd.DatetimeIndex( + ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], + tz="America/Los_Angeles", + ), + native_pd.DatetimeIndex( + ["2020-01-01 10:00:00+05:00", "2020-02-01 11:00:00+05:00"], + tz="America/Los_Angeles", + ), + native_pd.DatetimeIndex([1262347200000000000, 1262347400000000000]), ] NATIVE_INDEX_UNIQUE_TEST_DATA = [ @@ -34,4 +47,12 @@ native_pd.Index([5, None, 7, None]), native_pd.Index([1]), native_pd.Index(["a", "b", 1, 2, None, "a", 2], name="mixed index"), + native_pd.DatetimeIndex( + ["2020-01-01 10:00:00+00:00", "2020-02-01 11:00:00+00:00"], + tz="America/Los_Angeles", + ), + native_pd.DatetimeIndex( + ["2020-01-01 10:00:00+00:00", "2020-01-01 10:00:00+00:00"], + tz="America/Los_Angeles", + ), ] diff --git a/tests/integ/modin/index/test_astype.py b/tests/integ/modin/index/test_astype.py index 97013c5d211..de0464d068a 100644 --- a/tests/integ/modin/index/test_astype.py +++ b/tests/integ/modin/index/test_astype.py @@ -40,6 +40,7 @@ (native_pd.Index(["a", "b", "c", 1, 2, 4], dtype="O"), str), (native_pd.Index([1, 2, 3, 4], dtype="O"), np.int64), (native_pd.Index([1.11, 2.1111, 3.0002, 4.111], dtype=object), np.float64), + (native_pd.Index(["2024-01-01 10:00:00"], dtype=object), "datetime64[ns]"), ], ) def test_index_astype(index, type): diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py new file mode 100644 index 00000000000..83f8630dcdc --- /dev/null +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import modin.pandas as pd +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.utils import ( + assert_frame_equal, + assert_index_equal, + assert_series_equal, +) + + +@sql_count_checker(query_count=3) +def test_datetime_index_construction(): + # create from native pandas datetime index. + index = native_pd.DatetimeIndex(["2021-01-01", "2021-01-02", "2021-01-03"]) + snow_index = pd.Index(index) + assert isinstance(snow_index, pd.DatetimeIndex) + + # create from query compiler with timestamp type. + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=index) + snow_index = df.index + assert isinstance(snow_index, pd.DatetimeIndex) + + # create from snowpark pandas datetime index. + snow_index = pd.Index(pd.DatetimeIndex([123])) + assert isinstance(snow_index, pd.DatetimeIndex) + + +@pytest.mark.skip(reason="SNOW-1616989: Fix datetime index construction from int") +@sql_count_checker(query_count=1) +def test_datetime_index_construction_from_int(): + snow_index = pd.DatetimeIndex(pd.Index([1, 2, 3])) + native_index = native_pd.DatetimeIndex(native_pd.Index([1, 2, 3])) + assert_index_equal(snow_index, native_index) + + +@sql_count_checker(query_count=0) +def test_datetime_index_construction_negative(): + # Try to create datatime index query compiler with int type. + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + msg = "DatetimeIndex can only be created from a query compiler with TimestampType" + with pytest.raises(ValueError, match=msg): + pd.DatetimeIndex(df._query_compiler) + + +@sql_count_checker(query_count=0) +@pytest.mark.parametrize( + "kwargs", + [ + {"freq": "M"}, + {"tz": "UTC"}, + {"normalize": True}, + {"closed": "left"}, + {"ambiguous": "infer"}, + {"dayfirst": True}, + {"yearfirst": True}, + {"dtype": "int"}, + {"copy": True}, + {"name": "abc"}, + ], +) +def test_non_default_args(kwargs): + idx = pd.DatetimeIndex(["2014-01-01 10:00:00"]) + + name = list(kwargs.keys())[0] + value = list(kwargs.values())[0] + msg = f"Non-default argument '{name}={value}' when constructing Index with query compiler" + with pytest.raises(AssertionError, match=msg): + pd.DatetimeIndex(data=idx._query_compiler, **kwargs) + + +@sql_count_checker(query_count=6) +def test_index_parent(): + """ + Check whether the parent field in Index is updated properly. + """ + native_idx1 = native_pd.DatetimeIndex(["2024-01-01"], name="id1") + native_idx2 = native_pd.DatetimeIndex(["2024-01-01", "2024-01-02"], name="id2") + + # DataFrame case. + df = pd.DataFrame({"A": [1]}, index=native_idx1) + snow_idx = df.index + assert_frame_equal(snow_idx._parent, df) + assert_index_equal(snow_idx, native_idx1) + + # Series case. + s = pd.Series([1, 2], index=native_idx2, name="zyx") + snow_idx = s.index + assert_series_equal(snow_idx._parent, s) + assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_equals.py b/tests/integ/modin/index/test_equals.py index ccc2165ae9e..70c611a90ca 100644 --- a/tests/integ/modin/index/test_equals.py +++ b/tests/integ/modin/index/test_equals.py @@ -13,14 +13,49 @@ @pytest.mark.parametrize( "lhs, rhs, expected", [ - ([], [], True), # empty indices - ([None], [None], True), # none indices - ([1, 2, 3], [1, 2, 3], True), - ([1, 2, None], [1, 2, None], True), # nulls are considered equal - ([1, 2, 3], [1.0, 2.0, 3.0], True), # type is ignored - ([1, 2, 3], [1, 3, 2], False), # different order - ([1, 2, 3], [1, 2, 3, 4], False), # extra value in right - ([1, 2, 3, 4], [1, 2, 3], False), # extra value in left + (native_pd.Index([]), native_pd.Index([]), True), # empty indices + (native_pd.Index([None]), native_pd.Index([None]), True), # none indices + (native_pd.Index([1, 2, 3]), native_pd.Index([1, 2, 3]), True), + ( + native_pd.Index([1, 2, None]), + native_pd.Index([1, 2, None]), + True, + ), # nulls are equal + ( + native_pd.Index([1, 2, 3]), + native_pd.Index([1.0, 2.0, 3.0]), + True, + ), # type is ignored + ( + native_pd.Index([1, 2, 3]), + native_pd.Index([1, 3, 2]), + False, + ), # different order + ( + native_pd.Index([1, 2, 3]), + native_pd.Index([1, 2, 3, 4]), + False, + ), # extra value in right + ( + native_pd.Index([1, 2, 3, 4]), + native_pd.Index([1, 2, 3]), + False, + ), # extra value in left + ( + native_pd.DatetimeIndex(["2024-01-01 03:00:00+00:00"]), + native_pd.DatetimeIndex(["2024-01-01 03:00:00+00:00"]), + True, + ), # same + ( + native_pd.DatetimeIndex(["2024-01-01 04:00:00+00:00"]), + native_pd.DatetimeIndex(["2024-01-01 05:00:00+00:00"]), + False, + ), # different + ( + native_pd.DatetimeIndex(["2024-01-01 04:00:00+00:00"]), + native_pd.DatetimeIndex(["2024-01-01 04:00:00+05:00"]), + False, + ), # different tz ], ) def test_index_equals(lhs, rhs, expected): diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index ce0fc05f80b..1853e1c8bf6 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -318,7 +318,11 @@ def test_df_index_to_frame(native_df, index, name): @pytest.mark.parametrize("native_index", NATIVE_INDEX_TEST_DATA) def test_index_dtype(native_index): snow_index = pd.Index(native_index) - assert snow_index.dtype == native_index.dtype + if isinstance(native_index, native_pd.DatetimeIndex): + # Snowpark pandas does not include timezone info in dtype datetime64[ns], + assert snow_index.dtype == "datetime64[ns]" + else: + assert snow_index.dtype == native_index.dtype @sql_count_checker(query_count=0) @@ -364,3 +368,18 @@ def test_index_parent(): snow_idx = s.index assert_series_equal(snow_idx._parent, s) assert_index_equal(snow_idx, native_idx2) + + +@sql_count_checker(query_count=0) +@pytest.mark.parametrize( + "kwargs", + [{"dtype": "str"}, {"copy": True}, {"name": "abc"}, {"tupleize_cols": False}], +) +def test_non_default_args(kwargs): + idx = pd.Index([1, 2, 3, 4], name="name", dtype="int64") + + name = list(kwargs.keys())[0] + value = list(kwargs.values())[0] + msg = f"Non-default argument '{name}={value}' when constructing Index with query compiler" + with pytest.raises(AssertionError, match=msg): + pd.Index(data=idx._query_compiler, **kwargs) diff --git a/tests/integ/modin/series/test_loc.py b/tests/integ/modin/series/test_loc.py index 073abf9bad7..32c1bf64c4a 100644 --- a/tests/integ/modin/series/test_loc.py +++ b/tests/integ/modin/series/test_loc.py @@ -1753,7 +1753,7 @@ def test_series_non_partial_string_indexing_cases(ops, error): def test_series_partial_string_indexing_behavior_diff(): native_series_minute = native_pd.Series( [1, 2, 3], - pd.DatetimeIndex( + native_pd.DatetimeIndex( ["2011-12-31 23:59:00", "2012-01-01 00:00:00", "2012-01-01 00:02:00"] ), ) @@ -1771,7 +1771,7 @@ def test_series_partial_string_indexing_behavior_diff(): snow_res, native_pd.Series( [1], - pd.DatetimeIndex(["2011-12-31 23:59:00"]), + native_pd.DatetimeIndex(["2011-12-31 23:59:00"]), ), check_dtype=False, ) diff --git a/tests/integ/modin/tools/test_to_datetime.py b/tests/integ/modin/tools/test_to_datetime.py index 15130724465..d08495b31e9 100644 --- a/tests/integ/modin/tools/test_to_datetime.py +++ b/tests/integ/modin/tools/test_to_datetime.py @@ -15,7 +15,8 @@ import pandas._testing as tm import pytest import pytz -from modin.pandas import DatetimeIndex, NaT, Series, Timestamp, to_datetime +from modin.pandas import NaT, Series, Timestamp, to_datetime +from pandas import DatetimeIndex from pandas.core.arrays import DatetimeArray import snowflake.snowpark.modin.plugin # noqa: F401 diff --git a/tests/unit/modin/test_class.py b/tests/unit/modin/test_class.py index 7dfa6d43bae..1d4b3881b7b 100644 --- a/tests/unit/modin/test_class.py +++ b/tests/unit/modin/test_class.py @@ -24,7 +24,6 @@ def test_class_equivalence(): assert pd.CategoricalDtype is native_pd.CategoricalDtype assert pd.CategoricalIndex is native_pd.CategoricalIndex assert pd.DateOffset is native_pd.DateOffset - assert pd.DatetimeIndex is native_pd.DatetimeIndex assert pd.DatetimeTZDtype is native_pd.DatetimeTZDtype assert pd.ExcelWriter is native_pd.ExcelWriter assert pd.Flags is native_pd.Flags