@@ -641,8 +641,10 @@ def _deploy_db(db_schema_name,
641641 Will ONLY be used if the table does not exist and a dataframe is passed in
642642 :param model_cols: (List[str]) The columns from the table to use for the model. If None, all columns in the table
643643 will be passed to the model. If specified, the columns will be passed to the model
644- IN THAT ORDER. The columns passed here must exist in the table.
645- :param classes: (List[str]) The classes (prediction labels) for the model being deployed.\n
644+ IN THAT ORDER. The columns passed here must exist in the table. If creating the
645+ table from a dataframe, the table will be created from the columns in the DF, not
646+ model_cols. model_cols is only used at prediction time
647+ :param classes: (List[str]) The classes (prediction labels) for the model being deployed.
646648 NOTE: If not supplied, the table will have default column names for each class
647649 :param sklearn_args: (dict{str: str}) Prediction options for sklearn models: \n
648650 * Available key value options: \n
@@ -703,6 +705,8 @@ def _deploy_db(db_schema_name,
703705
704706 schema_table_name = f'{ db_schema_name } .{ db_table_name } '
705707
708+ # Feature columns are all of the columns of the table, model_cols are the subset of feature columns that are used \
709+ # in predictions. schema_types contains all columns from feature_columns
706710 feature_columns , schema_types = get_feature_columns_and_types (mlflow ._splice_context , df , create_model_table ,
707711 model_cols , schema_table_name )
708712
@@ -725,7 +729,9 @@ def _deploy_db(db_schema_name,
725729 # Create the schema of the table (we use this a few times)
726730 schema_str = ''
727731 for i in feature_columns :
728- schema_str += f'\t { i } { CONVERSIONS [schema_types [str (i )]]} ,'
732+ spark_data_type = schema_types [str (i )]
733+ assert spark_data_type in CONVERSIONS , f'Type { spark_data_type } not supported for table creation. Remove column and try again'
734+ schema_str += f'\t { i } { CONVERSIONS [spark_data_type ]} ,'
729735
730736 try :
731737 # Create/Alter table 1: DATA
@@ -739,11 +745,13 @@ def _deploy_db(db_schema_name,
739745
740746 # Create Trigger 1: model prediction
741747 print ('Creating model prediction trigger ...' , end = ' ' )
748+ # If model_cols were passed in, we'll use them here. Otherwise, use all of the columns (stored in feature_columns)
749+ model_cols = model_cols or feature_columns
742750 if model_type in (H2OModelType .KEY_VALUE , SklearnModelType .KEY_VALUE , KerasModelType .KEY_VALUE ):
743- create_vti_prediction_trigger (mlflow ._splice_context , schema_table_name , run_id , feature_columns , schema_types ,
751+ create_vti_prediction_trigger (mlflow ._splice_context , schema_table_name , run_id , model_cols , schema_types ,
744752 schema_str , primary_key , classes , model_type , sklearn_args , pred_threshold , verbose )
745753 else :
746- create_prediction_trigger (mlflow ._splice_context , schema_table_name , run_id , feature_columns , schema_types ,
754+ create_prediction_trigger (mlflow ._splice_context , schema_table_name , run_id , model_cols , schema_types ,
747755 schema_str , primary_key , model_type , verbose )
748756 print ('Done.' )
749757
0 commit comments