Skip to content

Commit ff39c92

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-30252][SQL] Disallow negative scale of Decimal
### What changes were proposed in this pull request? This PR propose to disallow negative `scale` of `Decimal` in Spark. And this PR brings two behavior changes: 1) for literals like `1.23E4BD` or `1.23E4`(with `spark.sql.legacy.exponentLiteralAsDecimal.enabled`=true, see [SPARK-29956](https://issues.apache.org/jira/browse/SPARK-29956)), we set its `(precision, scale)` to (5, 0) rather than (3, -2); 2) add negative `scale` check inside the decimal method if it exposes to set `scale` explicitly. If check fails, `AnalysisException` throws. And user could still use `spark.sql.legacy.allowNegativeScaleOfDecimal.enabled` to restore the previous behavior. ### Why are the changes needed? According to SQL standard, > 4.4.2 Characteristics of numbers An exact numeric type has a precision P and a scale S. P is a positive integer that determines the number of significant digits in a particular radix R, where R is either 2 or 10. S is a non-negative integer. scale of Decimal should always be non-negative. And other mainstream databases, like Presto, PostgreSQL, also don't allow negative scale. Presto: ``` presto:default> create table t (i decimal(2, -1)); Query 20191213_081238_00017_i448h failed: line 1:30: mismatched input '-'. Expecting: <integer>, <type> create table t (i decimal(2, -1)) ``` PostgrelSQL: ``` postgres=# create table t(i decimal(2, -1)); ERROR: NUMERIC scale -1 must be between 0 and precision 2 LINE 1: create table t(i decimal(2, -1)); ^ ``` And, actually, Spark itself already doesn't allow to create table with negative decimal types using SQL: ``` scala> spark.sql("create table t(i decimal(2, -1))"); org.apache.spark.sql.catalyst.parser.ParseException: no viable alternative at input 'create table t(i decimal(2, -'(line 1, pos 28) == SQL == create table t(i decimal(2, -1)) ----------------------------^^^ at org.apache.spark.sql.catalyst.parser.ParseException.withCommand(ParseDriver.scala:263) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:130) at org.apache.spark.sql.execution.SparkSqlParser.parse(SparkSqlParser.scala:48) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:76) at org.apache.spark.sql.SparkSession.$anonfun$sql$1(SparkSession.scala:605) at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111) at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:605) ... 35 elided ``` However, it is still possible to create such table or `DatFrame` using Spark SQL programming API: ``` scala> val tb = CatalogTable( TableIdentifier("test", None), CatalogTableType.MANAGED, CatalogStorageFormat.empty, StructType(StructField("i", DecimalType(2, -1) ) :: Nil)) ``` ``` scala> spark.sql("SELECT 1.23E4BD") res2: org.apache.spark.sql.DataFrame = [1.23E+4: decimal(3,-2)] ``` while, these two different behavior could make user confused. On the other side, even if user creates such table or `DataFrame` with negative scale decimal type, it can't write data out if using format, like `parquet` or `orc`. Because these formats have their own check for negative scale and fail on it. ``` scala> spark.sql("SELECT 1.23E4BD").write.saveAsTable("parquet") 19/12/13 17:37:04 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0) java.lang.IllegalArgumentException: Invalid DECIMAL scale: -2 at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:53) at org.apache.parquet.schema.Types$BasePrimitiveBuilder.decimalMetadata(Types.java:495) at org.apache.parquet.schema.Types$BasePrimitiveBuilder.build(Types.java:403) at org.apache.parquet.schema.Types$BasePrimitiveBuilder.build(Types.java:309) at org.apache.parquet.schema.Types$Builder.named(Types.java:290) at org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter.convertField(ParquetSchemaConverter.scala:428) at org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter.convertField(ParquetSchemaConverter.scala:334) at org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter.$anonfun$convert$2(ParquetSchemaConverter.scala:326) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238) at scala.collection.Iterator.foreach(Iterator.scala:941) at scala.collection.Iterator.foreach$(Iterator.scala:941) at scala.collection.AbstractIterator.foreach(Iterator.scala:1429) at scala.collection.IterableLike.foreach(IterableLike.scala:74) at scala.collection.IterableLike.foreach$(IterableLike.scala:73) at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99) at scala.collection.TraversableLike.map(TraversableLike.scala:238) at scala.collection.TraversableLike.map$(TraversableLike.scala:231) at org.apache.spark.sql.types.StructType.map(StructType.scala:99) at org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter.convert(ParquetSchemaConverter.scala:326) at org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport.init(ParquetWriteSupport.scala:97) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:388) at org.apache.parquet.hadoop.ParquetOutputFormat.getRecordWriter(ParquetOutputFormat.java:349) at org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter.<init>(ParquetOutputWriter.scala:37) at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$$anon$1.newInstance(ParquetFileFormat.scala:150) at org.apache.spark.sql.execution.datasources.SingleDirectoryDataWriter.newOutputWriter(FileFormatDataWriter.scala:124) at org.apache.spark.sql.execution.datasources.SingleDirectoryDataWriter.<init>(FileFormatDataWriter.scala:109) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:264) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$15(FileFormatWriter.scala:205) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:127) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:441) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:444) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` So, I think it would be better to disallow negative scale totally and make behaviors above be consistent. ### Does this PR introduce any user-facing change? Yes, if `spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=false`, user couldn't create Decimal value with negative scale anymore. ### How was this patch tested? Added new tests in `ExpressionParserSuite` and `DecimalSuite`; Updated `SQLQueryTestSuite`. Closes #26881 from Ngone51/nonnegative-scale. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent af70542 commit ff39c92

