Skip to content

Commit 8575cae

Browse files
authored
Merge pull request #202 from mlexchange/integrate-tiled-viewer
Replace data section dropdown with `tiled_viewer`
2 parents 0128353 + 830ada4 commit 8575cae

File tree

7 files changed

+138
-109
lines changed

7 files changed

+138
-109
lines changed

callbacks/control_bar.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import time
55
import uuid
6+
from urllib.parse import urlparse
67

78
import dash_mantine_components as dmc
89
import plotly.express as px
@@ -40,6 +41,36 @@
4041
open(EXPORT_FILE_PATH, "w").close()
4142

4243

44+
@callback(
45+
Output("image-uri", "value"),
46+
Input("tiled-image-selector", "selectedLinks"),
47+
prevent_initial_call=True,
48+
)
49+
def update_selected_image_uri(selected_links):
50+
print(f"DEBUG - Selected image links: {selected_links}") # Debug print
51+
52+
if selected_links:
53+
# Extract the 'self' key from the dictionary
54+
if isinstance(selected_links, dict) and "self" in selected_links:
55+
# Extract the full URI from the selected links
56+
full_uri = selected_links.get("self", "")
57+
58+
# Parse the URI and extract the path
59+
parsed_uri = urlparse(full_uri)
60+
extracted_path = parsed_uri.path
61+
62+
# Remove the common prefix from the path
63+
base_data_path = urlparse(tiled_datasets.data_tiled_uri).path
64+
if extracted_path.startswith(base_data_path):
65+
extracted_path = extracted_path[len(base_data_path) :]
66+
67+
# Clean up the extracted path and return
68+
return extracted_path.strip("/")
69+
# Fall back to string representation if we can't extract 'self'
70+
return str(selected_links)
71+
return ""
72+
73+
4374
@callback(
4475
Output("current-class-selection", "data", allow_duplicate=True),
4576
Output("notifications-container", "children", allow_duplicate=True),
@@ -746,10 +777,10 @@ def export_annotation(n_clicks, all_annotations, global_store):
746777
Input("save-annotations", "n_clicks"),
747778
State("annotation-store", "data"),
748779
State({"type": "annotation-class-store", "index": ALL}, "data"),
749-
State("project-name-src", "value"),
780+
State("image-uri", "value"),
750781
prevent_initial_call=True,
751782
)
752-
def save_data(n_clicks, global_store, all_annotations, image_src):
783+
def save_data(n_clicks, global_store, all_annotations, image_uri):
753784
"""This callback is responsible for saving the annotation data to the store"""
754785
if not n_clicks:
755786
raise PreventUpdate
@@ -758,7 +789,7 @@ def save_data(n_clicks, global_store, all_annotations, image_src):
758789
# TODO: save store to the server file-user system, this will be changed to DB later
759790
export_data = {
760791
"user": USER_NAME,
761-
"source": image_src,
792+
"source": image_uri,
762793
"time": time.strftime("%Y-%m-%d-%H:%M:%S"),
763794
"data": json.dumps(all_annotations),
764795
}
@@ -787,10 +818,10 @@ def toggle_save_load_modal(n_clicks, opened):
787818
@callback(
788819
Output("load-annotations-server-container", "children"),
789820
Input("open-data-management-modal-button", "n_clicks"),
790-
State("project-name-src", "value"),
821+
State("image-uri", "value"),
791822
prevent_initial_call=True,
792823
)
793-
def populate_load_annotations_dropdown_menu_options(modal_opened, image_src):
824+
def populate_load_annotations_dropdown_menu_options(modal_opened, image_uri):
794825
"""
795826
This callback populates dropdown window with all saved annotation options for the given project name.
796827
It then creates buttons with info about the save, which when clicked, loads the data from the server.
@@ -799,7 +830,7 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src):
799830
raise PreventUpdate
800831

801832
data = tiled_masks.DEV_load_exported_json_data(
802-
EXPORT_FILE_PATH, USER_NAME, image_src
833+
EXPORT_FILE_PATH, USER_NAME, image_uri
803834
)
804835
if not data:
805836
return "No annotations found for the selected data source."
@@ -827,11 +858,11 @@ def populate_load_annotations_dropdown_menu_options(modal_opened, image_src):
827858
Output("annotation-class-container", "children", allow_duplicate=True),
828859
Output("data-management-modal", "opened", allow_duplicate=True),
829860
Input({"type": "load-server-annotations", "index": ALL}, "n_clicks"),
830-
State("project-name-src", "value"),
861+
State("image-uri", "value"),
831862
State("image-selection-slider", "value"),
832863
prevent_initial_call=True,
833864
)
834-
def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx):
865+
def load_and_apply_selected_annotations(selected_annotation, image_uri, img_idx):
835866
"""
836867
This callback is responsible for loading and applying the selected annotations when user selects them from the modal.
837868
"""
@@ -845,7 +876,7 @@ def load_and_apply_selected_annotations(selected_annotation, image_src, img_idx)
845876

846877
# TODO : when quering from the server, load (data) for user, source, time
847878
data = tiled_masks.DEV_load_exported_json_data(
848-
EXPORT_FILE_PATH, USER_NAME, image_src
879+
EXPORT_FILE_PATH, USER_NAME, image_uri
849880
)
850881
data = tiled_masks.DEV_filter_json_data_by_timestamp(
851882
data, str(selected_annotation_timestamp)
@@ -875,16 +906,6 @@ def open_controls_drawer(n_clicks, is_opened):
875906
return no_update, no_update
876907

877908

878-
@callback(Output("project-name-src", "data"), Input("refresh-tiled", "n_clicks"))
879-
def refresh_data_client(refresh_tiled):
880-
if refresh_tiled:
881-
tiled_datasets.refresh_data_client()
882-
data_options = [
883-
item for item in tiled_datasets.get_data_project_names() if "seg" not in item
884-
]
885-
return data_options
886-
887-
888909
@callback(
889910
Output("show-result-overlay-toggle", "checked"),
890911
Output("show-result-overlay-toggle", "disabled"),

callbacks/image_viewer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def hide_show_segmentation_overlay(toggle_seg_result, opacity):
6565
Input("show-result-overlay-toggle", "checked"),
6666
Input("annotated-slices-selector", "value"),
6767
State({"type": "annotation-class-store", "index": ALL}, "data"),
68-
State("project-name-src", "value"),
68+
State("image-uri", "value"),
6969
State("annotation-store", "data"),
7070
State("image-metadata", "data"),
7171
State("screen-size", "data"),
@@ -81,7 +81,7 @@ def render_image(
8181
toggle_seg_result,
8282
slice_selection,
8383
all_annotation_class_store,
84-
project_name,
84+
image_uri,
8585
annotation_store,
8686
image_metadata,
8787
screen_size,
@@ -111,7 +111,7 @@ def render_image(
111111

112112
if image_idx:
113113
image_idx -= 1 # slider starts at 1, so subtract 1 to get the correct index
114-
tf = tiled_datasets.get_data_sequence_by_name(project_name)[image_idx]
114+
tf = tiled_datasets.get_data_sequence_by_trimmed_uri(image_uri)[image_idx]
115115
# Auto-scale data
116116
low = np.percentile(tf.ravel(), 1)
117117
high = np.percentile(tf.ravel(), 99)
@@ -135,15 +135,15 @@ def render_image(
135135
if image_idx in annotation_indices:
136136
# Will not return an error since we already checked if image_idx is in the list
137137
mapped_index = annotation_indices.index(image_idx)
138-
result = tiled_results.get_data_by_trimmed_uri(
138+
result = tiled_results.get_data_slice_by_trimmed_uri(
139139
seg_result["seg_result_trimmed_uri"], slice=mapped_index
140140
)
141141
else:
142142
result = None
143143
# if mask_idx is not given in the results,
144144
# then the result stems from inference on the full data set
145145
else:
146-
result = tiled_results.get_data_by_trimmed_uri(
146+
result = tiled_results.get_data_slice_by_trimmed_uri(
147147
seg_result["seg_result_trimmed_uri"], slice=image_idx
148148
)
149149
else:
@@ -219,8 +219,8 @@ def render_image(
219219

220220
# No update is needed for the 'children' of the control components
221221
# since we just want to trigger the loading overlay with this callback
222-
if project_name != image_metadata["name"] or image_metadata["name"] is None:
223-
curr_image_metadata = {"size": tf.shape, "name": project_name}
222+
if image_uri != image_metadata["name"] or image_metadata["name"] is None:
223+
curr_image_metadata = {"size": tf.shape, "name": image_uri}
224224
else:
225225
curr_image_metadata = dash.no_update
226226
return (
@@ -393,7 +393,7 @@ def update_viewfinder(relayout_data, annotation_store):
393393
}
394394
""",
395395
Output("image-viewer-loading", "className", allow_duplicate=True),
396-
Input("project-name-src", "value"),
396+
Input("image-uri", "value"),
397397
prevent_initial_call=True,
398398
)
399399
clientside_callback(
@@ -478,7 +478,7 @@ def locally_store_annotations(
478478
Output(
479479
{"type": "annotation-class-store", "index": ALL}, "data", allow_duplicate=True
480480
),
481-
Input("project-name-src", "value"),
481+
Input("image-uri", "value"),
482482
State({"type": "annotation-class-store", "index": ALL}, "data"),
483483
prevent_initial_call=True,
484484
)
@@ -498,17 +498,17 @@ def clear_annotations_on_dataset_change(change_project, all_annotation_class_sto
498498
Output("image-selection-slider", "value"),
499499
Output("image-selection-slider", "disabled"),
500500
Output("annotation-store", "data"),
501-
Input("project-name-src", "value"),
501+
Input("image-uri", "value"),
502502
State("annotation-store", "data"),
503503
)
504-
def update_slider_values(project_name, annotation_store):
504+
def update_slider_values(image_uri, annotation_store):
505505
"""
506506
When the data source is loaded, this callback will set the slider values and chain call
507507
"update_selection_and_image" callback which will update image and slider selection component.
508508
"""
509-
# Retrieve data shape if project_name is valid and points to a 3d array
509+
# Retrieve data shape if image_uri is valid and points to a 3d array
510510
data_shape = (
511-
tiled_datasets.get_data_shape_by_name(project_name) if project_name else None
511+
tiled_datasets.get_data_shape_by_trimmed_uri(image_uri) if image_uri else None
512512
)
513513
disable_slider = data_shape is None
514514
if not disable_slider:

callbacks/segmentation.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@
174174
Input("run-train", "n_clicks"),
175175
State("annotation-store", "data"),
176176
State({"type": "annotation-class-store", "index": ALL}, "data"),
177-
State("project-name-src", "value"),
177+
State("image-uri", "value"),
178178
State("model-parameters", "children"),
179179
State("model-list", "value"),
180180
State("job-name", "value"),
@@ -184,7 +184,7 @@ def run_train(
184184
n_clicks,
185185
global_store,
186186
all_annotations,
187-
project_name,
187+
image_uri,
188188
model_parameter_container,
189189
model_name,
190190
job_name,
@@ -208,7 +208,7 @@ def run_train(
208208
)
209209
return notification, no_update
210210
mask_uri, num_classes, mask_error_message = tiled_masks.save_annotations_data(
211-
global_store, all_annotations, project_name
211+
global_store, all_annotations, image_uri
212212
)
213213
model_parameters["num_classes"] = num_classes
214214
model_parameters["network"] = model_name
@@ -227,7 +227,7 @@ def run_train(
227227
"%Y/%m/%d %H:%M:%S"
228228
)
229229
flow_run_name = f"{job_name} {current_time}"
230-
data_uri = tiled_datasets.get_data_uri_by_name(project_name)
230+
data_uri = tiled_datasets.get_data_uri_by_trimmed_uri(image_uri)
231231
io_parameters = assemble_io_parameters_from_uris(data_uri, mask_uri)
232232
io_parameters["uid_retrieve"] = ""
233233
io_parameters["models_dir"] = RESULTS_DIR
@@ -263,7 +263,7 @@ def run_train(
263263
FLOW_NAME,
264264
parameters=TRAIN_PARAMS_EXAMPLE,
265265
flow_run_name=flow_run_name,
266-
tags=PREFECT_TAGS + ["train", project_name],
266+
tags=PREFECT_TAGS + ["train", image_uri],
267267
)
268268
job_message = f"Job has been succesfully submitted with uid: {job_uid} and mask uri: {mask_uri}"
269269
notification_color = "indigo"
@@ -287,7 +287,7 @@ def run_train(
287287
Input("run-inference", "n_clicks"),
288288
State("train-job-selector", "value"),
289289
State({"type": "annotation-class-store", "index": ALL}, "data"),
290-
State("project-name-src", "value"),
290+
State("image-uri", "value"),
291291
State("model-parameters", "children"),
292292
State("model-list", "value"),
293293
prevent_initial_call=True,
@@ -296,7 +296,7 @@ def run_inference(
296296
n_clicks,
297297
train_job_id,
298298
all_annotations,
299-
project_name,
299+
image_uri,
300300
model_parameter_container,
301301
model_name,
302302
):
@@ -324,7 +324,7 @@ def run_inference(
324324
model_parameters["network"] = model_name
325325

326326
# Set io_parameters for inference, there will be no mask
327-
data_uri = tiled_datasets.get_data_uri_by_name(project_name)
327+
data_uri = tiled_datasets.get_data_uri_by_trimmed_uri(image_uri)
328328
io_parameters = assemble_io_parameters_from_uris(data_uri, "")
329329
io_parameters["uid_retrieve"] = ""
330330
io_parameters["models_dir"] = RESULTS_DIR
@@ -380,7 +380,7 @@ def run_inference(
380380
FLOW_NAME,
381381
parameters=INFERENCE_PARAMS_EXAMPLE,
382382
flow_run_name=flow_run_name,
383-
tags=PREFECT_TAGS + ["inference", project_name],
383+
tags=PREFECT_TAGS + ["inference", image_uri],
384384
)
385385
job_message = (
386386
f"Job has been succesfully submitted with uid: {job_uid}"
@@ -445,10 +445,10 @@ def check_train_job(n_intervals):
445445
Output("infra-state", "data", allow_duplicate=True),
446446
Input("model-check", "n_intervals"),
447447
Input("train-job-selector", "value"),
448-
State("project-name-src", "value"),
448+
State("image-uri", "value"),
449449
prevent_initial_call=True,
450450
)
451-
def check_inference_job(n_intervals, train_job_id, project_name):
451+
def check_inference_job(n_intervals, train_job_id, image_uri):
452452
"""
453453
This callback populates the inference job selector dropdown with job names and ids from Prefect.
454454
The list of jobs is filtered by the selected train job in the train job selector dropdown.
@@ -472,7 +472,7 @@ def check_inference_job(n_intervals, train_job_id, project_name):
472472
if job_name is not None:
473473
data = query_flow_runs(
474474
flow_run_name=job_name,
475-
tags=PREFECT_TAGS + ["inference", project_name],
475+
tags=PREFECT_TAGS + ["inference", image_uri],
476476
)
477477
infra_state = no_update
478478

@@ -489,7 +489,7 @@ def check_inference_job(n_intervals, train_job_id, project_name):
489489

490490
def populate_segmentation_results(
491491
job_id,
492-
project_name,
492+
image_uri,
493493
job_type="training",
494494
):
495495
"""
@@ -498,7 +498,7 @@ def populate_segmentation_results(
498498
"""
499499
# Nothing has been selected is job_id is None
500500
if job_id is not None:
501-
data_uri = tiled_datasets.get_data_uri_by_name(project_name)
501+
data_uri = tiled_datasets.get_data_uri_by_trimmed_uri(image_uri)
502502
# Only returns the name if the job finished successfully
503503
job_name = get_flow_run_name(job_id)
504504
if job_name is not None:
@@ -515,7 +515,7 @@ def populate_segmentation_results(
515515
# First refresh the data client,
516516
# the root container may not yet have existed on startup
517517
tiled_results.refresh_data_client()
518-
result_container = tiled_results.get_data_by_trimmed_uri(
518+
result_container = tiled_results.get_data_slice_by_trimmed_uri(
519519
expected_result_uri
520520
)
521521
except Exception:
@@ -551,16 +551,16 @@ def populate_segmentation_results(
551551
Output("seg-results-train-store", "data"),
552552
Output("dvc-training-stats-link", "href"),
553553
Input("train-job-selector", "value"),
554-
State("project-name-src", "value"),
554+
State("image-uri", "value"),
555555
prevent_initial_call=True,
556556
)
557-
def populate_segmentation_results_train(train_job_id, project_name):
557+
def populate_segmentation_results_train(train_job_id, image_uri):
558558
"""
559559
This callback populates the segmentation results store based on the uids
560560
if the training job and the inference job.
561561
"""
562562
notification, result_store, segment_job_id = populate_segmentation_results(
563-
train_job_id, project_name, "training"
563+
train_job_id, image_uri, "training"
564564
)
565565
if segment_job_id is not None:
566566
results_link = f"/results/{segment_job_id}/dvc_metrics/report.html"
@@ -574,16 +574,16 @@ def populate_segmentation_results_train(train_job_id, project_name):
574574
Output("notifications-container", "children", allow_duplicate=True),
575575
Output("seg-results-inference-store", "data"),
576576
Input("inference-job-selector", "value"),
577-
State("project-name-src", "value"),
577+
State("image-uri", "value"),
578578
prevent_initial_call=True,
579579
)
580-
def populate_segmentation_results_inference(inference_job_id, project_name):
580+
def populate_segmentation_results_inference(inference_job_id, image_uri):
581581
"""
582582
This callback populates the segmentation results store based on the uids
583583
if the training job and the inference job.
584584
"""
585585
notification, result_store, _ = populate_segmentation_results(
586-
inference_job_id, project_name, "inference"
586+
inference_job_id, image_uri, "inference"
587587
)
588588
return (
589589
notification,

0 commit comments

Comments
 (0)