Skip to content

Commit aaea2ae

Browse files
authored
fix: allow not params refine text, and load normalizers to handle chinese numbers (#865)
1 parent a933b66 commit aaea2ae

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

examples/api/main.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323

2424

2525
from pydantic import BaseModel
26-
26+
from fastapi.exceptions import RequestValidationError
27+
from fastapi.responses import JSONResponse
28+
from tools.normalizer.en import normalizer_en_nemo_text
29+
from tools.normalizer.zh import normalizer_zh_tn
2730

2831
logger = get_logger("Command")
2932

@@ -35,14 +38,23 @@ async def startup_event():
3538
global chat
3639

3740
chat = ChatTTS.Chat(get_logger("ChatTTS"))
41+
chat.normalizer.register("en", normalizer_en_nemo_text())
42+
chat.normalizer.register("zh", normalizer_zh_tn())
43+
3844
logger.info("Initializing ChatTTS...")
39-
if chat.load():
45+
if chat.load(source="huggingface"):
4046
logger.info("Models loaded successfully.")
4147
else:
4248
logger.error("Models load failed.")
4349
sys.exit(1)
4450

4551

52+
@app.exception_handler(RequestValidationError)
53+
async def validation_exception_handler(request, exc: RequestValidationError):
54+
logger.error(f"Validation error: {exc.errors()}")
55+
return JSONResponse(status_code=422, content={"detail": exc.errors()})
56+
57+
4658
class ChatTTSParams(BaseModel):
4759
text: list[str]
4860
stream: bool = False
@@ -52,7 +64,7 @@ class ChatTTSParams(BaseModel):
5264
use_decoder: bool = True
5365
do_text_normalization: bool = True
5466
do_homophone_replacement: bool = False
55-
params_refine_text: ChatTTS.Chat.RefineTextParams
67+
params_refine_text: ChatTTS.Chat.RefineTextParams = None
5668
params_infer_code: ChatTTS.Chat.InferCodeParams
5769

5870

0 commit comments

Comments
 (0)