@@ -106,10 +106,40 @@ class GRFTree(BaseEstimator):
106
106
The minimum weighted fraction of the sum total of weights (of all
107
107
the input samples) required to be at a leaf node. Samples have
108
108
equal weight when sample_weight is not provided.
109
- min_var_leaf : ,
110
-
111
- min_var_leaf_on_val : ,
112
-
109
+ min_var_leaf : None or float in (0, 1], default=None
110
+ A constraint on the minimum degree of identification of the parameter of interest. This avoids performing
111
+ splits where either the variance of the treatment is small or the correlation of the instrument with the
112
+ treatment is small, or the variance of the instrument is small. Generically for any linear moment problem
113
+ this translates to conditions on the leaf jacobian matrix J(leaf) that are proxies for a well-conditioned
114
+ matrix, which leads to smaller variance of the local estimate. The proxy of the well-conditioning is
115
+ different for different criterion, primarily for computational efficiency reasons.
116
+ - If `criterion='het'`, then the diagonal entries of J(leaf) are constraint to have absolute
117
+ value at least `min_var_leaf`:
118
+ for all i in {1, ..., n_outputs}: abs(J(leaf)[i, i]) > `min_var_leaf`
119
+ In the context of a causal tree, when residual treatment is passed
120
+ at fit time, then, this translates to a requirement on Var(T[i]) for every treatment coordinate i.
121
+ In the context of an IV tree, with residual instruments and residual treatments passed at fit time
122
+ this translates to Cov(T[i], Z[i]) > `min_var_leaf` for each coordinate i of the instrument and the
123
+ treatment.
124
+ - If `criterion='mse'`, because the criterion stores more information about the leaf jacobian for
125
+ every candidate split, then we impose further constraints on the pairwise determininants of the
126
+ leaf jacobian, as they come at small extra computational cost, i.e.
127
+ for all i neq j: sqrt(abs(J(leaf)[i, i] * J(leaf)[j, j] - J(leaf)[i, j] * J(leaf)[j, i])) > `min_var_leaf`
128
+ In the context of a causal tree, when residual treatment is passed at fit time, then this
129
+ translates to a constraint on the pearson correlation coefficient on any two coordinates
130
+ of the treatment within the leaf, i.e.:
131
+ for all i neq j: sqrt( Var(T[i]) * Var(T[j]) * (1 - rho(T[i], T[j])^2) ) ) > `min_var_leaf`
132
+ where rho(X, Y) is the Pearson correlation coefficient of two random variables X, Y. Thus this
133
+ constraint also enforces that no two pairs of treatments be very co-linear within a leaf. This
134
+ extra constraint primarily has bite in the case of more than two input treatments.
135
+ min_var_leaf_on_val : bool, default=False
136
+ Whether the `min_var_leaf` constraint should also be enforced to hold on the validation set of the
137
+ honest split too. If `min_var_leaf=None` then this flag does nothing. Setting this to True should
138
+ be done with caution, as this partially violates the honesty structure, since parts of the variables
139
+ other than the X variable (e.g. the variables that go into the jacobian J of the linear model) are
140
+ used to inform the split structure of the tree. However, this is a benign dependence and for instance
141
+ in a causal tree or an IV tree does not use the label y. It only uses the treatment T and the instrument
142
+ Z and their local correlation structures to decide whether a split is feasible.
113
143
max_features : int, float or {"auto", "sqrt", "log2"}, default=None
114
144
The number of features to consider when looking for the best split:
115
145
- If int, then consider `max_features` features at each split.
@@ -144,27 +174,44 @@ class GRFTree(BaseEstimator):
144
174
left child, and ``N_t_R`` is the number of samples in the right child.
145
175
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
146
176
if ``sample_weight`` is passed.
147
- min_balancedness_tol:,
148
-
149
- honest: ,
177
+ min_balancedness_tol: float in [0, .5], default=.45
178
+ How imbalanced a split we can tolerate. This enforces that each split leaves at least
179
+ (.5 - min_balancedness_tol) fraction of samples on each side of the split; or fraction
180
+ of the total weight of samples, when sample_weight is not None. Default value, ensures
181
+ that at least 5% of the parent node weight falls in each side of the split. Set it to 0.0 for no
182
+ balancedness and to .5 for perfectly balanced splits. For the formal inference theory
183
+ to be valid, this has to be any positive constant bounded away from zero.
184
+ honest: bool, default=True
185
+ Whether the data should be split in two equally sized samples, such that the one half-sample
186
+ is used to determine the optimal split at each node and the other sample is used to determine
187
+ the value of every node.
150
188
151
189
Attributes
152
190
----------
153
191
feature_importances_ : ndarray of shape (n_features,)
154
- The feature importances.
192
+ The feature importances based on the amount of parameter heterogeneity they create .
155
193
The higher, the more important the feature.
156
- The importance of a feature is computed as the
157
- (normalized) total reduction of the criterion brought
158
- by that feature. It is also known as the Gini importance [4]_.
159
- Warning: impurity-based feature importances can be misleading for
160
- high cardinality features (many unique values). See
161
- :func:`sklearn.inspection.permutation_importance` as an alternative.
194
+ The importance of a feature is computed as the (normalized) total heterogeneity that the feature
195
+ creates. Each split that the feature was chosen adds:
196
+ parent_weight * (left_weight * right_weight) * mean((value_left[k] - value_right[k])**2) / parent_weight**2
197
+ to the importance of the feature. Each such quantity is also weighted by the depth of the split.
198
+ By default splits below `max_depth=4` are not used in this calculation and also each split
199
+ at depth `depth`, is re-weighted by 1 / (1 + `depth`)**2.0. See the method ``feature_importances``
200
+ for a method that allows one to change these defaults.
162
201
max_features_ : int
163
202
The inferred value of max_features.
164
203
n_features_ : int
165
204
The number of features when ``fit`` is performed.
166
205
n_outputs_ : int
167
206
The number of outputs when ``fit`` is performed.
207
+ n_relevant_outputs_ : int
208
+ The first n_relevant_outputs_ where the ones we cared about when ``fit`` was performed.
209
+ n_y_ : int
210
+ The raw label dimension when ``fit`` is performed.
211
+ n_samples_ : int
212
+ The number of training samples when ``fit`` is performed.
213
+ honest_ : int
214
+ Whether honesty was enabled when ``fit`` was performed
168
215
tree_ : Tree instance
169
216
The underlying Tree object. Please refer to
170
217
``help(econml.tree._tree.Tree)`` for attributes of Tree object.
@@ -209,6 +256,7 @@ def get_depth(self):
209
256
"""Return the depth of the decision tree.
210
257
The depth of a tree is the maximum distance between the root
211
258
and any leaf.
259
+
212
260
Returns
213
261
-------
214
262
self.tree_.max_depth : int
@@ -219,6 +267,7 @@ def get_depth(self):
219
267
220
268
def get_n_leaves (self ):
221
269
"""Return the number of leaves of the decision tree.
270
+
222
271
Returns
223
272
-------
224
273
self.tree_.n_leaves : int
@@ -240,6 +289,28 @@ def init(self,):
240
289
return self
241
290
242
291
def fit (self , X , y , n_y , n_outputs , n_relevant_outputs , sample_weight = None , check_input = True ):
292
+ """
293
+ Parameters
294
+ ----------
295
+ X : (n, d) array
296
+ The features to split on
297
+ y : (n, m) array
298
+ All the variables required to calculate the criterion function, evaluate splits and
299
+ estimate local values, i.e. all the values that go into the moment function except X.
300
+ n_y, n_outputs, n_relevant_outputs : auxiliary info passed to the criterion objects that
301
+ help the object parse the variable y into each separate variable components.
302
+ - In the case when `isinstance(criterion, LinearMomentGRFCriterion)`, then the first
303
+ n_y columns of y are the raw outputs, the next n_outputs columns contain the A part
304
+ of the moment and the next n_outputs * n_outputs columnts contain the J part of the moment
305
+ in row contiguous format. The first n_relevant_outputs parameters of the linear moment
306
+ are the ones that we care about. The rest are nuisance parameters.
307
+ sample_weight : (n,) array, default=None
308
+ The sample weights
309
+ check_input : bool, defaul=True
310
+ Whether to check the input parameters for validity. Should be set to False to improve
311
+ running time in parallel execution, if the variables have already been checked by the
312
+ forest class that spawned this tree.
313
+ """
243
314
244
315
random_state = self .random_state_
245
316
@@ -421,7 +492,8 @@ def fit(self, X, y, n_y, n_outputs, n_relevant_outputs, sample_weight=None, chec
421
492
return self
422
493
423
494
def _validate_X_predict (self , X , check_input ):
424
- """Validate X whenever one tries to predict, apply, predict_proba"""
495
+ """Validate X whenever one tries to predict, apply, or any other of the prediction
496
+ related methods. """
425
497
if check_input :
426
498
X = check_array (X , dtype = DTYPE , accept_sparse = False )
427
499
@@ -435,6 +507,10 @@ def _validate_X_predict(self, X, check_input):
435
507
return X
436
508
437
509
def get_train_test_split_inds (self ,):
510
+ """ Regenerate the train_test_split of input sample indices that was used for the training
511
+ and the evaluation split of the honest tree construction structure. Uses the same random seed
512
+ that was used at ``fit`` time and re-generates the indices.
513
+ """
438
514
check_is_fitted (self )
439
515
random_state = check_random_state (self .random_seed_ )
440
516
inds = np .arange (self .n_samples_ , dtype = np .intp )
@@ -445,23 +521,20 @@ def get_train_test_split_inds(self,):
445
521
return inds , inds
446
522
447
523
def predict (self , X , check_input = True ):
448
- """Predict class or regression value for X.
449
- For a classification model, the predicted class for each sample in X is
450
- returned. For a regression model, the predicted value based on X is
451
- returned.
524
+ """Return the fitted local parameters for each X, i.e. theta(X).
525
+
452
526
Parameters
453
527
----------
454
- X : {array-like, sparse matrix } of shape (n_samples, n_features)
528
+ X : {array-like} of shape (n_samples, n_features)
455
529
The input samples. Internally, it will be converted to
456
- ``dtype=np.float32`` and if a sparse matrix is provided
457
- to a sparse ``csr_matrix``.
530
+ ``dtype=np.float64``.
458
531
check_input : bool, default=True
459
532
Allow to bypass several input checking.
460
533
Don't use this parameter unless you know what you do.
461
534
Returns
462
535
-------
463
- y : array-like of shape (n_samples, n_outputs )
464
- The predicted classes, or the predict values.
536
+ theta(X) : array-like of shape (n_samples, n_relevant_outputs )
537
+ The estimated target parameters for each row of X
465
538
"""
466
539
check_is_fitted (self )
467
540
X = self ._validate_X_predict (X , check_input )
0 commit comments