Skip to content

Commit

Permalink
Merge pull request #308 from alexbrillant/issue-236
Browse files Browse the repository at this point in the history
Fix Issue #236
  • Loading branch information
alexbrillant authored Apr 4, 2020
2 parents bff3b42 + b5fd966 commit de06c15
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions neuraxle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,6 +3035,13 @@ class TransformHandlerOnlyMixin(NonFittableMixin):

@abstractmethod
def _transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> DataContainer:
"""
Transform data container with the given execution context.
:param data_container: data container
:param context: execution context
:return: transformed data container
"""
raise NotImplementedError('Must implement _transform_data_container in {0}'.format(self.name))

def transform(self, data_inputs) -> 'HandleOnlyMixin':
Expand Down Expand Up @@ -3111,7 +3118,13 @@ def __init__(self, cache_folder=None):
cache_folder = DEFAULT_CACHE_FOLDER
self.cache_folder = cache_folder

def transform(self, data_inputs):
def transform(self, data_inputs) -> Iterable:
"""
Using :func:`~neuraxle.base.BaseStep.handle_transform`, transform data inputs.
:param data_inputs: data inputs
:return: outputs
"""
execution_context = ExecutionContext(self.cache_folder, execution_mode=ExecutionMode.TRANSFORM)
context, data_container = self._encapsulate_data(
data_inputs, expected_outputs=None, execution_mode=ExecutionMode.TRANSFORM)
Expand All @@ -3121,18 +3134,38 @@ def transform(self, data_inputs):
return data_container.data_inputs

def fit(self, data_inputs, expected_outputs=None) -> 'HandleOnlyMixin':
"""
Using :func:`~neuraxle.base.BaseStep.handle_fit`, fit step with the given data inputs, and expected outputs.
:param data_inputs: data inputs
:return: fitted self
"""
context, data_container = self._encapsulate_data(data_inputs, expected_outputs, ExecutionMode.FIT)
new_self = self.handle_fit(data_container, context)

return new_self

def fit_transform(self, data_inputs, expected_outputs=None) -> Tuple['HandleOnlyMixin', Iterable]:
"""
Using :func:`~neuraxle.base.BaseStep.handle_fit_transform`, fit and transform step with the given data inputs, and expected outputs.
:param data_inputs: data inputs
:return: fitted self, outputs
"""
context, data_container = self._encapsulate_data(data_inputs, expected_outputs, ExecutionMode.FIT_TRANSFORM)
new_self, data_container = self.handle_fit_transform(data_container, context)

return new_self, data_container.data_inputs

def _encapsulate_data(self, data_inputs, expected_outputs=None, execution_mode=None):
def _encapsulate_data(self, data_inputs, expected_outputs=None, execution_mode=None) -> Tuple[ExecutionContext, DataContainer]:
"""
Encapsulate data with :class:`~neuraxle.data_container.DataContainer`.
:param data_inputs: data inputs
:param expected_outputs: expected outputs
:param execution_mode: execution mode
:return: execution context, data container
"""
data_container = DataContainer(data_inputs=data_inputs, expected_outputs=expected_outputs)
context = ExecutionContext(root=self.cache_folder, execution_mode=execution_mode)

Expand Down Expand Up @@ -3179,6 +3212,16 @@ class FullDumpLoader(Identity):
Identity step that can load the full dump of a pipeline step.
Used by :func:`~neuraxle.base.BaseStep.load`.
Usage example:
.. code-block:: python
saved_step = FullDumpLoader(
name=path,
stripped_saver=self.stripped_saver
).load(context_for_loading, True)
.. seealso::
:class:`ExecutionContext`
:class:`BaseStep`,
Expand Down

0 comments on commit de06c15

Please sign in to comment.