File tree

24 files changed

+178
-99
lines changed

24 files changed

+178
-99
lines changed

docs/sql-migration-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ license: |
257257
- Since Spark 3.0, the unary arithmetic operator plus(`+`) only accepts string, numeric and interval type values as inputs. Besides, `+` with a integral string representation will be coerced to double value, e.g. `+'1'` results `1.0`. In Spark version 2.4 and earlier, this operator is ignored. There is no type checking for it, thus, all type values with a `+` prefix are valid, e.g. `+ array(1, 2)` is valid and results `[1, 2]`. Besides, there is no type coercion for it at all, e.g. in Spark 2.4, the result of `+'1'` is string `1`.
258258

259259
- Since Spark 3.0, day-time interval strings are converted to intervals with respect to the `from` and `to` bounds. If an input string does not match to the pattern defined by specified bounds, the `ParseException` exception is thrown. For example, `interval '2 10:20' hour to minute` raises the exception because the expected format is `[+|-]h[h]:[m]m`. In Spark version 2.4, the `from` bound was not taken into account, and the `to` bound was used to truncate the resulted interval. For instance, the day-time interval string from the showed example is converted to `interval 10 hours 20 minutes`. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.fromDayTimeString.enabled` to `true`.
260+
261+
- Since Spark 3.0, negative scale of decimal is not allowed by default, e.g. data type of literal like `1E10BD` is `DecimalType(11, 0)`. In Spark version 2.4 and earlier, it was `DecimalType(2, -9)`. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.allowNegativeScaleOfDecimal.enabled` to `true`.
260262

261263
- Since Spark 3.0, the `date_add` and `date_sub` functions only accepts int, smallint, tinyint as the 2nd argument, fractional and string types are not valid anymore, e.g. `date_add(cast('1964-05-23' as date), '12.34')` will cause `AnalysisException`. In Spark version 2.4 and earlier, if the 2nd argument is fractional or string value, it will be coerced to int value, and the result will be a date value of `1964-06-04`.
262264

python/pyspark/sql/tests/test_types.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,14 @@ def test_create_dataframe_from_dict_respects_schema(self):
204204
self.assertEqual(df.columns, ['b'])
205205

206206
def test_negative_decimal(self):
207-
df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
208-
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
209-
actual = list(map(lambda r: int(r.value), ret))
210-
self.assertEqual(actual, [0, 10])
207+
try:
208+
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=true")
209+
df = self.spark.createDataFrame([(1, ), (11, )], ["value"])
210+
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
211+
actual = list(map(lambda r: int(r.value), ret))
212+
self.assertEqual(actual, [0, 10])
213+
finally:
214+
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=false")
211215

212216
def test_create_dataframe_from_objects(self):
213217
data = [MyObject(1, "1"), MyObject(2, "2")]

