Skip to content

Commit cc7e24c

Browse files
Merge pull request #736 from Starlitnightly/master
Enhance _snapshot method in StructureWatcher to capture detailed colu…
2 parents 1e37ed6 + 4b9a620 commit cc7e24c

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

dynamo/prediction/fate.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,18 @@ def _fate(
253253
t_stack, prediction_stack = np.hstack(t), np.hstack(prediction)
254254
n_cell, n_feature = init_states.shape
255255

256-
t_len = int(len(t_stack) / n_cell)
257-
avg = np.zeros((n_feature, t_len))
256+
if len(t_stack) > 0 and len(prediction_stack) > 0:
257+
t_len = int(len(t_stack) / n_cell)
258+
avg = np.zeros((n_feature, t_len))
258259

259-
for i in range(t_len):
260-
avg[:, i] = np.mean(prediction_stack[:, np.arange(n_cell) * t_len + i], 1)
260+
for i in range(t_len):
261+
avg[:, i] = np.mean(prediction_stack[:, np.arange(n_cell) * t_len + i], 1)
261262

262-
prediction = [avg]
263-
t = [np.sort(np.unique(t))]
263+
prediction = [avg]
264+
t = [np.sort(np.unique(t))]
265+
else:
266+
# If stack is empty (e.g. failed integrations), keep original or handle as needed
267+
pass
264268

265269
return t, prediction
266270

dynamo/tools/_track.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,24 @@ def _get_column_type(self, series):
7777

7878
def _snapshot(self, adata):
7979
"""Captures keys, types, and shapes for comparison."""
80+
obs_types = {}
81+
for c in adata.obs.columns:
82+
try:
83+
obs_types[c] = self._get_column_type(adata.obs[c])
84+
except Exception:
85+
obs_types[c] = "unknown (error)"
86+
87+
var_types = {}
88+
for c in adata.var.columns:
89+
try:
90+
var_types[c] = self._get_column_type(adata.var[c])
91+
except Exception:
92+
var_types[c] = "unknown (error)"
93+
8094
snapshot = {
8195
"shape": adata.shape,
82-
"obs": {c: self._get_column_type(adata.obs[c]) for c in adata.obs.columns},
83-
"var": {c: self._get_column_type(adata.var[c]) for c in adata.var.columns},
96+
"obs": obs_types,
97+
"var": var_types,
8498
"uns": set(adata.uns.keys()),
8599
# For complex slots, store a dict of {key: description_string}
86100
"obsm": {k: self._get_type_desc(v) for k, v in adata.obsm.items()},

0 commit comments

Comments
 (0)