Skip to content

Commit

Permalink
Add state to plots
Browse files Browse the repository at this point in the history
  • Loading branch information
gurayerus committed Sep 18, 2024
1 parent ede10c9 commit 3ef095c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 64 deletions.
6 changes: 3 additions & 3 deletions src/NiChart_Viewer/src/pages/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

# FIXME: temp path for running fast
# Should be set as the images are created
st.session_state.dir_t1img = st.session_state.path_root + '/test/test_input/test3_nifti+roi'
st.session_state.dir_dlmuse = st.session_state.path_root + '/test/test_input/test3_nifti+roi'
st.session_state.dir_t1img = st.session_state.path_root + '/test/test3_nifti+roi'
st.session_state.dir_dlmuse = st.session_state.path_root + '/test/test3_nifti+roi'
st.session_state.suffix_t1img = '_T1.nii.gz'
st.session_state.suffix_dlmuse = '_T1_DLMUSE.nii.gz'

Expand All @@ -64,7 +64,7 @@
st.session_state.path_csv_spare = ''

## FIXME : this is for quickly loading a test example
st.session_state.path_csv_spare = st.session_state.path_root + '/test/test_input/test3_nifti+roi/sMRI_Results_n4.csv'
st.session_state.path_csv_spare = st.session_state.path_root + '/test4_adni3/output/out_combined/MyStudy_All.csv'

st.session_state.instantiated = True

