Skip to content

Commit f6058c1

Browse files
kmontetfx-copybara
authored andcommitted
Encode producer component id and output key when CWP is created from an OutputChannel
PiperOrigin-RevId: 641945426
1 parent 810e466 commit f6058c1

15 files changed

+202
-95
lines changed

tfx/dsl/compiler/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def _compile_node(
214214

215215
# Step 3: Node inputs
216216
node_inputs_compiler.compile_node_inputs(
217-
pipeline_ctx, tfx_node, node.inputs)
218-
217+
pipeline_ctx, tfx_node, node.inputs
218+
)
219219
# Step 4: Node outputs
220220
if (isinstance(tfx_node, base_component.BaseComponent) or
221221
compiler_utils.is_importer(tfx_node)):

tfx/dsl/compiler/compiler_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def node_context_name(pipeline_context_name: str, node_id: str):
204204

205205
def implicit_channel_key(channel: types.BaseChannel):
206206
"""Key of a channel to the node that consumes the channel as input."""
207+
if (
208+
isinstance(channel, channel_types.ChannelWrappedPlaceholder)
209+
and channel.key
210+
):
211+
return channel.key
207212
if isinstance(channel, channel_types.PipelineInputChannel):
208213
channel = cast(channel_types.PipelineInputChannel, channel)
209214
return f"_{channel.pipeline.id}.{channel.output_key}"

tfx/dsl/compiler/node_inputs_compiler.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,20 +421,32 @@ def _compile_conditionals(
421421
contexts = context.dsl_context_registry.get_contexts(tfx_node)
422422
except ValueError:
423423
return
424-
425424
for dsl_context in contexts:
426425
if not isinstance(dsl_context, conditional.CondContext):
427426
continue
428427
cond_context = cast(conditional.CondContext, dsl_context)
429428
for channel in channel_utils.get_dependent_channels(cond_context.predicate):
429+
# Since the channels here are *always* from a CWP, which we now set the
430+
# key by default on for OutputChannel, we must re-create the input key if
431+
# an output channel is used, otherwise the wrong key may be used by
432+
# `get_input_key` (e.g. if the producer component is also used as data
433+
# input to the component.)
434+
# Note that this means we potentially have several inputs with identical
435+
# artifact queries under the hood, which should be optimized away if we
436+
# run into performance issues.
437+
if isinstance(channel, channel_types.OutputChannel):
438+
input_key = compiler_utils.implicit_channel_key(channel)
439+
else:
440+
input_key = context.get_node_context(tfx_node).get_input_key(channel)
430441
_compile_input_spec(
431442
pipeline_ctx=context,
432443
tfx_node=tfx_node,
433-
input_key=context.get_node_context(tfx_node).get_input_key(channel),
444+
input_key=input_key,
434445
channel=channel,
435446
hidden=False,
436447
min_count=1,
437-
result=result)
448+
result=result,
449+
)
438450
cond_id = context.get_conditional_id(cond_context)
439451
expr = channel_utils.encode_placeholder_with_channels(
440452
cond_context.predicate, context.get_node_context(tfx_node).get_input_key

tfx/dsl/compiler/node_inputs_compiler_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ def testCompileConditionals(self):
577577
self.assertEqual(result.inputs[cond_input_key].min_count, 1)
578578
self.assertLen(result.conditionals, 1)
579579
cond = list(result.conditionals.values())[0]
580-
self.assertProtoEquals("""
580+
self.assertProtoEquals(
581+
"""
581582
operator {
582583
compare_op {
583584
op: EQUAL
@@ -594,7 +595,7 @@ def testCompileConditionals(self):
594595
index_op {
595596
expression {
596597
placeholder {
597-
key: "%s"
598+
key: "_CondNode.x"
598599
}
599600
}
600601
}
@@ -605,7 +606,9 @@ def testCompileConditionals(self):
605606
}
606607
}
607608
}
608-
""" % cond_input_key, cond.placeholder_expression)
609+
""",
610+
cond.placeholder_expression,
611+
)
609612

610613
def testCompileInputsForDynamicProperties(self):
611614
producer = DummyNode('Producer')

