Skip to content

Commit 14d20ad

Browse files
committed
update functions related to three_d_plot
1 parent 25bf43f commit 14d20ad

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

spateo/plotting/static/three_d_plot/three_dims_plots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def three_d_plot(
182182
model: Union[PolyData, UnstructuredGrid, MultiBlock],
183183
key: Union[str, list] = None,
184184
filename: Optional[str] = None,
185-
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
185+
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
186186
off_screen: bool = False,
187187
window_size: tuple = (512, 512),
188188
background: str = "white",
@@ -346,7 +346,7 @@ def three_d_multi_plot(
346346
model: Union[PolyData, UnstructuredGrid, MultiBlock],
347347
key: Union[str, list] = None,
348348
filename: Optional[str] = None,
349-
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
349+
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
350350
off_screen: bool = False,
351351
shape: Union[str, list, tuple] = None,
352352
window_size: Optional[tuple] = None,
@@ -576,7 +576,7 @@ def three_d_animate(
576576
stable_kwargs: Optional[dict] = None,
577577
key: Optional[str] = None,
578578
filename: str = "animate.mp4",
579-
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany"]] = False,
579+
jupyter: Union[bool, Literal["panel", "none", "pythreejs", "static", "ipygany", "html"]] = False,
580580
off_screen: bool = False,
581581
window_size: tuple = (512, 512),
582582
background: str = "white",

spateo/plotting/static/three_d_plot/three_dims_plotter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def create_plotter(
7373

7474

7575
def _set_jupyter(
76-
jupyter: Union[bool, Literal["trame", "panel", "none", "static"]] = False,
76+
jupyter: Union[bool, Literal["trame", "panel", "none", "static", "html"]] = False,
7777
off_screen: bool = False,
7878
):
7979
if jupyter is False:
@@ -82,7 +82,7 @@ def _set_jupyter(
8282
elif jupyter is True:
8383
off_screen1, off_screen2 = True, off_screen
8484
jupyter_backend = "static"
85-
elif jupyter in ["trame", "panel", "none", "static"]:
85+
elif jupyter in ["trame", "panel", "none", "static", "html"]:
8686
off_screen1, off_screen2 = True, off_screen
8787
jupyter_backend = jupyter
8888
else:
@@ -535,7 +535,7 @@ def output_plotter(
535535
filename: Optional[str] = None,
536536
view_up: tuple = (0.5, 0.5, 1),
537537
framerate: int = 15,
538-
jupyter: Union[bool, Literal["trame", "panel", "none", "static"]] = False,
538+
jupyter: Union[bool, Literal["trame", "panel", "none", "static", "html"]] = False,
539539
):
540540
"""
541541
Output plotter as image, gif file or mp4 file.

spateo/tdr/models/models_migration/line_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ def construct_lines(
9595
plot_cmap: Recommended colormap parameter values for plotting.
9696
"""
9797

98-
padding = np.array([2] * edges.shape[0], int)
99-
edges_w_padding = np.vstack((padding, edges.T)).T
100-
model = pv.PolyData(points, edges_w_padding)
101-
98+
# padding = np.array([2] * edges.shape[0], int)
99+
# edges_w_padding = np.vstack((padding, edges.T)).T
100+
# model = pv.PolyData(points, edges_w_padding)
101+
model = pv.PolyData(points, lines=pv.CellArray.from_regular_cells(np.array(edges)))
102102
labels = np.asarray([label] * points.shape[0]) if isinstance(label, str) else np.asarray(label)
103103
assert len(labels) == points.shape[0], "The number of labels is not equal to the number of points."
104104
plot_cmap = None

spateo/tdr/models/models_migration/morphopath_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def construct_genesis(
157157
def construct_trajectory_X(
158158
cells_states: Union[np.ndarray, List[np.ndarray]],
159159
init_states: Optional[np.ndarray] = None,
160-
n_sampling: Optional[int] = None,
160+
n_sampling: Optional[Union[int, np.ndarray]] = None,
161161
sampling_method: str = "trn",
162162
key_added: str = "trajectory",
163163
label: Optional[Union[str, list, np.ndarray]] = None,
@@ -202,13 +202,15 @@ def construct_trajectory_X(
202202
else:
203203
raise ValueError("`label` value is wrong.")
204204

205-
if not (n_sampling is None):
205+
if (n_sampling is not None) and (isinstance(n_sampling, int)):
206206
index_arr = sample(
207207
arr=index_arr,
208208
n=n_sampling,
209209
method=sampling_method,
210210
X=init_states,
211211
)
212+
elif (n_sampling is not None) and (isinstance(n_sampling, np.ndarray)):
213+
index_arr = n_sampling
212214
else:
213215
if index_arr.shape[0] > 200:
214216
lm.main_warning(
@@ -244,6 +246,7 @@ def construct_trajectory_X(
244246

245247
streamlines, plot_cmap = construct_lines(
246248
points=np.concatenate(trajectories_points, axis=0),
249+
# points=trajectories_points,
247250
edges=np.concatenate(trajectories_edges, axis=0),
248251
key_added=key_added,
249252
label=np.asarray(trajectories_labels),
@@ -271,7 +274,7 @@ def construct_trajectory_X(
271274
def construct_trajectory(
272275
adata: AnnData,
273276
fate_key: str = "fate_develop",
274-
n_sampling: Optional[int] = None,
277+
n_sampling: Optional[Union[int, np.ndarray]] = None,
275278
sampling_method: str = "trn",
276279
key_added: str = "trajectory",
277280
label: Optional[Union[str, list, np.ndarray]] = None,

0 commit comments

Comments
 (0)