-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
graph utility methods #167
Labels
enhancement
New feature or request
Comments
Traversal functions? |
Here's a pruning operation (just barely tested, so beware bugs). def pruned(graph, keep):
"""
Return a copy of ``graph``, pruned to contain only ancestry for the
specific deme/time combinations specified in the dictionary ``keep``.
:param graph: The graph to be pruned.
:type graph: demes.Graph
:param keep: Dictionary mapping deme IDs to times.
:type keep: typing.Dict[str, typing.Union[int, float]]
:return: the new pruned graph
:rtype: demes.Graph
"""
for id, t in keep.items():
deme = graph[id]
if not (deme.start_time > t >= deme.end_time):
raise ValueError(f"time={t} out of bounds for deme {id}")
pulses_into = collections.defaultdict(list)
migrations_into = collections.defaultdict(list)
for pulse in graph.pulses:
pulses_into[pulse.dest].append(pulse)
for migration in graph.migrations:
migrations_into[migration.dest].append(migration)
pulses_keep = []
migrations_keep = []
demes_keep = copy.deepcopy(keep)
queue = list(demes_keep.keys())
while len(queue) > 0:
id = queue.pop()
deme = graph[id]
time = demes_keep[id]
for pulse in pulses_into.get(id, []):
if pulse.time <= time:
continue
if demes_keep.get(pulse.source, float("inf")) <= time:
continue
demes_keep[pulse.source] = pulse.time
pulses_keep.append(copy.deepcopy(pulse))
queue.append(pulse.source)
for migration in migrations_into.get(id, []):
if migration.start_time <= time:
continue
if demes_keep.get(migration.source, float("inf")) <= time:
continue
end_time = max(migration.end_time, time)
migrations_keep.append(copy.deepcopy(migration))
migrations_keep[-1].end_time = end_time
demes_keep[migration.source] = end_time
queue.append(migration.source)
for anc in deme.ancestors:
anc_time = demes_keep.get(anc, float("inf"))
if anc_time <= deme.start_time:
continue
demes_keep[anc] = deme.start_time
queue.append(anc)
g = demes.Graph(
description=graph.description,
time_units=graph.time_units,
generation_time=graph.generation_time
)
for deme in graph.demes:
end_time = demes_keep.get(deme.id)
if end_time is None:
continue
epochs = [copy.deepcopy(e) for e in deme.epochs if e.start_time > end_time]
epochs[-1].end_time = end_time
g.deme(
id=deme.id,
description=deme.description,
ancestors=deme.ancestors,
proportions=deme.proportions,
epochs=epochs,
)
g.migrations = migrations_keep
g.pulses = pulses_keep
return g |
Needs testing. def connected_subgraphs(graph: demes.Graph):
"""
Return a list of connected subgraphs of the given graph.
:param graph: The graph.
:return: A list of subgraphs.
:rtype: list[demes.Graph]
"""
# if len(graph.demes) == 1:
# return [copy.deepcopy(graph)]
# Find all groups of directly connected demes.
connected = []
for deme in graph.demes:
connected.append(set(deme.ancestors) | set([deme.name]))
for pulse in graph.pulses:
connected.append(set([pulse.source, pulse.dest]))
for migration in graph.migrations:
connected.append(set([migration.source, migration.dest]))
# Merge groups.
while len(connected) > 1:
merged_groups = False
for i, a in enumerate(connected):
for j, b in enumerate(connected[i + 1 :], i + 1):
if len(a & b) > 0:
connected = [
group for k, group in enumerate(connected) if k not in (i, j)
]
connected.append(a | b)
merged_groups = True
break
if merged_groups:
break
if not merged_groups:
# Couldn't merge any groups, so we're done.
break
data = graph.asdict()
subgraphs = []
for group in connected:
b = demes.Builder.fromdict(data)
b.data["demes"] = [deme for deme in b.data["demes"] if deme["name"] in group]
b.data["migrations"] = [
migration
for migration in b.data.get("migrations", [])
# We need only check the source deme, because source and dest
# must both be in the same group.
if migration["source"] in group
]
b.data["pulses"] = [
pulse for pulse in b.data.get("pulses", []) if pulse["source"] in group
]
subgraphs.append(b.resolve())
# tests
assert sum(len(g.demes) for g in subgraphs) == len(graph.demes)
for i, g1 in enumerate(subgraphs):
for g2 in connected[i + 1 :]:
d1 = [deme.name for deme in g1.demes]
d2 = [deme.name for deme in g2.demes]
assert len(set(d1) & set(d2)) == 0
if len(subgraphs) == 1:
graph.assert_close(subgraphs[0])
return subgraphs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We should consider including methods for graph-theoretic properties that are useful downstream. E.g.
roots()
to return the IDs of the graph's root demes.components()
to return a list of graphs, one for each collection of connected demes https://en.wikipedia.org/wiki/Component_(graph_theory).Are there other things that would be useful?
The text was updated successfully, but these errors were encountered: