-
Notifications
You must be signed in to change notification settings - Fork 14
support for multiple slice positions in 2D heatmaps with automatic layout #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,6 +127,7 @@ | |
magnitudes as the values. | ||
position : list, tuple, np.ndarray, float | ||
Position of the plane in the atlas. | ||
list of positions create multiple slices. | ||
orientation : str or tuple, optional | ||
Orientation of the plane in the atlas. Either, "frontal", | ||
"sagittal", "horizontal" or a tuple with the normal vector. | ||
|
@@ -182,6 +183,8 @@ | |
self.label_regions = label_regions | ||
self.annotate_regions = annotate_regions | ||
self.annotate_text_options_2d = annotate_text_options_2d | ||
self.slicer: Optional[Slicer] = None | ||
self.multiple_slicers: Optional[List[Slicer]] = None | ||
|
||
# create a scene | ||
self.scene = Scene( | ||
|
@@ -204,8 +207,27 @@ | |
if r.name != "root" | ||
] | ||
|
||
# prepare slicer object | ||
self.slicer = Slicer(position, orientation, thickness, self.scene.root) | ||
# prepare slicer object or objects when list(position) | ||
if isinstance(position, list): | ||
if self.format == "3D": | ||
raise ValueError( | ||
"List of positions not supported in 3D format. " | ||
"Did you mean to use a tuple as a 3D position?" | ||
) | ||
if len(position) <= 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be OK to pass a list of length 1 (but not an empty list)? If you agree, please adapt accordingly? |
||
raise ValueError( | ||
"List of positions should contain more than one position. " | ||
"Did you mean to pass a single value?" | ||
) | ||
self.positions = position | ||
self.multiple_slicers = [ | ||
Slicer(pos, orientation, thickness, self.scene.root) | ||
for pos in position | ||
] | ||
else: | ||
self.slicer = Slicer( | ||
position, orientation, thickness, self.scene.root | ||
) | ||
|
||
def prepare_colors( | ||
self, | ||
|
@@ -275,6 +297,9 @@ | |
Creates a 2D plot or 3D rendering of the heatmap | ||
""" | ||
if self.format == "3D": | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a case where |
||
self.slicer is not None | ||
), "Cannot access slice, check your parameters" | ||
self.slicer.slice_scene(self.scene, self.regions_meshes) | ||
view = self.render(**kwargs) | ||
else: | ||
|
@@ -327,6 +352,9 @@ | |
camera = self.orientation | ||
else: | ||
self.orientation = np.array(self.orientation) | ||
assert ( | ||
self.slicer is not None | ||
), "Cannot access plane0: slicer is None" | ||
com = self.slicer.plane0.center_of_mass() | ||
camera = { | ||
"pos": com - self.orientation * 2 * np.linalg.norm(com), | ||
|
@@ -349,7 +377,7 @@ | |
cbar_label: Optional[str] = None, | ||
show_cbar: bool = True, | ||
**kwargs, | ||
) -> plt.Figure: | ||
) -> Union[plt.Figure, List[plt.Figure]]: | ||
""" | ||
Plots the heatmap in 2D using matplotlib. | ||
|
||
|
@@ -381,39 +409,118 @@ | |
|
||
Returns | ||
------- | ||
plt.Figure | ||
The matplotlib figure object for the plot. | ||
Union[plt.Figure, List[plt.Figure]] | ||
The matplotlib figure object for the plot, | ||
or a list of figure objects | ||
when list(positions) are provided. | ||
|
||
Notes | ||
----- | ||
This method is used to generate a standalone plot of | ||
the heatmap data. | ||
When list(positions) are provided to class constructor, | ||
it creates multiple figures with optimized grid layouts. | ||
""" | ||
|
||
f, ax = plt.subplots(figsize=(9, 9)) | ||
|
||
f, ax = self.plot_subplot( | ||
fig=f, | ||
ax=ax, | ||
show_legend=show_legend, | ||
xlabel=xlabel, | ||
ylabel=ylabel, | ||
hide_axes=hide_axes, | ||
cbar_label=cbar_label, | ||
show_cbar=show_cbar, | ||
**kwargs, | ||
) | ||
if self.multiple_slicers is not None: | ||
num_slices = len(self.multiple_slicers) | ||
max_plots_per_fig = min(25, num_slices) | ||
num_figures = ( | ||
num_slices + max_plots_per_fig - 1 | ||
) // max_plots_per_fig | ||
|
||
all_figures = [] | ||
for fig_idx in range(num_figures): | ||
# calculate grid layout for the figure | ||
start_idx = fig_idx * max_plots_per_fig | ||
end_idx = min((fig_idx + 1) * max_plots_per_fig, num_slices) | ||
current_num_slices = end_idx - start_idx | ||
|
||
nrows = int(np.ceil(np.sqrt(current_num_slices))) | ||
ncols = int(np.ceil(current_num_slices / nrows)) | ||
|
||
f, axes = plt.subplots( | ||
nrows, | ||
ncols, | ||
layout="constrained", | ||
figsize=(26.25, 15), # 7:4 aspect ratio works well | ||
) | ||
|
||
if filename is not None: | ||
plt.savefig(filename, dpi=300) | ||
# padding (left, bottom, right, top) [0-1]% | ||
f.get_layout_engine().set(rect=(0, 0.01, 1, 0.97)) | ||
axes_flat = axes.flatten() | ||
|
||
# plot each position into the figure | ||
for i in range(current_num_slices): | ||
global_i = start_idx + i | ||
projected, _ = self.multiple_slicers[ | ||
global_i | ||
].get_structures_slice_coords( | ||
self.regions_meshes, self.scene.root | ||
) | ||
|
||
plt.show() | ||
return f | ||
self.plot_subplot( | ||
f, | ||
axes_flat[i], | ||
projected, | ||
show_legend, | ||
xlabel, | ||
ylabel, | ||
hide_axes, | ||
cbar_label, | ||
show_cbar, | ||
**kwargs, | ||
) | ||
|
||
axes_flat[i].set_title( | ||
self.title | ||
if self.title is not None | ||
else f"Position {self.positions[global_i]}" | ||
) | ||
|
||
# hide any empty subplots | ||
for i in range(current_num_slices, len(axes_flat)): | ||
axes_flat[i].axis("off") | ||
|
||
# save if filename provided | ||
if filename and num_figures > 1: | ||
print("Saving ", f"fig{fig_idx+1}_{filename}") | ||
plt.savefig(f"fig{fig_idx+1}_{filename}", dpi=200) | ||
elif filename: | ||
print("Saving ", filename) | ||
plt.savefig(filename, dpi=200) | ||
|
||
all_figures.append(f) | ||
plt.show() | ||
|
||
return all_figures | ||
else: | ||
f, ax = plt.subplots(figsize=(9, 9)) | ||
|
||
f, ax = self.plot_subplot( | ||
fig=f, | ||
ax=ax, | ||
show_legend=show_legend, | ||
xlabel=xlabel, | ||
ylabel=ylabel, | ||
hide_axes=hide_axes, | ||
cbar_label=cbar_label, | ||
show_cbar=show_cbar, | ||
**kwargs, | ||
) | ||
|
||
if filename is not None: | ||
print("Saving ", filename) | ||
plt.savefig(filename, dpi=300) | ||
|
||
plt.show() | ||
return f | ||
|
||
def plot_subplot( | ||
self, | ||
fig: plt.Figure, | ||
ax: plt.Axes, | ||
projected=None, | ||
show_legend: bool = False, | ||
xlabel: str = "µm", | ||
ylabel: str = "µm", | ||
|
@@ -431,10 +538,14 @@ | |
|
||
Parameters | ||
---------- | ||
fig : plt.Figure, optional | ||
fig : plt.Figure | ||
The figure object in which the subplot is plotted. | ||
ax : plt.Axes, optional | ||
ax : plt.Axes | ||
The axes object in which the subplot is plotted. | ||
projected : dict, optional | ||
Pre-computed slice coordinates. | ||
If None, coordinates will be | ||
calculated using the self.slicer. | ||
show_legend : bool, optional | ||
If True, displays a legend for the plotted regions. | ||
Default is False. | ||
|
@@ -461,9 +572,13 @@ | |
----- | ||
This method modifies the provided figure and axes objects in-place. | ||
""" | ||
projected, _ = self.slicer.get_structures_slice_coords( | ||
self.regions_meshes, self.scene.root | ||
) | ||
if projected is None: | ||
assert ( | ||
self.slicer is not None | ||
), "Cannot plot: slicer is None and no projected data provided" | ||
projected, _ = self.slicer.get_structures_slice_coords( | ||
self.regions_meshes, self.scene.root | ||
) | ||
|
||
segments: List[Dict[str, Union[str, np.ndarray, float]]] = [] | ||
for r, coords in projected.items(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
import matplotlib.pyplot as plt | ||
|
||
import brainglobe_heatmap as bgh | ||
|
||
data_dict = { | ||
|
@@ -13,28 +11,14 @@ | |
"VISam": 1.0, | ||
} | ||
|
||
# Create a list of scenes to plot | ||
# Note: it's important to keep reference to the scenes to avoid a | ||
# segmentation fault | ||
scenes = [] | ||
for distance in range(7500, 10500, 500): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe still worth keeping this example too, for advanced users (like when users want more control over figure size, or more than 25 subplots... or something else) to have more flexibility? |
||
scene = bgh.Heatmap( | ||
data_dict, | ||
position=distance, | ||
orientation="frontal", | ||
thickness=10, | ||
format="2D", | ||
cmap="Reds", | ||
vmin=0, | ||
vmax=1, | ||
label_regions=False, | ||
) | ||
scenes.append(scene) | ||
|
||
# Create a figure with 6 subplots and plot the scenes | ||
fig, axs = plt.subplots(3, 2, figsize=(18, 12)) | ||
for scene, ax in zip(scenes, axs.flatten(), strict=False): | ||
scene.plot_subplot(fig=fig, ax=ax, show_cbar=True, hide_axes=False) | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
f = bgh.Heatmap( | ||
data_dict, | ||
position=[7000, 7250, 7500, 8000, 8500, 9000, 9500, 10000], | ||
orientation="frontal", | ||
cmap="Reds", | ||
vmin=0, | ||
vmax=1, | ||
title="", # title=None for title with positions number | ||
label_regions=False, | ||
format="2D", | ||
).show(show_cbar=True, hide_axes=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo.