Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions pkg/generator/schema_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ func (g *schemaGenerator) generateRootType() error {
}

func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type, error) {
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
if decl, ok := schemaOutput.declsByName[t.Ref]; ok {
if decl != nil {
return decl.Type, nil
defName, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, err
}

if fileName == "" {
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
if decl, ok := schemaOutput.declsByName[defName]; ok {
if decl != nil {
return &codegen.NamedType{Decl: decl}, nil
}
}
}
}
Expand All @@ -92,11 +99,6 @@ func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type,
return codegen.EmptyInterfaceType{}, nil
}

defName, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, err
}

schema := g.schema
sg := g

Expand Down Expand Up @@ -689,11 +691,11 @@ func (g *schemaGenerator) generateStructType(t *schemas.Type, scope nameScope) (
}

if len(t.AnyOf) > 0 {
return g.generateAnyOfType(t.AnyOf, scope)
return g.generateAnyOfType(t, scope)
}

if len(t.AllOf) > 0 {
return g.generateAllOfType(t.AllOf, scope)
return g.generateAllOfType(t, scope)
}

// Checking .Not here because `false` is unmarshalled to .Not = Type{}.
Expand Down Expand Up @@ -853,13 +855,13 @@ func (g *schemaGenerator) addStructField(
return nil
}

func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
if len(anyOf) == 0 {
func (g *schemaGenerator) generateAnyOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
if len(t.AnyOf) == 0 {
return nil, errEmptyInAnyOf
}

isCycle := false
rAnyOf, hasNull := g.resolveRefs(anyOf, false)
rAnyOf, hasNull := g.resolveRefs(t.AnyOf, false)

for i, typ := range rAnyOf {
typ.SetSubSchemaTypeElem()
Expand Down Expand Up @@ -890,22 +892,26 @@ func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameSco
return codegen.EmptyInterfaceType{}, nil
}

anyOfType, err := schemas.AnyOf(rAnyOf)
anyOfType, err := schemas.AnyOf(rAnyOf, t)
if err != nil {
return nil, fmt.Errorf("could not merge anyOf types: %w", err)
}

anyOfType.AnyOf = nil

return g.generateTypeInline(anyOfType, scope)
}

func (g *schemaGenerator) generateAllOfType(allOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
rAllOf, _ := g.resolveRefs(allOf, true)
func (g *schemaGenerator) generateAllOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
rAllOf, _ := g.resolveRefs(t.AllOf, true)

allOfType, err := schemas.AllOf(rAllOf)
allOfType, err := schemas.AllOf(rAllOf, t)
if err != nil {
return nil, fmt.Errorf("could not merge allOf types: %w", err)
}

allOfType.AllOf = nil

return g.generateTypeInline(allOfType, scope)
}

Expand Down Expand Up @@ -960,11 +966,11 @@ func (g *schemaGenerator) generateTypeInline(t *schemas.Type, scope nameScope) (
}

if len(t.AnyOf) > 0 {
return g.generateAnyOfType(t.AnyOf, scope)
return g.generateAnyOfType(t, scope)
}

if len(t.AllOf) > 0 {
return g.generateAllOfType(t.AllOf, scope)
return g.generateAllOfType(t, scope)
}

if len(t.Type) == 2 && typeIsNullable {
Expand Down Expand Up @@ -1263,6 +1269,20 @@ func (g *schemaGenerator) resolveRef(t *schemas.Type) (*schemas.Type, error) {
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
}

// After resolving the ref type we lose info about the original schema
// so rewrite all nested refs to include the original schema id
_, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
}

if fileName != "" {
err = ntyp.Decl.SchemaType.ConvertAllRefs(fileName)
if err != nil {
return nil, fmt.Errorf("convert refs: %w", err)
}
}

ntyp.Decl.SchemaType.Dereferenced = true

g.schemaTypesByRef[t.Ref] = ntyp.Decl.SchemaType
Expand Down
74 changes: 68 additions & 6 deletions pkg/schemas/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"fmt"
"reflect"
"slices"
"strings"

"dario.cat/mergo"
)
Expand Down Expand Up @@ -238,6 +239,12 @@ func (value *Type) SetSubSchemaTypeElem() {
value.subSchemaTypeElem = true
}

