Skip to content

Commit

Permalink
Improvements acquisitions with MPI.
Browse files Browse the repository at this point in the history
Proxy the detector MPI communicator to the instrument instance.
When indexing the detectors (or the samplings) of an acquisition, change the communicator of the detectors (or the samplings) to MPI.COMM_SELF.
Fix clearing the cached operator when indexing the acquisition.
Add MPI-specific tests and execute them with mpirun in the CI.
  • Loading branch information
pchanial committed Jan 31, 2023
1 parent 7d8cca6 commit 9ce7307
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 35 deletions.
14 changes: 11 additions & 3 deletions .github/workflows/build-test-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,19 @@ jobs:
CIBW_SKIP: "*-musllinux_*"
CIBW_ARCHS: ${{ matrix.platform.arch }}
CIBW_BEFORE_TEST_LINUX: |
yum install -y openmpi-devel
MPICC=/lib64/openmpi/bin/mpicc pip install mpi4py
yum install -y openmpi-devel environment-modules
source /usr/share/Modules/init/sh
module load mpi
pip install mpi4py
CIBW_BEFORE_TEST_MACOS: brew install openmpi
CIBW_TEST_EXTRAS: dev
CIBW_TEST_COMMAND: pytest {package}/tests
CIBW_TEST_COMMAND: |
if [[ ${{ matrix.platform.os }} == ubuntu-20.04 ]]; then
source /usr/share/Modules/init/sh
module load mpi
fi
mpirun -np 6 --oversubscribe --allow-run-as-root pytest -m mpi --no-cov {package}/tests
pytest {package}/tests
PYTHONFAULTHANDLER: "1"

- name: Build macosx_arm64
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- --all

- repo: https://github.com/PyCQA/isort
rev: '5.10.1'
rev: '5.12.0'
hooks:
- id: isort
args:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ addopts = "-ra --cov=pysimulators"
testpaths = [
"tests",
]
markers = [
"mpi: mark tests to be run using mpirun.",
]

