Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc: Document how to enable distributed error aggregation according to RFC #5598 for pytorch distributed tasks #1776

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions examples/kfpytorch_plugin/kfpytorch_plugin/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,14 @@ def pytorch_training_wf(
# To visualize the outcomes, you can point Tensorboard on your local machine to these storage locations.
# :::
#
# :::{note}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit of scope creep:
I'm moving this section up before the "pytorch elastic" section as this affects only task_config=Pytorch tasks. Tasks using task_config=Elastic do this by default here.

# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary.
# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow,
# you have the option to raise an
# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html)
# for all remaining ranks.
# :::
#
# ## Pytorch elastic training (torchrun)
#
# Flyte supports distributed training through [torch elastic](https://pytorch.org/docs/stable/elastic/run.html) using `torchrun`.
Expand Down Expand Up @@ -388,10 +396,19 @@ def pytorch_training_wf(
#
# This configuration runs distributed training on two nodes, each with four worker processes.
#
# :::{note}
# In the context of distributed training, it's important to acknowledge that return values from various workers could potentially vary.
# If you need to regulate which worker's return value gets passed on to subsequent tasks in the workflow,
# you have the option to raise an
# [IgnoreOutputs exception](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.core.base_task.IgnoreOutputs.html)
# for all remaining ranks.
# :::
# ## Error handling for distributed PyTorch tasks
#
# Exceptions occurring in Flyte task pods are propagated to the Flyte backend by writing so-called *error files* into
# a preconfigured location in blob storage. In the case of PyTorch distributed tasks, each failed worker pod tries to write such
# an error file. By default, only a single error file is expected and evaluated by the backend leading to a race condition
# as it is not deterministic which worker pod's error file is considered. Flyte can aggregate the error files of all worker pods
# and use the timestamp of the exceptions to try to determine the root cause error. To enable this behavior, add the following to your
# helm chart values:
#
# ```yaml
# configmap:
# k8s:
# plugins:
# k8s:
# enable-distributed-error-aggregation: true
# ```
Loading