27
27
import functools
28
28
import json
29
29
import logging
30
+ import os
30
31
31
32
from merlin .dag import ColumnSelector , DataFormats , Supports
32
33
from merlin .dag .executors import LocalExecutor , _convert_format , _data_format
@@ -66,15 +67,24 @@ def __init__(self, workflow, output_dtypes, model_config, model_device):
66
67
)
67
68
68
69
# recurse over all column groups, initializing operators for inference pipeline.
69
- # (disabled for now while we sort out whether and how we want to use C++ implementations
70
- # of NVTabular operators for performance optimization)
71
- # self._initialize_ops(self.workflow.output_node)
70
+ # (disabled everything other than operators that are specifically listed
71
+ # by the `NVT_CPP_OPS` environment variable while we sort out whether
72
+ # and how we want to use C++ implementations of NVTabular operators for
73
+ # performance optimization)
74
+ _nvt_cpp_ops = os .environ .get ("NVT_CPP_OPS" , "Categorify" ).split ("," )
75
+ self ._initialize_ops (self .workflow .output_node , restrict = _nvt_cpp_ops )
76
+
77
+ def _initialize_ops (self , workflow_node , visited = None , restrict = None ):
78
+ restrict = restrict or []
72
79
73
- def _initialize_ops (self , workflow_node , visited = None ):
74
80
if visited is None :
75
81
visited = set ()
76
82
77
- if workflow_node .op and hasattr (workflow_node .op , "inference_initialize" ):
83
+ if (
84
+ workflow_node .op
85
+ and hasattr (workflow_node .op , "inference_initialize" )
86
+ and (not restrict or workflow_node .op .label in restrict )
87
+ ):
78
88
inference_op = workflow_node .op .inference_initialize (
79
89
workflow_node .selector , self .model_config
80
90
)
@@ -96,7 +106,7 @@ def _initialize_ops(self, workflow_node, visited=None):
96
106
for parent in workflow_node .parents_with_dependencies :
97
107
if parent not in visited :
98
108
visited .add (parent )
99
- self ._initialize_ops (parent , visited )
109
+ self ._initialize_ops (parent , visited = visited , restrict = restrict )
100
110
101
111
def run_workflow (self , input_tensors ):
102
112
transformable = TensorTable (input_tensors ).to_df ()
0 commit comments