tfx/dsl/compiler/testdata/composable_pipeline_async_input_v2_ir.pbtxt

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,6 +1942,43 @@ nodes {
19421942
}
19431943
}
19441944
inputs {
1945+
inputs {
1946+
key: "_Evaluator.blessing"
1947+
value {
1948+
channels {
1949+
producer_node_query {
1950+
id: "Evaluator"
1951+
}
1952+
context_queries {
1953+
type {
1954+
name: "pipeline"
1955+
}
1956+
name {
1957+
field_value {
1958+
string_value: "composable-pipeline"
1959+
}
1960+
}
1961+
}
1962+
context_queries {
1963+
type {
1964+
name: "node"
1965+
}
1966+
name {
1967+
field_value {
1968+
string_value: "composable-pipeline.Evaluator"
1969+
}
1970+
}
1971+
}
1972+
artifact_query {
1973+
type {
1974+
name: "ModelBlessing"
1975+
}
1976+
}
1977+
output_key: "blessing"
1978+
}
1979+
min_count: 1
1980+
}
1981+
}
19451982
inputs {
19461983
key: "blessing"
19471984
value {
@@ -2109,7 +2146,7 @@ nodes {
21092146
index_op {
21102147
expression {
21112148
placeholder {
2112-
key: "blessing"
2149+
key: "_Evaluator.blessing"
21132150
}
21142151
}
21152152
}

tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,55 @@ nodes {
12021202
min_count: 1
12031203
}
12041204
}
1205+
inputs {
1206+
key: "_Trainer.model"
1207+
value {
1208+
channels {
1209+
producer_node_query {
1210+
id: "Trainer"
1211+
}
1212+
context_queries {
1213+
type {
1214+
name: "pipeline"
1215+
}
1216+
name {
1217+
field_value {
1218+
string_value: "cond"
1219+
}
1220+
}
1221+
}
1222+
context_queries {
1223+
type {
1224+
name: "pipeline_run"
1225+
}
1226+
name {
1227+
runtime_parameter {
1228+
name: "pipeline-run-id"
1229+
type: STRING
1230+
}
1231+
}
1232+
}
1233+
context_queries {
1234+
type {
1235+
name: "node"
1236+
}
1237+
name {
1238+
field_value {
1239+
string_value: "cond.Trainer"
1240+
}
1241+
}
1242+
}
1243+
artifact_query {
1244+
type {
1245+
name: "Model"
1246+
base_type: MODEL
1247+
}
1248+
}
1249+
output_key: "model"
1250+
}
1251+
min_count: 1
1252+
}
1253+
}
12051254
inputs {
12061255
key: "model"
12071256
value {
@@ -1333,7 +1382,7 @@ nodes {
13331382
index_op {
13341383
expression {
13351384
placeholder {
1336-
key: "model"
1385+
key: "_Trainer.model"
13371386
}
13381387
}
13391388
}

tfx/dsl/compiler/testdata/consumer_pipeline_with_tags_input_v2_ir.pbtxt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# proto-file: tfx/proto/orchestration/pipeline.proto
2+
# proto-message: Pipeline
3+
#
4+
# This file contains the IR of an example pipeline
5+
# tfx/dsl/compiler/testdata/consumer_pipeline_with_tags.py
6+
17
pipeline_info {
28
id: "consumer-pipeline"
39
}

tfx/orchestration/kubeflow/v2/compiler_utils_test.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -266,36 +266,38 @@ def setUp(self):
266266

267267
@parameterized.named_parameters(
268268
{
269-
'testcase_name':
270-
'two_sides_placeholder',
271-
'predicate':
272-
_TEST_CHANNEL.future()[0].property('int1') <
273-
_TEST_CHANNEL.future()[0].property('int2'),
274-
'expected_cel':
275-
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int1\'] < '
276-
'inputs.artifacts[\'key\'].artifacts[0].metadata[\'int2\'])',
269+
'testcase_name': 'two_sides_placeholder',
270+
'predicate': _TEST_CHANNEL.future()[0].property(
271+
'int1'
272+
) < _TEST_CHANNEL.future()[0].property('int2'),
273+
'expected_cel': (
274+
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int1'] < "
275+
"inputs.artifacts['_producer.foo'].artifacts[0].metadata['int2'])"
276+
),
277277
},
278278
{
279-
'testcase_name':
280-
'left_side_placeholder_right_side_int',
281-
'predicate':
282-
_TEST_CHANNEL.future()[0].property('int') < 1,
283-
'expected_cel':
284-
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int\'] < 1.0)',
279+
'testcase_name': 'left_side_placeholder_right_side_int',
280+
'predicate': _TEST_CHANNEL.future()[0].property('int') < 1,
281+
'expected_cel': (
282+
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int']"
283+
' < 1.0)'
284+
),
285285
},
286286
{
287287
'testcase_name': 'left_side_placeholder_right_side_float',
288288
'predicate': _TEST_CHANNEL.future()[0].property('float') < 1.1,
289-
'expected_cel':
290-
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'float\'] < '
291-
'1.1)',
289+
'expected_cel': (
290+
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['float']"
291+
' < 1.1)'
292+
),
292293
},
293294
{
294295
'testcase_name': 'left_side_placeholder_right_side_string',
295296
'predicate': _TEST_CHANNEL.future()[0].property('str') == 'test_str',
296-
'expected_cel':
297-
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'str\'] == '
298-
'\'test_str\')',
297+
'expected_cel': (
298+
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['str']"
299+
" == 'test_str')"
300+
),
299301
},
300302
)
301303
def testComparison(self, predicate, expected_cel):
@@ -310,8 +312,9 @@ def testComparison(self, predicate, expected_cel):
310312

311313
def testArtifactUri(self):
312314
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
313-
expected_cel = ('(inputs.artifacts[\'key\'].artifacts[0].uri == '
314-
'\'test_str\')')
315+
expected_cel = (
316+
"(inputs.artifacts['_producer.foo'].artifacts[0].uri == 'test_str')"
317+
)
315318
channel_to_key_map = {
316319
_TEST_CHANNEL: 'key',
317320
}
@@ -323,8 +326,10 @@ def testArtifactUri(self):
323326

324327
def testNegation(self):
325328
predicate = _TEST_CHANNEL.future()[0].property('int') != 1
326-
expected_cel = ('!((inputs.artifacts[\'key\'].artifacts[0]'
327-
'.metadata[\'int\'] == 1.0))')
329+
expected_cel = (
330+
"!((inputs.artifacts['_producer.foo'].artifacts[0]"
331+
".metadata['int'] == 1.0))"
332+
)
328333
channel_to_key_map = {
329334
_TEST_CHANNEL: 'key',
330335
}
@@ -337,8 +342,9 @@ def testNegation(self):
337342
def testConcat(self):
338343
predicate = _TEST_CHANNEL.future()[0].uri + 'something' == 'test_str'
339344
expected_cel = (
340-
'((inputs.artifacts[\'key\'].artifacts[0].uri + \'something\') == '
341-
'\'test_str\')')
345+
"((inputs.artifacts['_producer.foo'].artifacts[0].uri + 'something') =="
346+
" 'test_str')"
347+
)
342348
channel_to_key_map = {
343349
_TEST_CHANNEL: 'key',
344350
}
@@ -360,14 +366,6 @@ def testUnsupportedOperator(self):
360366
ValueError, 'Got unsupported placeholder operator base64_encode_op.'):
361367
compiler_utils.placeholder_to_cel(placeholder_pb)
362368

363-
def testPlaceholderWithoutKey(self):
364-
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
365-
placeholder_pb = predicate.encode()
366-
with self.assertRaisesRegex(
367-
ValueError,
368-
'Only supports accessing placeholders with a key on KFPv2.'):
369-
compiler_utils.placeholder_to_cel(placeholder_pb)
370-
371369

372370
if __name__ == '__main__':
373371
tf.test.main()

tfx/orchestration/kubeflow/v2/testdata/expected_dummy_consumer_with_condition_task.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ inputs {
3535
}
3636
}
3737
trigger_policy {
38-
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
38+
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
3939
}
4040
component_ref {
4141
name: "DummyConsumerComponent"

tfx/orchestration/kubeflow/v2/testdata/legacy/expected_dummy_consumer_with_condition_task.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ inputs {
3535
}
3636
}
3737
trigger_policy {
38-
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
38+
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
3939
}
4040
component_ref {
4141
name: "DummyConsumerComponent"

0 commit comments

Comments
 (0)