Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions examples/plugins/ray_existing_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# For a local test, you can run the following commands:s
# Run the following command to start the Ray cluster:
# `ray start --head --port=6379 --dashboard-host=0.0.0.0`
# `export RAY_CLUSTER_ADDRESS=ray://127.0.0.1:10001`
# Then run flyte locally with:
# `flyte run --local ray_existing_example.py hello_ray_nested`

import asyncio
import typing
import os

import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

import flyte.remote
import flyte.storage


@ray.remote
def f(x):
return x * x

image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.49.0", "flyteplugins-ray", "pip")
)

task_env = flyte.TaskEnvironment(
name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)


@task_env.task()
async def hello_ray():
await asyncio.sleep(20)
print("Hello from the Ray task!")


@task_env.task
async def hello_ray_nested(n: int = 3, cluster_address: str = "ray://localhost:10001") -> typing.List[int]:
# Get cluster address from environment variable
ray.init(address=cluster_address)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it should just work

print("running ray task")
t = asyncio.create_task(hello_ray())
futures = [f.remote(i) for i in range(n)]
res = ray.get(futures)
await t
return res


if __name__ == "__main__":
flyte.init_from_config("../../config.yaml")
run = flyte.run(hello_ray_nested)
print("run name:", run.name)
print("run url:", run.url)
run.wait(run)

action_details = flyte.remote.ActionDetails.get(run_name=run.name, name="a0")
for log in action_details.pb2.attempts[-1].log_info:
print(f"{log.name}: {log.uri}")
Loading