From b122af42e3bcde287a28aded2fde9619a6b09dea Mon Sep 17 00:00:00 2001 From: Giacomo Spigler Date: Sat, 29 Jun 2024 12:42:50 +0200 Subject: [PATCH] proposed fix for RunningMeanStd overflow --- stable_baselines3/common/running_mean_std.py | 34 ++++++++++++++------ 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index ac3538c50..4b27ff947 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -44,14 +44,30 @@ def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, bat delta = batch_mean - self.mean tot_count = self.count + batch_count - new_mean = self.mean + delta * batch_count / tot_count - m_a = self.var * self.count - m_b = batch_var * batch_count - m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) - new_var = m_2 / (self.count + batch_count) + with np.errstate(over="raise"): + try: + new_mean = self.mean + delta * batch_count / tot_count + m_a = self.var * self.count + m_b = batch_var * batch_count - new_count = batch_count + self.count + # Calculate the products/divisions in an order that reduces the chance of overflow + # Original code: + # m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) + # new_var = m_2 / (self.count + batch_count) + mult1 = self.count / (self.count + batch_count) + mult2 = batch_count / (self.count + batch_count) + var_delta = np.square(delta) * mult1 * mult2 + new_var = (m_a + m_b) / (self.count + batch_count) + var_delta - self.mean = new_mean - self.var = new_var - self.count = new_count + new_count = batch_count + self.count + + self.mean = new_mean + self.var = new_var + self.count = new_count + + except FloatingPointError: + # This happens because self.count has gotten too large and the multiplication is overflowing + # We need to scale down the batch statistics + self.count /= 2 + batch_count /= 2 + self.update_from_moments(batch_mean, batch_var, batch_count)