-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Array and Map literals for the java scala codebase #50
base: main
Are you sure you want to change the base?
Changes from 1 commit
e755e2e
04d5618
bd38bac
c83ab5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -414,4 +414,21 @@ object JavaUtils { | |
} | ||
} | ||
|
||
def toScala(x: Any): Any = { | ||
import collection.JavaConverters._ | ||
x match { | ||
case y: java.util.Map[_, _] => | ||
mapAsScalaMap(y).map { | ||
case (d, v) => toScala(d) -> toScala(v) | ||
}.toMap | ||
case y: java.lang.Iterable[_] => | ||
iterableAsScalaIterable(y).toList.map { | ||
item: Any => toScala(item) | ||
} | ||
case y: java.util.Iterator[_] => | ||
toScala(y) | ||
case _ => | ||
x | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Utility function that recursively convert a java map/array to a scala map/array |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,22 @@ | ||
package com.snowflake.snowpark.internal.analyzer | ||
|
||
import com.snowflake.snowpark.internal.Utils | ||
import com.snowflake.snowpark.types._ | ||
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat | ||
|
||
import java.math.{BigDecimal => JBigDecimal} | ||
import java.sql.{Date, Timestamp} | ||
import java.util.TimeZone | ||
import java.math.{BigDecimal => JBigDecimal} | ||
|
||
import com.snowflake.snowpark.types._ | ||
import com.snowflake.snowpark.types.convertToSFType | ||
import javax.xml.bind.DatatypeConverter | ||
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat | ||
|
||
object DataTypeMapper { | ||
// milliseconds per day | ||
private val MILLIS_PER_DAY = 24 * 3600 * 1000L | ||
// microseconds per millisecond | ||
private val MICROS_PER_MILLIS = 1000L | ||
|
||
private[analyzer] def stringToSql(str: String): String = | ||
// Escapes all backslashes, single quotes and new line. | ||
// Escapes all backslashes, single quotes and new line. | ||
"'" + str | ||
.replaceAll("\\\\", "\\\\\\\\") | ||
.replaceAll("'", "''") | ||
|
@@ -25,63 +25,77 @@ object DataTypeMapper { | |
/* | ||
* Convert a value with DataType to a snowflake compatible sql | ||
*/ | ||
private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = { | ||
dataType match { | ||
case None => "NULL" | ||
case Some(dt) => | ||
(value, dt) match { | ||
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null => | ||
"NULL" | ||
case (_, IntegerType) if value == null => "NULL :: int" | ||
case (_, ShortType) if value == null => "NULL :: smallint" | ||
case (_, ByteType) if value == null => "NULL :: tinyint" | ||
case (_, LongType) if value == null => "NULL :: bigint" | ||
case (_, FloatType) if value == null => "NULL :: float" | ||
case (_, StringType) if value == null => "NULL :: string" | ||
case (_, DoubleType) if value == null => "NULL :: double" | ||
case (_, BooleanType) if value == null => "NULL :: boolean" | ||
case (_, BinaryType) if value == null => "NULL :: binary" | ||
case _ if value == null => "NULL" | ||
case (v: String, StringType) => stringToSql(v) | ||
case (v: Byte, ByteType) => v + s" :: tinyint" | ||
case (v: Short, ShortType) => v + s" :: smallint" | ||
case (v: Any, IntegerType) => v + s" :: int" | ||
case (v: Long, LongType) => v + s" :: bigint" | ||
case (v: Boolean, BooleanType) => s"$v :: boolean" | ||
// Float type doesn't have a suffix | ||
case (v: Float, FloatType) => | ||
val castedValue = v match { | ||
case _ if v.isNaN => "'NaN'" | ||
case Float.PositiveInfinity => "'Infinity'" | ||
case Float.NegativeInfinity => "'-Infinity'" | ||
case _ => s"'$v'" | ||
} | ||
s"$castedValue :: FLOAT" | ||
case (v: Double, DoubleType) => | ||
v match { | ||
case _ if v.isNaN => "'NaN'" | ||
case Double.PositiveInfinity => "'Infinity'" | ||
case Double.NegativeInfinity => "'-Infinity'" | ||
case _ => v + "::DOUBLE" | ||
} | ||
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" | ||
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" | ||
case (v: Int, DateType) => | ||
s"DATE '${SnowflakeDateTimeFormat | ||
.fromSqlFormat(Utils.DateInputFormat) | ||
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'" | ||
case (v: Long, TimestampType) => | ||
s"TIMESTAMP '${SnowflakeDateTimeFormat | ||
.fromSqlFormat(Utils.TimestampInputFormat) | ||
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'" | ||
case (v: Array[Byte], BinaryType) => | ||
s"'${DatatypeConverter.printHexBinary(v)}' :: binary" | ||
case _ => | ||
throw new UnsupportedOperationException( | ||
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") | ||
private[analyzer] def toSql(literal: TLiteral): String = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now takes an instance of |
||
literal match { | ||
case Literal(value, dataType) => (value, dataType) match { | ||
case (_, None) => "NULL" | ||
case (value, Some(dt)) => | ||
(value, dt) match { | ||
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null => | ||
"NULL" | ||
case (_, IntegerType) if value == null => "NULL :: int" | ||
case (_, ShortType) if value == null => "NULL :: smallint" | ||
case (_, ByteType) if value == null => "NULL :: tinyint" | ||
case (_, LongType) if value == null => "NULL :: bigint" | ||
case (_, FloatType) if value == null => "NULL :: float" | ||
case (_, StringType) if value == null => "NULL :: string" | ||
case (_, DoubleType) if value == null => "NULL :: double" | ||
case (_, BooleanType) if value == null => "NULL :: boolean" | ||
case (_, BinaryType) if value == null => "NULL :: binary" | ||
case _ if value == null => "NULL" | ||
case (v: String, StringType) => stringToSql(v) | ||
case (v: Byte, ByteType) => v + s" :: tinyint" | ||
case (v: Short, ShortType) => v + s" :: smallint" | ||
case (v: Any, IntegerType) => v + s" :: int" | ||
case (v: Long, LongType) => v + s" :: bigint" | ||
case (v: Boolean, BooleanType) => s"$v :: boolean" | ||
// Float type doesn't have a suffix | ||
case (v: Float, FloatType) => | ||
val castedValue = v match { | ||
case _ if v.isNaN => "'NaN'" | ||
case Float.PositiveInfinity => "'Infinity'" | ||
case Float.NegativeInfinity => "'-Infinity'" | ||
case _ => s"'$v'" | ||
} | ||
s"$castedValue :: FLOAT" | ||
case (v: Double, DoubleType) => | ||
v match { | ||
case _ if v.isNaN => "'NaN'" | ||
case Double.PositiveInfinity => "'Infinity'" | ||
case Double.NegativeInfinity => "'-Infinity'" | ||
case _ => v + "::DOUBLE" | ||
} | ||
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" | ||
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" | ||
case (v: Int, DateType) => | ||
s"DATE '${ | ||
SnowflakeDateTimeFormat | ||
.fromSqlFormat(Utils.DateInputFormat) | ||
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT")) | ||
}'" | ||
case (v: Long, TimestampType) => | ||
s"TIMESTAMP '${ | ||
SnowflakeDateTimeFormat | ||
.fromSqlFormat(Utils.TimestampInputFormat) | ||
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3) | ||
}'" | ||
case _ => | ||
throw new UnsupportedOperationException( | ||
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") | ||
} | ||
} | ||
case arrayLiteral: ArrayLiteral => | ||
if (arrayLiteral.dataTypeOption == Some(BinaryType)) { | ||
val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray | ||
s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary" | ||
} else { | ||
"ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")") | ||
} | ||
case mapLiteral: MapLiteral => | ||
"OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) => | ||
Seq(toSql(keyLiteral), toSql(valueLiteral)) | ||
}.mkString("(", ", ", ")") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use the |
||
} | ||
|
||
} | ||
|
||
private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer | |
|
||
import com.snowflake.snowpark.internal.ErrorMessage | ||
import com.snowflake.snowpark.types._ | ||
|
||
import java.math.{BigDecimal => JavaBigDecimal} | ||
import java.sql.{Date, Timestamp} | ||
import java.time.{Instant, LocalDate} | ||
|
||
import scala.math.BigDecimal | ||
|
||
private[snowpark] object Literal { | ||
// Snowflake max precision for decimal is 38 | ||
private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION) | ||
|
@@ -16,7 +15,7 @@ private[snowpark] object Literal { | |
decimal.round(bigDecimalRoundContext) | ||
} | ||
|
||
def apply(v: Any): Literal = v match { | ||
def apply(v: Any): TLiteral = v match { | ||
case i: Int => Literal(i, Option(IntegerType)) | ||
case l: Long => Literal(l, Option(LongType)) | ||
case d: Double => Literal(d, Option(DoubleType)) | ||
|
@@ -36,7 +35,8 @@ private[snowpark] object Literal { | |
case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType)) | ||
case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType)) | ||
case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType)) | ||
case a: Array[Byte] => Literal(a, Option(BinaryType)) | ||
case s: Seq[Any] => ArrayLiteral(s) | ||
case m: Map[Any, Any] => MapLiteral(m) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We build the Array and Map Literal here, the data types are infered recursively in the classes themselves. |
||
case null => Literal(null, None) | ||
case v: Literal => v | ||
case _ => | ||
|
@@ -45,10 +45,48 @@ private[snowpark] object Literal { | |
|
||
} | ||
|
||
private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType]) | ||
extends Expression { | ||
private[snowpark] trait TLiteral extends Expression { | ||
def value: Any | ||
def dataTypeOption: Option[DataType] | ||
|
||
override def children: Seq[Expression] = Seq.empty | ||
|
||
override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression = | ||
this | ||
} | ||
|
||
private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral | ||
|
||
private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral { | ||
val elementsLiterals: Seq[TLiteral] = value.map(Literal(_)) | ||
val dataTypeOption = inferArrayType | ||
|
||
private[analyzer] def inferArrayType(): Option[DataType] = { | ||
elementsLiterals.flatMap(_.dataTypeOption).distinct match { | ||
case Seq() => None | ||
case Seq(ByteType) => Some(BinaryType) | ||
case Seq(dt) => Some(ArrayType(dt)) | ||
case Seq(_, _*) => Some(ArrayType(VariantType)) | ||
} | ||
} | ||
} | ||
|
||
private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral { | ||
val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) } | ||
val dataTypeOption = inferMapType | ||
|
||
private[analyzer] def inferMapType(): Option[MapType] = { | ||
entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match { | ||
case Seq() => None | ||
case Seq(StringType) => | ||
val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct | ||
valuesTypes match { | ||
case Seq() => None | ||
case Seq(dt) => Some(MapType(StringType, dt)) | ||
case Seq(_, _*) => Some(MapType(StringType, VariantType)) | ||
} | ||
case _ => | ||
throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All keys of a Map must be of String type otherwise an exception is thrown. Maybe the exception could be more precise here ? |
||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal | |
import com.snowflake.snowpark.FileOperationCommand._ | ||
import com.snowflake.snowpark.Row | ||
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject} | ||
import com.snowflake.snowpark.types.{DataType, convertToSFType} | ||
import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType} | ||
|
||
package object analyzer { | ||
// constant string | ||
|
@@ -446,7 +446,9 @@ package object analyzer { | |
val types = output.map(_.dataType) | ||
val rows = data.map { row => | ||
val cells = row.toSeq.zip(types).map { | ||
case (v, dType) => DataTypeMapper.toSql(v, Option(dType)) | ||
case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v)) | ||
case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v)) | ||
case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the code is a bit puzzling me, I think Array and Map are not supported by the Values function, so maybe I should throw something here ? In any case, the design of the current code is a bit flawed here. Take for example a list of ints that represent DateType, when creating the |
||
} | ||
cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infinite loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really sorry about that, I copied an old function from our code base and it is indeed an infinite loop 💥
Fixed it by converting to scala iterator and calling the method
toScala
to each of its elements.