python/pyspark/sql/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -867,8 +867,6 @@ def _parse_datatype_json_string(json_string):
867867
>>> complex_maptype = MapType(complex_structtype,
868868
... complex_arraytype, False)
869869
>>> check_datatype(complex_maptype)
870-
>>> # Decimal with negative scale.
871-
>>> check_datatype(DecimalType(1,-1))
872870
"""
873871
return _parse_datatype_json_value(json.loads(json_string))
874872

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ object Literal {
6363
case s: String => Literal(UTF8String.fromString(s), StringType)
6464
case c: Char => Literal(UTF8String.fromString(c.toString), StringType)
6565
case b: Boolean => Literal(b, BooleanType)
66-
case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d))
66+
case d: BigDecimal =>
67+
val decimal = Decimal(d)
68+
Literal(decimal, DecimalType.fromDecimal(decimal))
6769
case d: JavaBigDecimal =>
68-
Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale()))
70+
val decimal = Decimal(d)
71+
Literal(decimal, DecimalType.fromDecimal(decimal))
6972
case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale))
7073
case i: Instant => Literal(instantToMicros(i), TimestampType)
7174
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,6 +1953,15 @@ object SQLConf {
19531953
.booleanConf
19541954
.createWithDefault(false)
19551955

1956+
val LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED =
1957+
buildConf("spark.sql.legacy.allowNegativeScaleOfDecimal.enabled")
1958+
.internal()
1959+
.doc("When set to true, negative scale of Decimal type is allowed. For example, " +
1960+
"the type of number 1E10BD under legacy mode is DecimalType(2, -9), but is " +
1961+
"Decimal(11, 0) in non legacy mode.")
1962+
.booleanConf
1963+
.createWithDefault(false)
1964+
19561965
val LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED =
19571966
buildConf("spark.sql.legacy.createHiveTableByDefault.enabled")
19581967
.internal()
@@ -2681,6 +2690,9 @@ class SQLConf extends Serializable with Logging {
26812690
def exponentLiteralAsDecimalEnabled: Boolean =
26822691
getConf(SQLConf.LEGACY_EXPONENT_LITERAL_AS_DECIMAL_ENABLED)
26832692

2693+
def allowNegativeScaleOfDecimalEnabled: Boolean =
2694+
getConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED)
2695+
26842696
def createHiveTableByDefaultEnabled: Boolean =
26852697
getConf(SQLConf.LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT_ENABLED)
26862698

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.math.{BigInteger, MathContext, RoundingMode}
2323
import scala.util.Try
2424

2525
import org.apache.spark.annotation.Unstable
26+
import org.apache.spark.sql.internal.SQLConf
2627

2728
/**
2829
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -89,6 +90,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
8990
* and return it, or return null if it cannot be set due to overflow.
9091
*/
9192
def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
93+
DecimalType.checkNegativeScale(scale)
9294
if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
9395
// We can't represent this compactly as a long without risking overflow
9496
if (precision < 19) {
@@ -113,6 +115,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
113115
* Set this Decimal to the given BigDecimal value, with a given precision and scale.
114116
*/
115117
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
118+
DecimalType.checkNegativeScale(scale)
116119
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
117120
if (decimalVal.precision > precision) {
118121
throw new ArithmeticException(
@@ -136,10 +139,16 @@ final class Decimal extends Ordered[Decimal] with Serializable {
136139
// result. For example, the precision of 0.01 equals to 1 based on the definition, but
137140
// the scale is 2. The expected precision should be 2.
138141
this._precision = decimal.scale
142+
this._scale = decimal.scale
143+
} else if (decimal.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
144+
this._precision = decimal.precision - decimal.scale
145+
this._scale = 0
146+
// set scale to 0 to correct unscaled value
147+
this.decimalVal = decimal.setScale(0)
139148
} else {
140149
this._precision = decimal.precision
150+
this._scale = decimal.scale
141151
}
142-
this._scale = decimal.scale
143152
this
144153
}
145154

@@ -375,6 +384,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
375384
if (precision == this.precision && scale == this.scale) {
376385
return true
377386
}
387+
DecimalType.checkNegativeScale(scale)
378388
// First, update our longVal if we can, or transfer over to using a BigDecimal
379389
if (decimalVal.eq(null)) {
380390
if (scale < _scale) {
@@ -583,6 +593,7 @@ object Decimal {
583593
* Creates a decimal from unscaled, precision and scale without checking the bounds.
584594
*/
585595
def createUnsafe(unscaled: Long, precision: Int, scale: Int): Decimal = {
596+
DecimalType.checkNegativeScale(scale)
586597
val dec = new Decimal()
587598
dec.longVal = unscaled
588599
dec._precision = precision

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.reflect.runtime.universe.typeTag
2424
import org.apache.spark.annotation.Stable
2525
import org.apache.spark.sql.AnalysisException
2626
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
27+
import org.apache.spark.sql.internal.SQLConf
2728

2829
/**
2930
* The data type representing `java.math.BigDecimal` values.
@@ -41,6 +42,8 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
4142
@Stable
4243
case class DecimalType(precision: Int, scale: Int) extends FractionalType {
4344

45+
DecimalType.checkNegativeScale(scale)
46+
4447
if (scale > precision) {
4548
throw new AnalysisException(
4649
s"Decimal scale ($scale) cannot be greater than precision ($precision).")
@@ -141,20 +144,26 @@ object DecimalType extends AbstractDataType {
141144
}
142145

143146
private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match {
144-
case v: Short => fromBigDecimal(BigDecimal(v))
145-
case v: Int => fromBigDecimal(BigDecimal(v))
146-
case v: Long => fromBigDecimal(BigDecimal(v))
147+
case v: Short => fromDecimal(Decimal(BigDecimal(v)))
148+
case v: Int => fromDecimal(Decimal(BigDecimal(v)))
149+
case v: Long => fromDecimal(Decimal(BigDecimal(v)))
147150
case _ => forType(literal.dataType)
148151
}
149152

150-
private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = {
151-
DecimalType(Math.max(d.precision, d.scale), d.scale)
152-
}
153+
private[sql] def fromDecimal(d: Decimal): DecimalType = DecimalType(d.precision, d.scale)
153154

154155
private[sql] def bounded(precision: Int, scale: Int): DecimalType = {
155156
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
156157
}
157158

159+
private[sql] def checkNegativeScale(scale: Int): Unit = {
160+
if (scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) {
161+
throw new AnalysisException(s"Negative scale is not allowed: $scale. " +
162+
s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal.enabled=true " +
163+
s"to enable legacy mode to allow it.")
164+
}
165+
}
166+
158167
/**
159168
* Scale adjustment implementation is based on Hive's one, which is itself inspired to
160169
* SQLServer's one. In particular, when a result precision is greater than
@@ -164,7 +173,8 @@ object DecimalType extends AbstractDataType {
164173
* This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
165174
*/
166175
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
167-
// Assumption:
176+
// Assumptions:
177+
checkNegativeScale(scale)
168178
assert(precision >= scale)
169179

170180
if (precision <= MAX_PRECISION) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.plans.PlanTest
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union}
29+
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.types._
3031

3132

@@ -273,12 +274,14 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
273274
}
274275

275276
test("SPARK-24468: operations on decimals with negative scale") {
276-
val a = AttributeReference("a", DecimalType(3, -10))()
277-
val b = AttributeReference("b", DecimalType(1, -1))()
278-
val c = AttributeReference("c", DecimalType(35, 1))()
279-
checkType(Multiply(a, b), DecimalType(5, -11))
280-
checkType(Multiply(a, c), DecimalType(38, -9))
281-
checkType(Multiply(b, c), DecimalType(37, 0))
277+
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
278+
val a = AttributeReference("a", DecimalType(3, -10))()
279+
val b = AttributeReference("b", DecimalType(1, -1))()
280+
val c = AttributeReference("c", DecimalType(35, 1))()
281+
checkType(Multiply(a, b), DecimalType(5, -11))
282+
checkType(Multiply(a, c), DecimalType(38, -9))
283+
checkType(Multiply(b, c), DecimalType(37, 0))
284+
}
282285
}
283286

284287
/** strength reduction for integer/decimal comparisons */

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Locale
2222

2323
import org.apache.spark.SparkFunSuite
2424
import org.apache.spark.sql.catalyst.plans.SQLHelper
25+
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types._
2627

2728
class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
@@ -147,13 +148,15 @@ class CSVInferSchemaSuite extends SparkFunSuite with SQLHelper {
147148
val options = new CSVOptions(Map.empty[String, String], false, "GMT")
148149
val inferSchema = new CSVInferSchema(options)
149150

150-
// 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
151-
assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") ==
152-
DecimalType(4, -9))
151+
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
152+
// 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
153+
assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") ==
154+
DecimalType(4, -9))
155+
}
153156

154157
// BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20.
155158
val value = "12345678901234567890.01234567890123456789"
156-
assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType)
159+
assert(inferSchema.inferField(DecimalType(3, 0), value) == DoubleType)
157160

158161
// Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
159162
assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == DecimalType(20, 0))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,15 +1048,9 @@ class CastSuite extends CastSuiteBase {
10481048
assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable)
10491049
assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false)
10501050

1051-
assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable)
1052-
assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false)
1053-
assert(cast(Decimal("995"), DecimalType(2, -1)).nullable)
1054-
assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false)
1055-
10561051
assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false)
10571052
assert(cast(true, DecimalType(1, 1)).nullable)
10581053

1059-
10601054
checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
10611055
checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
10621056
checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
@@ -1095,17 +1089,9 @@ class CastSuite extends CastSuiteBase {
10951089

10961090
checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003))
10971091
checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003))
1098-
checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000))
1099-
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000))
1100-
checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null)
1101-
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null)
11021092
checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null)
11031093

11041094
checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995))
1105-
checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000))
1106-
checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000))
1107-
checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null)
1108-
checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null)
11091095

11101096
checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
11111097
checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
@@ -1119,6 +1105,23 @@ class CastSuite extends CastSuiteBase {
11191105

11201106
checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1))
11211107
checkEvaluation(cast(true, DecimalType(1, 1)), null)
1108+
1109+
withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") {
1110+
assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable)
1111+
assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false)
1112+
assert(cast(Decimal("995"), DecimalType(2, -1)).nullable)
1113+
assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false)
1114+
1115+
checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000))
1116+
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000))
1117+
checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null)
1118+
checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null)
1119+
1120+
checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000))
1121+
checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000))
1122+
checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null)
1123+
checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null)
1124+
}
11221125
}
11231126

11241127
test("SPARK-28470: Cast should honor nullOnOverflow property") {

0 commit comments

Comments
 (0)