Skip to content
128 changes: 128 additions & 0 deletions examples/productionizing/productionizing/customizing_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def my_workflow(x: typing.List[int]) -> int:
#
# ## Using `with_overrides`
#
# ### override Resources
# You can use the `with_overrides` method to override the resources allocated to the tasks dynamically.
# Let's understand how the resources can be initialized with an example.

Expand Down Expand Up @@ -142,3 +143,130 @@ def my_pipeline(x: typing.List[int]) -> int:
# Resource allocated using "with_overrides" method
# :::
#
# ### override task_config
# Another example for using `with_overrides` method to override the `task_config`.
# In the following we take TF Trainning for example.
# Let’s understand how the TfJob can be initialized and override with an example.
#
# For task_config, refer to the {py:func}`flytekit:flytekit.task` documentation.
#
# Define some necessary functions and dependency.
# For more detail please check [here](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/kftensorflow_plugin/tf_mnist.html#run-distributed-tensorflow-training).
# In this content we focus on how to override the `task_conf`.
# %%
import os
from dataclasses import dataclass
from typing import NamedTuple, Tuple

from dataclasses_json import dataclass_json
from flytekit import ImageSpec, Resources, dynamic, task, workflow
from flytekit.types.directory import FlyteDirectory

custom_image = ImageSpec(
name="kftensorflow-flyte-plugin",
packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],
registry="ghcr.io/flyteorg",
)

if custom_image.is_container():
import tensorflow as tf
from flytekitplugins.kftensorflow import PS, Chief, TfJob, Worker

MODEL_FILE_PATH = "saved_model/"


@dataclass_json
@dataclass
class Hyperparameters(object):
# initialize a data class to store the hyperparameters.
batch_size_per_replica: int = 64
buffer_size: int = 10000
epochs: int = 10


def load_data(
hyperparameters: Hyperparameters,
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.distribute.Strategy]:
# Fetch train and evaluation datasets
...


def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model:
# compile a model
...


def decay(epoch: int):
# define a function for decaying the learning rate
...


def train_model(
model: tf.keras.Model,
train_dataset: tf.data.Dataset,
hyperparameters: Hyperparameters,
) -> Tuple[tf.keras.Model, str]:
# define the train_model function
...


def test_model(model: tf.keras.Model, checkpoint_dir: str, eval_dataset: tf.data.Dataset) -> Tuple[float, float]:
# define the test_model function to evaluate loss and accuracy on the test dataset
...


# %% [markdown]
# To create a TensorFlow task, add {py:class}`flytekitplugins:flytekitplugins.kftensorflow.TfJob` config to the Flyte task, that is a plugin can run distributed TensorFlow training on Kubernetes.
# %%
training_outputs = NamedTuple("TrainingOutputs", accuracy=float, loss=float, model_state=FlyteDirectory)

if os.getenv("SANDBOX") != "":
resources = Resources(gpu="0", mem="1000Mi", storage="500Mi", ephemeral_storage="500Mi")
else:
resources = Resources(gpu="2", mem="10Gi", storage="10Gi", ephemeral_storage="500Mi")


@task(
task_config=TfJob(worker=Worker(replicas=1), ps=PS(replicas=1), chief=Chief(replicas=1)),
retries=2,
cache=True,
cache_version="2.2",
requests=resources,
limits=resources,
container_image=custom_image,
)
def mnist_tensorflow_job(hyperparameters: Hyperparameters) -> training_outputs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a simpler task. Let's not make this complicated, and every task has to have a definition.

train_dataset, eval_dataset, strategy = load_data(hyperparameters=hyperparameters)
model = get_compiled_model(strategy=strategy)
model, checkpoint_dir = train_model(model=model, train_dataset=train_dataset, hyperparameters=hyperparameters)
eval_loss, eval_accuracy = test_model(model=model, checkpoint_dir=checkpoint_dir, eval_dataset=eval_dataset)
return training_outputs(accuracy=eval_accuracy, loss=eval_loss, model_state=MODEL_FILE_PATH)


# %% [markdown]
# You can use `@dynamic` to generate tasks at runtime with any custom configurations you want, and `with_overrides` method overrides the old configuration allocations.
# For here we override the worker replica count.
# %%
@workflow
def mnist_tensorflow_workflow(
hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
return mnist_tensorflow_job(hyperparameters=hyperparameters)


@dynamic
def dynamic_run(
new_worker: int,
hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
return mnist_tensorflow_job(hyperparameters=hyperparameters).with_overrides(
task_config=TfJob(worker=Worker(replicas=new_worker), ps=PS(replicas=1), chief=Chief(replicas=1))
)


# %% [markdown]
# You can execute the workflow locally.
# %%
if __name__ == "__main__":
print(mnist_tensorflow_workflow())
print(dynamic_run(new_worker=4))