diff --git a/rl_coach/filters/observation/observation_stacking_filter.py b/rl_coach/filters/observation/observation_stacking_filter.py index 58ed45b50..96fef2f03 100644 --- a/rl_coach/filters/observation/observation_stacking_filter.py +++ b/rl_coach/filters/observation/observation_stacking_filter.py @@ -65,10 +65,10 @@ def __init__(self, stack_size: int, stacking_axis: int=-1): self.stack = [] self.input_observation_space = None - if stack_size <= 0: - raise ValueError("The stack shape must be a positive number") if type(stack_size) != int: - raise ValueError("The stack shape must be of int type") + raise TypeError("The stack shape must be of int type") + if stack_size < 2: + raise ValueError("Cannot stack less than 2 frames") @property def next_filter(self) -> 'InputFilter':