Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEARCHREL-547] Include the fix for SNOW-899560 #76

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fips-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.snowflake</groupId>
<artifactId>snowpark-fips</artifactId>
<version>1.9.0-SNAPSHOT</version>
<version>1.9.0</version>
<name>${project.artifactId}</name>
<description>Snowflake's DataFrame API</description>
<url>https://www.snowflake.com/</url>
Expand Down
2 changes: 1 addition & 1 deletion java_doc.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.snowflake</groupId>
<artifactId>snowpark-java</artifactId>
<version>1.9.0-SNAPSHOT</version>
<version>1.9.0</version>
<name>${project.artifactId}</name>
<description>Snowflake's DataFrame API</description>
<url>https://www.snowflake.com/</url>
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.snowflake</groupId>
<artifactId>snowpark</artifactId>
<version>1.9.0-SNAPSHOT</version>
<version>1.9.0-coveo-1</version>
<name>${project.artifactId}</name>
<description>Snowflake's DataFrame API</description>
<url>https://www.snowflake.com/</url>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public static Column toScalar(DataFrame df) {
* @return The result column
*/
public static Column lit(Object literal) {
return new Column(com.snowflake.snowpark.functions.lit(literal));
return new Column(com.snowflake.snowpark.functions.lit(JavaUtils.toScala(literal)));
}

/**
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,15 @@ object JavaUtils {
}
}

def toScala(element: Any): Any = {
import collection.JavaConverters._
element match {
case map: java.util.Map[_, _] => mapAsScalaMap(map).map {
case (k, v) => toScala(k) -> toScala(v)
}.toMap
case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala)
case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala)
case _ => element
}
}
}
2 changes: 1 addition & 1 deletion src/main/scala/com/snowflake/snowpark/internal/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.Random

object Utils extends Logging {
val Version: String = "1.9.0-SNAPSHOT"
val Version: String = "1.9.0"
// Package name of snowpark on server side
val SnowparkPackageName = "com.snowflake:snowpark"
val PackageNameDelimiter = ":"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.Utils
import com.snowflake.snowpark.types._
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.TimeZone
import java.math.{BigDecimal => JBigDecimal}

import com.snowflake.snowpark.types._
import com.snowflake.snowpark.types.convertToSFType
import javax.xml.bind.DatatypeConverter
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

object DataTypeMapper {
// milliseconds per day
private val MILLIS_PER_DAY = 24 * 3600 * 1000L
// microseconds per millisecond
private val MICROS_PER_MILLIS = 1000L

private[analyzer] def stringToSql(str: String): String =
// Escapes all backslashes, single quotes and new line.
// Escapes all backslashes, single quotes and new line.
"'" + str
.replaceAll("\\\\", "\\\\\\\\")
.replaceAll("'", "''")
Expand All @@ -25,63 +25,77 @@ object DataTypeMapper {
/*
* Convert a value with DataType to a snowflake compatible sql
*/
private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = {
dataType match {
case None => "NULL"
case Some(dt) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'"
case (v: Array[Byte], BinaryType) =>
s"'${DatatypeConverter.printHexBinary(v)}' :: binary"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
private[analyzer] def toSql(literal: TLiteral): String = {
literal match {
case Literal(value, dataType) => (value, dataType) match {
case (_, None) => "NULL"
case (value, Some(dt)) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))
}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)
}'"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
}
}
case arrayLiteral: ArrayLiteral =>
if (arrayLiteral.dataTypeOption == Some(BinaryType)) {
val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray
s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary"
} else {
"ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")")
}
case mapLiteral: MapLiteral =>
"OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) =>
Seq(toSql(keyLiteral), toSql(valueLiteral))
}.mkString("(", ", ", ")")
}

}

private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types._

import java.math.{BigDecimal => JavaBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import scala.math.BigDecimal

private[snowpark] object Literal {
// Snowflake max precision for decimal is 38
private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION)
Expand All @@ -16,7 +15,7 @@ private[snowpark] object Literal {
decimal.round(bigDecimalRoundContext)
}

def apply(v: Any): Literal = v match {
def apply(v: Any): TLiteral = v match {
case i: Int => Literal(i, Option(IntegerType))
case l: Long => Literal(l, Option(LongType))
case d: Double => Literal(d, Option(DoubleType))
Expand All @@ -36,7 +35,8 @@ private[snowpark] object Literal {
case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType))
case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType))
case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType))
case a: Array[Byte] => Literal(a, Option(BinaryType))
case s: Seq[Any] => ArrayLiteral(s)
case m: Map[Any, Any] => MapLiteral(m)
case null => Literal(null, None)
case v: Literal => v
case _ =>
Expand All @@ -45,10 +45,48 @@ private[snowpark] object Literal {

}

private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType])
extends Expression {
private[snowpark] trait TLiteral extends Expression {
def value: Any
def dataTypeOption: Option[DataType]

override def children: Seq[Expression] = Seq.empty

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}

private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral

private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral {
val elementsLiterals: Seq[TLiteral] = value.map(Literal(_))
val dataTypeOption = inferArrayType

private[analyzer] def inferArrayType(): Option[DataType] = {
elementsLiterals.flatMap(_.dataTypeOption).distinct match {
case Seq() => None
case Seq(ByteType) => Some(BinaryType)
case Seq(dt) => Some(ArrayType(dt))
case Seq(_, _*) => Some(ArrayType(VariantType))
}
}
}

private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral {
val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) }
val dataTypeOption = inferMapType

private[analyzer] def inferMapType(): Option[MapType] = {
entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match {
case Seq() => None
case Seq(StringType) =>
val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct
valuesTypes match {
case Seq() => None
case Seq(dt) => Some(MapType(StringType, dt))
case Seq(_, _*) => Some(MapType(StringType, VariantType))
}
case _ =>
throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ private object SqlGenerator extends Logging {
case UnspecifiedFrame => ""
case SpecialFrameBoundaryExtractor(str) => str

case Literal(value, dataType) =>
DataTypeMapper.toSql(value, dataType)
case l: TLiteral =>
DataTypeMapper.toSql(l)
case attr: Attribute => quoteName(attr.name)
// unresolved expression
case UnresolvedAttribute(name) => name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal
import com.snowflake.snowpark.FileOperationCommand._
import com.snowflake.snowpark.Row
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject}
import com.snowflake.snowpark.types.{DataType, convertToSFType}
import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType}

package object analyzer {
// constant string
Expand Down Expand Up @@ -446,7 +446,9 @@ package object analyzer {
val types = output.map(_.dataType)
val rows = data.map { row =>
val cells = row.toSeq.zip(types).map {
case (v, dType) => DataTypeMapper.toSql(v, Option(dType))
case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v))
case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v))
case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType)))
}
cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis)
}
Expand Down
Loading