You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In predict_th, there's an assert statement that says the following:
assert rew_th.shape == state.shape[:1]
This will fail, even if you've modified state_th using self.preprocess to be a valid tensor.
Steps to reproduce
Attempt to pre-process a dictionary style state with the preprocess function of the reward network. Even if you return a valid state_th, it checks against the original state, which is incorrect.
Instead I believe it should be:
assert rew_th.shape == state_th.shape[:1]