-
Notifications
You must be signed in to change notification settings - Fork 131
/
examine_data.py
39 lines (34 loc) · 1.23 KB
/
examine_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
""" Some data examination """
import numpy as np
import matplotlib.pyplot as plt
def plot_rollout():
""" Plot a rollout """
from torch.utils.data import DataLoader
from data.loaders import RolloutSequenceDataset
dataloader = DataLoader(
RolloutSequenceDataset(
root='datasets/carracing', seq_len=900,
transform=lambda x: x, buffer_size=10,
train=False),
batch_size=1, shuffle=True)
dataloader.dataset.load_next_buffer()
# setting up subplots
plt.subplot(2, 2, 1)
monitor_obs = plt.imshow(np.zeros((64, 64, 3)))
plt.subplot(2, 2, 2)
monitor_next_obs = plt.imshow(np.zeros((64, 64, 3)))
plt.subplot(2, 2, 3)
monitor_diff = plt.imshow(np.zeros((64, 64, 3)))
for data in dataloader:
obs_seq = data[0].numpy().squeeze()
action_seq = data[1].numpy().squeeze()
next_obs_seq = data[-1].numpy().squeeze()
for obs, action, next_obs in zip(obs_seq, action_seq, next_obs_seq):
monitor_obs.set_data(obs)
monitor_next_obs.set_data(next_obs)
monitor_diff.set_data(next_obs - obs)
print(action)
plt.pause(.01)
break
if __name__ == '__main__':
plot_rollout()