Skip to content

Blueprint Proposal: Distributed Training on AWS Trainium with Flyte #158

@samhita-alla

Description

@samhita-alla

Community Note

  • Please vote on this issue by adding a 👍 reaction to the original issue to help the community and maintainers prioritize this request
  • Please do not leave "+1" or other comments that do not add relevant new information or questions, they generate extra noise for issue followers and do not help prioritize the request
  • If you are interested in working on this issue or have submitted a pull request, please leave a comment

What is the outcome that you are trying to reach?

Distributed model training can be challenging when building the entire pipeline end-to-end.

A typical setup often involves:

  • Data ingestion and preprocessing for large datasets
  • A distributed training execution layer
  • Experiment tracking and observability
  • Proper containerization for environment isolation
  • Orchestration with Kubernetes (EKS)
  • Durability, fault tolerance, caching, and reproducibility

While each of these components can be set up independently, stitching them together into a single, reliable, and reproducible pipeline often requires significant engineering effort. The existing examples focus narrowly on the training step, leaving orchestration, retries, caching, and multi-step workflows as a separate challenge for teams to solve.

Describe the solution you would like

I propose a fully open-source blueprint for distributed training on AWS Trainium that runs on EKS, with Flyte as the orchestration layer. For the model, I can either fine-tune BERT on an Arabic reviews dataset or pre-train BERT on the FineWeb dataset. Happy to proceed with whichever is preferred.

This blueprint would go beyond just running a training job. It would demonstrate how to manage the entire lifecycle in a reproducible and scalable way, including:

  • Multi-node distributed training on Trainium
  • Data streaming and preprocessing for large datasets
  • Integration with PyTorch Neuron/Optimum Neuron for Trainium acceleration
  • Built-in experiment tracking and observability
  • Automatic caching, retries, and fault tolerance
  • Integration with AWS-native services like S3 (storage), CloudWatch (logs), and ECR (images)
  • Fully open-source stack: PyTorch, streaming library, Neuron SDK, Flyte, Kubeflow Training Operator, Kubernetes/EKS

By providing this orchestration-first approach, we give AI practitioners a ready-to-use blueprint that is scalable, reproducible, and easy to adapt, reducing the operational burden while keeping the infrastructure open source and AWS-native.

Describe alternatives you have considered

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions