Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 2.8.0a0 (WIP)
Release 2.8.0a1 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -35,6 +35,8 @@ Deprecations:
Others:
^^^^^^^
- Updated to Python 3.10+ annotations
- Remove some unused variables (@unexploredtest)
- Improve type hints for distributions

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1939,4 +1941,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest
27 changes: 19 additions & 8 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from torch.distributions import Distribution as TorchDistribution

from stable_baselines3.common.preprocessing import get_action_dim

Expand All @@ -25,9 +26,10 @@
class Distribution(ABC):
"""Abstract base class for distributions."""

distribution: TorchDistribution | list[TorchDistribution]

def __init__(self):
super().__init__()
self.distribution = None

@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple[nn.Module, nn.Parameter]:
Expand All @@ -44,11 +46,11 @@ def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribut
"""

@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood

:param x: the taken action
:param actions: the taken action
:return: The log likelihood of the distribution
"""

Expand Down Expand Up @@ -129,11 +131,11 @@ class DiagGaussianDistribution(Distribution):
:param action_dim: Dimension of the action space.
"""

distribution: Normal

def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None

def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
"""
Expand Down Expand Up @@ -267,6 +269,8 @@ class CategoricalDistribution(Distribution):
:param action_dim: Number of discrete actions
"""

distribution: Categorical

def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
Expand Down Expand Up @@ -318,6 +322,8 @@ class MultiCategoricalDistribution(Distribution):
:param action_dims: List of sizes of discrete action spaces
"""

distribution: list[Categorical] # type: ignore[assignment]

def __init__(self, action_dims: list[int]):
super().__init__()
self.action_dims = action_dims
Expand Down Expand Up @@ -375,6 +381,8 @@ class BernoulliDistribution(Distribution):
:param action_dim: Number of binary actions
"""

distribution: Bernoulli

def __init__(self, action_dims: int):
super().__init__()
self.action_dims = action_dims
Expand Down Expand Up @@ -446,6 +454,7 @@ class StateDependentNoiseDistribution(Distribution):
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor
distribution: Normal

def __init__(
self,
Expand All @@ -459,8 +468,6 @@ def __init__(
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
Expand Down Expand Up @@ -704,7 +711,9 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
assert (
dist_true.__class__ == dist_pred.__class__
), f"Error: input distributions should be the same type, {dist_true.__class__} != {dist_pred.__class__}"

# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
Expand All @@ -723,4 +732,6 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor

# Use the PyTorch kl_divergence implementation
else:
assert isinstance(dist_true.distribution, TorchDistribution)
assert isinstance(dist_pred.distribution, TorchDistribution)
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0a0
2.8.0a1
2 changes: 1 addition & 1 deletion tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_continuous(model_class):
elif model_class in [A2C]:
kwargs["policy_kwargs"]["log_std_init"] = -0.5
elif model_class == PPO:
kwargs = dict(n_steps=512, n_epochs=5)
kwargs = dict(n_steps=512, n_epochs=5, seed=0)

model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps)

Expand Down