Skip to content

Commit

Permalink
Merge pull request #1008 from EleutherAI/openai_completions
Browse files Browse the repository at this point in the history
[Refactor] Openai completions
  • Loading branch information
lintangsutawika authored Dec 1, 2023
2 parents c3d97e4 + 5b42436 commit f1b64f6
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 5 deletions.
216 changes: 214 additions & 2 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import time
from typing import List, Tuple

import copy
from collections import defaultdict
from tqdm import tqdm

from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
Expand Down Expand Up @@ -51,7 +55,7 @@ def oa_completion(**kwargs):
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
return openai.Completions.create(**kwargs)
except openai.error.OpenAIError:
import traceback

Expand All @@ -60,7 +64,7 @@ def oa_completion(**kwargs):
backoff_time *= 1.5


@register_model("openai", "openai-completions", "gooseai")
@register_model("gooseai")
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20

Expand Down Expand Up @@ -304,3 +308,211 @@ def loglikelihood_rolling(self, requests) -> List[float]:
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods


def oa_chat_completion(client, **kwargs):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)

async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions

backoff_time = 3
while True:
try:
return client.chat.completions.create(**kwargs)
except openai.OpenAIError:
import traceback

traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5


@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None:
"""
:param model: str
OpenAI API model (e.g. gpt-3.5-turbo)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.model = model
self.frequency_penalty = 0
self.logit_bias = None
self.n = 1
self.presence_penalty = 0
self.temperature = 1
self.top_p = 1
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token

# Read from environment variable OPENAI_API_KEY
self.client = openai.OpenAI() # openai.AsyncOpenAI()

@property
def eot_token_id(self):
return self.end_of_text_token_id

@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048

@property
def max_gen_toks(self) -> int:
return 256

@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()

def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)

def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)

def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc

def generate_until(self, requests) -> List[str]:
res = defaultdict(list)
re_ords = {}

def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]

# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)

def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)

if ret:
yield ret, lastuntil

pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]

gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)

if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks

response = oa_chat_completion(
client=self.client,
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
)

for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content

if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]

res[key].append(s)

self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])

pbar.close()

return grouper.get_original(res)

def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")

def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ promptsource = [
]
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
anthropic = ["anthropic"]
openai = ["openai", "tiktoken"]
openai = ["openai>=1.3.5", "tiktoken"]
vllm = ["vllm"]
all = [
"lm_eval[dev]",
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_master/test_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import hashlib
import json
import openai
import os
import pickle
import pytest
import unittest.mock as mock

import lm_eval.models as models

from openai import OpenAI

client = OpenAI()


LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"),
Expand Down Expand Up @@ -172,7 +175,7 @@ def openai_mock_completion(**kwargs):
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret = client.completions.create(**kwargs)
ret.api_key = ""
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
Expand Down

0 comments on commit f1b64f6

Please sign in to comment.