Skip to content

Commit 6bb4d7e

Browse files
committed
Fix support for MLFlow 2.0.1 and fix linting
1 parent f57c92d commit 6bb4d7e

File tree

4 files changed

+126
-28
lines changed

4 files changed

+126
-28
lines changed

.flake8

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
exclude = .git,__pycache__,data,tools
3-
ignore = E101, E111, E114, E115, E116, E117, E121, E122, E123, E124, E125, E126, E127, E128, E129, E131, E133, E2, E3, E5, E501, E701, E702, E703, E704, W1, W2, W3, W503, W504
2+
max-line-length = 100
3+
extend-ignore = E203

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ repos:
1111
rev: 5.0.4
1212
hooks:
1313
- id: flake8
14+
args: ['--config=.flake8']
1415
additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.3', 'flake8-debugger==4.1.2', 'flake8-mypy==17.8.0']
1516
- repo: https://github.com/pre-commit/pre-commit-hooks
1617
rev: v4.3.0

comet_for_mlflow/comet_for_mlflow.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,19 @@
4141
from comet_ml.exceptions import CometRestApiException
4242
from comet_ml.offline import upload_single_offline_experiment
4343
from mlflow.entities.run_tag import RunTag
44-
from mlflow.entities.view_type import ViewType
4544
from mlflow.tracking import _get_store
4645
from mlflow.tracking._model_registry.utils import _get_store as get_model_registry_store
4746
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
4847
from tabulate import tabulate
4948
from tqdm import tqdm
5049

