@@ -795,36 +795,42 @@ def model_to_graphviz(
795
795
)
796
796
797
797
798
+ def _create_mermaid_node_name (name : str ) -> str :
799
+ return name .replace (":" , "_" ).replace (" " , "_" )
800
+
801
+
798
802
def _build_mermaid_node (node : NodeInfo ) -> list [str ]:
799
803
var = node .var
800
804
node_type = node .node_type
805
+ name = cast (str , var .name )
806
+ node_name = _create_mermaid_node_name (name )
801
807
if node_type == NodeType .DATA :
802
808
return [
803
- f"{ var . name } [{ var .name } ~ Data]" ,
804
- f"{ var . name } @{{ shape: db }}" ,
809
+ f"{ node_name } [{ var .name } ~ Data]" ,
810
+ f"{ node_name } @{{ shape: db }}" ,
805
811
]
806
812
elif node_type == NodeType .OBSERVED_RV :
807
813
return [
808
- f"{ var . name } ([{ var . name } ~ { random_variable_symbol (var )} ])" ,
809
- f"{ var . name } @{{ shape: rounded }}" ,
810
- f"style { var . name } fill:#757575" ,
814
+ f"{ node_name } ([{ name } ~ { random_variable_symbol (var )} ])" ,
815
+ f"{ node_name } @{{ shape: rounded }}" ,
816
+ f"style { node_name } fill:#757575" ,
811
817
]
812
818
813
819
elif node_type == NodeType .FREE_RV :
814
820
return [
815
- f"{ var . name } ([{ var . name } ~ { random_variable_symbol (var )} ])" ,
816
- f"{ var . name } @{{ shape: rounded }}" ,
821
+ f"{ node_name } ([{ name } ~ { random_variable_symbol (var )} ])" ,
822
+ f"{ node_name } @{{ shape: rounded }}" ,
817
823
]
818
824
elif node_type == NodeType .DETERMINISTIC :
819
825
return [
820
- f"{ var . name } ([{ var . name } ~ Deterministic])" ,
821
- f"{ var . name } @{{ shape: rect }}" ,
826
+ f"{ node_name } ([{ name } ~ Deterministic])" ,
827
+ f"{ node_name } @{{ shape: rect }}" ,
822
828
]
823
829
elif node_type == NodeType .POTENTIAL :
824
830
return [
825
- f"{ var . name } ([{ var . name } ~ Potential])" ,
826
- f"{ var . name } @{{ shape: diam }}" ,
827
- f"style { var . name } fill:#f0f0f0" ,
831
+ f"{ node_name } ([{ name } ~ Potential])" ,
832
+ f"{ node_name } @{{ shape: diam }}" ,
833
+ f"style { node_name } fill:#f0f0f0" ,
828
834
]
829
835
830
836
return []
@@ -842,8 +848,8 @@ def _build_mermaid_edges(edges) -> list[str]:
842
848
"""Return a list of Mermaid edge definitions."""
843
849
edge_lines = []
844
850
for child , parent in edges :
845
- child_id = str (child ). replace ( ":" , "_" )
846
- parent_id = str (parent ). replace ( ":" , "_" )
851
+ child_id = _create_mermaid_node_name (child )
852
+ parent_id = _create_mermaid_node_name (parent )
847
853
edge_lines .append (f"{ parent_id } --> { child_id } " )
848
854
return edge_lines
849
855
0 commit comments