2121import numpy as np
2222import scipy .stats
2323import xmltodict
24- from matplotlib import patches
2524from mpl_toolkits .mplot3d import Axes3D
2625
2726from microstructpy import _misc
@@ -619,21 +618,11 @@ def plot_seeds(seeds, phases, domain, plot_files=[], plot_axes=True,
619618
620619 # Plot seeds
621620 edge_kwargs .setdefault ('edgecolors' , {2 : 'k' , 3 : 'none' }[n_dim ])
622- seeds .plot (facecolors = seed_colors , ** edge_kwargs )
623-
624- # Add legend
625- custom_seeds = [None for _ in phases ]
626- for seed in seeds :
627- phase_num = seed .phase
628- if custom_seeds [phase_num ] is None :
629- c = _phase_color (phase_num , phases )
630- lbl = phase_names [phase_num ]
631- phase_patch = patches .Patch (fc = c , ec = 'k' , label = lbl )
632- custom_seeds [phase_num ] = phase_patch
633-
634621 if given_names and color_by == 'material' :
635- handles = [h for h in custom_seeds if h is not None ]
636- ax .legend (handles = handles , loc = 4 )
622+ seeds .plot (material = phase_names , facecolors = seed_colors , loc = 4 ,
623+ ** edge_kwargs )
624+ else :
625+ seeds .plot (facecolors = seed_colors , ** edge_kwargs )
637626
638627 # Set limits
639628 lims = domain .limits
@@ -671,6 +660,14 @@ def _phase_color(i, phases):
671660 return phases [i ].get ('color' , 'C' + str (i % 10 ))
672661
673662
663+ def _phase_color_by (i , phases , color_by = 'material' , colormap = 'viridis' ):
664+ if color_by == 'material' :
665+ return phases [i ].get ('color' , 'C' + str (i % 10 ))
666+ elif color_by == 'material number' :
667+ n = len (phases )
668+ return _cm_color (i / (n - 1 ), colormap )
669+
670+
674671def _cm_color (f , colormap = 'viridis' ):
675672 return plt .get_cmap (colormap )(f )
676673
@@ -739,7 +736,10 @@ def plot_poly(pmesh, phases, plot_files=['polymesh.png'], plot_axes=True,
739736 # Plot polygons
740737 fcs = _poly_colors (pmesh , phases , color_by , colormap , n_dim )
741738 if n_dim == 2 :
742- pmesh .plot (facecolors = fcs )
739+ if given_names and color_by == 'material' :
740+ pmesh .plot (facecolors = fcs , material = phase_names )
741+ else :
742+ pmesh .plot (facecolors = fcs )
743743
744744 edge_color = edge_kwargs .pop ('edgecolors' , (0 , 0 , 0 , 1 ))
745745 facet_colors = []
@@ -750,22 +750,14 @@ def plot_poly(pmesh, phases, plot_files=['polymesh.png'], plot_axes=True,
750750 facet_colors .append ('none' )
751751
752752 edge_kwargs .setdefault ('capstyle' , 'round' )
753- pmesh .plot_facets (color = facet_colors , ** edge_kwargs )
753+ pmesh .plot_facets (color = facet_colors , index_by = 'facet' , ** edge_kwargs )
754754 else :
755755 edge_kwargs .setdefault ('edgecolors' , 'k' )
756- pmesh .plot (facecolors = fcs , ** edge_kwargs )
757-
758- # add legend
759- if given_names :
760- custom_seeds = [None for _ in phases ]
761- for phase_num in pmesh .phase_numbers :
762- if custom_seeds [phase_num ] is None :
763- c = phase_colors [phase_num ]
764- lbl = phase_names [phase_num ]
765- phase_patch = patches .Patch (fc = c , ec = 'k' , label = lbl )
766- custom_seeds [phase_num ] = phase_patch
767- handles = [h for h in custom_seeds if h is not None ]
768- ax .legend (handles = handles , loc = 4 )
756+ if given_names and color_by == 'material' :
757+ pmesh .plot (facecolors = fcs , index_by = 'seed' , material = phase_names ,
758+ ** edge_kwargs )
759+ else :
760+ pmesh .plot (facecolors = fcs , index_by = 'seed' , ** edge_kwargs )
769761
770762 # format axes
771763 lims = np .array ([np .min (pmesh .points , 0 ), np .max (pmesh .points , 0 )]).T
@@ -790,35 +782,37 @@ def plot_poly(pmesh, phases, plot_files=['polymesh.png'], plot_axes=True,
790782def _poly_colors (pmesh , phases , color_by , colormap , n_dim ):
791783 if n_dim == 2 :
792784 if color_by == 'material' :
793- return [_phase_color (n , phases ) for n in pmesh .phase_numbers ]
785+ r_colors = [_phase_color (n , phases ) for n in pmesh .phase_numbers ]
794786 elif color_by == 'seed number' :
795787 n = max (pmesh .seed_numbers ) + 1
796- return [_cm_color (s / (n - 1 ), colormap ) for s in
797- pmesh .seed_numbers ]
788+ r_colors = [_cm_color (s / (n - 1 ), colormap ) for s in
789+ pmesh .seed_numbers ]
798790 elif color_by == 'material number' :
799791 n = len (phases )
800- return [_cm_color (p / (n - 1 ), colormap ) for p in
801- pmesh .phase_numbers ]
792+ r_colors = [_cm_color (p / (n - 1 ), colormap ) for p in
793+ pmesh .phase_numbers ]
794+ n_seeds = max (pmesh .seed_numbers ) + 1
795+ s_colors = ['none' for i in range (n_seeds )]
796+ for seed_num , r_c in zip (pmesh .seed_numbers , r_colors ):
797+ s_colors [seed_num ] = r_c
798+ return s_colors
802799 else :
803- poly_fcs = []
804- for n_pair in pmesh .facet_neighbors :
805- if min (n_pair ) < 0 :
806- n_int = max (n_pair )
807- if color_by == 'material' :
808- phase_num = pmesh .phase_numbers [n_int ]
809- color = _phase_color (phase_num , phases )
810- elif color_by == 'seed number' :
811- n_seed = max (pmesh .seed_numbers ) + 1
812- seed_num = pmesh .seed_numbers [n_int ]
813- color = _cm_color (seed_num / (n_seed - 1 ), colormap )
814- elif color_by == 'material number' :
815- n_phases = len (phases )
816- phase_num = pmesh .phase_numbers [n_int ]
817- color = _cm_color (phase_num / (n_phases - 1 ), colormap )
800+ s2p = {s : p for s , p in zip (pmesh .seed_numbers , pmesh .phase_numbers )}
801+ n = max (s2p .keys ()) + 1
802+ colors = []
803+ for s in range (n ):
804+ if color_by == 'material' :
805+ phase_num = s2p [s ]
806+ color = _phase_color (phase_num , phases )
807+ elif color_by == 'seed number' :
808+ color = _cm_color (s / (n - 1 ), colormap )
809+ elif color_by == 'material number' :
810+ n_phases = len (phases )
811+ color = _cm_color (s2p [s ] / (n_phases - 1 ), colormap )
818812 else :
819813 color = 'none'
820- poly_fcs .append (color )
821- return poly_fcs
814+ colors .append (color )
815+ return colors
822816
823817
824818# --------------------------------------------------------------------------- #
@@ -883,27 +877,56 @@ def plot_tri(tmesh, phases, seeds, pmesh, plot_files=[], plot_axes=True,
883877 ax ._axis3don = False
884878 fig .add_axes (ax )
885879
886- # determine triangle element colors
887- fcs = _tri_colors (tmesh , seeds , pmesh , phases , color_by , colormap , n_dim )
888- phase_nums = range (len (phases ))
889-
890880 # plot triangle mesh
891881 edge_kwargs .setdefault ('linewidths' , {2 : 0.5 , 3 : 0.1 }[n_dim ])
892882 edge_kwargs .setdefault ('edgecolors' , 'k' )
893- tmesh .plot (facecolors = fcs , ** edge_kwargs )
894-
895- # add legend
896- if any ([given_names [phase_num ] for phase_num in phase_nums ]):
897- custom_seeds = [None for _ in phases ]
898- for seed_num in tmesh .element_attributes :
899- phase_num = seeds [seed_num ].phase
900- if custom_seeds [phase_num ] is None :
901- c = phase_colors [phase_num ]
902- lbl = phase_names [phase_num ]
903- phase_patch = patches .Patch (fc = c , ec = 'k' , label = lbl )
904- custom_seeds [phase_num ] = phase_patch
905- handles = [h for h in custom_seeds if h is not None ]
906- ax .legend (handles = handles , loc = 4 )
883+ if given_names and color_by in ('material' , 'material number' ):
884+ n = len (phases )
885+ cs = [_phase_color_by (i , phases , color_by , colormap ) for i in range (n )]
886+
887+ old_e_att = np .copy (tmesh .element_attributes )
888+ old_f_att = np .copy (tmesh .facet_attributes )
889+ tmesh .element_attributes = [seeds [i ].phase for i in old_e_att ]
890+
891+ # Determine which facets are visible
892+ visible_facets = [i for i , fn in enumerate (pmesh .facet_neighbors )
893+ if min (fn ) < 0 ]
894+ f_frontier = set (visible_facets )
895+ f_expl = set ()
896+ r_expl = set (range (- 6 , 0 ))
897+ while f_frontier :
898+ new_facets = set ()
899+ for f_num in f_frontier :
900+ regions = pmesh .facet_neighbors [f_num ]
901+ new_regions = set (regions ) - r_expl
902+ for r in new_regions :
903+ phase = phases [pmesh .phase_numbers [r ]]
904+ p_type = phase .get ('material_type' , 'solid' )
905+ if p_type in _misc .kw_void :
906+ new_facets |= set (pmesh .regions [r ])
907+ r_expl .update (r )
908+ f_expl |= f_frontier
909+
910+ f_frontier |= new_facets
911+ f_frontier -= f_expl
912+
913+ plot_facets = [i for i , fn in enumerate (pmesh .facet_neighbors ) if
914+ len (set (fn ) - r_expl ) == 1 ]
915+ plot_regions = [list (set (pmesh .facet_neighbors [i ]) - r_expl )[0 ] for i
916+ in plot_facets ]
917+ poly_phases = [pmesh .phase_numbers [r ] for r in plot_regions ]
918+ p_dict = {f : p for f , p in zip (plot_facets , poly_phases )}
919+ plot_phases = [p_dict .get (f , n ) for f in tmesh .facet_attributes ]
920+ tmesh .facet_attributes = plot_phases
921+
922+ tmesh .plot (facecolors = cs , index_by = 'attribute' , material = phase_names ,
923+ ** edge_kwargs )
924+
925+ tmesh .element_attributes = old_e_att
926+ tmesh .facet_attributes = old_f_att
927+ else :
928+ fcs = _poly_colors (pmesh , phases , color_by , colormap , n_dim )
929+ tmesh .plot (facecolors = fcs , index_by = 'attribute' , ** edge_kwargs )
907930
908931 # format axes
909932 lims = np .array ([np .min (tmesh .points , 0 ), np .max (tmesh .points , 0 )]).T
0 commit comments