@@ -847,3 +847,82 @@ def test_dic_posterior_dependence(self):
847847 dic2 = 2 * np .mean (- 2 * p2 ) + 2 * norm .logpdf (0 , scale = 2 )
848848 assert np .isclose (bics [0 ], dic1 - dic2 , atol = 1e-3 )
849849
850+ def test_remove_last_chain (self ):
851+ tolerance = 5e-2
852+ consumer = ChainConsumer ()
853+ consumer .add_chain (self .data )
854+ consumer .add_chain (self .data * 2 )
855+ consumer .remove_chain ()
856+ consumer .configure (bins = 1.6 )
857+ summary = consumer .get_summary ()
858+ assert isinstance (summary , dict )
859+ actual = np .array (list (summary .values ())[0 ])
860+ expected = np .array ([3.5 , 5.0 , 6.5 ])
861+ diff = np .abs (expected - actual )
862+ assert np .all (diff < tolerance )
863+
864+ def test_remove_first_chain (self ):
865+ tolerance = 5e-2
866+ consumer = ChainConsumer ()
867+ consumer .add_chain (self .data * 2 )
868+ consumer .add_chain (self .data )
869+ consumer .remove_chain (chain = 0 )
870+ consumer .configure (bins = 1.6 )
871+ summary = consumer .get_summary ()
872+ assert isinstance (summary , dict )
873+ actual = np .array (list (summary .values ())[0 ])
874+ expected = np .array ([3.5 , 5.0 , 6.5 ])
875+ diff = np .abs (expected - actual )
876+ assert np .all (diff < tolerance )
877+
878+ def test_remove_chain_by_name (self ):
879+ tolerance = 5e-2
880+ consumer = ChainConsumer ()
881+ consumer .add_chain (self .data * 2 , name = "a" )
882+ consumer .add_chain (self .data , name = "b" )
883+ consumer .remove_chain (chain = "a" )
884+ consumer .configure (bins = 1.6 )
885+ summary = consumer .get_summary ()
886+ assert isinstance (summary , dict )
887+ actual = np .array (list (summary .values ())[0 ])
888+ expected = np .array ([3.5 , 5.0 , 6.5 ])
889+ diff = np .abs (expected - actual )
890+ assert np .all (diff < tolerance )
891+
892+ def test_remove_chain_recompute_params (self ):
893+ tolerance = 5e-2
894+ consumer = ChainConsumer ()
895+ consumer .add_chain (self .data * 2 , parameters = ["p1" ], name = "a" )
896+ consumer .add_chain (self .data , parameters = ["p2" ], name = "b" )
897+ consumer .remove_chain (chain = "a" )
898+ consumer .configure (bins = 1.6 )
899+ summary = consumer .get_summary ()
900+ assert isinstance (summary , dict )
901+ assert "p2" in summary
902+ assert "p1" not in summary
903+ actual = np .array (list (summary .values ())[0 ])
904+ expected = np .array ([3.5 , 5.0 , 6.5 ])
905+ diff = np .abs (expected - actual )
906+ assert np .all (diff < tolerance )
907+
908+ def test_remove_multiple_chains (self ):
909+ tolerance = 5e-2
910+ consumer = ChainConsumer ()
911+ consumer .add_chain (self .data * 2 , parameters = ["p1" ], name = "a" )
912+ consumer .add_chain (self .data , parameters = ["p2" ], name = "b" )
913+ consumer .add_chain (self .data * 3 , parameters = ["p3" ], name = "c" )
914+ consumer .remove_chain (chain = ["a" , "c" ])
915+ consumer .configure (bins = 1.6 )
916+ summary = consumer .get_summary ()
917+ assert isinstance (summary , dict )
918+ assert "p2" in summary
919+ assert "p1" not in summary
920+ assert "p3" not in summary
921+ actual = np .array (list (summary .values ())[0 ])
922+ expected = np .array ([3.5 , 5.0 , 6.5 ])
923+ diff = np .abs (expected - actual )
924+ assert np .all (diff < tolerance )
925+
926+ def test_remove_multiple_chains_fails (self ):
927+ with pytest .raises (AssertionError ):
928+ ChainConsumer ().add_chain (self .data ).remove_chain (chain = [0 , 0 ])
0 commit comments