diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..156b6b26
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,3 @@
+include requirements.txt
+include stats_requirements.txt
+include notebook_requirements.txt
diff --git a/README.md b/README.md
index fe94ab36..e58d0ff3 100644
--- a/README.md
+++ b/README.md
@@ -35,3 +35,11 @@ This package contains 4 main external modules. First, `splicemachine.spark.conte
4.2) [`splicemachine.notebooks`](https://pysplice.readthedocs.io/en/latest/splicemachine.notebook.html): houses utilities for use in Jupyter Notebooks running in the Kubernetes cloud environment
+## Docs
+The docs are managed py readthedocs and Sphinx. See latest docs [here](https://pysplice.readthedocs.io/en/latest/)
+
+### Building the docs
+```
+cd docs
+make html
+```
diff --git a/dist/splicemachine-2.7.0.dev0.tar.gz b/dist/splicemachine-2.7.0.dev0.tar.gz
deleted file mode 100644
index 138e3804..00000000
Binary files a/dist/splicemachine-2.7.0.dev0.tar.gz and /dev/null differ
diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle
deleted file mode 100644
index dd4997b7..00000000
Binary files a/docs/_build/doctrees/environment.pickle and /dev/null differ
diff --git a/docs/_build/doctrees/getting-started.doctree b/docs/_build/doctrees/getting-started.doctree
deleted file mode 100644
index 42b8b68f..00000000
Binary files a/docs/_build/doctrees/getting-started.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/index.doctree b/docs/_build/doctrees/index.doctree
deleted file mode 100644
index 8365b3ab..00000000
Binary files a/docs/_build/doctrees/index.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/spark.doctree b/docs/_build/doctrees/spark.doctree
deleted file mode 100644
index a3b4b4f3..00000000
Binary files a/docs/_build/doctrees/spark.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.doctree b/docs/_build/doctrees/splicemachine.doctree
deleted file mode 100644
index e1dbf42e..00000000
Binary files a/docs/_build/doctrees/splicemachine.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.features.doctree b/docs/_build/doctrees/splicemachine.features.doctree
deleted file mode 100644
index 99c77275..00000000
Binary files a/docs/_build/doctrees/splicemachine.features.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.mlflow_support.doctree b/docs/_build/doctrees/splicemachine.mlflow_support.doctree
deleted file mode 100644
index d97be221..00000000
Binary files a/docs/_build/doctrees/splicemachine.mlflow_support.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.notebook.doctree b/docs/_build/doctrees/splicemachine.notebook.doctree
deleted file mode 100644
index 547e1bb4..00000000
Binary files a/docs/_build/doctrees/splicemachine.notebook.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.spark.doctree b/docs/_build/doctrees/splicemachine.spark.doctree
deleted file mode 100644
index 466f61f0..00000000
Binary files a/docs/_build/doctrees/splicemachine.spark.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/splicemachine.stats.doctree b/docs/_build/doctrees/splicemachine.stats.doctree
deleted file mode 100644
index 5e9256f0..00000000
Binary files a/docs/_build/doctrees/splicemachine.stats.doctree and /dev/null differ
diff --git a/docs/_build/epub/.buildinfo b/docs/_build/epub/.buildinfo
deleted file mode 100644
index ff9e2d6e..00000000
--- a/docs/_build/epub/.buildinfo
+++ /dev/null
@@ -1,4 +0,0 @@
-# Sphinx build info version 1
-# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 22e6ad6f3e6c1f1a4b7bbda07fbd3f3c
-tags: 490e2b0d4a1bebf665648774830bc9b4
diff --git a/docs/_build/epub/META-INF/container.xml b/docs/_build/epub/META-INF/container.xml
deleted file mode 100644
index 326cf15f..00000000
--- a/docs/_build/epub/META-INF/container.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from abc import ABCMeta, abstractmethod
-
-import copy
-import threading
-
-from pyspark import since
-from pyspark.ml.param.shared import *
-from pyspark.ml.common import inherit_doc
-from pyspark.sql.functions import udf
-from pyspark.sql.types import StructField, StructType
-
-
-class _FitMultipleIterator(object):
- """
- Used by default implementation of Estimator.fitMultiple to produce models in a thread safe
- iterator. This class handles the simple case of fitMultiple where each param map should be
- fit independently.
-
- :param fitSingleModel: Function: (int => Model) which fits an estimator to a dataset.
- `fitSingleModel` may be called up to `numModels` times, with a unique index each time.
- Each call to `fitSingleModel` with an index should return the Model associated with
- that index.
- :param numModel: Number of models this iterator should produce.
-
- See Estimator.fitMultiple for more info.
- """
- def __init__(self, fitSingleModel, numModels):
- """
-
- """
- self.fitSingleModel = fitSingleModel
- self.numModel = numModels
- self.counter = 0
- self.lock = threading.Lock()
-
- def __iter__(self):
- return self
-
- def __next__(self):
- with self.lock:
- index = self.counter
- if index >= self.numModel:
- raise StopIteration("No models remaining.")
- self.counter += 1
- return index, self.fitSingleModel(index)
-
- def next(self):
- """For python2 compatibility."""
- return self.__next__()
-
-
-@inherit_doc
-class Estimator(Params):
- """
- Abstract class for estimators that fit models to data.
-
- .. versionadded:: 1.3.0
- """
-
- __metaclass__ = ABCMeta
-
- @abstractmethod
- def _fit(self, dataset):
- """
- Fits a model to the input dataset. This is called by the default implementation of fit.
-
- :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
- :returns: fitted model
- """
- raise NotImplementedError()
-
- @since("2.3.0")
- def fitMultiple(self, dataset, paramMaps):
- """
- Fits a model to the input dataset for each param map in `paramMaps`.
-
- :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`.
- :param paramMaps: A Sequence of param maps.
- :return: A thread safe iterable which contains one model for each param map. Each
- call to `next(modelIterator)` will return `(index, model)` where model was fit
- using `paramMaps[index]`. `index` values may not be sequential.
- """
- estimator = self.copy()
-
- def fitSingleModel(index):
- return estimator.fit(dataset, paramMaps[index])
-
- return _FitMultipleIterator(fitSingleModel, len(paramMaps))
-
- @since("1.3.0")
- def fit(self, dataset, params=None):
- """
- Fits a model to the input dataset with optional parameters.
-
- :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overrides embedded params. If a list/tuple of
- param maps is given, this calls fit on each param map and returns a list of
- models.
- :returns: fitted model(s)
- """
- if params is None:
- params = dict()
- if isinstance(params, (list, tuple)):
- models = [None] * len(params)
- for index, model in self.fitMultiple(dataset, params):
- models[index] = model
- return models
- elif isinstance(params, dict):
- if params:
- return self.copy(params)._fit(dataset)
- else:
- return self._fit(dataset)
- else:
- raise ValueError("Params must be either a param map or a list/tuple of param maps, "
- "but got %s." % type(params))
-
-
-@inherit_doc
-class Transformer(Params):
- """
- Abstract class for transformers that transform one dataset into another.
-
- .. versionadded:: 1.3.0
- """
-
- __metaclass__ = ABCMeta
-
- @abstractmethod
- def _transform(self, dataset):
- """
- Transforms the input dataset.
-
- :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
- :returns: transformed dataset
- """
- raise NotImplementedError()
-
- @since("1.3.0")
- def transform(self, dataset, params=None):
- """
- Transforms the input dataset with optional parameters.
-
- :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overrides embedded params.
- :returns: transformed dataset
- """
- if params is None:
- params = dict()
- if isinstance(params, dict):
- if params:
- return self.copy(params)._transform(dataset)
- else:
- return self._transform(dataset)
- else:
- raise ValueError("Params must be a param map but got %s." % type(params))
-
-
-@inherit_doc
-class Model(Transformer):
- """
- Abstract class for models that are fitted by estimators.
-
- .. versionadded:: 1.4.0
- """
-
- __metaclass__ = ABCMeta
-
-
-@inherit_doc
-class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
- """
- Abstract class for transformers that take one input column, apply transformation,
- and output the result as a new column.
-
- .. versionadded:: 2.3.0
- """
-
- def setInputCol(self, value):
- """
- Sets the value of :py:attr:`inputCol`.
- """
- return self._set(inputCol=value)
-
- def setOutputCol(self, value):
- """
- Sets the value of :py:attr:`outputCol`.
- """
- return self._set(outputCol=value)
-
- @abstractmethod
- def createTransformFunc(self):
- """
- Creates the transform function using the given param map. The input param map already takes
- account of the embedded param map. So the param values should be determined
- solely by the input param map.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def outputDataType(self):
- """
- Returns the data type of the output column.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def validateInputType(self, inputType):
- """
- Validates the input type. Throw an exception if it is invalid.
- """
- raise NotImplementedError()
-
- def transformSchema(self, schema):
- inputType = schema[self.getInputCol()].dataType
- self.validateInputType(inputType)
- if self.getOutputCol() in schema.names:
- raise ValueError("Output column %s already exists." % self.getOutputCol())
- outputFields = copy.copy(schema.fields)
- outputFields.append(StructField(self.getOutputCol(),
- self.outputDataType(),
- nullable=False))
- return StructType(outputFields)
-
- def _transform(self, dataset):
- self.transformSchema(dataset.schema)
- transformUDF = udf(self.createTransformFunc(), self.outputDataType())
- transformedDataset = dataset.withColumn(self.getOutputCol(),
- transformUDF(dataset[self.getInputCol()]))
- return transformedDataset
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import array
-import sys
-if sys.version > '3':
- basestring = str
- xrange = range
- unicode = str
-
-from abc import ABCMeta
-import copy
-import numpy as np
-
-from py4j.java_gateway import JavaObject
-
-from pyspark.ml.linalg import DenseVector, Vector, Matrix
-from pyspark.ml.util import Identifiable
-
-
-__all__ = ['Param', 'Params', 'TypeConverters']
-
-
-class Param(object):
- """
- A param with self-contained documentation.
-
- .. versionadded:: 1.3.0
- """
-
- def __init__(self, parent, name, doc, typeConverter=None):
- if not isinstance(parent, Identifiable):
- raise TypeError("Parent must be an Identifiable but got type %s." % type(parent))
- self.parent = parent.uid
- self.name = str(name)
- self.doc = str(doc)
- self.typeConverter = TypeConverters.identity if typeConverter is None else typeConverter
-
- def _copy_new_parent(self, parent):
- """Copy the current param to a new parent, must be a dummy param."""
- if self.parent == "undefined":
- param = copy.copy(self)
- param.parent = parent.uid
- return param
- else:
- raise ValueError("Cannot copy from non-dummy parent %s." % parent)
-
- def __str__(self):
- return str(self.parent) + "__" + self.name
-
- def __repr__(self):
- return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
-
- def __hash__(self):
- return hash(str(self))
-
- def __eq__(self, other):
- if isinstance(other, Param):
- return self.parent == other.parent and self.name == other.name
- else:
- return False
-
-
-class TypeConverters(object):
- """
- Factory methods for common type conversion functions for `Param.typeConverter`.
-
- .. versionadded:: 2.0.0
- """
-
- @staticmethod
- def _is_numeric(value):
- vtype = type(value)
- return vtype in [int, float, np.float64, np.int64] or vtype.__name__ == 'long'
-
- @staticmethod
- def _is_integer(value):
- return TypeConverters._is_numeric(value) and float(value).is_integer()
-
- @staticmethod
- def _can_convert_to_list(value):
- vtype = type(value)
- return vtype in [list, np.ndarray, tuple, xrange, array.array] or isinstance(value, Vector)
-
- @staticmethod
- def _can_convert_to_string(value):
- vtype = type(value)
- return isinstance(value, basestring) or vtype in [np.unicode_, np.string_, np.str_]
-
- @staticmethod
- def identity(value):
- """
- Dummy converter that just returns value.
- """
- return value
-
- @staticmethod
- def toList(value):
- """
- Convert a value to a list, if possible.
- """
- if type(value) == list:
- return value
- elif type(value) in [np.ndarray, tuple, xrange, array.array]:
- return list(value)
- elif isinstance(value, Vector):
- return list(value.toArray())
- else:
- raise TypeError("Could not convert %s to list" % value)
-
- @staticmethod
- def toListFloat(value):
- """
- Convert a value to list of floats, if possible.
- """
- if TypeConverters._can_convert_to_list(value):
- value = TypeConverters.toList(value)
- if all(map(lambda v: TypeConverters._is_numeric(v), value)):
- return [float(v) for v in value]
- raise TypeError("Could not convert %s to list of floats" % value)
-
- @staticmethod
- def toListListFloat(value):
- """
- Convert a value to list of list of floats, if possible.
- """
- if TypeConverters._can_convert_to_list(value):
- value = TypeConverters.toList(value)
- return [TypeConverters.toListFloat(v) for v in value]
- raise TypeError("Could not convert %s to list of list of floats" % value)
-
- @staticmethod
- def toListInt(value):
- """
- Convert a value to list of ints, if possible.
- """
- if TypeConverters._can_convert_to_list(value):
- value = TypeConverters.toList(value)
- if all(map(lambda v: TypeConverters._is_integer(v), value)):
- return [int(v) for v in value]
- raise TypeError("Could not convert %s to list of ints" % value)
-
- @staticmethod
- def toListString(value):
- """
- Convert a value to list of strings, if possible.
- """
- if TypeConverters._can_convert_to_list(value):
- value = TypeConverters.toList(value)
- if all(map(lambda v: TypeConverters._can_convert_to_string(v), value)):
- return [TypeConverters.toString(v) for v in value]
- raise TypeError("Could not convert %s to list of strings" % value)
-
- @staticmethod
- def toVector(value):
- """
- Convert a value to a MLlib Vector, if possible.
- """
- if isinstance(value, Vector):
- return value
- elif TypeConverters._can_convert_to_list(value):
- value = TypeConverters.toList(value)
- if all(map(lambda v: TypeConverters._is_numeric(v), value)):
- return DenseVector(value)
- raise TypeError("Could not convert %s to vector" % value)
-
- @staticmethod
- def toMatrix(value):
- """
- Convert a value to a MLlib Matrix, if possible.
- """
- if isinstance(value, Matrix):
- return value
- raise TypeError("Could not convert %s to matrix" % value)
-
- @staticmethod
- def toFloat(value):
- """
- Convert a value to a float, if possible.
- """
- if TypeConverters._is_numeric(value):
- return float(value)
- else:
- raise TypeError("Could not convert %s to float" % value)
-
- @staticmethod
- def toInt(value):
- """
- Convert a value to an int, if possible.
- """
- if TypeConverters._is_integer(value):
- return int(value)
- else:
- raise TypeError("Could not convert %s to int" % value)
-
- @staticmethod
- def toString(value):
- """
- Convert a value to a string, if possible.
- """
- if isinstance(value, basestring):
- return value
- elif type(value) in [np.string_, np.str_]:
- return str(value)
- elif type(value) == np.unicode_:
- return unicode(value)
- else:
- raise TypeError("Could not convert %s to string type" % type(value))
-
- @staticmethod
- def toBoolean(value):
- """
- Convert a value to a boolean, if possible.
- """
- if type(value) == bool:
- return value
- else:
- raise TypeError("Boolean Param requires value of type bool. Found %s." % type(value))
-
-
-class Params(Identifiable):
- """
- Components that take parameters. This also provides an internal
- param map to store parameter values attached to the instance.
-
- .. versionadded:: 1.3.0
- """
-
- __metaclass__ = ABCMeta
-
- def __init__(self):
- super(Params, self).__init__()
- #: internal param map for user-supplied values param map
- self._paramMap = {}
-
- #: internal param map for default values
- self._defaultParamMap = {}
-
- #: value returned by :py:func:`params`
- self._params = None
-
- # Copy the params from the class to the object
- self._copy_params()
-
- def _copy_params(self):
- """
- Copy all params defined on the class to current object.
- """
- cls = type(self)
- src_name_attrs = [(x, getattr(cls, x)) for x in dir(cls)]
- src_params = list(filter(lambda nameAttr: isinstance(nameAttr[1], Param), src_name_attrs))
- for name, param in src_params:
- setattr(self, name, param._copy_new_parent(self))
-
- @property
- def params(self):
- """
- Returns all params ordered by name. The default implementation
- uses :py:func:`dir` to get all attributes of type
- :py:class:`Param`.
- """
- if self._params is None:
- self._params = list(filter(lambda attr: isinstance(attr, Param),
- [getattr(self, x) for x in dir(self) if x != "params" and
- not isinstance(getattr(type(self), x, None), property)]))
- return self._params
-
- def explainParam(self, param):
- """
- Explains a single param and returns its name, doc, and optional
- default value and user-supplied value in a string.
- """
- param = self._resolveParam(param)
- values = []
- if self.isDefined(param):
- if param in self._defaultParamMap:
- values.append("default: %s" % self._defaultParamMap[param])
- if param in self._paramMap:
- values.append("current: %s" % self._paramMap[param])
- else:
- values.append("undefined")
- valueStr = "(" + ", ".join(values) + ")"
- return "%s: %s %s" % (param.name, param.doc, valueStr)
-
- def explainParams(self):
- """
- Returns the documentation of all params with their optionally
- default values and user-supplied values.
- """
- return "\n".join([self.explainParam(param) for param in self.params])
-
- def getParam(self, paramName):
- """
- Gets a param by its name.
- """
- param = getattr(self, paramName)
- if isinstance(param, Param):
- return param
- else:
- raise ValueError("Cannot find param with name %s." % paramName)
-
- def isSet(self, param):
- """
- Checks whether a param is explicitly set by user.
- """
- param = self._resolveParam(param)
- return param in self._paramMap
-
- def hasDefault(self, param):
- """
- Checks whether a param has a default value.
- """
- param = self._resolveParam(param)
- return param in self._defaultParamMap
-
- def isDefined(self, param):
- """
- Checks whether a param is explicitly set by user or has
- a default value.
- """
- return self.isSet(param) or self.hasDefault(param)
-
- def hasParam(self, paramName):
- """
- Tests whether this instance contains a param with a given
- (string) name.
- """
- if isinstance(paramName, basestring):
- p = getattr(self, paramName, None)
- return isinstance(p, Param)
- else:
- raise TypeError("hasParam(): paramName must be a string")
-
- def getOrDefault(self, param):
- """
- Gets the value of a param in the user-supplied param map or its
- default value. Raises an error if neither is set.
- """
- param = self._resolveParam(param)
- if param in self._paramMap:
- return self._paramMap[param]
- else:
- return self._defaultParamMap[param]
-
- def extractParamMap(self, extra=None):
- """
- Extracts the embedded default param values and user-supplied
- values, and then merges them with extra values from input into
- a flat param map, where the latter value is used if there exist
- conflicts, i.e., with ordering: default param values <
- user-supplied values < extra.
-
- :param extra: extra param values
- :return: merged param map
- """
- if extra is None:
- extra = dict()
- paramMap = self._defaultParamMap.copy()
- paramMap.update(self._paramMap)
- paramMap.update(extra)
- return paramMap
-
- def copy(self, extra=None):
- """
- Creates a copy of this instance with the same uid and some
- extra params. The default implementation creates a
- shallow copy using :py:func:`copy.copy`, and then copies the
- embedded and extra parameters over and returns the copy.
- Subclasses should override this method if the default approach
- is not sufficient.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- that = copy.copy(self)
- that._paramMap = {}
- that._defaultParamMap = {}
- return self._copyValues(that, extra)
-
- def set(self, param, value):
- """
- Sets a parameter in the embedded param map.
- """
- self._shouldOwn(param)
- try:
- value = param.typeConverter(value)
- except ValueError as e:
- raise ValueError('Invalid param value given for param "%s". %s' % (param.name, e))
- self._paramMap[param] = value
-
- def _shouldOwn(self, param):
- """
- Validates that the input param belongs to this Params instance.
- """
- if not (self.uid == param.parent and self.hasParam(param.name)):
- raise ValueError("Param %r does not belong to %r." % (param, self))
-
- def _resolveParam(self, param):
- """
- Resolves a param and validates the ownership.
-
- :param param: param name or the param instance, which must
- belong to this Params instance
- :return: resolved param instance
- """
- if isinstance(param, Param):
- self._shouldOwn(param)
- return param
- elif isinstance(param, basestring):
- return self.getParam(param)
- else:
- raise ValueError("Cannot resolve %r as a param." % param)
-
- @staticmethod
- def _dummy():
- """
- Returns a dummy Params instance used as a placeholder to
- generate docs.
- """
- dummy = Params()
- dummy.uid = "undefined"
- return dummy
-
- def _set(self, **kwargs):
- """
- Sets user-supplied params.
- """
- for param, value in kwargs.items():
- p = getattr(self, param)
- if value is not None:
- try:
- value = p.typeConverter(value)
- except TypeError as e:
- raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
- self._paramMap[p] = value
- return self
-
- def clear(self, param):
- """
- Clears a param from the param map if it has been explicitly set.
- """
- if self.isSet(param):
- del self._paramMap[param]
-
- def _setDefault(self, **kwargs):
- """
- Sets default params.
- """
- for param, value in kwargs.items():
- p = getattr(self, param)
- if value is not None and not isinstance(value, JavaObject):
- try:
- value = p.typeConverter(value)
- except TypeError as e:
- raise TypeError('Invalid default param value given for param "%s". %s'
- % (p.name, e))
- self._defaultParamMap[p] = value
- return self
-
- def _copyValues(self, to, extra=None):
- """
- Copies param values from this instance to another instance for
- params shared by them.
-
- :param to: the target instance
- :param extra: extra params to be copied
- :return: the target instance with param values copied
- """
- paramMap = self._paramMap.copy()
- if isinstance(extra, dict):
- for param, value in extra.items():
- if isinstance(param, Param):
- paramMap[param] = value
- else:
- raise TypeError("Expecting a valid instance of Param, but received: {}"
- .format(param))
- elif extra is not None:
- raise TypeError("Expecting a dict, but received an object of type {}."
- .format(type(extra)))
- for param in self.params:
- # copy default params
- if param in self._defaultParamMap and to.hasParam(param.name):
- to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
- # copy explicitly set params
- if param in paramMap and to.hasParam(param.name):
- to._set(**{param.name: paramMap[param]})
- return to
-
- def _resetUid(self, newUid):
- """
- Changes the uid of this instance. This updates both
- the stored uid and the parent uid of params and param maps.
- This is used by persistence (loading).
- :param newUid: new uid to use, which is converted to unicode
- :return: same instance, but with the uid and Param.parent values
- updated, including within param maps
- """
- newUid = unicode(newUid)
- self.uid = newUid
- newDefaultParamMap = dict()
- newParamMap = dict()
- for param in self.params:
- newParam = copy.copy(param)
- newParam.parent = newUid
- if param in self._defaultParamMap:
- newDefaultParamMap[newParam] = self._defaultParamMap[param]
- if param in self._paramMap:
- newParamMap[newParam] = self._paramMap[param]
- param.parent = newUid
- self._defaultParamMap = newDefaultParamMap
- self._paramMap = newParamMap
- return self
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.
-
-from pyspark.ml.param import *
-
-
-class HasMaxIter(Params):
- """
- Mixin for param maxIter: max number of iterations (>= 0).
- """
-
- maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasMaxIter, self).__init__()
-
- def getMaxIter(self):
- """
- Gets the value of maxIter or its default value.
- """
- return self.getOrDefault(self.maxIter)
-
-
-class HasRegParam(Params):
- """
- Mixin for param regParam: regularization parameter (>= 0).
- """
-
- regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasRegParam, self).__init__()
-
- def getRegParam(self):
- """
- Gets the value of regParam or its default value.
- """
- return self.getOrDefault(self.regParam)
-
-
-class HasFeaturesCol(Params):
- """
- Mixin for param featuresCol: features column name.
- """
-
- featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasFeaturesCol, self).__init__()
- self._setDefault(featuresCol='features')
-
- def getFeaturesCol(self):
- """
- Gets the value of featuresCol or its default value.
- """
- return self.getOrDefault(self.featuresCol)
-
-
-class HasLabelCol(Params):
- """
- Mixin for param labelCol: label column name.
- """
-
- labelCol = Param(Params._dummy(), "labelCol", "label column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasLabelCol, self).__init__()
- self._setDefault(labelCol='label')
-
- def getLabelCol(self):
- """
- Gets the value of labelCol or its default value.
- """
- return self.getOrDefault(self.labelCol)
-
-
-class HasPredictionCol(Params):
- """
- Mixin for param predictionCol: prediction column name.
- """
-
- predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasPredictionCol, self).__init__()
- self._setDefault(predictionCol='prediction')
-
- def getPredictionCol(self):
- """
- Gets the value of predictionCol or its default value.
- """
- return self.getOrDefault(self.predictionCol)
-
-
-class HasProbabilityCol(Params):
- """
- Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
- """
-
- probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasProbabilityCol, self).__init__()
- self._setDefault(probabilityCol='probability')
-
- def getProbabilityCol(self):
- """
- Gets the value of probabilityCol or its default value.
- """
- return self.getOrDefault(self.probabilityCol)
-
-
-class HasRawPredictionCol(Params):
- """
- Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name.
- """
-
- rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasRawPredictionCol, self).__init__()
- self._setDefault(rawPredictionCol='rawPrediction')
-
- def getRawPredictionCol(self):
- """
- Gets the value of rawPredictionCol or its default value.
- """
- return self.getOrDefault(self.rawPredictionCol)
-
-
-class HasInputCol(Params):
- """
- Mixin for param inputCol: input column name.
- """
-
- inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasInputCol, self).__init__()
-
- def getInputCol(self):
- """
- Gets the value of inputCol or its default value.
- """
- return self.getOrDefault(self.inputCol)
-
-
-class HasInputCols(Params):
- """
- Mixin for param inputCols: input column names.
- """
-
- inputCols = Param(Params._dummy(), "inputCols", "input column names.", typeConverter=TypeConverters.toListString)
-
- def __init__(self):
- super(HasInputCols, self).__init__()
-
- def getInputCols(self):
- """
- Gets the value of inputCols or its default value.
- """
- return self.getOrDefault(self.inputCols)
-
-
-class HasOutputCol(Params):
- """
- Mixin for param outputCol: output column name.
- """
-
- outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasOutputCol, self).__init__()
- self._setDefault(outputCol=self.uid + '__output')
-
- def getOutputCol(self):
- """
- Gets the value of outputCol or its default value.
- """
- return self.getOrDefault(self.outputCol)
-
-
-class HasOutputCols(Params):
- """
- Mixin for param outputCols: output column names.
- """
-
- outputCols = Param(Params._dummy(), "outputCols", "output column names.", typeConverter=TypeConverters.toListString)
-
- def __init__(self):
- super(HasOutputCols, self).__init__()
-
- def getOutputCols(self):
- """
- Gets the value of outputCols or its default value.
- """
- return self.getOrDefault(self.outputCols)
-
-
-class HasNumFeatures(Params):
- """
- Mixin for param numFeatures: Number of features. Should be greater than 0.
- """
-
- numFeatures = Param(Params._dummy(), "numFeatures", "Number of features. Should be greater than 0.", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasNumFeatures, self).__init__()
- self._setDefault(numFeatures=262144)
-
- def getNumFeatures(self):
- """
- Gets the value of numFeatures or its default value.
- """
- return self.getOrDefault(self.numFeatures)
-
-
-class HasCheckpointInterval(Params):
- """
- Mixin for param checkpointInterval: set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.
- """
-
- checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasCheckpointInterval, self).__init__()
-
- def getCheckpointInterval(self):
- """
- Gets the value of checkpointInterval or its default value.
- """
- return self.getOrDefault(self.checkpointInterval)
-
-
-class HasSeed(Params):
- """
- Mixin for param seed: random seed.
- """
-
- seed = Param(Params._dummy(), "seed", "random seed.", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasSeed, self).__init__()
- self._setDefault(seed=hash(type(self).__name__))
-
- def getSeed(self):
- """
- Gets the value of seed or its default value.
- """
- return self.getOrDefault(self.seed)
-
-
-class HasTol(Params):
- """
- Mixin for param tol: the convergence tolerance for iterative algorithms (>= 0).
- """
-
- tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms (>= 0).", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasTol, self).__init__()
-
- def getTol(self):
- """
- Gets the value of tol or its default value.
- """
- return self.getOrDefault(self.tol)
-
-
-class HasRelativeError(Params):
- """
- Mixin for param relativeError: the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]
- """
-
- relativeError = Param(Params._dummy(), "relativeError", "the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasRelativeError, self).__init__()
- self._setDefault(relativeError=0.001)
-
- def getRelativeError(self):
- """
- Gets the value of relativeError or its default value.
- """
- return self.getOrDefault(self.relativeError)
-
-
-class HasStepSize(Params):
- """
- Mixin for param stepSize: Step size to be used for each iteration of optimization (>= 0).
- """
-
- stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization (>= 0).", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasStepSize, self).__init__()
-
- def getStepSize(self):
- """
- Gets the value of stepSize or its default value.
- """
- return self.getOrDefault(self.stepSize)
-
-
-class HasHandleInvalid(Params):
- """
- Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.
- """
-
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasHandleInvalid, self).__init__()
-
- def getHandleInvalid(self):
- """
- Gets the value of handleInvalid or its default value.
- """
- return self.getOrDefault(self.handleInvalid)
-
-
-class HasElasticNetParam(Params):
- """
- Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
- """
-
- elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasElasticNetParam, self).__init__()
- self._setDefault(elasticNetParam=0.0)
-
- def getElasticNetParam(self):
- """
- Gets the value of elasticNetParam or its default value.
- """
- return self.getOrDefault(self.elasticNetParam)
-
-
-class HasFitIntercept(Params):
- """
- Mixin for param fitIntercept: whether to fit an intercept term.
- """
-
- fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", typeConverter=TypeConverters.toBoolean)
-
- def __init__(self):
- super(HasFitIntercept, self).__init__()
- self._setDefault(fitIntercept=True)
-
- def getFitIntercept(self):
- """
- Gets the value of fitIntercept or its default value.
- """
- return self.getOrDefault(self.fitIntercept)
-
-
-class HasStandardization(Params):
- """
- Mixin for param standardization: whether to standardize the training features before fitting the model.
- """
-
- standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", typeConverter=TypeConverters.toBoolean)
-
- def __init__(self):
- super(HasStandardization, self).__init__()
- self._setDefault(standardization=True)
-
- def getStandardization(self):
- """
- Gets the value of standardization or its default value.
- """
- return self.getOrDefault(self.standardization)
-
-
-class HasThresholds(Params):
- """
- Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
- """
-
- thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
-
- def __init__(self):
- super(HasThresholds, self).__init__()
-
- def getThresholds(self):
- """
- Gets the value of thresholds or its default value.
- """
- return self.getOrDefault(self.thresholds)
-
-
-class HasThreshold(Params):
- """
- Mixin for param threshold: threshold in binary classification prediction, in range [0, 1]
- """
-
- threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]", typeConverter=TypeConverters.toFloat)
-
- def __init__(self):
- super(HasThreshold, self).__init__()
- self._setDefault(threshold=0.5)
-
- def getThreshold(self):
- """
- Gets the value of threshold or its default value.
- """
- return self.getOrDefault(self.threshold)
-
-
-class HasWeightCol(Params):
- """
- Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0.
- """
-
- weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasWeightCol, self).__init__()
-
- def getWeightCol(self):
- """
- Gets the value of weightCol or its default value.
- """
- return self.getOrDefault(self.weightCol)
-
-
-class HasSolver(Params):
- """
- Mixin for param solver: the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
- """
-
- solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasSolver, self).__init__()
- self._setDefault(solver='auto')
-
- def getSolver(self):
- """
- Gets the value of solver or its default value.
- """
- return self.getOrDefault(self.solver)
-
-
-class HasVarianceCol(Params):
- """
- Mixin for param varianceCol: column name for the biased sample variance of prediction.
- """
-
- varianceCol = Param(Params._dummy(), "varianceCol", "column name for the biased sample variance of prediction.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasVarianceCol, self).__init__()
-
- def getVarianceCol(self):
- """
- Gets the value of varianceCol or its default value.
- """
- return self.getOrDefault(self.varianceCol)
-
-
-class HasAggregationDepth(Params):
- """
- Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
- """
-
- aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasAggregationDepth, self).__init__()
- self._setDefault(aggregationDepth=2)
-
- def getAggregationDepth(self):
- """
- Gets the value of aggregationDepth or its default value.
- """
- return self.getOrDefault(self.aggregationDepth)
-
-
-class HasParallelism(Params):
- """
- Mixin for param parallelism: the number of threads to use when running parallel algorithms (>= 1).
- """
-
- parallelism = Param(Params._dummy(), "parallelism", "the number of threads to use when running parallel algorithms (>= 1).", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasParallelism, self).__init__()
- self._setDefault(parallelism=1)
-
- def getParallelism(self):
- """
- Gets the value of parallelism or its default value.
- """
- return self.getOrDefault(self.parallelism)
-
-
-class HasCollectSubModels(Params):
- """
- Mixin for param collectSubModels: Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
- """
-
- collectSubModels = Param(Params._dummy(), "collectSubModels", "Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.", typeConverter=TypeConverters.toBoolean)
-
- def __init__(self):
- super(HasCollectSubModels, self).__init__()
- self._setDefault(collectSubModels=False)
-
- def getCollectSubModels(self):
- """
- Gets the value of collectSubModels or its default value.
- """
- return self.getOrDefault(self.collectSubModels)
-
-
-class HasLoss(Params):
- """
- Mixin for param loss: the loss function to be optimized.
- """
-
- loss = Param(Params._dummy(), "loss", "the loss function to be optimized.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasLoss, self).__init__()
-
- def getLoss(self):
- """
- Gets the value of loss or its default value.
- """
- return self.getOrDefault(self.loss)
-
-
-class HasDistanceMeasure(Params):
- """
- Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
- """
-
- distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasDistanceMeasure, self).__init__()
- self._setDefault(distanceMeasure='euclidean')
-
- def getDistanceMeasure(self):
- """
- Gets the value of distanceMeasure or its default value.
- """
- return self.getOrDefault(self.distanceMeasure)
-
-
-class HasValidationIndicatorCol(Params):
- """
- Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.
- """
-
- validationIndicatorCol = Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasValidationIndicatorCol, self).__init__()
-
- def getValidationIndicatorCol(self):
- """
- Gets the value of validationIndicatorCol or its default value.
- """
- return self.getOrDefault(self.validationIndicatorCol)
-
-
-class HasBlockSize(Params):
- """
- Mixin for param blockSize: block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.
- """
-
- blockSize = Param(Params._dummy(), "blockSize", "block size for stacking input data in matrices. Data is stacked within partitions. If block size is more than remaining data in a partition then it is adjusted to the size of this data.", typeConverter=TypeConverters.toInt)
-
- def __init__(self):
- super(HasBlockSize, self).__init__()
-
- def getBlockSize(self):
- """
- Gets the value of blockSize or its default value.
- """
- return self.getOrDefault(self.blockSize)
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import itertools
-import sys
-from multiprocessing.pool import ThreadPool
-
-import numpy as np
-
-from pyspark import since, keyword_only
-from pyspark.ml import Estimator, Model
-from pyspark.ml.common import _py2java, _java2py
-from pyspark.ml.param import Params, Param, TypeConverters
-from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
-from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams
-from pyspark.sql.functions import rand
-
-__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
- 'TrainValidationSplitModel']
-
-
-def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
- """
- Creates a list of callables which can be called from different threads to fit and evaluate
- an estimator in parallel. Each callable returns an `(index, metric)` pair.
-
- :param est: Estimator, the estimator to be fit.
- :param train: DataFrame, training data set, used for fitting.
- :param eva: Evaluator, used to compute `metric`
- :param validation: DataFrame, validation data set, used for evaluation.
- :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
- :param collectSubModel: Whether to collect sub model.
- :return: (int, float, subModel), an index into `epm` and the associated metric value.
- """
- modelIter = est.fitMultiple(train, epm)
-
- def singleTask():
- index, model = next(modelIter)
- metric = eva.evaluate(model.transform(validation, epm[index]))
- return index, metric, model if collectSubModel else None
-
- return [singleTask] * len(epm)
-
-
-class ParamGridBuilder(object):
- r"""
- Builder for a param grid used in grid search-based model selection.
-
- >>> from pyspark.ml.classification import LogisticRegression
- >>> lr = LogisticRegression()
- >>> output = ParamGridBuilder() \
- ... .baseOn({lr.labelCol: 'l'}) \
- ... .baseOn([lr.predictionCol, 'p']) \
- ... .addGrid(lr.regParam, [1.0, 2.0]) \
- ... .addGrid(lr.maxIter, [1, 5]) \
- ... .build()
- >>> expected = [
- ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
- ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
- ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
- ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
- >>> len(output) == len(expected)
- True
- >>> all([m in expected for m in output])
- True
-
- .. versionadded:: 1.4.0
- """
-
- def __init__(self):
- self._param_grid = {}
-
- @since("1.4.0")
- def addGrid(self, param, values):
- """
- Sets the given parameters in this grid to fixed values.
-
- param must be an instance of Param associated with an instance of Params
- (such as Estimator or Transformer).
- """
- if isinstance(param, Param):
- self._param_grid[param] = values
- else:
- raise TypeError("param must be an instance of Param")
-
- return self
-
- @since("1.4.0")
- def baseOn(self, *args):
- """
- Sets the given parameters in this grid to fixed values.
- Accepts either a parameter dictionary or a list of (parameter, value) pairs.
- """
- if isinstance(args[0], dict):
- self.baseOn(*args[0].items())
- else:
- for (param, value) in args:
- self.addGrid(param, [value])
-
- return self
-
- @since("1.4.0")
- def build(self):
- """
- Builds and returns all combinations of parameters specified
- by the param grid.
- """
- keys = self._param_grid.keys()
- grid_values = self._param_grid.values()
-
- def to_key_value_pairs(keys, values):
- return [(key, key.typeConverter(value)) for key, value in zip(keys, values)]
-
- return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
-
-
-class _ValidatorParams(HasSeed):
- """
- Common params for TrainValidationSplit and CrossValidator.
- """
-
- estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the validator metric")
-
- @since("2.0.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("2.0.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("2.0.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @classmethod
- def _from_java_impl(cls, java_stage):
- """
- Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
- """
-
- # Load information from java_stage to the instance.
- estimator = JavaParams._from_java(java_stage.getEstimator())
- evaluator = JavaParams._from_java(java_stage.getEvaluator())
- epms = [estimator._transfer_param_map_from_java(epm)
- for epm in java_stage.getEstimatorParamMaps()]
- return estimator, epms, evaluator
-
- def _to_java_impl(self):
- """
- Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
- """
-
- gateway = SparkContext._gateway
- cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
-
- java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
- for idx, epm in enumerate(self.getEstimatorParamMaps()):
- java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
-
- java_estimator = self.getEstimator()._to_java()
- java_evaluator = self.getEvaluator()._to_java()
- return java_estimator, java_epms, java_evaluator
-
-
-class _CrossValidatorParams(_ValidatorParams):
- """
- Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.
-
- .. versionadded:: 3.0.0
- """
-
- numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
- typeConverter=TypeConverters.toInt)
-
- @since("1.4.0")
- def getNumFolds(self):
- """
- Gets the value of numFolds or its default value.
- """
- return self.getOrDefault(self.numFolds)
-
-
-class CrossValidator(Estimator, _CrossValidatorParams, HasParallelism, HasCollectSubModels,
- MLReadable, MLWritable):
- """
-
- K-fold cross validation performs model selection by splitting the dataset into a set of
- non-overlapping randomly partitioned folds which are used as separate training and test datasets
- e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,
- each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the
- test set exactly once.
-
-
- >>> from pyspark.ml.classification import LogisticRegression
- >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
- >>> from pyspark.ml.linalg import Vectors
- >>> from pyspark.ml.tuning import CrossValidatorModel
- >>> import tempfile
- >>> dataset = spark.createDataFrame(
- ... [(Vectors.dense([0.0]), 0.0),
- ... (Vectors.dense([0.4]), 1.0),
- ... (Vectors.dense([0.5]), 0.0),
- ... (Vectors.dense([0.6]), 1.0),
- ... (Vectors.dense([1.0]), 1.0)] * 10,
- ... ["features", "label"])
- >>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- >>> evaluator = BinaryClassificationEvaluator()
- >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
- ... parallelism=2)
- >>> cvModel = cv.fit(dataset)
- >>> cvModel.getNumFolds()
- 3
- >>> cvModel.avgMetrics[0]
- 0.5
- >>> path = tempfile.mkdtemp()
- >>> model_path = path + "/model"
- >>> cvModel.write().save(model_path)
- >>> cvModelRead = CrossValidatorModel.read().load(model_path)
- >>> cvModelRead.avgMetrics
- [0.5, ...
- >>> evaluator.evaluate(cvModel.transform(dataset))
- 0.8333...
-
- .. versionadded:: 1.4.0
- """
-
- @keyword_only
- def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1, collectSubModels=False):
- """
- __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1, collectSubModels=False)
- """
- super(CrossValidator, self).__init__()
- self._setDefault(numFolds=3, parallelism=1)
- kwargs = self._input_kwargs
- self._set(**kwargs)
-
- @keyword_only
- @since("1.4.0")
- def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
- seed=None, parallelism=1, collectSubModels=False):
- """
- setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
- seed=None, parallelism=1, collectSubModels=False):
- Sets params for cross validator.
- """
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
- @since("2.0.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- return self._set(estimator=value)
-
- @since("2.0.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- return self._set(estimatorParamMaps=value)
-
- @since("2.0.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- return self._set(evaluator=value)
-
- @since("1.4.0")
- def setNumFolds(self, value):
- """
- Sets the value of :py:attr:`numFolds`.
- """
- return self._set(numFolds=value)
-
- def setSeed(self, value):
- """
- Sets the value of :py:attr:`seed`.
- """
- return self._set(seed=value)
-
- def setParallelism(self, value):
- """
- Sets the value of :py:attr:`parallelism`.
- """
- return self._set(parallelism=value)
-
- def setCollectSubModels(self, value):
- """
- Sets the value of :py:attr:`collectSubModels`.
- """
- return self._set(collectSubModels=value)
-
- def _fit(self, dataset):
- est = self.getOrDefault(self.estimator)
- epm = self.getOrDefault(self.estimatorParamMaps)
- numModels = len(epm)
- eva = self.getOrDefault(self.evaluator)
- nFolds = self.getOrDefault(self.numFolds)
- seed = self.getOrDefault(self.seed)
- h = 1.0 / nFolds
- randCol = self.uid + "_rand"
- df = dataset.select("*", rand(seed).alias(randCol))
- metrics = [0.0] * numModels
-
- pool = ThreadPool(processes=min(self.getParallelism(), numModels))
- subModels = None
- collectSubModelsParam = self.getCollectSubModels()
- if collectSubModelsParam:
- subModels = [[None for j in range(numModels)] for i in range(nFolds)]
-
- for i in range(nFolds):
- validateLB = i * h
- validateUB = (i + 1) * h
- condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
- validation = df.filter(condition).cache()
- train = df.filter(~condition).cache()
-
- tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
- for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
- metrics[j] += (metric / nFolds)
- if collectSubModelsParam:
- subModels[i][j] = subModel
-
- validation.unpersist()
- train.unpersist()
-
- if eva.isLargerBetter():
- bestIndex = np.argmax(metrics)
- else:
- bestIndex = np.argmin(metrics)
- bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
-
- @since("1.4.0")
- def copy(self, extra=None):
- """
- Creates a copy of this instance with a randomly generated uid
- and some extra params. This copies creates a deep copy of
- the embedded paramMap, and copies the embedded and extra parameters over.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- newCV = Params.copy(self, extra)
- if self.isSet(self.estimator):
- newCV.setEstimator(self.getEstimator().copy(extra))
- # estimatorParamMaps remain the same
- if self.isSet(self.evaluator):
- newCV.setEvaluator(self.getEvaluator().copy(extra))
- return newCV
-
- @since("2.3.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @classmethod
- @since("2.3.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java CrossValidator, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
- numFolds = java_stage.getNumFolds()
- seed = java_stage.getSeed()
- parallelism = java_stage.getParallelism()
- collectSubModels = java_stage.getCollectSubModels()
- # Create a new instance of this stage.
- py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- numFolds=numFolds, seed=seed, parallelism=parallelism,
- collectSubModels=collectSubModels)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java CrossValidator. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
-
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
- _java_obj.setEstimatorParamMaps(epms)
- _java_obj.setEvaluator(evaluator)
- _java_obj.setEstimator(estimator)
- _java_obj.setSeed(self.getSeed())
- _java_obj.setNumFolds(self.getNumFolds())
- _java_obj.setParallelism(self.getParallelism())
- _java_obj.setCollectSubModels(self.getCollectSubModels())
-
- return _java_obj
-
-
-class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
- """
-
- CrossValidatorModel contains the model with the highest average cross-validation
- metric across folds and uses this model to transform input data. CrossValidatorModel
- also tracks the metrics for each param map evaluated.
-
- .. versionadded:: 1.4.0
- """
-
- def __init__(self, bestModel, avgMetrics=[], subModels=None):
- super(CrossValidatorModel, self).__init__()
- #: best model from cross validation
- self.bestModel = bestModel
- #: Average cross-validation metrics for each paramMap in
- #: CrossValidator.estimatorParamMaps, in the corresponding order.
- self.avgMetrics = avgMetrics
- #: sub model list from cross validation
- self.subModels = subModels
-
- def _transform(self, dataset):
- return self.bestModel.transform(dataset)
-
- @since("1.4.0")
- def copy(self, extra=None):
- """
- Creates a copy of this instance with a randomly generated uid
- and some extra params. This copies the underlying bestModel,
- creates a deep copy of the embedded paramMap, and
- copies the embedded and extra parameters over.
- It does not copy the extra Params into the subModels.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- bestModel = self.bestModel.copy(extra)
- avgMetrics = self.avgMetrics
- subModels = self.subModels
- return CrossValidatorModel(bestModel, avgMetrics, subModels)
-
- @since("2.3.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @classmethod
- @since("2.3.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java CrossValidatorModel, create and return a Python wrapper of it.
- Used for ML persistence.
- """
- sc = SparkContext._active_spark_context
- bestModel = JavaParams._from_java(java_stage.bestModel())
- avgMetrics = _java2py(sc, java_stage.avgMetrics())
- estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
-
- py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)._set(estimator=estimator)
- py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
-
- if java_stage.hasSubModels():
- py_stage.subModels = [[JavaParams._from_java(sub_model)
- for sub_model in fold_sub_models]
- for fold_sub_models in java_stage.subModels()]
-
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- sc = SparkContext._active_spark_context
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, self.avgMetrics))
- estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
-
- _java_obj.set("evaluator", evaluator)
- _java_obj.set("estimator", estimator)
- _java_obj.set("estimatorParamMaps", epms)
-
- if self.subModels is not None:
- java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models]
- for fold_sub_models in self.subModels]
- _java_obj.setSubModels(java_sub_models)
- return _java_obj
-
-
-class _TrainValidationSplitParams(_ValidatorParams):
- """
- Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.
-
- .. versionadded:: 3.0.0
- """
-
- trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
- validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat)
-
- @since("2.0.0")
- def getTrainRatio(self):
- """
- Gets the value of trainRatio or its default value.
- """
- return self.getOrDefault(self.trainRatio)
-
-
-class TrainValidationSplit(Estimator, _TrainValidationSplitParams, HasParallelism,
- HasCollectSubModels, MLReadable, MLWritable):
- """
- Validation for hyper-parameter tuning. Randomly splits the input dataset into train and
- validation sets, and uses evaluation metric on the validation set to select the best model.
- Similar to :class:`CrossValidator`, but only splits the set once.
-
- >>> from pyspark.ml.classification import LogisticRegression
- >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
- >>> from pyspark.ml.linalg import Vectors
- >>> from pyspark.ml.tuning import TrainValidationSplitModel
- >>> import tempfile
- >>> dataset = spark.createDataFrame(
- ... [(Vectors.dense([0.0]), 0.0),
- ... (Vectors.dense([0.4]), 1.0),
- ... (Vectors.dense([0.5]), 0.0),
- ... (Vectors.dense([0.6]), 1.0),
- ... (Vectors.dense([1.0]), 1.0)] * 10,
- ... ["features", "label"]).repartition(1)
- >>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- >>> evaluator = BinaryClassificationEvaluator()
- >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
- ... parallelism=1, seed=42)
- >>> tvsModel = tvs.fit(dataset)
- >>> tvsModel.getTrainRatio()
- 0.75
- >>> tvsModel.validationMetrics
- [0.5, ...
- >>> path = tempfile.mkdtemp()
- >>> model_path = path + "/model"
- >>> tvsModel.write().save(model_path)
- >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
- >>> tvsModelRead.validationMetrics
- [0.5, ...
- >>> evaluator.evaluate(tvsModel.transform(dataset))
- 0.833...
-
- .. versionadded:: 2.0.0
- """
-
- @keyword_only
- def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, collectSubModels=False, seed=None):
- """
- __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, collectSubModels=False, seed=None)
- """
- super(TrainValidationSplit, self).__init__()
- self._setDefault(trainRatio=0.75, parallelism=1)
- kwargs = self._input_kwargs
- self._set(**kwargs)
-
- @since("2.0.0")
- @keyword_only
- def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,
- parallelism=1, collectSubModels=False, seed=None):
- """
- setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\
- parallelism=1, collectSubModels=False, seed=None):
- Sets params for the train validation split.
- """
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
- @since("2.0.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- return self._set(estimator=value)
-
- @since("2.0.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- return self._set(estimatorParamMaps=value)
-
- @since("2.0.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- return self._set(evaluator=value)
-
- @since("2.0.0")
- def setTrainRatio(self, value):
- """
- Sets the value of :py:attr:`trainRatio`.
- """
- return self._set(trainRatio=value)
-
- def setSeed(self, value):
- """
- Sets the value of :py:attr:`seed`.
- """
- return self._set(seed=value)
-
- def setParallelism(self, value):
- """
- Sets the value of :py:attr:`parallelism`.
- """
- return self._set(parallelism=value)
-
- def setCollectSubModels(self, value):
- """
- Sets the value of :py:attr:`collectSubModels`.
- """
- return self._set(collectSubModels=value)
-
- def _fit(self, dataset):
- est = self.getOrDefault(self.estimator)
- epm = self.getOrDefault(self.estimatorParamMaps)
- numModels = len(epm)
- eva = self.getOrDefault(self.evaluator)
- tRatio = self.getOrDefault(self.trainRatio)
- seed = self.getOrDefault(self.seed)
- randCol = self.uid + "_rand"
- df = dataset.select("*", rand(seed).alias(randCol))
- condition = (df[randCol] >= tRatio)
- validation = df.filter(condition).cache()
- train = df.filter(~condition).cache()
-
- subModels = None
- collectSubModelsParam = self.getCollectSubModels()
- if collectSubModelsParam:
- subModels = [None for i in range(numModels)]
-
- tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
- pool = ThreadPool(processes=min(self.getParallelism(), numModels))
- metrics = [None] * numModels
- for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
- metrics[j] = metric
- if collectSubModelsParam:
- subModels[j] = subModel
-
- train.unpersist()
- validation.unpersist()
-
- if eva.isLargerBetter():
- bestIndex = np.argmax(metrics)
- else:
- bestIndex = np.argmin(metrics)
- bestModel = est.fit(dataset, epm[bestIndex])
- return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
-
- @since("2.0.0")
- def copy(self, extra=None):
- """
- Creates a copy of this instance with a randomly generated uid
- and some extra params. This copies creates a deep copy of
- the embedded paramMap, and copies the embedded and extra parameters over.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- newTVS = Params.copy(self, extra)
- if self.isSet(self.estimator):
- newTVS.setEstimator(self.getEstimator().copy(extra))
- # estimatorParamMaps remain the same
- if self.isSet(self.evaluator):
- newTVS.setEvaluator(self.getEvaluator().copy(extra))
- return newTVS
-
- @since("2.3.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @classmethod
- @since("2.3.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java TrainValidationSplit, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
- trainRatio = java_stage.getTrainRatio()
- seed = java_stage.getSeed()
- parallelism = java_stage.getParallelism()
- collectSubModels = java_stage.getCollectSubModels()
- # Create a new instance of this stage.
- py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- trainRatio=trainRatio, seed=seed, parallelism=parallelism,
- collectSubModels=collectSubModels)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
- :return: Java object equivalent to this instance.
- """
-
- estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
-
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
- self.uid)
- _java_obj.setEstimatorParamMaps(epms)
- _java_obj.setEvaluator(evaluator)
- _java_obj.setEstimator(estimator)
- _java_obj.setTrainRatio(self.getTrainRatio())
- _java_obj.setSeed(self.getSeed())
- _java_obj.setParallelism(self.getParallelism())
- _java_obj.setCollectSubModels(self.getCollectSubModels())
- return _java_obj
-
-
-class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable, MLWritable):
- """
- Model from train validation split.
-
- .. versionadded:: 2.0.0
- """
-
- def __init__(self, bestModel, validationMetrics=[], subModels=None):
- super(TrainValidationSplitModel, self).__init__()
- #: best model from train validation split
- self.bestModel = bestModel
- #: evaluated validation metrics
- self.validationMetrics = validationMetrics
- #: sub models from train validation split
- self.subModels = subModels
-
- def _transform(self, dataset):
- return self.bestModel.transform(dataset)
-
- @since("2.0.0")
- def copy(self, extra=None):
- """
- Creates a copy of this instance with a randomly generated uid
- and some extra params. This copies the underlying bestModel,
- creates a deep copy of the embedded paramMap, and
- copies the embedded and extra parameters over.
- And, this creates a shallow copy of the validationMetrics.
- It does not copy the extra Params into the subModels.
-
- :param extra: Extra parameters to copy to the new instance
- :return: Copy of this instance
- """
- if extra is None:
- extra = dict()
- bestModel = self.bestModel.copy(extra)
- validationMetrics = list(self.validationMetrics)
- subModels = self.subModels
- return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
-
- @since("2.3.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @classmethod
- @since("2.3.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- # Load information from java_stage to the instance.
- sc = SparkContext._active_spark_context
- bestModel = JavaParams._from_java(java_stage.bestModel())
- validationMetrics = _java2py(sc, java_stage.validationMetrics())
- estimator, epms, evaluator = super(TrainValidationSplitModel,
- cls)._from_java_impl(java_stage)
- # Create a new instance of this stage.
- py_stage = cls(bestModel=bestModel,
- validationMetrics=validationMetrics)._set(estimator=estimator)
- py_stage = py_stage._set(estimatorParamMaps=epms)._set(evaluator=evaluator)
-
- if java_stage.hasSubModels():
- py_stage.subModels = [JavaParams._from_java(sub_model)
- for sub_model in java_stage.subModels()]
-
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
- :return: Java object equivalent to this instance.
- """
-
- sc = SparkContext._active_spark_context
- _java_obj = JavaParams._new_java_obj(
- "org.apache.spark.ml.tuning.TrainValidationSplitModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, self.validationMetrics))
- estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
-
- _java_obj.set("evaluator", evaluator)
- _java_obj.set("estimator", estimator)
- _java_obj.set("estimatorParamMaps", epms)
-
- if self.subModels is not None:
- java_sub_models = [sub_model._to_java() for sub_model in self.subModels]
- _java_obj.setSubModels(java_sub_models)
-
- return _java_obj
-
-
-if __name__ == "__main__":
- import doctest
-
- from pyspark.sql import SparkSession
- globs = globals().copy()
-
- # The small batch size here ensures that we see multiple batches,
- # even in these small test examples:
- spark = SparkSession.builder\
- .master("local[2]")\
- .appName("ml.tuning tests")\
- .getOrCreate()
- sc = spark.sparkContext
- globs['sc'] = sc
- globs['spark'] = spark
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- spark.stop()
- if failure_count:
- sys.exit(-1)
-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import json
-import sys
-import os
-import time
-import uuid
-import warnings
-
-if sys.version > '3':
- basestring = str
- unicode = str
- long = int
-
-from pyspark import SparkContext, since
-from pyspark.ml.common import inherit_doc
-from pyspark.sql import SparkSession
-from pyspark.util import VersionUtils
-
-
-def _jvm():
- """
- Returns the JVM view associated with SparkContext. Must be called
- after SparkContext is initialized.
- """
- jvm = SparkContext._jvm
- if jvm:
- return jvm
- else:
- raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
-
-
-class Identifiable(object):
- """
- Object with a unique ID.
- """
-
- def __init__(self):
- #: A unique id for the object.
- self.uid = self._randomUID()
-
- def __repr__(self):
- return self.uid
-
- @classmethod
- def _randomUID(cls):
- """
- Generate a unique unicode id for the object. The default implementation
- concatenates the class name, "_", and 12 random hex chars.
- """
- return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
-
-
-@inherit_doc
-class BaseReadWrite(object):
- """
- Base class for MLWriter and MLReader. Stores information about the SparkContext
- and SparkSession.
-
- .. versionadded:: 2.3.0
- """
-
- def __init__(self):
- self._sparkSession = None
-
- def session(self, sparkSession):
- """
- Sets the Spark Session to use for saving/loading.
- """
- self._sparkSession = sparkSession
- return self
-
- @property
- def sparkSession(self):
- """
- Returns the user-specified Spark Session or the default.
- """
- if self._sparkSession is None:
- self._sparkSession = SparkSession.builder.getOrCreate()
- return self._sparkSession
-
- @property
- def sc(self):
- """
- Returns the underlying `SparkContext`.
- """
- return self.sparkSession.sparkContext
-
-
-@inherit_doc
-class MLWriter(BaseReadWrite):
- """
- Utility class that can save ML instances.
-
- .. versionadded:: 2.0.0
- """
-
- def __init__(self):
- super(MLWriter, self).__init__()
- self.shouldOverwrite = False
-
- def _handleOverwrite(self, path):
- from pyspark.ml.wrapper import JavaWrapper
-
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.util.FileSystemOverwrite")
- wrapper = JavaWrapper(_java_obj)
- wrapper._call_java("handleOverwrite", path, True, self.sparkSession._jsparkSession)
-
- def save(self, path):
- """Save the ML instance to the input path."""
- if self.shouldOverwrite:
- self._handleOverwrite(path)
- self.saveImpl(path)
-
- def saveImpl(self, path):
- """
- save() handles overwriting and then calls this method. Subclasses should override this
- method to implement the actual saving of the instance.
- """
- raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
-
- def overwrite(self):
- """Overwrites if the output path already exists."""
- self.shouldOverwrite = True
- return self
-
-
-@inherit_doc
-class GeneralMLWriter(MLWriter):
- """
- Utility class that can save ML instances in different formats.
-
- .. versionadded:: 2.4.0
- """
-
- def format(self, source):
- """
- Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
- name for export).
- """
- self.source = source
- return self
-
-
-@inherit_doc
-class JavaMLWriter(MLWriter):
- """
- (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
- """
-
- def __init__(self, instance):
- super(JavaMLWriter, self).__init__()
- _java_obj = instance._to_java()
- self._jwrite = _java_obj.write()
-
- def save(self, path):
- """Save the ML instance to the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- self._jwrite.save(path)
-
- def overwrite(self):
- """Overwrites if the output path already exists."""
- self._jwrite.overwrite()
- return self
-
- def option(self, key, value):
- self._jwrite.option(key, value)
- return self
-
- def session(self, sparkSession):
- """Sets the Spark Session to use for saving."""
- self._jwrite.session(sparkSession._jsparkSession)
- return self
-
-
-@inherit_doc
-class GeneralJavaMLWriter(JavaMLWriter):
- """
- (Private) Specialization of :py:class:`GeneralMLWriter` for :py:class:`JavaParams` types
- """
-
- def __init__(self, instance):
- super(GeneralJavaMLWriter, self).__init__(instance)
-
- def format(self, source):
- """
- Specifies the format of ML export (e.g. "pmml", "internal", or the fully qualified class
- name for export).
- """
- self._jwrite.format(source)
- return self
-
-
-@inherit_doc
-class MLWritable(object):
- """
- Mixin for ML instances that provide :py:class:`MLWriter`.
-
- .. versionadded:: 2.0.0
- """
-
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
-
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of 'write().save(path)'."""
- self.write().save(path)
-
-
-@inherit_doc
-class JavaMLWritable(MLWritable):
- """
- (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
- """
-
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
-
-@inherit_doc
-class GeneralJavaMLWritable(JavaMLWritable):
- """
- (Private) Mixin for ML instances that provide :py:class:`GeneralJavaMLWriter`.
- """
-
- def write(self):
- """Returns an GeneralMLWriter instance for this ML instance."""
- return GeneralJavaMLWriter(self)
-
-
-@inherit_doc
-class MLReader(BaseReadWrite):
- """
- Utility class that can load ML instances.
-
- .. versionadded:: 2.0.0
- """
-
- def __init__(self):
- super(MLReader, self).__init__()
-
- def load(self, path):
- """Load the ML instance from the input path."""
- raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
-
-
-@inherit_doc
-class JavaMLReader(MLReader):
- """
- (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
- """
-
- def __init__(self, clazz):
- super(JavaMLReader, self).__init__()
- self._clazz = clazz
- self._jread = self._load_java_obj(clazz).read()
-
- def load(self, path):
- """Load the ML instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- java_obj = self._jread.load(path)
- if not hasattr(self._clazz, "_from_java"):
- raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
- % self._clazz)
- return self._clazz._from_java(java_obj)
-
- def session(self, sparkSession):
- """Sets the Spark Session to use for loading."""
- self._jread.session(sparkSession._jsparkSession)
- return self
-
- @classmethod
- def _java_loader_class(cls, clazz):
- """
- Returns the full class name of the Java ML instance. The default
- implementation replaces "pyspark" by "org.apache.spark" in
- the Python full class name.
- """
- java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
- if clazz.__name__ in ("Pipeline", "PipelineModel"):
- # Remove the last package name "pipeline" for Pipeline and PipelineModel.
- java_package = ".".join(java_package.split(".")[0:-1])
- return java_package + "." + clazz.__name__
-
- @classmethod
- def _load_java_obj(cls, clazz):
- """Load the peer Java object of the ML instance."""
- java_class = cls._java_loader_class(clazz)
- java_obj = _jvm()
- for name in java_class.split("."):
- java_obj = getattr(java_obj, name)
- return java_obj
-
-
-@inherit_doc
-class MLReadable(object):
- """
- Mixin for instances that provide :py:class:`MLReader`.
-
- .. versionadded:: 2.0.0
- """
-
- @classmethod
- def read(cls):
- """Returns an MLReader instance for this class."""
- raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
-
- @classmethod
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
-
-
-@inherit_doc
-class JavaMLReadable(MLReadable):
- """
- (Private) Mixin for instances that provide JavaMLReader.
- """
-
- @classmethod
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
-
-@inherit_doc
-class DefaultParamsWritable(MLWritable):
- """
- Helper trait for making simple :py:class:`Params` types writable. If a :py:class:`Params`
- class stores all data as :py:class:`Param` values, then extending this trait will provide
- a default implementation of writing saved instances of the class.
- This only handles simple :py:class:`Param` types; e.g., it will not handle
- :py:class:`Dataset`. See :py:class:`DefaultParamsReadable`, the counterpart to this trait.
-
- .. versionadded:: 2.3.0
- """
-
- def write(self):
- """Returns a DefaultParamsWriter instance for this class."""
- from pyspark.ml.param import Params
-
- if isinstance(self, Params):
- return DefaultParamsWriter(self)
- else:
- raise TypeError("Cannot use DefautParamsWritable with type %s because it does not " +
- " extend Params.", type(self))
-
-
-@inherit_doc
-class DefaultParamsWriter(MLWriter):
- """
- Specialization of :py:class:`MLWriter` for :py:class:`Params` types
-
- Class for writing Estimators and Transformers whose parameters are JSON-serializable.
-
- .. versionadded:: 2.3.0
- """
-
- def __init__(self, instance):
- super(DefaultParamsWriter, self).__init__()
- self.instance = instance
-
- def saveImpl(self, path):
- DefaultParamsWriter.saveMetadata(self.instance, path, self.sc)
-
- @staticmethod
- def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
- """
- Saves metadata + Params to: path + "/metadata"
-
- - class
- - timestamp
- - sparkVersion
- - uid
- - paramMap
- - defaultParamMap (since 2.4.0)
- - (optionally, extra metadata)
-
- :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
- :param paramMap: If given, this is saved in the "paramMap" field.
- """
- metadataPath = os.path.join(path, "metadata")
- metadataJson = DefaultParamsWriter._get_metadata_to_save(instance,
- sc,
- extraMetadata,
- paramMap)
- sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath)
-
- @staticmethod
- def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
- """
- Helper for :py:meth:`DefaultParamsWriter.saveMetadata` which extracts the JSON to save.
- This is useful for ensemble models which need to save metadata for many sub-models.
-
- .. note:: :py:meth:`DefaultParamsWriter.saveMetadata` for details on what this includes.
- """
- uid = instance.uid
- cls = instance.__module__ + '.' + instance.__class__.__name__
-
- # User-supplied param values
- params = instance._paramMap
- jsonParams = {}
- if paramMap is not None:
- jsonParams = paramMap
- else:
- for p in params:
- jsonParams[p.name] = params[p]
-
- # Default param values
- jsonDefaultParams = {}
- for p in instance._defaultParamMap:
- jsonDefaultParams[p.name] = instance._defaultParamMap[p]
-
- basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
- "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
- "defaultParamMap": jsonDefaultParams}
- if extraMetadata is not None:
- basicMetadata.update(extraMetadata)
- return json.dumps(basicMetadata, separators=[',', ':'])
-
-
-@inherit_doc
-class DefaultParamsReadable(MLReadable):
- """
- Helper trait for making simple :py:class:`Params` types readable.
- If a :py:class:`Params` class stores all data as :py:class:`Param` values,
- then extending this trait will provide a default implementation of reading saved
- instances of the class. This only handles simple :py:class:`Param` types;
- e.g., it will not handle :py:class:`Dataset`. See :py:class:`DefaultParamsWritable`,
- the counterpart to this trait.
-
- .. versionadded:: 2.3.0
- """
-
- @classmethod
- def read(cls):
- """Returns a DefaultParamsReader instance for this class."""
- return DefaultParamsReader(cls)
-
-
-@inherit_doc
-class DefaultParamsReader(MLReader):
- """
- Specialization of :py:class:`MLReader` for :py:class:`Params` types
-
- Default :py:class:`MLReader` implementation for transformers and estimators that
- contain basic (json-serializable) params and no data. This will not handle
- more complex params or types with data (e.g., models with coefficients).
-
- .. versionadded:: 2.3.0
- """
-
- def __init__(self, cls):
- super(DefaultParamsReader, self).__init__()
- self.cls = cls
-
- @staticmethod
- def __get_class(clazz):
- """
- Loads Python class from its name.
- """
- parts = clazz.split('.')
- module = ".".join(parts[:-1])
- m = __import__(module)
- for comp in parts[1:]:
- m = getattr(m, comp)
- return m
-
- def load(self, path):
- metadata = DefaultParamsReader.loadMetadata(path, self.sc)
- py_type = DefaultParamsReader.__get_class(metadata['class'])
- instance = py_type()
- instance._resetUid(metadata['uid'])
- DefaultParamsReader.getAndSetParams(instance, metadata)
- return instance
-
- @staticmethod
- def loadMetadata(path, sc, expectedClassName=""):
- """
- Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata`
-
- :param expectedClassName: If non empty, this is checked against the loaded metadata.
- """
- metadataPath = os.path.join(path, "metadata")
- metadataStr = sc.textFile(metadataPath, 1).first()
- loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
- return loadedVals
-
- @staticmethod
- def _parseMetaData(metadataStr, expectedClassName=""):
- """
- Parse metadata JSON string produced by :py:meth`DefaultParamsWriter._get_metadata_to_save`.
- This is a helper function for :py:meth:`DefaultParamsReader.loadMetadata`.
-
- :param metadataStr: JSON string of metadata
- :param expectedClassName: If non empty, this is checked against the loaded metadata.
- """
- metadata = json.loads(metadataStr)
- className = metadata['class']
- if len(expectedClassName) > 0:
- assert className == expectedClassName, "Error loading metadata: Expected " + \
- "class name {} but found class name {}".format(expectedClassName, className)
- return metadata
-
- @staticmethod
- def getAndSetParams(instance, metadata):
- """
- Extract Params from metadata, and set them in the instance.
- """
- # Set user-supplied param values
- for paramName in metadata['paramMap']:
- param = instance.getParam(paramName)
- paramValue = metadata['paramMap'][paramName]
- instance.set(param, paramValue)
-
- # Set default param values
- majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
- major = majorAndMinorVersions[0]
- minor = majorAndMinorVersions[1]
-
- # For metadata file prior to Spark 2.4, there is no default section.
- if major > 2 or (major == 2 and minor >= 4):
- assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
- "`defaultParamMap` section not found"
-
- for paramName in metadata['defaultParamMap']:
- paramValue = metadata['defaultParamMap'][paramName]
- instance._setDefault(**{paramName: paramValue})
-
- @staticmethod
- def loadParamsInstance(path, sc):
- """
- Load a :py:class:`Params` instance from the given path, and return it.
- This assumes the instance inherits from :py:class:`MLReadable`.
- """
- metadata = DefaultParamsReader.loadMetadata(path, sc)
- pythonClassName = metadata['class'].replace("org.apache.spark", "pyspark")
- py_type = DefaultParamsReader.__get_class(pythonClassName)
- instance = py_type.load(path)
- return instance
-
-
-@inherit_doc
-class HasTrainingSummary(object):
- """
- Base class for models that provides Training summary.
-
- .. versionadded:: 3.0.0
- """
-
- @property
- @since("2.1.0")
- def hasSummary(self):
- """
- Indicates whether a training summary exists for this model
- instance.
- """
- return self._call_java("hasSummary")
-
- @property
- @since("2.1.0")
- def summary(self):
- """
- Gets summary of the model trained on the training set. An exception is thrown if
- no summary exists.
- """
- return (self._call_java("summary"))
-
-"""
-Copyright 2020 Splice Machine, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-from __future__ import print_function
-
-import os
-
-from py4j.java_gateway import java_import
-from pyspark.sql import DataFrame
-from pyspark.sql.types import _parse_datatype_json_string
-from splicemachine.spark.constants import CONVERSIONS
-
-
-[docs]class PySpliceContext:
- """
- This class implements a SpliceMachineContext object (similar to the SparkContext object)
- """
- _spliceSparkPackagesName = "com.splicemachine.spark.splicemachine.*"
-
- def _splicemachineContext(self):
- return self.jvm.com.splicemachine.spark.splicemachine.SplicemachineContext(self.jdbcurl)
-
- def __init__(self, sparkSession, JDBC_URL=None, _unit_testing=False):
- """
- :param JDBC_URL: (string) The JDBC URL Connection String for your Splice Machine Cluster
- :param sparkSession: (sparkContext) A SparkSession object for talking to Spark
- """
-
- if JDBC_URL:
- self.jdbcurl = JDBC_URL
- else:
- try:
- self.jdbcurl = os.environ['BEAKERX_SQL_DEFAULT_JDBC']
- except KeyError as e:
- raise KeyError(
- "Could not locate JDBC URL. If you are not running on the cloud service,"
- "please specify the JDBC_URL=<some url> keyword argument in the constructor"
- )
-
- self._unit_testing = _unit_testing
-
- if not _unit_testing: # Private Internal Argument to Override Using JVM
- self.spark_sql_context = sparkSession._wrapped
- self.spark_session = sparkSession
- self.jvm = self.spark_sql_context._sc._jvm
- java_import(self.jvm, self._spliceSparkPackagesName)
- java_import(
- self.jvm, "org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions")
- java_import(
- self.jvm, "org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils")
- java_import(self.jvm, "scala.collection.JavaConverters._")
- java_import(self.jvm, "com.splicemachine.derby.impl.*")
- java_import(self.jvm, 'org.apache.spark.api.python.PythonUtils')
- self.jvm.com.splicemachine.derby.impl.SpliceSpark.setContext(
- self.spark_sql_context._jsc)
- self.context = self._splicemachineContext()
-
- else:
- from .tests.mocked import MockedScalaContext
- self.spark_sql_context = sparkSession._wrapped
- self.spark_session = sparkSession
- self.jvm = ''
- self.context = MockedScalaContext(self.jdbcurl)
-
-[docs] def toUpper(self, dataframe):
- """
- Returns a dataframe with all of the columns in uppercase
-
- :param dataframe: (Dataframe) The dataframe to convert to uppercase
- """
- for s in dataframe.schema:
- s.name = s.name.upper()
- # You need to re-generate the dataframe for the capital letters to take effect
- return dataframe.rdd.toDF(dataframe.schema)
-
-[docs] def replaceDataframeSchema(self, dataframe, schema_table_name):
- """
- Returns a dataframe with all column names replaced with the proper string case from the DB table
-
- :param dataframe: (Dataframe) A dataframe with column names to convert
- :param schema_table_name: (str) The schema.table with the correct column cases to pull from the database
- :return: (DataFrame) A Spark DataFrame with the replaced schema
- """
- schema = self.getSchema(schema_table_name)
- # Fastest way to replace the column case if changed
- dataframe = dataframe.rdd.toDF(schema)
- return dataframe
-
-[docs] def getConnection(self):
- """
- Return a connection to the database
- """
- return self.context.getConnection()
-
-[docs] def tableExists(self, schema_and_or_table_name, table_name=None):
- """
- Check whether or not a table exists
-
- :Example:
- .. code-block:: python
-
- splice.tableExists('schemaName.tableName')\n
- # or\n
- splice.tableExists('schemaName', 'tableName')
-
- :param schema_and_or_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_and_or_table_name contains only the schema name
- :return: (bool) whether or not the table exists
- """
- if table_name:
- return self.context.tableExists(schema_and_or_table_name, table_name)
- else:
- return self.context.tableExists(schema_and_or_table_name)
-
-[docs] def dropTable(self, schema_and_or_table_name, table_name=None):
- """
- Drop a specified table.
-
- :Example:
- .. code-block:: python
-
- splice.dropTable('schemaName.tableName') \n
- # or\n
- splice.dropTable('schemaName', 'tableName')
-
- :param schema_and_or_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_and_or_table_name contains only the schema name
- :return: None
- """
- if table_name:
- return self.context.dropTable(schema_and_or_table_name, table_name)
- else:
- return self.context.dropTable(schema_and_or_table_name)
-
-[docs] def df(self, sql):
- """
- Return a Spark Dataframe from the results of a Splice Machine SQL Query
-
- :Example:
- .. code-block:: python
-
- df = splice.df('SELECT * FROM MYSCHEMA.TABLE1 WHERE COL2 > 3')
-
- :param sql: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: (Dataframe) A Spark DataFrame containing the results
- """
- return DataFrame(self.context.df(sql), self.spark_sql_context)
-
-[docs] def insert(self, dataframe, schema_table_name, to_upper=False):
- """
- Insert a dataframe into a table (schema.table).
-
- :param dataframe: (Dataframe) The dataframe you would like to insert
- :param schema_table_name: (str) The table in which you would like to insert the DF
- :param to_upper: (bool) If the dataframe columns should be converted to uppercase before table creation
- If False, the table will be created with lower case columns. [Default False]
- :return: None
- """
- if to_upper:
- dataframe = self.toUpper(dataframe)
- return self.context.insert(dataframe._jdf, schema_table_name)
-
-[docs] def insertWithStatus(self, dataframe, schema_table_name, statusDirectory, badRecordsAllowed):
- """
- Insert a dataframe into a table (schema.table) while tracking and limiting records that fail to insert.
- The status directory and number of badRecordsAllowed allow for duplicate primary keys to be
- written to a bad records file. If badRecordsAllowed is set to -1, all bad records will be written
- to the status directory.
-
- :param dataframe: (Dataframe) The dataframe you would like to insert
- :param schema_table_name: (str) The table in which you would like to insert the dataframe
- :param statusDirectory: (str) The status directory where bad records file will be created
- :param badRecordsAllowed: (int) The number of bad records are allowed. -1 for unlimited
- :return: None
- """
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.insert(dataframe._jdf, schema_table_name, statusDirectory, badRecordsAllowed)
-
-[docs] def insertRdd(self, rdd, schema, schema_table_name):
- """
- Insert an rdd into a table (schema.table)
-
- :param rdd: (RDD) The RDD you would like to insert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to insert the RDD
- :return: None
- """
- return self.insert(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def insertRddWithStatus(self, rdd, schema, schema_table_name, statusDirectory, badRecordsAllowed):
- """
- Insert an rdd into a table (schema.table) while tracking and limiting records that fail to insert. \
- The status directory and number of badRecordsAllowed allow for duplicate primary keys to be \
- written to a bad records file. If badRecordsAllowed is set to -1, all bad records will be written \
- to the status directory.
-
- :param rdd: (RDD) The RDD you would like to insert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to insert the dataframe
- :param statusDirectory: (str) The status directory where bad records file will be created
- :param badRecordsAllowed: (int) The number of bad records are allowed. -1 for unlimited
- :return: None
- """
- return self.insertWithStatus(
- self.createDataFrame(rdd, schema),
- schema_table_name,
- statusDirectory,
- badRecordsAllowed
- )
-
-[docs] def upsert(self, dataframe, schema_table_name):
- """
- Upsert the data from a dataframe into a table (schema.table).
-
- :param dataframe: (Dataframe) The dataframe you would like to upsert
- :param schema_table_name: (str) The table in which you would like to upsert the RDD
- :return: None
- """
- # make sure column names are in the correct case
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.upsert(dataframe._jdf, schema_table_name)
-
-[docs] def upsertWithRdd(self, rdd, schema, schema_table_name):
- """
- Upsert the data from an RDD into a table (schema.table).
-
- :param rdd: (RDD) The RDD you would like to upsert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to upsert the RDD
- :return: None
- """
- return self.upsert(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def delete(self, dataframe, schema_table_name):
- """
- Delete records in a dataframe based on joining by primary keys from the data frame.
- Be careful with column naming and case sensitivity.
-
- :param dataframe: (Dataframe) The dataframe you would like to delete
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.context.delete(dataframe._jdf, schema_table_name)
-
-[docs] def deleteWithRdd(self, rdd, schema, schema_table_name):
- """
- Delete records using an rdd based on joining by primary keys from the rdd.
- Be careful with column naming and case sensitivity.
-
- :param rdd: (RDD) The RDD containing the primary keys you would like to delete from the table
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.delete(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def update(self, dataframe, schema_table_name):
- """
- Update data from a dataframe for a specified schema_table_name (schema.table).
- The keys are required for the update and any other columns provided will be updated
- in the rows.
-
- :param dataframe: (Dataframe) The dataframe you would like to update
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- # make sure column names are in the correct case
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.update(dataframe._jdf, schema_table_name)
-
-[docs] def updateWithRdd(self, rdd, schema, schema_table_name):
- """
- Update data from an rdd for a specified schema_table_name (schema.table).
- The keys are required for the update and any other columns provided will be updated
- in the rows.
-
- :param rdd: (RDD) The RDD you would like to use for updating the table
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.update(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def getSchema(self, schema_table_name):
- """
- Return the schema via JDBC.
-
- :param schema_table_name: (str) Table name
- :return: (StructType) PySpark StructType representation of the table
- """
- return _parse_datatype_json_string(self.context.getSchema(schema_table_name).json())
-
-[docs] def execute(self, query_string):
- '''
- execute a query over JDBC
-
- :Example:
- .. code-block:: python
-
- splice.execute('DELETE FROM TABLE1 WHERE col2 > 3')
-
- :param query_string: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: None
- '''
- return self.context.execute(query_string)
-
-[docs] def executeUpdate(self, query_string):
- '''
- execute a dml query:(update,delete,drop,etc)
-
- :Example:
- .. code-block:: python
-
- splice.executeUpdate('DROP TABLE table1')
-
- :param query_string: (string) SQL Query (eg. DROP TABLE table1)
- :return: None
- '''
- return self.context.executeUpdate(query_string)
-
-[docs] def internalDf(self, query_string):
- '''
- SQL to Dataframe translation (Lazy). Runs the query inside Splice Machine and sends the results to the Spark Adapter app
-
- :param query_string: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: (DataFrame) pyspark dataframe contains the result of query_string
- '''
- return DataFrame(self.context.internalDf(query_string), self.spark_sql_context)
-
-[docs] def rdd(self, schema_table_name, column_projection=None):
- """
- Table with projections in Splice mapped to an RDD.
-
- :param schema_table_name: (string) Accessed table
- :param column_projection: (list of strings) Names of selected columns
- :return: (RDD[Row]) the result of the projection
- """
- if column_projection:
- colnames = ', '.join(str(col) for col in column_projection)
- else:
- colnames = '*'
- return self.df('select '+colnames+' from '+schema_table_name).rdd
-
-[docs] def internalRdd(self, schema_table_name, column_projection=None):
- """
- Table with projections in Splice mapped to an RDD.
- Runs the projection inside Splice Machine and sends the results to the Spark Adapter app as an rdd
-
- :param schema_table_name: (str) Accessed table
- :param column_projection: (list of strings) Names of selected columns
- :return: (RDD[Row]) the result of the projection
- """
- if column_projection:
- colnames = ', '.join(str(col) for col in column_projection)
- else:
- colnames = '*'
- return self.internalDf('select '+colnames+' from '+schema_table_name).rdd
-
-[docs] def truncateTable(self, schema_table_name):
- """
- Truncate a table
-
- :param schema_table_name: (str) the full table name in the format "schema.table_name" which will be truncated
- :return: None
- """
- return self.context.truncateTable(schema_table_name)
-
-[docs] def analyzeSchema(self, schema_name):
- """
- Analyze the schema
-
- :param schema_name: (str) schema name which stats info will be collected
- :return: None
- """
- return self.context.analyzeSchema(schema_name)
-
-[docs] def analyzeTable(self, schema_table_name, estimateStatistics=False, samplePercent=10.0):
- """
- Collect stats info on a table
-
- :param schema_table_name: full table name in the format of 'schema.table'
- :param estimateStatistics: will use estimate statistics if True
- :param samplePercent: the percentage or rows to be sampled.
- :return: None
- """
- return self.context.analyzeTable(schema_table_name, estimateStatistics, float(samplePercent))
-
-[docs] def export(self,
- dataframe,
- location,
- compression=False,
- replicationCount=1,
- fileEncoding=None,
- fieldSeparator=None,
- quoteCharacter=None):
- """
- Export a dataFrame in CSV
-
- :param dataframe: (DataFrame)
- :param location: (str) Destination directory
- :param compression: (bool) Whether to compress the output or not
- :param replicationCount: (int) Replication used for HDFS write
- :param fileEncoding: (str) fileEncoding or None, defaults to UTF-8
- :param fieldSeparator: (str) fieldSeparator or None, defaults to ','
- :param quoteCharacter: (str) quoteCharacter or None, defaults to '"'
- :return: None
- """
- return self.context.export(dataframe._jdf, location, compression, replicationCount,
- fileEncoding, fieldSeparator, quoteCharacter)
-
-[docs] def exportBinary(self, dataframe, location, compression, e_format='parquet'):
- """
- Export a dataFrame in binary format
-
- :param dataframe: (DataFrame)
- :param location: (str) Destination directory
- :param compression: (bool) Whether to compress the output or not
- :param e_format: (str) Binary format to be used, currently only 'parquet' is supported. [Default 'parquet']
- :return: None
- """
- return self.context.exportBinary(dataframe._jdf, location, compression, e_format)
-
-[docs] def bulkImportHFile(self, dataframe, schema_table_name, options):
- """
- Bulk Import HFile from a dataframe into a schema.table
-
- :param dataframe: (DataFrame)
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param options: (Dict) Dictionary of options to be passed to --splice-properties; bulkImportDirectory is required
- :return: None
- """
- optionsMap = self.jvm.java.util.HashMap()
- for k, v in options.items():
- optionsMap.put(k, v)
- return self.context.bulkImportHFile(dataframe._jdf, schema_table_name, optionsMap)
-
-[docs] def bulkImportHFileWithRdd(self, rdd, schema, schema_table_name, options):
- """
- Bulk Import HFile from an rdd into a schema.table
-
- :param rdd: (RDD) Input data
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param options: (Dict) Dictionary of options to be passed to --splice-properties; bulkImportDirectory is required
- :return: None
- """
- return self.bulkImportHFile(
- self.createDataFrame(rdd, schema),
- schema_table_name,
- options
- )
-
-[docs] def splitAndInsert(self, dataframe, schema_table_name, sample_fraction):
- """
- Sample the dataframe, split the table, and insert a dataFrame into a schema.table.
- This corresponds to an insert into from select statement
-
- :param dataframe: (DataFrame) Input data
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param sample_fraction: (float) A value between 0 and 1 that specifies the percentage of data in the dataFrame \
- that should be sampled to determine the splits. \
- For example, specify 0.005 if you want 0.5% of the data sampled.
- :return: None
- """
- return self.context.splitAndInsert(dataframe._jdf, schema_table_name, float(sample_fraction))
-
-[docs] def createDataFrame(self, rdd, schema):
- """
- Creates a dataframe from a given rdd and schema.
-
- :param rdd: (RDD) Input data
- :param schema: (StructType) The schema of the rows in the RDD
- :return: (DataFrame) The Spark DataFrame
- """
- return self.spark_session.createDataFrame(rdd, schema)
-
- def _generateDBSchema(self, dataframe, types={}):
- """
- Generate the schema for create table
- """
- # convert keys and values to uppercase in the types dictionary
- types = dict((key.upper(), val) for key, val in types.items())
- db_schema = []
- # convert dataframe to have all uppercase column names
- dataframe = self.toUpper(dataframe)
- # i contains the name and pyspark datatype of the column
- for i in dataframe.schema:
- if i.name.upper() in types:
- print('Column {} is of type {}'.format(
- i.name.upper(), i.dataType))
- dt = types[i.name.upper()]
- else:
- dt = CONVERSIONS[str(i.dataType)]
- db_schema.append((i.name.upper(), dt))
-
- return db_schema
-
- def _getCreateTableSchema(self, schema_table_name, new_schema=False):
- """
- Parse schema for new table; if it is needed, create it
- """
- # try to get schema and table, else set schema to splice
- if '.' in schema_table_name:
- schema, table = schema_table_name.upper().split('.')
- else:
- schema = self.getConnection().getCurrentSchemaName()
- table = schema_table_name.upper()
- # check for new schema
- if new_schema:
- print('Creating schema {}'.format(schema))
- self.execute('CREATE SCHEMA {}'.format(schema))
-
- return schema, table
-
- def _dropTableIfExists(self, schema_table_name, table_name=None):
- """
- Drop table if it exists
- """
- if self.tableExists(schema_and_or_table_name=schema_table_name, table_name=table_name):
- print('Table exists. Dropping table')
- self.dropTable(schema_and_or_table_name=schema_table_name, table_name=table_name)
-
-[docs] def dropTableIfExists(self, schema_table_name, table_name=None):
- """
- Drops a table if exists
-
- :Example:
- .. code-block:: python
-
- splice.dropTableIfExists('schemaName.tableName') \n
- # or\n
- splice.dropTableIfExists('schemaName', 'tableName')
-
- :param schema_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_table_name contains only the schema name
- :return: None
- """
- self._dropTableIfExists(schema_table_name, table_name)
-
- def _jstructtype(self, schema):
- """
- Convert python StructType to java StructType
-
- :param schema: PySpark StructType
- :return: Java Spark StructType
- """
- return self.spark_session._jsparkSession.parseDataType(schema.json())
-
-[docs] def createTable(self, dataframe, schema_table_name, primary_keys=None, create_table_options=None, to_upper=False, drop_table=False):
- """
- Creates a schema.table (schema_table_name) from a dataframe
-
- :param dataframe: The Spark DataFrame to base the table off
- :param schema_table_name: str The schema.table to create
- :param primary_keys: List[str] the primary keys. Default None
- :param create_table_options: str The additional table-level SQL options default None
- :param to_upper: bool If the dataframe columns should be converted to uppercase before table creation. \
- If False, the table will be created with lower case columns. Default False
- :param drop_table: bool whether to drop the table if it exists. Default False. If False and the table exists, the function will throw an exception
- :return: None
-
- """
- if drop_table:
- self._dropTableIfExists(schema_table_name)
- if to_upper:
- dataframe = self.toUpper(dataframe)
- primary_keys = primary_keys if primary_keys else []
- self.createTableWithSchema(schema_table_name, dataframe.schema,
- keys=primary_keys, create_table_options=create_table_options)
-
-[docs] def createTableWithSchema(self, schema_table_name, schema, keys=None, create_table_options=None):
- """
- Creates a schema.table from a schema
-
- :param schema_table_name: str The schema.table to create
- :param schema: (StructType) The schema that describes the columns of the table
- :param keys: (List[str]) The primary keys. Default None
- :param create_table_options: (str) The additional table-level SQL options. Default None
- :return: None
- """
- if keys:
- keys_seq = self.jvm.PythonUtils.toSeq(keys)
- else:
- keys_seq = self.jvm.PythonUtils.toSeq([])
- self.context.createTable(
- schema_table_name,
- self._jstructtype(schema),
- keys_seq,
- create_table_options
- )
-
-
-[docs]class ExtPySpliceContext(PySpliceContext):
- """
- This class implements a SplicemachineContext object from com.splicemachine.spark2 for use outside of the K8s Cloud Service
- """
- _spliceSparkPackagesName = "com.splicemachine.spark2.splicemachine.*"
-
- def _splicemachineContext(self):
- return self.jvm.com.splicemachine.spark2.splicemachine.SplicemachineContext(
- self.jdbcurl, self.kafkaServers, self.kafkaPollTimeout)
-
- def __init__(self, sparkSession, JDBC_URL=None, kafkaServers='localhost:9092', kafkaPollTimeout=20000, _unit_testing=False):
- """
- :param JDBC_URL: (string) The JDBC URL Connection String for your Splice Machine Cluster
- :param sparkSession: (sparkContext) A SparkSession object for talking to Spark
- :param kafkaServers (string) Comma-separated list of Kafka broker addresses in the form host:port
- :param kafkaPollTimeout (int) Number of milliseconds to wait when polling Kafka
- """
- self.kafkaServers = kafkaServers
- self.kafkaPollTimeout = kafkaPollTimeout
- super().__init__(sparkSession, JDBC_URL, _unit_testing)
-
-"""
-Copyright 2020 Splice Machine, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.\n
-
-======================================================================================================================================================================================\n
-
-All functions in this module are accessible through the mlflow object and are to be referenced without the leading underscore as \n
-.. code-block:: python
-
- mlflow.function_name()
-
-For example, the function _current_exp_id() is accessible via\n
-.. code-block:: python
-
- mlflow.current_exp_id()
-
-
-All functions are accessible after running the following import\n
-.. code-block:: python
-
- from splicemachine.mlflow_support import *
-
-Importing anything directly from mlflow before running the above statement will cause problems. After running the above import, you can import additional mlflow submodules as normal\n
-.. code-block:: python
-
- from splicemachine.mlflow_support import *
- from mlflow.tensorflow import autolog
-
-======================================================================================================================================================================================\n
-"""
-import time
-from collections import defaultdict
-from contextlib import contextmanager
-from os import path
-from sys import version as py_version
-
-import gorilla
-import mlflow
-import requests
-from requests.auth import HTTPBasicAuth
-from mleap.pyspark import spark_support
-import pyspark
-import sklearn
-from sklearn.base import BaseEstimator as ScikitModel
-from tensorflow import __version__ as tf_version
-from tensorflow.keras import __version__ as keras_version
-from tensorflow.keras import Model as KerasModel
-
-from splicemachine.mlflow_support.constants import *
-from splicemachine.mlflow_support.utilities import *
-from splicemachine.spark.context import PySpliceContext
-from splicemachine.spark.constants import CONVERSIONS
-from pyspark.sql.dataframe import DataFrame as SparkDF
-from pandas.core.frame import DataFrame as PandasDF
-
-_TESTING = env_vars.get("TESTING", False)
-_TRACKING_URL = get_pod_uri("mlflow", "5001", _TESTING)
-
-_CLIENT = mlflow.tracking.MlflowClient(tracking_uri=_TRACKING_URL)
-mlflow.client = _CLIENT
-
-_GORILLA_SETTINGS = gorilla.Settings(allow_hit=True, store_hit=True)
-_PYTHON_VERSION = py_version.split('|')[0].strip()
-
-[docs]def _mlflow_patch(name):
- """
- Create a MLFlow Patch that applies the default gorilla settings
-
- :param name: destination name under mlflow package
- :return: decorator for patched function
- """
- return gorilla.patch(mlflow, name, settings=_GORILLA_SETTINGS)
-
-
-[docs]def _get_current_run_data():
- """
- Get the data associated with the current run.
- As of MLFLow 1.6, it currently does not support getting run info from the mlflow.active_run object, so we need it
- to be retrieved via the tracking client.
-
- :return: active run data object
- """
- return _CLIENT.get_run(mlflow.active_run().info.run_id).data
-
-
-[docs]@_mlflow_patch('get_run_ids_by_name')
-def _get_run_ids_by_name(run_name, experiment_id=None):
- """
- Gets a run id from the run name. If there are multiple runs with the same name, all run IDs are returned
-
- :param run_name: (str) The name of the run
- :param experiment_id: (int) The experiment to search in. If None, all experiments are searched. [Default None]
- :return: (List[str]) List of run ids
- """
- exps = [experiment_id] if experiment_id else _CLIENT.list_experiments()
- run_ids = []
- for exp in exps:
- for run in _CLIENT.search_runs(exp.experiment_id):
- if run_name == run.data.tags['mlflow.runName']:
- run_ids.append(run.data.tags['Run ID'])
- return run_ids
-
-
-[docs]@_mlflow_patch('register_splice_context')
-def _register_splice_context(splice_context):
- """
- Register a Splice Context for Spark/Database operations (artifact storage, for example)
-
- :param splice_context: (PySpliceContext) splice context to input
- :return: None
- """
- assert isinstance(splice_context, PySpliceContext), "You must pass in a PySpliceContext to this method"
- mlflow._splice_context = splice_context
-
-
-def _check_for_splice_ctx():
- """
- Check to make sure that the user has registered
- a PySpliceContext with the mlflow object before allowing
- spark operations to take place
- """
-
- if not hasattr(mlflow, '_splice_context'):
- raise SpliceMachineException(
- "You must run `mlflow.register_splice_context(pysplice_context) before "
- "you can run this mlflow operation!"
- )
-
-
-[docs]@_mlflow_patch('current_run_id')
-def _current_run_id():
- """
- Retrieve the current run id
-
- :return: (str) the current run id
- """
- return mlflow.active_run().info.run_uuid
-
-
-[docs]@_mlflow_patch('current_exp_id')
-def _current_exp_id():
- """
- Retrieve the current exp id
-
- :return: (int) the current experiment id
- """
- return mlflow.active_run().info.experiment_id
-
-
-[docs]@_mlflow_patch('lp')
-def _lp(key, value):
- """
- Add a shortcut for logging parameters in MLFlow.
-
- :param key: (str) key for the parameter
- :param value: (str) value for the parameter
- :return: None
- """
- if len(str(value)) > 250 or len(str(key)) > 250:
- raise SpliceMachineException(f'It seems your parameter input is too long. The max length is 250 characters.'
- f'Your key is length {len(str(key))} and your value is length {len(str(value))}.')
- mlflow.log_param(key, value)
-
-
-[docs]@_mlflow_patch('lm')
-def _lm(key, value, step=None):
- """
- Add a shortcut for logging metrics in MLFlow.
-
- :param key: (str) key for the parameter
- :param value: (str or int) value for the parameter
- :param step: (int) A single integer step at which to log the specified Metrics. If unspecified, each metric is logged at step zero.
- """
- if len(str(key)) > 250:
- raise SpliceMachineException(f'It seems your metric key is too long. The max length is 250 characters,'
- f'but yours is {len(str(key))}')
- mlflow.log_metric(key, value, step=step)
-
-
-[docs]@_mlflow_patch('log_model')
-def _log_model(model, name='model'):
- """
- Log a trained machine learning model
-
- :param model: (Model) is the trained Spark/SKlearn/H2O/Keras model
- with the current run
- :param name: (str) the run relative name to store the model under. [Deault 'model']
- """
- _check_for_splice_ctx()
- if _get_current_run_data().tags.get('splice.model_name'): # this function has already run
- raise SpliceMachineException("Only one model is permitted per run.")
-
- model_class = str(model.__class__)
- mlflow.set_tag('splice.model_type', model_class)
- mlflow.set_tag('splice.model_py_version', _PYTHON_VERSION)
-
- run_id = mlflow.active_run().info.run_uuid
- if isinstance(model, H2OModel):
- mlflow.set_tag('splice.h2o_version', h2o.__version__)
- H2OUtils.log_h2o_model(mlflow._splice_context, model, name, run_id)
-
- elif isinstance(model, SparkModel):
- mlflow.set_tag('splice.spark_version', pyspark.__version__)
- SparkUtils.log_spark_model(mlflow._splice_context, model, name, run_id)
-
- elif isinstance(model, ScikitModel):
- mlflow.set_tag('splice.sklearn_version', sklearn.__version__)
- SKUtils.log_sklearn_model(mlflow._splice_context, model, name, run_id)
-
- elif isinstance(model, KerasModel): # We can't handle keras models with a different backend
- mlflow.set_tag('splice.keras_version', keras_version)
- mlflow.set_tag('splice.tf_version', tf_version)
- KerasUtils.log_keras_model(mlflow._splice_context, model, name, run_id)
-
- else:
- raise SpliceMachineException('Model type not supported for logging.'
- 'Currently we support logging Spark, H2O, SKLearn and Keras (TF backend) models.'
- 'You can save your model to disk, zip it and run mlflow.log_artifact to save.')
-
- mlflow.set_tag('splice.model_name', name) # read in backend for deployment
-
-[docs]@_mlflow_patch('start_run')
-def _start_run(run_id=None, tags=None, experiment_id=None, run_name=None, nested=False):
- """
- Start a new run
-
- :Example:
- .. code-block:: python
-
- mlflow.start_run(run_name='my_run')\n
- # or\n
- with mlflow.start_run(run_name='my_run'):
- ...
-
-
- :param tags: a dictionary containing metadata about the current run. \
- For example: \
- { \
- 'team': 'pd', \
- 'purpose': 'r&d' \
- }
- :param run_name: (str) an optional name for the run to show up in the MLFlow UI. [Default None]
- :param run_id: (str) if you want to reincarnate an existing run, pass in the run id [Default None]
- :param experiment_id: (int) if you would like to create an experiment/use one for this run [Default None]
- :param nested: (bool) Controls whether run is nested in parent run. True creates a nest run [Default False]
- :return: (ActiveRun) the mlflow active run object
- """
- # Get the current running transaction ID for time travel/data governance
- _check_for_splice_ctx()
- db_connection = mlflow._splice_context.getConnection()
- prepared_statement = db_connection.prepareStatement('CALL SYSCS_UTIL.SYSCS_GET_CURRENT_TRANSACTION()')
- x = prepared_statement.executeQuery()
- x.next()
- timestamp = x.getLong(1)
- prepared_statement.close()
-
- tags = tags if tags else {}
- tags['mlflow.user'] = get_user()
- tags['DB Transaction ID'] = timestamp
-
- orig = gorilla.get_original_attribute(mlflow, "start_run")
- active_run = orig(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=nested)
-
- for key in tags:
- mlflow.set_tag(key, tags[key])
- if not run_id:
- mlflow.set_tag('Run ID', mlflow.active_run().info.run_uuid)
- if run_name:
- mlflow.set_tag('mlflow.runName', run_name)
-
- return active_run
-
-
-[docs]@_mlflow_patch('log_pipeline_stages')
-def _log_pipeline_stages(pipeline):
- """
- Log the pipeline stages of a Spark Pipeline as params for the run
-
- :param pipeline: (PipelineModel) fitted/unitted pipeline
- :return: None
- """
- for stage_number, pipeline_stage in enumerate(SparkUtils.get_stages(pipeline)):
- readable_stage_name = SparkUtils.readable_pipeline_stage(pipeline_stage)
- mlflow.log_param('Stage' + str(stage_number), readable_stage_name)
-
-
-[docs]@_mlflow_patch('log_feature_transformations')
-def _log_feature_transformations(unfit_pipeline):
- """
- Log feature transformations for an unfit spark pipeline
- Logs --> feature movement through the pipeline
-
- :param unfit_pipeline: (PipelineModel) unfit spark pipeline to log
- :return: None
- """
- transformations = defaultdict(lambda: [[], None]) # transformations, outputColumn
-
- for stage in SparkUtils.get_stages(unfit_pipeline):
- input_cols, output_col = SparkUtils.get_cols(stage, get_input=True), SparkUtils.get_cols(stage, get_input=False)
- if input_cols and output_col: # make sure it could parse transformer
- for column in input_cols:
- first_column_found = find_inputs_by_output(transformations, column)
- if first_column_found: # column is not original
- for f in first_column_found:
- transformations[f][1] = output_col
- transformations[f][0].append(
- SparkUtils.readable_pipeline_stage(stage))
- else:
- transformations[column][1] = output_col
- transformations[column][0].append(SparkUtils.readable_pipeline_stage(stage))
-
- for column in transformations:
- param_value = ' -> '.join([column] + transformations[column][0] +
- [transformations[column][1]])
- mlflow.log_param('Column- ' + column, param_value)
-
-
-[docs]@_mlflow_patch('log_model_params')
-def _log_model_params(pipeline_or_model):
- """
- Log the parameters of a fitted spark model or a model stage of a fitted spark pipeline
-
- :param pipeline_or_model: fitted spark pipeline/fitted spark model
- """
- model = SparkUtils.get_model_stage(pipeline_or_model)
-
- mlflow.log_param('model', SparkUtils.readable_pipeline_stage(model))
- if hasattr(model, '_java_obj'):
- verbose_parameters = SparkUtils.parse_string_parameters(model._java_obj.extractParamMap())
- elif hasattr(model, 'getClassifier'):
- verbose_parameters = SparkUtils.parse_string_parameters(
- model.getClassifier()._java_obj.extractParamMap())
- else:
- raise Exception("Could not parse model type: " + str(model))
- for param in verbose_parameters:
- try:
- value = float(verbose_parameters[param])
- mlflow.log_param(param.split('-')[0], value)
- except:
- mlflow.log_param(param.split('-')[0], verbose_parameters[param])
-
-
-[docs]@_mlflow_patch('timer')
-@contextmanager
-def _timer(timer_name, param=True):
- """
- Context manager for logging
-
- :Example:
- .. code-block:: python
-
- with mlflow.timer('my_timer'): \n
- ...
-
- :param timer_name: (str) the name of the timer
- :param param: (bool) whether or not to log the timer as a param (default=True). If false, logs as metric.
- :return: None
- """
- try:
- print(f'Starting Code Block {timer_name}...', end=' ')
- t0 = time.time()
- yield
- finally:
- t1 = time.time() - t0
- # Syntactic Sugar
- (mlflow.log_param if param else mlflow.log_metric)(timer_name, t1)
- print('Done.')
- print(
- f"Code Block {timer_name}:\nRan in {round(t1, 3)} secs\nRan in {round(t1 / 60, 3)} mins"
- )
-
-
-[docs]@_mlflow_patch('download_artifact')
-def _download_artifact(name, local_path, run_id=None):
- """
- Download the artifact at the given run id (active default) + name to the local path
-
- :param name: (str) artifact name to load (with respect to the run)
- :param local_path: (str) local path to download the model to. This path MUST include the file extension
- :param run_id: (str) the run id to download the artifact from. Defaults to active run
- :return: None
- """
- _check_for_splice_ctx()
- file_ext = path.splitext(local_path)[1]
-
- run_id = run_id or mlflow.active_run().info.run_uuid
- blob_data, f_ext = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
-
- if not file_ext: # If the user didn't provide the file (ie entered . as the local_path), fill it in for them
- local_path += f'/{name}.{f_ext}'
-
- with open(local_path, 'wb') as artifact_file:
- artifact_file.write(blob_data)
-
-[docs]@_mlflow_patch('get_model_name')
-def _get_model_name(run_id):
- """
- Gets the model name associated with a run or None
-
- :param run_id: (str) the run_id that the model is stored under
- :return: (str or None) The model name if it exists
- """
- return _CLIENT.get_run(run_id).data.tags.get('splice.model_name')
-
-[docs]@_mlflow_patch('load_model')
-def _load_model(run_id=None, name=None):
- """
- Download and deserialize a serialized model
-
- :param run_id: the id of the run to get a model from
- (the run must have an associated model with it named spark_model)
- :param name: the name of the model in the database
- """
- _check_for_splice_ctx()
- run_id = run_id or mlflow.active_run().info.run_uuid
- name = name or _get_model_name(run_id)
- if not name:
- raise SpliceMachineException(f"Uh Oh! Looks like there isn't a model logged with this run ({run_id})!"
- "If there is, pass in the name= parameter to this function")
- model_blob, file_ext = SparkUtils.retrieve_artifact_stream(mlflow._splice_context, run_id, name)
-
- if file_ext == FileExtensions.spark:
- model = SparkUtils.load_spark_model(mlflow._splice_context, model_blob)
- elif file_ext == FileExtensions.h2o:
- model = H2OUtils.load_h2o_model(model_blob)
- elif file_ext == FileExtensions.sklearn:
- model = SKUtils.load_sklearn_model(model_blob)
- elif file_ext == FileExtensions.keras:
- model = KerasUtils.load_keras_model(model_blob)
- else:
- raise SpliceMachineException(f'Model extension {file_ext} was not a supported model type. '
- f'Supported model extensions are {FileExtensions.get_valid()}')
-
- return model
-
-
-[docs]@_mlflow_patch('log_artifact')
-def _log_artifact(file_name, name=None, run_uuid=None):
- """
- Log an artifact for the active run
-
- :Example:
- .. code-block:: python
-
- with mlflow.start_run():\n
- mlflow.log_artifact('my_image.png')
-
- :param file_name: (str) the name of the file name to log
- :param name: (str) the name of the run relative name to store the model under
- :param run_uuid: (str) the run uuid of a previous run, if none, defaults to current run
- :return: None
-
- :NOTE:
- We do not currently support logging directories. If you would like to log a directory, please zip it first and log the zip file
- """
- _check_for_splice_ctx()
- file_ext = path.splitext(file_name)[1].lstrip('.')
-
- with open(file_name, 'rb') as artifact:
- byte_stream = bytearray(bytes(artifact.read()))
-
- run_id = run_uuid or mlflow.active_run().info.run_uuid
- name = name or file_name
- insert_artifact(mlflow._splice_context, name, byte_stream, run_id, file_ext=file_ext)
-
-
-[docs]@_mlflow_patch('login_director')
-def _login_director(username, password):
- """
- Authenticate into the MLManager Director
-
- :param username: (str) database username
- :param password: (str) database password
- """
- mlflow._basic_auth = HTTPBasicAuth(username, password)
-
-
-[docs]def _initiate_job(payload, endpoint):
- """
- Send a job to the initiation endpoint
-
- :param payload: (dict) JSON payload for POST request
- :param endpoint: (str) REST endpoint to target
- :return: (str) Response text from request
- """
- if not hasattr(mlflow, '_basic_auth'):
- raise Exception(
- "You have not logged into MLManager director."
- " Please run mlflow.login_director(username, password)"
- )
- request = requests.post(
- get_pod_uri('mlflow', 5003, _testing=_TESTING) + endpoint,
- auth=mlflow._basic_auth,
- json=payload,
-
- )
-
- if request.ok:
- print("Your Job has been submitted. View its status on port 5003 (Job Dashboard)")
- print(request.json)
- return request.json
- else:
- print("Error! An error occurred while submitting your job")
- print(request.text)
- return request.text
-
-
-[docs]@_mlflow_patch('deploy_aws')
-def _deploy_aws(app_name, region='us-east-2', instance_type='ml.m5.xlarge',
- run_id=None, instance_count=1, deployment_mode='replace'):
- """
- Queue Job to deploy a run to sagemaker with the
- given run id (found in MLFlow UI or through search API)
-
- :param run_id: the id of the run to deploy. Will default to the current
- run id.
- :param app_name: the name of the app in sagemaker once deployed
- :param region: the sagemaker region to deploy to (us-east-2,
- us-west-1, us-west-2, eu-central-1 supported)
- :param instance_type: the EC2 Sagemaker instance type to deploy on
- (ml.m4.xlarge supported)
- :param instance_count: the number of instances to load balance predictions
- on
- :param deployment_mode: the method to deploy; create=application will fail
- if an app with the name specified already exists; replace=application
- in sagemaker will be replaced with this one if app already exists;
- add=add the specified model to a prexisting application (not recommended)
- """
- # get run from mlflow
- print("Processing...")
- time.sleep(3) # give the mlflow server time to register the artifact, if necessary
-
- supported_aws_regions = ['us-east-2', 'us-west-1', 'us-west-2', 'eu-central-1']
- supported_instance_types = ['ml.m5.xlarge']
- supported_deployment_modes = ['replace', 'add']
-
- # data validation
- if region not in supported_aws_regions:
- raise Exception("Region must be in list: " + str(supported_aws_regions))
- if instance_type not in supported_instance_types:
- raise Exception("Instance type must be in list: " + str(instance_type))
- if deployment_mode not in supported_deployment_modes:
- raise Exception("Deployment mode must be in list: " + str(supported_deployment_modes))
-
- request_payload = {
- 'handler_name': 'DEPLOY_AWS', 'run_id': run_id if run_id else mlflow.active_run().info.run_uuid,
- 'region': region, 'user': get_user(),
- 'instance_type': instance_type, 'instance_count': instance_count,
- 'deployment_mode': deployment_mode, 'app_name': app_name
- }
-
- return _initiate_job(request_payload, '/api/rest/initiate')
-
-
-[docs]@_mlflow_patch('deploy_azure')
-def _deploy_azure(endpoint_name, resource_group, workspace, run_id=None, region='East US',
- cpu_cores=0.1, allocated_ram=0.5, model_name=None):
- """
- Deploy a given run to AzureML.
-
- :param endpoint_name: (str) the name of the endpoint in AzureML when deployed to
- Azure Container Services. Must be unique.
- :param resource_group: (str) Azure Resource Group for model. Automatically created if
- it doesn't exist.
- :param workspace: (str) the AzureML workspace to deploy the model under.
- Will be created if it doesn't exist
- :param run_id: (str) if specified, will deploy a previous run (
- must have an spark model logged). Otherwise, will default to the active run
- :param region: (str) AzureML Region to deploy to: Can be East US, East US 2, Central US,
- West US 2, North Europe, West Europe or Japan East
- :param cpu_cores: (float) Number of CPU Cores to allocate to the instance.
- Can be fractional. Default=0.1
- :param allocated_ram: (float) amount of RAM, in GB, allocated to the container.
- Default=0.5
- :param model_name: (str) If specified, this will be the name of the model in AzureML.
- Otherwise, the model name will be randomly generated.
- """
- supported_regions = ['East US', 'East US 2', 'Central US',
- 'West US 2', 'North Europe', 'West Europe', 'Japan East']
-
- if region not in supported_regions:
- raise Exception("Region must be in list: " + str(supported_regions))
- if cpu_cores <= 0:
- raise Exception("Invalid CPU Count")
- if allocated_ram <= 0:
- raise Exception("Invalid Allocated RAM")
-
- request_payload = {
- 'handler_name': 'DEPLOY_AZURE',
- 'endpoint_name': endpoint_name,
- 'resource_group': resource_group,
- 'workspace': workspace,
- 'run_id': run_id if run_id else mlflow.active_run().info.run_uuid,
- 'cpu_cores': cpu_cores,
- 'allocated_ram': allocated_ram,
- 'model_name': model_name
- }
- return _initiate_job(request_payload, '/api/rest/initiate')
-
-[docs]@_mlflow_patch('deploy_database')
-def _deploy_db(db_schema_name,
- db_table_name,
- run_id,
- primary_key=None,
- df = None,
- create_model_table = False,
- model_cols = None,
- classes=None,
- sklearn_args={},
- verbose=False,
- pred_threshold = None,
- replace=False) -> None:
- """
- Deploy a trained (currently Spark, Sklearn, Keras or H2O) model to the Database.
- This either creates a new table or alters an existing table in the database (depending on parameters passed)
-
- :param db_schema_name: (str) the schema name to deploy to.
- :param db_table_name: (str) the table name to deploy to.
- :param run_id: (str) The run_id to deploy the model on. The model associated with this run will be deployed
- :param primary_key: (List[Tuple[str, str]]) List of column + SQL datatype to use for the primary/composite key. \n
- * If you are deploying to a table that already exists, it must already have a primary key, and this parameter will be ignored. \n
- * If you are creating the table in this function, you MUST pass in a primary key
- :param df: (Spark or Pandas DF) The dataframe used to train the model \n
- | NOTE: The columns in this df are the ones that will be used to create the table unless specified by model_cols
- :param create_model_table: Whether or not to create the table from the dataframe. Default false. This
- Will ONLY be used if the table does not exist and a dataframe is passed in
- :param model_cols: (List[str]) The columns from the table to use for the model. If None, all columns in the table
- will be passed to the model. If specified, the columns will be passed to the model
- IN THAT ORDER. The columns passed here must exist in the table.
- :param classes: (List[str]) The classes (prediction labels) for the model being deployed.\n
- NOTE: If not supplied, the table will have default column names for each class
- :param sklearn_args: (dict{str: str}) Prediction options for sklearn models: \n
- * Available key value options: \n
- * 'predict_call': 'predict', 'predict_proba', or 'transform' \n
- * Determines the function call for the model \n
- * If blank, predict will be used (or transform if model doesn't have predict) \n
- * 'predict_args': 'return_std' or 'return_cov' - For Bayesian and Gaussian models \n
- * Only one can be specified \n
- * If the model does not have the option specified, it will be ignored.
- :param verbose: (bool) Whether or not to print out the queries being created. Helpful for debugging
- :param pred_threshold: (double) A prediction threshold for *Keras* binary classification models \n
- * If the model type isn't Keras, this parameter will be ignored \n
- NOTE: If the model type is Keras, the output layer has 1 node, and pred_threshold is None, \
- you will NOT receive a class prediction, only the output of the final layer (like model.predict()). \
- If you want a class prediction \
- for your binary classification problem, you MUST pass in a threshold.
- :param replace: (bool) whether or not to replace a currently existing model. This param does not yet work
- :return: None\n
-
- This function creates the following IF you are creating a table from the dataframe \n
- * The model table where run_id is the run_id passed in. This table will have a column for each feature in the feature vector. It will also contain:\n
- * USER which is the current user who made the request
- * EVAL_TIME which is the CURRENT_TIMESTAMP
- * the PRIMARY KEY column(s) passed in
- * PREDICTION. The prediction of the model. If the :classes: param is not filled in, this will be default values for classification models
- * A column for each class of the predictor with the value being the probability/confidence of the model if applicable\n
- IF you are deploying to an existing table, the table will be altered to include the columns above. \n
- :NOTE:
- .. code-block:: text
-
- The columns listed above are default value columns.\n
- This means that on a SQL insert into the table, \n
- you do not need to reference or insert values into them.\n
- They are automatically taken care of.\n
- Set verbose=True in the function call for more information
-
- The following will also be created for all deployments: \n
- * A trigger that runs on (after) insertion to the data table that runs an INSERT into the prediction table, \
- calling the PREDICT function, passing in the row of data as well as the schema of the dataset, and the run_id of the model to run \n
- * A trigger that runs on (after) insertion to the prediction table that calls an UPDATE to the row inserted, \
- parsing the prediction probabilities and filling in proper column values
- """
- _check_for_splice_ctx()
-
- # Get the model
- run_id = run_id if run_id else mlflow.active_run().info.run_uuid
- fitted_model = _load_model(run_id)
-
- # Param checking. Can't create model table without a dataframe
- if create_model_table and df is None: # Need to compare to None, truth value of df is ambiguous
- raise SpliceMachineException("If you'd like to create the model table as part of this deployment, you must pass in a dataframe")
- # Make sure primary_key is valid format
- if create_model_table and not primary_key:
- raise SpliceMachineException("If you'd like to create the model table as part of this deployment must provide the primary key(s)")
-
- # FIXME: We need to use the dbConnection so we can set a savepoint and rollback on failure
- classes = classes if classes else []
-
- schema_table_name = f'{db_schema_name}.{db_table_name}'
-
- feature_columns, schema_types = get_feature_columns_and_types(mlflow._splice_context, df, create_model_table,
- model_cols, schema_table_name)
-
-
- # Validate primary key is correct, or that provided table has primary keys
- primary_key = validate_primary_key(mlflow._splice_context, primary_key, db_schema_name, db_table_name) or primary_key
-
- library = get_model_library(fitted_model)
- if library == DBLibraries.MLeap:
- # Mleap needs a dataframe in order to serialize the model
- df = get_df_for_mleap(mlflow._splice_context, schema_table_name, df)
-
- model_type, classes, model_already_exists = ModelUtils[library].prep_model_for_deployment(mlflow._splice_context,
- fitted_model, classes, run_id,
- df, pred_threshold, sklearn_args)
-
-
- print(f'Deploying model {run_id} to table {schema_table_name}')
-
- # Create the schema of the table (we use this a few times)
- schema_str = ''
- for i in feature_columns:
- schema_str += f'\t{i} {CONVERSIONS[schema_types[str(i)]]},'
-
- try:
- # Create/Alter table 1: DATA
- if create_model_table:
- print('Creating model table ...', end=' ')
- create_model_deployment_table(mlflow._splice_context, run_id, schema_table_name, schema_str, classes, primary_key, model_type, verbose)
- print('Done.')
- else:
- print('Altering provided table for deployment')
- alter_model_table(mlflow._splice_context, run_id, schema_table_name, classes, model_type, verbose)
-
- # Create Trigger 1: model prediction
- print('Creating model prediction trigger ...', end=' ')
- if model_type in (H2OModelType.KEY_VALUE, SklearnModelType.KEY_VALUE, KerasModelType.KEY_VALUE):
- create_vti_prediction_trigger(mlflow._splice_context, schema_table_name, run_id, feature_columns, schema_types,
- schema_str, primary_key, classes, model_type, sklearn_args, pred_threshold, verbose)
- else:
- create_prediction_trigger(mlflow._splice_context, schema_table_name, run_id, feature_columns, schema_types,
- schema_str, primary_key, model_type, verbose)
- print('Done.')
-
- if model_type in (SparkModelType.CLASSIFICATION, SparkModelType.CLUSTERING_WITH_PROB,
- H2OModelType.CLASSIFICATION):
- # Create Trigger 2: model parsing
- print('Creating parsing trigger ...', end=' ')
- create_parsing_trigger(mlflow._splice_context, schema_table_name, primary_key, run_id, classes, model_type, verbose)
- print('Done.')
-
- add_model_to_metadata(mlflow._splice_context, run_id, schema_table_name)
-
-
- except Exception as e:
- import traceback
- exc = 'Model deployment failed. Rolling back transactions.\n'
- print(exc)
- drop_tables_on_failure(mlflow._splice_context, schema_table_name, run_id, model_already_exists)
- if not verbose:
- exc += 'For more insight into the SQL statement that generated this error, rerun with verbose=True'
- traceback.print_exc()
- raise SpliceMachineException(exc)
-
- print('Model Deployed.')
-
-[docs]@_mlflow_patch('get_deployed_models')
-def _get_deployed_models() -> PandasDF:
- """
- Get the currently deployed models in the database
- :return: Pandas df
- """
-
- return mlflow._splice_context.df(
- """
- SELECT * FROM MLMANAGER.LIVE_MODEL_STATUS
- """
- ).toPandas()
-
-
-def apply_patches():
- """
- Apply all the Gorilla Patches; \
- All Gorilla Patched MUST be predixed with '_' before their destination in MLflow
- """
- targets = [_register_splice_context, _lp, _lm, _timer, _log_artifact, _log_feature_transformations,
- _log_model_params, _log_pipeline_stages, _log_model, _load_model, _download_artifact,
- _start_run, _current_run_id, _current_exp_id, _deploy_aws, _deploy_azure, _deploy_db, _login_director,
- _get_run_ids_by_name, _get_deployed_models]
-
- for target in targets:
- gorilla.apply(gorilla.Patch(mlflow, target.__name__.lstrip('_'), target, settings=_GORILLA_SETTINGS))
-
-
-def main():
- mlflow.set_tracking_uri(_TRACKING_URL)
- apply_patches()
-
-
-main()
-
-import random
-from IPython.display import IFrame, HTML, display
-from pyspark import SparkContext
-from os import environ as env_vars
-
-[docs]def hide_toggle(toggle_next=False):
- """
- Function to add a toggle at the bottom of Jupyter Notebook cells to allow the entire cell to be collapsed.
- :param toggle_next: Bool determine if the toggle should affect the current cell or the next cell
- Usage: from splicemachine.stats.utilities import hide_toggle
- hide_toggle()
- """
- this_cell = """$('div.cell.code_cell.rendered.selected')"""
- next_cell = this_cell + '.next()'
-
- toggle_text = 'Toggle show/hide' # text shown on toggle link
- target_cell = this_cell # target cell to control with toggle
- js_hide_current = '' # bit of JS to permanently hide code in current cell (only when toggling next cell)
-
- if toggle_next:
- target_cell = next_cell
- toggle_text += ' next cell'
- js_hide_current = this_cell + '.find("div.input").hide();'
-
- js_f_name = 'code_toggle_{}'.format(str(random.randint(1, 2 ** 64)))
-
- html = """
- <script>
- function {f_name}() {{
- {cell_selector}.find('div.input').toggle();
- }}
- {js_hide_current}
- </script>
- <a href="javascript:{f_name}()"><button style='color:black'>{toggle_text}</button></a>
- """.format(
- f_name=js_f_name,
- cell_selector=target_cell,
- js_hide_current=js_hide_current,
- toggle_text=toggle_text
- )
-
- return HTML(html)
-
-[docs]def get_mlflow_ui():
- """Display the MLflow UI as an IFrame"""
- display(HTML('<font size=\"+1\"><a target=\"_blank\" href=/mlflow>MLFlow UI</a></font>'))
- return IFrame(src='/mlflow', width='100%', height='500px')
-
-[docs]def get_spark_ui(port=None, spark_session=None):
- """
- Display the Spark Jobs UI as an IFrame at a specific port
- :param port: (int or str) The port of the desired spark session
- :param spark_session: (SparkSession) Optionally the Spark Session associated with the desired UI
- :return:
- """
- if port:
- pass
- elif spark_session:
- port = spark_session.sparkContext.uiWebUrl.split(':')[-1]
- elif SparkContext._active_spark_context:
- port = SparkContext._active_spark_context.uiWebUrl.split(':')[-1]
- else:
- raise Exception('No parameters passed and no active Spark Session found.\n'
- 'Either pass in the active Spark Session into the "spark_session" parameter or the port of that session into the "port" parameter.\n'\
- 'You can find the port by running spark.sparkContext.uiWebUrl and taking the number after the \':\'')
- user = env_vars.get('JUPYTERHUB_USER','user')
- display(HTML(f'<font size=\"+1\"><a target=\"_blank\" href=/splicejupyter/user/{user}/sparkmonitor/{port}>Spark UI</a></font>'))
- return IFrame(src=f'/splicejupyter/user/{user}/sparkmonitor/{port}', width='100%', height='500px')
-
-"""
-Copyright 2020 Splice Machine, Inc.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-from __future__ import print_function
-
-import os
-
-from py4j.java_gateway import java_import
-from pyspark.sql import DataFrame
-from pyspark.sql.types import _parse_datatype_json_string
-from splicemachine.spark.constants import CONVERSIONS
-
-
-[docs]class PySpliceContext:
- """
- This class implements a SpliceMachineContext object (similar to the SparkContext object)
- """
- _spliceSparkPackagesName = "com.splicemachine.spark.splicemachine.*"
-
- def _splicemachineContext(self):
- return self.jvm.com.splicemachine.spark.splicemachine.SplicemachineContext(self.jdbcurl)
-
- def __init__(self, sparkSession, JDBC_URL=None, _unit_testing=False):
- """
- :param JDBC_URL: (string) The JDBC URL Connection String for your Splice Machine Cluster
- :param sparkSession: (sparkContext) A SparkSession object for talking to Spark
- """
-
- if JDBC_URL:
- self.jdbcurl = JDBC_URL
- else:
- try:
- self.jdbcurl = os.environ['BEAKERX_SQL_DEFAULT_JDBC']
- except KeyError as e:
- raise KeyError(
- "Could not locate JDBC URL. If you are not running on the cloud service,"
- "please specify the JDBC_URL=<some url> keyword argument in the constructor"
- )
-
- self._unit_testing = _unit_testing
-
- if not _unit_testing: # Private Internal Argument to Override Using JVM
- self.spark_sql_context = sparkSession._wrapped
- self.spark_session = sparkSession
- self.jvm = self.spark_sql_context._sc._jvm
- java_import(self.jvm, self._spliceSparkPackagesName)
- java_import(
- self.jvm, "org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions")
- java_import(
- self.jvm, "org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils")
- java_import(self.jvm, "scala.collection.JavaConverters._")
- java_import(self.jvm, "com.splicemachine.derby.impl.*")
- java_import(self.jvm, 'org.apache.spark.api.python.PythonUtils')
- self.jvm.com.splicemachine.derby.impl.SpliceSpark.setContext(
- self.spark_sql_context._jsc)
- self.context = self._splicemachineContext()
-
- else:
- from .tests.mocked import MockedScalaContext
- self.spark_sql_context = sparkSession._wrapped
- self.spark_session = sparkSession
- self.jvm = ''
- self.context = MockedScalaContext(self.jdbcurl)
-
-[docs] def toUpper(self, dataframe):
- """
- Returns a dataframe with all of the columns in uppercase
-
- :param dataframe: (Dataframe) The dataframe to convert to uppercase
- """
- for s in dataframe.schema:
- s.name = s.name.upper()
- # You need to re-generate the dataframe for the capital letters to take effect
- return dataframe.rdd.toDF(dataframe.schema)
-
-[docs] def replaceDataframeSchema(self, dataframe, schema_table_name):
- """
- Returns a dataframe with all column names replaced with the proper string case from the DB table
-
- :param dataframe: (Dataframe) A dataframe with column names to convert
- :param schema_table_name: (str) The schema.table with the correct column cases to pull from the database
- :return: (DataFrame) A Spark DataFrame with the replaced schema
- """
- schema = self.getSchema(schema_table_name)
- # Fastest way to replace the column case if changed
- dataframe = dataframe.rdd.toDF(schema)
- return dataframe
-
-[docs] def getConnection(self):
- """
- Return a connection to the database
- """
- return self.context.getConnection()
-
-[docs] def tableExists(self, schema_and_or_table_name, table_name=None):
- """
- Check whether or not a table exists
-
- :Example:
- .. code-block:: python
-
- splice.tableExists('schemaName.tableName')\n
- # or\n
- splice.tableExists('schemaName', 'tableName')
-
- :param schema_and_or_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_and_or_table_name contains only the schema name
- :return: (bool) whether or not the table exists
- """
- if table_name:
- return self.context.tableExists(schema_and_or_table_name, table_name)
- else:
- return self.context.tableExists(schema_and_or_table_name)
-
-[docs] def dropTable(self, schema_and_or_table_name, table_name=None):
- """
- Drop a specified table.
-
- :Example:
- .. code-block:: python
-
- splice.dropTable('schemaName.tableName') \n
- # or\n
- splice.dropTable('schemaName', 'tableName')
-
- :param schema_and_or_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_and_or_table_name contains only the schema name
- :return: None
- """
- if table_name:
- return self.context.dropTable(schema_and_or_table_name, table_name)
- else:
- return self.context.dropTable(schema_and_or_table_name)
-
-[docs] def df(self, sql):
- """
- Return a Spark Dataframe from the results of a Splice Machine SQL Query
-
- :Example:
- .. code-block:: python
-
- df = splice.df('SELECT * FROM MYSCHEMA.TABLE1 WHERE COL2 > 3')
-
- :param sql: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: (Dataframe) A Spark DataFrame containing the results
- """
- return DataFrame(self.context.df(sql), self.spark_sql_context)
-
-[docs] def insert(self, dataframe, schema_table_name, to_upper=False):
- """
- Insert a dataframe into a table (schema.table).
-
- :param dataframe: (Dataframe) The dataframe you would like to insert
- :param schema_table_name: (str) The table in which you would like to insert the DF
- :param to_upper: (bool) If the dataframe columns should be converted to uppercase before table creation
- If False, the table will be created with lower case columns. [Default False]
- :return: None
- """
- if to_upper:
- dataframe = self.toUpper(dataframe)
- return self.context.insert(dataframe._jdf, schema_table_name)
-
-[docs] def insertWithStatus(self, dataframe, schema_table_name, statusDirectory, badRecordsAllowed):
- """
- Insert a dataframe into a table (schema.table) while tracking and limiting records that fail to insert.
- The status directory and number of badRecordsAllowed allow for duplicate primary keys to be
- written to a bad records file. If badRecordsAllowed is set to -1, all bad records will be written
- to the status directory.
-
- :param dataframe: (Dataframe) The dataframe you would like to insert
- :param schema_table_name: (str) The table in which you would like to insert the dataframe
- :param statusDirectory: (str) The status directory where bad records file will be created
- :param badRecordsAllowed: (int) The number of bad records are allowed. -1 for unlimited
- :return: None
- """
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.insert(dataframe._jdf, schema_table_name, statusDirectory, badRecordsAllowed)
-
-[docs] def insertRdd(self, rdd, schema, schema_table_name):
- """
- Insert an rdd into a table (schema.table)
-
- :param rdd: (RDD) The RDD you would like to insert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to insert the RDD
- :return: None
- """
- return self.insert(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def insertRddWithStatus(self, rdd, schema, schema_table_name, statusDirectory, badRecordsAllowed):
- """
- Insert an rdd into a table (schema.table) while tracking and limiting records that fail to insert. \
- The status directory and number of badRecordsAllowed allow for duplicate primary keys to be \
- written to a bad records file. If badRecordsAllowed is set to -1, all bad records will be written \
- to the status directory.
-
- :param rdd: (RDD) The RDD you would like to insert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to insert the dataframe
- :param statusDirectory: (str) The status directory where bad records file will be created
- :param badRecordsAllowed: (int) The number of bad records are allowed. -1 for unlimited
- :return: None
- """
- return self.insertWithStatus(
- self.createDataFrame(rdd, schema),
- schema_table_name,
- statusDirectory,
- badRecordsAllowed
- )
-
-[docs] def upsert(self, dataframe, schema_table_name):
- """
- Upsert the data from a dataframe into a table (schema.table).
-
- :param dataframe: (Dataframe) The dataframe you would like to upsert
- :param schema_table_name: (str) The table in which you would like to upsert the RDD
- :return: None
- """
- # make sure column names are in the correct case
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.upsert(dataframe._jdf, schema_table_name)
-
-[docs] def upsertWithRdd(self, rdd, schema, schema_table_name):
- """
- Upsert the data from an RDD into a table (schema.table).
-
- :param rdd: (RDD) The RDD you would like to upsert
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) The table in which you would like to upsert the RDD
- :return: None
- """
- return self.upsert(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def delete(self, dataframe, schema_table_name):
- """
- Delete records in a dataframe based on joining by primary keys from the data frame.
- Be careful with column naming and case sensitivity.
-
- :param dataframe: (Dataframe) The dataframe you would like to delete
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.context.delete(dataframe._jdf, schema_table_name)
-
-[docs] def deleteWithRdd(self, rdd, schema, schema_table_name):
- """
- Delete records using an rdd based on joining by primary keys from the rdd.
- Be careful with column naming and case sensitivity.
-
- :param rdd: (RDD) The RDD containing the primary keys you would like to delete from the table
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.delete(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def update(self, dataframe, schema_table_name):
- """
- Update data from a dataframe for a specified schema_table_name (schema.table).
- The keys are required for the update and any other columns provided will be updated
- in the rows.
-
- :param dataframe: (Dataframe) The dataframe you would like to update
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- # make sure column names are in the correct case
- dataframe = self.replaceDataframeSchema(dataframe, schema_table_name)
- return self.context.update(dataframe._jdf, schema_table_name)
-
-[docs] def updateWithRdd(self, rdd, schema, schema_table_name):
- """
- Update data from an rdd for a specified schema_table_name (schema.table).
- The keys are required for the update and any other columns provided will be updated
- in the rows.
-
- :param rdd: (RDD) The RDD you would like to use for updating the table
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Splice Machine Table
- :return: None
- """
- return self.update(
- self.createDataFrame(rdd, schema),
- schema_table_name
- )
-
-[docs] def getSchema(self, schema_table_name):
- """
- Return the schema via JDBC.
-
- :param schema_table_name: (str) Table name
- :return: (StructType) PySpark StructType representation of the table
- """
- return _parse_datatype_json_string(self.context.getSchema(schema_table_name).json())
-
-[docs] def execute(self, query_string):
- '''
- execute a query over JDBC
-
- :Example:
- .. code-block:: python
-
- splice.execute('DELETE FROM TABLE1 WHERE col2 > 3')
-
- :param query_string: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: None
- '''
- return self.context.execute(query_string)
-
-[docs] def executeUpdate(self, query_string):
- '''
- execute a dml query:(update,delete,drop,etc)
-
- :Example:
- .. code-block:: python
-
- splice.executeUpdate('DROP TABLE table1')
-
- :param query_string: (string) SQL Query (eg. DROP TABLE table1)
- :return: None
- '''
- return self.context.executeUpdate(query_string)
-
-[docs] def internalDf(self, query_string):
- '''
- SQL to Dataframe translation (Lazy). Runs the query inside Splice Machine and sends the results to the Spark Adapter app
-
- :param query_string: (str) SQL Query (eg. SELECT * FROM table1 WHERE col2 > 3)
- :return: (DataFrame) pyspark dataframe contains the result of query_string
- '''
- return DataFrame(self.context.internalDf(query_string), self.spark_sql_context)
-
-[docs] def rdd(self, schema_table_name, column_projection=None):
- """
- Table with projections in Splice mapped to an RDD.
-
- :param schema_table_name: (string) Accessed table
- :param column_projection: (list of strings) Names of selected columns
- :return: (RDD[Row]) the result of the projection
- """
- if column_projection:
- colnames = ', '.join(str(col) for col in column_projection)
- else:
- colnames = '*'
- return self.df('select '+colnames+' from '+schema_table_name).rdd
-
-[docs] def internalRdd(self, schema_table_name, column_projection=None):
- """
- Table with projections in Splice mapped to an RDD.
- Runs the projection inside Splice Machine and sends the results to the Spark Adapter app as an rdd
-
- :param schema_table_name: (str) Accessed table
- :param column_projection: (list of strings) Names of selected columns
- :return: (RDD[Row]) the result of the projection
- """
- if column_projection:
- colnames = ', '.join(str(col) for col in column_projection)
- else:
- colnames = '*'
- return self.internalDf('select '+colnames+' from '+schema_table_name).rdd
-
-[docs] def truncateTable(self, schema_table_name):
- """
- Truncate a table
-
- :param schema_table_name: (str) the full table name in the format "schema.table_name" which will be truncated
- :return: None
- """
- return self.context.truncateTable(schema_table_name)
-
-[docs] def analyzeSchema(self, schema_name):
- """
- Analyze the schema
-
- :param schema_name: (str) schema name which stats info will be collected
- :return: None
- """
- return self.context.analyzeSchema(schema_name)
-
-[docs] def analyzeTable(self, schema_table_name, estimateStatistics=False, samplePercent=10.0):
- """
- Collect stats info on a table
-
- :param schema_table_name: full table name in the format of 'schema.table'
- :param estimateStatistics: will use estimate statistics if True
- :param samplePercent: the percentage or rows to be sampled.
- :return: None
- """
- return self.context.analyzeTable(schema_table_name, estimateStatistics, float(samplePercent))
-
-[docs] def export(self,
- dataframe,
- location,
- compression=False,
- replicationCount=1,
- fileEncoding=None,
- fieldSeparator=None,
- quoteCharacter=None):
- """
- Export a dataFrame in CSV
-
- :param dataframe: (DataFrame)
- :param location: (str) Destination directory
- :param compression: (bool) Whether to compress the output or not
- :param replicationCount: (int) Replication used for HDFS write
- :param fileEncoding: (str) fileEncoding or None, defaults to UTF-8
- :param fieldSeparator: (str) fieldSeparator or None, defaults to ','
- :param quoteCharacter: (str) quoteCharacter or None, defaults to '"'
- :return: None
- """
- return self.context.export(dataframe._jdf, location, compression, replicationCount,
- fileEncoding, fieldSeparator, quoteCharacter)
-
-[docs] def exportBinary(self, dataframe, location, compression, e_format='parquet'):
- """
- Export a dataFrame in binary format
-
- :param dataframe: (DataFrame)
- :param location: (str) Destination directory
- :param compression: (bool) Whether to compress the output or not
- :param e_format: (str) Binary format to be used, currently only 'parquet' is supported. [Default 'parquet']
- :return: None
- """
- return self.context.exportBinary(dataframe._jdf, location, compression, e_format)
-
-[docs] def bulkImportHFile(self, dataframe, schema_table_name, options):
- """
- Bulk Import HFile from a dataframe into a schema.table
-
- :param dataframe: (DataFrame)
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param options: (Dict) Dictionary of options to be passed to --splice-properties; bulkImportDirectory is required
- :return: None
- """
- optionsMap = self.jvm.java.util.HashMap()
- for k, v in options.items():
- optionsMap.put(k, v)
- return self.context.bulkImportHFile(dataframe._jdf, schema_table_name, optionsMap)
-
-[docs] def bulkImportHFileWithRdd(self, rdd, schema, schema_table_name, options):
- """
- Bulk Import HFile from an rdd into a schema.table
-
- :param rdd: (RDD) Input data
- :param schema: (StructType) The schema of the rows in the RDD
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param options: (Dict) Dictionary of options to be passed to --splice-properties; bulkImportDirectory is required
- :return: None
- """
- return self.bulkImportHFile(
- self.createDataFrame(rdd, schema),
- schema_table_name,
- options
- )
-
-[docs] def splitAndInsert(self, dataframe, schema_table_name, sample_fraction):
- """
- Sample the dataframe, split the table, and insert a dataFrame into a schema.table.
- This corresponds to an insert into from select statement
-
- :param dataframe: (DataFrame) Input data
- :param schema_table_name: (str) Full table name in the format of "schema.table"
- :param sample_fraction: (float) A value between 0 and 1 that specifies the percentage of data in the dataFrame \
- that should be sampled to determine the splits. \
- For example, specify 0.005 if you want 0.5% of the data sampled.
- :return: None
- """
- return self.context.splitAndInsert(dataframe._jdf, schema_table_name, float(sample_fraction))
-
-[docs] def createDataFrame(self, rdd, schema):
- """
- Creates a dataframe from a given rdd and schema.
-
- :param rdd: (RDD) Input data
- :param schema: (StructType) The schema of the rows in the RDD
- :return: (DataFrame) The Spark DataFrame
- """
- return self.spark_session.createDataFrame(rdd, schema)
-
- def _generateDBSchema(self, dataframe, types={}):
- """
- Generate the schema for create table
- """
- # convert keys and values to uppercase in the types dictionary
- types = dict((key.upper(), val) for key, val in types.items())
- db_schema = []
- # convert dataframe to have all uppercase column names
- dataframe = self.toUpper(dataframe)
- # i contains the name and pyspark datatype of the column
- for i in dataframe.schema:
- if i.name.upper() in types:
- print('Column {} is of type {}'.format(
- i.name.upper(), i.dataType))
- dt = types[i.name.upper()]
- else:
- dt = CONVERSIONS[str(i.dataType)]
- db_schema.append((i.name.upper(), dt))
-
- return db_schema
-
- def _getCreateTableSchema(self, schema_table_name, new_schema=False):
- """
- Parse schema for new table; if it is needed, create it
- """
- # try to get schema and table, else set schema to splice
- if '.' in schema_table_name:
- schema, table = schema_table_name.upper().split('.')
- else:
- schema = self.getConnection().getCurrentSchemaName()
- table = schema_table_name.upper()
- # check for new schema
- if new_schema:
- print('Creating schema {}'.format(schema))
- self.execute('CREATE SCHEMA {}'.format(schema))
-
- return schema, table
-
- def _dropTableIfExists(self, schema_table_name, table_name=None):
- """
- Drop table if it exists
- """
- if self.tableExists(schema_and_or_table_name=schema_table_name, table_name=table_name):
- print('Table exists. Dropping table')
- self.dropTable(schema_and_or_table_name=schema_table_name, table_name=table_name)
-
-[docs] def dropTableIfExists(self, schema_table_name, table_name=None):
- """
- Drops a table if exists
-
- :Example:
- .. code-block:: python
-
- splice.dropTableIfExists('schemaName.tableName') \n
- # or\n
- splice.dropTableIfExists('schemaName', 'tableName')
-
- :param schema_table_name: (str) Pass the schema name in this param when passing the table_name param,
- or pass schemaName.tableName in this param without passing the table_name param
- :param table_name: (optional) (str) Table Name, used when schema_table_name contains only the schema name
- :return: None
- """
- self._dropTableIfExists(schema_table_name, table_name)
-
- def _jstructtype(self, schema):
- """
- Convert python StructType to java StructType
-
- :param schema: PySpark StructType
- :return: Java Spark StructType
- """
- return self.spark_session._jsparkSession.parseDataType(schema.json())
-
-[docs] def createTable(self, dataframe, schema_table_name, primary_keys=None, create_table_options=None, to_upper=False, drop_table=False):
- """
- Creates a schema.table (schema_table_name) from a dataframe
-
- :param dataframe: The Spark DataFrame to base the table off
- :param schema_table_name: str The schema.table to create
- :param primary_keys: List[str] the primary keys. Default None
- :param create_table_options: str The additional table-level SQL options default None
- :param to_upper: bool If the dataframe columns should be converted to uppercase before table creation. \
- If False, the table will be created with lower case columns. Default False
- :param drop_table: bool whether to drop the table if it exists. Default False. If False and the table exists, the function will throw an exception
- :return: None
-
- """
- if drop_table:
- self._dropTableIfExists(schema_table_name)
- if to_upper:
- dataframe = self.toUpper(dataframe)
- primary_keys = primary_keys if primary_keys else []
- self.createTableWithSchema(schema_table_name, dataframe.schema,
- keys=primary_keys, create_table_options=create_table_options)
-
-[docs] def createTableWithSchema(self, schema_table_name, schema, keys=None, create_table_options=None):
- """
- Creates a schema.table from a schema
-
- :param schema_table_name: str The schema.table to create
- :param schema: (StructType) The schema that describes the columns of the table
- :param keys: (List[str]) The primary keys. Default None
- :param create_table_options: (str) The additional table-level SQL options. Default None
- :return: None
- """
- if keys:
- keys_seq = self.jvm.PythonUtils.toSeq(keys)
- else:
- keys_seq = self.jvm.PythonUtils.toSeq([])
- self.context.createTable(
- schema_table_name,
- self._jstructtype(schema),
- keys_seq,
- create_table_options
- )
-
-
-[docs]class ExtPySpliceContext(PySpliceContext):
- """
- This class implements a SplicemachineContext object from com.splicemachine.spark2 for use outside of the K8s Cloud Service
- """
- _spliceSparkPackagesName = "com.splicemachine.spark2.splicemachine.*"
-
- def _splicemachineContext(self):
- return self.jvm.com.splicemachine.spark2.splicemachine.SplicemachineContext(
- self.jdbcurl, self.kafkaServers, self.kafkaPollTimeout)
-
- def __init__(self, sparkSession, JDBC_URL=None, kafkaServers='localhost:9092', kafkaPollTimeout=20000, _unit_testing=False):
- """
- :param JDBC_URL: (string) The JDBC URL Connection String for your Splice Machine Cluster
- :param sparkSession: (sparkContext) A SparkSession object for talking to Spark
- :param kafkaServers (string) Comma-separated list of Kafka broker addresses in the form host:port
- :param kafkaPollTimeout (int) Number of milliseconds to wait when polling Kafka
- """
- self.kafkaServers = kafkaServers
- self.kafkaPollTimeout = kafkaPollTimeout
- super().__init__(sparkSession, JDBC_URL, _unit_testing)
-
-import warnings
-from multiprocessing.pool import ThreadPool
-import random
-from collections import defaultdict, OrderedDict
-
-import numpy as np
-import pandas as pd
-import scipy.stats as st
-import graphviz
-from numpy.linalg import eigh
-from tqdm import tqdm
-from IPython.display import HTML
-import pyspark_dist_explore as dist_explore
-from pyspark.sql import functions as F, Row
-from pyspark.sql.types import DoubleType, ArrayType, IntegerType, StringType
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param
-from pyspark.ml import Pipeline, Transformer
-from pyspark.ml.classification import LogisticRegressionModel
-from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
-from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler, Bucketizer, PCA
-from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator, BinaryClassificationEvaluator
-from pyspark import keyword_only
-from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
-
-
-[docs]def get_confusion_matrix(spark, TP, TN, FP, FN):
- """
- function that shows you a device called a confusion matrix... will be helpful when evaluating.
- It allows you to see how well your model performs
- :param TP: True Positives
- :param TN: True Negatives
- :param FP: False Positives
- :param FN: False Negatives
- """
-
- row = Row('', 'True', 'False')
- confusion_matrix = spark._wrapped.createDataFrame([row('True', TP, FN),
- row('False', FP, TN)])
- return confusion_matrix
-
-
-[docs]class SpliceBaseEvaluator(object):
- """
- Base ModelEvaluator
- """
-
- def __init__(self, spark, evaluator, supported_metrics, predictionCol="prediction",
- labelCol="label"):
- """
- Constructor for SpliceBaseEvaluator
- :param spark: spark from zeppelin
- :param evaluator: evaluator class from spark
- :param supported_metrics: supported metrics list
- :param predictionCol: prediction column
- :param labelCol: label column
- """
- self.spark = spark
- self.ev = evaluator
- self.prediction_col = predictionCol
- self.label = labelCol
- self.supported_metrics = supported_metrics
- self.avgs = defaultdict(list)
-
-[docs] def input(self, predictions_dataframe):
- """
- Input a dataframe
- :param ev: evaluator class
- :param predictions_dataframe: input df
- :return: none
- """
- for metric in self.supported_metrics:
- evaluator = self.ev(
- labelCol=self.label, predictionCol=self.prediction_col, metricName=metric)
- self.avgs[metric].append(evaluator.evaluate(predictions_dataframe))
- print("Current {metric}: {metric_val}".format(metric=metric,
- metric_val=self.avgs
- [metric][-1]))
-
-[docs] def get_results(self, as_dict=False):
- """
- Get Results
- :param dict: whether to get results in a dict or not
- :return: dictionary
- """
- computed_avgs = {}
- for key in self.avgs:
- computed_avgs[key] = np.mean(self.avgs[key])
-
- if as_dict:
- return computed_avgs
-
- metrics_row = Row(*self.supported_metrics)
- computed_row = metrics_row(*[float(computed_avgs[i]) for i in self.supported_metrics])
- return self.spark._wrapped.createDataFrame([computed_row])
-
-
-[docs]class SpliceBinaryClassificationEvaluator(SpliceBaseEvaluator):
- def __init__(self, spark, predictionCol="prediction", labelCol="label", confusion_matrix=True):
- self.avg_tp = []
- self.avg_tn = []
- self.avg_fn = []
- self.avg_fp = []
- self.confusion_matrix = confusion_matrix
-
- supported = ["areaUnderROC", "areaUnderPR", 'TPR', 'SPC', 'PPV', 'NPV', 'FPR', 'FDR', 'FNR', 'ACC', 'F1', 'MCC']
- SpliceBaseEvaluator.__init__(self, spark, BinaryClassificationEvaluator, supported, predictionCol=predictionCol,
- labelCol=labelCol)
-
-[docs] def input(self, predictions_dataframe):
- """
- Evaluate actual vs Predicted in a dataframe
- :param predictions_dataframe: the dataframe containing the label and the predicition
- """
- for metric in self.supported_metrics:
- if metric in ['areaUnderROC', 'areaUnderPR']:
- evaluator = self.ev(labelCol=self.label, rawPredictionCol=self.prediction_col, metricName=metric)
-
- self.avgs[metric].append(evaluator.evaluate(predictions_dataframe))
- print("Current {metric}: {metric_val}".format(metric=metric,
- metric_val=self.avgs
- [metric][-1]))
-
- pred_v_lab = predictions_dataframe.select(self.label,
- self.prediction_col) # Select the actual and the predicted labels
-
- # Add confusion stats
- self.avg_tp.append(pred_v_lab[(pred_v_lab[self.label] == 1)
- & (pred_v_lab[self.prediction_col] == 1)].count())
- self.avg_tn.append(pred_v_lab[(pred_v_lab[self.label] == 0)
- & (pred_v_lab[self.prediction_col] == 0)].count())
- self.avg_fp.append(pred_v_lab[(pred_v_lab[self.label] == 1)
- & (pred_v_lab[self.prediction_col] == 0)].count())
- self.avg_fn.append(pred_v_lab[(pred_v_lab[self.label] == 0)
- & (pred_v_lab[self.prediction_col] == 1)].count())
-
- TP = np.mean(self.avg_tp)
- TN = np.mean(self.avg_tn)
- FP = np.mean(self.avg_fp)
- FN = np.mean(self.avg_fn)
-
- self.avgs['TPR'].append(float(TP) / (TP + FN))
- self.avgs['SPC'].append(float(TP) / (TP + FN))
- self.avgs['TNR'].append(float(TN) / (TN + FP))
- self.avgs['PPV'].append(float(TP) / (TP + FP))
- self.avgs['NPV'].append(float(TN) / (TN + FN))
- self.avgs['FNR'].append(float(FN) / (FN + TP))
- self.avgs['FPR'].append(float(FP) / (FP + TN))
- self.avgs['FDR'].append(float(FP) / (FP + TP))
- self.avgs['FOR'].append(float(FN) / (FN + TN))
- self.avgs['ACC'].append(float(TP + TN) / (TP + FN + FP + TN))
- self.avgs['F1'].append(float(2 * TP) / (2 * TP + FP + FN))
- self.avgs['MCC'].append(float(TP * TN - FP * FN) / np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN)))
-
- if self.confusion_matrix:
- get_confusion_matrix(
- self.spark,
- float(TP),
- float(TN),
- float(FP),
- float(FN)
- ).show()
-
-[docs] def plotROC(self, fittedEstimator, ax):
- """
- Plots the receiver operating characteristic curve for the trained classifier
- :param fittedEstimator: fitted logistic regression model
- :param ax: matplotlib axis object
- :return: axis with ROC plot
- """
- if fittedEstimator.__class__ == LogisticRegressionModel:
- trainingSummary = fittedEstimator.summary
- roc = trainingSummary.roc.toPandas()
- ax.plot(roc['FPR'], roc['TPR'], label='Training set areaUnderROC: \n' + str(trainingSummary.areaUnderROC))
- ax.set_xlabel('False Positive Rate')
- ax.set_ylabel('True Positive Rate')
- ax.set_title('ROC Curve')
- ax.legend()
- return ax
- else:
- raise NotImplementedError("Only supported for Logistic Regression Models")
-
-
-[docs]class SpliceRegressionEvaluator(SpliceBaseEvaluator):
- """
- Splice Regression Evaluator
- """
-
- def __init__(self, spark, predictionCol="prediction", labelCol="label"):
- supported = ['rmse', 'mse', 'r2', 'mae']
- SpliceBaseEvaluator.__init__(self, spark, RegressionEvaluator, supported, predictionCol=predictionCol,
- labelCol=labelCol)
-
-
-[docs]class SpliceMultiClassificationEvaluator(SpliceBaseEvaluator):
- def __init__(self, spark, predictionCol="prediction", labelCol="label"):
- supported = ["f1", "weightedPrecision", "weightedRecall", "accuracy"]
- SpliceBaseEvaluator.__init__(self, spark, MulticlassClassificationEvaluator, supported,
- predictionCol=predictionCol, labelCol=labelCol)
-
-
-[docs]class DecisionTreeVisualizer(object):
- """
- Visualize a decision tree, either in code like format, or graphviz
- """
-
-[docs] @staticmethod
- def feature_importance(spark, model, dataset, featuresCol="features"):
- """
- Return a dataframe containing the relative importance of each feature
- :param model:
- :param dataframe:
- :param featureCol:
- :return: dataframe containing importance
- """
- import pandas as pd
- featureImp = model.featureImportances
- list_extract = []
- for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
- list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][
- i]
- varlist = pd.DataFrame(list_extract)
- varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
- return spark._wrapped.createDataFrame((varlist.sort_values('score', ascending=False)))
-
-[docs] @staticmethod
- def visualize(
- model,
- feature_column_names,
- label_names,
- size=None,
- horizontal=False,
- tree_name='tree',
- visual=False,
- ):
- """
- Visualize a decision tree, either in a code like format, or graphviz
- :param model: the fitted decision tree classifier
- :param feature_column_names: (List[str]) column names for features
- You can access these feature names by using your VectorAssembler (in PySpark) and calling it's .getInputCols() function
- :param label_names: (List[str]) labels vector (below avg, above avg)
- :param size: tuple(int,int) The size of the graph. If unspecified, graphviz will automatically assign a size
- :param horizontal: (Bool) if the tree should be rendered horizontally
- :param tree_name: the name you would like to call the tree
- :param visual: bool, true if you want a graphviz pdf containing your file
- :return dot: The graphvis object
- """
-
- tree_to_json = DecisionTreeVisualizer.replacer(model.toDebugString,
- ['feature ' + str(i) for i in
- range(len(feature_column_names) - 1, -1, -1)],
- reversed(feature_column_names))
-
- tree_to_json = DecisionTreeVisualizer.replacer(tree_to_json,
- [f'Predict: {str(i)}.0' for i in
- range(len(label_names) - 1, -1, -1)],
- reversed(label_names))
- if not visual:
- return tree_to_json
-
- dot = graphviz.Digraph(comment='Decision Tree')
- if size:
- dot.attr(size=size)
- if horizontal:
- dot.attr(rankdir="LR")
- dot.node_attr.update(color='lightblue2', style='filled')
- json_d = DecisionTreeVisualizer.tree_json(tree_to_json)
-
- DecisionTreeVisualizer.add_node(dot, '', '', json_d,
- realroot=True)
- dot.render(tree_name)
- print(f'Generated pdf file of tree. You can view it in your Jupyter directory under {dot.filepath}.pdf\n')
- dot.view()
- return (dot)
-
-[docs] @staticmethod
- def replacer(string, bad, good):
- """
- Replace every string in "bad" with the corresponding string in "good"
- :param string: string to replace in
- :param bad: array of strings to replace
- :param good: array of strings to replace with
- :return:
- """
-
- for (b, g) in zip(bad, good):
- string = string.replace(b, g)
- return string
-
-[docs] @staticmethod
- def add_node(
- dot,
- parent,
- node_hash,
- root,
- realroot=False,
- ):
- """
- Traverse through the .debugString json and generate a graphviz tree
- :param dot: dot file object
- :param parent: not used currently
- :param node_hash: unique node id
- :param root: the root of tree
- :param realroot: whether or not it is the real root, or a recursive root
- :return:
- """
-
- node_id = str(hash(root['name'])) + str(random.randint(0, 100))
- if root:
- dot.node(node_id, root['name'])
- if not realroot:
- dot.edge(node_hash, node_id)
- if root.get('children'):
- if not root['children'][0].get('children'):
- DecisionTreeVisualizer.add_node(dot, root['name'],
- node_id, root['children'][0])
- else:
- DecisionTreeVisualizer.add_node(dot, root['name'],
- node_id, root['children'][0])
- DecisionTreeVisualizer.add_node(dot, root['name'],
- node_id, root['children'][1])
-
-[docs] @staticmethod
- def parse(lines):
- """
- Lines in debug string
- :param lines:
- :return: block json
- """
-
- block = []
- while lines:
-
- if lines[0].startswith('If'):
- bl = ' '.join(lines.pop(0).split()[1:]).replace('(', ''
- ).replace(')', '')
- block.append({'name': bl,
- 'children': DecisionTreeVisualizer.parse(lines)})
-
- if lines[0].startswith('Else'):
- be = ' '.join(lines.pop(0).split()[1:]).replace('('
- , '').replace(')', '')
- block.append({'name': be,
- 'children': DecisionTreeVisualizer.parse(lines)})
- elif not lines[0].startswith(('If', 'Else')):
- block2 = lines.pop(0)
- block.append({'name': block2})
- else:
- break
- return block
-
-[docs] @staticmethod
- def tree_json(tree):
- """
- Generate a JSON representation of a decision tree
- :param tree: tree debug string
- :return: json
- """
-
- data = []
- for line in tree.splitlines():
- if line.strip():
- line = line.strip()
- data.append(line)
- else:
- break
- if not line:
- break
- res = [{'name': 'Root',
- 'children': DecisionTreeVisualizer.parse(data[1:])}]
- return res[0]
-
-
-[docs]def inspectTable(spliceMLCtx, sql, topN=5):
- """Inspect the values of the columns of the table (dataframe) returned from the sql query
- :param spliceMLCtx: SpliceMLContext
- :param sql: sql string to execute
- :param topN: the number of most frequent elements of a column to return, defaults to 5
- """
- df = spliceMLCtx.df(sql)
- df = df.repartition(50)
-
- for _col, _type in df.dtypes:
- print("------Inspecting column {} -------- ".format(_col))
-
- val_counts = df.groupby(_col).count()
- val_counts.show()
- val_counts.orderBy(F.desc('count')).limit(topN).show()
-
- if _type == 'double' or _type == 'int':
- df.select(_col).describe().show()
-
-
-
-# Custom Transformers
-[docs]class Rounder(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
- """Transformer to round predictions for ordinal regression
- Follows: https://spark.apache.org/docs/latest/ml-pipeline.html#transformers
- :param Transformer: Inherited Class
- :param HasInputCol: Inherited Class
- :param HasOutputCol: Inherited Class
- :return: Transformed Dataframe with rounded predictionCol
- Example:
- --------
- >>> from pyspark.sql.session import SparkSession
- >>> from splicemachine.stats.stats import Rounder
- >>> spark = SparkSession.builder.getOrCreate()
- >>> dataset = spark.createDataFrame(
- ... [(0.2, 0.0),
- ... (1.2, 1.0),
- ... (1.6, 2.0),
- ... (1.1, 0.0),
- ... (3.1, 0.0)],
- ... ["prediction", "label"])
- >>> dataset.show()
- +----------+-----+
- |prediction|label|
- +----------+-----+
- | 0.2| 0.0|
- | 1.2| 1.0|
- | 1.6| 2.0|
- | 1.1| 0.0|
- | 3.1| 0.0|
- +----------+-----+
- >>> rounder = Rounder(predictionCol = "prediction", labelCol = "label", clipPreds = True)
- >>> rounder.transform(dataset).show()
- +----------+-----+
- |prediction|label|
- +----------+-----+
- | 0.0| 0.0|
- | 1.0| 1.0|
- | 2.0| 2.0|
- | 1.0| 0.0|
- | 2.0| 0.0|
- +----------+-----+
- >>> rounderNoClip = Rounder(predictionCol = "prediction", labelCol = "label", clipPreds = False)
- >>> rounderNoClip.transform(dataset).show()
- +----------+-----+
- |prediction|label|
- +----------+-----+
- | 0.0| 0.0|
- | 1.0| 1.0|
- | 2.0| 2.0|
- | 1.0| 0.0|
- | 3.0| 0.0|
- +----------+-----+
- """
-
- @keyword_only
- def __init__(self, predictionCol="prediction", labelCol="label", clipPreds=True, maxLabel=None, minLabel=None):
- """initialize self
- :param predictionCol: column containing predictions, defaults to "prediction"
- :param labelCol: column containing labels, defaults to "label"
- :param clipPreds: clip all predictions above a specified maximum value
- :param maxLabel: optional: the maximum value for the prediction column, otherwise uses the maximum of the labelCol, defaults to None
- :param minLabel: optional: the minimum value for the prediction column, otherwise uses the maximum of the labelCol, defaults to None
- """
- """initialize self
- :param predictionCol: column containing predictions, defaults to "prediction"
- :param labelCol: column containing labels, defaults to "label"
- """
- super(Rounder, self).__init__()
- self.labelCol = labelCol
- self.predictionCol = predictionCol
- self.clipPreds = clipPreds
- self.maxLabel = maxLabel
- self.minLabel = minLabel
-
-[docs] @keyword_only
- def setParams(self, predictionCol="prediction", labelCol="label"):
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-[docs] def _transform(self, dataset):
- """
- Rounds the predictions to the nearest integer value, and also clips them at the max/min value observed in label
- :param dataset: dataframe with predictions to be rounded
- :return: DataFrame with rounded predictions
- """
- labelCol = self.labelCol
- predictionCol = self.predictionCol
-
- if self.clipPreds:
- max_label = self.maxLabel if self.maxLabel else dataset.agg({labelCol: 'max'}).collect()[0][0]
- min_label = self.minLabel if self.minLabel else dataset.agg({labelCol: 'min'}).collect()[0][0]
- clip = F.udf(lambda x: float(max_label) if x > max_label else (float(min_label) if x < min_label else x),
- DoubleType())
-
- dataset = dataset.withColumn(predictionCol, F.round(clip(F.col(predictionCol))))
- else:
- dataset = dataset.withColumn(predictionCol, F.round(F.col(predictionCol)))
-
- return dataset
-
-
-[docs]class OneHotDummies(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
- """
- Transformer to generate dummy columns for categorical variables as a part of a preprocessing pipeline
- Follows: https://spark.apache.org/docs/latest/ml-pipeline.html#transformers
- :param Transformer: Inherited Classes
- :param HasInputCol: Inherited Classes
- :param HasOutputCol: Inherited Classes
- :return: pyspark DataFrame
- """
-
- @keyword_only
- def __init__(self, inputCol=None, outputCol=None):
- """
- Assigns variables to parameters passed
- :param inputCol: Sparse vector returned by OneHotEncoders, defaults to None
- :param outputCol: string base to append to output columns names, defaults to None
- """
- super(OneHotDummies, self).__init__()
- # kwargs = self._input_kwargs
- # self.setParams(**kwargs)
- self.inputCol = inputCol
- self.outputCol = outputCol
- self.outcols = []
-
-[docs] @keyword_only
- def setParams(self, inputCol=None, outputCol=None):
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-[docs] def _transform(self, dataset):
-
- """iterates through the number of categorical values of a categorical variable and adds dummy columns for each of those categories
- For a string categorical column, include this transformer in the following workflow: StringIndexer -> OneHotEncoder -> OneHotDummies -> PCA/ Learning Algorithm
- :param dataset: PySpark DataFrame where inputCol is the column returned by by OneHotEncoders
- :return: original DataFrame with M additional columns where M = # of categories for this variable
- """
- out_col_suffix = self.outputCol # this is what I want to append to the column name
- col_name = self.inputCol
-
- out_col_base = col_name + out_col_suffix # this is the base for the n outputted columns
-
- # helper functions
- get_num_categories = F.udf(lambda x: int(x.size), IntegerType())
- get_active_index = F.udf(lambda x: int(x.indices[0]), IntegerType())
- check_active_index = F.udf(lambda active, i: int(active == i), IntegerType())
-
- num_categories = dataset.select(
- get_num_categories(col_name).alias('num_categories')).distinct() # this returns a dataframe
- if num_categories.count() == 1: # making sure all the sparse vectors have the same number of categories
- num_categories_int = num_categories.collect()[0]['num_categories'] # now this is an int
-
- dataset = dataset.withColumn('active_index', get_active_index(col_name))
- column_names = []
- for i in range(num_categories_int): # Now I'm going to make a column for each category
- column_name = out_col_base + '_' + str(i)
- dataset = dataset.withColumn(column_name, check_active_index(F.col('active_index'), F.lit(i)))
- column_names.append(column_name)
-
- dataset = dataset.drop('active_index')
- self.outcols = column_names
- return dataset
-
-
-
-
-[docs]class IndReconstructer(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
- """Transformer to reconstruct String Index from OneHotDummy Columns. This can be used as a part of a Pipeline Ojbect
- Follows: https://spark.apache.org/docs/latest/ml-pipeline.html#transformers
- :param Transformer: Inherited Class
- :param HasInputCol: Inherited Class
- :param HasOutputCol: Inherited Class
- :return: Transformed PySpark Dataframe With Original String Indexed Variables
- """
-
- @keyword_only
- def __init__(self, inputCol=None, outputCol=None):
- super(IndReconstructer, self).__init__()
- # kwargs = self._input_kwargs
- # self.setParams(**kwargs)
- self.inputCol = inputCol
- self.outputCol = outputCol
-
-[docs] @keyword_only
- def setParams(self, inputCol=None, outputCol=None):
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-[docs] def _transform(self, dataset):
- """
- iterates through the oneHotDummy columns for a categorical variable and returns the index of the column that is closest to one. This corresponds to the stringIndexed value of this feature for this row.
- :param dataset: dataset with OneHotDummy columns
- :return: DataFrame with column corresponding to a categorical indexed column
- """
- inColBase = self.inputCol
- outCol = self.outputCol
-
- closestToOne = F.udf(lambda x: abs(x - 1), DoubleType())
- dummies = dataset.select(*[closestToOne(i).alias(i) if inColBase in i else i for i in dataset.columns if
- inColBase in i or i == 'SUBJECT'])
- dummies = dummies.withColumn('least_val',
- F.lit(F.least(*[F.col(i) for i in dataset.columns if inColBase in i])))
-
- dummies = dummies.select(
- *[(F.col(i) == F.col('least_val')).alias(i + 'isind') if inColBase in i else i for i in dataset.columns if
- inColBase in i or i == 'SUBJECT'])
- getActive = F.udf(lambda row: [idx for idx, val in enumerate(row) if val][0], IntegerType())
- dummies = dummies.withColumn(outCol, getActive(
- F.struct(*[F.col(x) for x in dummies.columns if x != 'SUBJECT']).alias('struct')))
- dataset = dataset.join(dummies.select(['SUBJECT', outCol]), 'SUBJECT')
-
- return dataset
-
-
-[docs]class OverSampler(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
- """Transformer to oversample datapoints with minority labels
- Follows: https://spark.apache.org/docs/latest/ml-pipeline.html#transformers
- :param Transformer: Inherited Class
- :param HasInputCol: Inherited Class
- :param HasOutputCol: Inherited Class
- :return: PySpark Dataframe with labels in approximately equal ratios
- Example:
- -------
- >>> from pyspark.sql import functions as F
- >>> from pyspark.sql.session import SparkSession
- >>> from pyspark.stats.linalg import Vectors
- >>> from splicemachine.stats.stats import OverSampler
- >>> spark = SparkSession.builder.getOrCreate()
- >>> df = spark.createDataFrame(
- ... [(Vectors.dense([0.0]), 0.0),
- ... (Vectors.dense([0.5]), 0.0),
- ... (Vectors.dense([0.4]), 1.0),
- ... (Vectors.dense([0.6]), 1.0),
- ... (Vectors.dense([1.0]), 1.0)] * 10,
- ... ["features", "Class"])
- >>> df.groupBy(F.col("Class")).count().orderBy("count").show()
- +-----+-----+
- |Class|count|
- +-----+-----+
- | 0.0| 20|
- | 1.0| 30|
- +-----+-----+
- >>> oversampler = OverSampler(labelCol = "Class", strategy = "auto")
- >>> oversampler.transform(df).groupBy("Class").count().show()
- +-----+-----+
- |Class|count|
- +-----+-----+
- | 0.0| 29|
- | 1.0| 30|
- +-----+-----+
- """
-
- @keyword_only
- def __init__(self, labelCol=None, strategy="auto", randomState=None):
- """Initialize self
- :param labelCol: Label Column name, defaults to None
- :param strategy: defaults to "auto", strategy to resample the dataset:
- • Only currently supported for "auto" Corresponds to random samples with repleaement
- :param randomState: sets the seed of sample algorithm
- """
- super(OverSampler, self).__init__()
- self.labelCol = labelCol
- self.strategy = strategy
- self.withReplacement = True if strategy == "auto" else False
- self.randomState = np.random.randn() if not randomState else randomState
-
-[docs] @keyword_only
- def setParams(self, labelCol=None, strategy="auto"):
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-[docs] def _transform(self, dataset):
- """
- Oversamples
- :param dataset: dataframe to be oversampled
- :return: DataFrame with the resampled data points
- """
- if self.strategy == "auto":
-
- pd_value_counts = dataset.groupBy(F.col(self.labelCol)).count().toPandas()
-
- label_type = dataset.schema[self.labelCol].dataType.simpleString()
- types_dic = {'int': int, "string": str, "double": float}
-
- maxidx = pd_value_counts['count'].idxmax()
-
- self.majorityLabel = types_dic[label_type](pd_value_counts[self.labelCol].loc[maxidx])
- majorityData = dataset.filter(F.col(self.labelCol) == self.majorityLabel)
-
- returnData = None
-
- if len(pd_value_counts) == 1:
- raise ValueError(
- f'Error! Number of labels = {len(pd_value_counts)}. Cannot Oversample with this number of classes')
- elif len(pd_value_counts) == 2:
- minidx = pd_value_counts['count'].idxmin()
- minorityLabel = types_dic[label_type](pd_value_counts[self.labelCol].loc[minidx])
- ratio = pd_value_counts['count'].loc[maxidx] / pd_value_counts['count'].loc[minidx] * 1.0
-
- returnData = majorityData.union(
- dataset.filter(F.col(self.labelCol) == minorityLabel).sample(withReplacement=self.withReplacement,
- fraction=ratio, seed=self.randomState))
-
- else:
- minority_labels = list(pd_value_counts.drop(maxidx)[self.labelCol])
-
- ratios = {types_dic[label_type](minority_label): pd_value_counts['count'].loc[maxidx] / float(
- pd_value_counts[pd_value_counts[self.labelCol] == minority_label]['count']) for minority_label in
- minority_labels}
-
- for (minorityLabel, ratio) in ratios.items():
- minorityData = dataset.filter(F.col(self.labelCol) == minorityLabel).sample(
- withReplacement=self.withReplacement, fraction=ratio, seed=self.randomState)
- if not returnData:
- returnData = majorityData.union(minorityData)
- else:
- returnData = returnData.union(minorityData)
-
- return returnData
- else:
- raise NotImplementedError("Only auto is currently implemented")
-
-
-[docs]class OverSampleCrossValidator(CrossValidator):
- """Class to perform Cross Validation model evaluation while over-sampling minority labels.
- Example:
- -------
- >>> from pyspark.sql.session import SparkSession
- >>> from pyspark.stats.classification import LogisticRegression
- >>> from pyspark.stats.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
- >>> from pyspark.stats.linalg import Vectors
- >>> from splicemachine.stats.stats import OverSampleCrossValidator
- >>> spark = SparkSession.builder.getOrCreate()
- >>> dataset = spark.createDataFrame(
- ... [(Vectors.dense([0.0]), 0.0),
- ... (Vectors.dense([0.5]), 0.0),
- ... (Vectors.dense([0.4]), 1.0),
- ... (Vectors.dense([0.6]), 1.0),
- ... (Vectors.dense([1.0]), 1.0)] * 10,
- ... ["features", "label"])
- >>> lr = LogisticRegression()
- >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- >>> PRevaluator = BinaryClassificationEvaluator(metricName = 'areaUnderPR')
- >>> AUCevaluator = BinaryClassificationEvaluator(metricName = 'areaUnderROC')
- >>> ACCevaluator = MulticlassClassificationEvaluator(metricName="accuracy")
- >>> cv = OverSampleCrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=AUCevaluator, altEvaluators = [PRevaluator, ACCevaluator],parallelism=2,seed = 1234)
- >>> cvModel = cv.fit(dataset)
- >>> print(cvModel.avgMetrics)
- [(0.5, [0.5888888888888888, 0.3888888888888889]), (0.806878306878307, [0.8556863149300125, 0.7055555555555556])]
- >>> print(AUCevaluator.evaluate(cvModel.transform(dataset)))
- 0.8333333333333333
- """
-
- def __init__(self, estimator, estimatorParamMaps, evaluator, numFolds=3, seed=None, parallelism=3,
- collectSubModels=False, labelCol='label', altEvaluators=None, overSample=True):
- """ Initialize Self
- :param estimator: Machine Learning Model, defaults to None
- :param estimatorParamMaps: paramMap to search, defaults to None
- :param evaluator: primary model evaluation metric, defaults to None
- :param numFolds: number of folds to perform, defaults to 3
- :param seed: random state, defaults to None
- :param parallelism: number of threads, defaults to 1
- :param collectSubModels: to return submodels, defaults to False
- :param labelCol: target variable column label, defaults to 'label'
- :param altEvaluators: additional metrics to evaluate, defaults to None
- If passed, the metrics of the alternate evaluators are accessed in the CrossValidatorModel.avgMetrics attribute
- :param overSample: Boolean: to perform oversampling of minority labels, defaults to True
- """
- self.label = labelCol
- self.altEvaluators = altEvaluators
- self.toOverSample = overSample
- super(OverSampleCrossValidator, self).__init__(estimator=estimator, estimatorParamMaps=estimatorParamMaps,
- evaluator=evaluator, numFolds=numFolds, seed=seed,
- parallelism=parallelism, collectSubModels=collectSubModels)
-
-
-
-
-
-
-
-[docs] def _parallelFitTasks(self, est, train, eva, validation, epm, collectSubModel, altEvaluators):
- """
- Creates a list of callables which can be called from different threads to fit and evaluate
- an estimator in parallel. Each callable returns an `(index, metric)` pair if altEvaluators, (index, metric, [alt_metrics]).
- :param est: Estimator, the estimator to be fit.
- :param train: DataFrame, training data set, used for fitting.
- :param eva: Evaluator, used to compute `metric`
- :param validation: DataFrame, validation data set, used for evaluation.
- :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
- :param collectSubModel: Whether to collect sub model.
- :return: (int, float, subModel), an index into `epm` and the associated metric value.
- """
- modelIter = est.fitMultiple(train, epm)
-
- def singleTask():
- index, model = next(modelIter)
- metric = eva.evaluate(model.transform(validation, epm[index]))
- altmetrics = None
- if altEvaluators:
- altmetrics = [altEva.evaluate(model.transform(validation, epm[index])) for altEva in altEvaluators]
- return index, metric, altmetrics, model if collectSubModel else None
-
- return [singleTask] * len(epm)
-
-[docs] def _fit(self, dataset):
- """Performs k-fold crossvaldidation on simple oversampled dataset
- :param dataset: full dataset
- :return: CrossValidatorModel containing the fitted BestModel with the average of the primary and alternate metrics in a list of tuples in the format: [(paramComb1_average_primary_metric, [paramComb1_average_altmetric1,paramComb1_average_altmetric2]), (paramComb2_average_primary_metric, [paramComb2_average_altmetric1,paramComb2_average_altmetric2])]
- """
- est = self.getOrDefault(self.estimator)
- epm = self.getOrDefault(self.estimatorParamMaps)
- numModels = len(epm)
- eva = self.getOrDefault(self.evaluator)
- nFolds = self.getOrDefault(self.numFolds)
- seed = self.getOrDefault(self.seed)
-
- # Getting Label and altEvaluators
- label = self.getLabel()
- altEvaluators = self.getAltEvaluators()
- altMetrics = [[0.0] * len(altEvaluators)] * numModels if altEvaluators else None
- h = 1.0 / nFolds
- randCol = self.uid + "_rand"
- df = dataset.select("*", F.rand(seed).alias(randCol))
- metrics = [0.0] * numModels
-
- pool = ThreadPool(processes=min(self.getParallelism(), numModels))
- subModels = None
- collectSubModelsParam = self.getCollectSubModels()
- if collectSubModelsParam:
- subModels = [[None for j in range(numModels)] for i in range(nFolds)]
-
- for i in range(nFolds):
- # Getting the splits such that no data is reused
- validateLB = i * h
- validateUB = (i + 1) * h
- condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
- validation = df.filter(condition).cache()
- train = df.filter(~condition).cache()
-
- # Oversampling the minority class(s) here
- if self.toOverSample:
- withReplacement = True
- oversampler = OverSampler(labelCol=self.label, strategy="auto")
-
- # Oversampling
- train = oversampler.transform(train)
- # Getting the individual tasks so this can be parallelized
- tasks = self._parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam, altEvaluators)
- # Calling the parallel process
- for j, metric, fold_alt_metrics, subModel in pool.imap_unordered(lambda f: f(), tasks):
- metrics[j] += (metric / nFolds)
- if fold_alt_metrics:
- altMetrics[j] = [altMetrics[j][i] + fold_alt_metrics[i] / nFolds for i in range(len(altEvaluators))]
-
- if collectSubModelsParam:
- subModels[i][j] = subModel
-
- validation.unpersist()
- train.unpersist()
-
- if eva.isLargerBetter():
- bestIndex = np.argmax(metrics)
- else:
- bestIndex = np.argmin(metrics)
- bestModel = est.fit(dataset, epm[bestIndex])
- metrics = [(metric, altMetrics[idx]) for idx, metric in enumerate(metrics)]
- return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
-
-
-## Pipeline Functions
-[docs]def get_string_pipeline(df, cols_to_exclude, steps=['StringIndexer', 'OneHotEncoder', 'OneHotDummies']):
- """Generates a list of preprocessing stages
- :param df: DataFrame including only the training data
- :param cols_to_exclude: Column names we don't want to to include in the preprocessing (i.e. SUBJECT/ target column)
- :param stages: preprocessing steps to take
- :return: (stages, Numeric_Columns)
- stages: list of pipeline stages to be used in preprocessing
- Numeric_Columns: list of columns that contain numeric features
- """
-
- String_Columns = []
- Numeric_Columns = []
- for _col, _type in df.dtypes: # This is a tuple of (<col name>, data type)
- if _col in cols_to_exclude:
- continue
- if _type == 'string':
- String_Columns.append(_col)
- elif _type == 'double' or _type == 'int' or _type == 'float':
- Numeric_Columns.append(_col)
- else:
- print("Unhandled Data type = {}".format((_col, _type)))
- continue
-
- stages = []
- if 'StringIndexer' in steps:
- # String Inexing
- str_indexers = [StringIndexer(inputCol=c, outputCol=c + '_ind', handleInvalid='skip') for c in String_Columns]
- indexed_string_vars = [c + '_ind' for c in String_Columns]
- stages = stages + str_indexers
-
- if 'OneHotEncoder' in steps:
- # One hot encoding
- str_hot = [OneHotEncoder(inputCol=c + '_ind', outputCol=c + '_vec', dropLast=False) for c in String_Columns]
- encoded_str_vars = [c + '_vec' for c in String_Columns]
- stages = stages + str_hot
-
- if 'OneHotDummies' in steps:
- # Converting the sparse vector to dummy columns
- str_dumbers = [OneHotDummies(inputCol=c + '_vec', outputCol='_dummy') for c in String_Columns]
- str_dumb_cols = [c for dummy in str_dumbers for c in dummy.getOutCols()]
- stages = stages + str_dumbers
-
- if len(stages) == 0:
- ERROR = """
- Parameter <steps> must include 'StringIndexer', 'OneHotEncoder', 'OneHotDummies'
- """
- print(ERROR)
- raise Exception(ERROR)
-
- return stages, Numeric_Columns
-
-
-[docs]def vector_assembler_pipeline(df, columns, doPCA=False, k=10):
- """After preprocessing String Columns, this function can be used to assemble a feature vector to be used for learning
- creates the following stages: VectorAssembler -> Standard Scalar [{ -> PCA}]
- :param df: DataFrame containing preprocessed Columns
- :param columns: list of Column names of the preprocessed columns
- :param doPCA: Do you want to do PCA as part of the vector assembler? defaults to False
- :param k: Number of Principal Components to use, defaults to 10
- :return: List of vector assembling stages
- """
-
- assembler = VectorAssembler(inputCols=columns, outputCol='featuresVec')
- scaler = StandardScaler(inputCol="featuresVec", outputCol="features", withStd=True,
- withMean=True) # centering and standardizing the data
-
- if doPCA:
- pca_obj = PCA(k=k, inputCol="features", outputCol="pcaFeatures")
- stages = [assembler, scaler, pca_obj]
- else:
- stages = [assembler, scaler]
- return stages
-
-
-[docs]def postprocessing_pipeline(df, cols_to_exclude):
- """Assemble postprocessing pipeline to reconstruct original categorical indexed values from OneHotDummy Columns
- :param df: DataFrame Including the original string Columns
- :param cols_to_exclude: list of columns to exclude
- :return: (reconstructers, String_Columns)
- reconstructers: list of IndReconstructer stages
- String_Columns: list of columns that are being reconstructed
- """
- String_Columns = []
- Numeric_Columns = []
- for _col, _type in df.dtypes: # This is a tuple of (<col name>, data type)
- if _col in cols_to_exclude:
- continue
- if _type == 'string':
- String_Columns.append(_col)
- elif _type == 'double' or _type == 'int' or _type == 'float':
- Numeric_Columns.append(_col)
- else:
- print("Unhandled Data type = {}".format((_col, _type)))
- continue
-
- # Extracting the Value of the OneHotEncoded Variable
- reconstructors = [IndReconstructer(inputCol=c, outputCol=c + '_activeInd') for c in String_Columns]
- return reconstructors, String_Columns
-
-
-# Distribution fitting Functions
-[docs]def make_pdf(dist, params, size=10000):
- """Generate distributions's Probability Distribution Function
- :param dist: scipy.stats distribution object: https://docs.scipy.org/doc/scipy/reference/stats.html
- :param params: distribution parameters
- :param size: how many data points to generate , defaults to 10000
- :return: series of probability density function for this distribution
- """
- # Separate parts of parameters
- arg = params[:-2]
- loc = params[-2]
- scale = params[-1]
-
- # Get sane start and end points of distribution
- start = dist.ppf(0.01, *arg, loc=loc, scale=scale) if arg else dist.ppf(0.01, loc=loc, scale=scale)
- end = dist.ppf(0.99, *arg, loc=loc, scale=scale) if arg else dist.ppf(0.99, loc=loc, scale=scale)
-
- # Build PDF and turn into pandas Series
- x = np.linspace(start, end, size)
- y = dist.pdf(x, loc=loc, scale=scale, *arg)
- pdf = pd.Series(y, x)
-
- return pdf
-
-
-[docs]def best_fit_distribution(data, col_name, bins, ax):
- """Model data by finding best fit distribution to data
- :param data: DataFrame with one column containing the feature whose distribution is to be investigated
- :param col_name: column name for feature
- :param bins: number of bins to use in generating the histogram of this data
- :param ax: axis to plot histogram on
- :return: (best_distribution.name, best_params, best_sse)
- best_distribution.name: string of the best distribution name
- best_params: parameters for this distribution
- best_sse: sum of squared errors for this distribution against the empirical pdf
- """
- # Get histogram of original data
-
- output = dist_explore.pandas_histogram(data, bins=bins)
- output.reset_index(level=0, inplace=True)
- output['index'] = output['index'].apply(lambda x: np.mean([float(i.strip()) for i in x.split('-')]))
- output[col_name] = output[col_name] / np.sum(output[col_name]) / (output['index'][1] - (output['index'][0]))
-
- x = output['index']
- y = output[col_name]
- # DISTRIBUTIONS = [
- # st.alpha,st.anglit,st.arcsine,st.beta,st.betaprime,st.bradford,st.burr,st.cauchy,st.chi,st.chi2,st.cosine,
- # st.dgamma,st.dweibull,st.erlang,st.expon,st.exponnorm,st.exponweib,st.exponpow,st.f,st.fatiguelife,st.fisk,
- # st.foldcauchy,st.foldnorm,st.frechet_r,st.frechet_l,st.genlogistic,st.genpareto,st.gennorm,st.genexpon,
- # st.genextreme,st.gausshyper,st.gamma,st.gengamma,st.genhalflogistic,st.gilbrat,st.gompertz,st.gumbel_r,
- # st.gumbel_l,st.halfcauchy,st.halflogistic,st.halfnorm,st.halfgennorm,st.hypsecant,st.invgamma,st.invgauss,
- # st.invweibull,st.johnsonsb,st.johnsonsu,st.ksone,st.kstwobign,st.laplace,st.levy,st.levy_l,st.levy_stable,
- # st.logistic,st.loggamma,st.loglaplace,st.lognorm,st.lomax,st.maxwell,st.mielke,st.nakagami,st.ncx2,st.ncf,
- # st.nct,st.norm,st.pareto,st.pearson3,st.powerlaw,st.powerlognorm,st.powernorm,st.rdist,st.reciprocal,
- # st.rayleigh,st.rice,st.recipinvgauss,st.semicircular,st.t,st.triang,st.truncexpon,st.truncnorm,st.tukeylambda,
- # st.uniform,st.vonmises,st.vonmises_line,st.wald,st.weibull_min,st.weibull_max,st.wrapcauchy
- # ]
-
- DISTRIBUTIONS = [
- st.beta, st.expon,
- st.halfnorm,
- st.norm,
- st.lognorm,
- st.uniform
- ]
-
- # Best holders
- best_distribution = st.norm
- best_params = (0.0, 1.0)
- best_sse = np.inf
-
- # Estimate distribution parameters from data
- for distribution in tqdm(DISTRIBUTIONS):
-
- # Try to fit the distribution
- try:
- # Ignore warnings from data that can't be fit
- with warnings.catch_warnings():
- warnings.filterwarnings('ignore')
-
- # fit dist to data
- params = distribution.fit(data.collect())
-
- # Separate parts of parameters
- arg = params[:-2]
- loc = params[-2]
- scale = params[-1]
-
- # Calculate fitted PDF and error with fit in distribution
-
- pdf = distribution.pdf(x, loc=loc, scale=scale, *arg)
- sse = np.sum(np.power(y.values - pdf, 2.0))
-
- # if axis pass in add to plot
- try:
- if ax:
- if sse < 0.05:
- # Don't want to plot really bad ones
- ax = pdf.plot(legend=True, label=distribution.name)
- # ax.plot(x,pdf, label = distribution.name)
- ax.legend()
- except Exception:
- pass
-
- # identify if this distribution is better
- if best_sse > sse > 0:
- best_distribution = distribution
- best_params = params
- best_sse = sse
-
- except Exception:
- pass
-
- return (best_distribution.name, best_params, best_sse)
-
-
-## PCA Functions
-
-[docs]def estimateCovariance(df, features_col='features'):
- """Compute the covariance matrix for a given dataframe.
- Note: The multi-dimensional covariance array should be calculated using outer products. Don't forget to normalize the data by first subtracting the mean.
- :param df: PySpark dataframe
- :param features_col: name of the column with the features, defaults to 'features'
- :return: np.ndarray: A multi-dimensional array where the number of rows and columns both equal the length of the arrays in the input dataframe.
- """
- m = df.select(df[features_col]).rdd.map(lambda x: x[0]).mean()
-
- dfZeroMean = df.select(df[features_col]).rdd.map(lambda x: x[0]).map(lambda x: x - m) # subtract the mean
-
- return dfZeroMean.map(lambda x: np.outer(x, x)).sum() / df.count()
-
-
-[docs]def pca_with_scores(df, k=10):
- """Computes the top `k` principal components, corresponding scores, and all eigenvalues.
- Note:
- All eigenvalues should be returned in sorted order (largest to smallest). `eigh` returns
- each eigenvectors as a column. This function should also return eigenvectors as columns.
- :param df: A Spark dataframe with a 'features' column, which (column) consists of DenseVectors.
- :param k: The number of principal components to return., defaults to 10
- :return:(eigenvectors, `RDD` of scores, eigenvalues).
- Eigenvectors: multi-dimensional array where the number of
- rows equals the length of the arrays in the input `RDD` and the number of columns equals`k`.
- `RDD` of scores: has the same number of rows as `data` and consists of arrays of length `k`.
- Eigenvalues is an array of length d (the number of features).
- """
- cov = estimateCovariance(df)
- col = cov.shape[1]
- eigVals, eigVecs = eigh(cov)
- inds = np.argsort(eigVals)
- eigVecs = eigVecs.T[inds[-1:-(col + 1):-1]]
- components = eigVecs[0:k]
- eigVals = eigVals[inds[-1:-(col + 1):-1]] # sort eigenvals
- score = df.select(df['features']).rdd.map(lambda x: x[0]).map(lambda x: np.dot(x, components.T))
- # Return the `k` principal components, `k` scores, and all eigenvalues
-
- return components.T, score, eigVals
-
-
-[docs]def varianceExplained(df, k=10):
- """returns the proportion of variance explained by `k` principal componenets. Calls the above PCA procedure
- :param df: PySpark DataFrame
- :param k: number of principal components , defaults to 10
- :return: (proportion, principal_components, scores, eigenvalues)
- """
- components, scores, eigenvalues = pca_with_scores(df, k)
- return sum(eigenvalues[0:k]) / sum(eigenvalues), components, scores, eigenvalues
-
-
-# PCA reconstruction Functions
-
-[docs]def reconstructPCA(sql, df, pc, mean, std, originalColumns, fits, pcaColumn='pcaFeatures'):
- """Reconstruct data from lower dimensional space after performing PCA
- :param sql: SQLContext
- :param df: PySpark DataFrame: inputted PySpark DataFrame
- :param pc: numpy.ndarray: principal components projected onto
- :param mean: numpy.ndarray: mean of original columns
- :param std: numpy.ndarray: standard deviation of original columns
- :param originalColumns: list: original column names
- :param fits: fits of features returned from best_fit_distribution
- :param pcaColumn: column in df that contains PCA features, defaults to 'pcaFeatures'
- :return: dataframe containing reconstructed data
- """
-
- cols = df.columns
- cols.remove(pcaColumn)
-
- pddf = df.toPandas()
- first_series = pddf['pcaFeatures'].apply(lambda x: np.array(x.toArray())).as_matrix().reshape(-1, 1)
- first_features = np.apply_along_axis(lambda x: x[0], 1, first_series)
- # undo-ing PCA
- first_reconstructed = np.dot(first_features, pc)
- # undo-ing the scaling
- first_reconstructed = np.multiply(first_reconstructed, std) + mean
- first_reconstructedDF = pd.DataFrame(first_reconstructed, columns=originalColumns)
- for _col in cols:
- first_reconstructedDF[_col] = pddf[_col]
-
- # This is a pyspark Dataframe containing the reconstructed data, including the dummy columns for the string variables-- next step is to reverse the one-hot-encoding for the string columns
- first_reconstructed = sql.createDataFrame(first_reconstructedDF)
-
- cols_to_exclude = ['DATE_OF_STUDY']
- postPipeStages, String_Columns = postprocessing_pipeline(df, cols_to_exclude)
-
- postPipe = Pipeline(stages=postPipeStages)
- out = postPipe.fit(first_reconstructed).transform(first_reconstructed)
- for _col in String_Columns:
- out = out.join(df.select([_col, _col + '_ind']) \
- .withColumnRenamed(_col + '_ind', _col + '_activeInd'), _col + '_activeInd') \
- .dropDuplicates()
- cols_to_drop = [_col for _col in out.columns if any([base in _col for base in String_Columns]) and '_' in _col]
-
- reconstructedDF = out.drop(
- *cols_to_drop) # This is the equivalent as the first translated reconstructed dataframe above
- clip = F.udf(lambda x: x if x > 0 else 0.0, DoubleType())
- for _key in fits.keys():
- if fits[_key]['dist'] == 'EMPIRICAL':
- reconstructedDF = reconstructedDF.withColumn(_key, F.round(clip(F.col(_key))))
- else:
- reconstructedDF = reconstructedDF.withColumn(_key, clip(F.col(_key)))
-
- return reconstructedDF
-
-
-[docs]class MarkovChain(object):
- def __init__(self, transition_prob):
- """
- Initialize the MarkovChain instance.
- Parameters
- ----------
- transition_prob: dict
- A dict object representing the transition
- probabilities in Markov Chain.
- Should be of the form:
- {'state1': {'state1': 0.1, 'state2': 0.4},
- 'state2': {...}}
- """
- self.transition_prob = transition_prob
- self.states = list(transition_prob.keys()) # states that have transitions to the next layer
- # For states in the form <stateN_M> where N is the visit (layer) and M is the cluster in the N-th Layer
- self.max_num_steps = max([int(i.split('state')[1][0]) for i in self.states])
-
-
-
-[docs] def next_state(self, current_state):
- """Returns the state of the random variable at the next time
- instance.
- :param current_state: The current state of the system.
- :raises: Exception if random choice fails
- :return: next state
- """
-
- try:
-
- # if not current_state in self.states:
- # print('We have reached node {} where we do not know where they go from here... \n try reducing the number of clusters at level {} \n otherwise we might be at the terminating layer'.format(current_state, int(current_state.split('state')[1][0])))
- # raise Exception('Unknown transition')
-
- next_possible_states = self.transition_prob[current_state].keys()
- return np.random.choice(
- next_possible_states,
- p=[self.transition_prob[current_state][next_state]
- for next_state in next_possible_states]
- )[:]
- except Exception as e:
- raise e
-
-[docs] def generate_states(self, current_state, no=10, last=True):
- """
- Generates the next states of the system.
- Parameters
- ----------
- current_state: str
- The state of the current random variable.
- no: int
- The number of future states to generate.
- last: bool
- Do we want to return just the last value
- """
- try:
- if no > self.max_num_steps:
- print('Number of steps exceeds the max number of possible next steps')
- raise Exception('<no> should not exceed {}. The value of <no> was: {}'.format(self.max_num_steps, no))
-
- future_states = []
- for i in range(no):
- try:
- next_state = self.next_state(current_state)
- except Exception as e:
- raise e
- future_states.append(next_state)
- current_state = next_state
- if last:
- return future_states[-1]
- else:
- return future_states
- except Exception as e:
- raise e
-
-[docs] def rep_states(self, current_state, no=10, num_reps=10):
- """running generate states a bunch of times and returning the final state that happens the most
- Arguments:
- current_state str -- The state of the current random variable
- no int -- number of time steps in the future to run
- num_reps int -- number of times to run the simultion forward
- Returns:
- state -- the most commonly reached state at the end of these runs
- """
- if no > self.max_num_steps:
- print('Number of steps exceeds the max number of possible next steps')
- raise Exception('<no> should not exceed {}. The value of <no> was: {}'.format(self.max_num_steps, no))
-
- endstates = []
- for _ in range(num_reps):
- endstates.append(self.generate_states(current_state, no=no, last=True))
- return max(set(endstates), key=endstates.count)
-