16
16
LOAD_DIABETES = "load_diabetes"
17
17
18
18
19
- ADDITIONAL_PARAM_DESCRIPTIONS = """
20
-
19
+ ADDITIONAL_PARAM_DESCRIPTIONS = {
20
+ "input_cols" : """
21
21
input_cols: Optional[Union[str, List[str]]]
22
22
A string or list of strings representing column names that contain features.
23
23
If this parameter is not specified, all columns in the input DataFrame except
24
24
the columns specified by label_cols, sample_weight_col, and passthrough_cols
25
- parameters are considered input columns.
26
-
25
+ parameters are considered input columns. Input columns can also be set after
26
+ initialization with the `set_input_cols` method.
27
+ """ ,
28
+ "label_cols" : """
27
29
label_cols: Optional[Union[str, List[str]]]
28
30
A string or list of strings representing column names that contain labels.
29
- This is a required param for estimators, as there is no way to infer these
30
- columns. If this parameter is not specified, then object is fitted without
31
- labels (like a transformer).
32
-
31
+ Label columns must be specified with this parameter during initialization
32
+ or with the `set_label_cols` method before fitting.
33
+ """ ,
34
+ "output_cols" : """
33
35
output_cols: Optional[Union[str, List[str]]]
34
36
A string or list of strings representing column names that will store the
35
37
output of predict and transform operations. The length of output_cols must
36
- match the expected number of output columns from the specific estimator or
38
+ match the expected number of output columns from the specific predictor or
37
39
transformer class used.
38
- If this parameter is not specified, output column names are derived by
39
- adding an OUTPUT_ prefix to the label column names. These inferred output
40
- column names work for estimator's predict() method, but output_cols must
41
- be set explicitly for transformers.
42
-
40
+ If you omit this parameter, output column names are derived by adding an
41
+ OUTPUT_ prefix to the label column names for supervised estimators, or
42
+ OUTPUT_<IDX>for unsupervised estimators. These inferred output column names
43
+ work for predictors, but output_cols must be set explicitly for transformers.
44
+ In general, explicitly specifying output column names is clearer, especially
45
+ if you don’t specify the input column names.
46
+ To transform in place, pass the same names for input_cols and output_cols.
47
+ be set explicitly for transformers. Output columns can also be set after
48
+ initialization with the `set_output_cols` method.
49
+ """ ,
50
+ "sample_weight_col" : """
43
51
sample_weight_col: Optional[str]
44
52
A string representing the column name containing the sample weights.
45
- This argument is only required when working with weighted datasets.
46
-
53
+ This argument is only required when working with weighted datasets. Sample
54
+ weight column can also be set after initialization with the
55
+ `set_sample_weight_col` method.
56
+ """ ,
57
+ "passthrough_cols" : """
47
58
passthrough_cols: Optional[Union[str, List[str]]]
48
59
A string or a list of strings indicating column names to be excluded from any
49
60
operations (such as train, transform, or inference). These specified column(s)
50
61
will remain untouched throughout the process. This option is helpful in scenarios
51
62
requiring automatic input_cols inference, but need to avoid using specific
52
- columns, like index columns, during training or inference.
53
-
63
+ columns, like index columns, during training or inference. Passthrough columns
64
+ can also be set after initialization with the `set_passthrough_cols` method.
65
+ """ ,
66
+ "drop_input_cols" : """
54
67
drop_input_cols: Optional[bool], default=False
55
68
If set, the response of predict(), transform() methods will not contain input columns.
56
- """
69
+ """ ,
70
+ }
57
71
58
72
ADDITIONAL_METHOD_DESCRIPTION = """
59
73
Raises:
@@ -448,7 +462,6 @@ class WrapperGeneratorBase:
448
462
is contained in.
449
463
estimator_imports GENERATED Imports needed for the estimator / fit()
450
464
call.
451
- wrapper_provider_class GENERATED Class name of wrapper provider.
452
465
------------------------------------------------------------------------------------
453
466
SIGNATURES AND ARGUMENTS
454
467
------------------------------------------------------------------------------------
@@ -545,7 +558,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
545
558
self .estimator_imports = ""
546
559
self .estimator_imports_list : List [str ] = []
547
560
self .score_sproc_imports : List [str ] = []
548
- self .wrapper_provider_class = ""
549
561
self .additional_import_statements = ""
550
562
551
563
# Test strings
@@ -630,10 +642,11 @@ def _populate_class_doc_fields(self) -> None:
630
642
class_docstring = inspect .getdoc (self .class_object [1 ]) or ""
631
643
class_docstring = class_docstring .rsplit ("Attributes\n " , 1 )[0 ]
632
644
645
+ parameters_heading = "Parameters\n ----------\n "
633
646
class_description , param_description = (
634
- class_docstring .rsplit ("Parameters \n " , 1 )
635
- if len (class_docstring .rsplit ("Parameters \n " , 1 )) == 2
636
- else (class_docstring , "---------- \n " )
647
+ class_docstring .rsplit (parameters_heading , 1 )
648
+ if len (class_docstring .rsplit (parameters_heading , 1 )) == 2
649
+ else (class_docstring , "" )
637
650
)
638
651
639
652
# Extract the first sentence of the class description
@@ -645,9 +658,11 @@ def _populate_class_doc_fields(self) -> None:
645
658
f"]\n ({ self .get_doc_link ()} )"
646
659
)
647
660
648
- # Add SnowML specific param descriptions.
649
- param_description = "Parameters\n " + param_description .strip ()
650
- param_description += ADDITIONAL_PARAM_DESCRIPTIONS
661
+ # Add SnowML specific param descriptions before third party parameters.
662
+ snowml_parameters = ""
663
+ for d in ADDITIONAL_PARAM_DESCRIPTIONS .values ():
664
+ snowml_parameters += d
665
+ param_description = f"{ parameters_heading } { snowml_parameters } \n { param_description .strip ()} "
651
666
652
667
class_docstring = f"{ class_description } \n \n { param_description } "
653
668
class_docstring = textwrap .indent (class_docstring , " " ).strip ()
@@ -718,12 +733,23 @@ def _populate_function_names_and_signatures(self) -> None:
718
733
for member in inspect .getmembers (self .class_object [1 ]):
719
734
if member [0 ] == "__init__" :
720
735
self .original_init_signature = inspect .signature (member [1 ])
736
+ elif member [0 ] == "fit" :
737
+ original_fit_signature = inspect .signature (member [1 ])
738
+ if original_fit_signature .parameters ["y" ].default is None :
739
+ # The fit does not require labels, so our label_cols argument is optional.
740
+ ADDITIONAL_PARAM_DESCRIPTIONS [
741
+ "label_cols"
742
+ ] = """
743
+ label_cols: Optional[Union[str, List[str]]]
744
+ This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
745
+ """
721
746
722
747
signature_lines = []
723
748
sklearn_init_lines = []
724
749
init_member_args = []
725
750
has_kwargs = False
726
751
sklearn_init_args_dict_list = []
752
+
727
753
for k , v in self .original_init_signature .parameters .items ():
728
754
if k == "self" :
729
755
signature_lines .append ("self" )
@@ -855,9 +881,9 @@ def generate(self) -> "WrapperGeneratorBase":
855
881
self ._populate_flags ()
856
882
self ._populate_class_names ()
857
883
self ._populate_import_statements ()
858
- self ._populate_class_doc_fields ()
859
884
self ._populate_function_doc_fields ()
860
885
self ._populate_function_names_and_signatures ()
886
+ self ._populate_class_doc_fields ()
861
887
self ._populate_file_paths ()
862
888
self ._populate_integ_test_fields ()
863
889
return self
@@ -876,13 +902,8 @@ def generate(self) -> "SklearnWrapperGenerator":
876
902
# Populate all the common values
877
903
super ().generate ()
878
904
879
- is_model_selector = WrapperGeneratorFactory ._is_class_of_type (self .class_object [1 ], "BaseSearchCV" )
880
-
881
905
# Populate SKLearn specific values
882
906
self .estimator_imports_list .extend (["import sklearn" , f"import { self .root_module_name } " ])
883
- self .wrapper_provider_class = (
884
- "SklearnModelSelectionWrapperProvider" if is_model_selector else "SklearnWrapperProvider"
885
- )
886
907
self .score_sproc_imports = ["sklearn" ]
887
908
888
909
if "random_state" in self .original_init_signature .parameters .keys ():
@@ -982,6 +1003,9 @@ def generate(self) -> "SklearnWrapperGenerator":
982
1003
if self ._is_hist_gradient_boosting_regressor :
983
1004
self .test_estimator_input_args_list .extend (["min_samples_leaf=1" , "max_leaf_nodes=100" ])
984
1005
1006
+ self .deps = (
1007
+ "f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'"
1008
+ )
985
1009
self .supported_export_method = "to_sklearn"
986
1010
self .unsupported_export_methods = ["to_xgboost" , "to_lightgbm" ]
987
1011
self ._construct_string_from_lists ()
@@ -1010,10 +1034,10 @@ def generate(self) -> "XGBoostWrapperGenerator":
1010
1034
["random_state=0" , "subsample=1.0" , "colsample_bynode=1.0" , "n_jobs=1" ]
1011
1035
)
1012
1036
self .score_sproc_imports = ["xgboost" ]
1013
- self .wrapper_provider_class = "XGBoostWrapperProvider"
1014
1037
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
1015
1038
self .supported_export_method = "to_xgboost"
1016
1039
self .unsupported_export_methods = ["to_sklearn" , "to_lightgbm" ]
1040
+ self .deps = "f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'"
1017
1041
self ._construct_string_from_lists ()
1018
1042
return self
1019
1043
@@ -1039,8 +1063,8 @@ def generate(self) -> "LightGBMWrapperGenerator":
1039
1063
self .estimator_imports_list .append ("import lightgbm" )
1040
1064
self .test_estimator_input_args_list .extend (["random_state=0" , "n_jobs=1" ])
1041
1065
self .score_sproc_imports = ["lightgbm" ]
1042
- self .wrapper_provider_class = "LightGBMWrapperProvider"
1043
1066
1067
+ self .deps = "f'numpy=={np.__version__}', f'lightgbm=={lightgbm.__version__}', f'cloudpickle=={cp.__version__}'"
1044
1068
self .supported_export_method = "to_lightgbm"
1045
1069
self .unsupported_export_methods = ["to_sklearn" , "to_xgboost" ]
1046
1070
self ._construct_string_from_lists ()
0 commit comments