Skip to content

Commit fe1fa17

Browse files
committed
Brillplot: work with Agg backend, create 3d prettyplot base
It's important that an Axis passed to the Pymatgen Brillouin plotter is already 3D or we get some rather confusing error messages. small cleanup
1 parent 0b70c05 commit fe1fa17

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

sumo/cli/brillplot.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424
from pymatgen.io.vasp.outputs import BSVasprun
2525
from pymatgen.electronic_structure.bandstructure import get_reconstructed_band_structure
26-
from pymatgen.electronic_structure.plotter import BSPlotter
26+
from pymatgen.electronic_structure.plotter import plot_brillouin_zone
2727
import matplotlib as mpl
28-
2928
mpl.use("Agg")
29+
from sumo.plotting import pretty_plot_3d
30+
3031
__author__ = "Arthur Youd"
3132
__version__ = "1.0"
3233
__maintainer__ = "Alex Ganose"
@@ -35,6 +36,7 @@
3536

3637

3738
def brillplot(filenames=None, prefix=None, directory=None,
39+
width=6, height=6, fonts=None,
3840
image_format="pdf", dpi=400):
3941
"""Generate plot of first brillouin zone from a band-structure calculation.
4042
Args:
@@ -57,17 +59,28 @@ def brillplot(filenames=None, prefix=None, directory=None,
5759
bs = vr.get_band_structure(line_mode=True)
5860
bandstructures.append(bs)
5961
bs = get_reconstructed_band_structure(bandstructures)
60-
plotter = BSPlotter(bs)
61-
plt = plotter.plot_brillouin()
62+
63+
labels = {}
64+
for k in bs.kpoints:
65+
if k.label:
66+
labels[k.label] = k.frac_coords
67+
68+
lines = []
69+
for b in bs.branches:
70+
lines.append([bs.kpoints[b['start_index']].frac_coords,
71+
bs.kpoints[b['end_index']].frac_coords])
72+
73+
plt = pretty_plot_3d(width, height, dpi=dpi, fonts=fonts)
74+
fig = plot_brillouin_zone(bs.lattice_rec, lines=lines, labels=labels,
75+
ax=plt.gca())
6276

6377
basename = "brillouin.{}".format(image_format)
6478
filename = "{}_{}".format(prefix, basename) if prefix else basename
6579
if directory:
6680
filename = os.path.join(directory, filename)
67-
plt.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")
81+
fig.savefig(filename, format=image_format, dpi=dpi, bbox_inches="tight")
6882
return plt
6983

70-
7184
def find_vasprun_files():
7285
"""Search for vasprun files from the current directory.
7386
@@ -99,7 +112,7 @@ def find_vasprun_files():
99112

100113
def _get_parser():
101114
parser = argparse.ArgumentParser(description="""
102-
brillplot is a script to produce publication-ready
115+
brillplot is a script to produce publication-ready
103116
brillouin zone diagrams""",
104117
epilog="""
105118
Author: {}

sumo/plotting/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from matplotlib.collections import LineCollection
1313
from matplotlib import rc, rcParams
1414
from pkg_resources import resource_filename
15+
from mpl_toolkits.mplot3d import Axes3D
1516

1617
colour_cache = {}
1718

@@ -38,7 +39,6 @@ def styled_plot(*style_sheets):
3839
"""
3940

4041
def decorator(get_plot):
41-
4242
def wrapper(*args, fonts=None, style=None, no_base_style=False,
4343
**kwargs):
4444

@@ -96,6 +96,27 @@ def pretty_plot(width=None, height=None, plt=None, dpi=None):
9696
return plt
9797

9898

99+
def pretty_plot_3d(width=5, height=5, plt=None, dpi=None, fonts=None):
100+
if plt is None:
101+
plt = matplotlib.pyplot
102+
if width is None:
103+
width = matplotlib.rcParams['figure.figsize'][0]
104+
if height is None:
105+
height = matplotlib.rcParams['figure.figsize'][1]
106+
107+
if dpi is not None:
108+
matplotlib.rcParams['figure.dpi'] = dpi
109+
110+
fig = plt.figure(figsize=(width, height), dpi=dpi)
111+
112+
else:
113+
fig = plt.gcf()
114+
115+
ax = fig.add_subplot(111, projection='3d')
116+
117+
return plt
118+
119+
99120
def pretty_subplot(nrows, ncols, width=None, height=None, sharex=True,
100121
sharey=True, dpi=None, plt=None, gridspec_kw=None):
101122
"""Get a :obj:`matplotlib.pyplot` subplot object with pretty defaults.
@@ -133,7 +154,6 @@ def pretty_subplot(nrows, ncols, width=None, height=None, sharex=True,
133154
plt.subplots(nrows, ncols, sharex=sharex, sharey=sharey, dpi=dpi,
134155
figsize=(width, height), facecolor='w',
135156
gridspec_kw=gridspec_kw)
136-
137157
return plt
138158

139159

0 commit comments

Comments
 (0)