This is a package for training control policies through motion capture tracking using deep reinforcement learning.
- Python 3.11 or 3.12
- uv package manager (recommended) or pip
- CUDA 12.x or 13.x (for GPU support, optional)
If you don't have uv installed:
# Linux/macOS
curl -LsSf https://astral.sh/uv/install.sh | sh
# Or using pip
pip install uv- Clone the repository:
git clone https://github.com/talmolab/track-mjx.git
cd track-mjx- Create and activate a virtual environment:
uv venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate- Install the package with your CUDA version:
# For CUDA 12.x:
uv pip install -e ".[cuda12]"# For CUDA 13.x:
uv pip install -e ".[cuda13]"# For CPU-only:
uv pip install -e .# For development (includes testing and documentation tools):
uv pip install -e ".[cuda13,dev]"- Verify the installation:
python -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Available devices: {jax.devices()}')"- Test the environment:
Execute the tests in
notebooks/test_setup.ipynb. This will check if MuJoCo, GPU support and Jax appear to be working.
If you prefer using pip instead of uv:
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
pip install -e ".[cuda13]" # or cuda12CUDA version mismatch:
- Check your CUDA version:
nvcc --versionornvidia-smi - Ensure you install the matching JAX CUDA version (cuda12 or cuda13)
Import errors:
- Verify the virtual environment is activated
- Try reinstalling:
uv pip install --force-reinstall -e ".[cuda13]"
GPU not detected:
- Verify CUDA installation:
nvidia-smi - Check that JAX can see GPUs:
python -c "import jax; print(jax.devices())"
Expected output:
- GPU: Should show
cudaorgpudevices - CPU: Should show
cpudevice
- Clone the repository:
git clone https://github.com/talmolab/track-mjx.git && cd track-mjx
- Create a new development environment via
conda(this will create the necessary base environment):conda env create -f environment.yml
- Activate the environment:
conda activate track-mjx
- Install the package with desired CUDA version:
If your machine supports up to CUDA 13:
bash pip install -e ".[cuda12]"If your machiine supports up to CUDA 12:bash pip install -e ".[cuda13]"This will install the package with the desired CUDA version. - Test the environment:
Execute the tests in
notebooks/test_setup.ipynb. This will check if MuJoCo, GPU support and Jax appear to be working.
The main training entrypoint is defined in track_mjx/train.py and relies on the config in track_mjx/config/rodent-full-clips.yaml.
To download data, run notebooks/download_and_run_rodent.ipynb
Execute the following command in terminal
uv run python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='talmolab/MIMIC-MJX', repo_type='dataset', filename='data/rodent/rodent_reference_clips.h5', local_dir='.')"Using uv:
uv run python -m track_mjx.train data_path="data/rodent/rodent_reference_clips.h5" --config-name rodent-full-clips.yamlUsing conda:
conda activate track_mjx
python -m track_mjx.train data_path="data/rodent/rodent_reference_clips.h5" --config-name rodent-full-clips.yamlThis package is distributed under a BSD 3-Clause License and can be used without
restrictions. See LICENSE for details.