Skip to content

Commit

Permalink
Respecting required field of inner case classes in Coder macro (#4645)
Browse files Browse the repository at this point in the history
* Respecting required  field of inner case classes in Coder macro

* Fix scala 2.12 compilation

* fix scalafix+compile errors

* Added negative unsupported test scenarios

* Failing on all inner classes

* Update scio-test/src/test/scala/com/spotify/scio/coders/CoderTest.scala

Co-authored-by: Michel Davit <[email protected]>

* Addressing the comment

Co-authored-by: Michel Davit <[email protected]>
  • Loading branch information
shnapz and RustedBones authored Jan 13, 2023
1 parent ea7cb3d commit 7242316
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,32 @@ object LowPriorityCoderDerivation {
new CaseClassConstructor(caseClass.getClass.getName)
}

private class CaseClassConstructor[T] private (private val className: String)
extends Serializable {
// We can call rawConstruct on an empty CaseClass instance
@transient lazy val ctx: CaseClass[Coder, T] = ClosureCleaner
.instantiateClass(Class.forName(className))
.asInstanceOf[CaseClass[Coder, T]]
private class CaseClassConstructor[T] private (
private val className: String
) extends Serializable {

@transient lazy val ctxClass: Class[_] = Class.forName(className)

@transient lazy val ctx: CaseClass[Coder, T] = {
ClosureCleaner.outerFieldOf(ctxClass) match {
/* The field "$outer" is added by scala compiler to a case class if it is declared inside
another class. And the constructor of that compiled class requires outer field to be not
null.
If "$outer" is present it's an inner class and this scenario is officially not supported
by Scio */
case Some(_) =>
throw new Throwable(
s"Found an $$outer field in $ctxClass. Possibly it is an attempt to use inner case " +
"class in a Scio transformation. Inner case classes are not supported in Scio " +
"auto-derived macros. Move the case class to the package level or define a custom " +
"coder."
)
/* If "$outer" field is absent then T is not an inner class, we create an empty instance
of ctx */
case None =>
ClosureCleaner.instantiateClass(ctxClass).asInstanceOf[CaseClass[Coder, T]]
}
}

def rawConstruct(fieldValues: Seq[Any]): T = ctx.rawConstruct(fieldValues)
}
Expand Down Expand Up @@ -88,6 +108,7 @@ trait LowPriorityCoderDerivation {
def join[T: ClassTag](ctx: CaseClass[Coder, T]): Coder[T] = {
val typeName = ctx.typeName.full
val constructor = CaseClassConstructor(ctx)

if (ctx.isValueClass) {
val p = ctx.parameters.head
Coder.xmap(p.typeclass.asInstanceOf[Coder[Any]])(
Expand All @@ -99,6 +120,7 @@ trait LowPriorityCoderDerivation {
} else {
Coder.ref(typeName) {
val cs = Array.ofDim[(String, Coder[Any])](ctx.parameters.length)

ctx.parameters.foreach { p =>
cs.update(p.index, p.label -> p.typeclass.asInstanceOf[Coder[Any]])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,20 @@ object CoderAssertions {
def assert(value: T)(implicit c: Coder[T], eq: Equality[T]): Assertion
}

def roundtrip[T](opts: PipelineOptions = DefaultPipelineOptions): CoderAssertion[T] =
def roundtripWithCustomAssert[T](
opts: PipelineOptions = DefaultPipelineOptions
)(customAssertEquality: (T, T) => Assertion): CoderAssertion[T] =
new CoderAssertion[T] {
override def assert(value: T)(implicit c: Coder[T], eq: Equality[T]): Assertion = {
val beamCoder = CoderMaterializer.beamWithDefault(c, o = opts)
val result = roundtripWithCoder(beamCoder, value)
customAssertEquality(value, result)
}
}

def roundtrip[T](
opts: PipelineOptions = DefaultPipelineOptions
): CoderAssertion[T] =
new CoderAssertion[T] {
override def assert(value: T)(implicit c: Coder[T], eq: Equality[T]): Assertion = {
val beamCoder = CoderMaterializer.beamWithDefault(c, o = opts)
Expand Down Expand Up @@ -110,4 +123,9 @@ object CoderAssertions {

result should ===(value)
}

private def roundtripWithCoder[T](beamCoder: BCoder[T], value: T): T = {
val bytes = CoderUtils.encodeToByteArray(beamCoder, value)
CoderUtils.decodeFromByteArray(beamCoder, bytes)
}
}
98 changes: 97 additions & 1 deletion scio-test/src/test/scala/com/spotify/scio/coders/CoderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,39 @@ final case class AnyValExample(value: String) extends AnyVal
// Non deterministic
final case class NonDeterministic(a: Double, b: Double)

class ClassWrapper() {
case class InnerCaseClass(str: String)

def runWithImplicit(implicit
c: Coder[InnerCaseClass]
): Unit =
InnerCaseClass("51") coderShould roundtrip()

def run(): Unit =
InnerCaseClass("51") coderShould roundtrip()
}

object TopLevelObject {
case class InnerCaseClass(str: String)
}

final class CoderTest extends AnyFlatSpec with Matchers {

val userId: UserId = UserId(Seq[Byte](1, 2, 3, 4))
val user: User = User(userId, "johndoe", "[email protected]")

/*
* Case class nested inside another class. Do not move outside
* */
case class InnerCaseClass(str: String)

/*
* Object nested inside another class. Do not move outside
* */
object InnerObject {
case class InnerCaseClass(str: String)
}

def materialize[T](coder: Coder[T]): BCoder[T] =
CoderMaterializer.beam(PipelineOptionsFactory.create(), coder)

Expand All @@ -135,7 +164,7 @@ final class CoderTest extends AnyFlatSpec with Matchers {
4.5 coderShould roundtrip()
}

it should "support Scala collections" in {
"Coders" should "support Scala collections" in {
import scala.collection.BitSet

val nil: Seq[String] = Nil
Expand Down Expand Up @@ -168,6 +197,73 @@ final class CoderTest extends AnyFlatSpec with Matchers {
CoderProperties.structuralValueConsistentWithEquals(bmc, m, m)
}

"Coders" should "not support inner case classes" in {
{
the[Throwable] thrownBy {
InnerObject coderShould roundtrip()
}
}.getMessage should include(
"Found an $outer field in class com.spotify.scio.coders.CoderTest$$"
)

val cw = new ClassWrapper()
try {
cw.runWithImplicit
throw new Throwable("Is expected to throw when passing implicit from outer class")
} catch {
case e: NullPointerException =>
// In this case outer field is called "$cw" and it is hard to wrap it with proper exception
// so we allow it to fail with NullPointerException
e.getMessage should be(null)
}

{
the[Throwable] thrownBy {
cw.InnerCaseClass("49") coderShould roundtrip()
}
}.getMessage should startWith(
"Found an $outer field in class com.spotify.scio.coders.CoderTest$$"
)

{
the[Throwable] thrownBy {
cw.run()
}
}.getMessage should startWith(
"Found an $outer field in class com.spotify.scio.coders.ClassWrapper$$"
)

{
the[Throwable] thrownBy {
InnerCaseClass("42") coderShould roundtrip()
}
}.getMessage should startWith(
"Found an $outer field in class com.spotify.scio.coders.CoderTest$$"
)

case class ClassInsideMethod(str: String)

{
the[Throwable] thrownBy {
ClassInsideMethod("50") coderShould roundtrip()
}
}.getMessage should startWith(
"Found an $outer field in class com.spotify.scio.coders.CoderTest$$"
)

{
the[Throwable] thrownBy {
InnerObject.InnerCaseClass("42") coderShould roundtrip()
}
}.getMessage should startWith(
"Found an $outer field in class com.spotify.scio.coders.CoderTest$$"
)
}

"Coders" should "support inner classes in objects" in {
TopLevelObject.InnerCaseClass("42") coderShould roundtrip()
}

it should "support tuples" in {
import shapeless.syntax.std.tuple._
val t22 = (
Expand Down

0 comments on commit 7242316

Please sign in to comment.