func (value *Type) ConvertAllRefs(absolutePath string) error {
val := reflect.ValueOf(value).Elem()

return updateAllRefsValues(&val, absolutePath)
}

// UnmarshalJSON accepts booleans as schemas where `true` is equivalent to `{}`
// and `false` is equivalent to `{"not": {}}`.
func (value *Type) UnmarshalJSON(raw []byte) error {
Expand Down Expand Up @@ -280,8 +287,8 @@ func (value *Type) UnmarshalJSON(raw []byte) error {
return nil
}

func AllOf(types []*Type) (*Type, error) {
typ, err := MergeTypes(types)
func AllOf(types []*Type, baseType *Type) (*Type, error) {
typ, err := MergeTypes(types, baseType)
if err != nil {
return nil, err
}
Expand All @@ -291,8 +298,8 @@ func AllOf(types []*Type) (*Type, error) {
return typ, nil
}

func AnyOf(types []*Type) (*Type, error) {
typ, err := MergeTypes(types)
func AnyOf(types []*Type, baseType *Type) (*Type, error) {
typ, err := MergeTypes(types, baseType)
if err != nil {
return nil, err
}
Expand All @@ -303,14 +310,14 @@ func AnyOf(types []*Type) (*Type, error) {
return typ, nil
}

func MergeTypes(types []*Type) (*Type, error) {
func MergeTypes(types []*Type, baseType *Type) (*Type, error) {
if len(types) == 0 {
return nil, ErrEmptyTypesList
}

result := &Type{}

if isPrimitiveTypeList(types) {
if isPrimitiveTypeList(types, result.Type) {
return result, nil
}

Expand All @@ -319,6 +326,10 @@ func MergeTypes(types []*Type) (*Type, error) {
mergo.WithTransformers(typeListTransformer{}),
}

if err := mergo.Merge(result, baseType, opts...); err != nil {
return nil, fmt.Errorf("%w: %w", ErrCannotMergeTypes, err)
}

for _, t := range types {
if err := mergo.Merge(result, t, opts...); err != nil {
return nil, fmt.Errorf("%w: %w", ErrCannotMergeTypes, err)
Expand All @@ -328,6 +339,57 @@ func MergeTypes(types []*Type) (*Type, error) {
return result, nil
}

func updateAllRefsValues(structValue *reflect.Value, refPath string) error {
switch structValue.Kind() { //nolint:exhaustive
case reflect.Struct:
for i := range structValue.NumField() {
field := structValue.Field(i)
name := structValue.Type().Field(i).Name

switch field.Kind() { //nolint:exhaustive
case reflect.String:
fieldVal := field.String()
if name == "Ref" && fieldVal != "" && field.CanSet() {
if strings.HasPrefix(fieldVal, "#") {
field.SetString(refPath + fieldVal)
}
}

default:
if err := updateAllRefsValues(&field, refPath); err != nil {
return fmt.Errorf("struct error: %w", err)
}
}
}

case reflect.Ptr:
elem := structValue.Elem()
if !structValue.IsNil() {
if err := updateAllRefsValues(&elem, refPath); err != nil {
return fmt.Errorf("ptr error: %w", err)
}
}

case reflect.Map:
for _, key := range structValue.MapKeys() {
val := structValue.MapIndex(key)
if err := updateAllRefsValues(&val, refPath); err != nil {
return fmt.Errorf("map error: %w", err)
}
}

case reflect.Slice, reflect.Array:
for i := range structValue.Len() {
field := structValue.Index(i)
if err := updateAllRefsValues(&field, refPath); err != nil {
return fmt.Errorf("slice error: %w", err)
}
}
}

return nil
}

type typeListTransformer struct{}

func (t typeListTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
Expand Down
6 changes: 5 additions & 1 deletion pkg/schemas/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func CleanNameForSorting(name string) string {
return name
}

func isPrimitiveTypeList(types []*Type) bool {
func isPrimitiveTypeList(types []*Type, baseType TypeList) bool {
if len(baseType) > 0 && !IsPrimitiveType(baseType[0]) {
return false
}

for _, typ := range types {
if len(typ.Type) == 0 {
continue
Expand Down
Loading