Skip to content

Commit 90a1a47

Browse files
perheromissis
andauthored
Subschema improvements (part 1) (#480)
Co-authored-by: omissis <[email protected]>
1 parent 4b5de3c commit 90a1a47

28 files changed

+1177
-40
lines changed

pkg/generator/schema_generator.go

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,17 @@ func (g *schemaGenerator) generateRootType() error {
7272
}
7373

7474
func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type, error) {
75-
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
76-
if decl, ok := schemaOutput.declsByName[t.Ref]; ok {
77-
if decl != nil {
78-
return decl.Type, nil
75+
defName, fileName, err := g.extractRefNames(t)
76+
if err != nil {
77+
return nil, err
78+
}
79+
80+
if fileName == "" {
81+
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
82+
if decl, ok := schemaOutput.declsByName[defName]; ok {
83+
if decl != nil {
84+
return &codegen.NamedType{Decl: decl}, nil
85+
}
7986
}
8087
}
8188
}
@@ -92,11 +99,6 @@ func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type,
9299
return codegen.EmptyInterfaceType{}, nil
93100
}
94101

95-
defName, fileName, err := g.extractRefNames(t)
96-
if err != nil {
97-
return nil, err
98-
}
99-
100102
schema := g.schema
101103
sg := g
102104

@@ -689,11 +691,11 @@ func (g *schemaGenerator) generateStructType(t *schemas.Type, scope nameScope) (
689691
}
690692

691693
if len(t.AnyOf) > 0 {
692-
return g.generateAnyOfType(t.AnyOf, scope)
694+
return g.generateAnyOfType(t, scope)
693695
}
694696

695697
if len(t.AllOf) > 0 {
696-
return g.generateAllOfType(t.AllOf, scope)
698+
return g.generateAllOfType(t, scope)
697699
}
698700

699701
// Checking .Not here because `false` is unmarshalled to .Not = Type{}.
@@ -853,13 +855,13 @@ func (g *schemaGenerator) addStructField(
853855
return nil
854856
}
855857

856-
func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
857-
if len(anyOf) == 0 {
858+
func (g *schemaGenerator) generateAnyOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
859+
if len(t.AnyOf) == 0 {
858860
return nil, errEmptyInAnyOf
859861
}
860862

861863
isCycle := false
862-
rAnyOf, hasNull := g.resolveRefs(anyOf, false)
864+
rAnyOf, hasNull := g.resolveRefs(t.AnyOf, false)
863865

864866
for i, typ := range rAnyOf {
865867
typ.SetSubSchemaTypeElem()
@@ -890,22 +892,26 @@ func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameSco
890892
return codegen.EmptyInterfaceType{}, nil
891893
}
892894

893-
anyOfType, err := schemas.AnyOf(rAnyOf)
895+
anyOfType, err := schemas.AnyOf(rAnyOf, t)
894896
if err != nil {
895897
return nil, fmt.Errorf("could not merge anyOf types: %w", err)
896898
}
897899

900+
anyOfType.AnyOf = nil
901+
898902
return g.generateTypeInline(anyOfType, scope)
899903
}
900904

901-
func (g *schemaGenerator) generateAllOfType(allOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
902-
rAllOf, _ := g.resolveRefs(allOf, true)
905+
func (g *schemaGenerator) generateAllOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
906+
rAllOf, _ := g.resolveRefs(t.AllOf, true)
903907

904-
allOfType, err := schemas.AllOf(rAllOf)
908+
allOfType, err := schemas.AllOf(rAllOf, t)
905909
if err != nil {
906910
return nil, fmt.Errorf("could not merge allOf types: %w", err)
907911
}
908912

913+
allOfType.AllOf = nil
914+
909915
return g.generateTypeInline(allOfType, scope)
910916
}
911917

@@ -960,11 +966,11 @@ func (g *schemaGenerator) generateTypeInline(t *schemas.Type, scope nameScope) (
960966
}
961967

962968
if len(t.AnyOf) > 0 {
963-
return g.generateAnyOfType(t.AnyOf, scope)
969+
return g.generateAnyOfType(t, scope)
964970
}
965971

966972
if len(t.AllOf) > 0 {
967-
return g.generateAllOfType(t.AllOf, scope)
973+
return g.generateAllOfType(t, scope)
968974
}
969975

970976
if len(t.Type) == 2 && typeIsNullable {
@@ -1263,6 +1269,20 @@ func (g *schemaGenerator) resolveRef(t *schemas.Type) (*schemas.Type, error) {
12631269
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
12641270
}
12651271

1272+
// After resolving the ref type we lose info about the original schema
1273+
// so rewrite all nested refs to include the original schema id
1274+
_, fileName, err := g.extractRefNames(t)
1275+
if err != nil {
1276+
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
1277+
}
1278+
1279+
if fileName != "" {
1280+
err = ntyp.Decl.SchemaType.ConvertAllRefs(fileName)
1281+
if err != nil {
1282+
return nil, fmt.Errorf("convert refs: %w", err)
1283+
}
1284+
}
1285+
12661286
ntyp.Decl.SchemaType.Dereferenced = true
12671287

12681288
g.schemaTypesByRef[t.Ref] = ntyp.Decl.SchemaType

pkg/schemas/model.go

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"fmt"
2929
"reflect"
3030
"slices"
31+
"strings"
3132

3233
"dario.cat/mergo"
3334
)
@@ -238,6 +239,12 @@ func (value *Type) SetSubSchemaTypeElem() {
238239
value.subSchemaTypeElem = true
239240
}
240241

242+
func (value *Type) ConvertAllRefs(absolutePath string) error {
243+
val := reflect.ValueOf(value).Elem()
244+
245+
return updateAllRefsValues(&val, absolutePath)
246+
}
247+
241248
// UnmarshalJSON accepts booleans as schemas where `true` is equivalent to `{}`
242249
// and `false` is equivalent to `{"not": {}}`.
243250
func (value *Type) UnmarshalJSON(raw []byte) error {
@@ -280,8 +287,8 @@ func (value *Type) UnmarshalJSON(raw []byte) error {
280287
return nil
281288
}
282289

283-
func AllOf(types []*Type) (*Type, error) {
284-
typ, err := MergeTypes(types)
290+
func AllOf(types []*Type, baseType *Type) (*Type, error) {
291+
typ, err := MergeTypes(types, baseType)
285292
if err != nil {
286293
return nil, err
287294
}
@@ -291,8 +298,8 @@ func AllOf(types []*Type) (*Type, error) {
291298
return typ, nil
292299
}
293300

294-
func AnyOf(types []*Type) (*Type, error) {
295-
typ, err := MergeTypes(types)
301+
func AnyOf(types []*Type, baseType *Type) (*Type, error) {
302+
typ, err := MergeTypes(types, baseType)
296303
if err != nil {
297304
return nil, err
298305
}
@@ -303,14 +310,14 @@ func AnyOf(types []*Type) (*Type, error) {
303310
return typ, nil
304311
}
305312

306-
func MergeTypes(types []*Type) (*Type, error) {
313+
func MergeTypes(types []*Type, baseType *Type) (*Type, error) {
307314
if len(types) == 0 {
308315
return nil, ErrEmptyTypesList
309316
}
310317

311318
result := &Type{}
312319

313-
if isPrimitiveTypeList(types) {
320+
if isPrimitiveTypeList(types, result.Type) {
314321
return result, nil
315322
}
316323

@@ -319,6 +326,10 @@ func MergeTypes(types []*Type) (*Type, error) {
319326
mergo.WithTransformers(typeListTransformer{}),
320327
}
321328

329+
if err := mergo.Merge(result, baseType, opts...); err != nil {
330+
return nil, fmt.Errorf("%w: %w", ErrCannotMergeTypes, err)
331+
}
332+
322333
for _, t := range types {
323334
if err := mergo.Merge(result, t, opts...); err != nil {
324335
return nil, fmt.Errorf("%w: %w", ErrCannotMergeTypes, err)
@@ -328,6 +339,57 @@ func MergeTypes(types []*Type) (*Type, error) {
328339
return result, nil
329340
}
330341

342+
func updateAllRefsValues(structValue *reflect.Value, refPath string) error {
343+
switch structValue.Kind() { //nolint:exhaustive
344+
case reflect.Struct:
345+
for i := range structValue.NumField() {
346+
field := structValue.Field(i)
347+
name := structValue.Type().Field(i).Name
348+
349+
switch field.Kind() { //nolint:exhaustive
350+
case reflect.String:
351+
fieldVal := field.String()
352+
if name == "Ref" && fieldVal != "" && field.CanSet() {
353+
if strings.HasPrefix(fieldVal, "#") {
354+
field.SetString(refPath + fieldVal)
355+
}
356+
}
357+
358+
default:
359+
if err := updateAllRefsValues(&field, refPath); err != nil {
360+
return fmt.Errorf("struct error: %w", err)
361+
}
362+
}
363+
}
364+
365+
case reflect.Ptr:
366+
elem := structValue.Elem()
367+
if !structValue.IsNil() {
368+
if err := updateAllRefsValues(&elem, refPath); err != nil {
369+
return fmt.Errorf("ptr error: %w", err)
370+
}
371+
}
372+
373+
case reflect.Map:
374+
for _, key := range structValue.MapKeys() {
375+
val := structValue.MapIndex(key)
376+
if err := updateAllRefsValues(&val, refPath); err != nil {
377+
return fmt.Errorf("map error: %w", err)
378+
}
379+
}
380+
381+
case reflect.Slice, reflect.Array:
382+
for i := range structValue.Len() {
383+
field := structValue.Index(i)
384+
if err := updateAllRefsValues(&field, refPath); err != nil {
385+
return fmt.Errorf("slice error: %w", err)
386+
}
387+
}
388+
}
389+
390+
return nil
391+
}
392+
331393
type typeListTransformer struct{}
332394

333395
func (t typeListTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {

pkg/schemas/types.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ func CleanNameForSorting(name string) string {
3131
return name
3232
}
3333

34-
func isPrimitiveTypeList(types []*Type) bool {
34+
func isPrimitiveTypeList(types []*Type, baseType TypeList) bool {
35+
if len(baseType) > 0 && !IsPrimitiveType(baseType[0]) {
36+
return false
37+
}
38+
3539
for _, typ := range types {
3640
if len(typ.Type) == 0 {
3741
continue

0 commit comments

Comments
 (0)