Skip to content

Commit 5cb0ff0

Browse files
authored
Update sklearn api (#60)
* update interface in air_pls * update airpls * update arpls api * fix documentation in index_selector * fix typo in index_selector * fix api in constant baseline correction * update cubic splines * update api in linear correction * update init for linear correction * fix api in non negative * fix api in polynomial and in subtract reference * fix norris williams * fix api in savitzky golay * fix minmax scaler * fix norm scaler * fix point scaler * fix emsc * update msc * update rnv * update snv * update mean filter * update median filter * update savgol filter * update whittaker
1 parent 84a6b62 commit 5cb0ff0

27 files changed

+122
-376
lines changed

chemotools/baseline/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from .air_pls import AirPls
2-
from .ar_pls import ArPls
3-
from .constant_baseline_correction import ConstantBaselineCorrection
4-
from .cubic_spline_correction import CubicSplineCorrection
5-
from .linear_correction import LinearCorrection
6-
from .non_negative import NonNegative
7-
from .polynomial_correction import PolynomialCorrection
8-
from .subtract_reference import SubtractReference
1+
from ._air_pls import AirPls
2+
from ._ar_pls import ArPls
3+
from ._constant_baseline_correction import ConstantBaselineCorrection
4+
from ._cubic_spline_correction import CubicSplineCorrection
5+
from ._linear_correction import LinearCorrection
6+
from ._non_negative import NonNegative
7+
from ._polynomial_correction import PolynomialCorrection
8+
from ._subtract_reference import SubtractReference

chemotools/baseline/air_pls.py renamed to chemotools/baseline/_air_pls.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@ class AirPls(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
3030
The number of iterations used to calculate the baseline. Increasing the number of iterations can improve the
3131
accuracy of the baseline correction, but also increases the computation time.
3232
33-
Attributes
34-
----------
35-
n_features_in_ : int
36-
The number of features in the input data.
37-
38-
_is_fitted : bool
39-
A flag indicating whether the estimator has been fitted to data.
40-
4133
Methods
4234
-------
4335
fit(X, y=None)
@@ -85,13 +77,7 @@ def fit(self, X: np.ndarray, y=None) -> "AirPls":
8577
Returns the instance itself.
8678
"""
8779
# Check that X is a 2D array and has only finite values
88-
X = check_input(X)
89-
90-
# Set the number of features
91-
self.n_features_in_ = X.shape[1]
92-
93-
# Set the fitted attribute to True
94-
self._is_fitted = True
80+
X = self._validate_data(X)
9581

9682
return self
9783

@@ -113,7 +99,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
11399
"""
114100

115101
# Check that the estimator is fitted
116-
check_is_fitted(self, "_is_fitted")
102+
check_is_fitted(self, "n_features_in_")
117103

118104
# Check that X is a 2D array and has only finite values
119105
X = check_input(X)

chemotools/baseline/ar_pls.py renamed to chemotools/baseline/_ar_pls.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ class ArPls(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
2929
nr_iterations : int, optional (default=100)
3030
The maximum number of iterations for the weight updating scheme.
3131
32-
Attributes
33-
----------
34-
n_features_in_ : int
35-
The number of input features.
36-
37-
_is_fitted : bool
38-
Whether the estimator has been fitted.
3932
4033
Methods
4134
-------
@@ -86,13 +79,7 @@ def fit(self, X: np.ndarray, y=None) -> "ArPls":
8679
"""
8780

8881
# Check that X is a 2D array and has only finite values
89-
X = check_input(X)
90-
91-
# Set the number of features
92-
self.n_features_in_ = X.shape[1]
93-
94-
# Set the fitted attribute to True
95-
self._is_fitted = True
82+
X = self._validate_data(X)
9683

9784
return self
9885

@@ -114,7 +101,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
114101
"""
115102

116103
# Check that the estimator is fitted
117-
check_is_fitted(self, "_is_fitted")
104+
check_is_fitted(self, "n_features_in_")
118105

119106
# Check that X is a 2D array and has only finite values
120107
X = check_input(X)

chemotools/baseline/constant_baseline_correction.py renamed to chemotools/baseline/_constant_baseline_correction.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ class ConstantBaselineCorrection(OneToOneFeatureMixin, BaseEstimator, Transforme
3030
end_index_ : int
3131
The index of the end of the range. It is 1 if the wavenumbers are not provided.
3232
33-
n_features_in_ : int
34-
The number of features in the input data.
35-
36-
_is_fitted : bool
37-
Whether the transformer has been fitted to data.
38-
3933
Methods
4034
-------
4135
fit(X, y=None)
@@ -46,7 +40,10 @@ class ConstantBaselineCorrection(OneToOneFeatureMixin, BaseEstimator, Transforme
4640
"""
4741

4842
def __init__(
49-
self, start: int = 0, end: int = 1, wavenumbers: np.ndarray = None,
43+
self,
44+
start: int = 0,
45+
end: int = 1,
46+
wavenumbers: np.ndarray = None,
5047
) -> None:
5148
self.start = start
5249
self.end = end
@@ -70,13 +67,7 @@ def fit(self, X: np.ndarray, y=None) -> "ConstantBaselineCorrection":
7067
The fitted transformer.
7168
"""
7269
# Check that X is a 2D array and has only finite values
73-
X = check_input(X)
74-
75-
# Set the number of features
76-
self.n_features_in_ = X.shape[1]
77-
78-
# Set the fitted attribute to True
79-
self._is_fitted = True
70+
X = self._validate_data(X)
8071

8172
# Set the start and end indices
8273
if self.wavenumbers is None:
@@ -109,7 +100,7 @@ def transform(self, X: np.ndarray, y=0, copy=True) -> np.ndarray:
109100
The transformed input data.
110101
"""
111102
# Check that the estimator is fitted
112-
check_is_fitted(self, "_is_fitted")
103+
check_is_fitted(self, ["start_index_", "end_index_"])
113104

114105
# Check that X is a 2D array and has only finite values
115106
X = check_input(X)

chemotools/baseline/cubic_spline_correction.py renamed to chemotools/baseline/_cubic_spline_correction.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from chemotools.utils.check_inputs import check_input
77

8+
89
class CubicSplineCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
910
"""
10-
A transformer that corrects a baseline by subtracting a cubic spline through the
11+
A transformer that corrects a baseline by subtracting a cubic spline through the
1112
points defined by the indices.
1213
1314
Parameters
@@ -32,6 +33,7 @@ class CubicSplineCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
3233
Transform the input data by subtracting the constant baseline value.
3334
3435
"""
36+
3537
def __init__(self, indices: list = None) -> None:
3638
self.indices = indices
3739

@@ -53,13 +55,7 @@ def fit(self, X: np.ndarray, y=None) -> "CubicSplineCorrection":
5355
The fitted transformer.
5456
"""
5557
# Check that X is a 2D array and has only finite values
56-
X = check_input(X)
57-
58-
# Set the number of features
59-
self.n_features_in_ = X.shape[1]
60-
61-
# Set the fitted attribute to True
62-
self._is_fitted = True
58+
X = self._validate_data(X)
6359

6460
if self.indices is None:
6561
self.indices_ = [0, len(X[0]) - 1]
@@ -89,15 +85,17 @@ def transform(self, X: np.ndarray, y=None, copy=True):
8985
The transformed data.
9086
"""
9187
# Check that the estimator is fitted
92-
check_is_fitted(self, "_is_fitted")
88+
check_is_fitted(self, "indices_")
9389

9490
# Check that X is a 2D array and has only finite values
9591
X = check_input(X)
9692
X_ = X.copy()
9793

9894
# Check that the number of features is the same as the fitted data
9995
if X_.shape[1] != self.n_features_in_:
100-
raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
96+
raise ValueError(
97+
f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
98+
)
10199

102100
# Calculate spline baseline correction
103101
for i, x in enumerate(X_):
@@ -106,7 +104,7 @@ def transform(self, X: np.ndarray, y=None, copy=True):
106104

107105
def _spline_baseline_correct(self, x: np.ndarray) -> np.ndarray:
108106
indices = self.indices_
109-
intensity = x[indices]
107+
intensity = x[indices]
110108
spl = CubicSpline(indices, intensity)
111-
baseline = spl(range(len(x)))
112-
return x - baseline
109+
baseline = spl(range(len(x)))
110+
return x - baseline

chemotools/baseline/linear_correction.py renamed to chemotools/baseline/_linear_correction.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,6 @@ class LinearCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
1010
A transformer that corrects a baseline by subtracting a linear baseline through the
1111
initial and final points of the spectrum.
1212
13-
Parameters
14-
----------
15-
16-
Attributes
17-
----------
18-
n_features_in_ : int
19-
The number of features in the input data.
20-
21-
_is_fitted : bool
22-
Whether the transformer has been fitted to data.
23-
2413
Methods
2514
-------
2615
fit(X, y=None)
@@ -68,13 +57,7 @@ def fit(self, X: np.ndarray, y=None) -> "LinearCorrection":
6857
The fitted transformer.
6958
"""
7059
# Check that X is a 2D array and has only finite values
71-
X = check_input(X)
72-
73-
# Set the number of features
74-
self.n_features_in_ = X.shape[1]
75-
76-
# Set the fitted attribute to True
77-
self._is_fitted = True
60+
X = self._validate_data(X)
7861

7962
return self
8063

@@ -99,7 +82,7 @@ def transform(self, X: np.ndarray, y=0, copy=True) -> np.ndarray:
9982
The transformed data.
10083
"""
10184
# Check that the estimator is fitted
102-
check_is_fitted(self, "_is_fitted")
85+
check_is_fitted(self, "n_features_in_")
10386

10487
# Check that X is a 2D array and has only finite values
10588
X = check_input(X)

chemotools/baseline/non_negative.py renamed to chemotools/baseline/_non_negative.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@ class NonNegative(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
1414
mode : str, optional
1515
The mode to use for the non-negative values. Can be "zero" or "abs".
1616
17-
Attributes
18-
----------
19-
n_features_in_ : int
20-
The number of features in the input data.
21-
22-
_is_fitted : bool
23-
Whether the transformer has been fitted to data.
24-
2517
Methods
2618
-------
2719
fit(X, y=None)
@@ -52,13 +44,7 @@ def fit(self, X: np.ndarray, y=None) -> "NonNegative":
5244
The fitted transformer.
5345
"""
5446
# Check that X is a 2D array and has only finite values
55-
X = check_input(X)
56-
57-
# Set the number of features
58-
self.n_features_in_ = X.shape[1]
59-
60-
# Set the fitted attribute to True
61-
self._is_fitted = True
47+
X = self._validate_data(X)
6248

6349
return self
6450

@@ -80,7 +66,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
8066
The transformed data.
8167
"""
8268
# Check that the estimator is fitted
83-
check_is_fitted(self, "_is_fitted")
69+
check_is_fitted(self, "n_features_in_")
8470

8571
# Check that X is a 2D array and has only finite values
8672
X = check_input(X)

chemotools/baseline/polynomial_correction.py renamed to chemotools/baseline/_polynomial_correction.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from chemotools.utils.check_inputs import check_input
66

7+
78
class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
89
"""
9-
A transformer that subtracts a polynomial baseline from the input data. The polynomial is
10+
A transformer that subtracts a polynomial baseline from the input data. The polynomial is
1011
fitted to the points in the spectrum specified by the indices parameter.
1112
1213
Parameters
@@ -18,14 +19,6 @@ class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin
1819
The indices of the points in the spectrum to fit the polynomial to. Defaults to None,
1920
which fits the polynomial to all points in the spectrum (equivalent to detrend).
2021
21-
Attributes
22-
----------
23-
n_features_in_ : int
24-
The number of features in the input data.
25-
26-
_is_fitted : bool
27-
Whether the transformer has been fitted to data.
28-
2922
Methods
3023
-------
3124
fit(X, y=None)
@@ -37,6 +30,7 @@ class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin
3730
_baseline_correct_spectrum(x)
3831
Subtract the polynomial baseline from a single spectrum.
3932
"""
33+
4034
def __init__(self, order: int = 1, indices: list = None) -> None:
4135
self.order = order
4236
self.indices = indices
@@ -59,22 +53,16 @@ def fit(self, X: np.ndarray, y=None) -> "PolynomialCorrection":
5953
The fitted transformer.
6054
"""
6155
# Check that X is a 2D array and has only finite values
62-
X = check_input(X)
63-
64-
# Set the number of features
65-
self.n_features_in_ = X.shape[1]
66-
67-
# Set the fitted attribute to True
68-
self._is_fitted = True
56+
X = self._validate_data(X)
6957

7058
if self.indices is None:
7159
self.indices_ = range(0, len(X[0]))
7260
else:
7361
self.indices_ = self.indices
7462

7563
return self
76-
77-
def transform(self, X: np.ndarray, y:int=0, copy:bool=True) -> np.ndarray:
64+
65+
def transform(self, X: np.ndarray, y: int = 0, copy: bool = True) -> np.ndarray:
7866
"""
7967
Transform the input data by subtracting the polynomial baseline.
8068
@@ -95,21 +83,23 @@ def transform(self, X: np.ndarray, y:int=0, copy:bool=True) -> np.ndarray:
9583
The transformed data.
9684
"""
9785
# Check that the estimator is fitted
98-
check_is_fitted(self, "_is_fitted")
86+
check_is_fitted(self, "indices_")
9987

10088
# Check that X is a 2D array and has only finite values
10189
X = check_input(X)
10290
X_ = X.copy()
10391

10492
# Check that the number of features is the same as the fitted data
10593
if X_.shape[1] != self.n_features_in_:
106-
raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
94+
raise ValueError(
95+
f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
96+
)
10797

10898
# Calculate polynomial baseline correction
10999
for i, x in enumerate(X_):
110100
X_[i] = self._baseline_correct_spectrum(x)
111101
return X_.reshape(-1, 1) if X_.ndim == 1 else X_
112-
102+
113103
def _baseline_correct_spectrum(self, x: np.ndarray) -> np.ndarray:
114104
"""
115105
Subtract the polynomial baseline from a single spectrum.
@@ -126,5 +116,5 @@ def _baseline_correct_spectrum(self, x: np.ndarray) -> np.ndarray:
126116
"""
127117
intensity = x[self.indices_]
128118
poly = np.polyfit(self.indices_, intensity, self.order)
129-
baseline = [np.polyval(poly, i) for i in range(0, len(x))]
130-
return x - baseline
119+
baseline = [np.polyval(poly, i) for i in range(0, len(x))]
120+
return x - baseline

0 commit comments

Comments
 (0)