23
23
24
24
25
25
from pydantic import BaseModel
26
-
26
+ from fastapi .exceptions import RequestValidationError
27
+ from fastapi .responses import JSONResponse
28
+ from tools .normalizer .en import normalizer_en_nemo_text
29
+ from tools .normalizer .zh import normalizer_zh_tn
27
30
28
31
logger = get_logger ("Command" )
29
32
@@ -35,14 +38,23 @@ async def startup_event():
35
38
global chat
36
39
37
40
chat = ChatTTS .Chat (get_logger ("ChatTTS" ))
41
+ chat .normalizer .register ("en" , normalizer_en_nemo_text ())
42
+ chat .normalizer .register ("zh" , normalizer_zh_tn ())
43
+
38
44
logger .info ("Initializing ChatTTS..." )
39
- if chat .load ():
45
+ if chat .load (source = "huggingface" ):
40
46
logger .info ("Models loaded successfully." )
41
47
else :
42
48
logger .error ("Models load failed." )
43
49
sys .exit (1 )
44
50
45
51
52
+ @app .exception_handler (RequestValidationError )
53
+ async def validation_exception_handler (request , exc : RequestValidationError ):
54
+ logger .error (f"Validation error: { exc .errors ()} " )
55
+ return JSONResponse (status_code = 422 , content = {"detail" : exc .errors ()})
56
+
57
+
46
58
class ChatTTSParams (BaseModel ):
47
59
text : list [str ]
48
60
stream : bool = False
@@ -52,7 +64,7 @@ class ChatTTSParams(BaseModel):
52
64
use_decoder : bool = True
53
65
do_text_normalization : bool = True
54
66
do_homophone_replacement : bool = False
55
- params_refine_text : ChatTTS .Chat .RefineTextParams
67
+ params_refine_text : ChatTTS .Chat .RefineTextParams = None
56
68
params_infer_code : ChatTTS .Chat .InferCodeParams
57
69
58
70
0 commit comments