Expand Down
141 changes: 80 additions & 61 deletions src/NiChart_Viewer/src/pages/view_plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,29 @@ def add_plot():
Adds a new plot (updates a dataframe with plot ids)
'''
df_p = st.session_state.plots
df_p.loc[st.session_state.pid] = [f'Plot {st.session_state.pid}']
st.session_state.pid += 1
plot_id = f'Plot{st.session_state.plot_index}'
df_p.loc[plot_id] = [plot_id,
st.session_state.plot_xvar,
st.session_state.plot_yvar,
st.session_state.plot_hvar,
st.session_state.plot_trend
]

st.session_state.plot_index += 1

# Remove a plot
def remove_plot(pid):
def remove_plot(plot_id):
'''
Removes the plot with the pid (updates the plot ids dataframe)
Removes the plot with the plot_id (updates the plot ids dataframe)
'''
df_p = st.session_state.plots
df_p = df_p[df_p.PID != pid]
df_p = df_p[df_p.PID != plot_id]
st.session_state.plots = df_p


def display_plot(pid):
def display_plot(plot_id):
'''
Displays the plot with the pid
Displays the plot with the plot_id
'''

# Create a copy of dataframe for filtered data
Expand All @@ -80,46 +87,58 @@ def display_plot(pid):

# Tab 1: to set plotting parameters
with ptabs[1]:
plot_type = st.selectbox("Plot Type", ["DistPlot", "RegPlot"], key=f"plot_type_{pid}")
# x_var = st.selectbox("X Var", df_filt.columns, key=f"x_var_{pid}", index=3)
# y_var = st.selectbox("Y Var", df_filt.columns, key=f"y_var_{pid}", index=8)

# Set index for default values
x_ind = df.columns.get_loc(st.session_state.default_x_var)
y_ind = df.columns.get_loc(st.session_state.default_y_var)
hue_ind = df.columns.get_loc(st.session_state.default_hue_var)
trend_index = st.session_state.trend_types.index(st.session_state.default_trend_type)

x_var = st.selectbox("X Var", df_filt.columns, key=f"x_var_{pid}", index = x_ind)
y_var = st.selectbox("Y Var", df_filt.columns, key=f"y_var_{pid}", index = y_ind)
st.session_state.sel_var = y_var
plot_type = st.selectbox("Plot Type", ["DistPlot", "RegPlot"], key=f"plot_type_{plot_id}")

# Get plot params
xvar = st.session_state.plots.loc[plot_id].xvar
yvar = st.session_state.plots.loc[plot_id].yvar
hvar = st.session_state.plots.loc[plot_id].hvar
trend = st.session_state.plots.loc[plot_id].trend

# Select plot params from the user
xind = df.columns.get_loc(xvar)
yind = df.columns.get_loc(yvar)
hind = df.columns.get_loc(hvar)
tind = st.session_state.trend_types.index(trend)

xvar = st.selectbox("X Var", df_filt.columns,
key=f"plot_xvar_{plot_id}", index = xind)
yvar = st.selectbox("Y Var", df_filt.columns,
key=f"plot_yvar_{plot_id}", index = yind)
hvar = st.selectbox("Hue Var", df_filt.columns,
key=f"plot_hvar_{plot_id}", index = hind)
trend = st.selectbox("Trend Line", st.session_state.trend_types,
key=f"trend_type_{plot_id}", index = tind)

# Set plot params to session_state
st.session_state.plots.loc[plot_id].xvar = xvar
st.session_state.plots.loc[plot_id].yvar = yvar
st.session_state.plots.loc[plot_id].hvar = hvar
st.session_state.plots.loc[plot_id].trend = trend

hue_var = st.selectbox("Hue Var", df_filt.columns, key=f"hue_var_{pid}", index = hue_ind)
trend_type = st.selectbox("Trend Line", st.session_state.trend_types, key=f"trend_type_{pid}", index = trend_index)

# Tab 2: to set data filtering parameters
with ptabs[2]:
df_filt = filter_dataframe(df, pid)
df_filt = filter_dataframe(df, plot_id)

# Tab 3: to set centiles
with ptabs[3]:
cent_type = st.selectbox("Centile Type", ['CN-All', 'CN-F', 'CN-M'], key=f"cent_type_{pid}")
cent_type = st.selectbox("Centile Type", ['CN-All', 'CN-F', 'CN-M'], key=f"cent_type_{plot_id}")

# Tab 4: to reset parameters or to delete plot
with ptabs[4]:
st.button('Delete Plot', key=f'p_delete_{pid}',
on_click=remove_plot, args=[pid])
st.button('Delete Plot', key=f'p_delete_{plot_id}',
on_click=remove_plot, args=[plot_id])

# Main plot
if trend_type == 'none':
scatter_plot = px.scatter(df_filt, x = x_var, y = y_var, color = hue_var)
if trend == 'none':
scatter_plot = px.scatter(df_filt, x = xvar, y = yvar, color = hvar)
else:
scatter_plot = px.scatter(df_filt, x = x_var, y = y_var, color = hue_var,
trendline = trend_type)
scatter_plot = px.scatter(df_filt, x = xvar, y = yvar, color = hvar, trendline = trend)

# Add plot
# - on_select: when clicked it will rerun and return the info
sel_info = st.plotly_chart(scatter_plot, on_select='rerun', key=f"bubble_chart_{pid}")
sel_info = st.plotly_chart(scatter_plot, on_select='rerun', key=f"bubble_chart_{plot_id}")

# Detect MRID from the click info
try:
Expand All @@ -136,7 +155,7 @@ def display_plot(pid):
# ## FIXME: this is temp (for debugging the selection of clicked subject)
# st.dataframe(df_filt)

def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:
def filter_dataframe(df: pd.DataFrame, plot_id) -> pd.DataFrame:
"""
Adds a UI on top of a dataframe to let viewers filter columns
Expand All @@ -153,14 +172,14 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:
# Create filters selected by the user
modification_container = st.container()
with modification_container:
widget_no = pid + '_filter'
widget_no = plot_id + '_filter'
to_filter_columns = st.multiselect("Filter dataframe on", df.columns, key = widget_no)
for vno, column in enumerate(to_filter_columns):
left, right = st.columns((1, 20))
left.write("↳")
# Treat columns with < 10 unique values as categorical
if is_categorical_dtype(df[column]) or df[column].nunique() < 10:
widget_no = pid + '_col_' + str(vno)
widget_no = plot_id + '_col_' + str(vno)
user_cat_input = right.multiselect(
f"Values for {column}",
df[column].unique(),
Expand Down Expand Up @@ -241,30 +260,30 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:
# Tab 0: to set plotting parameters
with ptabs[1]:
# Default values for plot params
st.session_state.default_hue_var = 'Sex'

def_ind_x = 0
if st.session_state.default_x_var in df.columns:
def_ind_x = df.columns.get_loc(st.session_state.default_x_var)

def_ind_y = 0
if st.session_state.default_y_var in df.columns:
def_ind_y = df.columns.get_loc(st.session_state.default_y_var)

def_ind_hue = 0
if st.session_state.default_hue_var in df.columns:
def_ind_hue = df.columns.get_loc(st.session_state.default_hue_var)

st.session_state.default_x_var = st.selectbox("Default X Var", df.columns, key=f"x_var_init",
index = def_ind_x)
st.session_state.default_y_var = st.selectbox("Default Y Var", df.columns, key=f"y_var_init",
index = def_ind_y)
st.session_state.sel_var = st.session_state.default_y_var

st.session_state.default_hue_var = st.selectbox("Default Hue Var", df.columns, key=f"hue_var_init",
index = def_ind_hue)
trend_index = st.session_state.trend_types.index(st.session_state.default_trend_type)
st.session_state.default_trend_type = st.selectbox("Default Trend Line", st.session_state.trend_types,
st.session_state.plot_hvar = 'Sex'

plot_xvar_ind = 0
if st.session_state.plot_xvar in df.columns:
plot_xvar_ind = df.columns.get_loc(st.session_state.plot_xvar)

plot_yvar_ind = 0
if st.session_state.plot_yvar in df.columns:
plot_yvar_ind = df.columns.get_loc(st.session_state.plot_yvar)

plot_hvar_ind = 0
if st.session_state.plot_hvar in df.columns:
plot_hvar_ind = df.columns.get_loc(st.session_state.plot_hvar)

st.session_state.plot_xvar = st.selectbox("Default X Var", df.columns, key=f"plot_xvar_init",
index = plot_xvar_ind)
st.session_state.plot_yvar = st.selectbox("Default Y Var", df.columns, key=f"plot_yvar_init",
index = plot_yvar_ind)
st.session_state.sel_var = st.session_state.plot_yvar

st.session_state.plot_hvar = st.selectbox("Default Hue Var", df.columns, key=f"plot_hvar_init",
index = plot_hvar_ind)
trend_index = st.session_state.trend_types.index(st.session_state.plot_trend)
st.session_state.plot_trend = st.selectbox("Default Trend Line", st.session_state.trend_types,
key=f"trend_type_init", index = trend_index)

# Button to add a new plot
Expand All @@ -277,18 +296,18 @@ def filter_dataframe(df: pd.DataFrame, pid) -> pd.DataFrame:

# Read plot ids
df_p = st.session_state.plots
p_index = df_p.PID.tolist()
list_plots = df_p.index.tolist()
plot_per_raw = st.session_state.plot_per_raw

# Render plots
# - iterates over plots;
# - for every "plot_per_raw" plots, creates a new columns block, resets column index, and displays the plot
for i in range(0, len(p_index)):
for i, plot_ind in enumerate(list_plots):
column_no = i % plot_per_raw
if column_no == 0:
blocks = st.columns(plot_per_raw)
with blocks[column_no]:
display_plot(p_index[i])
display_plot(plot_ind)


# FIXME: this is for debugging; will be removed
Expand Down

0 comments on commit 3ef095c

Please sign in to comment.