Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No aliases for containers #74

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 105 additions & 38 deletions api-first-core/src/main/scala/de/zalando/apifirst/TypeFlattener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object TypeDeduplicator extends TypeAnalyzer {

private def equal(app: StrictModel) = (t1: Type, t2: Type) =>
t1 == t2 ||
((notHierarhyRoot(t1, app) || notHierarhyRoot(t2, app)) && isSameTypeDef(t1)(t2) && isSameConstraints(t1)(t2))
((notHierarchyRoot(t1, app) || notHierarchyRoot(t2, app)) && isSameTypeDef(t1)(t2) && isSameConstraints(t1)(t2))

/**
* Removes redundant type definitions changing pointing references
Expand All @@ -30,15 +30,17 @@ object TypeDeduplicator extends TypeAnalyzer {
duplicate map replaceSingle(app) map apply getOrElse app
}

private def notHierarhyRoot(t: Type, app: StrictModel) = !app.discriminators.contains(t.name)
private def notHierarchyRoot(t: Type, app: StrictModel) = !app.discriminators.contains(t.name)

private def replaceSingle(model: StrictModel): Type => StrictModel = tpe => {
val duplicates = model.typeDefs.filter { d => equal(model)(tpe, d._2) }
val duplicateNames = sortByDiscriminatorOrPathLength(model.discriminators, duplicates)
val bestMatch :: refsToRemove = duplicateNames
val typesToRewrite = model.typeDefs filterNot { t => refsToRemove.contains(t._1) }
val callsWithCorrectRefs = model.calls map { c => replaceReferenceInCall(refsToRemove, bestMatch)(c) }
val typesWithCorrectRefs = typesToRewrite map { d => replaceReferencesInTypes(refsToRemove, bestMatch)(d._1, d._2) }
val typesWithCorrectRefs = typesToRewrite map { (d: (Reference, Type)) =>
replaceReferencesInTypes(refsToRemove, bestMatch)(d._1, d._2)
}
val newParams = model.params map replaceReferenceInParameter(refsToRemove, bestMatch)
model.copy(typeDefs = typesWithCorrectRefs, params = newParams, calls = callsWithCorrectRefs)
}
Expand All @@ -56,33 +58,46 @@ object TypeDeduplicator extends TypeAnalyzer {
call.copy(resultTypes = TypesResponseInfo(resultTypes, default))
}

private def replaceReferencesInTypes(duplicateRefs: Seq[Reference], target: Reference): (Reference, Type) => (Reference, Type) = (ref, tpe) => ref -> {
tpe match {
case c: Container =>
c.tpe match {
case r: TypeRef if duplicateRefs.contains(r.name) => c.withType(TypeRef(target))
case o => c
}

case c: Composite =>
val newDescendants = c.descendants map {
case d: TypeRef if duplicateRefs.contains(d.name) => TypeRef(target)
case o => o
}
c.withTypes(newDescendants)
private def replaceType(tpe: Type, duplicateRefs: Seq[Reference], target: Reference): Type = tpe match {
case c: Container => replaceContainerType(c, duplicateRefs, target)
case c: Composite => replaceCompositeType(c, duplicateRefs, target)
case t: TypeDef => replaceTypeDefType(t, duplicateRefs, target)
case n: TypeRef if duplicateRefs.contains(n.name) => TypeRef(target)
case _ => tpe
}

case t: TypeDef =>
val newFields = t.fields.map {
case f @ Field(_, tpe: TypeRef) if duplicateRefs.contains(tpe.name) => f.copy(tpe = TypeRef(target))
case o => o
}
val newName = if (duplicateRefs.contains(t.name)) target else t.name
t.copy(name = newName, fields = newFields)
private def replaceContainerType(c: Container, duplicateRefs: Seq[Reference], target: Reference): Type = c.tpe match {
case r: TypeRef if duplicateRefs.contains(r.name) => c.withType(TypeRef(target))
case r: Container if isRecursiveContainerType(r) => c.withType(replaceType(r, duplicateRefs, target))
case o => c
}

case n: TypeRef if duplicateRefs.contains(n.name) => TypeRef(target)
private def replaceCompositeType(c: Composite, duplicateRefs: Seq[Reference], target: Reference) = {
val newDescendants = c.descendants map {
case d: TypeRef if duplicateRefs.contains(d.name) => TypeRef(target)
case o => o
}
c.withTypes(newDescendants)
}

case _ => tpe
private def replaceTypeDefType(t: TypeDef, duplicateRefs: Seq[Reference], target: Reference) = {
val newFields = t.fields.map {
case f @ Field(_, tpe: TypeRef) if duplicateRefs.contains(tpe.name) => f.copy(tpe = TypeRef(target))
case f @ Field(_, c: Container) if isRecursiveContainerType(c) && duplicateRefs.contains(getInnerContainerType(c).name) =>
f.copy(tpe = replaceContainerType(c, duplicateRefs, target))
case o => o
}
val newName = if (duplicateRefs.contains(t.name)) target else t.name
t.copy(name = newName, fields = newFields)
}

private def getInnerContainerType(c: Container): Type = c.tpe match {
case inner: Container if isRecursiveContainerType(inner) => getInnerContainerType(inner)
case inner => inner
}

private def replaceReferencesInTypes(duplicateRefs: Seq[Reference], target: Reference): (Reference, Type) => (Reference, Type) = { (ref, tpe) =>
ref -> replaceType(tpe, duplicateRefs, target)
}

private def replaceReferenceInParameter(duplicateRefs: Seq[Reference], target: Reference): ((ParameterRef, Parameter)) => (ParameterRef, Parameter) = {
Expand Down Expand Up @@ -124,7 +139,7 @@ object TypeFlattener extends TypeAnalyzer {
}

private def flatten0(typeDefs: TypeLookupTable): TypeLookupTable = {
val flatTypeDefs = typeDefs flatMap { case (k, v) => extractComplexType(k, v) }
val flatTypeDefs = typeDefs flatMap { case (ref, tpe) => extractComplexType(ref, tpe) }
if (flatTypeDefs == typeDefs)
flatTypeDefs
else
Expand All @@ -133,20 +148,29 @@ object TypeFlattener extends TypeAnalyzer {

private def extractComplexType(ref: Reference, typeDef: Type): Seq[(Reference, Type)] = typeDef match {
case t: TypeDef if complexFields(t).nonEmpty =>
val (changedFields, extractedTypes) = t.fields.filter(complexField).map(createTypeFromField(t)).unzip
val (changedFields, extractedTypes) = complexFields(t).map(createTypeFromField(t)).unzip
val newFields = t.fields.filterNot(complexField) ++ changedFields
val newTypeDef = t.copy(fields = newFields)
(ref -> newTypeDef) +: extractedTypes
case t: TypeDef if containerFieldsWithComplexType(t).nonEmpty =>
val (changedFields, extractedTypes) = containerFieldsWithComplexType(t).map(createRecursiveTypeFromContainerField(t)).unzip
val newFields = t.fields.filterNot(containerFieldWithComplexType) ++ changedFields
val newTypeDef = t.copy(fields = newFields)
(ref -> newTypeDef) +: extractedTypes
case t: EnumTrait =>
val leafTypes = t.leaves.map { l => ref / l.fieldValue -> l }
(ref -> t) +: leafTypes.toSeq
case c: Container if isComplexType(c.tpe) =>
case c: Container if isRecursiveContainerType(c) && isRecursiveComplexType(c) =>
val (newType, newRef, extractedType) = createRecursiveTypeFromContainer(ref, c)
Seq(ref -> newType, newRef -> extractedType)
case c: Container if !isRecursiveContainerType(c) && isComplexType(c.tpe) =>
val newRef = ref / c.getClass.getSimpleName
Seq(ref -> c.withType(TypeRef(newRef)), newRef -> c.tpe)
case c: Composite =>
val (changedTypes, extractedTypes) = c.descendants.filter(isComplexType).
zipWithIndex.map(flattenType(c.getClass.getSimpleName, ref)).unzip
val newTypes = c.descendants.filterNot(isComplexType) ++ changedTypes
val (complexDescendants, simpleDescendants) = c.descendants.partition(isComplexType)
val (changedTypes, extractedTypes) = complexDescendants.zipWithIndex
.map(flattenType(c.getClass.getSimpleName, ref)).unzip
val newTypes = simpleDescendants ++ changedTypes
val newTypeDef = c.withTypes(newTypes)
(ref -> newTypeDef) +: extractedTypes
case _ => Seq(ref -> typeDef)
Expand All @@ -156,12 +180,38 @@ object TypeFlattener extends TypeAnalyzer {

private def complexField: (Field) => Boolean = f => isComplexType(f.tpe)

private def containerFieldsWithComplexType(typeDef: TypeDef): Seq[Field] = typeDef.fields filter containerFieldWithComplexType

private def containerFieldWithComplexType: (Field) => Boolean = f => isRecursiveContainerType(f.tpe) && isRecursiveComplexType(f.tpe)

private def createTypeFromField(t: TypeDef): (Field) => (Field, (Reference, Type)) = field => {
val newReference = TypeRef(t.name / field.name.simple)

val extractedType = field.tpe
(field.copy(tpe = newReference), newReference.name -> extractedType)
}

private def createRecursiveTypeFromContainerField(t: Type): (Field) => (Field, (Reference, Type)) = field => {
val reference = t.name / field.name.simple

val (newType, newReference, extractedType) = createRecursiveTypeFromContainer(reference, field.tpe)

(field.copy(tpe = newType), newReference -> extractedType)
}

private def createRecursiveTypeFromContainer(ref: Reference, t: Type): (Type, Reference, Type) = {
val newReference = ref / t.name.simple

t match {
case c: Container if isRecursiveContainerType(c) =>
val result = createRecursiveTypeFromContainer(newReference, c.tpe)
result.copy(c.withType(result._1))

case t =>
(TypeRef(newReference), newReference, t)
}
}

private def flattenType: (String, Reference) => ((Type, Int)) => (Type, (Reference, Type)) = (name, ref) => pair => {
val (typeDef, index) = pair
val newReference = TypeRef(ref / (name + index))
Expand All @@ -180,9 +230,9 @@ object ParameterDereferencer extends TypeAnalyzer {
private[apifirst] def apply(app: StrictModel): StrictModel = {
var result = app
result.params foreach {
case (name, definition) =>
case (name: ParameterRef, definition: Expr) =>
definition.typeName match {
case tpe if isComplexType(tpe) =>
case tpe if isComplexTypeParam(tpe) =>
val newName = name.name / "ref"
val newReference = TypeRef(newName)
val tps = app.typeDefs + (newName -> tpe)
Expand All @@ -197,10 +247,24 @@ object ParameterDereferencer extends TypeAnalyzer {

trait TypeAnalyzer {
def isComplexType(t: Type): Boolean = t match {
case tpe @ (_: TypeDef | _: Composite | _: Container) => true
case _: TypeDef | _: Composite => true
case t: Container => !isRecursiveContainerType(t)
case _ => false
}

def isComplexTypeParam(t: Type): Boolean = t match {
case _: TypeDef | _: Composite | _: Container => true
case _ => false
}

def isRecursiveComplexType(t: Type): Boolean = t match {
case c: Container if isRecursiveContainerType(c) => isRecursiveComplexType(c.tpe)
case tpe => isComplexType(tpe)
}

def isRecursiveContainerType(t: Type): Boolean =
t.isInstanceOf[Arr] || t.isInstanceOf[ArrResult] || t.isInstanceOf[Opt]

def isSameConstraints(one: Type)(two: Type): Boolean = (one, two) match {
case (c1: Container, c2: Container) if c1.getClass == c2.getClass =>
isSameConstraints(c1.tpe)(c2.tpe)
Expand Down Expand Up @@ -237,8 +301,11 @@ trait TypeAnalyzer {
case _ => false
}

def sameFields(t1: TypeDef, t2: TypeDef): Boolean =
t1.fields.forall(p => t2.fields.exists(e => isSameTypeDef(p.tpe)(e.tpe) && sameNames(p, e)))
def sameFields(t1: TypeDef, t2: TypeDef): Boolean = {
t1.fields.forall { field =>
t2.fields.exists { e => isSameTypeDef(field.tpe)(e.tpe) & sameNames(field, e) }
}
}

type hasSimpleName = { def name: { def simple: String } }

Expand All @@ -249,4 +316,4 @@ trait TypeAnalyzer {
c1.root == c2.root &&
c1.descendants.forall(p => c2.descendants.exists(isSameTypeDef(p)))

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ object StringUtil {
case class ScalaName(ref: Reference) {
import ScalaName._
import StringUtil._
val parts = ref.parts.flatMap(_.split("/").filter(_.nonEmpty)) match {
val parts: List[String] = ref.parts.flatMap(_.split("/").filter(_.nonEmpty)) match {
case Nil =>
throw new IllegalArgumentException(s"At least one part required to construct a name, but got $ref")
case single :: Nil => "" :: removeVars(single) :: Nil
Expand All @@ -172,6 +172,7 @@ case class ScalaName(ref: Reference) {
if (prefix.trim.isEmpty) (withSuffix, capitalize _) else (prefix :: withSuffix, camelize _)
escape(caseTransformer("/", withPrefix.mkString("/")))
}
def typeLongAlias: String = parts.reverse.mkString

def methodName: String = escape(camelize("/", parts.last))
def names: (String, String, String) = (packageName, className, methodName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ class ParameterDereferencerTest extends FunSpec with MustMatchers {
testChangeNothing()
}

it("should dereference container Opt types") {
testContainerType(Opt.apply)
it("should change nothing if parameters contain only containers") {
testContainer()
}
it("should dereference container CatchAll types") {
testContainerType(CatchAll.apply)
}
it("should dereference composition OneOf types") {

it("should dereference composition OneOf types") {
testCompositionType(OneOf.apply)
}
it("should dereference composition AllOf types") {
it("should dereference composition AllOf types") {
testCompositionType(OneOf.apply)
}
it("should dereference TypeDefs") {
Expand All @@ -48,6 +46,18 @@ class ParameterDereferencerTest extends FunSpec with MustMatchers {
checkExpectations(types)(params)
}

def testContainer(): Unit = {
val types = Map[Reference, Type](
reference1 -> Opt(Intgr(Some("Limit for search queries")), Some("some other stuff")),
reference2 -> CatchAll(Intgr(None), None)
)
val params: ParameterLookupTable = Map(
ParameterRef(reference1) -> Parameter("limit", TypeRef(reference1), None, None, "", encode = false, ParameterPlace.BODY),
ParameterRef(reference2) -> Parameter("id", TypeRef(reference2), None, None, "", encode = false, ParameterPlace.BODY)
)
checkExpectations(types)(params)
}

def testContainerType[T](constructor: (Type, TypeMeta) => Type): Unit = {
val types = Map[Reference, Type](
reference1 -> constructor(Intgr(None), TypeMeta(None)),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package de.zalando.apifirst

import de.zalando.apifirst.Domain.{ Field, Opt, Str, TypeDef, TypeMeta, TypeRef }
import de.zalando.apifirst.naming.Reference
import org.scalatest._

import scala.language.implicitConversions

class TypeAnalyzerTest extends FunSpec with Matchers {

private def analyzer = new TypeAnalyzer {}

describe("isComplexType") {
it("should be false for Opt[Str]") {
val tpe = Opt(Str(None, TypeMeta(None, List())), TypeMeta(None))
analyzer.isComplexType(tpe) shouldBe false
}

it("should be false for Str") {
val tpe = Str(None, TypeMeta(None, List()))
analyzer.isComplexType(tpe) shouldBe false
}

it("should be true for TypeDef") {
val tpe = TypeDef(
Reference("⌿definitions⌿Basic"),
Seq(
Field(Reference("⌿definitions⌿Basic⌿optional"), Opt(TypeRef(Reference("⌿definitions⌿Basic⌿optional")), TypeMeta(None, List())))
),
TypeMeta(Some("Named types: 1"), List())
)
analyzer.isComplexType(tpe) shouldBe true
}

it("should be false for Opt[TypeDef]") {
val tpe = Opt(
TypeDef(
Reference("⌿definitions⌿Basic"),
Seq(
Field(Reference("⌿definitions⌿Basic⌿optional"), Opt(TypeRef(Reference("⌿definitions⌿Basic⌿optional")), TypeMeta(None, List())))
),
TypeMeta(Some("Named types: 1"), List())
),
TypeMeta(None)
)
analyzer.isComplexType(tpe) shouldBe false
}
}

describe("isRecursiveComplexType") {
it("should be true for Opt[TypeDef]") {
val tpe = Opt(
TypeDef(
Reference("⌿definitions⌿Basic"),
Seq(
Field(Reference("⌿definitions⌿Basic⌿optional"), Opt(TypeRef(Reference("⌿definitions⌿Basic⌿optional")), TypeMeta(None, List())))
),
TypeMeta(Some("Named types: 1"), List())
),
TypeMeta(None)
)
analyzer.isRecursiveComplexType(tpe) shouldBe true
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,4 @@ class TypeDeduplicatorTest extends FunSpec with MustMatchers {
if (discriminators.isEmpty)
result.params.foreach(_._2.typeName.asInstanceOf[TypeRef].name mustBe reference2)
}

}
Loading