diff --git a/pkg/generator/schema_generator.go b/pkg/generator/schema_generator.go index db5dc923..356fa522 100644 --- a/pkg/generator/schema_generator.go +++ b/pkg/generator/schema_generator.go @@ -3,6 +3,7 @@ package generator import ( "errors" "fmt" + "slices" "strings" "github.com/google/go-cmp/cmp" @@ -1085,10 +1086,23 @@ func (g *schemaGenerator) generateEnumType( var enumType codegen.Type - if len(t.Type) == 1 { + schemaType := make([]string, len(t.Type)) + copy(schemaType, t.Type) + + nullIdx := slices.Index(schemaType, schemas.TypeNameNull) + + if len(schemaType) == 2 && nullIdx != -1 { + if nullIdx == 0 { + schemaType = schemaType[1:] + } else { + schemaType = schemaType[:1] + } + } + + if len(schemaType) == 1 { var err error if enumType, err = codegen.PrimitiveTypeFromJSONSchemaType( - t.Type[0], + schemaType[0], t.Format, false, g.config.MinSizedInts, @@ -1097,11 +1111,11 @@ func (g *schemaGenerator) generateEnumType( &t.ExclusiveMinimum, &t.ExclusiveMaximum, ); err != nil { - return nil, fmt.Errorf("invalid type %q: %w", t.Type[0], err) + return nil, fmt.Errorf("invalid type %q: %w", schemaType[0], err) } // Enforce integer type for enum values. - if t.Type[0] == "integer" { + if schemaType[0] == "integer" { for i, v := range t.Enum { switch v := v.(type) { case float64: @@ -1113,9 +1127,9 @@ func (g *schemaGenerator) generateEnumType( } } - wrapInStruct = t.Type[0] == schemas.TypeNameNull // Null uses interface{}, which cannot have methods. + wrapInStruct = schemaType[0] == schemas.TypeNameNull // Null uses interface{}, which cannot have methods. } else { - if len(t.Type) > 1 { + if len(schemaType) > 1 { // TODO: Support multiple types. g.warner("Enum defined with multiple types; ignoring it and using enum values instead") } @@ -1126,21 +1140,21 @@ func (g *schemaGenerator) generateEnumType( var valueType string if v == nil { - valueType = interfaceTypeName - } else { - switch v.(type) { - case string: - valueType = "string" + continue + } - case float64: - valueType = float64Type + switch v.(type) { + case string: + valueType = "string" - case bool: - valueType = "bool" + case float64: + valueType = float64Type - default: - return nil, fmt.Errorf("%w %v", errEnumNonPrimitiveVal, v) - } + case bool: + valueType = "bool" + + default: + return nil, fmt.Errorf("%w %v", errEnumNonPrimitiveVal, v) } if primitiveType == "" {