Skip to content

Commit 91ffa0d

Browse files
Add Validation Support (#26)
* instructor validator * update simple example * validate in stream * add with validator option * add example for validator * update example * rename withValidator varible to validate * rename with validator to withValidation * remove required param for WithValidation --------- Co-authored-by: Robby <[email protected]>
1 parent 038c15c commit 91ffa0d

File tree

10 files changed

+183
-12
lines changed

10 files changed

+183
-12
lines changed

examples/validator/main.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
8+
"github.com/instructor-ai/instructor-go/pkg/instructor"
9+
openai "github.com/sashabaranov/go-openai"
10+
)
11+
12+
type User struct {
13+
FirstName string `json:"first_name" jsonschema:"title=First Name,description=The first name of the user" validate:"required"`
14+
LastName string `json:"last_name" jsonschema:"title=Last Name,description=The last name of the user" validate:"required"`
15+
Age uint8 `json:"age" jsonschema:"title=Age,description=The age of the user" validate:"gte=0,lte=130"`
16+
Email string `json:"email" jsonschema:"title=Email,description=The email address of the user" validate:"required,email"`
17+
Gender string `json:"gender" jsonschema:"title=Gender,description=The gender of the user" validate:"oneof=male female prefer_not_to"`
18+
FavouriteColor string `json:"favourite_color" jsonschema:"title=Favourite Color,description=The favourite color of the user" validate:"iscolor"`
19+
Addresses []*Address `json:"addresses" jsonschema:"title=Addresses,description=The addresses of the user" validate:"required,dive,required"`
20+
}
21+
22+
type Address struct {
23+
Street string `json:"street" jsonschema:"title=Street,description=The street address" validate:"required"`
24+
City string `json:"city" jsonschema:"title=City,description=The city" validate:"required"`
25+
Planet string `json:"planet" jsonschema:"title=Planet,description=The planet" validate:"required"`
26+
Phone string `json:"phone" jsonschema:"title=Phone,description=The phone number" validate:"required"`
27+
}
28+
29+
func (u User) String() string {
30+
result := fmt.Sprintf("First Name: %s\nLast Name: %s\nAge: %d\nEmail: %s\nGender: %s\nFavourite Color: %s\nAddresses:\n",
31+
u.FirstName, u.LastName, u.Age, u.Email, u.Gender, u.FavouriteColor)
32+
for _, address := range u.Addresses {
33+
result += fmt.Sprintf(" %s\n", address)
34+
}
35+
return result
36+
}
37+
38+
func (a Address) String() string {
39+
return fmt.Sprintf("Street: %s, City: %s, Planet: %s, Phone: %s", a.Street, a.City, a.Planet, a.Phone)
40+
}
41+
42+
func main() {
43+
ctx := context.Background()
44+
45+
client := instructor.FromOpenAI(
46+
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
47+
instructor.WithMode(instructor.ModeJSON),
48+
instructor.WithMaxRetries(3),
49+
instructor.WithValidation(),
50+
)
51+
52+
var user User
53+
_, err := client.CreateChatCompletion(
54+
ctx,
55+
openai.ChatCompletionRequest{
56+
Model: openai.GPT4o,
57+
Messages: []openai.ChatCompletionMessage{
58+
{
59+
Role: openai.ChatMessageRoleUser,
60+
Content: "Meet Jane Doe: a 30-year-old adventurer who can be reached at [email protected]. " +
61+
"Jane loves the vibrant hue of #FF5733. She resides in Metropolis at 456 Oak St, on the wonderful planet Earth. " +
62+
"To chat with her, dial (555) 555-1234. Jane also spends her weekends at her cottage located at 789 Pine St, " +
63+
"in Smallville, on the same planet. You can contact her there at (555) 555-5678.",
64+
},
65+
},
66+
},
67+
&user,
68+
)
69+
if err != nil {
70+
panic(err)
71+
}
72+
73+
fmt.Println(user)
74+
/*
75+
First Name: Jane
76+
Last Name: Doe
77+
Age: 30
78+
79+
Gender: female
80+
Favourite Color: #FF5733
81+
Addresses:
82+
Street: 456 Oak St, City: Metropolis, Planet: Earth, Phone: (555) 555-1234
83+
Street: 789 Pine St, City: Smallville, Planet: Earth, Phone: (555) 555-5678
84+
*/
85+
}

go.mod

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go 1.21.8
44

