Skip to content

Commit 3ae5095

Browse files
authored
BUG: Mermaid nodes with spaces in name (#7835)
* standardize mermaid node names * add test with space in variable name * remove use of fixture * use the "when in doubt, cast" strategy
1 parent 4271195 commit 3ae5095

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

pymc/model_graph.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -795,36 +795,42 @@ def model_to_graphviz(
795795
)
796796

797797

798+
def _create_mermaid_node_name(name: str) -> str:
799+
return name.replace(":", "_").replace(" ", "_")
800+
801+
798802
def _build_mermaid_node(node: NodeInfo) -> list[str]:
799803
var = node.var
800804
node_type = node.node_type
805+
name = cast(str, var.name)
806+
node_name = _create_mermaid_node_name(name)
801807
if node_type == NodeType.DATA:
802808
return [
803-
f"{var.name}[{var.name} ~ Data]",
804-
f"{var.name}@{{ shape: db }}",
809+
f"{node_name}[{var.name} ~ Data]",
810+
f"{node_name}@{{ shape: db }}",
805811
]
806812
elif node_type == NodeType.OBSERVED_RV:
807813
return [
808-
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
809-
f"{var.name}@{{ shape: rounded }}",
810-
f"style {var.name} fill:#757575",
814+
f"{node_name}([{name} ~ {random_variable_symbol(var)}])",
815+
f"{node_name}@{{ shape: rounded }}",
816+
f"style {node_name} fill:#757575",
811817
]
812818

813819
elif node_type == NodeType.FREE_RV:
814820
return [
815-
f"{var.name}([{var.name} ~ {random_variable_symbol(var)}])",
816-
f"{var.name}@{{ shape: rounded }}",
821+
f"{node_name}([{name} ~ {random_variable_symbol(var)}])",
822+
f"{node_name}@{{ shape: rounded }}",
817823
]
818824
elif node_type == NodeType.DETERMINISTIC:
819825
return [
820-
f"{var.name}([{var.name} ~ Deterministic])",
821-
f"{var.name}@{{ shape: rect }}",
826+
f"{node_name}([{name} ~ Deterministic])",
827+
f"{node_name}@{{ shape: rect }}",
822828
]
823829
elif node_type == NodeType.POTENTIAL:
824830
return [
825-
f"{var.name}([{var.name} ~ Potential])",
826-
f"{var.name}@{{ shape: diam }}",
827-
f"style {var.name} fill:#f0f0f0",
831+
f"{node_name}([{name} ~ Potential])",
832+
f"{node_name}@{{ shape: diam }}",
833+
f"style {node_name} fill:#f0f0f0",
828834
]
829835

830836
return []
@@ -842,8 +848,8 @@ def _build_mermaid_edges(edges) -> list[str]:
842848
"""Return a list of Mermaid edge definitions."""
843849
edge_lines = []
844850
for child, parent in edges:
845-
child_id = str(child).replace(":", "_")
846-
parent_id = str(parent).replace(":", "_")
851+
child_id = _create_mermaid_node_name(child)
852+
parent_id = _create_mermaid_node_name(parent)
847853
edge_lines.append(f"{parent_id} --> {child_id}")
848854
return edge_lines
849855

tests/test_model_graph.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,20 @@ def test_model_to_mermaid(simple_model):
652652
%% Plates:
653653
""")
654654
assert model_to_mermaid(simple_model) == expected_mermaid_string.strip()
655+
656+
657+
def test_model_to_mermaid_with_variable_with_space():
658+
with pm.Model() as variable_with_space:
659+
pm.Normal("plant growth")
660+
661+
expected_mermaid_string = dedent("""
662+
graph TD
663+
%% Nodes:
664+
plant_growth([plant growth ~ Normal])
665+
plant_growth@{ shape: rounded }
666+
667+
%% Edges:
668+
669+
%% Plates:
670+
""")
671+
assert model_to_mermaid(variable_with_space) == expected_mermaid_string.strip()

0 commit comments

Comments
 (0)