[tool.setuptools_scm]
version_scheme = "post-release"
Expand Down
38 changes: 29 additions & 9 deletions src/pysimulators/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,43 @@ def __getitem__(self, x):
Restrict to the first 10 pixels of the scene:
>>> new_acq = acq[..., :10]
"""

def is_colon(x):
return isinstance(x, slice) and x == slice(None)

out = copy(self)
if not isinstance(x, tuple):
out.instrument = self.instrument[x]
return out
if len(x) == 2 and x[0] is Ellipsis:
x = Ellipsis, Ellipsis, x[1]
if len(x) > 3:
x = (x,)
elif len(x) == 2 and x[0] is Ellipsis:
x = slice(None), slice(None), x[1]
elif len(x) > 3:
raise ValueError('Invalid selection.')
x = x + (3 - len(x)) * (Ellipsis,)
if x[2] is not Ellipsis and (
not isinstance(x[2], slice) or x[2] == slice(None)
):

x = tuple(slice(None) if _ is Ellipsis else _ for _ in x)
x = x + (3 - len(x)) * (slice(None),)

if all(is_colon(_) for _ in x):
return out

if any(not is_colon(_) for _ in x):
self._operator = None
gc.collect()

out.instrument = self.instrument[x[0]]
if not is_colon(x[0]):
object.__setattr__(out.instrument.detector, 'comm', MPI.COMM_SELF)
out.sampling = self.sampling[x[1]] # XXX FIX BLOCKS!!!
if not is_colon(x[1]):
object.__setattr__(out.sampling, 'comm', MPI.COMM_SELF)
out.scene = self.scene[x[2]]

if not is_colon(x[0]) and not is_colon(x[1]):
out.comm = MPI.COMM_SELF
elif not is_colon(x[0]):
out.comm = self.sampling.comm.Create_cart([1, self.sampling.comm.size])
elif not is_colon(x[1]):
out.comm = self.instrument.comm.Create_cart([self.instrument.comm.size, 1])

return out

def __str__(self):
Expand Down
4 changes: 4 additions & 0 deletions src/pysimulators/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def __iter__(self):
def __len__(self):
return len(self.detector)

@property
def comm(self):
return self.detector.comm

def pack(self, x):
return self.detector.pack(x)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_get_noise1(shape, fknee):
fsamp = 5
sigma = 0.3
scene = Scene(10)
sampling = Sampling(2e4, period=1 / fsamp)
sampling = Sampling(20_000, period=1 / fsamp)
np.random.seed(0)

class MyAcquisition1(Acquisition):
Expand All @@ -58,7 +58,7 @@ def test_get_noise2(shape):
fsamp = 5
sigma = 0.3
scene = Scene(10)
sampling = Sampling(2e4, period=1 / fsamp)
sampling = Sampling(20_000, period=1 / fsamp)
freq = np.arange(6) / 6 * fsamp
psd = np.array([0, 1, 1, 1, 1, 1], float) * sigma**2 / fsamp

Expand All @@ -80,7 +80,7 @@ def test_get_noise3(shape):
fsamp = 5
sigma = 0.3
scene = Scene(10)
sampling = Sampling(2e4, period=1 / fsamp)
sampling = Sampling(20_000, period=1 / fsamp)
freq = np.arange(4) / 6 * fsamp
psd = np.array([0, 2, 2, 1], float) * sigma**2 / fsamp

Expand Down
158 changes: 139 additions & 19 deletions tests/test_acquisitions_mpi.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,144 @@
from itertools import chain, product

import numpy as np
import pytest

from pyoperators import MPI
from pyoperators import MPI, MPIDistributionIdentityOperator
from pyoperators.utils.testing import assert_same
from pysimulators import Acquisition, Instrument, PackedTable, Sampling, Scene
from pysimulators.operators import ProjectionOperator
from pysimulators.sparse import FSRMatrix

pytestmark = pytest.mark.mpi

RANK = MPI.COMM_WORLD.rank
SIZE = MPI.COMM_WORLD.size
NPROCS_INSTRUMENT = sorted(
{int(n) for n in [1, SIZE / 3, SIZE / 2, SIZE] if int(n) == n}
)
NSCENE = 10
NSAMPLING_GLOBAL = 100
NDETECTOR_GLOBAL = 16
SCENE = Scene(NSCENE)
SAMPLING = Sampling(NSAMPLING_GLOBAL, period=1.0)
INSTRUMENT = Instrument('', PackedTable(NDETECTOR_GLOBAL))


class MyAcquisition(Acquisition):
def get_projection_operator(self):
dtype = [('index', int), ('value', float)]
data = np.recarray((len(self.instrument), len(self.sampling), 1), dtype=dtype)
for ilocal, iglobal in enumerate(self.instrument.detector.index):
data[ilocal].value = iglobal
data[ilocal, :, 0].index = [
(iglobal + int(t)) % NSCENE for t in self.sampling.time
]

matrix = FSRMatrix(
(len(self.instrument) * len(self.sampling), NSCENE),
data=data.reshape((-1, 1)),
)

return ProjectionOperator(
matrix, dtype=float, shapeout=(len(self.instrument), len(self.sampling))
)


def get_acquisition(comm, nprocs_instrument):
return MyAcquisition(
INSTRUMENT, SAMPLING, SCENE, comm=comm, nprocs_instrument=nprocs_instrument
)


@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
def test_communicators(nprocs_instrument):
sky = SCENE.ones()
nprocs_sampling = SIZE // nprocs_instrument
serial_acq = get_acquisition(MPI.COMM_SELF, 1)
assert serial_acq.comm.size == 1
assert serial_acq.instrument.comm.size == 1
assert serial_acq.sampling.comm.size == 1
assert len(serial_acq.instrument) == NDETECTOR_GLOBAL
assert len(serial_acq.sampling) == NSAMPLING_GLOBAL

parallel_acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
assert parallel_acq.comm.size == SIZE
assert parallel_acq.instrument.comm.size == nprocs_instrument
assert parallel_acq.sampling.comm.size == nprocs_sampling
assert (
parallel_acq.instrument.comm.allreduce(len(parallel_acq.instrument))
== NDETECTOR_GLOBAL
)
assert (
parallel_acq.sampling.comm.allreduce(len(parallel_acq.sampling))
== NSAMPLING_GLOBAL
)

rank = MPI.COMM_WORLD.rank
size = MPI.COMM_WORLD.size


def test():
scene = Scene(1024)
instrument = Instrument('instrument', PackedTable((32, 32)))
sampling = Sampling(1000)
acq = Acquisition(instrument, sampling, scene, nprocs_sampling=max(size // 2, 1))
print(
acq.comm.rank,
acq.instrument.detector.comm.rank,
'/',
acq.instrument.detector.comm.size,
acq.sampling.comm.rank,
'/',
acq.sampling.comm.size,
serial_H = serial_acq.get_projection_operator()
ref_tod = serial_H(sky)

parallel_H = (
parallel_acq.get_projection_operator()
* MPIDistributionIdentityOperator(parallel_acq.comm)
)
local_tod = parallel_H(sky)
actual_tod = np.vstack(
parallel_acq.instrument.comm.allgather(
np.hstack(parallel_acq.sampling.comm.allgather(local_tod))
)
)
pytest.xfail('the test is not finished.')
assert_same(actual_tod, ref_tod, atol=20)

ref_backproj = serial_H.T(ref_tod)
actual_backproj = parallel_H.T(local_tod)
assert_same(actual_backproj, ref_backproj, atol=20)


@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
@pytest.mark.parametrize(
'selection',
[
Ellipsis,
slice(None),
]
+ list(chain(*(product([slice(None), Ellipsis], repeat=n) for n in [1, 2, 3]))),
)
def test_communicators_getitem_all(nprocs_instrument, selection):
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
assert acq.instrument.comm.size == nprocs_instrument
assert acq.sampling.comm.size == MPI.COMM_WORLD.size / nprocs_instrument
assert acq.comm.size == MPI.COMM_WORLD.size
restricted_acq = acq[selection]
assert restricted_acq.instrument.comm.size == nprocs_instrument
assert restricted_acq.sampling.comm.size == MPI.COMM_WORLD.size / nprocs_instrument
assert restricted_acq.comm.size == MPI.COMM_WORLD.size


@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
@pytest.mark.parametrize('selection', [0, slice(None, 1), np.array])
def test_communicators_getitem_instrument(nprocs_instrument, selection):
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
if selection is np.array:
selection = np.zeros(len(acq.instrument), bool)
selection[0] = True
restricted_acq = acq[selection]
assert restricted_acq.instrument.comm.size == 1
assert restricted_acq.sampling.comm.size == acq.sampling.comm.size
assert restricted_acq.comm.size == acq.sampling.comm.size


SELECTION_GETITEM_SAMPLING = np.zeros(NSAMPLING_GLOBAL, bool)
SELECTION_GETITEM_SAMPLING[0] = True


@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
@pytest.mark.parametrize('selection', [0, slice(None, 1), np.array])
def test_communicators_getitem_sampling(nprocs_instrument, selection):
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
if selection is np.array:
selection = np.zeros(len(acq.sampling), bool)
selection[0] = True
restricted_acq = acq[:, selection]
assert restricted_acq.instrument.comm.size == acq.instrument.comm.size
assert restricted_acq.sampling.comm.size == 1
assert restricted_acq.comm.size == acq.instrument.comm.size

0 comments on commit 9ce7307

Please sign in to comment.