5
5
from scipy .stats import chi2 , norm
6
6
7
7
from causallearn .utils .KCI .KCI import KCI_CInd , KCI_UInd
8
+ from causallearn .utils .FastKCI .FastKCI import FastKCI_CInd , FastKCI_UInd
9
+ from causallearn .utils .RCIT .RCIT import RCIT as RCIT_CInd
10
+ from causallearn .utils .RCIT .RCIT import RIT as RCIT_UInd
8
11
from causallearn .utils .PCUtils import Helper
9
12
10
13
CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5
13
16
mv_fisherz = "mv_fisherz"
14
17
mc_fisherz = "mc_fisherz"
15
18
kci = "kci"
19
+ rcit = "rcit"
20
+ fastkci = "fastkci"
16
21
chisq = "chisq"
17
22
gsq = "gsq"
18
23
d_separation = "d_separation"
@@ -23,15 +28,19 @@ def CIT(data, method='fisherz', **kwargs):
23
28
Parameters
24
29
----------
25
30
data: numpy.ndarray of shape (n_samples, n_features)
26
- method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "chisq", "gsq"]
27
- kwargs: placeholder for future arguments, or for KCI specific arguments now
31
+ method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", " chisq", "gsq"]
32
+ kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
28
33
TODO: utimately kwargs should be replaced by explicit named parameters.
29
34
check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
30
35
'''
31
36
if method == fisherz :
32
37
return FisherZ (data , ** kwargs )
33
38
elif method == kci :
34
39
return KCI (data , ** kwargs )
40
+ elif method == fastkci :
41
+ return FastKCI (data , ** kwargs )
42
+ elif method == rcit :
43
+ return RCIT (data , ** kwargs )
35
44
elif method in [chisq , gsq ]:
36
45
return Chisq_or_Gsq (data , method_name = method , ** kwargs )
37
46
elif method == mv_fisherz :
@@ -43,6 +52,7 @@ def CIT(data, method='fisherz', **kwargs):
43
52
else :
44
53
raise ValueError ("Unknown method: {}" .format (method ))
45
54
55
+
46
56
class CIT_Base (object ):
47
57
# Base class for CIT, contains basic operations for input check and caching, etc.
48
58
def __init__ (self , data , cache_path = None , ** kwargs ):
@@ -193,6 +203,50 @@ def __call__(self, X, Y, condition_set=None):
193
203
self .pvalue_cache [cache_key ] = p
194
204
return p
195
205
206
+ class FastKCI (CIT_Base ):
207
+ def __init__ (self , data , ** kwargs ):
208
+ super ().__init__ (data , ** kwargs )
209
+ kci_ui_kwargs = {k : v for k , v in kwargs .items () if k in
210
+ ['K' , 'J' , 'alpha' ]}
211
+ kci_ci_kwargs = {k : v for k , v in kwargs .items () if k in
212
+ ['K' , 'J' , 'alpha' , 'use_gp' ]}
213
+ self .check_cache_method_consistent (
214
+ 'kci' , hashlib .md5 (json .dumps (kci_ci_kwargs , sort_keys = True ).encode ('utf-8' )).hexdigest ())
215
+ self .assert_input_data_is_valid ()
216
+ self .kci_ui = FastKCI_UInd (** kci_ui_kwargs )
217
+ self .kci_ci = FastKCI_CInd (** kci_ci_kwargs )
218
+
219
+ def __call__ (self , X , Y , condition_set = None ):
220
+ # Kernel-based conditional independence test.
221
+ Xs , Ys , condition_set , cache_key = self .get_formatted_XYZ_and_cachekey (X , Y , condition_set )
222
+ if cache_key in self .pvalue_cache : return self .pvalue_cache [cache_key ]
223
+ p = self .kci_ui .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ])[0 ] if len (condition_set ) == 0 else \
224
+ self .kci_ci .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ], self .data [:, condition_set ])[0 ]
225
+ self .pvalue_cache [cache_key ] = p
226
+ return p
227
+
228
+ class RCIT (CIT_Base ):
229
+ def __init__ (self , data , ** kwargs ):
230
+ super ().__init__ (data , ** kwargs )
231
+ rit_kwargs = {k : v for k , v in kwargs .items () if k in
232
+ ['approx' ]}
233
+ rcit_kwargs = {k : v for k , v in kwargs .items () if k in
234
+ ['approx' , 'num_f' , 'num_f2' , 'rcit' ]}
235
+ self .check_cache_method_consistent (
236
+ 'kci' , hashlib .md5 (json .dumps (rcit_kwargs , sort_keys = True ).encode ('utf-8' )).hexdigest ())
237
+ self .assert_input_data_is_valid ()
238
+ self .rit = RCIT_UInd (** rit_kwargs )
239
+ self .rcit = RCIT_CInd (** rcit_kwargs )
240
+
241
+ def __call__ (self , X , Y , condition_set = None ):
242
+ # Kernel-based conditional independence test.
243
+ Xs , Ys , condition_set , cache_key = self .get_formatted_XYZ_and_cachekey (X , Y , condition_set )
244
+ if cache_key in self .pvalue_cache : return self .pvalue_cache [cache_key ]
245
+ p = self .rit .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ])[0 ] if len (condition_set ) == 0 else \
246
+ self .rcit .compute_pvalue (self .data [:, Xs ], self .data [:, Ys ], self .data [:, condition_set ])[0 ]
247
+ self .pvalue_cache [cache_key ] = p
248
+ return p
249
+
196
250
class Chisq_or_Gsq (CIT_Base ):
197
251
def __init__ (self , data , method_name , ** kwargs ):
198
252
def _unique (column ):
0 commit comments