Skip to content

Commit 9e63df8

Browse files
committed
more comments in _base_grftree.py
1 parent 18dbeca commit 9e63df8

File tree

1 file changed

+97
-24
lines changed

1 file changed

+97
-24
lines changed

econml/grf/_base_grftree.py

+97-24
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,40 @@ class GRFTree(BaseEstimator):
106106
The minimum weighted fraction of the sum total of weights (of all
107107
the input samples) required to be at a leaf node. Samples have
108108
equal weight when sample_weight is not provided.
109-
min_var_leaf : ,
110-
111-
min_var_leaf_on_val : ,
112-
109+
min_var_leaf : None or float in (0, 1], default=None
110+
A constraint on the minimum degree of identification of the parameter of interest. This avoids performing
111+
splits where either the variance of the treatment is small or the correlation of the instrument with the
112+
treatment is small, or the variance of the instrument is small. Generically for any linear moment problem
113+
this translates to conditions on the leaf jacobian matrix J(leaf) that are proxies for a well-conditioned
114+
matrix, which leads to smaller variance of the local estimate. The proxy of the well-conditioning is
115+
different for different criterion, primarily for computational efficiency reasons.
116+
- If `criterion='het'`, then the diagonal entries of J(leaf) are constraint to have absolute
117+
value at least `min_var_leaf`:
118+
for all i in {1, ..., n_outputs}: abs(J(leaf)[i, i]) > `min_var_leaf`
119+
In the context of a causal tree, when residual treatment is passed
120+
at fit time, then, this translates to a requirement on Var(T[i]) for every treatment coordinate i.
121+
In the context of an IV tree, with residual instruments and residual treatments passed at fit time
122+
this translates to Cov(T[i], Z[i]) > `min_var_leaf` for each coordinate i of the instrument and the
123+
treatment.
124+
- If `criterion='mse'`, because the criterion stores more information about the leaf jacobian for
125+
every candidate split, then we impose further constraints on the pairwise determininants of the
126+
leaf jacobian, as they come at small extra computational cost, i.e.
127+
for all i neq j: sqrt(abs(J(leaf)[i, i] * J(leaf)[j, j] - J(leaf)[i, j] * J(leaf)[j, i])) > `min_var_leaf`
128+
In the context of a causal tree, when residual treatment is passed at fit time, then this
129+
translates to a constraint on the pearson correlation coefficient on any two coordinates
130+
of the treatment within the leaf, i.e.:
131+
for all i neq j: sqrt( Var(T[i]) * Var(T[j]) * (1 - rho(T[i], T[j])^2) ) ) > `min_var_leaf`
132+
where rho(X, Y) is the Pearson correlation coefficient of two random variables X, Y. Thus this
133+
constraint also enforces that no two pairs of treatments be very co-linear within a leaf. This
134+
extra constraint primarily has bite in the case of more than two input treatments.
135+
min_var_leaf_on_val : bool, default=False
136+
Whether the `min_var_leaf` constraint should also be enforced to hold on the validation set of the
137+
honest split too. If `min_var_leaf=None` then this flag does nothing. Setting this to True should
138+
be done with caution, as this partially violates the honesty structure, since parts of the variables
139+
other than the X variable (e.g. the variables that go into the jacobian J of the linear model) are
140+
used to inform the split structure of the tree. However, this is a benign dependence and for instance
141+
in a causal tree or an IV tree does not use the label y. It only uses the treatment T and the instrument
142+
Z and their local correlation structures to decide whether a split is feasible.
113143
max_features : int, float or {"auto", "sqrt", "log2"}, default=None
114144
The number of features to consider when looking for the best split:
115145
- If int, then consider `max_features` features at each split.
@@ -144,27 +174,44 @@ class GRFTree(BaseEstimator):
144174
left child, and ``N_t_R`` is the number of samples in the right child.
145175
``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
146176
if ``sample_weight`` is passed.
147-
min_balancedness_tol:,
148-
149-
honest: ,
177+
min_balancedness_tol: float in [0, .5], default=.45
178+
How imbalanced a split we can tolerate. This enforces that each split leaves at least
179+
(.5 - min_balancedness_tol) fraction of samples on each side of the split; or fraction
180+
of the total weight of samples, when sample_weight is not None. Default value, ensures
181+
that at least 5% of the parent node weight falls in each side of the split. Set it to 0.0 for no
182+
balancedness and to .5 for perfectly balanced splits. For the formal inference theory
183+
to be valid, this has to be any positive constant bounded away from zero.
184+
honest: bool, default=True
185+
Whether the data should be split in two equally sized samples, such that the one half-sample
186+
is used to determine the optimal split at each node and the other sample is used to determine
187+
the value of every node.
150188
151189
Attributes
152190
----------
153191
feature_importances_ : ndarray of shape (n_features,)
154-
The feature importances.
192+
The feature importances based on the amount of parameter heterogeneity they create.
155193
The higher, the more important the feature.
156-
The importance of a feature is computed as the
157-
(normalized) total reduction of the criterion brought
158-
by that feature. It is also known as the Gini importance [4]_.
159-
Warning: impurity-based feature importances can be misleading for
160-
high cardinality features (many unique values). See
161-
:func:`sklearn.inspection.permutation_importance` as an alternative.
194+
The importance of a feature is computed as the (normalized) total heterogeneity that the feature
195+
creates. Each split that the feature was chosen adds:
196+
parent_weight * (left_weight * right_weight) * mean((value_left[k] - value_right[k])**2) / parent_weight**2
197+
to the importance of the feature. Each such quantity is also weighted by the depth of the split.
198+
By default splits below `max_depth=4` are not used in this calculation and also each split
199+
at depth `depth`, is re-weighted by 1 / (1 + `depth`)**2.0. See the method ``feature_importances``
200+
for a method that allows one to change these defaults.
162201
max_features_ : int
163202
The inferred value of max_features.
164203
n_features_ : int
165204
The number of features when ``fit`` is performed.
166205
n_outputs_ : int
167206
The number of outputs when ``fit`` is performed.
207+
n_relevant_outputs_ : int
208+
The first n_relevant_outputs_ where the ones we cared about when ``fit`` was performed.
209+
n_y_ : int
210+
The raw label dimension when ``fit`` is performed.
211+
n_samples_ : int
212+
The number of training samples when ``fit`` is performed.
213+
honest_ : int
214+
Whether honesty was enabled when ``fit`` was performed
168215
tree_ : Tree instance
169216
The underlying Tree object. Please refer to
170217
``help(econml.tree._tree.Tree)`` for attributes of Tree object.
@@ -209,6 +256,7 @@ def get_depth(self):
209256
"""Return the depth of the decision tree.
210257
The depth of a tree is the maximum distance between the root
211258
and any leaf.
259+
212260
Returns
213261
-------
214262
self.tree_.max_depth : int
@@ -219,6 +267,7 @@ def get_depth(self):
219267

220268
def get_n_leaves(self):
221269
"""Return the number of leaves of the decision tree.
270+
222271
Returns
223272
-------
224273
self.tree_.n_leaves : int
@@ -240,6 +289,28 @@ def init(self,):
240289
return self
241290

