Skip to content

Commit

Permalink
Handle matplotlib ax
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnectedSystems committed Sep 11, 2019
1 parent d4cfaf1 commit 0418389
Showing 1 changed file with 63 additions and 33 deletions.
96 changes: 63 additions & 33 deletions wosis/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ def plot_saver(func):
@wraps(func)
def wrapper(*args, **kwargs):
save_plot_fn = kwargs.pop('save_plot_fn', None)
fig = func(*args, **kwargs)
ret = kwargs.pop('return_fig', None)

axes = func(*args, **kwargs)
if isinstance(axes, list):
fig = axes[0].get_figure()
else:
fig = axes.get_figure()
# End if

plt.tight_layout()

if save_plot_fn:
Expand All @@ -33,9 +41,8 @@ def wrapper(*args, **kwargs):
dpi=300, bbox_inches='tight')
# End if

ret = kwargs.pop('return_fig', None)
if ret:
return fig
return axes
# End wrapper()

return wrapper
Expand Down Expand Up @@ -78,7 +85,7 @@ def _set_title(title, num_text):


@plot_saver
def plot_pub_trend(search_results, title=None, no_log_scale=False):
def plot_pub_trend(search_results, title=None, no_log_scale=False, ax=None):
"""Plot publication trend across time.
Will publication trend in log scale if large number of publications found.
Expand All @@ -89,6 +96,8 @@ def plot_pub_trend(search_results, title=None, no_log_scale=False):
* search_results : MetaKnowledge RecordCollection, of search results
* title : str, title for plot
* no_log_scale : bool, avoid log scale
* ax : Axis Object or None, matplotlib axis object to add to
* kwargs: additional args passed to matplotlib
Returns
==========
Expand All @@ -110,7 +119,8 @@ def plot_pub_trend(search_results, title=None, no_log_scale=False):
if i in num_pubs.index else 0 for i in idx.year]},
index=idx)

fig, axes = plt.subplots(1)
if not ax:
fig, ax = plt.subplots(1)

# Rotate x-axis labels if there is enough space
rot = 45 if len(num_pubs.index) < 20 else 90
Expand All @@ -131,31 +141,32 @@ def plot_pub_trend(search_results, title=None, no_log_scale=False):
# End if

# force y-axis to use integer values
axes.yaxis.set_major_locator(MaxNLocator(integer=True))
num_pubs.plot(kind='bar', figsize=(9, 6), ax=axes, rot=rot, logy=log_form, legend=False)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
num_pubs.plot(kind='bar', figsize=(9, 6), ax=ax, rot=rot, logy=log_form, legend=False)

if len(num_pubs.index) > tick_threshold:
axes.set_xticks([i for i in range(0, len(num_pubs.index), 2)])
axes.set_xticklabels([i.year for i in num_pubs.index[::2]])
ax.set_xticks([i for i in range(0, len(num_pubs.index), 2)])
ax.set_xticklabels([i.year for i in num_pubs.index[::2]])

axes.set_xlabel("Year")
ax.set_xlabel("Year")
ax_label = "Num. Publications"
if log_form:
ax_label += "\n(log scale)"

axes.set_ylabel(ax_label)
ax.set_ylabel(ax_label)

return fig
return ax
# End plot_pub_trend()


@plot_saver
def plot_kw_trend(search_results, title=None, no_log_scale=False):
def plot_kw_trend(search_results, title=None, no_log_scale=False, ax=None):
"""Plot keyword trends across time.
Parameters
==========
* search_results : MetaKnowledge RecordCollection, of search results
* ax : List[Axis Object] or None, matplotlib axis object to add to
See Also
==========
Expand Down Expand Up @@ -190,7 +201,10 @@ def plot_kw_trend(search_results, title=None, no_log_scale=False):
num_pubs = pd.DataFrame({'count': [num_pubs.loc[i, 'count'] if i in num_pubs.index else 0 for i in idx.year]},
index=idx)

