Skip to content

Commit bdbf56f

Browse files
committed
initial commit
0 parents  commit bdbf56f

34 files changed

+16717
-0
lines changed

.gitignore

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# Ignore MAC service files
132+
.DS_Store

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 AIRI - Artificial Intelligence Research Institute
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# When to Switch: Planning and Learning For Partially Observable Multi-Agent Pathfinding
2+
3+
This repository provides the implementation of the "When to Switch" paper, offering various policies and algorithms
4+
designed to address the challenging problem of finding non-conflicting paths for a set of agents in an environment that
5+
is only partially observable to each agent (PO-MAPF).
6+
The repository includes two main policies: one is based on search-based re-planning (**RePlan**), and the other is based
7+
on reinforcement learning (**EPOM**).
8+
Additionally, the repository features three implementations of mixed policies, which switch between **RePlan** and **EPOM**.
9+
10+
## Installation
11+
12+
Install all dependencies using:
13+
14+
```bash
15+
pip install -r docker/requirements.txt
16+
```
17+
18+
## Inference Example
19+
20+
21+
To download pretrained weights, use this [link](https://drive.google.com/file/d/1LMu2YOxzQbWDDacQaV7R-Pizkvjpp8R_/view?usp=sharing)
22+
23+
Execute **EPOM**, **RePlan**, **ASwitcher**, **LSwitcher**, and **HSwitcher** to generate animations using pre-trained
24+
weights with the following command:
25+
26+
```bash
27+
python example.py
28+
```
29+
30+
31+
The animations will be stored in the ```renders``` folder.
32+
33+
## Training EPOM
34+
35+
To train **EPOM**, execute ```train_epom.py``` with the ```learning/train.yaml``` config file:
36+
37+
```bash
38+
python train_epom.py --config_path="learning/train.yaml"
39+
```
40+
41+
## Training LSwitcher
42+
43+
To train **LSwitcher** estimator for the **RePlan** or **EPOM** algorithm, use the commands below:
44+
45+
```bash
46+
python train_lswitcher.py --algo="RePlan"
47+
```
48+
49+
```bash
50+
python train_lswitcher.py --algo="EPOM"
51+
```
52+
53+
## Citation
54+
55+
If you use this repository in your research or wish to reference it, please cite our TNNLS paper:
56+
57+
```bibtex
58+
@article{skrynnik2023switch,
59+
title = {When to Switch: Planning and Learning for Partially Observable Multi-Agent Pathfinding},
60+
author = {Skrynnik, Alexey and Andreychuk, Anton and Yakovlev, Konstantin and Panov, Aleksandr I},
61+
journal = {IEEE Transactions on Neural Networks and Learning Systems},
62+
year = {2023},
63+
publisher = {IEEE}
64+
}
65+
```

agents/assistant_switcher.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from pathlib import Path
2+
3+
from agents.utils_agents import run_algorithm
4+
5+
try:
6+
from typing import Literal
7+
except ImportError:
8+
from typing_extensions import Literal
9+
10+
from pydantic import Extra
11+
12+
from agents.replan import RePlanConfig
13+
from agents.utils_switching import SwitcherBaseConfig, SwitcherBase
14+
15+
16+
class ASwitcherConfig(SwitcherBaseConfig, extra=Extra.forbid):
17+
name: Literal['ASwitcher'] = 'ASwitcher'
18+
planning: RePlanConfig = RePlanConfig(name='RePlanCPP', fix_loops=True, add_none_if_loop=True, no_path_random=False,
19+
use_best_move=False, fix_nones=False)
20+
21+
22+
class AssistantSwitcher(SwitcherBase):
23+
24+
def get_learning_use_mask(self, planning_actions, learning_actions, observations):
25+
return [a is None for a in planning_actions]
26+
27+
28+
def example_assistant_switcher(map_name='sc1-AcrosstheCape', max_episode_steps=512, seed=None, num_agents=64,
29+
main_dir='./', animate=False):
30+
from agents.epom import EpomConfig
31+
algo = AssistantSwitcher(ASwitcherConfig(learning=EpomConfig(path_to_weights=str(main_dir / Path('weights/epom')))))
32+
return run_algorithm(algo, map_name, max_episode_steps, seed, num_agents, animate)
33+
34+
35+
if __name__ == '__main__':
36+
print(example_assistant_switcher(main_dir='../'))

agents/epom.py

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import json
2+
from copy import deepcopy
3+
from os.path import join
4+
from pathlib import Path
5+
6+
try:
7+
from typing import Literal
8+
except ImportError:
9+
from typing_extensions import Literal
10+
11+
import torch
12+
13+
from pydantic import Extra
14+
15+
from sample_factory.algorithms.appo.actor_worker import transform_dict_observations
16+
from sample_factory.algorithms.appo.learner import LearnerWorker
17+
from sample_factory.algorithms.appo.model import create_actor_critic
18+
from sample_factory.algorithms.appo.model_utils import get_hidden_size
19+
from sample_factory.envs.create_env import create_env
20+
from sample_factory.utils.utils import AttrDict
21+
22+
from agents.utils_agents import AlgoBase, run_algorithm
23+
from learning.epom_config import Environment
24+
from learning.grid_memory import MultipleGridMemory
25+
from pomapf_env.wrappers import MatrixObservationWrapper
26+
27+
from train_epom import validate_config, register_custom_components
28+
29+
30+
class EpomConfig(AlgoBase, extra=Extra.forbid):
31+
name: Literal['EPOM'] = 'EPOM'
32+
path_to_weights: str = "weights/epom"
33+
34+
35+
class EPOM:
36+
def __init__(self, algo_cfg):
37+
self.algo_cfg: EpomConfig = algo_cfg
38+
39+
path = algo_cfg.path_to_weights
40+
device = algo_cfg.device
41+
register_custom_components()
42+
43+
self.path = path
44+
self.env = None
45+
config_path = join(path, 'cfg.json')
46+
with open(config_path, "r") as f:
47+
config = json.load(f)
48+
exp, flat_config = validate_config(config['full_config'])
49+
algo_cfg = flat_config
50+
51+
env = create_env(algo_cfg.env, cfg=algo_cfg, env_config={})
52+
actor_critic = create_actor_critic(algo_cfg, env.observation_space, env.action_space)
53+
env.close()
54+
55+
if device == 'cpu' or not torch.cuda.is_available():
56+
device = torch.device('cpu')
57+
else:
58+
device = torch.device('cuda')
59+
self.device = device
60+
61+
actor_critic.model_to_device(device)
62+
policy_id = algo_cfg.policy_index
63+
checkpoints = join(path, f'checkpoint_p{policy_id}')
64+
checkpoints = LearnerWorker.get_checkpoints(checkpoints)
65+
checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device)
66+
actor_critic.load_state_dict(checkpoint_dict['model'])
67+
68+
self.ppo = actor_critic
69+
self.device = device
70+
self.cfg = algo_cfg
71+
72+
self.rnn_states = None
73+
self.mgm = MultipleGridMemory()
74+
self._step = 0
75+
76+
def after_reset(self):
77+
torch.manual_seed(self.algo_cfg.seed)
78+
self.mgm.clear()
79+
self._step = 0
80+
81+
def get_additional_info(self):
82+
result = {"rl_used": 1.0, }
83+
return result
84+
85+
def get_name(self):
86+
return Path(self.path).name
87+
88+
def act(self, observations, rewards=None, dones=None, infos=None):
89+
observations = deepcopy(observations)
90+
if self.rnn_states is None or len(self.rnn_states) != len(observations):
91+
self.rnn_states = torch.zeros([len(observations), get_hidden_size(self.cfg)], dtype=torch.float32,
92+
device=self.device)
93+
env_cfg: Environment = Environment(**self.cfg.full_config['environment'])
94+
self.mgm.update(observations)
95+
gm_radius = env_cfg.grid_memory_obs_radius
96+
self.mgm.modify_observation(observations, obs_radius=gm_radius if gm_radius else env_cfg.grid_config.obs_radius)
97+
observations = MatrixObservationWrapper.to_matrix(observations)
98+
99+
with torch.no_grad():
100+
101+
obs_torch = AttrDict(transform_dict_observations(observations))
102+
for key, x in obs_torch.items():
103+
obs_torch[key] = torch.from_numpy(x).to(self.device).float()
104+
policy_outputs = self.ppo(obs_torch, self.rnn_states, with_action_distribution=True)
105+
106+
self.rnn_states = policy_outputs.rnn_states
107+
actions = policy_outputs.actions
108+
109+
self._step += 1
110+
result = actions.cpu().numpy()
111+
return result
112+
113+
def clear_hidden(self, agent_idx):
114+
if self.rnn_states is not None:
115+
self.rnn_states[agent_idx] = torch.zeros([get_hidden_size(self.cfg)], dtype=torch.float32,
116+
device=self.device)
117+
118+
def after_step(self, dones):
119+
for agent_idx, done_flag in enumerate(dones):
120+
if done_flag:
121+
self.clear_hidden(agent_idx)
122+
123+
if all(dones):
124+
self.rnn_states = None
125+
self.mgm.clear()
126+
127+
128+
def example_epom(map_name='sc1-AcrosstheCape', max_episode_steps=512, seed=None, num_agents=64, main_dir='./',
129+
animate=False):
130+
algo = EPOM(EpomConfig(path_to_weights=str(main_dir / Path('weights/epom'))))
131+
return run_algorithm(algo, map_name, max_episode_steps, seed, num_agents, animate)
132+
133+
134+
if __name__ == '__main__':
135+
print(example_epom(main_dir='../'))

0 commit comments

Comments
 (0)