Skip to content
9 changes: 8 additions & 1 deletion pkg/codegen/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,17 @@ func (p *Package) AddDecl(d Decl) {
}

func (p *Package) hasDecl(d Decl) bool {
n1, ok1 := d.(Named)

for _, pd := range p.Decls {
if pd == d || reflect.DeepEqual(pd, d) {
return true
}

n2, ok2 := pd.(Named)
if ok1 && ok2 && n1.GetName() == n2.GetName() {
return true
}
}

return false
Expand Down Expand Up @@ -277,7 +284,7 @@ type AliasType struct {
}

func (p AliasType) Generate(out *Emitter) error {
out.Printf("type %s = %s", p.Alias, p.Name)
out.Printlnf("type %s = %s", p.Alias, p.Name)

return nil
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/generator/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type Config struct {
// DisableCustomTypesForMaps configures the generator to avoid creating a custom type for maps,
// and to use the map type directly.
DisableCustomTypesForMaps bool
// AliasSingleAllOfAnyOfRefs will convert types with a single nested anyOf or allOf ref type into a type alias.
AliasSingleAllOfAnyOfRefs bool
// PreferOmitzero will use omit omitzero instead of omitempty, note this requires Go 1.24
PreferOmitzero bool
}

type SchemaMapping struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/generator/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

type formatter interface {
addImport(out *codegen.File, declType *codegen.TypeDecl)
getName() string

generate(output *output, declType *codegen.TypeDecl, validators []validator) func(*codegen.Emitter) error
enumMarshal(declType *codegen.TypeDecl) func(*codegen.Emitter) error
Expand Down
9 changes: 9 additions & 0 deletions pkg/generator/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ func (g *Generator) AddFile(fileName string, schema *schemas.Schema) error {
return err
}

if schema.ID != "" {
if _, processed := o.processedSchemas[schema.ID]; processed {
return nil
}

o.processedSchemas[schema.ID] = true
}

return newSchemaGenerator(g, schema, fileName, o).generateRootType()
}

Expand Down Expand Up @@ -213,6 +221,7 @@ func (g *Generator) beginOutput(
declsBySchema: map[*schemas.Type]*codegen.TypeDecl{},
declsByName: map[string]*codegen.TypeDecl{},
unmarshallersByTypeDecl: map[*codegen.TypeDecl]bool{},
processedSchemas: map[string]bool{},
}
g.outputs[id] = output

Expand Down
4 changes: 4 additions & 0 deletions pkg/generator/json_formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,7 @@ func (jf *jsonFormatter) addImport(out *codegen.File, declType *codegen.TypeDecl
}
}
}

func (yf *jsonFormatter) getName() string {
return "json"
}
1 change: 1 addition & 0 deletions pkg/generator/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type output struct {
declsByName map[string]*codegen.TypeDecl
declsBySchema map[*schemas.Type]*codegen.TypeDecl
unmarshallersByTypeDecl map[*codegen.TypeDecl]bool
processedSchemas map[string]bool
warner func(string)
}

Expand Down
148 changes: 119 additions & 29 deletions pkg/generator/schema_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ func (g *schemaGenerator) generateRootType() error {
}

func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type, error) {
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
if decl, ok := schemaOutput.declsByName[t.Ref]; ok {
if decl != nil {
return decl.Type, nil
defName, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, err
}

if fileName == "" {
if schemaOutput, ok := g.outputs[g.schema.ID]; ok {
if decl, ok := schemaOutput.declsByName[defName]; ok {
if decl != nil {
return &codegen.NamedType{Decl: decl}, nil
}
}
}
}
Expand All @@ -92,11 +99,6 @@ func (g *schemaGenerator) generateReferencedType(t *schemas.Type) (codegen.Type,
return codegen.EmptyInterfaceType{}, nil
}

defName, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, err
}

schema := g.schema
sg := g

Expand Down Expand Up @@ -374,6 +376,7 @@ func (g *schemaGenerator) generateDeclaredType(t *schemas.Type, scope nameScope)
return &codegen.NamedType{Decl: &decl}, nil
}

