diff --git a/mujoco_playground/experimental/learning/aloha.ipynb b/mujoco_playground/experimental/learning/aloha.ipynb deleted file mode 100644 index 963d6c3a3..000000000 --- a/mujoco_playground/experimental/learning/aloha.ipynb +++ /dev/null @@ -1,199 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n", - "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n", - "os.environ[\"XLA_FLAGS\"] = xla_flags\n", - "os.environ[\"MUJOCO_GL\"] = \"egl\"" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "import functools\n", - "import json\n", - "from datetime import datetime\n", - "\n", - "import jax\n", - "import matplotlib.pyplot as plt\n", - "import mediapy as media\n", - "from brax.training.agents.ppo import networks as ppo_networks\n", - "from brax.training.agents.ppo import train as ppo\n", - "from IPython.display import clear_output, display\n", - "import numpy as np\n", - "\n", - "from mujoco_playground import BraxEnvWrapper, manipulation\n", - "\n", - "# Enable persistent compilation cache.\n", - "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n", - "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n", - "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "env_name = \"AlohaSinglePegInsertion\"\n", - "env_cfg = manipulation.get_default_config(env_name)\n", - "env = manipulation.load(env_name, config=env_cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from mujoco_playground.learning import manipulation_params\n", - "\n", - "ppo_params = manipulation_params.brax_ppo_config(env_name)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "x_data, y_data, y_dataerr = [], [], []\n", - "times = [datetime.now()]\n", - "\n", - "\n", - "def progress(num_steps, metrics):\n", - " clear_output(wait=True)\n", - "\n", - " times.append(datetime.now())\n", - " x_data.append(num_steps)\n", - " y_data.append(metrics[\"eval/episode_reward\"])\n", - " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n", - "\n", - " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n", - " plt.ylim([0, 15_000])\n", - " plt.xlabel(\"# environment steps\")\n", - " plt.ylabel(\"reward per episode\")\n", - " plt.title(f\"y={y_data[-1]:.3f}\")\n", - " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", - "\n", - " display(plt.gcf())\n", - "\n", - "\n", - "training_params = dict(ppo_params)\n", - "del training_params[\"network_factory\"]\n", - "train_fn = functools.partial(ppo.train, **training_params, progress_fn=progress)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "network_factory = functools.partial(\n", - " ppo_networks.make_ppo_networks,\n", - " policy_hidden_layer_sizes=ppo_params.network_factory.policy_hidden_layer_sizes,\n", - ")\n", - "make_inference_fn, params, metrics = train_fn(environment=BraxEnvWrapper(env))\n", - "print(f\"time to jit: {times[1] - times[0]}\")\n", - "print(f\"time to train: {times[-1] - times[1]}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "jit_reset = jax.jit(env.reset)\n", - "jit_step = jax.jit(env.step)\n", - "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# rng = jax.random.PRNGKey(42)\n", - "# rollout = []\n", - "# n_episodes = 1\n", - "\n", - "# for _ in range(n_episodes):\n", - "# state = jit_reset(rng)\n", - "# rollout.append(state)\n", - "# for i in range(env_cfg.episode_length):\n", - "# act_rng, rng = jax.random.split(rng)\n", - "# ctrl, _ = jit_inference_fn(state.obs, act_rng)\n", - "# state = jit_step(state, ctrl)\n", - "# rollout.append(state)\n", - "\n", - "render_every = 1\n", - "frames = env.render(\n", - " rollout[::render_every],\n", - " camera=\"teleoperator_pov\",\n", - " height=480 * 2,\n", - " width=640 * 2,\n", - ")\n", - "rewards = [s.reward for s in rollout]\n", - "media.show_video(frames, fps=1.0 / env.dt / render_every)\n", - "\n", - "plt.plot(np.convolve(rewards, np.ones(100) / 100, mode=\"valid\"))\n", - "plt.xlabel(\"time step\")\n", - "plt.ylabel(\"reward\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.15" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}