Skip to content

Commit c9be9ff

Browse files
authored
count all usage from all retries and failures (#29)
* Count all usage from failures to unmarshal and validate json * usage counting: move provider-specific logic into provider chat files
1 parent ea5dfe9 commit c9be9ff

File tree

5 files changed

+207
-11
lines changed

5 files changed

+207
-11
lines changed

pkg/instructor/anthropic_chat.go

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ func (i *InstructorAnthropic) CreateMessages(ctx context.Context, request anthro
1313

1414
resp, err := chatHandler(i, ctx, request, responseType)
1515
if err != nil {
16-
return anthropic.MessagesResponse{}, err
16+
if resp == nil {
17+
return anthropic.MessagesResponse{}, err
18+
}
19+
return *nilAnthropicRespWithUsage(resp.(*anthropic.MessagesResponse)), err
1720
}
1821

1922
response = *(resp.(*anthropic.MessagesResponse))
@@ -68,13 +71,13 @@ func (i *InstructorAnthropic) completionToolCall(ctx context.Context, request *a
6871

6972
toolInput, err := json.Marshal(c.Input)
7073
if err != nil {
71-
return "", nil, err
74+
return "", nilAnthropicRespWithUsage(&resp), err
7275
}
7376
// TODO: handle more than 1 tool use
7477
return string(toolInput), &resp, nil
7578
}
7679

77-
return "", nil, errors.New("more than 1 tool response at a time is not implemented")
80+
return "", nilAnthropicRespWithUsage(&resp), errors.New("more than 1 tool response at a time is not implemented")
7881

7982
}
8083

@@ -103,3 +106,57 @@ Make sure to return an instance of the JSON, not the schema itself.
103106

104107
return *text, &resp, nil
105108
}
109+
110+
func (i *InstructorAnthropic) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
111+
return &anthropic.MessagesResponse{
112+
Usage: anthropic.MessagesUsage{
113+
InputTokens: usage.InputTokens,
114+
OutputTokens: usage.OutputTokens,
115+
},
116+
}
117+
}
118+
119+
func (i *InstructorAnthropic) emptyResponseWithResponseUsage(response interface{}) interface{} {
120+
resp, ok := response.(*anthropic.MessagesResponse)
121+
if !ok || resp == nil {
122+
return nil
123+
}
124+
125+
return &anthropic.MessagesResponse{
126+
Usage: resp.Usage,
127+
}
128+
}
129+
130+
func (i *InstructorAnthropic) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
131+
resp, ok := response.(*anthropic.MessagesResponse)
132+
if !ok {
133+
return response, fmt.Errorf("internal type error: expected *anthropic.MessagesResponse, got %T", response)
134+
}
135+
136+
resp.Usage.InputTokens += usage.InputTokens
137+
resp.Usage.OutputTokens += usage.OutputTokens
138+
139+
return response, nil
140+
}
141+
142+
func (i *InstructorAnthropic) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
143+
resp, ok := response.(*anthropic.MessagesResponse)
144+
if !ok {
145+
return usage
146+
}
147+
148+
usage.InputTokens += resp.Usage.InputTokens
149+
usage.OutputTokens += resp.Usage.OutputTokens
150+
151+
return usage
152+
}
153+
154+
func nilAnthropicRespWithUsage(resp *anthropic.MessagesResponse) *anthropic.MessagesResponse {
155+
if resp == nil {
156+
return nil
157+
}
158+
159+
return &anthropic.MessagesResponse{
160+
Usage: resp.Usage,
161+
}
162+
}

pkg/instructor/chat.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ import (
99
"github.com/go-playground/validator/v10"
1010
)
1111

12+
type UsageSum struct {
13+
InputTokens int
14+
OutputTokens int
15+
TotalTokens int
16+
}
17+
1218
func chatHandler(i Instructor, ctx context.Context, request interface{}, response any) (interface{}, error) {
1319

1420
var err error
@@ -20,12 +26,15 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
2026
return nil, err
2127
}
2228

29+
// keep a running total of usage
30+
usage := &UsageSum{}
31+
2332
for attempt := 0; attempt < i.MaxRetries(); attempt++ {
2433

2534
text, resp, err := i.chat(ctx, request, schema)
2635
if err != nil {
2736
// no retry on non-marshalling/validation errors
28-
return nil, err
37+
return i.emptyResponseWithResponseUsage(resp), err
2938
}
3039

3140
text = extractJSON(&text)
@@ -37,6 +46,8 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
3746
//
3847
// Currently, its just recalling with no new information
3948
// or attempt to fix the error with the last generated JSON
49+
50+
i.countUsageFromResponse(resp, usage)
4051
continue
4152
}
4253

