Skip to content

Commit 90569c9

Browse files
authored
Added correct casting and mod operation (#172)
1 parent c9ce6de commit 90569c9

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

dask_sql/input_utils/hive.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def wrapped_read_function(location, column_information, **kwargs):
108108
df = df.rename(columns=dict(zip(df.columns, column_information.keys())))
109109

110110
for col, expected_type in column_information.items():
111-
df = cast_column_type(df, col, expected_type)
111+
df[col] = cast_column_type(df[col], expected_type)
112112

113113
return df
114114

@@ -146,8 +146,9 @@ def wrapped_read_function(location, column_information, **kwargs):
146146

147147
partition_id = 0
148148
for partition_key, partition_type in partition_information.items():
149-
table[partition_key] = partition_values[partition_id]
150-
table = cast_column_type(table, partition_key, partition_type)
149+
table[partition_key] = cast_column_type(
150+
partition_values[partition_id], partition_type
151+
)
151152

152153
partition_id += 1
153154

dask_sql/mappings.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -255,24 +255,20 @@ def similar_type(lhs: type, rhs: type) -> bool:
255255
return False
256256

257257

258-
def cast_column_type(
259-
df: dd.DataFrame, column_name: str, expected_type: type
260-
) -> dd.DataFrame:
258+
def cast_column_type(column: dd.Series, expected_type: type) -> dd.Series:
261259
"""
262260
Cast the type of the given column to the expected type,
263261
if they are far "enough" away.
264262
This means, a float will never be converted into a double
265263
or a tinyint into another int - but a string to an integer etc.
266264
"""
267-
current_type = df[column_name].dtype
265+
current_type = column.dtype
268266

269-
logger.debug(
270-
f"Column {column_name} has type {current_type}, expecting {expected_type}..."
271-
)
267+
logger.debug(f"Column has type {current_type}, expecting {expected_type}...")
272268

273269
if similar_type(current_type, expected_type):
274270
logger.debug("...not converting.")
275-
return df
271+
return column
276272

277273
current_float = pd.api.types.is_float_dtype(current_type)
278274
expected_integer = pd.api.types.is_integer_dtype(expected_type)
@@ -282,9 +278,9 @@ def cast_column_type(
282278
# because NA is a different type. It works with np.NaN though.
283279
# For our use case, that does not matter, as the conversion to integer later
284280
# will convert both NA and np.NaN to NA.
285-
df[column_name] = da.trunc(df[column_name].fillna(value=np.NaN))
281+
column = da.trunc(column.fillna(value=np.NaN))
286282

287-
logger.debug(f"Need to cast {column_name} from {current_type} to {expected_type}")
288-
df[column_name] = df[column_name].astype(expected_type)
283+
logger.debug(f"Need to cast from {current_type} to {expected_type}")
284+
column = column.astype(expected_type)
289285

290-
return df
286+
return column

dask_sql/physical/rel/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,6 @@ def fix_dtype_to_row_type(
106106
expected_type = sql_to_python_type(field_type)
107107
field_name = cc.get_backend_by_frontend_index(index)
108108

109-
df = cast_column_type(df, field_name, expected_type)
109+
df[field_name] = cast_column_type(df[field_name], expected_type)
110110

111111
return DataContainer(df, dc.column_container)

dask_sql/physical/rex/core/call.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dask.utils import random_state_data
1515

1616
from dask_sql.datacontainer import DataContainer
17-
from dask_sql.mappings import sql_to_python_type
17+
from dask_sql.mappings import cast_column_type, sql_to_python_type
1818
from dask_sql.physical.rex import RexConverter
1919
from dask_sql.physical.rex.base import BaseRexPlugin
2020
from dask_sql.utils import (
@@ -179,6 +179,21 @@ def case(self, *operands) -> SeriesOrScalar:
179179
return then if where else other
180180

181181

182+
class CastOperation(Operation):
183+
"""The cast operator"""
184+
185+
needs_rex = True
186+
187+
def __init__(self):
188+
super().__init__(self.cast)
189+
190+
def cast(self, operand, rex=None) -> SeriesOrScalar:
191+
output_type = str(rex.getType())
192+
output_type = sql_to_python_type(output_type.upper())
193+
194+
return cast_column_type(operand, output_type)
195+
196+
182197
class IsFalseOperation(Operation):
183198
"""The is false operator"""
184199

@@ -650,7 +665,7 @@ class RexCallPlugin(BaseRexPlugin):
650665
"is distinct from": NotOperation().of(IsNotDistinctOperation()),
651666
"is not distinct from": IsNotDistinctOperation(),
652667
# special operations
653-
"cast": lambda x: x,
668+
"cast": CastOperation(),
654669
"case": CaseOperation(),
655670
"like": LikeOperation(),
656671
"similar to": SimilarOperation(),
@@ -680,7 +695,7 @@ class RexCallPlugin(BaseRexPlugin):
680695
"floor": CeilFloorOperation("floor"),
681696
"log10": Operation(da.log10),
682697
"ln": Operation(da.log),
683-
# "mod": Operation(da.mod), # needs cast
698+
"mod": Operation(da.mod),
684699
"power": Operation(da.power),
685700
"radians": Operation(da.radians),
686701
"round": TensorScalarOperation(lambda x, *ops: x.round(*ops), np.round),

tests/integration/test_rex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def test_math_operations(c, df):
311311
, FLOOR(b) AS "floor"
312312
, LOG10(b) AS "log10"
313313
, LN(b) AS "ln"
314+
, MOD(b, 4) AS "mod"
314315
, POWER(b, 2) AS "power"
315316
, POWER(b, a) AS "power2"
316317
, RADIANS(b) AS "radians"
@@ -339,6 +340,7 @@ def test_math_operations(c, df):
339340
expected_df["floor"] = np.floor(df.b)
340341
expected_df["log10"] = np.log10(df.b)
341342
expected_df["ln"] = np.log(df.b)
343+
expected_df["mod"] = np.mod(df.b, 4)
342344
expected_df["power"] = np.power(df.b, 2)
343345
expected_df["power2"] = np.power(df.b, df.a)
344346
expected_df["radians"] = df.b / 180 * np.pi

0 commit comments

Comments
 (0)