diff --git a/dags/common/vm_resource.py b/dags/common/vm_resource.py index 64b894bf..94a162f9 100644 --- a/dags/common/vm_resource.py +++ b/dags/common/vm_resource.py @@ -333,6 +333,10 @@ class DockerImage(enum.Enum): "gcr.io/tpu-prod-env-multipod/maxtext_gpu_jax_stable_stack:" f"{datetime.datetime.today().strftime('%Y-%m-%d')}" ) + MAXTEXT_GPU_JAX_STABLE = ( + "gcr.io/tpu-prod-env-multipod/maxtext_gpu_jax_stable:" + f"{datetime.datetime.today().strftime('%Y-%m-%d')}" + ) MAXTEXT_GPU_STABLE_STACK_NIGHTLY_JAX = ( "gcr.io/tpu-prod-env-multipod/maxtext_gpu_stable_stack_nightly_jax:" f"{datetime.datetime.today().strftime('%Y-%m-%d')}" diff --git a/dags/multipod/maxtext_gpu_end_to_end.py b/dags/multipod/maxtext_gpu_end_to_end.py index 0dc6f325..f8683ee0 100644 --- a/dags/multipod/maxtext_gpu_end_to_end.py +++ b/dags/multipod/maxtext_gpu_end_to_end.py @@ -210,7 +210,7 @@ def run_maxtext_tests(dag: models.DAG): run_model_cmds=(test_script,), num_slices=nnodes, cluster=XpkClusters.GPU_A3_CLUSTER, - docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value, + docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE.value, test_owner=test_owner.YUWEI_Y, ).run_with_quarantine(quarantine_task_group) pinned_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config( @@ -228,7 +228,7 @@ def run_maxtext_tests(dag: models.DAG): run_model_cmds=(test_script,), num_slices=nnodes, cluster=XpkClusters.GPU_A3PLUS_CLUSTER, - docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value, + docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE.value, test_owner=test_owner.YUWEI_Y, ).run_with_quarantine(quarantine_task_group) pinned_a3_gpu >> stable_a3_gpu >> pinned_a3plus_gpu >> stable_a3plus_gpu