Skip to content

Commit

Permalink
run black on code. Final changes for #7
Browse files Browse the repository at this point in the history
  • Loading branch information
nhsavage committed Nov 6, 2020
1 parent 33c85d9 commit 4836de1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
47 changes: 25 additions & 22 deletions lib/catnip/tests/test_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,26 @@ def tearDown(self):
pass

def test_get_xy_noborder(self):
'''
"""
Test basic version of _get_xy_noborder
'''
"""

indices_expected = (2, 4, 1, 4)

# set up test data
d = np.full((5, 5), True)
# fill some values
d[3,2]=False
d[3,3,]=False
d[1,3]=False
d[3, 2] = False
d[3, 3,] = False
d[1, 3] = False

indices_actual = _get_xy_noborder(d)
self.assertEqual(indices_expected, indices_actual)

def test_get_xy_noborder_false(self):
'''
"""
Test get_xy_noborder when mask all False
'''
"""

indices_expected = (0, 5, 0, 5)

Expand All @@ -102,9 +102,9 @@ def test_get_xy_noborder_false(self):
self.assertEqual(indices_expected, indices_actual)

def test_get_xy_noborder_true(self):
'''
"""
Test _get_xy_noborder when all True - error
'''
"""

# set up test data
d = np.full((5, 5), True)
Expand Down Expand Up @@ -161,29 +161,32 @@ def test_add_coord_system(self):
self.assertRaises(TypeError, add_coord_system, self.mslp_daily_cube)

def test_extract_rot_cube(self):
'''
"""
Test the shape of the extracted cube and min/max
lat and lon
'''
"""

# define box to extract
min_lat = 50
min_lon = -10
max_lat = 60
max_lon = 0

tcube = self.rcm_monthly_cube.extract_strict('air_temperature')
extracted_cube = extract_rot_cube(tcube, min_lat, min_lon,
max_lat, max_lon)
tcube = self.rcm_monthly_cube.extract_strict("air_temperature")
extracted_cube = extract_rot_cube(tcube, min_lat, min_lon, max_lat, max_lon)
self.assertEqual(np.shape(extracted_cube.data), (2, 102, 78))
self.assertEqual(np.max(extracted_cube.coord('latitude').points),
61.365291870327816)
self.assertEqual(np.min(extracted_cube.coord('latitude').points),
48.213032844268646)
self.assertEqual(np.max(extracted_cube.coord('longitude').points),
3.642576550089792)
self.assertEqual(np.min(extracted_cube.coord('longitude').points),
-16.29169201066359)
self.assertEqual(
np.max(extracted_cube.coord("latitude").points), 61.365291870327816
)
self.assertEqual(
np.min(extracted_cube.coord("latitude").points), 48.213032844268646
)
self.assertEqual(
np.max(extracted_cube.coord("longitude").points), 3.642576550089792
)
self.assertEqual(
np.min(extracted_cube.coord("longitude").points), -16.29169201066359
)

def test_add_time_coord_cats(self):
cube = self.mslp_daily_cube.copy()
Expand Down
27 changes: 14 additions & 13 deletions lib/catnip/tests/test_visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import imagehash
from PIL import Image
from catnip.visualisation import vector_plot, plot_regress

# import matplotlob after catnip vector plot as that sets the Agg
# backend
import matplotlib.pyplot as plt
Expand All @@ -48,23 +49,23 @@
#: Default maximum perceptual hash hamming distance.
_HAMMING_DISTANCE = 0


def _compare_images(figure, expected_filename):
'''
"""
Use imagehash to compare images fast and reliably.
Returns True if they match within tolerance, false
otherwise
'''
"""
img_buffer = io.BytesIO()
figure.savefig(img_buffer, format="png")
img_buffer.seek(0)
gen_phash = imagehash.phash(Image.open(img_buffer), hash_size=_HASH_SIZE)
exp_phash = imagehash.phash(
Image.open(expected_filename), hash_size=_HASH_SIZE
)
exp_phash = imagehash.phash(Image.open(expected_filename), hash_size=_HASH_SIZE)
distance = abs(gen_phash - exp_phash)
return distance <= _HAMMING_DISTANCE


class TestVisualisation(unittest.TestCase):
"""Unittest class for visualisation module"""

Expand All @@ -74,10 +75,10 @@ def setUpClass(self):
file2 = os.path.join(conf.DATA_DIR, "gcm_monthly.pp")
self.rcm_monthly_cube = iris.load(file1)
self.gcm_monthly_cube = iris.load(file2)
self.gcm_u = self.gcm_monthly_cube.extract_strict('x_wind')
self.gcm_v = self.gcm_monthly_cube.extract_strict('y_wind')
self.rcm_u = self.rcm_monthly_cube.extract_strict('x_wind')[0,...]
self.rcm_v = self.rcm_monthly_cube.extract_strict('y_wind')[0,...]
self.gcm_u = self.gcm_monthly_cube.extract_strict("x_wind")
self.gcm_v = self.gcm_monthly_cube.extract_strict("y_wind")
self.rcm_u = self.rcm_monthly_cube.extract_strict("x_wind")[0, ...]
self.rcm_v = self.rcm_monthly_cube.extract_strict("y_wind")[0, ...]

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -109,7 +110,7 @@ def test_vector_plot_gcm_title(self):

expected_png = os.path.join(conf.KGO_DIR, "gcm_ws_title.png")

vector_plot(self.gcm_u, self.gcm_v, title='GCM W/S')
vector_plot(self.gcm_u, self.gcm_v, title="GCM W/S")

actual_fig = plt.gcf()
self.assertTrue(_compare_images(actual_fig, expected_png))
Expand All @@ -136,13 +137,13 @@ def test_vector_plot_gcm_n10(self):
vector_plot(self.gcm_u, self.gcm_v, npts=10)

actual_fig = plt.gcf()
plt.savefig('/scratch/fris/gcm_ws_n10.png')
plt.savefig("/scratch/fris/gcm_ws_n10.png")
self.assertTrue(_compare_images(actual_fig, expected_png))

def test_vector_rot_error(self):
'''
"""
Test that passing a global field gives an Exception
'''
"""

self.assertRaises(Exception, vector_plot, self.gcm_u, self.gcm_v, unrotate=True)

Expand Down

0 comments on commit 4836de1

Please sign in to comment.