@@ -128,6 +128,9 @@ class RandomForestClassifier(skRandomForestClassifier, DiffprivlibMixin): # pyl
128128
129129 """
130130
131+ _parameter_constraints = DiffprivlibMixin ._copy_parameter_constraints (
132+ skRandomForestClassifier , "n_estimators" , "n_jobs" , "verbose" , "random_state" , "warm_start" )
133+
131134 def __init__ (self , n_estimators = 10 , * , epsilon = 1.0 , bounds = None , classes = None , n_jobs = 1 , verbose = 0 , accountant = None ,
132135 random_state = None , max_depth = 5 , warm_start = False , shuffle = False , ** unused_args ):
133136 super ().__init__ (
@@ -145,7 +148,11 @@ def __init__(self, n_estimators=10, *, epsilon=1.0, bounds=None, classes=None, n
145148 self .shuffle = shuffle
146149 self .accountant = BudgetAccountant .load_default (accountant )
147150
148- self .base_estimator = DecisionTreeClassifier ()
151+ # Todo: Remove when scikit-learn v1.2 is a min requirement
152+ if hasattr (self , "estimator" ):
153+ self .estimator = DecisionTreeClassifier ()
154+ else :
155+ self .base_estimator = DecisionTreeClassifier ()
149156 self .estimator_params = ("max_depth" , "epsilon" , "bounds" , "classes" )
150157
151158 self ._warn_unused_args (unused_args )
@@ -170,6 +177,7 @@ def fit(self, X, y, sample_weight=None):
170177 self : object
171178 Fitted estimator.
172179 """
180+ self ._validate_params ()
173181 self .accountant .check (self .epsilon , 0 )
174182
175183 if sample_weight is not None :
@@ -250,6 +258,7 @@ def fit(self, X, y, sample_weight=None):
250258 # that case. However, for joblib 0.12+ we respect any
251259 # parallel_backend contexts set at a higher level,
252260 # since correctness does not rely on using threads.
261+ # Todo: Remove when scikit-learn v1.1 is a min requirement
253262 try :
254263 trees = Parallel (n_jobs = self .n_jobs , verbose = self .verbose , prefer = "threads" )(
255264 delayed (_parallel_build_trees )(
@@ -332,9 +341,12 @@ class DecisionTreeClassifier(skDecisionTreeClassifier, DiffprivlibMixin):
332341
333342 """
334343
344+ _parameter_constraints = DiffprivlibMixin ._copy_parameter_constraints (
345+ skDecisionTreeClassifier , "max_depth" , "random_state" )
346+
335347 def __init__ (self , max_depth = 5 , * , epsilon = 1 , bounds = None , classes = None , random_state = None , accountant = None ,
336348 ** unused_args ):
337- # TODO : Remove try...except when sklearn v1.0 is min- requirement
349+ # Todo : Remove when scikit-learn v1.0 is a min requirement
338350 try :
339351 super ().__init__ ( # pylint: disable=unexpected-keyword-arg
340352 criterion = None ,
@@ -391,6 +403,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
391403 self : DecisionTreeClassifier
392404 Fitted estimator.
393405 """
406+ self ._validate_params ()
394407 random_state = check_random_state (self .random_state )
395408
396409 self .accountant .check (self .epsilon , 0 )
0 commit comments