19
19
import os
20
20
import time
21
21
import types
22
+ import uuid
22
23
import weakref
23
24
from asyncio .queues import Queue
24
25
from asyncio .tasks import wait_for
@@ -444,18 +445,30 @@ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
444
445
@log_async (logger = logger )
445
446
async def generate (self , prompt : str , * args , ** kwargs ):
446
447
if self .allow_batching ():
448
+ # not support request_id
449
+ kwargs .pop ("request_id" , None )
447
450
return await self .handle_batching_request (
448
451
prompt , "generate" , * args , ** kwargs
449
452
)
450
453
else :
451
454
kwargs .pop ("raw_params" , None )
452
455
if hasattr (self ._model , "generate" ):
456
+ # not support request_id
457
+ kwargs .pop ("request_id" , None )
453
458
return await self ._call_wrapper_json (
454
459
self ._model .generate , prompt , * args , ** kwargs
455
460
)
456
461
if hasattr (self ._model , "async_generate" ):
462
+ if "request_id" not in kwargs :
463
+ kwargs ["request_id" ] = str (uuid .uuid1 ())
464
+ else :
465
+ # model only accept string
466
+ kwargs ["request_id" ] = str (kwargs ["request_id" ])
457
467
return await self ._call_wrapper_json (
458
- self ._model .async_generate , prompt , * args , ** kwargs
468
+ self ._model .async_generate ,
469
+ prompt ,
470
+ * args ,
471
+ ** kwargs ,
459
472
)
460
473
raise AttributeError (f"Model { self ._model .model_spec } is not for generate." )
461
474
@@ -534,17 +547,26 @@ async def chat(self, messages: List[Dict], *args, **kwargs):
534
547
response = None
535
548
try :
536
549
if self .allow_batching ():
550
+ # not support request_id
551
+ kwargs .pop ("request_id" , None )
537
552
return await self .handle_batching_request (
538
553
messages , "chat" , * args , ** kwargs
539
554
)
540
555
else :
541
556
kwargs .pop ("raw_params" , None )
542
557
if hasattr (self ._model , "chat" ):
558
+ # not support request_id
559
+ kwargs .pop ("request_id" , None )
543
560
response = await self ._call_wrapper_json (
544
561
self ._model .chat , messages , * args , ** kwargs
545
562
)
546
563
return response
547
564
if hasattr (self ._model , "async_chat" ):
565
+ if "request_id" not in kwargs :
566
+ kwargs ["request_id" ] = str (uuid .uuid1 ())
567
+ else :
568
+ # model only accept string
569
+ kwargs ["request_id" ] = str (kwargs ["request_id" ])
548
570
response = await self ._call_wrapper_json (
549
571
self ._model .async_chat , messages , * args , ** kwargs
550
572
)
@@ -577,9 +599,10 @@ async def abort_request(self, request_id: str) -> str:
577
599
return await self ._scheduler_ref .abort_request (request_id )
578
600
return AbortRequestMessage .NO_OP .name
579
601
580
- @log_async (logger = logger )
581
602
@request_limit
603
+ @log_async (logger = logger )
582
604
async def create_embedding (self , input : Union [str , List [str ]], * args , ** kwargs ):
605
+ kwargs .pop ("request_id" , None )
583
606
if hasattr (self ._model , "create_embedding" ):
584
607
return await self ._call_wrapper_json (
585
608
self ._model .create_embedding , input , * args , ** kwargs
@@ -589,8 +612,8 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
589
612
f"Model { self ._model .model_spec } is not for creating embedding."
590
613
)
591
614
592
- @log_async (logger = logger )
593
615
@request_limit
616
+ @log_async (logger = logger )
594
617
async def rerank (
595
618
self ,
596
619
documents : List [str ],
@@ -602,6 +625,7 @@ async def rerank(
602
625
* args ,
603
626
** kwargs ,
604
627
):
628
+ kwargs .pop ("request_id" , None )
605
629
if hasattr (self ._model , "rerank" ):
606
630
return await self ._call_wrapper_json (
607
631
self ._model .rerank ,
@@ -616,8 +640,8 @@ async def rerank(
616
640
)
617
641
raise AttributeError (f"Model { self ._model .model_spec } is not for reranking." )
618
642
619
- @log_async (logger = logger , args_formatter = lambda _ , kwargs : kwargs .pop ("audio" ))
620
643
@request_limit
644
+ @log_async (logger = logger , ignore_kwargs = ["audio" ])
621
645
async def transcriptions (
622
646
self ,
623
647
audio : bytes ,
@@ -626,7 +650,9 @@ async def transcriptions(
626
650
response_format : str = "json" ,
627
651
temperature : float = 0 ,
628
652
timestamp_granularities : Optional [List [str ]] = None ,
653
+ ** kwargs ,
629
654
):
655
+ kwargs .pop ("request_id" , None )
630
656
if hasattr (self ._model , "transcriptions" ):
631
657
return await self ._call_wrapper_json (
632
658
self ._model .transcriptions ,
@@ -641,8 +667,8 @@ async def transcriptions(
641
667
f"Model { self ._model .model_spec } is not for creating transcriptions."
642
668
)
643
669
644
- @log_async (logger = logger , args_formatter = lambda _ , kwargs : kwargs .pop ("audio" ))
645
670
@request_limit
671
+ @log_async (logger = logger , ignore_kwargs = ["audio" ])
646
672
async def translations (
647
673
self ,
648
674
audio : bytes ,
@@ -651,7 +677,9 @@ async def translations(
651
677
response_format : str = "json" ,
652
678
temperature : float = 0 ,
653
679
timestamp_granularities : Optional [List [str ]] = None ,
680
+ ** kwargs ,
654
681
):
682
+ kwargs .pop ("request_id" , None )
655
683
if hasattr (self ._model , "translations" ):
656
684
return await self ._call_wrapper_json (
657
685
self ._model .translations ,
@@ -668,10 +696,7 @@ async def translations(
668
696
669
697
@request_limit
670
698
@xo .generator
671
- @log_async (
672
- logger = logger ,
673
- args_formatter = lambda _ , kwargs : kwargs .pop ("prompt_speech" , None ),
674
- )
699
+ @log_async (logger = logger , ignore_kwargs = ["prompt_speech" ])
675
700
async def speech (
676
701
self ,
677
702
input : str ,
@@ -681,6 +706,7 @@ async def speech(
681
706
stream : bool = False ,
682
707
** kwargs ,
683
708
):
709
+ kwargs .pop ("request_id" , None )
684
710
if hasattr (self ._model , "speech" ):
685
711
return await self ._call_wrapper_binary (
686
712
self ._model .speech ,
@@ -695,8 +721,8 @@ async def speech(
695
721
f"Model { self ._model .model_spec } is not for creating speech."
696
722
)
697
723
698
- @log_async (logger = logger )
699
724
@request_limit
725
+ @log_async (logger = logger )
700
726
async def text_to_image (
701
727
self ,
702
728
prompt : str ,
@@ -706,6 +732,7 @@ async def text_to_image(
706
732
* args ,
707
733
** kwargs ,
708
734
):
735
+ kwargs .pop ("request_id" , None )
709
736
if hasattr (self ._model , "text_to_image" ):
710
737
return await self ._call_wrapper_json (
711
738
self ._model .text_to_image ,
@@ -720,6 +747,10 @@ async def text_to_image(
720
747
f"Model { self ._model .model_spec } is not for creating image."
721
748
)
722
749
750
+ @log_async (
751
+ logger = logger ,
752
+ ignore_kwargs = ["image" ],
753
+ )
723
754
async def image_to_image (
724
755
self ,
725
756
image : "PIL.Image" ,
@@ -731,6 +762,7 @@ async def image_to_image(
731
762
* args ,
732
763
** kwargs ,
733
764
):
765
+ kwargs .pop ("request_id" , None )
734
766
if hasattr (self ._model , "image_to_image" ):
735
767
return await self ._call_wrapper_json (
736
768
self ._model .image_to_image ,
@@ -747,6 +779,10 @@ async def image_to_image(
747
779
f"Model { self ._model .model_spec } is not for creating image."
748
780
)
749
781
782
+ @log_async (
783
+ logger = logger ,
784
+ ignore_kwargs = ["image" ],
785
+ )
750
786
async def inpainting (
751
787
self ,
752
788
image : "PIL.Image" ,
@@ -759,6 +795,7 @@ async def inpainting(
759
795
* args ,
760
796
** kwargs ,
761
797
):
798
+ kwargs .pop ("request_id" , None )
762
799
if hasattr (self ._model , "inpainting" ):
763
800
return await self ._call_wrapper_json (
764
801
self ._model .inpainting ,
@@ -776,12 +813,13 @@ async def inpainting(
776
813
f"Model { self ._model .model_spec } is not for creating image."
777
814
)
778
815
779
- @log_async (logger = logger )
780
816
@request_limit
817
+ @log_async (logger = logger , ignore_kwargs = ["image" ])
781
818
async def infer (
782
819
self ,
783
820
** kwargs ,
784
821
):
822
+ kwargs .pop ("request_id" , None )
785
823
if hasattr (self ._model , "infer" ):
786
824
return await self ._call_wrapper_json (
787
825
self ._model .infer ,
@@ -791,15 +829,16 @@ async def infer(
791
829
f"Model { self ._model .model_spec } is not for flexible infer."
792
830
)
793
831
794
- @log_async (logger = logger )
795
832
@request_limit
833
+ @log_async (logger = logger )
796
834
async def text_to_video (
797
835
self ,
798
836
prompt : str ,
799
837
n : int = 1 ,
800
838
* args ,
801
839
** kwargs ,
802
840
):
841
+ kwargs .pop ("request_id" , None )
803
842
if hasattr (self ._model , "text_to_video" ):
804
843
return await self ._call_wrapper_json (
805
844
self ._model .text_to_video ,
0 commit comments