Skip to content

Commit

Permalink
update functions related to three_d_plot
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Aug 30, 2024
1 parent 25bf43f commit 14d20ad
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
6 changes: 3 additions & 3 deletions spateo/plotting/static/three_d_plot/three_dims_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def three_d_plot(
model: Union[PolyData, UnstructuredGrid, MultiBlock],
key: Union[str, list] = None,
filename: Optional[str] = None,
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
off_screen: bool = False,
window_size: tuple = (512, 512),
background: str = "white",
Expand Down Expand Up @@ -346,7 +346,7 @@ def three_d_multi_plot(
model: Union[PolyData, UnstructuredGrid, MultiBlock],
key: Union[str, list] = None,
filename: Optional[str] = None,
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
off_screen: bool = False,
shape: Union[str, list, tuple] = None,
window_size: Optional[tuple] = None,
Expand Down Expand Up @@ -576,7 +576,7 @@ def three_d_animate(
stable_kwargs: Optional[dict] = None,
key: Optional[str] = None,
filename: str = "animate.mp4",
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
off_screen: bool = False,
window_size: tuple = (512, 512),
background: str = "white",
Expand Down
6 changes: 3 additions & 3 deletions spateo/plotting/static/three_d_plot/three_dims_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def create_plotter(


def _set_jupyter(
jupyter: Union[bool, Literal["trame", "panel", "none", "static"]] = False,
jupyter: Union[bool, Literal["trame", "panel", "none", "static", "html"]] = False,
off_screen: bool = False,
):
if jupyter is False:
Expand All @@ -82,7 +82,7 @@ def _set_jupyter(
elif jupyter is True:
off_screen1, off_screen2 = True, off_screen
jupyter_backend = "static"
elif jupyter in ["trame", "panel", "none", "static"]:
elif jupyter in ["trame", "panel", "none", "static", "html"]:
off_screen1, off_screen2 = True, off_screen
jupyter_backend = jupyter
else:
Expand Down Expand Up @@ -535,7 +535,7 @@ def output_plotter(
filename: Optional[str] = None,
view_up: tuple = (0.5, 0.5, 1),
framerate: int = 15,
jupyter: Union[bool, Literal["trame", "panel", "none", "static"]] = False,
jupyter: Union[bool, Literal["trame", "panel", "none", "static", "html"]] = False,
):
"""
Output plotter as image, gif file or mp4 file.
Expand Down
8 changes: 4 additions & 4 deletions spateo/tdr/models/models_migration/line_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def construct_lines(
plot_cmap: Recommended colormap parameter values for plotting.
"""

padding = np.array([2] * edges.shape[0], int)
edges_w_padding = np.vstack((padding, edges.T)).T
model = pv.PolyData(points, edges_w_padding)

# padding = np.array([2] * edges.shape[0], int)
# edges_w_padding = np.vstack((padding, edges.T)).T
# model = pv.PolyData(points, edges_w_padding)
model = pv.PolyData(points, lines=pv.CellArray.from_regular_cells(np.array(edges)))
labels = np.asarray([label] * points.shape[0]) if isinstance(label, str) else np.asarray(label)
assert len(labels) == points.shape[0], "The number of labels is not equal to the number of points."
plot_cmap = None
Expand Down
9 changes: 6 additions & 3 deletions spateo/tdr/models/models_migration/morphopath_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def construct_genesis(
def construct_trajectory_X(
cells_states: Union[np.ndarray, List[np.ndarray]],
init_states: Optional[np.ndarray] = None,
n_sampling: Optional[int] = None,
n_sampling: Optional[Union[int, np.ndarray]] = None,
sampling_method: str = "trn",
key_added: str = "trajectory",
label: Optional[Union[str, list, np.ndarray]] = None,
Expand Down Expand Up @@ -202,13 +202,15 @@ def construct_trajectory_X(
else:
raise ValueError("`label` value is wrong.")

if not (n_sampling is None):
if (n_sampling is not None) and (isinstance(n_sampling, int)):
index_arr = sample(
arr=index_arr,
n=n_sampling,
method=sampling_method,
X=init_states,
)
elif (n_sampling is not None) and (isinstance(n_sampling, np.ndarray)):
index_arr = n_sampling
else:
if index_arr.shape[0] > 200:
lm.main_warning(
Expand Down Expand Up @@ -244,6 +246,7 @@ def construct_trajectory_X(

streamlines, plot_cmap = construct_lines(
points=np.concatenate(trajectories_points, axis=0),
# points=trajectories_points,
edges=np.concatenate(trajectories_edges, axis=0),
key_added=key_added,
label=np.asarray(trajectories_labels),
Expand Down Expand Up @@ -271,7 +274,7 @@ def construct_trajectory_X(
def construct_trajectory(
adata: AnnData,
fate_key: str = "fate_develop",
n_sampling: Optional[int] = None,
n_sampling: Optional[Union[int, np.ndarray]] = None,
sampling_method: str = "trn",
key_added: str = "trajectory",
label: Optional[Union[str, list, np.ndarray]] = None,
Expand Down

0 comments on commit 14d20ad

Please sign in to comment.