Skip to content

Commit f1b64f6

Browse files
Merge pull request #1008 from EleutherAI/openai_completions
[Refactor] Openai completions
2 parents c3d97e4 + 5b42436 commit f1b64f6

File tree

3 files changed

+220
-5
lines changed

3 files changed

+220
-5
lines changed

lm_eval/models/openai_completions.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import os
22
import time
33
from typing import List, Tuple
4+
5+
import copy
6+
from collections import defaultdict
47
from tqdm import tqdm
8+
59
from lm_eval import utils
610
from lm_eval.api.model import LM
711
from lm_eval.api.registry import register_model
@@ -51,7 +55,7 @@ def oa_completion(**kwargs):
5155
backoff_time = 3
5256
while True:
5357
try:
54-
return openai.Completion.create(**kwargs)
58+
return openai.Completions.create(**kwargs)
5559
except openai.error.OpenAIError:
5660
import traceback
5761

@@ -60,7 +64,7 @@ def oa_completion(**kwargs):
6064
backoff_time *= 1.5
6165

6266

63-
@register_model("openai", "openai-completions", "gooseai")
67+
@register_model("gooseai")
6468
class OpenaiCompletionsLM(LM):
6569
REQ_CHUNK_SIZE = 20
6670

@@ -304,3 +308,211 @@ def loglikelihood_rolling(self, requests) -> List[float]:
304308
string_nll = sum(string_nll)
305309
loglikelihoods.append(string_nll)
306310
return loglikelihoods
311+
312+
313+
def oa_chat_completion(client, **kwargs):
314+
"""Query OpenAI API for chat completion.
315+
316+
Retry with back-off until they respond
317+
"""
318+
try:
319+
import openai, tiktoken # noqa: E401
320+
except ModuleNotFoundError:
321+
raise Exception(
322+
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
323+
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
324+
)
325+
326+
async def _get_completions(**kwargs):
327+
chat_completions = await client.chat.completions.create(**kwargs)
328+
return chat_completions
329+
330+
backoff_time = 3
331+
while True:
332+
try:
333+
return client.chat.completions.create(**kwargs)
334+
except openai.OpenAIError:
335+
import traceback
336+
337+
traceback.print_exc()
338+
time.sleep(backoff_time)
339+
backoff_time *= 1.5
340+
341+
342+
@register_model("openai-chat-completions")
343+
class OpenaiChatCompletionsLM(LM):
344+
def __init__(
345+
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
346+
) -> None:
347+
"""
348+
349+
:param model: str
350+
OpenAI API model (e.g. gpt-3.5-turbo)
351+
:param truncate: bool
352+
Truncate input if too long (if False and input is too long, throw error)
353+
"""
354+
super().__init__()
355+
try:
356+
import openai, tiktoken # noqa: E401
357+
except ModuleNotFoundError:
358+
raise Exception(
359+
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
360+
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
361+
)
362+
self.model = model
363+
self.frequency_penalty = 0
364+
self.logit_bias = None
365+
self.n = 1
366+
self.presence_penalty = 0
367+
self.temperature = 1
368+
self.top_p = 1
369+
self.tokenizer = tiktoken.encoding_for_model(self.model)
370+
self.vocab_size = self.tokenizer.n_vocab
371+
self.truncate = truncate
372+
self.end_of_text_token_id = self.tokenizer.eot_token
373+
374+
# Read from environment variable OPENAI_API_KEY
375+
self.client = openai.OpenAI() # openai.AsyncOpenAI()
376+
377+
@property
378+
def eot_token_id(self):
379+
return self.end_of_text_token_id
380+
381+
@property
382+
def max_length(self) -> int:
383+
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
384+
return 2048
385+
386+
@property
387+
def max_gen_toks(self) -> int:
388+
return 256
389+
390+
@property
391+
def batch_size(self):
392+
# Isn't used because we override _loglikelihood_tokens
393+
raise NotImplementedError()
394+
395+
@property
396+
def device(self):
397+
# Isn't used because we override _loglikelihood_tokens
398+
raise NotImplementedError()
399+
400+
def tok_encode(self, string: str) -> List[int]:
401+
return self.tokenizer.encode(string)
402+
403+
def tok_decode(self, tokens: List[int]) -> str:
404+
return self.tokenizer.decode(tokens)
405+
406+
def _encode_pair(
407+
self, context: str, continuation: str
408+
) -> Tuple[List[int], List[int]]:
409+
n_spaces = len(context) - len(context.rstrip())
410+
if n_spaces > 0:
411+
continuation = context[-n_spaces:] + continuation
412+
context = context[:-n_spaces]
413+
whole_enc = self.tok_encode(context + continuation)
414+
context_enc = self.tok_encode(context)
415+
context_enc_len = len(context_enc)
416+
continuation_enc = whole_enc[context_enc_len:]
417+
return context_enc, continuation_enc
418+
419+
def generate_until(self, requests) -> List[str]:
420+
res = defaultdict(list)
421+
re_ords = {}
422+
423+
def _collate(x):
424+
toks = self.tok_encode(x[0])
425+
return -len(toks), x[0]
426+
427+
# we group requests by their generation_kwargs,
428+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
429+
# in the same batch.
430+
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
431+
for key, reqs in grouper.get_grouped().items():
432+
# within each set of reqs for given kwargs, we reorder by token length, descending.
433+
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
434+
435+
def sameuntil_chunks(xs, size):
436+
ret = []
437+
lastuntil = xs[0][1]
438+
for x in xs:
439+
if len(ret) >= size or x[1] != lastuntil:
440+
yield ret, lastuntil
441+
ret = []
442+
lastuntil = x[1]
443+
ret.append(x)
444+
445+
if ret:
446+
yield ret, lastuntil
447+
448+
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
449+
for key, re_ord in re_ords.items():
450+
# n needs to be 1 because messages in
451+
# chat completion are not batch but
452+
# is regarded as a single conversation.
453+
chunks = utils.chunks(re_ord.get_reordered(), n=1)
454+
for chunk in chunks:
455+
contexts, all_gen_kwargs = zip(*chunk)
456+
inps = [{"role": "user", "content": context} for context in contexts]
457+
458+
gen_kwargs = all_gen_kwargs[0]
459+
until = None
460+
if isinstance(gen_kwargs, dict):
461+
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
462+
if "until" in kwargs.keys():
463+
until = kwargs.pop("until")
464+
if isinstance(until, str):
465+
until = [kwargs]
466+
elif not isinstance(until, list):
467+
raise ValueError(
468+
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
469+
)
470+
else:
471+
raise ValueError(
472+
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
473+
)
474+
475+
if "max_gen_toks" in kwargs.keys():
476+
max_gen_toks = kwargs.pop("max_gen_toks")
477+
else:
478+
max_gen_toks = self.max_gen_toks
479+
480+
response = oa_chat_completion(
481+
client=self.client,
482+
messages=inps,
483+
model=self.model,
484+
frequency_penalty=self.frequency_penalty,
485+
# logit_bias=self.logit_bias,
486+
max_tokens=max_gen_toks,
487+
n=self.n,
488+
presence_penalty=self.presence_penalty,
489+
temperature=self.temperature,
490+
top_p=self.top_p,
491+
)
492+
493+
for resp, (context, args_) in zip(response.choices, chunk):
494+
s = resp.message.content
495+
496+
if until is not None:
497+
for term in until:
498+
if len(term) > 0:
499+
s = s.split(term)[0]
500+
501+
res[key].append(s)
502+
503+
self.cache_hook.add_partial(
504+
"generate_until", (context, {"until": until}), s
505+
)
506+
pbar.update(1)
507+
# reorder this group of results back to original unsorted form
508+
res[key] = re_ord.get_original(res[key])
509+
510+
pbar.close()
511+
512+
return grouper.get_original(res)
513+
514+
def loglikelihood(self, requests):
515+
raise NotImplementedError("No support for logits.")
516+
517+
def loglikelihood_rolling(self, requests):
518+
raise NotImplementedError("No support for logits.")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ promptsource = [
7070
]
7171
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
7272
anthropic = ["anthropic"]
73-
openai = ["openai", "tiktoken"]
73+
openai = ["openai>=1.3.5", "tiktoken"]
7474
vllm = ["vllm"]
7575
all = [
7676
"lm_eval[dev]",

tests/tests_master/test_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import hashlib
22
import json
3-
import openai
43
import os
54
import pickle
65
import pytest
76
import unittest.mock as mock
87

98
import lm_eval.models as models
109

10+
from openai import OpenAI
11+
12+
client = OpenAI()
13+
1114

1215
LOGLIKELIHOOD_TEST_CASES = [
1316
("The quick brown fox jumps over the lazy", " dog"),
@@ -172,7 +175,7 @@ def openai_mock_completion(**kwargs):
172175
if os.path.exists(fname):
173176
with open(fname, "rb") as fh:
174177
return pickle.load(fh)
175-
ret = openai.Completion.create(**kwargs)
178+
ret = client.completions.create(**kwargs)
176179
ret.api_key = ""
177180
with open(fname, "wb") as fh:
178181
pickle.dump(ret, fh)

0 commit comments

Comments
 (0)