From 4ad9c41c1696923b5b1735cd7a21fb908de1a3fd Mon Sep 17 00:00:00 2001 From: linyq Date: Fri, 22 Nov 2024 18:30:49 +0800 Subject: [PATCH] =?UTF-8?q?refactor(webui):=20=E4=BC=98=E5=8C=96=E8=A7=86?= =?UTF-8?q?=E8=A7=89=E5=88=86=E6=9E=90=E6=89=B9=E6=AC=A1=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 提取 vision_batch_size 到单独变量,提高代码可读性 - 使用 vision_batch_size 替代多次调用 config(frames.get("vision_batch_size") - 添加调试日志,记录批次数量和每批次的图片数量 --- .github/workflows/dockerImageBuild.yml | 1 + app/utils/vision_analyzer.py | 4 +++- webui/components/script_settings.py | 9 +++++---- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/dockerImageBuild.yml b/.github/workflows/dockerImageBuild.yml index 3089eb0..6c4755e 100644 --- a/.github/workflows/dockerImageBuild.yml +++ b/.github/workflows/dockerImageBuild.yml @@ -3,6 +3,7 @@ name: build_docker on: release: types: [created] # 表示在创建新的 Release 时触发 + workflow_dispatch: jobs: build_docker: diff --git a/app/utils/vision_analyzer.py b/app/utils/vision_analyzer.py index 8024729..06342d7 100644 --- a/app/utils/vision_analyzer.py +++ b/app/utils/vision_analyzer.py @@ -55,7 +55,7 @@ async def _generate_content_with_retry(self, prompt, batch): async def analyze_images(self, images: Union[List[str], List[PIL.Image.Image]], prompt: str, - batch_size: int = 5) -> List[Dict]: + batch_size: int) -> List[Dict]: """批量分析多张图片""" try: # 加载图片 @@ -82,6 +82,8 @@ async def analyze_images(self, results = [] total_batches = (len(images) + batch_size - 1) // batch_size + logger.debug(f"共 {total_batches} 个批次,每批次 {batch_size} 张图片") + with tqdm(total=total_batches, desc="分析进度") as pbar: for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] diff --git a/webui/components/script_settings.py b/webui/components/script_settings.py index 4edf5c6..30c23d3 100644 --- a/webui/components/script_settings.py +++ b/webui/components/script_settings.py @@ -417,11 +417,12 @@ def update_progress(progress: float, message: str = ""): asyncio.set_event_loop(loop) # 执行异步分析 + vision_batch_size = st.session_state.get('vision_batch_size') or config.frames.get("vision_batch_size") results = loop.run_until_complete( analyzer.analyze_images( images=keyframe_files, prompt=config.app.get('vision_analysis_prompt'), - batch_size=config.frames.get("vision_batch_size", st.session_state.get('vision_batch_size', 5)) + batch_size=vision_batch_size ) ) loop.close() @@ -437,8 +438,8 @@ def update_progress(progress: float, message: str = ""): if 'error' in result: logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") continue - - batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5)) + # 获取当前批次的文件列表 + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) logger.debug(f"批次 {result['batch_index']} 处理完成,共 {len(batch_files)} 张图片") logger.debug(batch_files) @@ -477,7 +478,7 @@ def update_progress(progress: float, message: str = ""): if 'error' in result: continue - batch_files = get_batch_files(keyframe_files, result, config.frames.get("vision_batch_size", 5)) + batch_files = get_batch_files(keyframe_files, result, vision_batch_size) _, _, timestamp_range = get_batch_timestamps(batch_files, prev_batch_files) frame_content = {