Skip to content

Commit 51a88e4

Browse files
authored
Merge pull request tofu-tf#286 from zawodskoj/refinements
Add support for deriving typeclasses with refinements
2 parents fa76852 + e6b700b commit 51a88e4

File tree

3 files changed

+77
-16
lines changed

3 files changed

+77
-16
lines changed

core/src/main/scala/derevo/Derevo.scala

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ class Derevo(val c: blackbox.Context) {
1010
type Newtype = NewtypeP[Tree]
1111
type NameAndTypes = NameAndTypesP[c.Type]
1212

13-
val DelegatingSymbol = typeOf[delegating].typeSymbol
14-
val PhantomSymbol = typeOf[phantom].typeSymbol
15-
val PassTypeArgsSymbol = typeOf[PassTypeArgs].typeSymbol
13+
val DelegatingSymbol = typeOf[delegating].typeSymbol
14+
val PhantomSymbol = typeOf[phantom].typeSymbol
15+
val PassTypeArgsSymbol = typeOf[PassTypeArgs].typeSymbol
16+
val KeepRefinementsSymbol = typeOf[KeepRefinements].typeSymbol
1617

1718
val instanceDefs = Vector(
1819
)
@@ -179,9 +180,9 @@ class Derevo(val c: blackbox.Context) {
179180
}
180181

181182
val (mode, call) = tree match {
182-
case q"$obj(..$args)" => (nameAndTypes(obj), tree)
183+
case q"$obj.$method(..$args)" => (nameAndTypes(obj), tree)
183184

184-
case q"$obj.$method($args)" => (nameAndTypes(obj), tree)
185+
case q"$obj(..$args)" => (nameAndTypes(obj), tree)
185186

186187
case q"$obj" =>
187188
val call = newType.fold(q"$obj.instance")(t => q"$obj.newtype[${t.underlying}].instance")
@@ -207,14 +208,33 @@ class Derevo(val c: blackbox.Context) {
207208

208209
val callWithT = if (mode.passArgs) q"$call[$outTyp]" else call
209210

211+
def fixFirstTypeParam = {
212+
val nothingT = c.typeOf[Nothing]
213+
214+
c.typecheck(call, silent = true) match {
215+
case q"$method[$nothing, ..$remainingTpes](..$args)" if nothing.tpe == nothingT =>
216+
q"$method[$outTyp, ..$remainingTpes](..$args)"
217+
case q"$method[$nothing, ..$remainingTpes]" if nothing.tpe == nothingT =>
218+
q"$method[$outTyp, ..$remainingTpes]"
219+
case _ => tree
220+
}
221+
}
222+
210223
if (allTparams.isEmpty || allTparams.length <= mode.drop) {
211-
val resTc = if (newType.isDefined) mode.newtype else mode.to
212-
val resT = mkAppliedType(resTc, tq"$typRef")
224+
if (mode.keepRefinements) {
225+
q"""
226+
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
227+
implicit val $tn = $fixFirstTypeParam
228+
"""
229+
} else {
230+
val resTc = if (newType.isDefined) mode.newtype else mode.to
231+
val resT = mkAppliedType(resTc, tq"$typRef")
213232

214-
q"""
215-
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
216-
implicit val $tn: $resT = $callWithT
217-
"""
233+
q"""
234+
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
235+
implicit val $tn: $resT = $callWithT
236+
"""
237+
}
218238
} else {
219239

220240
val implicits =
@@ -231,10 +251,17 @@ class Derevo(val c: blackbox.Context) {
231251
}
232252
else Nil
233253

234-
q"""
235-
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
236-
implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT
237-
"""
254+
if (mode.keepRefinements) {
255+
q"""
256+
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
257+
implicit def $tn[..$tparams](implicit ..$implicits) = $fixFirstTypeParam
258+
"""
259+
} else {
260+
q"""
261+
@java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
262+
implicit def $tn[..$tparams](implicit ..$implicits): $resT = $callWithT
263+
"""
264+
}
238265
}
239266
}
240267

@@ -266,7 +293,12 @@ class Derevo(val c: blackbox.Context) {
266293
case _ => false
267294
}
268295

269-
nt.copy(passArgs = passArgs)
296+
val keepRefinements = objTyp.baseType(KeepRefinementsSymbol) match {
297+
case TypeRef(_, _, _) => true
298+
case _ => false
299+
}
300+
301+
nt.copy(passArgs = passArgs, keepRefinements = keepRefinements)
270302
}
271303

272304
trait DerivationMatcher {
@@ -319,6 +351,7 @@ object Derevo {
319351
newtype: typ,
320352
drop: Int,
321353
cascade: Boolean,
354+
keepRefinements: Boolean = false,
322355
passArgs: Boolean = false,
323356
)
324357
}

core/src/main/scala/derevo/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package derevo {
88
/**
99
*/
1010
trait PassTypeArgs
11+
trait KeepRefinements
1112

1213
class delegating(to: String, args: Any*) extends StaticAnnotation
1314
class phantom extends StaticAnnotation
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package derevo.tests
2+
3+
import derevo._
4+
5+
trait Trait[T]
6+
7+
object refinedTrait extends Derivation[Trait] with KeepRefinements {
8+
def apply[T, A](a: A): Trait[T] { type Refinement = A } = new Trait[T] { type Refinement = A }
9+
def singleTparam[T](dummy: Int): Trait[T] { type Refinement = String } = new Trait[T] { type Refinement = String }
10+
}
11+
12+
object refinedTraitWithInstanceMethod extends Derivation[Trait] with KeepRefinements {
13+
def instance[T]: Trait[T] { type Refinement = String } = new Trait[T] { type Refinement = String }
14+
}
15+
16+
object Test {
17+
@derive(refinedTrait("123")) case class Foo()
18+
@derive(refinedTrait.singleTparam(123)) case class Bar()
19+
@derive(refinedTraitWithInstanceMethod) case class Baz()
20+
@derive(refinedTrait(123)) case class PolymorphicFoo[@phantom Arg](arg: Arg)
21+
22+
val nonRefined: Trait[Foo] = implicitly
23+
val refinedFoo: Trait[Foo] { type Refinement = String } = implicitly
24+
val refinedBar: Trait[Bar] { type Refinement = String } = implicitly
25+
val refinedBaz: Trait[Baz] { type Refinement = String } = implicitly
26+
val refinedPolymorphic: Trait[PolymorphicFoo[Foo]] { type Refinement = Int } = implicitly
27+
}

0 commit comments

Comments
 (0)