Skip to content

Commit 32c1b95

Browse files
committed
feat: enhance CLI with additional options for processing and tracking, refactor some part of the code in utils
1 parent 04e4eff commit 32c1b95

File tree

3 files changed

+348
-90
lines changed

3 files changed

+348
-90
lines changed

src/prismtoolbox/cli/main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
from typing import Annotated
3+
14
import sys
25
import typer
36

@@ -23,15 +26,24 @@
2326
app.add_typer(app_preprocessing, name="preprocessing", help="Preprocessing commands of the PrismToolBox.")
2427

2528
@app.callback()
26-
def main(verbose: int = typer.Option(0, "--verbose", "-v", count=True, help="Increase verbosity")):
29+
def main(
30+
ctx: typer.Context,
31+
verbose: Annotated[int, typer.Option("-v", count=True, help="Increase verbosity.", show_default=False)] = 0,
32+
skip_errors: Annotated[bool, typer.Option(help="Skip slides which raised an error during processing.")] = False,
33+
skip_existing: Annotated[bool, typer.Option(help="Skip existing files.")] = True,
34+
tracking_file: Annotated[bool, typer.Option(help="Enable tracking file for processing status.")] = False,
35+
):
2736
if verbose >= 2:
2837
level = logging.DEBUG
2938
elif verbose == 1:
3039
level = logging.INFO
3140
else:
3241
level = logging.WARNING
3342
logging.basicConfig(level=level)
34-
43+
ctx.ensure_object(dict)
44+
ctx.obj["skip_errors"] = skip_errors
45+
ctx.obj["skip_existing"] = skip_existing
46+
ctx.obj["tracking_file"] = tracking_file
3547

3648
@app.command()
3749
def version():
Lines changed: 139 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1+
from typing import Annotated
2+
13
import os
24
import typer
35
import logging
6+
import pandas as pd
47

5-
from enum import Enum
6-
from typing import Annotated
78
from pathlib import Path
89

9-
from .utils import load_config_file
10+
from .utils import (
11+
Engine,
12+
ContoursExtension,
13+
PatchExtension,
14+
PatchMode,
15+
load_config_file,
16+
update_tracking_df,
17+
extract_contours_from_slide,
18+
extract_patches_from_slide,
19+
)
1020

1121
log = logging.getLogger(__name__)
1222