55
require (
66
github.com/cohere-ai/cohere-go/v2 v2.8.1
7+
github.com/go-playground/validator/v10 v10.21.0
78
github.com/invopop/jsonschema v0.12.0
89
github.com/liushuangls/go-anthropic/v2 v2.1.0
910
github.com/sashabaranov/go-openai v1.24.1
@@ -12,8 +13,16 @@ require (
1213
require (
1314
github.com/bahlo/generic-list-go v0.2.0 // indirect
1415
github.com/buger/jsonparser v1.1.1 // indirect
16+
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
17+
github.com/go-playground/locales v0.14.1 // indirect
18+
github.com/go-playground/universal-translator v0.18.1 // indirect
1519
github.com/google/uuid v1.6.0 // indirect
20+
github.com/leodido/go-urn v1.4.0 // indirect
1621
github.com/mailru/easyjson v0.7.7 // indirect
1722
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
23+
golang.org/x/crypto v0.19.0 // indirect
24+
golang.org/x/net v0.21.0 // indirect
25+
golang.org/x/sys v0.17.0 // indirect
26+
golang.org/x/text v0.14.0 // indirect
1827
gopkg.in/yaml.v3 v3.0.1 // indirect
1928
)

go.sum

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,23 @@ github.com/cohere-ai/cohere-go/v2 v2.8.1 h1:7+MCdXtz8onJLRmJik/cD5XGfgDNLhte4aW4
66
github.com/cohere-ai/cohere-go/v2 v2.8.1/go.mod h1:dlDCT66i8BqZDuuskFvYzsrc+O0M4l5J9Ibckoflvt4=
77
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
88
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
9-
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
10-
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
9+
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
10+
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
11+
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
12+
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
13+
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
14+
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
15+
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
16+
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
17+
github.com/go-playground/validator/v10 v10.21.0 h1:4fZA11ovvtkdgaeev9RGWPgc1uj3H8W+rNYyH/ySBb0=
18+
github.com/go-playground/validator/v10 v10.21.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
1119
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
1220
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
1321
github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI=
1422
github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
1523
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
24+
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
25+
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
1626
github.com/liushuangls/go-anthropic/v2 v2.1.0 h1:5ntOeehozlMin0+hgnhxbTru+tmBH84ADaSPelG5fPg=
1727
github.com/liushuangls/go-anthropic/v2 v2.1.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
1828
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
@@ -21,10 +31,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
2131
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
2232
github.com/sashabaranov/go-openai v1.24.1 h1:DWK95XViNb+agQtuzsn+FyHhn3HQJ7Va8z04DQDJ1MI=
2333
github.com/sashabaranov/go-openai v1.24.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
24-
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
25-
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
34+
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
35+
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
2636
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
2737
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
38+
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
39+
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
40+
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
41+
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
42+
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
43+
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
44+
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
45+
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
2846
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
2947
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3048
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

pkg/instructor/anthropic.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ type InstructorAnthropic struct {
1010
provider Provider
1111
mode Mode
1212
maxRetries int
13+
validate bool
1314
}
1415

1516
var _ Instructor = &InstructorAnthropic{}
@@ -24,6 +25,7 @@ func FromAnthropic(client *anthropic.Client, opts ...Options) *InstructorAnthrop
2425
provider: ProviderOpenAI,
2526
mode: *options.Mode,
2627
maxRetries: *options.MaxRetries,
28+
validate: *options.validate,
2729
}
2830
return i
2931
}
@@ -39,3 +41,6 @@ func (i *InstructorAnthropic) Mode() string {
3941
func (i *InstructorAnthropic) Provider() string {
4042
return i.provider
4143
}
44+
func (i *InstructorAnthropic) Validate() bool {
45+
return i.validate
46+
}

pkg/instructor/chat.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"errors"
77
"reflect"
8+
9+
"github.com/go-playground/validator/v10"
810
)
911

1012
func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) {
@@ -38,6 +40,18 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
3840
continue
3941
}
4042

43+
if i.Validate() {
44+
validate = validator.New()
45+
// Validate the response structure against the defined model using the validator
46+
err = validate.Struct(response)
47+
48+
if err != nil {
49+
// TODO:
50+
// add more sophisticated retry logic (send back validator error and parse error for model to fix).
51+
continue
52+
}
53+
}
54+
4155
return resp, nil
4256
}
4357

pkg/instructor/chat_stream.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"reflect"
77
"strings"
8+
9+
"github.com/go-playground/validator/v10"
810
)
911

1012
type StreamWrapper[T any] struct {
@@ -36,12 +38,17 @@ func chatStreamHandler(i Instructor, ctx context.Context, request interface{}, r
3638
return nil, err
3739
}
3840

39-
parsedChan := parseStream(ctx, ch, responseType)
41+
shouldValidate := i.Validate()
42+
if shouldValidate {
43+
validate = validator.New()
44+
}
45+
46+
parsedChan := parseStream(ctx, ch, shouldValidate, responseType)
4047

4148
return parsedChan, nil
4249
}
4350

