Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph utility methods #167

Open
grahamgower opened this issue Dec 3, 2020 · 3 comments
Open

graph utility methods #167

grahamgower opened this issue Dec 3, 2020 · 3 comments
Labels
enhancement New feature or request

Comments

@grahamgower
Copy link
Member

We should consider including methods for graph-theoretic properties that are useful downstream. E.g.

Are there other things that would be useful?

@grahamgower grahamgower added the enhancement New feature or request label Dec 3, 2020
@jeromekelleher
Copy link
Member

Traversal functions?

@grahamgower
Copy link
Member Author

Here's a pruning operation (just barely tested, so beware bugs).

def pruned(graph, keep):
    """
    Return a copy of ``graph``, pruned to contain only ancestry for the
    specific deme/time combinations specified in the dictionary ``keep``.

    :param graph: The graph to be pruned.
    :type graph: demes.Graph
    :param keep: Dictionary mapping deme IDs to times.
    :type keep: typing.Dict[str, typing.Union[int, float]]
    :return: the new pruned graph
    :rtype: demes.Graph
    """
    for id, t in keep.items():
        deme = graph[id]
        if not (deme.start_time > t >= deme.end_time):
            raise ValueError(f"time={t} out of bounds for deme {id}")

    pulses_into = collections.defaultdict(list)
    migrations_into = collections.defaultdict(list)
    for pulse in graph.pulses:
        pulses_into[pulse.dest].append(pulse)
    for migration in graph.migrations:
        migrations_into[migration.dest].append(migration)

    pulses_keep = []
    migrations_keep = []

    demes_keep = copy.deepcopy(keep)
    queue = list(demes_keep.keys())
    while len(queue) > 0:
        id = queue.pop()
        deme = graph[id]
        time = demes_keep[id]

        for pulse in pulses_into.get(id, []):
            if pulse.time <= time:
                continue
            if demes_keep.get(pulse.source, float("inf")) <= time:
                continue
            demes_keep[pulse.source] = pulse.time
            pulses_keep.append(copy.deepcopy(pulse))
            queue.append(pulse.source)

        for migration in migrations_into.get(id, []):
            if migration.start_time <= time:
                continue
            if demes_keep.get(migration.source, float("inf")) <= time:
                continue
            end_time = max(migration.end_time, time)
            migrations_keep.append(copy.deepcopy(migration))
            migrations_keep[-1].end_time = end_time
            demes_keep[migration.source] = end_time
            queue.append(migration.source)

        for anc in deme.ancestors:
            anc_time = demes_keep.get(anc, float("inf"))
            if anc_time <= deme.start_time:
                continue
            demes_keep[anc] = deme.start_time
            queue.append(anc)

    g = demes.Graph(
        description=graph.description,
        time_units=graph.time_units,
        generation_time=graph.generation_time
    )
    for deme in graph.demes:
        end_time = demes_keep.get(deme.id)
        if end_time is None:
            continue
        epochs = [copy.deepcopy(e) for e in deme.epochs if e.start_time > end_time]
        epochs[-1].end_time = end_time
        g.deme(
            id=deme.id,
            description=deme.description,
            ancestors=deme.ancestors,
            proportions=deme.proportions,
            epochs=epochs,
        )
    g.migrations = migrations_keep
    g.pulses = pulses_keep

    return g

@grahamgower
Copy link
Member Author

Needs testing.

def connected_subgraphs(graph: demes.Graph):
    """
    Return a list of connected subgraphs of the given graph.

    :param graph: The graph.
    :return: A list of subgraphs.
    :rtype: list[demes.Graph]
    """

    #    if len(graph.demes) == 1:
    #        return [copy.deepcopy(graph)]

    # Find all groups of directly connected demes.
    connected = []
    for deme in graph.demes:
        connected.append(set(deme.ancestors) | set([deme.name]))
    for pulse in graph.pulses:
        connected.append(set([pulse.source, pulse.dest]))
    for migration in graph.migrations:
        connected.append(set([migration.source, migration.dest]))

    # Merge groups.
    while len(connected) > 1:
        merged_groups = False
        for i, a in enumerate(connected):
            for j, b in enumerate(connected[i + 1 :], i + 1):
                if len(a & b) > 0:
                    connected = [
                        group for k, group in enumerate(connected) if k not in (i, j)
                    ]
                    connected.append(a | b)
                    merged_groups = True
                    break
            if merged_groups:
                break
        if not merged_groups:
            # Couldn't merge any groups, so we're done.
            break

    data = graph.asdict()

    subgraphs = []
    for group in connected:
        b = demes.Builder.fromdict(data)
        b.data["demes"] = [deme for deme in b.data["demes"] if deme["name"] in group]
        b.data["migrations"] = [
            migration
            for migration in b.data.get("migrations", [])
            # We need only check the source deme, because source and dest
            # must both be in the same group.
            if migration["source"] in group
        ]
        b.data["pulses"] = [
            pulse for pulse in b.data.get("pulses", []) if pulse["source"] in group
        ]
        subgraphs.append(b.resolve())

    # tests

    assert sum(len(g.demes) for g in subgraphs) == len(graph.demes)

    for i, g1 in enumerate(subgraphs):
        for g2 in connected[i + 1 :]:
            d1 = [deme.name for deme in g1.demes]
            d2 = [deme.name for deme in g2.demes]
            assert len(set(d1) & set(d2)) == 0

    if len(subgraphs) == 1:
        graph.assert_close(subgraphs[0])

    return subgraphs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants