From a3b796ebfa7307f51b7e6d90205008663ecceca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Thu, 31 Oct 2024 13:32:41 +0100 Subject: [PATCH] add kwargs --- minigrid/envs/babyai/mixed_seq_levels.py | 86 +++++++++++++++++------- 1 file changed, 61 insertions(+), 25 deletions(-) diff --git a/minigrid/envs/babyai/mixed_seq_levels.py b/minigrid/envs/babyai/mixed_seq_levels.py index 4ff381fa0..2037889d4 100644 --- a/minigrid/envs/babyai/mixed_seq_levels.py +++ b/minigrid/envs/babyai/mixed_seq_levels.py @@ -38,6 +38,11 @@ def __init__( num_rows=1, num_cols=1, num_dists=8, + instr_kinds=['action', 'seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to']) @@ -52,10 +57,11 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['action', 'seq1'], - locations=False, - unblocking=False, - implicit_unlock=False + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) # ['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'], @@ -272,6 +278,11 @@ def __init__( num_rows=1, num_cols=1, num_dists=8, + instr_kinds=['action', 'seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to']) @@ -286,10 +297,11 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['action', 'seq1'], - locations=False, - unblocking=False, - implicit_unlock=False + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) def gen_mission(self): @@ -492,6 +504,11 @@ def __init__( num_cols=1, num_dists=8, language='french', + instr_kinds=['action', 'seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to']) @@ -506,11 +523,12 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['action', 'seq1'], - locations=False, - unblocking=False, - implicit_unlock=False, - language=language + language=language, + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) # ['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'], @@ -722,6 +740,11 @@ def __init__( num_rows=1, num_cols=1, num_dists=8, + instr_kinds=['seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = 'pick up seq pick up ' @@ -734,10 +757,11 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['seq1'], - locations=False, - unblocking=False, - implicit_unlock=False + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) def gen_mission(self): @@ -856,6 +880,11 @@ def __init__( num_rows=1, num_cols=1, num_dists=8, + instr_kinds=['seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = 'pick up seq pick up ' @@ -868,10 +897,11 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['seq1'], - locations=False, - unblocking=False, - implicit_unlock=False + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) def gen_mission(self): @@ -986,6 +1016,11 @@ def __init__( num_rows=1, num_cols=1, num_dists=8, + instr_kinds=['seq1'], + locations=False, + unblocking=False, + implicit_unlock=False, + **kwargs, ): action = 'pick up seq pick up ' @@ -998,10 +1033,11 @@ def __init__( num_cols=num_cols, num_dists=num_dists, action_kinds=[action], - instr_kinds=['seq1'], - locations=False, - unblocking=False, - implicit_unlock=False + instr_kinds=instr_kinds, + locations=locations, + unblocking=unblocking, + implicit_unlock=implicit_unlock, + **kwargs, ) def gen_mission(self):