Skip to content

Commit f6aa500

Browse files
authored
Merge pull request #202 from OliverSchacht/causallearn-pr
Add two variants of the KCI test
2 parents d450dd8 + b803533 commit f6aa500

File tree

8 files changed

+1069
-3
lines changed

8 files changed

+1069
-3
lines changed

causallearn/utils/FastKCI/FastKCI.py

+534
Large diffs are not rendered by default.

causallearn/utils/FastKCI/__init__.py

Whitespace-only changes.

causallearn/utils/RCIT/RCIT.py

+403
Large diffs are not rendered by default.

causallearn/utils/RCIT/__init__.py

Whitespace-only changes.

causallearn/utils/cit.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from scipy.stats import chi2, norm
66

77
from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd
8+
from causallearn.utils.FastKCI.FastKCI import FastKCI_CInd, FastKCI_UInd
9+
from causallearn.utils.RCIT.RCIT import RCIT as RCIT_CInd
10+
from causallearn.utils.RCIT.RCIT import RIT as RCIT_UInd
811
from causallearn.utils.PCUtils import Helper
912

1013
CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5
@@ -13,6 +16,8 @@
1316
mv_fisherz = "mv_fisherz"
1417
mc_fisherz = "mc_fisherz"
1518
kci = "kci"
19+
rcit = "rcit"
20+
fastkci = "fastkci"
1621
chisq = "chisq"
1722
gsq = "gsq"
1823
d_separation = "d_separation"
@@ -23,15 +28,19 @@ def CIT(data, method='fisherz', **kwargs):
2328
Parameters
2429
----------
2530
data: numpy.ndarray of shape (n_samples, n_features)
26-
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "chisq", "gsq"]
27-
kwargs: placeholder for future arguments, or for KCI specific arguments now
31+
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", "chisq", "gsq"]
32+
kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
2833
TODO: utimately kwargs should be replaced by explicit named parameters.
2934
check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
3035
'''
3136
if method == fisherz:
3237
return FisherZ(data, **kwargs)
3338
elif method == kci:
3439
return KCI(data, **kwargs)
40+
elif method == fastkci:
41+
return FastKCI(data, **kwargs)
42+
elif method == rcit:
43+
return RCIT(data, **kwargs)
3544
elif method in [chisq, gsq]:
3645
return Chisq_or_Gsq(data, method_name=method, **kwargs)
3746
elif method == mv_fisherz:
@@ -43,6 +52,7 @@ def CIT(data, method='fisherz', **kwargs):
4352
else:
4453
raise ValueError("Unknown method: {}".format(method))
4554

55+
4656
class CIT_Base(object):
4757
# Base class for CIT, contains basic operations for input check and caching, etc.
4858
def __init__(self, data, cache_path=None, **kwargs):
@@ -193,6 +203,50 @@ def __call__(self, X, Y, condition_set=None):
193203
self.pvalue_cache[cache_key] = p
194204
return p
195205

206+
class FastKCI(CIT_Base):
207+
def __init__(self, data, **kwargs):
208+
super().__init__(data, **kwargs)
209+
kci_ui_kwargs = {k: v for k, v in kwargs.items() if k in
210+
['K', 'J', 'alpha']}
211+
kci_ci_kwargs = {k: v for k, v in kwargs.items() if k in
212+
['K', 'J', 'alpha', 'use_gp']}
213+
self.check_cache_method_consistent(
214+
'kci', hashlib.md5(json.dumps(kci_ci_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
215+
self.assert_input_data_is_valid()
216+
self.kci_ui = FastKCI_UInd(**kci_ui_kwargs)
217+
self.kci_ci = FastKCI_CInd(**kci_ci_kwargs)
218+
219+
def __call__(self, X, Y, condition_set=None):
220+
# Kernel-based conditional independence test.
221+
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
222+
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
223+
p = self.kci_ui.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
224+
self.kci_ci.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
225+
self.pvalue_cache[cache_key] = p
226+
return p
227+
228+
class RCIT(CIT_Base):
229+
def __init__(self, data, **kwargs):
230+
super().__init__(data, **kwargs)
231+
rit_kwargs = {k: v for k, v in kwargs.items() if k in
232+
['approx']}
233+
rcit_kwargs = {k: v for k, v in kwargs.items() if k in
234+
['approx', 'num_f', 'num_f2', 'rcit']}
235+
self.check_cache_method_consistent(
236+
'kci', hashlib.md5(json.dumps(rcit_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
237+
self.assert_input_data_is_valid()
238+
self.rit = RCIT_UInd(**rit_kwargs)
239+
self.rcit = RCIT_CInd(**rcit_kwargs)
240+
241+
def __call__(self, X, Y, condition_set=None):
242+
# Kernel-based conditional independence test.
243+
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
244+
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
245+
p = self.rit.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
246+
self.rcit.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
247+
self.pvalue_cache[cache_key] = p
248+
return p
249+
196250
class Chisq_or_Gsq(CIT_Base):
197251
def __init__(self, data, method_name, **kwargs):
198252
def _unique(column):

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
'matplotlib',
2323
'networkx',
2424
'pydot',
25-
'tqdm'
25+
'tqdm',
26+
'momentchi2'
2627
],
2728
url='https://github.com/py-why/causal-learn',
2829
packages=setuptools.find_packages(),

tests/TestCIT_FastKCI.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
import causallearn.utils.cit as cit
6+
7+
8+
class TestCIT_FastKCI(unittest.TestCase):
9+
def test_Gaussian_dist(self):
10+
np.random.seed(10)
11+
X = np.random.randn(1200, 1)
12+
X_prime = np.random.randn(1200, 1)
13+
Y = X + 0.5 * np.random.randn(1200, 1)
14+
Z = Y + 0.5 * np.random.randn(1200, 1)
15+
data = np.hstack((X, X_prime, Y, Z))
16+
17+
pvalue01 = []
18+
pvalue03 = []
19+
pvalue032 = []
20+
for K in [3, 10]:
21+
for J in [8, 16]:
22+
for use_gp in [True, False]:
23+
cit_CIT = cit.CIT(data, 'fastkci', K=K, J=J, use_gp=use_gp)
24+
pvalue01.append(round(cit_CIT(0, 1), 4))
25+
pvalue03.append(round(cit_CIT(0, 3), 4))
26+
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
27+
28+
pvalue01 = np.array(pvalue01)
29+
pvalue03 = np.array(pvalue03)
30+
pvalue032 = np.array(pvalue032)
31+
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
32+
"pvalue01 contains invalid values")
33+
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
34+
"pvalue03 contains invalid values")
35+
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
36+
"pvalue032 contains invalid values")

tests/TestCIT_RCIT.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
import causallearn.utils.cit as cit
6+
7+
8+
class TestCIT_RCIT(unittest.TestCase):
9+
def test_Gaussian_dist(self):
10+
np.random.seed(10)
11+
X = np.random.randn(300, 1)
12+
X_prime = np.random.randn(300, 1)
13+
Y = X + 0.5 * np.random.randn(300, 1)
14+
Z = Y + 0.5 * np.random.randn(300, 1)
15+
data = np.hstack((X, X_prime, Y, Z))
16+
17+
pvalue01 = []
18+
pvalue03 = []
19+
pvalue032 = []
20+
for approx in ["lpd4", "hbe", "gamma", "chi2", "perm"]:
21+
for num_f in [50, 100]:
22+
for num_f2 in [5, 10]:
23+
for rcit in [True, False]:
24+
cit_CIT = cit.CIT(data, 'rcit', approx=approx, num_f=num_f,
25+
num_f2=num_f2, rcit=rcit)
26+
pvalue01.append(round(cit_CIT(0, 1), 4))
27+
pvalue03.append(round(cit_CIT(0, 3), 4))
28+
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))
29+
30+
pvalue01 = np.array(pvalue01)
31+
pvalue03 = np.array(pvalue03)
32+
pvalue032 = np.array(pvalue032)
33+
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
34+
"pvalue01 contains invalid values")
35+
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
36+
"pvalue03 contains invalid values")
37+
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
38+
"pvalue032 contains invalid values")

0 commit comments

Comments
 (0)