Skip to content

Commit 20e48aa

Browse files
committed
Expanding plot extents slightly due to change in sigma2d default. Adding test for disjoint parameter summaries.
1 parent 4b7b92d commit 20e48aa

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

chainconsumer/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def get_extents(data, weight, plot=False, wide_extents=True):
99
icdf = (1 - cdf)[::-1]
1010
icdf = icdf / icdf.max()
1111
cdf = 1 - icdf[::-1]
12-
threshold = 1e-3 if plot else 1e-5
12+
threshold = 1e-4 if plot else 1e-5
1313
if plot and not wide_extents:
1414
threshold = 0.05
1515
i1 = np.where(cdf > threshold)[0][0]

tests/test_analysis.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ def test_summary_specific(self):
9191
diff = np.abs(expected - actual)
9292
assert np.all(diff < tolerance)
9393

94+
def test_summary_disjoint(self):
95+
tolerance = 5e-2
96+
consumer = ChainConsumer()
97+
consumer.add_chain(self.data, parameters="A")
98+
consumer.add_chain(self.data, parameters="B")
99+
consumer.configure(bins=0.8)
100+
summary = consumer.analysis.get_summary(parameters="A")
101+
assert len(summary) == 2 # Two chains
102+
assert summary[1] == {} # Second chain doesnt have param A
103+
actual = summary[0]["A"]
104+
expected = np.array([3.5, 5.0, 6.5])
105+
diff = np.abs(expected - actual)
106+
assert np.all(diff < tolerance)
107+
94108
def test_output_text(self):
95109
consumer = ChainConsumer()
96110
consumer.add_chain(self.data, parameters=["a"])

tests/test_plotter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,35 @@ def test_plotter_extents1(self):
1414
c.add_chain(self.data, parameters=["x"])
1515
c.configure()
1616
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
17-
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
18-
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
17+
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
18+
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)
1919

2020
def test_plotter_extents2(self):
2121
c = ChainConsumer()
2222
c.add_chain(self.data, parameters=["x"])
2323
c.add_chain(self.data + 5, parameters=["y"])
2424
c.configure()
2525
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
26-
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
27-
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
26+
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
27+
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)
2828

2929
def test_plotter_extents3(self):
3030
c = ChainConsumer()
3131
c.add_chain(self.data, parameters=["x"])
3232
c.add_chain(self.data + 5, parameters=["x"])
3333
c.configure()
3434
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
35-
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
36-
assert np.isclose(maxv, (10.0 + 1.5 * 3.1), atol=0.1)
35+
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
36+
assert np.isclose(maxv, (10.0 + 1.5 * 3.7), atol=0.2)
3737

3838
def test_plotter_extents4(self):
3939
c = ChainConsumer()
4040
c.add_chain(self.data, parameters=["x"])
4141
c.add_chain(self.data + 5, parameters=["y"])
4242
c.configure()
4343
minv, maxv = c.plotter._get_parameter_extents("x", c.chains[:1])
44-
assert np.isclose(minv, (5.0 - 1.5 * 3.1), atol=0.1)
45-
assert np.isclose(maxv, (5.0 + 1.5 * 3.1), atol=0.1)
44+
assert np.isclose(minv, (5.0 - 1.5 * 3.7), atol=0.2)
45+
assert np.isclose(maxv, (5.0 + 1.5 * 3.7), atol=0.2)
4646

4747
def test_plotter_extents5(self):
4848
x, y = np.linspace(-3, 3, 200), np.linspace(-5, 5, 200)
@@ -55,4 +55,4 @@ def test_plotter_extents5(self):
5555
c.configure()
5656
minv, maxv = c.plotter._get_parameter_extents("x", c.chains)
5757
assert np.isclose(minv, -3, atol=0.001)
58-
assert np.isclose(maxv, 3, atol=0.001)
58+
assert np.isclose(maxv, 3, atol=0.001)

0 commit comments

Comments
 (0)