@@ -48,12 +59,14 @@ func chatHandler(i Instructor, ctx context.Context, request interface{}, respons
4859
if err != nil {
4960
// TODO:
5061
// add more sophisticated retry logic (send back validator error and parse error for model to fix).
62+
63+
i.countUsageFromResponse(resp, usage)
5164
continue
5265
}
5366
}
5467

55-
return resp, nil
68+
return i.addUsageSumToResponse(resp, usage)
5669
}
5770

58-
return nil, errors.New("hit max retry attempts")
71+
return i.emptyResponseWithUsageSum(usage), errors.New("hit max retry attempts")
5972
}

pkg/instructor/cohere_chat.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ func (i *InstructorCohere) Chat(
1717

1818
resp, err := chatHandler(i, ctx, request, response)
1919
if err != nil {
20-
return nil, err
20+
if resp == nil {
21+
return &cohere.NonStreamedChatResponse{}, err
22+
}
23+
return nilCohereRespWithUsage(resp.(*cohere.NonStreamedChatResponse)), err
2124
}
2225

2326
return resp.(*cohere.NonStreamedChatResponse), nil
@@ -80,6 +83,52 @@ func (i *InstructorCohere) addOrConcatJSONSystemPrompt(request *cohere.ChatReque
8083
}
8184
}
8285

86+
func (i *InstructorCohere) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
87+
return &cohere.NonStreamedChatResponse{
88+
Meta: &cohere.ApiMeta{
89+
Tokens: &cohere.ApiMetaTokens{
90+
InputTokens: toPtr(float64(usage.InputTokens)),
91+
OutputTokens: toPtr(float64(usage.OutputTokens)),
92+
},
93+
},
94+
}
95+
}
96+
97+
func (i *InstructorCohere) emptyResponseWithResponseUsage(response interface{}) interface{} {
98+
resp, ok := response.(*cohere.NonStreamedChatResponse)
99+
if !ok || resp == nil {
100+
return nil
101+
}
102+
103+
return &cohere.NonStreamedChatResponse{
104+
Meta: resp.Meta,
105+
}
106+
}
107+
108+
func (i *InstructorCohere) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
109+
resp, ok := response.(*cohere.NonStreamedChatResponse)
110+
if !ok {
111+
return response, fmt.Errorf("internal type error: expected *cohere.NonStreamedChatResponse, got %T", response)
112+
}
113+
114+
*resp.Meta.Tokens.InputTokens += float64(usage.InputTokens)
115+
*resp.Meta.Tokens.OutputTokens += float64(usage.OutputTokens)
116+
117+
return response, nil
118+
}
119+
120+
func (i *InstructorCohere) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
121+
resp, ok := response.(*cohere.NonStreamedChatResponse)
122+
if !ok {
123+
return usage
124+
}
125+
126+
usage.InputTokens += int(*resp.Meta.Tokens.InputTokens)
127+
usage.OutputTokens += int(*resp.Meta.Tokens.OutputTokens)
128+
129+
return usage
130+
}
131+
83132
func createCohereTools(schema *Schema) *cohere.Tool {
84133

85134
tool := &cohere.Tool{
@@ -98,3 +147,13 @@ func createCohereTools(schema *Schema) *cohere.Tool {
98147

99148
return tool
100149
}
150+
151+
func nilCohereRespWithUsage(resp *cohere.NonStreamedChatResponse) *cohere.NonStreamedChatResponse {
152+
if resp == nil {
153+
return nil
154+
}
155+
156+
return &cohere.NonStreamedChatResponse{
157+
Meta: resp.Meta,
158+
}
159+
}

pkg/instructor/instructor.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,11 @@ type Instructor interface {
2929
request interface{},
3030
schema *Schema,
3131
) (<-chan string, error)
32+
33+
// Usage counting
34+
35+
emptyResponseWithUsageSum(usage *UsageSum) interface{}
36+
emptyResponseWithResponseUsage(response interface{}) interface{}
37+
addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error)
38+
countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum
3239
}

