Skip to content

Colormap not preserved in scatter plot when generating baseline plot. #84

Open
@rcjackson

Description

@rcjackson

Hello, when we are trying to create a baseline for a unit test that uses a scatter plot, we have noticed that pytest-mpl has not been preserving the colormap used in our scatter plot.

For example, the plot we wish to compare against is this:
myfig

When we run the unittest through simply doing an import and showing the resulting figure we get the figure above. However, whenever we use pytest to generate the baseline figure, we get the below figure:
test_time_height_scatter

Main test:

@pytest.mark.mpl_image_compare(tolerance=30)
def test_time_height_scatter():
    sonde_ds = arm.read_netcdf(
        sample_files.EXAMPLE_SONDE1)

    display = TimeSeriesDisplay({'sgpsondewnpnC1.b1': sonde_ds},
                                figsize=(7, 3))
    display.time_height_scatter('tdry', day_night_background=True)
    sonde_ds.close()

    return display.fig

time_height_scatter routine:

    def time_height_scatter(
            self, data_field=None, dsname=None, cmap='rainbow',
            alt_label=None, alt_field='alt', cb_label=None, **kwargs):
        """
        Create a time series plot of altitued and data varible with
        color also indicating value with a color bar. The Color bar is
        positioned to serve both as the indicator of the color intensity
        and the second y-axis.

        Parameters
        ----------
        data_field: str
            Name of data field in the object to plot on second y-axis
        height_field: str
            Name of height field in the object to plot on first y-axis.
        dsname: str or None
            The name of the datastream to plot
        cmap: str
            Colorbar corlor map to use.
        alt_label: str
            Altitued first y-axis label to use. If not set will try to use
            long_name and units.
        alt_field: str
            Label for field in the object to plot on first y-axis.
        cb_label: str
            Colorbar label to use. If not set will try to use
            long_name and units.
        **kwargs: keyword arguments
            Any other keyword arguments that will be passed
            into TimeSeriesDisplay.plot module when the figure
            is made.
        """
        if dsname is None and len(self._arm.keys()) > 1:
            raise ValueError(("You must choose a datastream when there are 2 "
                              "or more datasets in the TimeSeriesDisplay "
                              "object."))
        elif dsname is None:
            dsname = list(self._arm.keys())[0]

        # Get data and dimensions
        data = self._arm[dsname][data_field]
        altitude = self._arm[dsname][alt_field]
        dim = list(self._arm[dsname][data_field].dims)
        xdata = self._arm[dsname][dim[0]]

        if alt_label is None:
            try:
                alt_label = (altitude.attrs['long_name'] +
                             ''.join([' (', altitude.attrs['units'], ')']))
            except KeyError:
                alt_label = alt_field

        if cb_label is None:
            try:
                cb_label = (data.attrs['long_name'] +
                            ''.join([' (', data.attrs['units'], ')']))
            except KeyError:
                cb_label = data_field

        colorbar_map = plt.cm.get_cmap(cmap)
        self.fig.subplots_adjust(left=0.1, right=0.86,
                                 bottom=0.16, top=0.91)
        ax1 = self.plot(alt_field, color='black', **kwargs)
        ax1.set_ylabel(alt_label)
        ax2 = ax1.twinx()
        sc = ax2.scatter(xdata.values, data.values, c=data.values,
                         marker='.', cmap=colorbar_map)
        cbaxes = self.fig.add_axes(
            [self.fig.subplotpars.right + 0.02, self.fig.subplotpars.bottom,
             0.02, self.fig.subplotpars.top - self.fig.subplotpars.bottom])
        cbar = plt.colorbar(sc, cax=cbaxes)
        ax2.set_ylim(cbar.get_clim())
        cbar.ax.set_ylabel(cb_label)
        ax2.set_yticklabels([])

        return self.axes[0]

Any help you could provide on this would be appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions