1
1
import os
2
2
import time
3
3
from typing import List , Tuple
4
+
5
+ import copy
6
+ from collections import defaultdict
4
7
from tqdm import tqdm
8
+
5
9
from lm_eval import utils
6
10
from lm_eval .api .model import LM
7
11
from lm_eval .api .registry import register_model
@@ -51,7 +55,7 @@ def oa_completion(**kwargs):
51
55
backoff_time = 3
52
56
while True :
53
57
try :
54
- return openai .Completion .create (** kwargs )
58
+ return openai .Completions .create (** kwargs )
55
59
except openai .error .OpenAIError :
56
60
import traceback
57
61
@@ -60,7 +64,7 @@ def oa_completion(**kwargs):
60
64
backoff_time *= 1.5
61
65
62
66
63
- @register_model ("openai" , "openai-completions" , " gooseai" )
67
+ @register_model ("gooseai" )
64
68
class OpenaiCompletionsLM (LM ):
65
69
REQ_CHUNK_SIZE = 20
66
70
@@ -304,3 +308,211 @@ def loglikelihood_rolling(self, requests) -> List[float]:
304
308
string_nll = sum (string_nll )
305
309
loglikelihoods .append (string_nll )
306
310
return loglikelihoods
311
+
312
+
313
+ def oa_chat_completion (client , ** kwargs ):
314
+ """Query OpenAI API for chat completion.
315
+
316
+ Retry with back-off until they respond
317
+ """
318
+ try :
319
+ import openai , tiktoken # noqa: E401
320
+ except ModuleNotFoundError :
321
+ raise Exception (
322
+ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
323
+ please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`" ,
324
+ )
325
+
326
+ async def _get_completions (** kwargs ):
327
+ chat_completions = await client .chat .completions .create (** kwargs )
328
+ return chat_completions
329
+
330
+ backoff_time = 3
331
+ while True :
332
+ try :
333
+ return client .chat .completions .create (** kwargs )
334
+ except openai .OpenAIError :
335
+ import traceback
336
+
337
+ traceback .print_exc ()
338
+ time .sleep (backoff_time )
339
+ backoff_time *= 1.5
340
+
341
+
342
+ @register_model ("openai-chat-completions" )
343
+ class OpenaiChatCompletionsLM (LM ):
344
+ def __init__ (
345
+ self , model : str = "gpt-3.5-turbo" , truncate : bool = False , batch_size : int = 1
346
+ ) -> None :
347
+ """
348
+
349
+ :param model: str
350
+ OpenAI API model (e.g. gpt-3.5-turbo)
351
+ :param truncate: bool
352
+ Truncate input if too long (if False and input is too long, throw error)
353
+ """
354
+ super ().__init__ ()
355
+ try :
356
+ import openai , tiktoken # noqa: E401
357
+ except ModuleNotFoundError :
358
+ raise Exception (
359
+ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
360
+ please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`" ,
361
+ )
362
+ self .model = model
363
+ self .frequency_penalty = 0
364
+ self .logit_bias = None
365
+ self .n = 1
366
+ self .presence_penalty = 0
367
+ self .temperature = 1
368
+ self .top_p = 1
369
+ self .tokenizer = tiktoken .encoding_for_model (self .model )
370
+ self .vocab_size = self .tokenizer .n_vocab
371
+ self .truncate = truncate
372
+ self .end_of_text_token_id = self .tokenizer .eot_token
373
+
374
+ # Read from environment variable OPENAI_API_KEY
375
+ self .client = openai .OpenAI () # openai.AsyncOpenAI()
376
+
377
+ @property
378
+ def eot_token_id (self ):
379
+ return self .end_of_text_token_id
380
+
381
+ @property
382
+ def max_length (self ) -> int :
383
+ # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
384
+ return 2048
385
+
386
+ @property
387
+ def max_gen_toks (self ) -> int :
388
+ return 256
389
+
390
+ @property
391
+ def batch_size (self ):
392
+ # Isn't used because we override _loglikelihood_tokens
393
+ raise NotImplementedError ()
394
+
395
+ @property
396
+ def device (self ):
397
+ # Isn't used because we override _loglikelihood_tokens
398
+ raise NotImplementedError ()
399
+
400
+ def tok_encode (self , string : str ) -> List [int ]:
401
+ return self .tokenizer .encode (string )
402
+
403
+ def tok_decode (self , tokens : List [int ]) -> str :
404
+ return self .tokenizer .decode (tokens )
405
+
406
+ def _encode_pair (
407
+ self , context : str , continuation : str
408
+ ) -> Tuple [List [int ], List [int ]]:
409
+ n_spaces = len (context ) - len (context .rstrip ())
410
+ if n_spaces > 0 :
411
+ continuation = context [- n_spaces :] + continuation
412
+ context = context [:- n_spaces ]
413
+ whole_enc = self .tok_encode (context + continuation )
414
+ context_enc = self .tok_encode (context )
415
+ context_enc_len = len (context_enc )
416
+ continuation_enc = whole_enc [context_enc_len :]
417
+ return context_enc , continuation_enc
418
+
419
+ def generate_until (self , requests ) -> List [str ]:
420
+ res = defaultdict (list )
421
+ re_ords = {}
422
+
423
+ def _collate (x ):
424
+ toks = self .tok_encode (x [0 ])
425
+ return - len (toks ), x [0 ]
426
+
427
+ # we group requests by their generation_kwargs,
428
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
429
+ # in the same batch.
430
+ grouper = utils .Grouper (requests , lambda x : str (x .args [1 ]))
431
+ for key , reqs in grouper .get_grouped ().items ():
432
+ # within each set of reqs for given kwargs, we reorder by token length, descending.
433
+ re_ords [key ] = utils .Reorderer ([req .args for req in reqs ], _collate )
434
+
435
+ def sameuntil_chunks (xs , size ):
436
+ ret = []
437
+ lastuntil = xs [0 ][1 ]
438
+ for x in xs :
439
+ if len (ret ) >= size or x [1 ] != lastuntil :
440
+ yield ret , lastuntil
441
+ ret = []
442
+ lastuntil = x [1 ]
443
+ ret .append (x )
444
+
445
+ if ret :
446
+ yield ret , lastuntil
447
+
448
+ pbar = tqdm (total = len (requests ), disable = (self .rank != 0 ))
449
+ for key , re_ord in re_ords .items ():
450
+ # n needs to be 1 because messages in
451
+ # chat completion are not batch but
452
+ # is regarded as a single conversation.
453
+ chunks = utils .chunks (re_ord .get_reordered (), n = 1 )
454
+ for chunk in chunks :
455
+ contexts , all_gen_kwargs = zip (* chunk )
456
+ inps = [{"role" : "user" , "content" : context } for context in contexts ]
457
+
458
+ gen_kwargs = all_gen_kwargs [0 ]
459
+ until = None
460
+ if isinstance (gen_kwargs , dict ):
461
+ kwargs = copy .deepcopy (gen_kwargs ) # edge case for repeats > 1
462
+ if "until" in kwargs .keys ():
463
+ until = kwargs .pop ("until" )
464
+ if isinstance (until , str ):
465
+ until = [kwargs ]
466
+ elif not isinstance (until , list ):
467
+ raise ValueError (
468
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got { until } "
469
+ )
470
+ else :
471
+ raise ValueError (
472
+ f"Expected `kwargs` to be of type `dict` but got { kwargs } "
473
+ )
474
+
475
+ if "max_gen_toks" in kwargs .keys ():
476
+ max_gen_toks = kwargs .pop ("max_gen_toks" )
477
+ else :
478
+ max_gen_toks = self .max_gen_toks
479
+
480
+ response = oa_chat_completion (
481
+ client = self .client ,
482
+ messages = inps ,
483
+ model = self .model ,
484
+ frequency_penalty = self .frequency_penalty ,
485
+ # logit_bias=self.logit_bias,
486
+ max_tokens = max_gen_toks ,
487
+ n = self .n ,
488
+ presence_penalty = self .presence_penalty ,
489
+ temperature = self .temperature ,
490
+ top_p = self .top_p ,
491
+ )
492
+
493
+ for resp , (context , args_ ) in zip (response .choices , chunk ):
494
+ s = resp .message .content
495
+
496
+ if until is not None :
497
+ for term in until :
498
+ if len (term ) > 0 :
499
+ s = s .split (term )[0 ]
500
+
501
+ res [key ].append (s )
502
+
503
+ self .cache_hook .add_partial (
504
+ "generate_until" , (context , {"until" : until }), s
505
+ )
506
+ pbar .update (1 )
507
+ # reorder this group of results back to original unsorted form
508
+ res [key ] = re_ord .get_original (res [key ])
509
+
510
+ pbar .close ()
511
+
512
+ return grouper .get_original (res )
513
+
514
+ def loglikelihood (self , requests ):
515
+ raise NotImplementedError ("No support for logits." )
516
+
517
+ def loglikelihood_rolling (self , requests ):
518
+ raise NotImplementedError ("No support for logits." )
0 commit comments