Skip to content

Commit cf2aadd

Browse files
devin-petersohnHyukjinKwon
authored andcommitted
[SPARK-55980][PS] Always apply _cast_back_float in numeric arithmetic
### What changes were proposed in this pull request? Remove ANSI mode guard from `_cast_back_float` calls in `num_ops.py`. ### Why are the changes needed? Simplify the code without changing behavior. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? Co-authored-by: Claude Opus 4 Closes #54779 from devin-petersohn/devin/always-cast-back-float. Authored-by: Devin Petersohn <devin.petersohn@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 0e8d39e commit cf2aadd

File tree

1 file changed

+9
-34
lines changed

1 file changed

+9
-34
lines changed

python/pyspark/pandas/data_type_ops/num_ops.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,21 @@ def add(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
125125
_sanitize_list_like(right)
126126
if not is_valid_operand_for_numeric_arithmetic(right):
127127
raise TypeError("Addition can not be applied to given types.")
128-
spark_session = left._internal.spark_frame.sparkSession
129128
new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
130129

131130
def wrapped_add(lc: PySparkColumn, rc: Any) -> PySparkColumn:
132-
expr = PySparkColumn.__add__(lc, rc)
133-
if is_ansi_mode_enabled(spark_session):
134-
expr = _cast_back_float(expr, left.dtype, right)
135-
return expr
131+
return _cast_back_float(PySparkColumn.__add__(lc, rc), left.dtype, right)
136132

137133
return column_op(wrapped_add)(left, new_right)
138134

139135
def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
140136
_sanitize_list_like(right)
141137
if not is_valid_operand_for_numeric_arithmetic(right):
142138
raise TypeError("Subtraction can not be applied to given types.")
143-
spark_session = left._internal.spark_frame.sparkSession
144139
new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
145140

146141
def wrapped_sub(lc: PySparkColumn, rc: Any) -> PySparkColumn:
147-
expr = PySparkColumn.__sub__(lc, rc)
148-
if is_ansi_mode_enabled(spark_session):
149-
expr = _cast_back_float(expr, left.dtype, right)
150-
return expr
142+
return _cast_back_float(PySparkColumn.__sub__(lc, rc), left.dtype, right)
151143

152144
return column_op(wrapped_sub)(left, new_right)
153145

@@ -162,10 +154,9 @@ def mod(left_op: PySparkColumn, right_op: Any) -> PySparkColumn:
162154
expr = F.when(F.lit(right_op == 0), F.lit(None)).otherwise(
163155
((left_op % right_op) + right_op) % right_op
164156
)
165-
expr = _cast_back_float(expr, left.dtype, right)
166157
else:
167158
expr = ((left_op % right_op) + right_op) % right_op
168-
return expr
159+
return _cast_back_float(expr, left.dtype, right)
169160

170161
new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
171162

@@ -190,44 +181,32 @@ def radd(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
190181
_sanitize_list_like(right)
191182
if not isinstance(right, numbers.Number):
192183
raise TypeError("Addition can not be applied to given types.")
193-
spark_session = left._internal.spark_frame.sparkSession
194184
new_right = transform_boolean_operand_to_numeric(right)
195185

196186
def wrapped_radd(lc: PySparkColumn, rc: Any) -> PySparkColumn:
197-
expr = PySparkColumn.__radd__(lc, rc)
198-
if is_ansi_mode_enabled(spark_session):
199-
expr = _cast_back_float(expr, left.dtype, right)
200-
return expr
187+
return _cast_back_float(PySparkColumn.__radd__(lc, rc), left.dtype, right)
201188

202189
return column_op(wrapped_radd)(left, new_right)
203190

204191
def rsub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
205192
_sanitize_list_like(right)
206193
if not isinstance(right, numbers.Number):
207194
raise TypeError("Subtraction can not be applied to given types.")
208-
spark_session = left._internal.spark_frame.sparkSession
209195
new_right = transform_boolean_operand_to_numeric(right)
210196

211197
def wrapped_rsub(lc: PySparkColumn, rc: Any) -> PySparkColumn:
212-
expr = PySparkColumn.__rsub__(lc, rc)
213-
if is_ansi_mode_enabled(spark_session):
214-
expr = _cast_back_float(expr, left.dtype, right)
215-
return expr
198+
return _cast_back_float(PySparkColumn.__rsub__(lc, rc), left.dtype, right)
216199

217200
return column_op(wrapped_rsub)(left, new_right)
218201

219202
def rmul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
220203
_sanitize_list_like(right)
221204
if not isinstance(right, numbers.Number):
222205
raise TypeError("Multiplication can not be applied to given types.")
223-
spark_session = left._internal.spark_frame.sparkSession
224206
new_right = transform_boolean_operand_to_numeric(right)
225207

226208
def wrapped_rmul(lc: PySparkColumn, rc: Any) -> PySparkColumn:
227-
expr = PySparkColumn.__mul__(lc, rc)
228-
if is_ansi_mode_enabled(spark_session):
229-
expr = _cast_back_float(expr, left.dtype, right)
230-
return expr
209+
return _cast_back_float(PySparkColumn.__mul__(lc, rc), left.dtype, right)
231210

232211
return column_op(wrapped_rmul)(left, new_right)
233212

@@ -256,10 +235,9 @@ def safe_rmod(left_op: PySparkColumn, right_op: Any) -> PySparkColumn:
256235
result = F.when(
257236
left_op != 0, ((F.lit(right_op) % left_op) + left_op) % left_op
258237
).otherwise(F.lit(None))
259-
result = _cast_back_float(result, left.dtype, right)
260-
return result
261238
else:
262-
return ((right_op % left_op) + left_op) % left_op
239+
result = ((right_op % left_op) + left_op) % left_op
240+
return _cast_back_float(result, left.dtype, right)
263241

264242
return column_op(safe_rmod)(left, new_right)
265243

@@ -472,10 +450,7 @@ def mul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
472450
new_right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type)
473451

474452
def wrapped_mul(lc: PySparkColumn, rc: Any) -> PySparkColumn:
475-
expr = PySparkColumn.__mul__(lc, rc)
476-
if is_ansi:
477-
expr = _cast_back_float(expr, left.dtype, right)
478-
return expr
453+
return _cast_back_float(PySparkColumn.__mul__(lc, rc), left.dtype, right)
479454

480455
return column_op(wrapped_mul)(left, new_right)
481456

0 commit comments

Comments
 (0)