Skip to content

Commit a19d311

Browse files
rjzamoraoliverholworthyjperez999
authored
Enable cpp code path for Categorify ops (#389)
* allow cpp code path to be used for Categorify ops * Update merlin/systems/workflow/base.py * formatting * spelling * require env-variable opt-in tfor cpp code path * Update merlin/systems/workflow/base.py --------- Co-authored-by: Oliver Holworthy <[email protected]> Co-authored-by: Julio Perez <[email protected]>
1 parent ddb775b commit a19d311

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

merlin/systems/workflow/base.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import functools
2828
import json
2929
import logging
30+
import os
3031

3132
from merlin.dag import ColumnSelector, DataFormats, Supports
3233
from merlin.dag.executors import LocalExecutor, _convert_format, _data_format
@@ -66,15 +67,24 @@ def __init__(self, workflow, output_dtypes, model_config, model_device):
6667
)
6768

6869
# recurse over all column groups, initializing operators for inference pipeline.
69-
# (disabled for now while we sort out whether and how we want to use C++ implementations
70-
# of NVTabular operators for performance optimization)
71-
# self._initialize_ops(self.workflow.output_node)
70+
# (disabled everything other than operators that are specifically listed
71+
# by the `NVT_CPP_OPS` environment variable while we sort out whether
72+
# and how we want to use C++ implementations of NVTabular operators for
73+
# performance optimization)
74+
_nvt_cpp_ops = os.environ.get("NVT_CPP_OPS", "Categorify").split(",")
75+
self._initialize_ops(self.workflow.output_node, restrict=_nvt_cpp_ops)
76+
77+
def _initialize_ops(self, workflow_node, visited=None, restrict=None):
78+
restrict = restrict or []
7279

73-
def _initialize_ops(self, workflow_node, visited=None):
7480
if visited is None:
7581
visited = set()
7682

77-
if workflow_node.op and hasattr(workflow_node.op, "inference_initialize"):
83+
if (
84+
workflow_node.op
85+
and hasattr(workflow_node.op, "inference_initialize")
86+
and (not restrict or workflow_node.op.label in restrict)
87+
):
7888
inference_op = workflow_node.op.inference_initialize(
7989
workflow_node.selector, self.model_config
8090
)
@@ -96,7 +106,7 @@ def _initialize_ops(self, workflow_node, visited=None):
96106
for parent in workflow_node.parents_with_dependencies:
97107
if parent not in visited:
98108
visited.add(parent)
99-
self._initialize_ops(parent, visited)
109+
self._initialize_ops(parent, visited=visited, restrict=restrict)
100110

101111
def run_workflow(self, input_tensors):
102112
transformable = TensorTable(input_tensors).to_df()

0 commit comments

Comments
 (0)