|
| 1 | +import argparse |
| 2 | +import pytest |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +import ptychi.api as api |
| 7 | +from ptychi.api.task import PtychographyTask |
| 8 | +from ptychi.utils import get_suggested_object_size, get_default_complex_dtype |
| 9 | + |
| 10 | +import test_utils as tutils |
| 11 | + |
| 12 | + |
| 13 | +class TestMultislicePtychoAutodiffMultiGPU(tutils.BaseTester): |
| 14 | + |
| 15 | + @pytest.mark.local |
| 16 | + @tutils.BaseTester.wrap_recon_tester(name='test_multislice_ptycho_autodiff_multigpu', run_comparison=False) |
| 17 | + def test_multislice_ptycho_autodiff_multigpu(self): |
| 18 | + self.setup_ptychi(cpu_only=False, gpu_indices=(0, 1)) |
| 19 | + |
| 20 | + data, probe, pixel_size_m, positions_px = self.load_data_ptychodus( |
| 21 | + *self.get_default_input_data_file_paths('multislice_ptycho_AuNi'), |
| 22 | + subtract_position_mean=True |
| 23 | + ) |
| 24 | + wavelength_m = 1.03e-10 |
| 25 | + |
| 26 | + options = api.AutodiffPtychographyOptions() |
| 27 | + |
| 28 | + options.data_options.data = data |
| 29 | + options.data_options.wavelength_m = wavelength_m |
| 30 | + |
| 31 | + options.object_options.initial_guess = torch.ones([2, *get_suggested_object_size(positions_px, probe.shape[-2:], extra=50)], dtype=get_default_complex_dtype()) |
| 32 | + options.object_options.pixel_size_m = pixel_size_m |
| 33 | + options.object_options.slice_spacings_m = np.array([2e-5]) |
| 34 | + options.object_options.slice_spacing_options.optimizable = True |
| 35 | + options.object_options.slice_spacing_options.step_size = 1e-7 |
| 36 | + options.object_options.slice_spacing_options.optimizer = api.Optimizers.ADAM |
| 37 | + options.object_options.optimizable = True |
| 38 | + options.object_options.optimizer = api.Optimizers.ADAM |
| 39 | + options.object_options.step_size = 1e-3 |
| 40 | + |
| 41 | + options.probe_options.initial_guess = probe |
| 42 | + options.probe_options.optimizable = True |
| 43 | + options.probe_options.optimizer = api.Optimizers.ADAM |
| 44 | + options.probe_options.step_size = 1e-3 |
| 45 | + |
| 46 | + options.probe_position_options.position_x_px = positions_px[:, 1] |
| 47 | + options.probe_position_options.position_y_px = positions_px[:, 0] |
| 48 | + options.probe_position_options.optimizable = True |
| 49 | + options.probe_position_options.optimizer = api.Optimizers.SGD |
| 50 | + options.probe_position_options.step_size = 1e-1 |
| 51 | + |
| 52 | + options.reconstructor_options.forward_model_class = api.ForwardModels.PLANAR_PTYCHOGRAPHY |
| 53 | + options.reconstructor_options.loss_function = api.LossFunctions.MSE_SQRT |
| 54 | + options.reconstructor_options.use_double_precision_for_fft = True |
| 55 | + options.reconstructor_options.batch_size = 101 |
| 56 | + options.reconstructor_options.num_epochs = 32 |
| 57 | + options.reconstructor_options.default_device = api.Devices.GPU |
| 58 | + options.reconstructor_options.random_seed = 123 |
| 59 | + options.reconstructor_options.allow_nondeterministic_algorithms = False |
| 60 | + |
| 61 | + task = PtychographyTask(options) |
| 62 | + task.run() |
| 63 | + |
| 64 | + recon = task.get_data_to_cpu('object', as_numpy=True) |
| 65 | + return recon |
| 66 | + |
| 67 | + |
| 68 | +if __name__ == '__main__': |
| 69 | + parser = argparse.ArgumentParser() |
| 70 | + parser.add_argument('--generate-gold', action='store_true') |
| 71 | + args = parser.parse_args() |
| 72 | + |
| 73 | + tester = TestMultislicePtychoAutodiffMultiGPU() |
| 74 | + tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True) |
| 75 | + tester.test_multislice_ptycho_autodiff_multigpu() |
0 commit comments