11from builtins import super
22from collections import defaultdict
3- from os import environ as env_vars , popen as rbash , system as bash
3+ from os import environ as env_vars , popen as rbash , system as bash , path
44from sys import getsizeof
55from time import time , sleep
66from enum import Enum
77from typing import List , Dict , Tuple
88import re
9-
109import requests
10+ from zipfile import ZipFile
11+ from io import BytesIO
1112from requests .auth import HTTPBasicAuth
1213
1314import mlflow
@@ -188,7 +189,7 @@ class MLManager(MlflowClient):
188189 A class for managing your MLFlow Runs/Experiments
189190 """
190191 MLMANAGER_SCHEMA = 'MLMANAGER'
191- ARTIFACT_INSERT_SQL = f'INSERT INTO { MLMANAGER_SCHEMA } .ARTIFACTS (run_uuid, name, "size", "binary") VALUES (?, ?, ?, ?)'
192+ ARTIFACT_INSERT_SQL = f'INSERT INTO { MLMANAGER_SCHEMA } .ARTIFACTS (run_uuid, name, "size", "binary", file_extension ) VALUES (?, ?, ?, ?, ?)'
192193 ARTIFACT_RETRIEVAL_SQL = 'SELECT "binary" FROM ' + f'{ MLMANAGER_SCHEMA } .' + 'ARTIFACTS WHERE name=\' {name}\' ' \
193194 'AND run_uuid=\' {runid}\' '
194195 MLEAP_INSERT_SQL = f'INSERT INTO { MLMANAGER_SCHEMA } .MODELS(RUN_UUID, MODEL) VALUES (?, ?)'
@@ -471,11 +472,14 @@ def log_artifact(self, file_name, name):
471472 Log an artifact for the active run
472473 :param file_name: (str) the name of the file name to log
473474 :param name: (str) the name of the run relative name to store the model under
475+ NOTE: We do not currently support logging directories. If you would like to log a directory, please zip it first
476+ and log the zip file
474477 """
478+ file_ext = path .splitext (file_name )[1 ].lstrip ('.' )
475479 with open (file_name , 'rb' ) as artifact :
476- byte_stream = bytearray (bytes (artifact .read ()))
480+ byte_stream = bytearray (bytes (artifact .read ()))
477481
478- self ._insert_artifact (name , byte_stream )
482+ self ._insert_artifact (name , byte_stream , file_ext = file_ext )
479483
480484 @check_active
481485 def log_artifacts (self , file_names , names ):
@@ -513,8 +517,7 @@ def log_spark_model(self, model, name='model'):
513517 oos .writeObject (model ._to_java ())
514518 oos .flush ()
515519 oos .close ()
516-
517- self ._insert_artifact (name , baos .toByteArray ()) # write the byte stream to the db as a BLOB
520+ self ._insert_artifact (name , baos .toByteArray (), file_ext = 'sparkmodel' ) # write the byte stream to the db as a BLOB
518521
519522 @staticmethod
520523 def _is_spark_model (spark_object ):
@@ -536,11 +539,14 @@ def _is_spark_model(spark_object):
536539
537540 raise Exception ("The model supplied does not appear to be a Spark Model!" )
538541
539- def _insert_artifact (self , name , byte_array , mleap_model = False ):
542+ def _insert_artifact (self , name , byte_array , mleap_model = False , file_ext = None ):
540543 """
541544 :param name: (str) the path to store the binary
542545 under (with respect to the current run)
543546 :param byte_array: (byte[]) Java byte array
547+ :param mleap_model: (bool) whether or not the artifact is an MLeap model
548+ (We handle mleap models differently, likely to change in future releases)
549+ :param file_ext: (str) the file extension of the model (used for downloading)
544550 """
545551 db_connection = self .splice_context .getConnection ()
546552 file_size = getsizeof (byte_array )
@@ -560,6 +566,7 @@ def _insert_artifact(self, name, byte_array, mleap_model=False):
560566 prepared_statement .setString (2 , name )
561567 prepared_statement .setInt (3 , file_size )
562568 prepared_statement .setBinaryStream (4 , binary_input_stream )
569+ prepared_statement .setString (5 ,file_ext )
563570
564571 prepared_statement .execute ()
565572 prepared_statement .close ()
@@ -748,13 +755,22 @@ def download_artifact(self, name, local_path, run_id=None):
748755 :param name: (str) artifact name to load
749756 (with respect to the run)
750757 :param local_path: (str) local path to download the
751- model to
758+ model to. This path MUST include the file extension
752759 :param run_id: (str) the run id to download the artifact
753760 from. Defaults to active run
754761 """
762+ file_ext = path .splitext (local_path )[1 ]
763+ if not file_ext :
764+ raise ValueError ('local_path variable must contain the file extension!' )
765+
766+ run_id = run_id or self .current_run_id
755767 blob_data = self .retrieve_artifact_stream (run_id , name )
756- with open (local_path , 'wb' ) as artifact_file :
757- artifact_file .write (blob_data )
768+ if file_ext == '.zip' :
769+ zip_file = ZipFile (BytesIO (blob_data ))
770+ zip_file .extractall ()
771+ else :
772+ with open (local_path , 'wb' ) as artifact_file :
773+ artifact_file .write (blob_data )
758774
759775 def load_spark_model (self , run_id = None , name = 'model' ):
760776 """
@@ -763,18 +779,19 @@ def load_spark_model(self, run_id=None, name='model'):
763779 :param run_id: the id of the run to get a model from
764780 (the run must have an associated model with it named spark_model)
765781 """
766- if not run_id :
767- run_id = self .current_run_id
768- else :
769- run_id = self .get_run (run_id ).info .run_uuid
782+ run_id = run_id or self .current_run_id
770783
771784 spark_pipeline_blob = self .retrieve_artifact_stream (run_id , name )
785+ # pipeline_zip = ZipFile(BytesIO(spark_pipeline_blob))
786+ # pipeline_zip.extractall()
787+ # pipeline = PipelineModel.load(name)
772788 bis = self .splice_context .jvm .java .io .ByteArrayInputStream (spark_pipeline_blob )
773789 ois = self .splice_context .jvm .java .io .ObjectInputStream (bis )
774790 pipeline = PipelineModel ._from_java (ois .readObject ()) # convert object from Java
775- # PipelineModel to Python PipelineModel
791+ # PipelineModel to Python PipelineModel
776792 ois .close ()
777793
794+
778795 if len (pipeline .stages ) == 1 and self ._is_spark_model (pipeline .stages [0 ]):
779796 pipeline = pipeline .stages [0 ]
780797
0 commit comments