@@ -16,26 +26,10 @@
1626
no_args_is_help=True,
1727
add_completion=True,
1828
rich_markup_mode="rich")
19-
20-
class Engine(str, Enum):
21-
openslide = "openslide"
22-
tiffslide = "tiffslide"
23-
24-
class ContoursExtension(str, Enum):
25-
geojson = "geojson"
26-
pickle = "pickle"
27-
28-
class PatchExtension(str, Enum):
29-
h5 = "h5"
30-
geojson = "geojson"
31-
32-
class PatchMode(str, Enum):
33-
contours = "contours"
34-
roi = "roi"
35-
all = "all"
36-
29+
3730
@app_preprocessing.command(no_args_is_help=True)
3831
def contour(
32+
ctx: typer.Context,
3933
slide_directory: Annotated[str, typer.Argument(
4034
help="Path to the directory containing the files.")],
4135
results_directory: Annotated[str, typer.Argument(
@@ -46,19 +40,18 @@ def contour(
4640
annotations_directory: Annotated[str | None, typer.Option(
4741
help="Path to the directory containing the annotations."
4842
)] = None,
49-
contours_exts: Annotated[list[ContoursExtension], typer.Option(
50-
help="File extension for the contours annotations.",
43+
contour_exts: Annotated[list[ContoursExtension], typer.Option(
44+
help="File extensions for the contours annotations.",
5145
case_sensitive=False,
5246
)] = [ContoursExtension.pickle],
53-
config_file: Annotated[str, typer.Option(
47+
config_file: Annotated[str | None, typer.Option(
5448
help="Path to the configuration file for tissue extraction."
5549
)] = None,
5650
visualize: Annotated[bool, typer.Option(
5751
help="Visualize the contours extracted.",
5852
)] = False,
5953
):
6054
"""Extract tissue contours from the slides in a specified directory."""
61-
import prismtoolbox as ptb
6255

6356
# Set default parameters
6457
params_detect_tissue = {
@@ -75,7 +68,7 @@ def contour(
7568
"line_thickness": 50,
7669
}
7770

78-
if not os.path.exists(config_file):
71+
if config_file is None or not os.path.exists(config_file):
7972
log.info(f"Using default parameters for tissue extraction.")
8073
else:
8174
log.info(f"Using parameters from config file: {config_file}")
@@ -95,31 +88,62 @@ def contour(
9588
directory_visualize = os.path.join(results_directory, f"contoured_images")
9689
Path(directory_visualize).mkdir(parents=True, exist_ok=True)
9790

91+
if ctx.obj.get("tracking_file", True):
92+
# Create an empty tracking file to keep track of the processing status
93+
tracking_file_path = os.path.join(results_directory, "tracking_contouring.csv")
94+
if not os.path.exists(tracking_file_path):
95+
# Create the tracking file if it does not exist
96+
col_names = ["file_name", "nb contours"] + list(params_detect_tissue.keys()) + ["status", "error message", "timestamp"]
97+
tracking_df = pd.DataFrame(columns=col_names)
98+
tracking_df.to_csv(tracking_file_path, index=False)
99+
98100
# Iterate over the files in the directory
99101
for file_name in os.listdir(slide_directory):
100-
# Load the image
101-
WSI_object = ptb.WSI(os.path.join(slide_directory, file_name), engine=engine)
102-
print(f"Processing {WSI_object.slide_name}...")
103-
104-
if f"{WSI_object.slide_name}.pkl" in os.listdir(directory_contours):
105-
continue
106-
107-
# Extract the contours from the image
108-
WSI_object.detect_tissue(**params_detect_tissue)
109-
# Apply pathologist annotations
110-
if annotations_directory is not None:
111-
WSI_object.apply_pathologist_annotations(os.path.join(annotations_directory, f"{WSI_object.slide_name}.geojson"))
112-
# Save extracted contours
113-
for contours_ext in contours_exts:
114-
WSI_object.save_tissue_contours(directory_contours, file_format=contours_ext)
115-
if visualize:
116-
# Visualize the extracted contours on the tissue
117-
img = WSI_object.visualize(**params_visualize_WSI)
118-
img.save(os.path.join(directory_visualize, f"{WSI_object.slide_name}.jpg"))
102+
slide_path = os.path.join(slide_directory, file_name)
103+
try:
104+
already_processed, nb_contours = extract_contours_from_slide(
105+
slide_path,
106+
engine,
107+
directory_contours,
108+
params_detect_tissue,
109+
contour_exts,
110+
annotations_directory,
111+
visualize,
112+
directory_visualize,
113+
params_visualize_WSI,
114+
ctx.obj["skip_existing"]
115+
)
116+
if ctx.obj.get("tracking_file", True):
117+
# Update the tracking file with the processing status
118+
update_tracking_df(
119+
tracking_file_path,
120+
file_name,
121+
("contours", nb_contours),
122+
params_dict=params_detect_tissue,
123+
already_processed=already_processed,
124+
)
125+
except Exception as e:
126+
if ctx.obj.get("tracking_file", True):
127+
# Update the tracking file with the processing status
128+
update_tracking_df(
129+
tracking_file_path,
130+
file_name,
131+
nb_objects=("contours", 0),
132+
params_dict=params_detect_tissue,
133+
error_message=str(e)
134+
)
135+
if ctx.obj.get("skip_errors", True):
136+
log.warning(f"Skipping slide {file_name}: {e}")
137+
continue
138+
else:
139+
log.error(f"Error processing slide {file_name}:")
140+
raise
141+
119142
print("Contours extracted and saved successfully.")
120143

121144
@app_preprocessing.command(no_args_is_help=True)
122145
def patchify(
146+
ctx: typer.Context,
123147
slide_directory: Annotated[str, typer.Argument(
124148
help="Path to the directory containing the files.")],
125149
results_directory: Annotated[str, typer.Argument(
@@ -134,34 +158,36 @@ def patchify(
134158
help="Engine to use for reading the slides.",
135159
case_sensitive=False)] = Engine.openslide,
136160
patch_exts: Annotated[list[PatchExtension], typer.Option(
137-
help="File extension for the patches.",
161+
help="File extensions for the patches.",
138162
case_sensitive=False,
139163
)] = [PatchExtension.h5],
140164
mode: Annotated[PatchMode, typer.Option(
141165
help="The mode to use for patch extraction. Possible values are 'contours', 'roi', and 'all'.",
142166
case_sensitive=False,
143167
)] = PatchMode.all,
144-
config_file: Annotated[str, typer.Option(
168+
config_file: Annotated[str | None, typer.Option(
145169
help="Path to the configuration file for patch extraction."
146170
)] = None,
147171
stitch: Annotated[bool, typer.Option(
148172
help="Whether to stitch the extracted patches into a single image for visualization.",
149173
)] = True,
174+
force_patch_extraction: Annotated[bool, typer.Option(
175+
help="Force patch extraction using mode 'all', even if no contours were detected when"
176+
"using 'contours' mode (useful if problem to extract contours on some slides).",
177+
)] = False,
178+
num_workers: Annotated[int, typer.Option(
179+
help="Number of workers to use for parallel processing.",
180+
min=1,
181+
)] = 10,
150182
):
151183
"""Extract patches from the slides in a specified directory."""
152-
import prismtoolbox as ptb
153-
assert mode == "contours" and contours_directory is not None, \
154-
"If the mode is 'contours', you must provide a directory with contours annotations. " \
155-
"Please use the `contour` command to extract contours first."
156-
assert mode == "roi" and roi_csv is not None, \
157-
"If the mode is 'roi', you must provide a file with the ROI coordinates."
158-
159184
# Set default parameters
160185
params_patches = {"patch_size": 256, "patch_level": 0, "overlap": 0,
161-
"units": ["px", "px"], "contours_mode": "four_pt", "rgb_threshs": [2, 240], "percentages": [0.6, 0.9]}
186+
"units": ["px", "px"], "contours_mode": "four_pt", "rgb_threshs": [2, 240],
187+
"percentages": [0.6, 0.9]}
162188
params_stitch_WSI = {"vis_level": 4, "draw_grid": False}
163189

164-
if not os.path.exists(config_file):
190+
if config_file is None or not os.path.exists(config_file):
165191
log.info(f"Using default parameters for tissue extraction.")
166192
else:
167193
log.info(f"Using parameters from config file: {config_file}")
@@ -173,7 +199,7 @@ def patchify(
173199
params_stitch_WSI = load_config_file(config_file,
174200
dict_to_update=params_stitch_WSI,
175201
key_to_check='stitching_settings')
176-
202+
params_patches = {**{"mode": mode}, **params_patches}
177203
# Path to the directory where the patches will be saved
178204
directory_patches = os.path.join(results_directory,
179205
f"patches_{params_patches['patch_size']}_overlap"
@@ -186,34 +212,62 @@ def patchify(
186212
f"stitched_images_{params_patches['patch_size']}_overlap"
187213
f"_{params_patches['overlap']}")
188214
Path(directory_stitch).mkdir(parents=True, exist_ok=True)
189-
215+
216+
if ctx.obj.get("tracking_file", True):
217+
# Create an empty tracking file to keep track of the processing status
218+
tracking_file_path = os.path.join(results_directory, "tracking_patching.csv")
219+
if not os.path.exists(tracking_file_path):
220+
# Create the tracking file if it does not exist
221+
col_names = ["file_name", "nb patches"] + list(params_patches.keys()) + ["status", "error message", "timestamp"]
222+
tracking_df = pd.DataFrame(columns=col_names)
223+
tracking_df.to_csv(tracking_file_path, index=False)
224+
190225
# Iterate over the files in the directory
191226
for file_name in os.listdir(slide_directory):
192-
# Load the image
193-
WSI_object = ptb.WSI(os.path.join(slide_directory, file_name), engine=engine)
194-
print(f"Processing {WSI_object.slide_name}...")
195-
196-
if f"{WSI_object.slide_name}.h5" in os.listdir(directory_patches):
197-
continue
198-
199-
if mode == "roi":
200-
# Set the region of interest for the image
201-
WSI_object.set_roi(rois_df_path=roi_csv)
202-
203-
elif mode == "contours":
204-
# Load the contours for the image
205-
WSI_object.load_tissue_contours(os.path.join(contours_directory, f"{WSI_object.slide_name}.pkl"))
206-
207-
# Extract patches from the contours
208-
WSI_object.extract_patches(mode=mode, **params_patches)
209-
# Save the extracted patches
210-
for patch_ext in patch_exts:
211-
WSI_object.save_patches(directory_patches, file_format=patch_ext)
212-
if stitch:
213-
# Stitch the extracted patches
214-
if params_stitch_WSI["vis_level"] >= len(WSI_object.level_dimensions):
227+
slide_path = os.path.join(slide_directory, file_name)
228+
try:
229+
already_processed, nb_patches = extract_patches_from_slide(
230+
slide_path,
231+
engine,
232+
directory_patches,
233+
params_patches,
234+
patch_exts,
235+
roi_csv,
236+
contours_directory,
237+
stitch,
238+
directory_stitch,
239+
params_stitch_WSI,
240+
force_patch_extraction,
241+
ctx.obj["skip_existing"],
242+
num_workers=num_workers
243+
)
244+
if ctx.obj.get("tracking_file", True):
245+
# Update the tracking file with the processing status
246+
update_tracking_df(
247+
tracking_file_path,
248+
file_name,
249+
("patches", nb_patches),
250+
params_dict=params_patches,
251+
already_processed=already_processed
252+
)
253+
except Exception as e:
254+
if ctx.obj.get("tracking_file", True):
255+
# Update the tracking file with the processing status
256+
update_tracking_df(
257+
tracking_file_path,
258+
file_name,
259+
("patches", 0),
260+
params_dict=params_patches,
261+
error_message=str(e)
262+
)
263+
if ctx.obj.get("skip_errors", True):
264+
log.warning(f"Skipping slide {file_name}: {e}")
215265
continue
216-
img = WSI_object.stitch(**params_stitch_WSI)
217-
218-
img.save(os.path.join(directory_stitch, f"{WSI_object.slide_name}.jpg"))
219-
print("Done !")
266+
else:
267+
log.error(f"Error processing slide {file_name}:")
268+
raise
269+
if params_patches["mode"] != mode:
270+
log.warning(f"Mode {params_patches['mode']} was used for patch extraction instead of {mode}." \
271+
"Reverting back to mode {mode} for the next slide.")
272+
params_patches["mode"] = mode
273+
print("Patches extracted and saved successfully.")

0 commit comments

Comments
 (0)