Skip to content
This repository was archived by the owner on Apr 15, 2022. It is now read-only.

Commit 5d43423

Browse files
author
Epstein
authored
Dbaas 3569 (#43)
* handling artifact directories as zip files * zip checking in download * handling spaces * fixed zip file * back to original model saving * file extension checking * back to original Pipeline storage * file ext no .
1 parent 9507e76 commit 5d43423

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

splicemachine/ml/management.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from builtins import super
22
from collections import defaultdict
3-
from os import environ as env_vars, popen as rbash, system as bash
3+
from os import environ as env_vars, popen as rbash, system as bash, path
44
from sys import getsizeof
55
from time import time, sleep
66
from enum import Enum
77
from typing import List, Dict, Tuple
88
import re
9-
109
import requests
10+
from zipfile import ZipFile
11+
from io import BytesIO
1112
from requests.auth import HTTPBasicAuth
1213

1314
import mlflow
@@ -188,7 +189,7 @@ class MLManager(MlflowClient):
188189
A class for managing your MLFlow Runs/Experiments
189190
"""
190191
MLMANAGER_SCHEMA = 'MLMANAGER'
191-
ARTIFACT_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.ARTIFACTS (run_uuid, name, "size", "binary") VALUES (?, ?, ?, ?)'
192+
ARTIFACT_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.ARTIFACTS (run_uuid, name, "size", "binary", file_extension) VALUES (?, ?, ?, ?, ?)'
192193
ARTIFACT_RETRIEVAL_SQL = 'SELECT "binary" FROM ' + f'{MLMANAGER_SCHEMA}.' + 'ARTIFACTS WHERE name=\'{name}\' ' \
193194
'AND run_uuid=\'{runid}\''
194195
MLEAP_INSERT_SQL = f'INSERT INTO {MLMANAGER_SCHEMA}.MODELS(RUN_UUID, MODEL) VALUES (?, ?)'
@@ -471,11 +472,14 @@ def log_artifact(self, file_name, name):
471472
Log an artifact for the active run
472473
:param file_name: (str) the name of the file name to log
473474
:param name: (str) the name of the run relative name to store the model under
475+
NOTE: We do not currently support logging directories. If you would like to log a directory, please zip it first
476+
and log the zip file
474477
"""
478+
file_ext = path.splitext(file_name)[1].lstrip('.')
475479
with open(file_name, 'rb') as artifact:
476-
byte_stream = bytearray(bytes(artifact.read()))
480+
byte_stream = bytearray(bytes(artifact.read()))
477481

478-
self._insert_artifact(name, byte_stream)
482+
self._insert_artifact(name, byte_stream, file_ext=file_ext)
479483

480484
@check_active
481485
def log_artifacts(self, file_names, names):
@@ -513,8 +517,7 @@ def log_spark_model(self, model, name='model'):
513517
oos.writeObject(model._to_java())
514518
oos.flush()
515519
oos.close()
516-
517-
self._insert_artifact(name, baos.toByteArray()) # write the byte stream to the db as a BLOB
520+
self._insert_artifact(name, baos.toByteArray(), file_ext='sparkmodel') # write the byte stream to the db as a BLOB
518521

519522
@staticmethod
520523
def _is_spark_model(spark_object):
@@ -536,11 +539,14 @@ def _is_spark_model(spark_object):
536539

537540
raise Exception("The model supplied does not appear to be a Spark Model!")
538541

539-
def _insert_artifact(self, name, byte_array, mleap_model=False):
542+
def _insert_artifact(self, name, byte_array, mleap_model=False, file_ext=None):
540543
"""
541544
:param name: (str) the path to store the binary
542545
under (with respect to the current run)
543546
:param byte_array: (byte[]) Java byte array
547+
:param mleap_model: (bool) whether or not the artifact is an MLeap model
548+
(We handle mleap models differently, likely to change in future releases)
549+
:param file_ext: (str) the file extension of the model (used for downloading)
544550
"""
545551
db_connection = self.splice_context.getConnection()
546552
file_size = getsizeof(byte_array)
@@ -560,6 +566,7 @@ def _insert_artifact(self, name, byte_array, mleap_model=False):
560566
prepared_statement.setString(2, name)
561567
prepared_statement.setInt(3, file_size)
562568
prepared_statement.setBinaryStream(4, binary_input_stream)
569+
prepared_statement.setString(5,file_ext)
563570

564571
prepared_statement.execute()
565572
prepared_statement.close()
@@ -748,13 +755,22 @@ def download_artifact(self, name, local_path, run_id=None):
748755
:param name: (str) artifact name to load
749756
(with respect to the run)
750757
:param local_path: (str) local path to download the
751-
model to
758+
model to. This path MUST include the file extension
752759
:param run_id: (str) the run id to download the artifact
753760
from. Defaults to active run
754761
"""
762+
file_ext = path.splitext(local_path)[1]
763+
if not file_ext:
764+
raise ValueError('local_path variable must contain the file extension!')
765+
766+
run_id = run_id or self.current_run_id
755767
blob_data = self.retrieve_artifact_stream(run_id, name)
756-
with open(local_path, 'wb') as artifact_file:
757-
artifact_file.write(blob_data)
768+
if file_ext == '.zip':
769+
zip_file = ZipFile(BytesIO(blob_data))
770+
zip_file.extractall()
771+
else:
772+
with open(local_path, 'wb') as artifact_file:
773+
artifact_file.write(blob_data)
758774

759775
def load_spark_model(self, run_id=None, name='model'):
760776
"""
@@ -763,18 +779,19 @@ def load_spark_model(self, run_id=None, name='model'):
763779
:param run_id: the id of the run to get a model from
764780
(the run must have an associated model with it named spark_model)
765781
"""
766-
if not run_id:
767-
run_id = self.current_run_id
768-
else:
769-
run_id = self.get_run(run_id).info.run_uuid
782+
run_id = run_id or self.current_run_id
770783

771784
spark_pipeline_blob = self.retrieve_artifact_stream(run_id, name)
785+
# pipeline_zip = ZipFile(BytesIO(spark_pipeline_blob))
786+
# pipeline_zip.extractall()
787+
# pipeline = PipelineModel.load(name)
772788
bis = self.splice_context.jvm.java.io.ByteArrayInputStream(spark_pipeline_blob)
773789
ois = self.splice_context.jvm.java.io.ObjectInputStream(bis)
774790
pipeline = PipelineModel._from_java(ois.readObject()) # convert object from Java
775-
# PipelineModel to Python PipelineModel
791+
# PipelineModel to Python PipelineModel
776792
ois.close()
777793

794+
778795
if len(pipeline.stages) == 1 and self._is_spark_model(pipeline.stages[0]):
779796
pipeline = pipeline.stages[0]
780797

0 commit comments

Comments
 (0)