Skip to content

Commit 3538a8c

Browse files
maxhuettenrauchZhengLi1314
authored andcommitted
fixed env seeding in test_sac_with_il.py (thu-ml#1081)
1 parent 2451cd7 commit 3538a8c

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/continuous/test_sac_with_il.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
7878
# seed
7979
np.random.seed(args.seed)
8080
torch.manual_seed(args.seed)
81+
train_envs.seed(args.seed)
82+
test_envs.seed(args.seed + args.training_num)
8183
# model
8284
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
8385
actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device)
@@ -181,10 +183,12 @@ def stop_fn(mean_rewards: float) -> bool:
181183
action_scaling=True,
182184
action_bound_method="clip",
183185
)
186+
il_test_env = gym.make(args.task)
187+
il_test_env.reset(seed=args.seed + args.training_num + args.test_num)
184188
il_test_collector = Collector(
185189
il_policy,
186190
# envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed),
187-
gym.make(args.task),
191+
il_test_env,
188192
)
189193
train_collector.reset()
190194
result = OffpolicyTrainer(

0 commit comments

Comments
 (0)