44-
func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Type) <-chan interface{} {
51+
func parseStream(ctx context.Context, ch <-chan string, shouldValidate bool, responseType reflect.Type) <-chan interface{} {
4552

4653
parsedChan := make(chan any)
4754

@@ -58,7 +65,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ
5865
case text, ok := <-ch:
5966
if !ok {
6067
// Stream closed
61-
processRemainingBuffer(buffer, parsedChan, responseType)
68+
processRemainingBuffer(buffer, parsedChan, shouldValidate, responseType)
6269
return
6370
}
6471

@@ -69,7 +76,7 @@ func parseStream(ctx context.Context, ch <-chan string, responseType reflect.Typ
6976
inArray = startArray(buffer)
7077
}
7178

72-
processBuffer(buffer, parsedChan, responseType)
79+
processBuffer(buffer, parsedChan, shouldValidate, responseType)
7380
}
7481
}
7582
}()
@@ -93,7 +100,7 @@ func startArray(buffer *strings.Builder) bool {
93100
return true
94101
}
95102

96-
func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) {
103+
func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) {
97104

98105
data := buffer.String()
99106

@@ -107,14 +114,23 @@ func processBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, respo
107114
if err != nil {
108115
break
109116
}
117+
118+
if shouldValidate {
119+
// Validate the instance
120+
err = validate.Struct(instance)
121+
if err != nil {
122+
break
123+
}
124+
}
125+
110126
parsedChan <- instance
111127

112128
buffer.Reset()
113129
buffer.WriteString(remaining)
114130
}
115131
}
116132

117-
func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, responseType reflect.Type) {
133+
func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface{}, shouldValidate bool, responseType reflect.Type) {
118134

119135
data := buffer.String()
120136

@@ -124,5 +140,6 @@ func processRemainingBuffer(buffer *strings.Builder, parsedChan chan<- interface
124140
data = data[:idx]
125141
}
126142

127-
processBuffer(buffer, parsedChan, responseType)
143+
processBuffer(buffer, parsedChan, shouldValidate, responseType)
144+
128145
}

pkg/instructor/cohere_struct.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ type InstructorCohere struct {
1010
provider Provider
1111
mode Mode
1212
maxRetries int
13+
validate bool
1314
}
1415

1516
var _ Instructor = &InstructorCohere{}
@@ -39,3 +40,6 @@ func (i *InstructorCohere) Mode() string {
3940
func (i *InstructorCohere) MaxRetries() int {
4041
return i.maxRetries
4142
}
43+
func (i *InstructorCohere) Validate() bool {
44+
return i.validate
45+
}

pkg/instructor/instructor.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ package instructor
22

33
import (
44
"context"
5+
6+
"github.com/go-playground/validator/v10"
57
)
68

9+
var validate *validator.Validate
10+
711
type Instructor interface {
812
Provider() Provider
913
Mode() Mode
1014
MaxRetries() int
15+
Validate() bool
1116

1217
// Chat / Messages
1318

pkg/instructor/openai_struct.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ type InstructorOpenAI struct {
1010
provider Provider
1111
mode Mode
1212
maxRetries int
13+
validate bool
1314
}
1415

1516
var _ Instructor = &InstructorOpenAI{}
@@ -24,6 +25,7 @@ func FromOpenAI(client *openai.Client, opts ...Options) *InstructorOpenAI {
2425
provider: ProviderOpenAI,
2526
mode: *options.Mode,
2627
maxRetries: *options.MaxRetries,
28+
validate: *options.validate,
2729
}
2830
return i
2931
}
@@ -37,3 +39,6 @@ func (i *InstructorOpenAI) Mode() Mode {
3739
func (i *InstructorOpenAI) MaxRetries() int {
3840
return i.maxRetries
3941
}
42+
func (i *InstructorOpenAI) Validate() bool {
43+
return i.validate
44+
}

pkg/instructor/options.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@ package instructor
22

33
const (
44
DefaultMaxRetries = 3
5+
DefaultValidator = false
56
)
67

78
type Options struct {
89
Mode *Mode
910
MaxRetries *int
10-
11+
validate *bool
1112
// Provider specific options:
1213
}
1314

1415
var defaultOptions = Options{
1516
Mode: toPtr(ModeDefault),
1617
MaxRetries: toPtr(DefaultMaxRetries),
18+
validate: toPtr(DefaultValidator),
1719
}
1820

1921
func WithMode(mode Mode) Options {
@@ -24,13 +26,20 @@ func WithMaxRetries(maxRetries int) Options {
2426
return Options{MaxRetries: toPtr(maxRetries)}
2527
}
2628

29+
func WithValidation() Options {
30+
return Options{validate: toPtr(true)}
31+
}
32+
2733
func mergeOption(old, new Options) Options {
2834
if new.Mode != nil {
2935
old.Mode = new.Mode
3036
}
3137
if new.MaxRetries != nil {
3238
old.MaxRetries = new.MaxRetries
3339
}
40+
if new.validate != nil {
41+
old.validate = new.validate
42+
}
3443

3544
return old
3645
}

0 commit comments

Comments
 (0)