Skip to content

Commit a327156

Browse files
authored
Merge pull request #181 from B612-Asteroid-Institute/kk/variant_fix
Kk/variant fix
2 parents a2bb925 + 3fe87e7 commit a327156

File tree

5 files changed

+62
-16
lines changed

5 files changed

+62
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010
]
1111
description = "Core libraries for the ADAM platform"
1212
readme = "README.md"
13-
requires-python = ">=3.11"
13+
requires-python = ">=3.11,<3.14"
1414
classifiers = [
1515
"Operating System :: OS Independent",
1616
"Development Status :: 4 - Beta",

src/adam_core/dynamics/tests/test_propagation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def test_propagate_2body_preserves_physical_parameters():
500500
np.testing.assert_allclose(have_H, expected_H)
501501
np.testing.assert_allclose(have_G, expected_G)
502502

503+
503504
@pytest.mark.profile
504505
def test_profile_propagate_2body_matrix(propagated_orbits, tmp_path):
505506
"""Profile the propagate_2body function with different combinations of orbits and times.

src/adam_core/orbits/tests/test_variants.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def test_VariantOrbits_collapse_by_object_id():
135135
orbit_id=["obj1", "obj1"],
136136
object_id=["obj1", "obj1"],
137137
variant_id=["0", "1"],
138-
physical_parameters=PhysicalParameters.from_kwargs(H_v=[15.0, 15.0], G=[0.15, 0.15]),
138+
physical_parameters=PhysicalParameters.from_kwargs(
139+
H_v=[15.0, 15.0], G=[0.15, 0.15]
140+
),
139141
coordinates=CartesianCoordinates.from_kwargs(
140142
x=[1.0, 1.1],
141143
y=[1.0, 1.1],
@@ -156,7 +158,9 @@ def test_VariantOrbits_collapse_by_object_id():
156158
orbit_id=["obj1", "obj1"],
157159
object_id=["obj1", "obj1"],
158160
variant_id=["0", "1"],
159-
physical_parameters=PhysicalParameters.from_kwargs(H_v=[15.0, 15.0], G=[0.15, 0.15]),
161+
physical_parameters=PhysicalParameters.from_kwargs(
162+
H_v=[15.0, 15.0], G=[0.15, 0.15]
163+
),
160164
coordinates=CartesianCoordinates.from_kwargs(
161165
x=[1.0, 1.1],
162166
y=[1.0, 1.1],

src/adam_core/propagator/propagator.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def propagate_orbits(
798798
chunk_size: int = 100,
799799
max_processes: Optional[int] = 1,
800800
seed: Optional[int] = None,
801-
) -> Orbits:
801+
) -> Union[Orbits, VariantOrbits]:
802802
"""
803803
Propagate each orbit in orbits to each time in times.
804804
@@ -830,15 +830,21 @@ def propagate_orbits(
830830
831831
Returns
832832
-------
833-
propagated : `~adam_core.orbits.orbits.Orbits`
833+
propagated : `~adam_core.orbits.orbits.Orbits` or `~adam_core.orbits.variants.VariantOrbits`
834834
Propagated orbits.
835835
"""
836+
if covariance is True and isinstance(orbits, VariantOrbits):
837+
raise AssertionError("Covariance is not supported for VariantOrbits")
838+
836839
if max_processes is None:
837840
max_processes = mp.cpu_count()
838841

839842
if max_processes > 1:
840843
propagated_list: List[Orbits] = []
841-
variants_list: List[VariantOrbits] = []
844+
covariance_variants_list: List[VariantOrbits] = []
845+
# When the input is VariantOrbits, do not treat them as covariance.
846+
propagated_variants_input_list: List[VariantOrbits] = []
847+
input_is_variants: Optional[bool] = None
842848

843849
if RAY_INSTALLED is False:
844850
raise ImportError(
@@ -856,13 +862,18 @@ def propagate_orbits(
856862
times = ray.get(times_ref)
857863

858864
if not isinstance(orbits, ObjectRef):
865+
input_is_variants = isinstance(orbits, VariantOrbits)
859866
orbits_ref = ray.put(orbits)
860867
else:
861868
orbits_ref = orbits
862869
# We need to dereference the orbits ObjectRef so we can
863870
# check its length for chunking and determine
864871
# if we need to propagate variants
865872
orbits = ray.get(orbits_ref)
873+
input_is_variants = isinstance(orbits, VariantOrbits)
874+
875+
if covariance is True and input_is_variants:
876+
raise AssertionError("Covariance is not supported for VariantOrbits")
866877

867878
# Create futures inputs
868879
futures_inputs = []
@@ -910,7 +921,10 @@ def propagate_orbits(
910921
if isinstance(result, Orbits):
911922
propagated_list.append(result)
912923
elif isinstance(result, VariantOrbits):
913-
variants_list.append(result)
924+
if input_is_variants:
925+
propagated_variants_input_list.append(result)
926+
else:
927+
covariance_variants_list.append(result)
914928
else:
915929
raise ValueError(
916930
f"Unexpected result type from propagation worker: {type(result)}"
@@ -923,22 +937,33 @@ def propagate_orbits(
923937
if isinstance(result, Orbits):
924938
propagated_list.append(result)
925939
elif isinstance(result, VariantOrbits):
926-
variants_list.append(result)
940+
if input_is_variants:
941+
propagated_variants_input_list.append(result)
942+
else:
943+
covariance_variants_list.append(result)
927944
else:
928945
raise ValueError(
929946
f"Unexpected result type from propagation worker: {type(result)}"
930947
)
931948

932949
# Concatenate propagated orbits
933-
propagated = qv.concatenate(propagated_list)
934-
if len(variants_list) > 0:
935-
propagated_variants = qv.concatenate(variants_list)
936-
# sort by variant_id and time
937-
propagated_variants = propagated_variants.sort_by(
938-
["variant_id", "coordinates.time.days", "coordinates.time.nanos"]
939-
)
940-
else:
950+
if input_is_variants:
951+
propagated = qv.concatenate(propagated_variants_input_list)
941952
propagated_variants = None
953+
else:
954+
propagated = qv.concatenate(propagated_list)
955+
if len(covariance_variants_list) > 0:
956+
propagated_variants = qv.concatenate(covariance_variants_list)
957+
# sort by variant_id and time
958+
propagated_variants = propagated_variants.sort_by(
959+
[
960+
"variant_id",
961+
"coordinates.time.days",
962+
"coordinates.time.nanos",
963+
]
964+
)
965+
else:
966+
propagated_variants = None
942967

943968
else:
944969
propagated = self._propagate_orbits(orbits, times)

src/adam_core/propagator/tests/test_propagator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,22 @@ def test_propagator_multiple_workers_ray():
233233
assert len(have) == len(orbits) * len(times)
234234

235235

236+
@pytest.mark.skipif(not RAY_INSTALLED, reason="Ray not installed")
237+
def test_propagate_orbits_multiple_workers_ray_variant_orbits_input():
238+
"""
239+
Regression test: VariantOrbits should be supported as an input to propagate_orbits
240+
under the ray parallel dispatcher.
241+
"""
242+
base = make_real_orbits(4)
243+
variants = VariantOrbits.create(base, method="sigma-point")
244+
times = Timestamp.from_iso8601(["2020-01-01T00:00:00", "2020-01-01T00:00:01"])
245+
246+
prop = MockPropagator()
247+
have = prop.propagate_orbits(variants, times, max_processes=2)
248+
assert isinstance(have, VariantOrbits)
249+
assert len(have) == len(variants) * len(times)
250+
251+
236252
def test_propagate_different_origins():
237253
"""
238254
Test that we are returning propagated orbits with their original origins

0 commit comments

Comments
 (0)