Skip to content

Commit 77493ee

Browse files
committed
Add curated schema for complement_nb
1 parent 20fea4e commit 77493ee

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

lale/lib/sklearn/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
* lale.lib.sklearn. `BernoulliNB`_
2828
* lale.lib.sklearn. `CalibratedClassifierCV`_
2929
* lale.lib.sklearn. `CCA`_
30+
* lale.lib.sklearn. `ComplementNB`_
3031
* lale.lib.sklearn. `DecisionTreeClassifier`_
3132
* lale.lib.sklearn. `DummyClassifier`_
3233
* lale.lib.sklearn. `ExtraTreesClassifier`_
@@ -102,6 +103,7 @@
102103
.. _`Birch`: lale.lib.sklearn.birch.html
103104
.. _`CalibratedClassifierCV`: lale.lib.sklearn.calibrated_classifier_cv.html
104105
.. _`ColumnTransformer`: lale.lib.sklearn.column_transformer.html
106+
.. _`ComplementNB`: lale.lib.sklearn.complement_nb.html
105107
.. _`DecisionTreeClassifier`: lale.lib.sklearn.decision_tree_classifier.html
106108
.. _`DecisionTreeRegressor`: lale.lib.sklearn.decision_tree_regressor.html
107109
.. _`DummyClassifier`: lale.lib.sklearn.dummy_classifier.html
@@ -165,6 +167,7 @@
165167
from .calibrated_classifier_cv import CalibratedClassifierCV
166168
from .cca import CCA
167169
from .column_transformer import ColumnTransformer
170+
from .complement_nb import ComplementNB
168171
from .decision_tree_classifier import DecisionTreeClassifier
169172
from .decision_tree_regressor import DecisionTreeRegressor
170173
from .dummy_classifier import DummyClassifier

lale/lib/sklearn/complement_nb.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from sklearn.naive_bayes import ComplementNB as Op
2+
3+
from lale.docstrings import set_docstrings
4+
from lale.operators import make_operator
5+
6+
_hyperparams_schema = {
7+
"$schema": "http://json-schema.org/draft-04/schema#",
8+
"description": "The Complement Naive Bayes classifier described in Rennie et al. (2003).",
9+
"allOf": [
10+
{
11+
"type": "object",
12+
"required": ["alpha", "fit_prior", "class_prior", "norm"],
13+
"relevantToOptimizer": [],
14+
"additionalProperties": False,
15+
"properties": {
16+
"alpha": {
17+
"type": "number",
18+
"default": 1.0,
19+
"description": "Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).",
20+
},
21+
"fit_prior": {
22+
"type": "boolean",
23+
"default": True,
24+
"description": "Only used in edge case with a single class in the training set.",
25+
},
26+
"class_prior": {
27+
"anyOf": [
28+
{"type": "array", "items": {"type": "number"}},
29+
{"enum": [None]},
30+
],
31+
"default": None,
32+
"description": "Prior probabilities of the classes. Not used.",
33+
},
34+
"norm": {
35+
"type": "boolean",
36+
"default": False,
37+
"description": "Whether or not a second normalization of the weights is performed",
38+
},
39+
},
40+
},
41+
],
42+
}
43+
_input_fit_schema = {
44+
"$schema": "http://json-schema.org/draft-04/schema#",
45+
"description": "Fit Naive Bayes classifier according to X, y",
46+
"type": "object",
47+
"required": ["X", "y"],
48+
"properties": {
49+
"X": {
50+
"type": "array",
51+
"items": {"type": "array", "items": {"type": "number"}},
52+
"description": "Training vectors, where n_samples is the number of samples and n_features is the number of features.",
53+
},
54+
"y": {
55+
"type": "array",
56+
"items": {"type": "number"},
57+
"description": "Target values.",
58+
},
59+
"sample_weight": {
60+
"anyOf": [{"type": "array", "items": {"type": "number"}}, {"enum": [None]}],
61+
"default": None,
62+
"description": "Weights applied to individual samples (1",
63+
},
64+
},
65+
}
66+
_input_predict_schema = {
67+
"$schema": "http://json-schema.org/draft-04/schema#",
68+
"description": "Perform classification on an array of test vectors X.",
69+
"type": "object",
70+
"required": ["X"],
71+
"properties": {
72+
"X": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}}
73+
},
74+
}
75+
_output_predict_schema = {
76+
"$schema": "http://json-schema.org/draft-04/schema#",
77+
"description": "Predicted target values for X",
78+
"type": "array",
79+
"items": {"type": "number"},
80+
}
81+
_input_predict_proba_schema = {
82+
"$schema": "http://json-schema.org/draft-04/schema#",
83+
"description": "Return probability estimates for the test vector X.",
84+
"type": "object",
85+
"required": ["X"],
86+
"properties": {
87+
"X": {"type": "array", "items": {"type": "array", "items": {"type": "number"}}}
88+
},
89+
}
90+
_output_predict_proba_schema = {
91+
"$schema": "http://json-schema.org/draft-04/schema#",
92+
"description": "Returns the probability of the samples for each class in the model",
93+
"type": "array",
94+
"items": {"type": "array", "items": {"type": "number"}},
95+
}
96+
_combined_schemas = {
97+
"$schema": "http://json-schema.org/draft-04/schema#",
98+
"description": """`Complement Naive Bayes`_ classifier described in Rennie et al. (2003).
99+
100+
.. _`Complement Naive Bayes`: https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.ComplementNB
101+
""",
102+
"documentation_url": "https://lale.readthedocs.io/en/latest/modules/lale.lib.sklearn.complement_nb.html",
103+
"import_from": "sklearn.naive_bayes",
104+
"type": "object",
105+
"tags": {"pre": [], "op": ["estimator"], "post": []},
106+
"properties": {
107+
"hyperparams": _hyperparams_schema,
108+
"input_fit": _input_fit_schema,
109+
"input_predict": _input_predict_schema,
110+
"output_predict": _output_predict_schema,
111+
"input_predict_proba": _input_predict_proba_schema,
112+
"output_predict_proba": _output_predict_proba_schema,
113+
},
114+
}
115+
ComplementNB = make_operator(Op, _combined_schemas)
116+
117+
set_docstrings(ComplementNB)

test/test_core_classifiers.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def test_classifier(self):
131131
"lale.lib.sklearn.BernoulliNB",
132132
"lale.lib.sklearn.CalibratedClassifierCV",
133133
"lale.lib.sklearn.CCA",
134+
"lale.lib.sklearn.ComplementNB",
134135
"lale.lib.sklearn.DummyClassifier",
135136
"lale.lib.sklearn.RandomForestClassifier",
136137
"lale.lib.sklearn.DecisionTreeClassifier",

0 commit comments

Comments
 (0)