@@ -441,6 +441,9 @@ def run(
441
441
442
442
"""
443
443
444
+ def pre_setup (self , lit_api : LitAPI , spec : Optional [LitSpec ]):
445
+ pass
446
+
444
447
def __call__ (
445
448
self ,
446
449
lit_api : LitAPI ,
@@ -487,7 +490,109 @@ def run(
487
490
raise NotImplementedError
488
491
489
492
490
- class SingleLoop (_BaseLoop ):
493
+ class LitLoop (_BaseLoop ):
494
+ def __init__ (self ):
495
+ self ._context = {}
496
+
497
+ def get_batch_requests (self , lit_api : LitAPI , request_queue : Queue , max_batch_size : int , batch_timeout : float ):
498
+ if max_batch_size <= 1 :
499
+ raise ValueError ("max_batch_size must be greater than 1" )
500
+
501
+ batches , timed_out_uids = collate_requests (
502
+ lit_api ,
503
+ request_queue ,
504
+ max_batch_size ,
505
+ batch_timeout ,
506
+ )
507
+ return batches , timed_out_uids
508
+
509
+ def get_request (self , request_queue : Queue , timeout : float = 1.0 ):
510
+ response_queue_id , uid , timestamp , x_enc = request_queue .get (timeout = timeout )
511
+ return response_queue_id , uid , timestamp , x_enc
512
+
513
+ def populate_context (self , lit_spec : LitSpec , request : Any ):
514
+ if lit_spec and hasattr (lit_spec , "populate_context" ):
515
+ lit_spec .populate_context (self ._context , request )
516
+
517
+ def put_response (
518
+ self , response_queues : List [Queue ], response_queue_id : int , uid : str , response_data : Any , status : LitAPIStatus
519
+ ) -> None :
520
+ response_queues [response_queue_id ].put ((uid , (response_data , status )))
521
+
522
+ def put_error_response (
523
+ self , response_queues : List [Queue ], response_queue_id : int , uid : str , error : Exception
524
+ ) -> None :
525
+ response_queues [response_queue_id ].put ((uid , (error , LitAPIStatus .ERROR )))
526
+
527
+
528
+ class DefaultLoop (LitLoop ):
529
+ def pre_setup (self , lit_api : LitAPI , spec : Optional [LitSpec ]):
530
+ # we will sanitize regularly if no spec
531
+ # in case, we have spec then:
532
+ # case 1: spec implements a streaming API
533
+ # Case 2: spec implements a non-streaming API
534
+ if spec :
535
+ # TODO: Implement sanitization
536
+ lit_api ._spec = spec
537
+ return
538
+
539
+ original = lit_api .unbatch .__code__ is LitAPI .unbatch .__code__
540
+ if (
541
+ lit_api .stream
542
+ and lit_api .max_batch_size > 1
543
+ and not all ([
544
+ inspect .isgeneratorfunction (lit_api .predict ),
545
+ inspect .isgeneratorfunction (lit_api .encode_response ),
546
+ (original or inspect .isgeneratorfunction (lit_api .unbatch )),
547
+ ])
548
+ ):
549
+ raise ValueError (
550
+ """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
551
+ `lit_api.unbatch` must generate values using `yield`.
552
+
553
+ Example:
554
+
555
+ def predict(self, inputs):
556
+ ...
557
+ for i in range(max_token_length):
558
+ yield prediction
559
+
560
+ def encode_response(self, outputs):
561
+ for output in outputs:
562
+ encoded_output = ...
563
+ yield encoded_output
564
+
565
+ def unbatch(self, outputs):
566
+ for output in outputs:
567
+ unbatched_output = ...
568
+ yield unbatched_output
569
+ """
570
+ )
571
+
572
+ if lit_api .stream and not all ([
573
+ inspect .isgeneratorfunction (lit_api .predict ),
574
+ inspect .isgeneratorfunction (lit_api .encode_response ),
575
+ ]):
576
+ raise ValueError (
577
+ """When `stream=True` both `lit_api.predict` and
578
+ `lit_api.encode_response` must generate values using `yield`.
579
+
580
+ Example:
581
+
582
+ def predict(self, inputs):
583
+ ...
584
+ for i in range(max_token_length):
585
+ yield prediction
586
+
587
+ def encode_response(self, outputs):
588
+ for output in outputs:
589
+ encoded_output = ...
590
+ yield encoded_output
591
+ """
592
+ )
593
+
594
+
595
+ class SingleLoop (DefaultLoop ):
491
596
def __call__ (
492
597
self ,
493
598
lit_api : LitAPI ,
@@ -505,7 +610,7 @@ def __call__(
505
610
run_single_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner )
506
611
507
612
508
- class BatchedLoop (_BaseLoop ):
613
+ class BatchedLoop (DefaultLoop ):
509
614
def __call__ (
510
615
self ,
511
616
lit_api : LitAPI ,
@@ -531,7 +636,7 @@ def __call__(
531
636
)
532
637
533
638
534
- class StreamingLoop (_BaseLoop ):
639
+ class StreamingLoop (DefaultLoop ):
535
640
def __call__ (
536
641
self ,
537
642
lit_api : LitAPI ,
@@ -549,7 +654,7 @@ def __call__(
549
654
run_streaming_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner )
550
655
551
656
552
- class BatchedStreamingLoop (_BaseLoop ):
657
+ class BatchedStreamingLoop (DefaultLoop ):
553
658
def __call__ (
554
659
self ,
555
660
lit_api : LitAPI ,
@@ -593,41 +698,6 @@ class Output:
593
698
status : LitAPIStatus
594
699
595
700
596
- class LitLoop (_BaseLoop ):
597
- def __init__ (self ):
598
- self ._context = {}
599
-
600
- def get_batch_requests (self , lit_api : LitAPI , request_queue : Queue , max_batch_size : int , batch_timeout : float ):
601
- if max_batch_size <= 1 :
602
- raise ValueError ("max_batch_size must be greater than 1" )
603
-
604
- batches , timed_out_uids = collate_requests (
605
- lit_api ,
606
- request_queue ,
607
- max_batch_size ,
608
- batch_timeout ,
609
- )
610
- return batches , timed_out_uids
611
-
612
- def get_request (self , request_queue : Queue , timeout : float = 1.0 ):
613
- response_queue_id , uid , timestamp , x_enc = request_queue .get (timeout = timeout )
614
- return response_queue_id , uid , timestamp , x_enc
615
-
616
- def populate_context (self , lit_spec : LitSpec , request : Any ):
617
- if lit_spec and hasattr (lit_spec , "populate_context" ):
618
- lit_spec .populate_context (self ._context , request )
619
-
620
- def put_response (
621
- self , response_queues : List [Queue ], response_queue_id : int , uid : str , response_data : Any , status : LitAPIStatus
622
- ) -> None :
623
- response_queues [response_queue_id ].put ((uid , (response_data , status )))
624
-
625
- def put_error_response (
626
- self , response_queues : List [Queue ], response_queue_id : int , uid : str , error : Exception
627
- ) -> None :
628
- response_queues [response_queue_id ].put ((uid , (error , LitAPIStatus .ERROR )))
629
-
630
-
631
701
class ContinuousBatchingLoop (LitLoop ):
632
702
def __init__ (self , max_sequence_length : int = 2048 ):
633
703
super ().__init__ ()
@@ -840,15 +910,7 @@ def inference_worker(
840
910
logging .info (f"LitServe will use { lit_spec .__class__ .__name__ } spec" )
841
911
842
912
if loop == "auto" :
843
- loop = (
844
- BatchedStreamingLoop ()
845
- if stream and max_batch_size > 1
846
- else StreamingLoop ()
847
- if stream
848
- else BatchedLoop ()
849
- if max_batch_size > 1
850
- else SingleLoop ()
851
- )
913
+ loop = get_default_loop (stream , max_batch_size )
852
914
853
915
loop (
854
916
lit_api ,
@@ -863,3 +925,15 @@ def inference_worker(
863
925
workers_setup_status ,
864
926
callback_runner ,
865
927
)
928
+
929
+
930
+ def get_default_loop (stream : bool , max_batch_size : int ) -> _BaseLoop :
931
+ return (
932
+ BatchedStreamingLoop ()
933
+ if stream and max_batch_size > 1
934
+ else StreamingLoop ()
935
+ if stream
936
+ else BatchedLoop ()
937
+ if max_batch_size > 1
938
+ else SingleLoop ()
939
+ )
0 commit comments