fig, (ax1, ax2) = plt.subplots(1, 2)
if not ax:
fig, (ax1, ax2) = plt.subplots(1, 2)
else:
ax1, ax2 = ax

# Rotate x-axis labels if there is enough space
rot = 45 if len(kw_trend) < 20 else 90
Expand Down Expand Up @@ -233,12 +247,12 @@ def plot_kw_trend(search_results, title=None, no_log_scale=False):

ax2.set_ylabel(ax_label)

return fig
return fig.axes
# End plot_kw_trend()


@plot_saver
def plot_pub_per_kw(kw_matches, corpora, kw_category, annotate=False):
def plot_pub_per_kw(kw_matches, corpora, kw_category, annotate=False, ax=None):
"""Plot publications per keyword.
Parameters
Expand All @@ -247,6 +261,7 @@ def plot_pub_per_kw(kw_matches, corpora, kw_category, annotate=False):
* corpora : Metaknowledge Collection, representing corpora
* kw_category : str, text indicating keyword category for use in plot title
* annotate : bool, display number of records in plot
* ax : Axis Object or None, matplotlib axis object to add to
Example
==========
Expand Down Expand Up @@ -278,7 +293,7 @@ def plot_pub_per_kw(kw_matches, corpora, kw_category, annotate=False):
summary.keys()), columns=['Keyword', 'Count'])
pubs_per_kw.sort_values(by='Count', inplace=True)

ax = pubs_per_kw.plot(kind='bar', title=ptitle, figsize=(8, 6))
ax = pubs_per_kw.plot(kind='bar', title=ptitle, figsize=(8, 6), ax=ax)

if annotate:
# Annotate number above bar
Expand All @@ -289,7 +304,7 @@ def plot_pub_per_kw(kw_matches, corpora, kw_category, annotate=False):
# force y-axis to use integer values
ax.yaxis.set_major_locator(MaxNLocator(integer=True))

return ax.get_figure()
return ax
# End plot_pub_per_kw()


Expand All @@ -306,7 +321,7 @@ def _prep_journal_records(search_results):


@plot_saver
def plot_pubs_per_journal(search_results, top_n=10, annotate=False, show_stats=True, title=None):
def plot_pubs_per_journal(search_results, top_n=10, annotate=False, show_stats=True, title=None, ax=None):
"""Plot horizontal bar plot of publications for each journal in descending order.
Parameters
Expand All @@ -315,6 +330,7 @@ def plot_pubs_per_journal(search_results, top_n=10, annotate=False, show_stats=T
* top_n : int, number of journals to display (default: 10)
* annotate : bool, annotate plot with values (default: False)
* print_stats : bool, print out percentage of publications the results represent
* ax : Axis Object or None, matplotlib axis object to add to
Returns
==========
Expand Down Expand Up @@ -347,7 +363,8 @@ def plot_pubs_per_journal(search_results, top_n=10, annotate=False, show_stats=T
fontsize=12,
title=plot_title,
figsize=(12, 6),
color='#2b7bba' # match seaborn blue
color='#2b7bba', # match seaborn blue
ax=ax
)
ax.set_ylabel('')

Expand All @@ -357,7 +374,7 @@ def plot_pubs_per_journal(search_results, top_n=10, annotate=False, show_stats=T
ax.annotate("{}".format(p.get_width()),
(p.get_width() + 0.01, p.get_y()), fontsize=12)

return ax.get_figure()
return ax
# End plot_pubs_per_journal()


Expand Down Expand Up @@ -406,18 +423,19 @@ def plot_journal_pub_trend(search_results, title='Journal Publication Trend', to
[ax[0].legend([so], fontsize=10, loc='center left', bbox_to_anchor=(1.0, 0.5))
for ax, so in zip(axes, pubs_across_time.columns)]

return axes[0][0].get_figure()
return axes[0][0]
# End plot_journal_pub_trend()


@plot_saver
def plot_criteria_trend(corpora_df, threshold=3):
def plot_criteria_trend(corpora_df, threshold=3, ax=None):
"""Plot criteria membership across time.
Parameters
==========
* corpora_df : Pandas DataFrame, of records from wosis.analysis.search.collate_keyword_criteria_matches
* threshold : int, plot number of papers that are members of at least this number of criterias (default: 3)
* ax : Axis Object or None, matplotlib axis object to add to
Returns
==========
Expand All @@ -439,14 +457,15 @@ def plot_criteria_trend(corpora_df, threshold=3):
current_palette = sns.color_palette()

ax = grp_count.plot(kind='bar', color=current_palette[0],
title='Papers with Keywords\nin {} or More Criteria'.format(threshold))
title='Papers with Keywords\nin {} or More Criteria'.format(threshold),
ax=ax)

return ax.get_figure()
return ax
# End plot_criteria_trend()


@plot_saver
def plot_topic_trend(topics, total_rc=None, title='Topic Trend'):
def plot_topic_trend(topics, total_rc=None, title='Topic Trend', ax=None):
"""Plot the trends of topics over time.
Parameters
Expand All @@ -455,6 +474,7 @@ def plot_topic_trend(topics, total_rc=None, title='Topic Trend'):
* total_rc : RecordCollection or None, collection used to calculate topic proportion relative to the corpora.
If `None`, plots number of publications. Defaults to None.
* title : str, title for plot
* ax : Axis Object or None, matplotlib axis object to add to
Returns
==========
Expand All @@ -468,15 +488,18 @@ def plot_topic_trend(topics, total_rc=None, title='Topic Trend'):
rc = pd.DataFrame(total_rc.timeSeries('year'))
rc = rc.set_index('year', drop=True)
rc = rc['count']
y_label = 'Perc. of Corpora (%)'
y_label = 'Percentage of Corpora (%)'
mod = 100.0
else:
rc = 1
mod = 1
y_label = 'Num. Publications'

alpha_val = 0.7 if len(topics) > 1 else 1.0
fig, ax = plt.subplots(figsize=(12,6))

if not ax:
fig, ax = plt.subplots(figsize=(12,6))

plt.title(title)
for kwm in topics:
if hasattr(kwm, 'recs'):
Expand All @@ -494,7 +517,7 @@ def plot_topic_trend(topics, total_rc=None, title='Topic Trend'):

label = label + "\n({} publications)".format(df['count'].sum())

ax = ((df['count'] / rc) * mod).plot(legend=True, ax=ax, label=label, style='-o', alpha=alpha_val)
ax = ((df['count'] / rc) * mod).plot(legend=False, ax=ax, label=label, style='-o', alpha=alpha_val)
# End for

plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
Expand All @@ -504,31 +527,38 @@ def plot_topic_trend(topics, total_rc=None, title='Topic Trend'):
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_ylabel(y_label)

return ax.get_figure()
return ax
# End plot_topic_trend()


@plot_saver
def plot_citations(citation_df, top_n=10, plot_title='Citations', annotate=True):
def plot_citations(citation_df, top_n=10, plot_title='Citations', annotate=True, ax=None):
"""Plot citations per top `n` papers.
Parameters
==========
* citation_df : DataFrame, with citations column
* top_n : int, top `n` papers to show
* annotate : bool, show number of citations. Defaults to True.
* ax : Axis Object or None, matplotlib axis object to add to
Returns
==========
* Matplotlib plot object
"""
citation_df.index = citation_df['title'].str[0:25] + "..."
ax = citation_df['citations'][:top_n][::-1].plot(kind='barh', color='blue', title=plot_title, fontsize=10, figsize=(10,6))
ax = citation_df['citations'][:top_n][::-1].plot(kind='barh',
color='blue',
title=plot_title,
fontsize=10,
figsize=(10,6),
ax=ax)

if annotate:
# Annotate number above bar
for p in ax.patches:
ax.annotate("{}".format(p.get_width()),
(p.get_width() + 0.01, p.get_y()), fontsize=12)

return ax.get_figure()
return ax
# End plot_citations()

0 comments on commit 0418389

Please sign in to comment.