Skip to content

Commit c6ce50f

Browse files
Improve type hints for distributions and cleanup (#2200)
* Remove some unused variables * Update changelog.rst * Update changelog.rst * Improve type hints for distributions * fix: add missing seed in PPO identity test --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 58ac622 commit c6ce50f

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

docs/misc/changelog.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 2.8.0a0 (WIP)
7+
Release 2.8.0a1 (WIP)
88
--------------------------
99

1010
Breaking Changes:
@@ -35,6 +35,8 @@ Deprecations:
3535
Others:
3636
^^^^^^^
3737
- Updated to Python 3.10+ annotations
38+
- Remove some unused variables (@unexploredtest)
39+
- Improve type hints for distributions
3840

3941
Documentation:
4042
^^^^^^^^^^^^^^
@@ -1939,4 +1941,4 @@ And all the contributors:
19391941
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
19401942
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
19411943
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
1942-
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti
1944+
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest

stable_baselines3/common/distributions.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from gymnasium import spaces
99
from torch import nn
1010
from torch.distributions import Bernoulli, Categorical, Normal
11+
from torch.distributions import Distribution as TorchDistribution
1112

1213
from stable_baselines3.common.preprocessing import get_action_dim
1314

@@ -25,9 +26,10 @@
2526
class Distribution(ABC):
2627
"""Abstract base class for distributions."""
2728

29+
distribution: TorchDistribution | list[TorchDistribution]
30+
2831
def __init__(self):
2932
super().__init__()
30-
self.distribution = None
3133

3234
@abstractmethod
3335
def proba_distribution_net(self, *args, **kwargs) -> nn.Module | tuple[nn.Module, nn.Parameter]:
@@ -44,11 +46,11 @@ def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribut
4446
"""
4547

4648
@abstractmethod
47-
def log_prob(self, x: th.Tensor) -> th.Tensor:
49+
def log_prob(self, actions: th.Tensor) -> th.Tensor:
4850
"""
4951
Returns the log likelihood
5052
51-
:param x: the taken action
53+
:param actions: the taken action
5254
:return: The log likelihood of the distribution
5355
"""
5456

@@ -129,11 +131,11 @@ class DiagGaussianDistribution(Distribution):
129131
:param action_dim: Dimension of the action space.
130132
"""
131133

134+
distribution: Normal
135+
132136
def __init__(self, action_dim: int):
133137
super().__init__()
134138
self.action_dim = action_dim
135-
self.mean_actions = None
136-
self.log_std = None
137139

138140
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
139141
"""
@@ -267,6 +269,8 @@ class CategoricalDistribution(Distribution):
267269
:param action_dim: Number of discrete actions
268270
"""
269271

272+
distribution: Categorical
273+
270274
def __init__(self, action_dim: int):
271275
super().__init__()
272276
self.action_dim = action_dim
@@ -318,6 +322,8 @@ class MultiCategoricalDistribution(Distribution):
318322
:param action_dims: List of sizes of discrete action spaces
319323
"""
320324

325+
distribution: list[Categorical] # type: ignore[assignment]
326+
321327
def __init__(self, action_dims: list[int]):
322328
super().__init__()
323329
self.action_dims = action_dims
@@ -375,6 +381,8 @@ class BernoulliDistribution(Distribution):
375381
:param action_dim: Number of binary actions
376382
"""
377383

384+
distribution: Bernoulli
385+
378386
def __init__(self, action_dims: int):
379387
super().__init__()
380388
self.action_dims = action_dims
@@ -446,6 +454,7 @@ class StateDependentNoiseDistribution(Distribution):
446454
_latent_sde: th.Tensor
447455
exploration_mat: th.Tensor
448456
exploration_matrices: th.Tensor
457+
distribution: Normal
449458

450459
def __init__(
451460
self,
@@ -459,8 +468,6 @@ def __init__(
459468
super().__init__()
460469
self.action_dim = action_dim
461470
self.latent_sde_dim = None
462-
self.mean_actions = None
463-
self.log_std = None
464471
self.use_expln = use_expln
465472
self.full_std = full_std
466473
self.epsilon = epsilon
@@ -704,7 +711,9 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
704711
:return: KL(dist_true||dist_pred)
705712
"""
706713
# KL Divergence for different distribution types is out of scope
707-
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
714+
assert (
715+
dist_true.__class__ == dist_pred.__class__
716+
), f"Error: input distributions should be the same type, {dist_true.__class__} != {dist_pred.__class__}"
708717

709718
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
710719
# so we need to implement it ourselves!
@@ -723,4 +732,6 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
723732

724733
# Use the PyTorch kl_divergence implementation
725734
else:
735+
assert isinstance(dist_true.distribution, TorchDistribution)
736+
assert isinstance(dist_pred.distribution, TorchDistribution)
726737
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.8.0a0
1+
2.8.0a1

tests/test_identity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_continuous(model_class):
4545
elif model_class in [A2C]:
4646
kwargs["policy_kwargs"]["log_std_init"] = -0.5
4747
elif model_class == PPO:
48-
kwargs = dict(n_steps=512, n_epochs=5)
48+
kwargs = dict(n_steps=512, n_epochs=5, seed=0)
4949

5050
model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps)
5151

0 commit comments

Comments
 (0)