Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.8.0a0 (WIP)
Release 2.8.0a1 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -35,6 +35,8 @@ Deprecations:
Others:
^^^^^^^
- Updated to Python 3.10+ annotations
- Remove some unused variables (@unexploredtest)
- Improve type hints for distributions

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1939,4 +1941,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest
27 changes: 19 additions & 8 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from torch.distributions import Distribution as TorchDistribution

from stable_baselines3.common.preprocessing import get_action_dim

Expand All @@ -25,9 +26,10 @@
class Distribution(ABC):
"""Abstract base class for distributions."""

distribution: TorchDistribution | list[TorchDistribution]

def __init__(self):
super().__init__()
self.distribution = None

@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple[nn.Module, nn.Parameter]:
Expand All @@ -44,11 +46,11 @@ def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribut
"""

@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood

:param x: the taken action
:param actions: the taken action
:return: The log likelihood of the distribution
"""

Expand Down Expand Up @@ -129,11 +131,11 @@ class DiagGaussianDistribution(Distribution):
:param action_dim: Dimension of the action space.
"""

distribution: Normal

def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None

def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
"""
Expand Down Expand Up @@ -267,6 +269,8 @@ class CategoricalDistribution(Distribution):
:param action_dim: Number of discrete actions
"""

distribution: Categorical

def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
Expand Down Expand Up @@ -318,6 +322,8 @@ class MultiCategoricalDistribution(Distribution):
:param action_dims: List of sizes of discrete action spaces
"""

distribution: list[Categorical] # type: ignore[assignment]

def __init__(self, action_dims: list[int]):
super().__init__()
self.action_dims = action_dims
Expand Down Expand Up @@ -375,6 +381,8 @@ class BernoulliDistribution(Distribution):
:param action_dim: Number of binary actions
"""

distribution: Bernoulli

def __init__(self, action_dims: int):
super().__init__()
self.action_dims = action_dims
Expand Down Expand Up @@ -446,6 +454,7 @@ class StateDependentNoiseDistribution(Distribution):
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor
distribution: Normal

def __init__(
self,
Expand All @@ -459,8 +468,6 @@ def __init__(
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
Expand Down Expand Up @@ -704,7 +711,9 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
assert (
dist_true.__class__ == dist_pred.__class__
), f"Error: input distributions should be the same type, {dist_true.__class__} != {dist_pred.__class__}"

# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
Expand All @@ -723,4 +732,6 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor

# Use the PyTorch kl_divergence implementation
else:
assert isinstance(dist_true.distribution, TorchDistribution)
assert isinstance(dist_pred.distribution, TorchDistribution)
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0a0
2.8.0a1
Loading