@@ -10,9 +10,10 @@ class Derevo(val c: blackbox.Context) {
1010 type Newtype = NewtypeP [Tree ]
1111 type NameAndTypes = NameAndTypesP [c.Type ]
1212
13- val DelegatingSymbol = typeOf[delegating].typeSymbol
14- val PhantomSymbol = typeOf[phantom].typeSymbol
15- val PassTypeArgsSymbol = typeOf[PassTypeArgs ].typeSymbol
13+ val DelegatingSymbol = typeOf[delegating].typeSymbol
14+ val PhantomSymbol = typeOf[phantom].typeSymbol
15+ val PassTypeArgsSymbol = typeOf[PassTypeArgs ].typeSymbol
16+ val KeepRefinementsSymbol = typeOf[KeepRefinements ].typeSymbol
1617
1718 val instanceDefs = Vector (
1819 )
@@ -179,9 +180,9 @@ class Derevo(val c: blackbox.Context) {
179180 }
180181
181182 val (mode, call) = tree match {
182- case q " $obj(.. $args) " => (nameAndTypes(obj), tree)
183+ case q " $obj. $method (.. $args) " => (nameAndTypes(obj), tree)
183184
184- case q " $obj. $method ( $args) " => (nameAndTypes(obj), tree)
185+ case q " $obj(.. $args) " => (nameAndTypes(obj), tree)
185186
186187 case q " $obj" =>
187188 val call = newType.fold(q " $obj.instance " )(t => q " $obj.newtype[ ${t.underlying}].instance " )
@@ -207,14 +208,33 @@ class Derevo(val c: blackbox.Context) {
207208
208209 val callWithT = if (mode.passArgs) q " $call[ $outTyp] " else call
209210
211+ def fixFirstTypeParam = {
212+ val nothingT = c.typeOf[Nothing ]
213+
214+ c.typecheck(call, silent = true ) match {
215+ case q " $method[ $nothing, .. $remainingTpes](.. $args) " if nothing.tpe == nothingT =>
216+ q " $method[ $outTyp, .. $remainingTpes](.. $args) "
217+ case q " $method[ $nothing, .. $remainingTpes] " if nothing.tpe == nothingT =>
218+ q " $method[ $outTyp, .. $remainingTpes] "
219+ case _ => tree
220+ }
221+ }
222+
210223 if (allTparams.isEmpty || allTparams.length <= mode.drop) {
211- val resTc = if (newType.isDefined) mode.newtype else mode.to
212- val resT = mkAppliedType(resTc, tq " $typRef" )
224+ if (mode.keepRefinements) {
225+ q """
226+ @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
227+ implicit val $tn = $fixFirstTypeParam
228+ """
229+ } else {
230+ val resTc = if (newType.isDefined) mode.newtype else mode.to
231+ val resT = mkAppliedType(resTc, tq " $typRef" )
213232
214- q """
215- @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
216- implicit val $tn: $resT = $callWithT
217- """
233+ q """
234+ @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
235+ implicit val $tn: $resT = $callWithT
236+ """
237+ }
218238 } else {
219239
220240 val implicits =
@@ -231,10 +251,17 @@ class Derevo(val c: blackbox.Context) {
231251 }
232252 else Nil
233253
234- q """
235- @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
236- implicit def $tn[.. $tparams](implicit .. $implicits): $resT = $callWithT
237- """
254+ if (mode.keepRefinements) {
255+ q """
256+ @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
257+ implicit def $tn[.. $tparams](implicit .. $implicits) = $fixFirstTypeParam
258+ """
259+ } else {
260+ q """
261+ @java.lang.SuppressWarnings(scala.Array("org.wartremover.warts.All", "scalafix:All", "all"))
262+ implicit def $tn[.. $tparams](implicit .. $implicits): $resT = $callWithT
263+ """
264+ }
238265 }
239266 }
240267
@@ -266,7 +293,12 @@ class Derevo(val c: blackbox.Context) {
266293 case _ => false
267294 }
268295
269- nt.copy(passArgs = passArgs)
296+ val keepRefinements = objTyp.baseType(KeepRefinementsSymbol ) match {
297+ case TypeRef (_, _, _) => true
298+ case _ => false
299+ }
300+
301+ nt.copy(passArgs = passArgs, keepRefinements = keepRefinements)
270302 }
271303
272304 trait DerivationMatcher {
@@ -319,6 +351,7 @@ object Derevo {
319351 newtype : typ,
320352 drop : Int ,
321353 cascade : Boolean ,
354+ keepRefinements : Boolean = false ,
322355 passArgs : Boolean = false ,
323356 )
324357}
0 commit comments