Skip to content

Commit 70d9b2d

Browse files
author
Joseph
committed
Increase test coverage
1 parent 3c88fe4 commit 70d9b2d

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

markov_builder/MarkovChain.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,7 @@ def eval_transition_matrix(self, rates_dict: dict) -> Tuple[List[str], sp.Matrix
283283
"""
284284

285285
if rates_dict is None:
286-
assert len(rates_dict) > 0
287-
rates_dict = self.rate_expressions
286+
Exception('rate dictionary not provided')
288287

289288
l, Q = self.get_transition_matrix(use_parameters=True)
290289
Q_evaled = np.array(Q.evalf(subs=rates_dict)).astype(np.float64)
@@ -428,7 +427,9 @@ def sample_trajectories(self, no_trajectories: int, time_range: list = [0, 1],
428427
if s_i == 0:
429428
waiting_times[state_index] = np.inf
430429
else:
431-
waiting_times[state_index] = self.rng.exponential(mean_waiting_times[state_index] / (s_i))
430+
waiting_times[state_index] =\
431+
self.rng.exponential(mean_waiting_times[state_index] /
432+
(s_i))
432433

433434
if t + min(waiting_times) > time_range[1] - time_range[0]:
434435
break
@@ -463,7 +464,15 @@ def get_equilibrium_distribution(self, param_dict: dict = None) -> Tuple[List[st
463464
A, B = self.eliminate_state_from_transition_matrix(use_parameters=True)
464465

465466
labels = self.graph.nodes()
466-
ss = -np.array(A.LUsolve(B).evalf(subs=param_dict)).astype(np.float64)
467+
try:
468+
ss = -np.array(A.LUsolve(B).evalf(subs=param_dict)).astype(np.float64)
469+
470+
except TypeError as exc:
471+
logging.warning("Couldn't evaluate equilibrium distribution as float."
472+
"Is every parameter defined?"
473+
"%s" % str(exc))
474+
raise exc
475+
467476
logging.debug("ss is %s", ss)
468477
ss = np.append(ss, 1 - ss.sum())
469478
return labels, ss
@@ -781,6 +790,9 @@ def define_auxiliary_expression(self, expression: sp.Expr, label: str = None, de
781790
self.default_values = {**self.default_values, **default_values}
782791
self.auxiliary_expression = expression
783792

793+
def get_states(self):
794+
return list(self.graph)
795+
784796
def as_latex(self, state_to_remove: str = None, include_auxiliary_expression: bool = False,
785797
column_vector=True, label_order: list = None) -> str:
786798
"""Creates a LaTeX expression describing the Markov chain, its parameters and

tests/test_MarkovChain.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import markov_builder.example_models as example_models
1313

14+
from markov_builder.rate_expressions import negative_rate_expr, positive_rate_expr
15+
1416

1517
class TestMarkovChain(unittest.TestCase):
1618

@@ -55,10 +57,13 @@ def test_transition_matrix(self):
5557
nx.drawing.nx_agraph.write_dot(mc.graph, "Beattie_dotfile.dot")
5658

5759
# Draw graph using pyvis
60+
mc.draw_graph(os.path.join(self.output_dir, "BeattieModel.html"), show_options=True)
5861
mc.draw_graph(os.path.join(self.output_dir, "BeattieModel.html"))
5962
logging.debug(mc.graph)
6063

64+
labels, Q = mc.get_transition_matrix(use_parameters=True)
6165
labels, Q = mc.get_transition_matrix()
66+
6267
logging.debug("Q^T matrix is {}, labels are {}".format(Q.T, labels))
6368

6469
system = mc.eliminate_state_from_transition_matrix(['C', 'O', 'I'])
@@ -84,7 +89,23 @@ def test_construct_examples(self):
8489
nx.drawing.nx_agraph.write_dot(mc.graph, "%s_dotfile.dot" % name)
8590
nx.drawing.nx_agraph.write_dot(mc.graph, "%s_dotfile.dot" % name)
8691

87-
def test_parameterise_rates(self):
92+
def test_parameterise_rates_no_default(self):
93+
"""Test parameterise rates using a dictionary with no default parameter values.
94+
95+
Using the Beattie model
96+
"""
97+
98+
mc = example_models.construct_four_state_chain()
99+
100+
rate_dictionary = {'k_1': positive_rate_expr,
101+
'k_2': negative_rate_expr,
102+
'k_3': positive_rate_expr,
103+
'k_4': negative_rate_expr,
104+
}
105+
106+
mc.parameterise_rates(rate_dictionary, shared_variables=('V',))
107+
108+
def test_myokit_output(self):
88109
"""
89110
Test the MarkovChain.parameterise_rates function.
90111
"""
@@ -109,7 +130,11 @@ def test_parameterise_rates(self):
109130
self.assertEqual(param_list.count('V'), 1)
110131

111132
# Generate myokit code
133+
myokit_model = mc.generate_myokit_model(eliminate_state='O')
112134
myokit_model = mc.generate_myokit_model()
135+
myokit_model = mc.generate_myokit_model(drug_binding=True)
136+
myokit_model = mc.generate_myokit_model(drug_binding=True, eliminate_state='O')
137+
113138
myokit.save(os.path.join(self.output_dir, 'beattie_model.mmt'), myokit_model)
114139

115140
# Eliminate last node
@@ -214,6 +239,7 @@ def test_latex_printing(self):
214239
for mc in models:
215240
logging.debug(f"Printing latex for {mc.name}")
216241
logging.debug(mc.as_latex())
242+
logging.debug(mc.as_latex(label_order=mc.get_states()))
217243
logging.debug(mc.as_latex(state_to_remove='O'))
218244
logging.debug(mc.as_latex(include_auxiliary_expression=True))
219245
logging.debug(mc.as_latex('O', True))
@@ -232,9 +258,13 @@ def test_sample_trajectories(self):
232258
param_dict = mc.default_values
233259
param_dict['V'] = 0
234260

261+
self.assertRaises(TypeError, mc.get_equilibrium_distribution)
262+
235263
labels, eqm_dist = mc.get_equilibrium_distribution(param_dict=param_dict)
236264
starting_distribution = [int(val) for val in n_samples * eqm_dist]
237265

266+
df = mc.sample_trajectories(n_samples, (0, 10))
267+
238268
df = mc.sample_trajectories(n_samples, (0, 250), param_dict=param_dict,
239269
starting_distribution=starting_distribution)
240270
df = df.set_index('time')

0 commit comments

Comments
 (0)