Skip to content

Commit

Permalink
Added a serial type of running
Browse files Browse the repository at this point in the history
  • Loading branch information
leschultz committed Feb 27, 2024
1 parent 47e0286 commit 05abd28
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package information
name = 'madml'
version = '2.3.9' # Need to increment every time to push to PyPI
version = '2.4.0' # Need to increment every time to push to PyPI
description = 'Application domain of machine learning in materials science.'
url = 'https://github.com/leschultz/'\
'materials_application_domain_machine_learning.git'
Expand Down
12 changes: 10 additions & 2 deletions src/madml/assess.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
y,
g=None,
splitters=None,
n_jobs=-1,
):

'''
Expand All @@ -38,12 +39,14 @@ def __init__(
y = The original target features to be split.
g = The groups of data to be split.
splitters = All the types of splitters to assess.
n_jobs = The number of cores to use.
'''

self.X = X # Features
self.y = y # Target
self.splitters = copy.deepcopy(splitters) # Splitter
self.model = copy.deepcopy(model)
self.n_jobs = n_jobs

# If user defined
self.gt_rmse = self.model.gt_rmse
Expand Down Expand Up @@ -84,7 +87,12 @@ def cv(self, split, save_inner_folds=None):
model = copy.deepcopy(self.model)

try:
model.fit(self.X[train], self.y[train], self.g[train])
model.fit(
self.X[train],
self.y[train],
self.g[train],
n_jobs=self.n_jobs,
)
except Exception:
return pd.DataFrame(), None, name

Expand Down Expand Up @@ -181,7 +189,7 @@ def test(
df_confusion = pd.concat(df_confusion)

# Full fit
self.model.fit(self.X, self.y, self.g)
self.model.fit(self.X, self.y, self.g, n_jobs=self.n_jobs)

# Refit on out-of-bag data for final classification models
self.model.domain_rmse.fit(
Expand Down
45 changes: 31 additions & 14 deletions src/madml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,16 @@ def cv(self, split, gs_model, ds_model, X, y, g=None):

return data

def fit(self, X, y, g=None, d_input=None):
def fit(self, X, y, g=None, d_input=None, n_jobs=-1):
'''
Fit all models. Thresholds for domain classification are also set.
inputs:
X = The features.
y = The target variable.
g = The groups.
d_input = The d cutoff to use that is custom from user.
n_jobs = The number of cores to use.
outputs:
data_cv = Cross validation data used.
Expand All @@ -429,19 +431,34 @@ def fit(self, X, y, g=None, d_input=None):
except Exception:
continue

# Analyze each split in parallel
data_cv = parallel(
self.cv,
splits,
disable=self.disable_tqdm,
gs_model=self.gs_model,
ds_model=self.ds_model,
X=X,
y=y,
g=g,
)

# Combine data
# Analyze each split
if n_jobs == 1:
data_cv = []
for i in splits:
d = self.cv(
i,
gs_model=self.gs_model,
ds_model=self.ds_model,
X=X,
y=y,
g=g,
)

data_cv.append(d)

else:
data_cv = parallel(
self.cv,
splits,
disable=self.disable_tqdm,
gs_model=self.gs_model,
ds_model=self.ds_model,
X=X,
y=y,
g=g,
)

# Put data in one dataframe
data_cv = pd.concat(data_cv)

# Separate data
Expand Down
20 changes: 17 additions & 3 deletions src/madml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import os


def parallel(func, x, message=None, disable=False, *args, **kwargs):
def parallel(
func,
x,
message=None,
disable=False,
n_jobs=-1,
*args,
**kwargs,
):
'''
Run some function in parallel.
Expand All @@ -14,6 +22,7 @@ def parallel(func, x, message=None, disable=False, *args, **kwargs):
x = The list of items to iterate on.
message = A message to print.
disable = Disable tqdm print.
n_jobs = The number of cores to run on.
args = Arguemnts for func.
kwargs = Keyword arguments for func.
Expand All @@ -25,8 +34,13 @@ def parallel(func, x, message=None, disable=False, *args, **kwargs):
print(message)

part_func = partial(func, *args, **kwargs)
cores = os.cpu_count()
with Pool(cores) as pool:

if n_jobs == -1:
n_jobs = os.cpu_count()
else:
n_jobs = n_jobs

with Pool(n_jobs) as pool:
data = list(tqdm(
pool.imap(part_func, x),
total=len(x),
Expand Down

0 comments on commit 05abd28

Please sign in to comment.