pkg/instructor/openai_chat.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ func (i *InstructorOpenAI) CreateChatCompletion(
1717

1818
resp, err := chatHandler(i, ctx, request, responseType)
1919
if err != nil {
20-
return openai.ChatCompletionResponse{}, err
20+
if resp == nil {
21+
return openai.ChatCompletionResponse{}, err
22+
}
23+
return *nilOpenaiRespWithUsage(resp.(*openai.ChatCompletionResponse)), err
2124
}
2225

2326
response = *(resp.(*openai.ChatCompletionResponse))
@@ -69,7 +72,7 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha
6972
numTools := len(toolCalls)
7073

7174
if numTools < 1 {
72-
return "", nil, errors.New("recieved no tool calls from model, expected at least 1")
75+
return "", nilOpenaiRespWithUsage(&resp), errors.New("received no tool calls from model, expected at least 1")
7376
}
7477

7578
if numTools == 1 {
@@ -84,14 +87,14 @@ func (i *InstructorOpenAI) chatToolCall(ctx context.Context, request *openai.Cha
8487
var jsonObj map[string]interface{}
8588
err = json.Unmarshal([]byte(toolCall.Function.Arguments), &jsonObj)
8689
if err != nil {
87-
return "", nil, err
90+
return "", nilOpenaiRespWithUsage(&resp), err
8891
}
8992
jsonArray[i] = jsonObj
9093
}
9194

9295
resultJSON, err := json.Marshal(jsonArray)
9396
if err != nil {
94-
return "", nil, err
97+
return "", nilOpenaiRespWithUsage(&resp), err
9598
}
9699

97100
return string(resultJSON), &resp, nil
@@ -128,6 +131,53 @@ func (i *InstructorOpenAI) chatJSONSchema(ctx context.Context, request *openai.C
128131
return text, &resp, nil
129132
}
130133

134+
func (i *InstructorOpenAI) emptyResponseWithUsageSum(usage *UsageSum) interface{} {
135+
return &openai.ChatCompletionResponse{
136+
Usage: openai.Usage{
137+
PromptTokens: usage.InputTokens,
138+
CompletionTokens: usage.OutputTokens,
139+
TotalTokens: usage.TotalTokens,
140+
},
141+
}
142+
}
143+
144+
func (i *InstructorOpenAI) emptyResponseWithResponseUsage(response interface{}) interface{} {
145+
resp, ok := response.(*openai.ChatCompletionResponse)
146+
if !ok || resp == nil {
147+
return nil
148+
}
149+
150+
return &openai.ChatCompletionResponse{
151+
Usage: resp.Usage,
152+
}
153+
}
154+
155+
func (i *InstructorOpenAI) addUsageSumToResponse(response interface{}, usage *UsageSum) (interface{}, error) {
156+
resp, ok := response.(*openai.ChatCompletionResponse)
157+
if !ok {
158+
return response, fmt.Errorf("internal type error: expected *openai.ChatCompletionResponse, got %T", response)
159+
}
160+
161+
resp.Usage.PromptTokens += usage.InputTokens
162+
resp.Usage.CompletionTokens += usage.OutputTokens
163+
resp.Usage.TotalTokens += usage.TotalTokens
164+
165+
return response, nil
166+
}
167+
168+
func (i *InstructorOpenAI) countUsageFromResponse(response interface{}, usage *UsageSum) *UsageSum {
169+
resp, ok := response.(*openai.ChatCompletionResponse)
170+
if !ok {
171+
return usage
172+
}
173+
174+
usage.InputTokens += resp.Usage.PromptTokens
175+
usage.OutputTokens += resp.Usage.CompletionTokens
176+
usage.TotalTokens += resp.Usage.TotalTokens
177+
178+
return usage
179+
}
180+
131181
func createJSONMessage(schema *Schema) *openai.ChatCompletionMessage {
132182
message := fmt.Sprintf(`
133183
Please respond with JSON in the following JSON schema:
@@ -144,3 +194,13 @@ Make sure to return an instance of the JSON, not the schema itself
144194

145195
return msg
146196
}
197+
198+
func nilOpenaiRespWithUsage(resp *openai.ChatCompletionResponse) *openai.ChatCompletionResponse {
199+
if resp == nil {
200+
return nil
201+
}
202+
203+
return &openai.ChatCompletionResponse{
204+
Usage: resp.Usage,
205+
}
206+
}

0 commit comments

Comments
 (0)