|
9 | 9 | import pytest |
10 | 10 | from flyteidl.core import workflow_pb2 as _core_workflow |
11 | 11 |
|
12 | | -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, Resources |
| 12 | +from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask |
13 | 13 | from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings |
14 | 14 | from flytekit.core import context_manager |
15 | 15 | from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver |
|
21 | 21 | LiteralMap, |
22 | 22 | LiteralOffloadedMetadata, |
23 | 23 | ) |
24 | | -from flytekit.models.task import Resources as _resources_models |
25 | 24 | from flytekit.tools.translator import get_serializable |
26 | 25 | from flytekit.types.directory import FlyteDirectory |
27 | 26 |
|
@@ -350,59 +349,16 @@ def my_wf1() -> typing.List[typing.Optional[int]]: |
350 | 349 | assert my_wf1() == [1, None, 3, 4] |
351 | 350 |
|
352 | 351 |
|
353 | | -@task |
354 | | -def my_mappable_task(a: int) -> typing.Optional[str]: |
355 | | - return str(a) |
356 | | - |
357 | | - |
358 | | -@task( |
359 | | - container_image="original-image", |
360 | | - timeout=timedelta(seconds=10), |
361 | | - interruptible=False, |
362 | | - retries=10, |
363 | | - cache=True, |
364 | | - cache_version="original-version", |
365 | | - requests=Resources(cpu=1) |
366 | | -) |
367 | | -def my_mappable_task_1(a: int) -> typing.Optional[str]: |
368 | | - return str(a) |
369 | | - |
370 | | - |
371 | | -@pytest.mark.parametrize( |
372 | | - "task_func", |
373 | | - [my_mappable_task, my_mappable_task_1] |
374 | | -) |
375 | | -def test_map_task_override(serialization_settings, task_func): |
376 | | - array_node_map_task = map_task(task_func) |
| 352 | +def test_map_task_override(serialization_settings): |
| 353 | + @task |
| 354 | + def my_mappable_task(a: int) -> typing.Optional[str]: |
| 355 | + return str(a) |
377 | 356 |
|
378 | 357 | @workflow |
379 | 358 | def wf(x: typing.List[int]): |
380 | | - array_node_map_task(a=x).with_overrides( |
381 | | - container_image="new-image", |
382 | | - timeout=timedelta(seconds=20), |
383 | | - interruptible=True, |
384 | | - retries=5, |
385 | | - cache=True, |
386 | | - cache_version="new-version", |
387 | | - requests=Resources(cpu=2) |
388 | | - ) |
389 | | - |
390 | | - assert wf.nodes[0]._container_image == "new-image" |
391 | | - |
392 | | - od = OrderedDict() |
393 | | - wf_spec = get_serializable(od, serialization_settings, wf) |
| 359 | + map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") |
394 | 360 |
|
395 | | - array_node = wf_spec.template.nodes[0] |
396 | | - assert array_node.metadata.timeout == timedelta() |
397 | | - sub_node_spec = array_node.array_node.node |
398 | | - assert sub_node_spec.metadata.timeout == timedelta(seconds=20) |
399 | | - assert sub_node_spec.metadata.interruptible |
400 | | - assert sub_node_spec.metadata.retries.retries == 5 |
401 | | - assert sub_node_spec.metadata.cacheable |
402 | | - assert sub_node_spec.metadata.cache_version == "new-version" |
403 | | - assert sub_node_spec.target.overrides.resources.requests == [ |
404 | | - _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2") |
405 | | - ] |
| 361 | + assert wf.nodes[0]._container_image == "random:image" |
406 | 362 |
|
407 | 363 |
|
408 | 364 | def test_serialization_metadata(serialization_settings): |
|
0 commit comments