//nolint:gocyclo // todo: reduce cyclomatic complexity
func (g *schemaGenerator) structFieldValidators(
validators []validator,
f codegen.StructField,
Expand All @@ -391,12 +394,23 @@ func (g *schemaGenerator) structFieldValidators(
validators = g.structFieldValidators(validators, f, v.Type, v.IsNillable())

case codegen.PrimitiveType:
if v.Type == schemas.TypeNameString {
switch {
case v.Type == schemas.TypeNameString:
hasPattern := len(f.SchemaType.Pattern) != 0
if f.SchemaType.MinLength != 0 || f.SchemaType.MaxLength != 0 || hasPattern {
if f.SchemaType.MinLength != 0 || f.SchemaType.MaxLength != 0 || hasPattern || f.SchemaType.Const != nil {
// Double escape the escape characters so we don't effectively parse the escapes within the value.
escapedPattern := f.SchemaType.Pattern

var constVal *string

if f.SchemaType.Const != nil {
if s, ok := f.SchemaType.Const.(string); ok {
constVal = &s
} else {
g.warner(fmt.Sprintf("Ignoring non string const value: %v", f.SchemaType.Const))
}
}

replaceJSONCharactersBy := []string{"\\b", "\\f", "\\n", "\\r", "\\t"}

replaceJSONCharacters := []string{"\b", "\f", "\n", "\r", "\t"}
Expand All @@ -411,19 +425,22 @@ func (g *schemaGenerator) structFieldValidators(
minLength: f.SchemaType.MinLength,
maxLength: f.SchemaType.MaxLength,
pattern: escapedPattern,
constVal: constVal,
isNillable: isNillable,
})
}

if hasPattern {
g.output.file.Package.AddImport("regexp", "")
}
} else if strings.Contains(v.Type, "int") || v.Type == float64Type {

case strings.Contains(v.Type, "int") || v.Type == float64Type:
if f.SchemaType.MultipleOf != nil ||
f.SchemaType.Maximum != nil ||
f.SchemaType.ExclusiveMaximum != nil ||
f.SchemaType.Minimum != nil ||
f.SchemaType.ExclusiveMinimum != nil {
f.SchemaType.ExclusiveMinimum != nil ||
f.SchemaType.Const != nil {
validators = append(validators, &numericValidator{
jsonName: f.JSONName,
fieldName: f.Name,
Expand All @@ -433,13 +450,34 @@ func (g *schemaGenerator) structFieldValidators(
exclusiveMaximum: f.SchemaType.ExclusiveMaximum,
minimum: f.SchemaType.Minimum,
exclusiveMinimum: f.SchemaType.ExclusiveMinimum,
constVal: f.SchemaType.Const,
roundToInt: strings.Contains(v.Type, "int"),
})
}

if f.SchemaType.MultipleOf != nil && v.Type == float64Type {
g.output.file.Package.AddImport("math", "")
}

case v.Type == "bool":
if f.SchemaType.Const != nil {
var constVal *bool

if f.SchemaType.Const != nil {
if b, ok := f.SchemaType.Const.(bool); ok {
constVal = &b
} else {
g.warner(fmt.Sprintf("Ignoring non boolean const value: %v", f.SchemaType.Const))
}
}

validators = append(validators, &booleanValidator{
jsonName: f.JSONName,
fieldName: f.Name,
isNillable: isNillable,
constVal: constVal,
})
}
}

case *codegen.ArrayType:
Expand Down Expand Up @@ -505,7 +543,7 @@ func (g *schemaGenerator) generateUnmarshaler(decl *codegen.TypeDecl, validators

g.output.file.Package.AddDecl(&codegen.Method{
Impl: formatter.generate(g.output, decl, validators),
Name: decl.GetName() + "_validator",
Name: decl.GetName() + "_validator_" + formatter.getName(),
})
}
}
Expand Down Expand Up @@ -689,11 +727,11 @@ func (g *schemaGenerator) generateStructType(t *schemas.Type, scope nameScope) (
}

if len(t.AnyOf) > 0 {
return g.generateAnyOfType(t.AnyOf, scope)
return g.generateAnyOfType(t, scope)
}

if len(t.AllOf) > 0 {
return g.generateAllOfType(t.AllOf, scope)
return g.generateAllOfType(t, scope)
}

// Checking .Not here because `false` is unmarshalled to .Not = Type{}.
Expand Down Expand Up @@ -812,11 +850,16 @@ func (g *schemaGenerator) addStructField(

tags := ""

if isRequired || g.DisableOmitempty() {
switch {
case isRequired || g.DisableOmitempty():
for _, tag := range g.config.Tags {
tags += fmt.Sprintf(`%s:"%s" `, tag, name)
}
} else {
case g.config.PreferOmitzero:
for _, tag := range g.config.Tags {
tags += fmt.Sprintf(`%s:"%s,omitzero" `, tag, name)
}
default:
for _, tag := range g.config.Tags {
tags += fmt.Sprintf(`%s:"%s,omitempty" `, tag, name)
}
Expand Down Expand Up @@ -853,15 +896,32 @@ func (g *schemaGenerator) addStructField(
return nil
}

func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
if len(anyOf) == 0 {
func (g *schemaGenerator) generateAnyOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
if len(t.AnyOf) == 0 {
return nil, errEmptyInAnyOf
}

if g.config.AliasSingleAllOfAnyOfRefs && len(t.AnyOf) == 1 && t.IsEmptyObject() {
childType := t.AnyOf[0]
if childType.Ref != "" {
resolvedType, err := g.resolveRef(childType)
if err == nil {
return g.generateTypeInline(resolvedType, scope)
} else {
g.warner(fmt.Sprintf("Could not resolve ref %q: %v", childType.Ref, err))
}
}
}

isCycle := false
rAnyOf, hasNull := g.resolveRefs(anyOf, false)
rAnyOf, hasNull := g.resolveRefs(t.AnyOf, false)

for i, typ := range rAnyOf {
// infer type from base if not set
if len(typ.Type) == 0 {
typ.Type = append(schemas.TypeList{}, t.Type...)
}

typ.SetSubSchemaTypeElem()

ic, cleanupCycle, cycleErr := g.detectCycle(typ)
Expand Down Expand Up @@ -890,22 +950,38 @@ func (g *schemaGenerator) generateAnyOfType(anyOf []*schemas.Type, scope nameSco
return codegen.EmptyInterfaceType{}, nil
}

anyOfType, err := schemas.AnyOf(rAnyOf)
anyOfType, err := schemas.AnyOf(rAnyOf, t)
if err != nil {
return nil, fmt.Errorf("could not merge anyOf types: %w", err)
}

anyOfType.AnyOf = nil

return g.generateTypeInline(anyOfType, scope)
}

func (g *schemaGenerator) generateAllOfType(allOf []*schemas.Type, scope nameScope) (codegen.Type, error) {
rAllOf, _ := g.resolveRefs(allOf, true)
func (g *schemaGenerator) generateAllOfType(t *schemas.Type, scope nameScope) (codegen.Type, error) {
if g.config.AliasSingleAllOfAnyOfRefs && len(t.AllOf) == 1 && t.IsEmptyObject() {
subType := t.AllOf[0]
if subType.Ref != "" {
resolvedType, err := g.resolveRef(subType)
if err == nil {
return g.generateTypeInline(resolvedType, scope)
} else {
g.warner(fmt.Sprintf("Could not resolve subtype ref %q: %v", subType.Ref, err))
}
}
}

allOfType, err := schemas.AllOf(rAllOf)
rAllOf, _ := g.resolveRefs(t.AllOf, true)

allOfType, err := schemas.AllOf(rAllOf, t)
if err != nil {
return nil, fmt.Errorf("could not merge allOf types: %w", err)
}

allOfType.AllOf = nil

return g.generateTypeInline(allOfType, scope)
}

Expand Down Expand Up @@ -960,11 +1036,11 @@ func (g *schemaGenerator) generateTypeInline(t *schemas.Type, scope nameScope) (
}

if len(t.AnyOf) > 0 {
return g.generateAnyOfType(t.AnyOf, scope)
return g.generateAnyOfType(t, scope)
}

if len(t.AllOf) > 0 {
return g.generateAllOfType(t.AllOf, scope)
return g.generateAllOfType(t, scope)
}

if len(t.Type) == 2 && typeIsNullable {
Expand Down Expand Up @@ -1192,13 +1268,13 @@ func (g *schemaGenerator) generateEnumType(
if wrapInStruct {
g.output.file.Package.AddDecl(&codegen.Method{
Impl: formatter.enumMarshal(&enumDecl),
Name: enumDecl.GetName() + "_enum",
Name: enumDecl.GetName() + "_enum_" + formatter.getName(),
})
}

g.output.file.Package.AddDecl(&codegen.Method{
Impl: formatter.enumUnmarshal(enumDecl, enumType, valueConstant, wrapInStruct),
Name: enumDecl.GetName() + "_enum_unmarshal",
Name: enumDecl.GetName() + "_enum_unmarshal_" + formatter.getName(),
})
}
}
Expand Down Expand Up @@ -1263,6 +1339,20 @@ func (g *schemaGenerator) resolveRef(t *schemas.Type) (*schemas.Type, error) {
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
}

// After resolving the ref type we lose info about the original schema
// so rewrite all nested refs to include the original schema id
_, fileName, err := g.extractRefNames(t)
if err != nil {
return nil, fmt.Errorf("%w: %w", errCannotResolveRef, err)
}

if fileName != "" {
err = ntyp.Decl.SchemaType.ConvertAllRefs(fileName)
if err != nil {
return nil, fmt.Errorf("convert refs: %w", err)
}
}

ntyp.Decl.SchemaType.Dereferenced = true

g.schemaTypesByRef[t.Ref] = ntyp.Decl.SchemaType
Expand Down
Loading