Skip to content

Commit bcb1d98

Browse files
Spark 3.5: StaticInvoke compatibility in ManualTypedEncoder (8/9-arg); deterministic + arg types; fallback to constructor. Add test.
Signed-off-by: Anudeep Konaboina <[email protected]>
1 parent 45c6a1a commit bcb1d98

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@ __pycache__
6262
.coverage*
6363
*.jar
6464
.python-version
65+
66+
# Ignore SBT lock files
67+
project/.boot/**/sbt.boot.lock
68+
project/.boot/**/sbt.components.lock
69+
project/.ivy/.sbt.ivy.lock
70+
project/.sbtboot/**/.sbt.cache.lock

core/src/main/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoder.scala

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,95 @@ import scala.reflect.{ClassTag, classTag}
1010

1111
/** Can be useful for non Scala types and for complicated case classes with implicits in the constructor. */
1212
object ManualTypedEncoder {
13+
/** Constructs StaticInvoke via reflection to handle Spark 3.4/3.5 constructor differences. */
14+
private def staticInvokeSafely(
15+
targetClass: Class[_],
16+
dataType: DataType,
17+
functionName: String,
18+
arguments: Seq[Expression],
19+
propagateNull: Boolean,
20+
returnNullable: Boolean
21+
): InvokeLike = {
22+
val ctors = classOf[StaticInvoke].getConstructors
23+
val boxedPropagateNull = Boolean.box(propagateNull)
24+
val boxedReturnNullable = Boolean.box(returnNullable)
25+
val TRUE = Boolean.box(true)
26+
27+
val ctor = ctors.maxBy(_.getParameterTypes.length)
28+
val argTypes: Seq[DataType] = arguments.map(_.dataType)
29+
val targetModuleClass: Class[_] = {
30+
val moduleName = targetClass.getName + "$"
31+
try Class.forName(moduleName)
32+
catch { case _: ClassNotFoundException => targetClass }
33+
}
34+
35+
def tryInvoke(onClass: Class[_]): InvokeLike = ctor.getParameterTypes.length match {
36+
case 9 =>
37+
// (Class, DataType, String, Seq, Seq, boolean, boolean, boolean, Option)
38+
ctor.newInstance(
39+
onClass,
40+
dataType,
41+
functionName,
42+
arguments,
43+
argTypes,
44+
boxedPropagateNull,
45+
boxedReturnNullable,
46+
TRUE,
47+
None
48+
).asInstanceOf[InvokeLike]
49+
case 8 =>
50+
// (Class, DataType, String, Seq, Seq, boolean, boolean, boolean)
51+
ctor.newInstance(
52+
onClass,
53+
dataType,
54+
functionName,
55+
arguments,
56+
argTypes,
57+
boxedPropagateNull,
58+
boxedReturnNullable,
59+
TRUE
60+
).asInstanceOf[InvokeLike]
61+
case _ =>
62+
throw new NotImplementedError("StaticInvoke constructor has unexpected shape")
63+
}
64+
ctor.getParameterTypes.length match {
65+
case 9 | 8 =>
66+
// Try on the class first (top-level case classes have static forwarders), then on module
67+
val firstError = try {
68+
return tryInvoke(targetClass)
69+
} catch { case t: Throwable => t }
70+
tryInvoke(targetModuleClass)
71+
case _ =>
72+
throw new NotImplementedError("StaticInvoke constructor has unexpected shape")
73+
}
74+
}
75+
76+
/** Detect whether a static forwarder for `apply` of given arity exists on the given class. */
77+
private def hasStaticApply(onClass: Class[_], arity: Int): Boolean = {
78+
import java.lang.reflect.Modifier
79+
onClass.getMethods.exists { m =>
80+
m.getName == "apply" && m.getParameterCount == arity && Modifier.isStatic(m.getModifiers)
81+
}
82+
}
83+
1384
/** Invokes apply from the companion object. */
1485
def staticInvoke[T: ClassTag](
1586
fields: List[RecordEncoderField],
1687
fieldNameModify: String => String = identity,
1788
isNullable: Boolean = true
18-
): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => StaticInvoke(classTag.runtimeClass, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false) }, fieldNameModify, isNullable)
89+
): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) =>
90+
val target = classTag.runtimeClass
91+
val moduleName = target.getName + "$"
92+
val moduleClass = try Class.forName(moduleName) catch { case _: ClassNotFoundException => null }
93+
val arity = newArgs.length
94+
if ((hasStaticApply(target, arity)) || (moduleClass != null && hasStaticApply(moduleClass, arity))) {
95+
staticInvokeSafely(target, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false)
96+
}
97+
else {
98+
// Fall back to directly invoking the primary constructor
99+
NewInstance(target, newArgs, jvmRepr, propagateNull = true)
100+
}
101+
}, fieldNameModify, isNullable)
19102

20103
/** Invokes object constructor. */
21104
def newInstance[T: ClassTag](

core/src/test/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoderSpec.scala

Whitespace-only changes.

0 commit comments

Comments
 (0)