diff --git a/heatmap/heatmap.py b/heatmap/heatmap.py index f386537..d933f26 100644 --- a/heatmap/heatmap.py +++ b/heatmap/heatmap.py @@ -95,6 +95,8 @@ def value_to_size(val): ax.set_xlabel(kwargs.get('xlabel', '')) ax.set_ylabel(kwargs.get('ylabel', '')) + + main_ax = ax # Add color legend on the right side of the plot if color_min < color_max: @@ -118,12 +120,17 @@ def value_to_size(val): ax.set_xticks([]) # Remove horizontal ticks ax.set_yticks(np.linspace(min(bar_y), max(bar_y), 3)) # Show vertical ticks for min, middle and max ax.yaxis.tick_right() # Show vertical ticks on the right + + colorbar_ax = ax + return main_ax, colorbar_ax + + return main_ax def corrplot(data, size_scale=500, marker='s'): corr = pd.melt(data.reset_index(), id_vars='index').replace(np.nan, 0) corr.columns = ['x', 'y', 'value'] - heatmap( + return heatmap( corr['x'], corr['y'], color=corr['value'], color_range=[-1, 1], palette=sns.diverging_palette(20, 220, n=256),