Skip to content

Commit 653023d

Browse files
committed
TEST: add multi-GPU tester
1 parent a1c3ea7 commit 653023d

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import argparse
2+
import pytest
3+
import torch
4+
import numpy as np
5+
6+
import ptychi.api as api
7+
from ptychi.api.task import PtychographyTask
8+
from ptychi.utils import get_suggested_object_size, get_default_complex_dtype
9+
10+
import test_utils as tutils
11+
12+
13+
class TestMultislicePtychoAutodiffMultiGPU(tutils.BaseTester):
14+
15+
@pytest.mark.local
16+
@tutils.BaseTester.wrap_recon_tester(name='test_multislice_ptycho_autodiff_multigpu', run_comparison=False)
17+
def test_multislice_ptycho_autodiff_multigpu(self):
18+
self.setup_ptychi(cpu_only=False, gpu_indices=(0, 1))
19+
20+
data, probe, pixel_size_m, positions_px = self.load_data_ptychodus(
21+
*self.get_default_input_data_file_paths('multislice_ptycho_AuNi'),
22+
subtract_position_mean=True
23+
)
24+
wavelength_m = 1.03e-10
25+
26+
options = api.AutodiffPtychographyOptions()
27+
28+
options.data_options.data = data
29+
options.data_options.wavelength_m = wavelength_m
30+
31+
options.object_options.initial_guess = torch.ones([2, *get_suggested_object_size(positions_px, probe.shape[-2:], extra=50)], dtype=get_default_complex_dtype())
32+
options.object_options.pixel_size_m = pixel_size_m
33+
options.object_options.slice_spacings_m = np.array([2e-5])
34+
options.object_options.slice_spacing_options.optimizable = True
35+
options.object_options.slice_spacing_options.step_size = 1e-7
36+
options.object_options.slice_spacing_options.optimizer = api.Optimizers.ADAM
37+
options.object_options.optimizable = True
38+
options.object_options.optimizer = api.Optimizers.ADAM
39+
options.object_options.step_size = 1e-3
40+
41+
options.probe_options.initial_guess = probe
42+
options.probe_options.optimizable = True
43+
options.probe_options.optimizer = api.Optimizers.ADAM
44+
options.probe_options.step_size = 1e-3
45+
46+
options.probe_position_options.position_x_px = positions_px[:, 1]
47+
options.probe_position_options.position_y_px = positions_px[:, 0]
48+
options.probe_position_options.optimizable = True
49+
options.probe_position_options.optimizer = api.Optimizers.SGD
50+
options.probe_position_options.step_size = 1e-1
51+
52+
options.reconstructor_options.forward_model_class = api.ForwardModels.PLANAR_PTYCHOGRAPHY
53+
options.reconstructor_options.loss_function = api.LossFunctions.MSE_SQRT
54+
options.reconstructor_options.use_double_precision_for_fft = True
55+
options.reconstructor_options.batch_size = 101
56+
options.reconstructor_options.num_epochs = 32
57+
options.reconstructor_options.default_device = api.Devices.GPU
58+
options.reconstructor_options.random_seed = 123
59+
options.reconstructor_options.allow_nondeterministic_algorithms = False
60+
61+
task = PtychographyTask(options)
62+
task.run()
63+
64+
recon = task.get_data_to_cpu('object', as_numpy=True)
65+
return recon
66+
67+
68+
if __name__ == '__main__':
69+
parser = argparse.ArgumentParser()
70+
parser.add_argument('--generate-gold', action='store_true')
71+
args = parser.parse_args()
72+
73+
tester = TestMultislicePtychoAutodiffMultiGPU()
74+
tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True)
75+
tester.test_multislice_ptycho_autodiff_multigpu()

0 commit comments

Comments
 (0)