Skip to content

Commit

Permalink
Merge pull request #36 from tartavull/aux-tooling
Browse files Browse the repository at this point in the history
Added a simple Tool to 'Wiggle' Models
  • Loading branch information
mginoya authored Oct 30, 2024
2 parents 06c8e23 + 856dca6 commit b836c51
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 5 deletions.
12 changes: 7 additions & 5 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def __init__(self,

# reward dictionary must be provided
if rewards:
self._rewards = rewards
self._rewards = rewards
else:
raise Exception("reward_Structure must be in kwargs")
self._rewards = {}
# raise Exception("reward_Structure must be in kwargs")

# TODO: clean this up in the future &
# make n_frames a function of input dt
Expand Down Expand Up @@ -115,9 +116,10 @@ def step(self, state: State, action: jax.Array) -> State:
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

# Add all additional parameters to compute rewards
self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd'])
self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd'])
if self._rewards:
# Add all additional parameters to compute rewards
self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd'])
self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd'])

# Compute all rewards and accumulate total reward
total_reward = 0.0
Expand Down
141 changes: 141 additions & 0 deletions alfredo/tools/tWiggleAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import functools
import os
import re
import sys
import importlib
import inspect

import brax
import jax
from brax import envs
from brax.envs.base import PipelineEnv
from brax.base import State, System
from brax.io import html, json, model
from jax import numpy as jp

from alfredo.agents import *

def generate_wiggle_traj(env: PipelineEnv, dt=0.1, motion_time=1.0):
"""
Generate html visual of wiggle trajectory.
Primarily used for debugging new models
Parameters:
- env (PipelineEnv):
- dt (float): The time step duration for which each action is applied.
- motion_time (float): The total time duration for jogging from -1 to 1.
Returns:
- HTML string
"""

# Generate Wiggle
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)

wiggle_actions = generate_wiggle_actions(env.action_size, dt, motion_time)

for wa in wiggle_actions:
print(f"commanding: {wa}")
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)

state = jit_env_step(state, wa)


traj_html_str = html.render(env.sys.replace(dt=env.dt), rollout)

return traj_html_str

def generate_wiggle_actions(action_size, dt=0.1, motion_time=1.0):
"""
Generate action vectors to gradually jog each actuator from
-1 to 1 (normalized control values).
Parameters:
- action_size (int): The number of actuators in the model.
- dt (float): The time step duration for which each action is applied.
- motion_time (float): The total time duration for jogging from -1 to 1.
Returns:
- List of action vectors for jogging each actuator.
"""

actions = []

# Calculate the number of steps required for the full jog
total_steps = int(motion_time / dt)

# Calculate the increment based on the total steps
increment = 2.0 / total_steps # Since we are jogging from -1 to 1

# Generate action sequences for each actuator
for i in range(action_size):
# Jogging forward
for j in range(total_steps):
action_vector = jp.zeros(action_size)
action_vector = action_vector.at[i].set(-1.0 + increment * (j + 1)) # Gradual increase
actions.append(action_vector)

# Jogging backward
for j in range(total_steps):
action_vector = jp.zeros(action_size)
action_vector = action_vector.at[i].set(1.0 - increment * (j + 1)) # Gradual decrease
actions.append(action_vector)

return actions

if __name__ == '__main__':

backend = "positional"

# Load desired model xml and trained param set
# get filepaths from commandline args
cwd = os.getcwd()

# Get the filepath to the env and agent xmls
import alfredo.scenes as scenes
import alfredo.agents as agents

agent_name = sys.argv[-2]
module_name = f"alfredo.agents.{agent_name}"

agents_fp = os.path.dirname(agents.__file__)
agent_xml_path = f"{agents_fp}/{agent_name}/{agent_name}.xml"

scenes_fp = os.path.dirname(scenes.__file__)
env_xml_path = f"{scenes_fp}/{sys.argv[-1]}"

print(f"agent description file: {agent_xml_path}")
print(f"environment description file: {env_xml_path}")

# Find & create Agent Brax environment
env_init_params = {"backend": backend,
"env_xml_path": env_xml_path,
"agent_xml_path": agent_xml_path}

module = importlib.import_module(module_name)

classes_in_module = [member for name, member in inspect.getmembers(module, inspect.isclass)
if member.__module__.startswith(module.__name__)]

if len(classes_in_module) == 1:
agentClass = classes_in_module[0]
env = agentClass(**env_init_params)
else:
raise ImportError(f"Agent Class not Found")

traj_html_str = generate_wiggle_traj(env, dt=env.dt)

cwd = os.getcwd()
save_fp = f"{cwd}/vis-store/{agent_name}_wiggle_traj.html"
save_fp = save_fp.replace(" ", "_")

with open(save_fp, "w") as file:
file.write(traj_html_str)
print(f"saved wiggle traj visualization to {save_fp}")

0 comments on commit b836c51

Please sign in to comment.