@@ -798,7 +798,7 @@ def propagate_orbits(
798798 chunk_size : int = 100 ,
799799 max_processes : Optional [int ] = 1 ,
800800 seed : Optional [int ] = None ,
801- ) -> Orbits :
801+ ) -> Union [ Orbits , VariantOrbits ] :
802802 """
803803 Propagate each orbit in orbits to each time in times.
804804
@@ -830,15 +830,21 @@ def propagate_orbits(
830830
831831 Returns
832832 -------
833- propagated : `~adam_core.orbits.orbits.Orbits`
833+ propagated : `~adam_core.orbits.orbits.Orbits` or `~adam_core.orbits.variants.VariantOrbits`
834834 Propagated orbits.
835835 """
836+ if covariance is True and isinstance (orbits , VariantOrbits ):
837+ raise AssertionError ("Covariance is not supported for VariantOrbits" )
838+
836839 if max_processes is None :
837840 max_processes = mp .cpu_count ()
838841
839842 if max_processes > 1 :
840843 propagated_list : List [Orbits ] = []
841- variants_list : List [VariantOrbits ] = []
844+ covariance_variants_list : List [VariantOrbits ] = []
845+ # When the input is VariantOrbits, do not treat them as covariance.
846+ propagated_variants_input_list : List [VariantOrbits ] = []
847+ input_is_variants : Optional [bool ] = None
842848
843849 if RAY_INSTALLED is False :
844850 raise ImportError (
@@ -856,13 +862,18 @@ def propagate_orbits(
856862 times = ray .get (times_ref )
857863
858864 if not isinstance (orbits , ObjectRef ):
865+ input_is_variants = isinstance (orbits , VariantOrbits )
859866 orbits_ref = ray .put (orbits )
860867 else :
861868 orbits_ref = orbits
862869 # We need to dereference the orbits ObjectRef so we can
863870 # check its length for chunking and determine
864871 # if we need to propagate variants
865872 orbits = ray .get (orbits_ref )
873+ input_is_variants = isinstance (orbits , VariantOrbits )
874+
875+ if covariance is True and input_is_variants :
876+ raise AssertionError ("Covariance is not supported for VariantOrbits" )
866877
867878 # Create futures inputs
868879 futures_inputs = []
@@ -910,7 +921,10 @@ def propagate_orbits(
910921 if isinstance (result , Orbits ):
911922 propagated_list .append (result )
912923 elif isinstance (result , VariantOrbits ):
913- variants_list .append (result )
924+ if input_is_variants :
925+ propagated_variants_input_list .append (result )
926+ else :
927+ covariance_variants_list .append (result )
914928 else :
915929 raise ValueError (
916930 f"Unexpected result type from propagation worker: { type (result )} "
@@ -923,22 +937,33 @@ def propagate_orbits(
923937 if isinstance (result , Orbits ):
924938 propagated_list .append (result )
925939 elif isinstance (result , VariantOrbits ):
926- variants_list .append (result )
940+ if input_is_variants :
941+ propagated_variants_input_list .append (result )
942+ else :
943+ covariance_variants_list .append (result )
927944 else :
928945 raise ValueError (
929946 f"Unexpected result type from propagation worker: { type (result )} "
930947 )
931948
932949 # Concatenate propagated orbits
933- propagated = qv .concatenate (propagated_list )
934- if len (variants_list ) > 0 :
935- propagated_variants = qv .concatenate (variants_list )
936- # sort by variant_id and time
937- propagated_variants = propagated_variants .sort_by (
938- ["variant_id" , "coordinates.time.days" , "coordinates.time.nanos" ]
939- )
940- else :
950+ if input_is_variants :
951+ propagated = qv .concatenate (propagated_variants_input_list )
941952 propagated_variants = None
953+ else :
954+ propagated = qv .concatenate (propagated_list )
955+ if len (covariance_variants_list ) > 0 :
956+ propagated_variants = qv .concatenate (covariance_variants_list )
957+ # sort by variant_id and time
958+ propagated_variants = propagated_variants .sort_by (
959+ [
960+ "variant_id" ,
961+ "coordinates.time.days" ,
962+ "coordinates.time.nanos" ,
963+ ]
964+ )
965+ else :
966+ propagated_variants = None
942967
943968 else :
944969 propagated = self ._propagate_orbits (orbits , times )
0 commit comments