diff --git a/tests/pytorch/common.libsonnet b/tests/pytorch/common.libsonnet index 08b2aef66..661f6e180 100644 --- a/tests/pytorch/common.libsonnet +++ b/tests/pytorch/common.libsonnet @@ -58,20 +58,6 @@ local volumes = import 'templates/volumes.libsonnet'; PyTorchTest:: PyTorchBaseTest { local config = self, - entrypoint: [ - 'bash', - '-cxue', - ||| - if [[ ! -z "$(KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS)" ]]; then - # Trim grpc:// prefix - export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" - fi - - # Run whatever is in `command` here - docker-entrypoint.sh "${@:0}" - |||, - ], - volumeMap+: { dshm: volumes.MemoryVolumeSpec { name: 'dshm', diff --git a/tests/pytorch/nightly/common.libsonnet b/tests/pytorch/nightly/common.libsonnet index 21e9a6983..cd3e679d4 100644 --- a/tests/pytorch/nightly/common.libsonnet +++ b/tests/pytorch/nightly/common.libsonnet @@ -118,7 +118,24 @@ local volumes = import 'templates/volumes.libsonnet'; }, GpuMixin:: { local config = self, - imageTag+: '_cuda_11.8', + + # TODO: merge common setup with PyTorchTpuVmMixin + entrypoint: [ + 'bash', + '-cxue', + ||| + git clone --depth=1 https://github.com/pytorch/pytorch.git + cd pytorch + git clone https://github.com/pytorch/xla.git + cd .. + + # Run whatever is in `command` here + "${@:0}" + |||, + ], + + image: 'us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla', + imageTag: 'nightly_3.8_cuda_11.8', podTemplate+:: { spec+: { @@ -128,6 +145,7 @@ local volumes = import 'templates/volumes.libsonnet'; containerMap+:: { train+: { envMap+: { + PJRT_DEVICE: 'GPU', GPU_NUM_DEVICES: '%d' % config.accelerator.count, }, }, diff --git a/tests/pytorch/nightly/mnist.libsonnet b/tests/pytorch/nightly/mnist.libsonnet index 6de92c43e..5ce9836b9 100644 --- a/tests/pytorch/nightly/mnist.libsonnet +++ b/tests/pytorch/nightly/mnist.libsonnet @@ -91,20 +91,6 @@ local utils = import 'templates/utils.libsonnet'; v4_8:: { accelerator: tpus.v4_8, }, - local gpu = self.gpu, - gpu:: common.GpuMixin { - // Disable XLA metrics report on GPU - command+: [ - '--nometrics_debug', - ], - flags+: { - modelDir: null, - }, - }, - local v100x4 = self.v100x4, - v100x4:: gpu { - accelerator: gpus.teslaV100 { count: 4 }, - }, local tpuVm = self.tpuVm, tpuVm:: common.PyTorchTpuVmMixin { @@ -135,6 +121,5 @@ local utils = import 'templates/utils.libsonnet'; mnist + convergence + v2_8 + timeouts.Hours(1) + pjrt, mnist + convergence_ddp + v2_8 + timeouts.Hours(1) + pjrt + pjrt_ddp, mnist + convergence + v4_8 + timeouts.Hours(1) + pjrt + mixins.Experimental, - mnist + convergence + v100x4 + timeouts.Hours(6) + mixins.Experimental, ], } diff --git a/tests/pytorch/nightly/resnet50-mp.libsonnet b/tests/pytorch/nightly/resnet50-mp.libsonnet index f3f184179..a0ad7db31 100644 --- a/tests/pytorch/nightly/resnet50-mp.libsonnet +++ b/tests/pytorch/nightly/resnet50-mp.libsonnet @@ -42,12 +42,6 @@ local tpus = import 'templates/tpus.libsonnet'; flags:: { modelDir: '$(MODEL_DIR)', }, - volumeMap+: { - datasets: common.datasetsVolume, - }, - - cpu: '90.0', - memory: '400Gi', }, local fake_data = self.fake_data, @@ -63,6 +57,9 @@ local tpus = import 'templates/tpus.libsonnet'; '--num_epochs=2', '--datadir=/datasets/imagenet-mini', ], + volumeMap+: { + datasets: common.datasetsVolume, + }, }, local convergence = self.convergence, convergence:: common.Convergence { @@ -91,6 +88,9 @@ local tpus = import 'templates/tpus.libsonnet'; }, }, }, + volumeMap+: { + datasets: common.datasetsVolume, + }, }, // DDP converges worse than MP. local convergence_ddp = self.convergence_ddp, @@ -201,7 +201,6 @@ local tpus = import 'templates/tpus.libsonnet'; configs: [ // XRT - resnet50 + functional + v100x4 + timeouts.Hours(1), resnet50 + functional + v3_8 + timeouts.Hours(2) + tpuVm + mixins.Experimental, resnet50 + fake_data + v3_8 + timeouts.Hours(2) + tpuVm, resnet50 + fake_data + v3_8 + timeouts.Hours(2) + tpuVm + xrt_ddp, @@ -214,6 +213,7 @@ local tpus = import 'templates/tpus.libsonnet'; resnet50 + convergence + v4_8 + timeouts.Hours(24) + tpuVm + mixins.Experimental, resnet50 + convergence + v4_32 + timeouts.Hours(24) + tpuVm + mixins.Experimental, // PJRT + resnet50 + fake_data + v100x4 + timeouts.Hours(1), resnet50 + fake_data + v3_8 + timeouts.Hours(2) + pjrt, resnet50 + convergence + v3_8 + timeouts.Hours(24) + pjrt, resnet50 + fake_data + v3_8 + timeouts.Hours(2) + pjrt + pjrt_ddp,