50+
from .compat import (
51+
get_artifact_repository,
52+
get_mlflow_model_name,
53+
get_mlflow_run_id,
54+
search_mlflow_store_experiments,
55+
search_mlflow_store_runs,
56+
)
5157
from .file_writer import JsonLinesFile
5258
from .utils import (
5359
get_comet_project_name,
@@ -65,18 +71,10 @@
6571
pass
6672

6773

68-
try:
69-
# MLFLOW version 1.4.0
70-
from mlflow.store.artifact.artifact_repository_registry import (
71-
get_artifact_repository,
72-
)
73-
except ImportError:
74-
# MLFLOW version < 1.4.0
75-
from mlflow.store.artifact_repository_registry import get_artifact_repository
76-
7774
logging.basicConfig(level=logging.INFO, format="%(message)s")
7875
LOGGER = logging.getLogger()
7976

77+
8078
# Install a global exception hook
8179
def except_hook(exc_type, exc_value, exc_traceback):
8280
Reporting.report(
@@ -137,8 +135,7 @@ def __init__(
137135
except UnsupportedModelRegistryStoreURIException:
138136
self.model_registry_store = None
139137

140-
# Most of list_experiments returns a list anyway
141-
self.mlflow_experiments = list(self.store.list_experiments())
138+
self.mlflow_experiments = search_mlflow_store_experiments(self.store)
142139
self.len_experiments = len(self.mlflow_experiments) # We start counting at 0
143140

144141
self.summary = {
@@ -239,22 +236,28 @@ def prepare(self):
239236

240237
LOGGER.info("")
241238
LOGGER.info(
242-
"If you need support, you can contact us at http://chat.comet.ml/ or https://comet.ml/docs/quick-start/#getting-support"
239+
"""If you need support, you can contact us at http://chat.comet.ml/"""
240+
""" or https://comet.ml/docs/quick-start/#getting-support"""
243241
)
244242
LOGGER.info("")
245243

246244
def prepare_mlflow_exp(
247-
self, exp,
245+
self,
246+
exp,
248247
):
249-
runs_info = self.store.list_run_infos(exp.experiment_id, ViewType.ALL)
248+
runs_info = search_mlflow_store_runs(self.store, exp.experiment_id)
250249
len_runs = len(runs_info)
251250

252251
for run_number, run_info in enumerate(runs_info):
253252
try:
254-
run_id = run_info.run_id
253+
run_id = get_mlflow_run_id(run_info)
254+
255255
run = self.store.get_run(run_id)
256256
LOGGER.info(
257-
"## Preparing run %d/%d [%s]", run_number + 1, len_runs, run_id,
257+
"## Preparing run %d/%d [%s]",
258+
run_number + 1,
259+
len_runs,
260+
run_id,
258261
)
259262
LOGGER.debug(
260263
"## Preparing run %d/%d: %r", run_number + 1, len_runs, run
@@ -410,15 +413,25 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
410413
break
411414

412415
if matching_model:
416+
model_name = get_mlflow_model_name(matching_model)
417+
418+
prefix = "models/"
419+
if artifact_path.startswith(prefix):
420+
comet_artifact_path = artifact_path[len(prefix) :]
421+
else:
422+
comet_artifact_path = artifact_path
423+
413424
json_writer.log_artifact_as_model(
414425
local_artifact_path,
415-
artifact_path,
426+
comet_artifact_path,
416427
run_start_time,
417-
matching_model.registered_model.name,
428+
model_name,
418429
)
419430
else:
420431
json_writer.log_artifact_as_asset(
421-
local_artifact_path, artifact_path, run_start_time,
432+
local_artifact_path,
433+
artifact_path,
434+
run_start_time,
422435
)
423436

424437
return self.compress_archive(run.info.run_id)
@@ -438,12 +451,15 @@ def upload(self, prepared_data):
438451
project_note = experiment.tags.get("mlflow.note.content", None)
439452
if project_note:
440453
note_template = (
441-
u"/!\\ This project notes has been copied from MLFlow. It might be overwritten if you run comet_for_mlflow again/!\\ \n%s"
454+
"/!\\ This project notes has been copied from MLFlow."
455+
" It might be overwritten if you run comet_for_mlflow again/!\\ \n%s"
442456
% project_note
443457
)
444458
# We don't support Unicode project notes yet
445459
self.api_client.set_project_notes(
446-
self.workspace, project_name, note_template,
460+
self.workspace,
461+
project_name,
462+
note_template,
447463
)
448464

449465
all_project_names.append(project_name)
@@ -487,7 +503,8 @@ def upload(self, prepared_data):
487503
LOGGER.info("\t- %s", url)
488504

489505
LOGGER.info(
490-
"Get deeper instrumentation by adding Comet SDK to your project: https://comet.ml/docs/python-sdk/mlflow/"
506+
"Get deeper instrumentation by adding Comet SDK to your project:"
507+
" https://comet.ml/docs/python-sdk/mlflow/"
491508
)
492509
LOGGER.info("")
493510

@@ -598,19 +615,23 @@ def create_or_login(self):
598615
Reporting.report("mlflow_new_user", api_key=new_account["apiKey"])
599616

600617
LOGGER.info(
601-
"A Comet.ml account has been created for you and an email was sent to you to setup your password later."
618+
"A Comet.ml account has been created for you and an email was sent to"
619+
" you to setup your password later."
602620
)
603621
save_api_key(new_account["apiKey"])
604622
LOGGER.info(
605-
"Your Comet API Key has been saved to ~/.comet.ini, it is also available on your Comet.ml dashboard."
623+
"Your Comet API Key has been saved to ~/.comet.ini, it is also"
624+
" available on your Comet.ml dashboard."
606625
)
607626
return (
608627
new_account["apiKey"],
609628
new_account["token"],
610629
)
611630
else:
612631
LOGGER.info(
613-
"An account already exists for this account, please input your API Key below (you can find it in your Settings page, https://comet.ml/docs/quick-start/#getting-your-comet-api-key):"
632+
"An account already exists for this account, please input your API Key"
633+
" below (you can find it in your Settings page,"
634+
" https://comet.ml/docs/quick-start/#getting-your-comet-api-key):"
614635
)
615636
api_key = input("API Key: ")
616637

comet_for_mlflow/compat.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2020 Comet.ml Team.
5+
#
6+
# This file is part of Comet-For-MLFlow
7+
# (see https://github.com/comet-ml/comet-for-mlflow).
8+
#
9+
# This program is free software: you can redistribute it and/or modify
10+
# it under the terms of the GNU General Public License as published by
11+
# the Free Software Foundation, either version 3 of the License, or
12+
# (at your option) any later version.
13+
#
14+
# This program is distributed in the hope that it will be useful,
15+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
16+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17+
# GNU General Public License for more details.
18+
#
19+
# You should have received a copy of the GNU General Public License
20+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
21+
#
22+
23+
"""
24+
Contains code to support multiple versions of MLFlow
25+
"""
26+
from mlflow.entities.view_type import ViewType
27+
28+
try:
29+
# MLFLOW version 1.4.0
30+
from mlflow.store.artifact.artifact_repository_registry import ( # noqa
31+
get_artifact_repository,
32+
)
33+
except ImportError:
34+
# MLFLOW version < 1.4.0
35+
from mlflow.store.artifact_repository_registry import ( # noqa
36+
get_artifact_repository,
37+
)
38+
39+
40+
def search_mlflow_store_experiments(mlflow_store):
41+
if hasattr(mlflow_store, "search_experiments"):
42+
# MLflow supports search for up to 50000 experiments, defined in
43+
# mlflow/store/tracking/__init__.py
44+
mlflow_experiments = mlflow_store.search_experiments(max_results=50000)
45+
# TODO: Check if there are more than 50000 experiments
46+
return list(mlflow_experiments)
47+
else:
48+
return list(mlflow_store.list_experiments())
49+
50+
51+
def search_mlflow_store_runs(mlflow_store, experiment_id):
52+
if hasattr(mlflow_store, "search_runs"):
53+
# MLflow supports search for up to 50000 experiments, defined in
54+
# mlflow/store/tracking/__init__.py
55+
return mlflow_store.search_runs(
56+
[experiment_id],
57+
filter_string="",
58+
run_view_type=ViewType.ALL,
59+
max_results=50000,
60+
)
61+
else:
62+
return mlflow_store.list_run_infos(experiment_id, ViewType.ALL)
63+
64+
65+
def get_mlflow_run_id(mlflow_run):
66+
if hasattr(mlflow_run, "info"):
67+
return mlflow_run.info.run_id
68+
else:
69+
return mlflow_run.run_id
70+
71+
72+
def get_mlflow_model_name(mlflow_model):
73+
if hasattr(mlflow_model, "name"):
74+
return mlflow_model.name
75+
else:
76+
return mlflow_model.registered_model.name

0 commit comments

Comments
 (0)