Skip to content

Commit

Permalink
add kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
BartekCupial committed Oct 31, 2024
1 parent 34d3feb commit a3b796e
Showing 1 changed file with 61 additions and 25 deletions.
86 changes: 61 additions & 25 deletions minigrid/envs/babyai/mixed_seq_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def __init__(
num_rows=1,
num_cols=1,
num_dists=8,
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'])
Expand All @@ -52,10 +57,11 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

# ['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'],
Expand Down Expand Up @@ -272,6 +278,11 @@ def __init__(
num_rows=1,
num_cols=1,
num_dists=8,
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'])
Expand All @@ -286,10 +297,11 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

def gen_mission(self):
Expand Down Expand Up @@ -492,6 +504,11 @@ def __init__(
num_cols=1,
num_dists=8,
language='french',
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = self._rand_elem(['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'])
Expand All @@ -506,11 +523,12 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['action', 'seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
language=language
language=language,
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

# ['goto', 'pickup', 'open', 'putnext', 'pick up seq go to'],
Expand Down Expand Up @@ -722,6 +740,11 @@ def __init__(
num_rows=1,
num_cols=1,
num_dists=8,
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = 'pick up seq pick up '
Expand All @@ -734,10 +757,11 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

def gen_mission(self):
Expand Down Expand Up @@ -856,6 +880,11 @@ def __init__(
num_rows=1,
num_cols=1,
num_dists=8,
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = 'pick up seq pick up '
Expand All @@ -868,10 +897,11 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

def gen_mission(self):
Expand Down Expand Up @@ -986,6 +1016,11 @@ def __init__(
num_rows=1,
num_cols=1,
num_dists=8,
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False,
**kwargs,
):

action = 'pick up seq pick up '
Expand All @@ -998,10 +1033,11 @@ def __init__(
num_cols=num_cols,
num_dists=num_dists,
action_kinds=[action],
instr_kinds=['seq1'],
locations=False,
unblocking=False,
implicit_unlock=False
instr_kinds=instr_kinds,
locations=locations,
unblocking=unblocking,
implicit_unlock=implicit_unlock,
**kwargs,
)

def gen_mission(self):
Expand Down

0 comments on commit a3b796e

Please sign in to comment.