88from gymnasium import spaces
99from torch import nn
1010from torch .distributions import Bernoulli , Categorical , Normal
11+ from torch .distributions import Distribution as TorchDistribution
1112
1213from stable_baselines3 .common .preprocessing import get_action_dim
1314
2526class Distribution (ABC ):
2627 """Abstract base class for distributions."""
2728
29+ distribution : TorchDistribution | list [TorchDistribution ]
30+
2831 def __init__ (self ):
2932 super ().__init__ ()
30- self .distribution = None
3133
3234 @abstractmethod
3335 def proba_distribution_net (self , * args , ** kwargs ) -> nn .Module | tuple [nn .Module , nn .Parameter ]:
@@ -44,11 +46,11 @@ def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribut
4446 """
4547
4648 @abstractmethod
47- def log_prob (self , x : th .Tensor ) -> th .Tensor :
49+ def log_prob (self , actions : th .Tensor ) -> th .Tensor :
4850 """
4951 Returns the log likelihood
5052
51- :param x : the taken action
53+ :param actions : the taken action
5254 :return: The log likelihood of the distribution
5355 """
5456
@@ -129,11 +131,11 @@ class DiagGaussianDistribution(Distribution):
129131 :param action_dim: Dimension of the action space.
130132 """
131133
134+ distribution : Normal
135+
132136 def __init__ (self , action_dim : int ):
133137 super ().__init__ ()
134138 self .action_dim = action_dim
135- self .mean_actions = None
136- self .log_std = None
137139
138140 def proba_distribution_net (self , latent_dim : int , log_std_init : float = 0.0 ) -> tuple [nn .Module , nn .Parameter ]:
139141 """
@@ -267,6 +269,8 @@ class CategoricalDistribution(Distribution):
267269 :param action_dim: Number of discrete actions
268270 """
269271
272+ distribution : Categorical
273+
270274 def __init__ (self , action_dim : int ):
271275 super ().__init__ ()
272276 self .action_dim = action_dim
@@ -318,6 +322,8 @@ class MultiCategoricalDistribution(Distribution):
318322 :param action_dims: List of sizes of discrete action spaces
319323 """
320324
325+ distribution : list [Categorical ] # type: ignore[assignment]
326+
321327 def __init__ (self , action_dims : list [int ]):
322328 super ().__init__ ()
323329 self .action_dims = action_dims
@@ -375,6 +381,8 @@ class BernoulliDistribution(Distribution):
375381 :param action_dim: Number of binary actions
376382 """
377383
384+ distribution : Bernoulli
385+
378386 def __init__ (self , action_dims : int ):
379387 super ().__init__ ()
380388 self .action_dims = action_dims
@@ -446,6 +454,7 @@ class StateDependentNoiseDistribution(Distribution):
446454 _latent_sde : th .Tensor
447455 exploration_mat : th .Tensor
448456 exploration_matrices : th .Tensor
457+ distribution : Normal
449458
450459 def __init__ (
451460 self ,
@@ -459,8 +468,6 @@ def __init__(
459468 super ().__init__ ()
460469 self .action_dim = action_dim
461470 self .latent_sde_dim = None
462- self .mean_actions = None
463- self .log_std = None
464471 self .use_expln = use_expln
465472 self .full_std = full_std
466473 self .epsilon = epsilon
@@ -704,7 +711,9 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
704711 :return: KL(dist_true||dist_pred)
705712 """
706713 # KL Divergence for different distribution types is out of scope
707- assert dist_true .__class__ == dist_pred .__class__ , "Error: input distributions should be the same type"
714+ assert (
715+ dist_true .__class__ == dist_pred .__class__
716+ ), f"Error: input distributions should be the same type, { dist_true .__class__ } != { dist_pred .__class__ } "
708717
709718 # MultiCategoricalDistribution is not a PyTorch Distribution subclass
710719 # so we need to implement it ourselves!
@@ -723,4 +732,6 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
723732
724733 # Use the PyTorch kl_divergence implementation
725734 else :
735+ assert isinstance (dist_true .distribution , TorchDistribution )
736+ assert isinstance (dist_pred .distribution , TorchDistribution )
726737 return th .distributions .kl_divergence (dist_true .distribution , dist_pred .distribution )
0 commit comments