|
15 | 15 | # |
16 | 16 | # You should have received a copy of the GNU Lesser General Public License |
17 | 17 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
18 | | - |
| 18 | +import shutil |
19 | 19 | import tempfile |
20 | | -import unittest |
21 | 20 |
|
| 21 | +import pytest |
| 22 | +from numpy.testing import assert_ |
22 | 23 |
|
23 | 24 | from pyemma._base.serialization.cli import main |
24 | 25 | from pyemma.coordinates import source, tica, cluster_kmeans |
25 | 26 |
|
26 | 27 |
|
27 | | -class TestListModelCLI(unittest.TestCase): |
28 | | - @classmethod |
29 | | - def setUpClass(cls): |
| 28 | +@pytest.fixture |
| 29 | +def model_file(): |
| 30 | + file = None |
| 31 | + try: |
30 | 32 | from pyemma.datasets import get_bpti_test_data |
31 | | - |
32 | 33 | d = get_bpti_test_data() |
33 | 34 | trajs, top = d['trajs'], d['top'] |
34 | 35 | s = source(trajs, top=top) |
35 | 36 |
|
36 | 37 | t = tica(s, lag=1) |
37 | 38 |
|
38 | 39 | c = cluster_kmeans(t) |
39 | | - cls.model_file = tempfile.mktemp() |
40 | | - c.save(cls.model_file, save_streaming_chain=True) |
41 | | - |
42 | | - @classmethod |
43 | | - def tearDownClass(cls): |
44 | | - import os |
45 | | - os.unlink(cls.model_file) |
46 | | - |
47 | | - def test_recursive(self): |
48 | | - """ check the whole chain has been printed""" |
49 | | - from pyemma.util.contexts import Capturing |
50 | | - with Capturing() as out: |
51 | | - main(['--recursive', self.model_file]) |
52 | | - assert out |
53 | | - all_out = '\n'.join(out) |
54 | | - self.assertIn('FeatureReader', all_out) |
55 | | - self.assertIn('TICA', all_out) |
56 | | - self.assertIn('Kmeans', all_out) |
57 | | - |
58 | | - |
59 | | -if __name__ == '__main__': |
60 | | - unittest.main() |
| 40 | + file = tempfile.mktemp() |
| 41 | + c.save(file, save_streaming_chain=True) |
| 42 | + |
| 43 | + yield file |
| 44 | + finally: |
| 45 | + if file is not None: |
| 46 | + shutil.rmtree(file, ignore_errors=True) |
| 47 | + |
| 48 | + |
| 49 | +def test_recursive(model_file): |
| 50 | + """ check the whole chain has been printed""" |
| 51 | + from pyemma.util.contexts import Capturing |
| 52 | + with Capturing() as out: |
| 53 | + main(['--recursive', model_file]) |
| 54 | + assert out |
| 55 | + all_out = '\n'.join(out) |
| 56 | + assert_('FeatureReader' in all_out) |
| 57 | + assert_('TICA' in all_out) |
| 58 | + assert_('Kmeans' in all_out) |
0 commit comments