Skip to content

Commit

Permalink
Edit the img viewer to make it run faster by moving img operations to…
Browse files Browse the repository at this point in the history
… a function decorated using st.cache_data
  • Loading branch information
gurayerus committed Sep 15, 2024
1 parent 57b457c commit 1c8affa
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 63 deletions.
125 changes: 68 additions & 57 deletions src/NiChart_Viewer/src/pages/view_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
# st.session_state.pid = 1
# st.session_state.instantiated = True

VIEWS = ["axial", "sagittal", "coronal"]
VIEW_AXES = [0, 1, 2]
VIEW_OTHER_AXES = [(1,2), (0,2), (0,1)]

def reorient_nifti(nii_in, ref_orient = 'LPS'):

Expand All @@ -37,8 +40,6 @@ def crop_image(img, mask):
'''
Crop img to the foreground of the mask
'''


# Detect bounding box
nz = np.nonzero(mask)
mn = np.min(nz, axis=1)
Expand Down Expand Up @@ -69,34 +70,30 @@ def crop_image(img, mask):
mask = np.pad(mask, padding, mode='constant', constant_values=0)

return img, mask


def show_nifti(img, mask, view):
def detect_mask_bounds(mask):
'''
Detect the mask start, end and center in each view
'''
mask_bounds = np.zeros([3,3]).astype(int)
for i, axis in enumerate(VIEW_AXES):
mask_bounds[i,0] = 0
mask_bounds[i,1] = mask.shape[i]
slices_nz = np.where(np.sum(mask, axis = VIEW_OTHER_AXES[i]) > 0)[0]
try:
mask_bounds[i,2] = slices_nz[len(slices_nz) // 2]
except:
# Could not detect masked region. Set center to image center
mask_bounds[i,2] = mask.shape[i] // 2
return mask_bounds

def show_nifti(img, view, sel_axis_bounds):
'''
Displays the nifti img
'''

# Set parameters based on orientation
if view == 'axial':
sel_axis = 0
other_axes = (1,2)

if view == 'sagittal':
sel_axis = 2
other_axes = (0,1)

if view == 'coronal':
sel_axis = 1
other_axes = (0,2)


# Detect middle masked slice
slices_nz = np.where(np.sum(mask, axis = other_axes) > 0)[0]
sel_slice = slices_nz[len(slices_nz) // 2]

# Create a slider to select the slice index
slice_index = st.slider(f"{view}", 0, img.shape[sel_axis] - 1,
value=sel_slice, key = f'slider_{view}')
slice_index = st.slider(f"{view}", 0, sel_axis_bounds[1] - 1,
value=sel_axis_bounds[2], key = f'slider_{view}')

# Extract the slice and display it
if view == 'axial':
Expand All @@ -106,7 +103,37 @@ def show_nifti(img, mask, view):
else:
st.image(img[:, slice_index, :], use_column_width = True)

@st.cache_data
def prep_images(f_img, f_mask, sel_roi_ind):
# Read nifti
nii_img = nib.load(f_img)
nii_mask = nib.load(f_mask)

# Reorient nifti
nii_img = reorient_nifti(nii_img, ref_orient = 'IPL')
nii_mask = reorient_nifti(nii_mask, ref_orient = 'IPL')

# Extract image to matrix
img = nii_img.get_fdata()
mask = nii_mask.get_fdata()

# Crop image to ROIs and reshape
img, mask = crop_image(img, mask)

# Select target roi
mask = (mask == sel_roi_ind).astype(int)

# Merge image and mask
img = np.stack((img,)*3, axis=-1)

img_masked = img.copy()
img_masked[mask == 1] = mask_color

# Scale values
img = img / img.max()
img_masked = img_masked / img_masked.max()

return img, mask, img_masked

# # Config page
# st.set_page_config(page_title="DataFrame Demo", page_icon="📊", layout='wide')
Expand All @@ -122,8 +149,12 @@ def show_nifti(img, mask, view):
sel_roi = 'Ventricles'
mask_color = (0, 255, 0) # RGB format

# Select roi index
dict_roi = {'Ventricles':51, 'Hippocampus_R':100, 'Hippocampus_L':48}
sel_roi_ind = dict_roi[sel_roi]

# Process image and mask to prepare final 3d matrix to display
img, mask, img_masked = prep_images(f1, f2, sel_roi_ind)

# Page controls in side bar
with st.sidebar:
Expand Down Expand Up @@ -154,45 +185,25 @@ def show_nifti(img, mask, view):
st.write('---')

# Create a list of checkbox options
orient_options = ["axial", "sagittal", "coronal"]
#list_orient = st.multiselect("Select viewing planes:", orient_options, orient_options[0])
list_orient = st.multiselect("Select viewing planes:", orient_options, orient_options)
#list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS[0])
list_orient = st.multiselect("Select viewing planes:", VIEWS, VIEWS)

# View hide overlay
is_show_overlay = st.checkbox('Show overlay', True)

# Print the selected options (optional)
if list_orient:
st.write("Selected options:", list_orient)

# Select roi index
sel_roi_ind = dict_roi[sel_roi]

# Read nifti
nii_ulay = nib.load(f1)
nii_olay = nib.load(f2)

# Reorient nifti
nii_ulay = reorient_nifti(nii_ulay, ref_orient = 'IPL')
nii_olay = reorient_nifti(nii_olay, ref_orient = 'IPL')

# Extract image to matrix
img_ulay = nii_ulay.get_fdata()
img_olay = nii_olay.get_fdata()

# Crop image to ROIs and reshape
img_ulay, img_olay = crop_image(img_ulay, img_olay)

# Select target roi
img_olay = (img_olay == sel_roi_ind).astype(int)

# Merge image and mask
img_ulay = np.stack((img_ulay,)*3, axis=-1)
img_ulay[img_olay == 1] = mask_color

# Scale values
img_ulay = img_ulay / img_ulay.max()

# Detect mask bounds and center in each view
mask_bounds = detect_mask_bounds(mask)

# Show images
blocks = st.columns(len(list_orient))
for i, tmp_orient in enumerate(list_orient):
with blocks[i]:
show_nifti(img_ulay, img_olay, tmp_orient)
if is_show_overlay == False:
show_nifti(img, tmp_orient, mask_bounds[i,:])
else:
show_nifti(img_masked, tmp_orient, mask_bounds[i,:])

24 changes: 18 additions & 6 deletions src/NiChart_Viewer/src/pages/view_plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,26 @@
import tkinter as tk
from tkinter import filedialog

def browse_file(init_dir):
'''
File selector
Returns the file name selected by the user and the parent folder
'''
root = tk.Tk()
root.withdraw() # Hide the main window
out_path = filedialog.askopenfilename(initialdir = init_dir)
out_dir = os.path.dirname(out_path)
root.destroy()
return out_path, out_dir

def browse_file_folder(is_file, init_dir):
def browse_folder(init_dir):
'''
Folder selector
Returns the folder name selected by the user
'''
root = tk.Tk()
root.withdraw() # Hide the main window
if is_file == True:
out_path = filedialog.askopenfilename(initialdir = init_dir, multiple=0)
else:
out_path = filedialog.askdirectory(initialdir = init_dir)
out_path = filedialog.askdirectory(initialdir = init_dir)
root.destroy()
return out_path

Expand Down Expand Up @@ -203,7 +215,7 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:

# Input file name
if st.sidebar.button("Select input file"):
st.session_state.in_csv_sMRI = browse_file_folder(True, dir_root)
st.session_state.in_csv_sMRI, st.session_state.init_dir = browse_file(st.session_state.init_dir)
spare_csv = st.sidebar.text_input("Enter the name of the ROI csv file:",
value = st.session_state.in_csv_sMRI,
label_visibility="collapsed")
Expand Down
51 changes: 51 additions & 0 deletions src/NiChart_Viewer/src/pages/view_plot_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,57 @@
# st.session_state.pid = 1
# st.session_state.instantiated = True

def calc_subject_centiles(df_subj, df_cent):
'''
Calculate subject specific centile values
'''

# Filter centiles to subject's age
tmp_ind = (df2.Age - df1.Age[0]).abs().idxmin()
sel_age = df2.loc[tmp_ind, 'Age']
df_cent_sel = df_cent[df_cent.Age == sel_age]

# Find ROIs in subj data that are included in the centiles file
sel_rois = df_subj.columns[df_subj.columns.isin(df_cent_sel.ROI.unique())].tolist()
df_cent_sel = df_cent_sel[df_cent_sel.ROI.isin(sel_rois)].drop(['ROI','Age'], axis=1)

cent = df_cent_sel.columns.str.replace('centile_', '').astype(int).values
vals_cent = df_cent_sel.values
vals_subj = df_subj.loc[0,sel_rois]

cent_subj = np.zeros(vals_subj.shape[0])
for i, sval in enumerate(vals_subj):
# Find nearest x values
ind1 = np.where(vals_subj[i] < vals_cent[i,:])[0][0]-1
ind2 = ind1 + 1

print(ind1)

# Calculate slope
slope = (cent[ind2] - cent[ind1]) / (vals_cent[i, ind2] - vals_cent[i, ind1])

# Estimate subj centile
cent_subj[i] = cent[ind1] + slope * (vals_subj[i] - vals_cent[i, ind1])

df_out = pd.DataFrame(dict(ROI=sel_rois, Centiles=cent_subj))
return df_out


print("Estimated y value for x =", target_x, "is:", estimated_y)

lower_percentiles = centile_values[ind_l, :]
upper_percentiles = centile_values[ind_u, :]
proportions = (target_values[:, None] - lower_percentiles) / (upper_percentiles - lower_percentiles)

estimated_percentiles = ind_l + proportions * (ind_u - ind_l)

print("Estimated percentiles:", estimated_percentiles)

# Calculate subject centile values
for sel_roi in



def display_plot(sel_id):
'''
Displays the plot with the given mrid
Expand Down

0 comments on commit 1c8affa

Please sign in to comment.