242291
def fit(self, X, y, n_y, n_outputs, n_relevant_outputs, sample_weight=None, check_input=True):
292+
"""
293+
Parameters
294+
----------
295+
X : (n, d) array
296+
The features to split on
297+
y : (n, m) array
298+
All the variables required to calculate the criterion function, evaluate splits and
299+
estimate local values, i.e. all the values that go into the moment function except X.
300+
n_y, n_outputs, n_relevant_outputs : auxiliary info passed to the criterion objects that
301+
help the object parse the variable y into each separate variable components.
302+
- In the case when `isinstance(criterion, LinearMomentGRFCriterion)`, then the first
303+
n_y columns of y are the raw outputs, the next n_outputs columns contain the A part
304+
of the moment and the next n_outputs * n_outputs columnts contain the J part of the moment
305+
in row contiguous format. The first n_relevant_outputs parameters of the linear moment
306+
are the ones that we care about. The rest are nuisance parameters.
307+
sample_weight : (n,) array, default=None
308+
The sample weights
309+
check_input : bool, defaul=True
310+
Whether to check the input parameters for validity. Should be set to False to improve
311+
running time in parallel execution, if the variables have already been checked by the
312+
forest class that spawned this tree.
313+
"""
243314

244315
random_state = self.random_state_
245316

@@ -421,7 +492,8 @@ def fit(self, X, y, n_y, n_outputs, n_relevant_outputs, sample_weight=None, chec
421492
return self
422493

423494
def _validate_X_predict(self, X, check_input):
424-
"""Validate X whenever one tries to predict, apply, predict_proba"""
495+
"""Validate X whenever one tries to predict, apply, or any other of the prediction
496+
related methods. """
425497
if check_input:
426498
X = check_array(X, dtype=DTYPE, accept_sparse=False)
427499

@@ -435,6 +507,10 @@ def _validate_X_predict(self, X, check_input):
435507
return X
436508

437509
def get_train_test_split_inds(self,):
510+
""" Regenerate the train_test_split of input sample indices that was used for the training
511+
and the evaluation split of the honest tree construction structure. Uses the same random seed
512+
that was used at ``fit`` time and re-generates the indices.
513+
"""
438514
check_is_fitted(self)
439515
random_state = check_random_state(self.random_seed_)
440516
inds = np.arange(self.n_samples_, dtype=np.intp)
@@ -445,23 +521,20 @@ def get_train_test_split_inds(self,):
445521
return inds, inds
446522

447523
def predict(self, X, check_input=True):
448-
"""Predict class or regression value for X.
449-
For a classification model, the predicted class for each sample in X is
450-
returned. For a regression model, the predicted value based on X is
451-
returned.
524+
"""Return the fitted local parameters for each X, i.e. theta(X).
525+
452526
Parameters
453527
----------
454-
X : {array-like, sparse matrix} of shape (n_samples, n_features)
528+
X : {array-like} of shape (n_samples, n_features)
455529
The input samples. Internally, it will be converted to
456-
``dtype=np.float32`` and if a sparse matrix is provided
457-
to a sparse ``csr_matrix``.
530+
``dtype=np.float64``.
458531
check_input : bool, default=True
459532
Allow to bypass several input checking.
460533
Don't use this parameter unless you know what you do.
461534
Returns
462535
-------
463-
y : array-like of shape (n_samples, n_outputs)
464-
The predicted classes, or the predict values.
536+
theta(X) : array-like of shape (n_samples, n_relevant_outputs)
537+
The estimated target parameters for each row of X
465538
"""
466539
check_is_fitted(self)
467540
X = self._validate_X_predict(X, check_input)

0 commit comments

Comments
 (0)