Skip to content

Commit 0e39eaa

Browse files
committed
Add norm
1 parent 5c7b6cb commit 0e39eaa

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

train_robot_dis.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, x):
6060
hidden_dim = 128
6161
num_layers = 1
6262
num_heads = 4
63-
num_bins = 512
63+
num_bins = 64
6464
sequence_length = 100
6565

6666
# Read the robot data from the CSV file
@@ -84,6 +84,9 @@ def forward(self, x):
8484
data = data[joints]
8585
trajectory_dim = len(joints)
8686

87+
# Normalize the LKnee joint data (per joint) / center and scale
88+
data = (data - data.mean()) / data.std()
89+
8790
# Plot the LKnee joint data
8891
plt.figure(figsize=(12, 6))
8992
plt.plot(data)
@@ -105,6 +108,10 @@ def forward(self, x):
105108
num_samples = real_trajectories.size(0)
106109
real_trajectories = real_trajectories[torch.randperm(real_trajectories.size(0))]
107110

111+
112+
# Limit to -1 to 1
113+
real_trajectories = torch.tanh(real_trajectories)
114+
108115
# Subplot each joint, showing the first n batches
109116
n = 1
110117
plt.figure(figsize=(12, 6))
@@ -122,20 +129,20 @@ def forward(self, x):
122129
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
123130

124131
# Create batches of data
125-
batch_size = 32
132+
batch_size = 16
126133

127134
# Training loop
128-
for epoch in tqdm(range(20)): # Number of training epochs
135+
for epoch in tqdm(range(5)): # Number of training epochs
129136
for batch in range(num_samples // batch_size):
130137
targets = real_trajectories[batch * batch_size: (batch + 1) * batch_size].to(device)
131138

132139
optimizer.zero_grad()
133140

134-
# Map -1 pi - pi to 0 - 1
135-
targets_scaled = (targets + np.pi) / (2 * np.pi)
141+
# Map the data to the range 0 to 1
142+
targets_scaled = (targets + 1) / 2
136143

137144
# Discretize into num_bins
138-
targets_binned = (targets_scaled * num_bins).long()
145+
targets_binned = (targets_scaled * (num_bins - 1)).long()
139146

140147
# Make floating point tensors
141148
targets_binned = targets_binned.unsqueeze(-1).to(device)
@@ -175,7 +182,8 @@ def sample_trajectory(steps=20):
175182
probabilities.append(predicted_bin[:, -1, 0].squeeze(0).softmax(-1).cpu().detach().numpy())
176183

177184
# Sample top bin as the next velocity
178-
sampled_bin = torch.multinomial(predicted_bin[:, -1].softmax(-1).squeeze(), 1, replacement=True)
185+
#sampled_bin = torch.multinomial(predicted_bin[:, -1].softmax(-1).squeeze(), 1, replacement=True)
186+
_, sampled_bin = torch.topk(predicted_bin[:, -1].softmax(-1).squeeze(), 1)
179187

180188
# Only keep the last predicted bin
181189
sampled_bin = sampled_bin[:, -1]
@@ -204,15 +212,11 @@ def sample_trajectory(steps=20):
204212
sampled_trajectory = sample_trajectory(steps=99)
205213
# Coverting the sampled trajectory to a numpy array
206214
sampled_trajectory = np.array(sampled_trajectory)
207-
# Convert back to radians
208-
sampled_trajectory = (sampled_trajectory / num_bins) * (2 * np.pi) - np.pi
209215
plt.figure(figsize=(12, 6))
210216
# plot the sampled trajectory for each joint in a subplot
211217
for j in range(trajectory_dim):
212218
plt.subplot(3, 4, j + 1)
213219
plt.plot(sampled_trajectory[:, j], label="Sampled Trajectory")
214-
# Fix limits to -pi to pi
215-
plt.ylim(-np.pi, np.pi)
216220
plt.title(f"Joint {joints[j]}")
217221
plt.legend()
218222
plt.show()

0 commit comments

Comments
 (0)