From aa8eefcceee1fefd6b677501639377a9f78e28c6 Mon Sep 17 00:00:00 2001 From: Howuhh Date: Thu, 21 Dec 2023 15:35:57 +0300 Subject: [PATCH] inital benchmarks --- README.md | 20 ++-- examples/walkthrough.ipynb | 36 +++++++- scripts/generate_benchmarks.sh | 56 ++++++++++++ scripts/ruleset_generator.py | 122 +++++++++++++++++++------ src/xminigrid/__init__.py | 2 +- src/xminigrid/benchmarks.py | 28 ++++-- src/xminigrid/rendering/text_render.py | 61 +++++++++---- training/train_meta_task.py | 6 +- training/train_single_task.py | 2 +- 9 files changed, 265 insertions(+), 68 deletions(-) create mode 100644 scripts/generate_benchmarks.sh diff --git a/README.md b/README.md index f12d6ce..62d3bd9 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,9 @@ + + + Open In Colab @@ -88,7 +91,7 @@ key = jax.random.PRNGKey(0) reset_key, ruleset_key = jax.random.split(key) # to list available benchmarks: xminigrid.registered_benchmarks() -benchmark = xminigrid.load_benchmark(name="Trivial") +benchmark = xminigrid.load_benchmark(name="trivial-1m") # choosing ruleset, see section on rules and goals ruleset = benchmark.sample_ruleset(ruleset_key) @@ -150,11 +153,11 @@ While composing rules and goals by hand is flexible, it can quickly become cumbe Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations To avoid significant overhead during training and facilitate reliable comparisons between agents, -we pre-sampled several benchmarks with up to **one million unique tasks**, following the procedure used to train DeepMind +we pre-sampled several benchmarks with up to **five million unique tasks**, following the procedure used to train DeepMind AdA agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example -the `Trivial` benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution -against treating benchmarks as a progression from simple to complex. They are just different 🤷. +the `trivial-1m` benchmark can be used to debug your agents, allowing very quick iterations. However, we would caution +against treating benchmarks as a progression from simple to complex. They are just different 🤷. Pre-sampled benchmarks are hosted on [HuggingFace](https://huggingface.co/datasets/Howuhh/xland_minigrid/tree/main) and will be downloaded and cached on the first use: @@ -165,13 +168,16 @@ from xminigrid.benchmarks import Benchmark # downloading to path specified by XLAND_MINIGRID_DATA, # ~/.xland_minigrid by default -benchmark: Benchmark = xminigrid.load_benchmark(name="Trivial") +benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m") # reusing cached on the second use -benchmark: Benchmark = xminigrid.load_benchmark(name="Trivial") +benchmark: Benchmark = xminigrid.load_benchmark(name="trivial-1m") # users can sample or get specific rulesets benchmark.sample_ruleset(jax.random.PRNGKey(0)) benchmark.get_ruleset(ruleset_id=benchmark.num_rulesets() - 1) + +# or split them for train & test +train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8) ``` We also provide the [script](scripts/ruleset_generator.py) used to generate these benchmarks. Users can use it for their own purposes: @@ -181,7 +187,7 @@ python scripts/ruleset_generator.py --help In depth description of all available benchmarks is provided [here (soon)](). -**P.S.** Currently only one benchmark is available. We will release more after some testing and configs balancing. Stay tuned! +**P.S.** Be aware, that benchmarks can change, as we are currently testing and balancing them! ## Environments 🌍 diff --git a/examples/walkthrough.ipynb b/examples/walkthrough.ipynb index ca7951e..02d0712 100644 --- a/examples/walkthrough.ipynb +++ b/examples/walkthrough.ipynb @@ -814,7 +814,7 @@ "source": [ "While composing rules and goals by hand is flexible, it can quickly become cumbersome. Besides, it's hard to express efficiently in a JAX-compatible way due to the high number of heterogeneous computations\n", "\n", - "To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to **one million unique** tasks, following the procedure used to train [DeepMind AdA](https://sites.google.com/view/adaptive-agent/) agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the Trivial benchmark can be used to debug your agents, allowing very quick iterations. \n", + "To avoid significant overhead during training and facilitate reliable comparisons between agents, we pre-sampled several benchmarks with up to **five million unique** tasks (apart from the randomization of object positions during reset), following the procedure used to train [DeepMind AdA](https://sites.google.com/view/adaptive-agent/) agent from the original XLand. These benchmarks differ in the generation configs, producing distributions with varying levels of diversity and average difficulty of the tasks. They can be used for different purposes, for example the `trivial-1m` benchmark can be used to debug your agents, allowing quick iterations. \n", "\n", "**Generation protocol**:\n", "\n", @@ -921,7 +921,7 @@ "source": [ "print(\"Benchmarks available:\", xminigrid.registered_benchmarks())\n", "\n", - "benchmark = xminigrid.load_benchmark(name=\"Trivial\")\n", + "benchmark = xminigrid.load_benchmark(name=\"trivial-1m\")\n", "print(\"Total rulesets:\", benchmark.num_rulesets())\n", "print(\"Ruleset with id 128: \\n\", benchmark.get_ruleset(ruleset_id=128))\n", "print(\"Random ruleset: \\n\", benchmark.sample_ruleset(jax.random.PRNGKey(0)))" @@ -966,7 +966,7 @@ "outputs": [], "source": [ "# example path, can be any your valid path\n", - "bechmark_path = os.path.join(DATA_PATH, NAME2HFFILENAME[\"Trivial\"])\n", + "bechmark_path = os.path.join(DATA_PATH, NAME2HFFILENAME[\"trivial-1m\"])\n", "\n", "rulesets_clear = load_bz2_pickle(bechmark_path)\n", "loaded_benchmark = Benchmark(\n", @@ -977,6 +977,36 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "0a08d28a-60fd-4a09-acb2-7e4d42538b6e", + "metadata": {}, + "source": [ + "You also my need splitting functionality to test generalization of your agents. For this users can use `split` or `filter_split`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "653b678b-65c1-47a8-9644-4f2d519d9e55", + "metadata": {}, + "outputs": [], + "source": [ + "train, test = benchmark.shuffle(key=jax.random.PRNGKey(0)).split(prop=0.8)\n", + "\n", + "# or, by some function:\n", + "def cond_fn(goal, rules):\n", + " # 0 index in the encoding is the ID\n", + " return jnp.logical_not(\n", + " jnp.logical_and(\n", + " jnp.greater_equal(goal[0], 7),\n", + " jnp.less_equal(goal[0], 14)\n", + " )\n", + " )\n", + " \n", + "train, test = benchmark.filter_split(fn=cond_fn)" + ] + }, { "cell_type": "markdown", "id": "4786beac-9bea-4bbc-bcf9-92c21b5770d2", diff --git a/scripts/generate_benchmarks.sh b/scripts/generate_benchmarks.sh new file mode 100644 index 0000000..330f366 --- /dev/null +++ b/scripts/generate_benchmarks.sh @@ -0,0 +1,56 @@ +# This can take a lot of time. Generate only needed! +# TODO: provide same for 5M benchmarks + +# trivial +python scripts/ruleset_generator.py \ + --chain_depth=0 \ + --num_distractor_objects=3 \ + --total_rulesets=1_000_000 \ + --save_path="trivial_1m" + + +# small +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.3 \ + --chain_depth=1 \ + --sample_distractor_rules \ + --num_distractor_rules=2 \ + --num_distractor_objects=2 \ + --total_rulesets=1_000_000 \ + --save_path="small_1m" + +# medium +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.3 \ + --chain_depth=2 \ + --sample_distractor_rules \ + --num_distractor_rules=3 \ + --num_distractor_objects=0 \ + --total_rulesets=1_000_000 \ + --save_path="medium_1m" + + +# high +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.1 \ + --chain_depth=3 \ + --sample_distractor_rules \ + --num_distractor_rules=4 \ + --num_distractor_objects=1 \ + --total_rulesets=1_000_000 \ + --save_path="high_1m" + + +# medium + distractors +python scripts/ruleset_generator.py \ + --prune_chain \ + --prune_prob=0.8 \ + --chain_depth=2 \ + --sample_distractor_rules \ + --num_distractor_rules=4 \ + --num_distractor_objects=2 \ + --total_rulesets=1_000_000 \ + --save_path="medium_dist_1m" diff --git a/scripts/ruleset_generator.py b/scripts/ruleset_generator.py index abd25d0..ef51d0d 100644 --- a/scripts/ruleset_generator.py +++ b/scripts/ruleset_generator.py @@ -10,26 +10,62 @@ from tqdm.auto import tqdm, trange from xminigrid.benchmarks import save_bz2_pickle from xminigrid.core.constants import Colors, Tiles -from xminigrid.core.goals import AgentHoldGoal, AgentNearGoal, TileNearGoal +from xminigrid.core.goals import ( + AgentHoldGoal, + AgentNearDownGoal, + AgentNearGoal, + AgentNearLeftGoal, + AgentNearRightGoal, + AgentNearUpGoal, + TileNearDownGoal, + TileNearGoal, + TileNearLeftGoal, + TileNearRightGoal, + TileNearUpGoal, +) from xminigrid.core.grid import pad_along_axis -from xminigrid.core.rules import AgentHoldRule, AgentNearRule, EmptyRule, TileNearRule - -COLORS = [Colors.RED, Colors.GREEN, Colors.BLUE, Colors.PURPLE, Colors.YELLOW, Colors.GREY, Colors.WHITE] +from xminigrid.core.rules import ( + AgentHoldRule, + AgentNearDownRule, + AgentNearLeftRule, + AgentNearRightRule, + AgentNearRule, + AgentNearUpRule, + EmptyRule, + TileNearDownRule, + TileNearLeftRule, + TileNearRightRule, + TileNearRule, + TileNearUpRule, +) + +COLORS = [ + Colors.RED, + Colors.GREEN, + Colors.BLUE, + Colors.PURPLE, + Colors.YELLOW, + Colors.GREY, + Colors.WHITE, + Colors.BROWN, + Colors.PINK, + Colors.ORANGE, +] # we need to distinguish between them, to avoid sampling # near(goal, goal) goal or rule as goal tiles are not pickable -NEAR_TILES_LHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.GOAL], COLORS)) +NEAR_TILES_LHS = list( + product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS) +) # these are pickable! -NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS)) +NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS)) + +HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS)) -HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS)) -PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY], COLORS)) # to imitate disappearance production rule +PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS)) PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)] -GOALS = (AgentHoldGoal, AgentNearGoal, TileNearGoal) -RULES = (AgentHoldRule, AgentNearRule, TileNearRule) - def encode(ruleset): flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist() @@ -41,42 +77,72 @@ def diff(list1, list2): def sample_goal(): - goal_idx = random.randint(0, 2) + goals = ( + AgentHoldGoal, + # agent near variations + AgentNearGoal, + AgentNearUpGoal, + AgentNearDownGoal, + AgentNearLeftGoal, + AgentNearRightGoal, + # tile near variations + TileNearGoal, + TileNearUpGoal, + TileNearDownGoal, + TileNearLeftGoal, + TileNearRightGoal, + ) + goal_idx = random.randint(0, 10) if goal_idx == 0: tile = random.choice(HOLD_TILES) - goal = AgentHoldGoal(tile=jnp.array(tile)) + goal = goals[0](tile=jnp.array(tile)) return goal, (tile,) - elif goal_idx == 1: + elif 1 <= goal_idx <= 5: tile = random.choice(NEAR_TILES_LHS) - goal = AgentNearGoal(tile=jnp.array(tile)) + goal = goals[goal_idx](tile=jnp.array(tile)) return goal, (tile,) - elif goal_idx == 2: + elif 6 <= goal_idx <= 10: tile_a = random.choice(NEAR_TILES_LHS) tile_b = random.choice(NEAR_TILES_RHS) - goal = TileNearGoal(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b)) + goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b)) return goal, (tile_a, tile_b) else: - raise RuntimeError(f"Unknown goal, should be one of: {GOALS}") + raise RuntimeError("Unknown goal") def sample_rule(prod_tile, used_tiles): - rule_idx = random.randint(0, 2) + rules = ( + AgentHoldRule, + # agent near variations + AgentNearRule, + AgentNearUpRule, + AgentNearDownRule, + AgentNearLeftRule, + AgentNearRightRule, + # tile near variations + TileNearRule, + TileNearUpRule, + TileNearDownRule, + TileNearLeftRule, + TileNearRightRule, + ) + rule_idx = random.randint(0, 10) + if rule_idx == 0: tile = random.choice(diff(HOLD_TILES, used_tiles)) - rule = AgentHoldRule(tile=jnp.array(tile), prod_tile=jnp.array(prod_tile)) + rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile)) return rule, (tile,) - elif rule_idx == 1: - tile = random.choice(diff(NEAR_TILES_LHS, used_tiles)) - rule = AgentNearRule(tile=jnp.array(tile), prod_tile=jnp.array(prod_tile)) + elif 1 <= rule_idx <= 5: + tile = random.choice(diff(HOLD_TILES, used_tiles)) + rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile)) return rule, (tile,) - elif rule_idx == 2: + elif 6 <= rule_idx <= 10: tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles)) tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles)) - - rule = TileNearRule(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile)) + rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile)) return rule, (tile_a, tile_b) else: - raise RuntimeError(f"Unknown rule, should be one of: {RULES}") + raise RuntimeError("Unknown rule") # See Appendix A.2 in "Human-timescale adaptation in an open-ended task space" for sampling procedure. @@ -158,7 +224,7 @@ def sample_ruleset( "init_tiles": init_tiles, # additional info (for example for biasing sampling by number of rules) # you can add other field if needed, just copy-paste this file! - "num_rules": num_levels, + "num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]), } diff --git a/src/xminigrid/__init__.py b/src/xminigrid/__init__.py index f17699a..e5c38f3 100644 --- a/src/xminigrid/__init__.py +++ b/src/xminigrid/__init__.py @@ -2,7 +2,7 @@ from .registration import make, register, registered_environments # TODO: add __all__ -__version__ = "0.2.0" +__version__ = "0.3.0" # ---------- XLand-MiniGrid environments ---------- # TODO: reconsider grid sizes and time limits after the benchmarks are generated. diff --git a/src/xminigrid/benchmarks.py b/src/xminigrid/benchmarks.py index a07f392..aede8c2 100644 --- a/src/xminigrid/benchmarks.py +++ b/src/xminigrid/benchmarks.py @@ -2,7 +2,7 @@ import os import pickle import urllib.request -from typing import Dict +from typing import Callable, Dict import jax import jax.numpy as jnp @@ -15,15 +15,22 @@ HF_REPO_ID = os.environ.get("XLAND_MINIGRID_HF_REPO_ID", "Howuhh/xland_minigrid") DATA_PATH = os.environ.get("XLAND_MINIGRID_DATA", os.path.expanduser("~/.xland_minigrid")) + NAME2HFFILENAME = { - "Trivial": "xminigrid_rulesets_trivial", + # 1M pre-sampled tasks + "trivial-1M": "trivial_1m", + "small-1M": "small_1m", + "small-dist-1M": "small_dist_1m", + "medium-1M": "medium_1m", + "high-1M": "high_1m", + # 5M pre-sampled tasks (TODO) + "trivial-5M": "", + "small-5M": "", + "small-dist-5M": "", + "medium-5M": "", + "high-5M": "", } -# NAME2HFFILENAME = { -# "base-trivial-v0": "xminigrid_rulesets_trivial", -# "extended-trivial-v0": ..., -# } - # jit compatible sampling and indexing! # You can implement your custom curriculums based on this class. @@ -53,6 +60,13 @@ def split(self, prop: float) -> tuple["Benchmark", "Benchmark"]: bench2 = jtu.tree_map(lambda a: a[idx:], self) return bench1, bench2 + def filter_split(self, fn: Callable[[jax.Array, jax.Array], bool]) -> tuple["Benchmark", "Benchmark"]: + # fn(single_goal, single_rules) -> bool + mask = jax.vmap(fn)(self.goals, self.rules) + bench1 = jtu.tree_map(lambda a: a[mask], self) + bench2 = jtu.tree_map(lambda a: a[~mask], self) + return bench1, bench2 + def load_benchmark(name: str) -> Benchmark: if name not in NAME2HFFILENAME: diff --git a/src/xminigrid/rendering/text_render.py b/src/xminigrid/rendering/text_render.py index 7b6adfb..a36a453 100644 --- a/src/xminigrid/rendering/text_render.py +++ b/src/xminigrid/rendering/text_render.py @@ -47,6 +47,8 @@ Tiles.PYRAMID: "pyramid", Tiles.GOAL: "goal", Tiles.KEY: "key", + Tiles.HEX: "hexagon", + Tiles.STAR: "star", } PLAYER_STR = {0: "^", 1: ">", 2: "V", 3: "<"} @@ -78,21 +80,35 @@ def render(grid: jax.Array, agent: AgentState | None = None) -> str: return string -# WARN: This is also for debugging mainly! -def _text_encode_tile(tile): +# WARN: This is for debugging mainly! Will refactor later if needed. +def _encode_tile(tile): return f"{COLOR_NAMES[tile[1]]} {RULE_TILE_STR[tile[0]]}" def _text_encode_goal(goal): goal_id = goal[0] if goal_id == 1: - return f"AgentHold({_text_encode_tile(goal[1:3])})" + return f"AgentHold({_encode_tile(goal[1:3])})" elif goal_id == 3: - return f"AgentNear({_text_encode_tile(goal[1:3])})" + return f"AgentNear({_encode_tile(goal[1:3])})" elif goal_id == 4: - tile_a = _text_encode_tile(goal[1:3]) - tile_b = _text_encode_tile(goal[3:5]) - return f"TileNear({tile_a}, {tile_b})" + return f"TileNear({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})" + elif goal_id == 7: + return f"TileNearUpGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})" + elif goal_id == 8: + return f"TileNearRightGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})" + elif goal_id == 9: + return f"TileNearDownGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})" + elif goal_id == 10: + return f"TileNearLeftGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})" + elif goal_id == 11: + return f"AgentNearUpGoal({_encode_tile(goal[1:3])})" + elif goal_id == 12: + return f"AgentNearRightGoal({_encode_tile(goal[1:3])})" + elif goal_id == 13: + return f"AgentNearDownGoal({_encode_tile(goal[1:3])})" + elif goal_id == 14: + return f"AgentNearLeftGoal({_encode_tile(goal[1:3])})" else: raise RuntimeError(f"Rendering: Unknown goal id: {goal_id}") @@ -100,18 +116,27 @@ def _text_encode_goal(goal): def _text_encode_rule(rule): rule_id = rule[0] if rule_id == 1: - tile = _text_encode_tile(rule[1:3]) - prod_tile = _text_encode_tile(rule[3:5]) - return f"AgentHold({tile}) -> {prod_tile}" + return f"AgentHold({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" elif rule_id == 2: - tile = _text_encode_tile(rule[1:3]) - prod_tile = _text_encode_tile(rule[3:5]) - return f"AgentNear({tile}) -> {prod_tile}" + return f"AgentNear({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" elif rule_id == 3: - tile_a = _text_encode_tile(rule[1:3]) - tile_b = _text_encode_tile(rule[3:5]) - prod_tile = _text_encode_tile(rule[5:7]) - return f"TileNear({tile_a}, {tile_b}) -> {prod_tile}" + return f"TileNear({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}" + elif rule_id == 4: + return f"TileNearUpRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}" + elif rule_id == 5: + return f"TileNearRightRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}" + elif rule_id == 6: + return f"TileNearDownRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}" + elif rule_id == 7: + return f"TileNearLeftRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}" + elif rule_id == 8: + return f"AgentNearUpRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" + elif rule_id == 9: + return f"AgentNearRightRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" + elif rule_id == 10: + return f"AgentNearDownRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" + elif rule_id == 11: + return f"AgentNearLeftRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}" else: raise RuntimeError(f"Rendering: Unknown rule id: {rule_id}") @@ -128,4 +153,4 @@ def print_ruleset(ruleset: RuleSet): print("INIT TILES:") for tile in ruleset.init_tiles.tolist(): if tile[0] != 0: - print(_text_encode_tile(tile)) + print(_encode_tile(tile)) diff --git a/training/train_meta_task.py b/training/train_meta_task.py index 7d12df4..8bf4535 100644 --- a/training/train_meta_task.py +++ b/training/train_meta_task.py @@ -33,9 +33,9 @@ class TrainConfig: project: str = "xminigrid" group: str = "default" - name: str = "meta_task_ppo" - env_id: str = "XLand-Minigrid-R1-8x8" - benchmark_id: str = "Trivial" + name: str = "meta-task-ppo" + env_id: str = "XLand-Minigrid-R1-9x9" + benchmark_id: str = "trivial-1m" # agent action_emb_dim: int = 16 rnn_hidden_dim: int = 1024 diff --git a/training/train_single_task.py b/training/train_single_task.py index ed74ed1..0c63f49 100644 --- a/training/train_single_task.py +++ b/training/train_single_task.py @@ -27,7 +27,7 @@ class TrainConfig: project: str = "xminigrid" group: str = "default" - name: str = "single_task_ppo" + name: str = "single-task-ppo" env_id: str = "MiniGrid-Empty-6x6" # agent action_emb_dim: int = 16