Skip to content

Commit 87e04fe

Browse files
committed
Update README.md
1 parent 13b0fae commit 87e04fe

File tree

6 files changed

+79
-16
lines changed

6 files changed

+79
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ checkpoints/
145145
dataset/ycb/YCB_Video_Dataset
146146

147147
# evaluation
148+
eval/*/debug/*
148149
eval/ycb/YCB_Video_toolbox/
149150
eval/ycb/eval_results*
150151
eval/dttd/eval_results*

README.md

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,12 @@ bash eval.sh
9191
```
9292
You can customize your own eval command, for example:
9393
```bash
94-
python3 eval.py --dataset_root ../dataset/dttd_iphone/DTTD_IPhone_Dataset/root --model ../checkpoints/m8p4.pth --output eval_results --visualize
94+
python eval.py --dataset_root ./dataset/dttd_iphone/DTTD_IPhone_Dataset/root\
95+
--model ./checkpoints/m2p1.pth\
96+
--base_latent 256 --embed_dim 512 --fusion_block_num 1 --layer_num_m 2 --layer_num_p 1\
97+
--visualize --output eval_results_m8p4_model_filtered_best\
9598
```
96-
99+
To load model with filter-enhanced MLP, please add flag `--filter`.
97100
To visualize the attention map or/and the reduced geometric embeddings' distribution, you can add flag `--debug`.
98101

99102
### Eval
@@ -102,16 +105,25 @@ This is the [ToolBox](https://github.com/yuxng/YCB_Video_toolbox) that we used f
102105
### Train
103106
To run training of our method, you can use:
104107
```bash
105-
python train.py --dataset dttd_iphone --output_dir ./result/train_result --device 0 --batch_size 1 --lr 1e-6 --min_lr 1e-7 --warm_epoch 1 --pretrain ./checkpoints/m8p4_filter_modelrecon.pth
108+
python train.py --device 0 \
109+
--dataset iphone --dataset_root ./dataset/dttd_iphone/DTTD_IPhone_Dataset/root --dataset_config ./dataset/dttd_iphone/dataset_config \
110+
--output_dir ./result/result \
111+
--base_latent 256 --embed_dim 512 --fusion_block_num 1 --layer_num_m 8 --layer_num_p 4 \
112+
--recon_w 0.3 --recon_choice depth \
113+
--loss adds --optim_batch 4 \
114+
--start_epoch 0 \
115+
--lr 1e-5 --min_lr 1e-6 --lr_rate 0.3 --decay_margin 0.033 --decay_rate 0.82 --nepoch 60 --warm_epoch 1 \
116+
--filter_enhance \
106117
```
118+
To train a smaller model, you can set flags `--layer_num_m 2 --layer_num_p 1`.
107119
To enable our method with depth robustifying modules, you can add flags `--filter_enhance` or/and `--recon_choice model`.
108120

109-
To adjust the weight of Chamfer Distance Loss to 0.5, you can set flags `--reon_weight 0.5`.
121+
To adjust the weight of Chamfer Distance Loss to 0.5, you can set flags `--reon_w 0.5`.
110122

111123
Our model is applicable on YCBV_Dataset and DTTD_v1 as well, please try following commands to run training of our method with other datasets (please ensure you download the dataset that you want to train on):
112124
```bash
113-
python train.py --dataset ycb --output_dir ./result/train_result --device 0 --batch_size 1 --lr 1e-6 --min_lr 1e-7 --warm_epoch 1
114-
python train.py --dataset dttd --output_dir ./result/train_result --device 0 --batch_size 1 --lr 1e-6 --min_lr 1e-7 --warm_epoch 1
125+
python train.py --dataset ycb --output_dir ./result/train_result --device 0 --batch_size 1 --lr 8e-5 --min_lr 8e-6 --warm_epoch 1
126+
python train.py --dataset dttd --output_dir ./result/train_result --device 0 --batch_size 1 --lr 1e-5 --min_lr 1e-6 --warm_epoch 1
115127
```
116128

117129
### Citation

model/model_utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,52 @@ def __init__(self, max_seq_length, hidden_size, hidden_dropout_prob):
8383
self.out_dropout = nn.Dropout(hidden_dropout_prob)
8484
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
8585

86-
def forward(self, input_tensor):
86+
def forward(self, input_tensor, return_filtered=False):
8787
input_tensor = input_tensor.transpose(2, 1).contiguous()
8888
batch, seq_len, hidden = input_tensor.shape
8989
x = torch.fft.rfft(input_tensor, dim=1, norm='ortho')
9090
weight = torch.view_as_complex(self.complex_weight)
9191
x = x * weight
9292
sequence_emb_fft = torch.fft.irfft(x, n=seq_len, dim=1, norm='ortho')
93+
if return_filtered: return sequence_emb_fft
9394
hidden_states = self.out_dropout(sequence_emb_fft)
9495
hidden_states = self.LayerNorm(hidden_states + input_tensor)
9596
# hidden_states = sequence_emb_fft + input_tensor
9697
return hidden_states.transpose(2, 1).contiguous()
98+
99+
def visualize_frequency_domain(self, input_tensor):
100+
import matplotlib.pyplot as plt
101+
import matplotlib
102+
font = {'family' : 'Times New Roman',
103+
'weight' : 'bold',
104+
'size' : 22}
105+
106+
matplotlib.rc('font', **font)
107+
output_tensor = self.forward(input_tensor, return_filtered=True)
108+
print(output_tensor.shape)
109+
110+
input_tensor = input_tensor.cpu().detach().numpy()[0].T
111+
output_tensor = output_tensor.cpu().detach().numpy()[0]
112+
sequence_length = input_tensor.shape[0]
113+
114+
def save(input_tensor, fn, title):
115+
# Perform SVD
116+
U, s, V = np.linalg.svd(input_tensor, full_matrices=False)
117+
118+
projected_tensor = input_tensor@U[0,:]
119+
120+
plt.figure(figsize=(10, 6))
121+
plt.hist(projected_tensor, bins=1000, edgecolor='red')
122+
123+
plt.xlabel('Geometric Features Reduced to 1-dimensional')
124+
plt.ylabel('Probability Density')
125+
plt.title(title)
126+
plt.grid(True)
127+
plt.savefig(fn)
128+
129+
save(input_tensor, 'before_filtered.png', 'Before')
130+
save(output_tensor, 'after_filtered.png', 'After')
131+
97132

98133
# Transformer Customization
99134
future_mask = torch.triu(torch.zeros([1024, 1024]).fill_(float("-inf")), 1)

model/posefusion.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _make_layer(self, base_latent, embed_dim, n_layer1, n_layer2, require_adl):
197197
class PoseNet(nn.Module):
198198
def __init__(self, num_points, num_obj, \
199199
base_latent=256, embedding_dim=512, fusion_block_num=1, layer_num_m=2, layer_num_p=4, \
200-
filter_enhance=True, require_adl=True):
200+
recon_choice='depth', filter_enhance=True, require_adl=False):
201201
super(PoseNet, self).__init__()
202202
self.num_points = num_points
203203
self.num_obj = num_obj
@@ -211,8 +211,10 @@ def __init__(self, num_points, num_obj, \
211211

212212
# unimodal embedding
213213
self.cnn = ModifiedResnet(base_latent)
214+
self.recon_choice = recon_choice
214215
self.ptnet = PointCloudAE(256, num_points, base_latent)
215-
self.filter_enhance = None if not filter_enhance else FilterLayer(num_points, base_latent, 0.0)
216+
self.modelnet = PointCloudAE(256, num_points, base_latent) if recon_choice=='both' else None
217+
self.filter_enhance = FilterLayer(num_points, base_latent, 0.0) if filter_enhance else None
216218

217219
# modality and position interaction
218220
self.fusion = PoseFusion(base_latent, embedding_dim, \
@@ -236,25 +238,35 @@ def forward(self, img, x, choose, obj, recon_ref=None):
236238
out_img = self.cnn(img)
237239
bs, di, _, _ = out_img.size()
238240
emb = out_img.view(bs, di, -1)
241+
robust_loss = 0
239242

240243
# selection of rgb color embedding
241244
choose = choose.repeat(1, di, 1)
242245
rgb_emb = torch.gather(emb, 2, choose).contiguous()
243246

244247
# depth map / point cloud (embedding)
245-
pt_feat, pt_emb, pt_recon, extra_loss = self.ptnet(x, None, recon_ref)
248+
if self.recon_choice == 'both':
249+
object_geo = recon_ref[1]
250+
recon_ref = recon_ref[0]
251+
pt_feat, pt_emb, pt_recon, cdl_0 = self.ptnet(x, None, recon_ref)
252+
robust_loss+=cdl_0
246253
pt_emb = self.ptnet.latent(pt_feat, pt_emb)
247254
if self.filter_enhance is not None:
248255
pt_emb = self.filter_enhance(pt_emb)
249-
256+
if self.recon_choice == 'both':
257+
_, obj_emb, _, cdl_1 = self.modelnet(x, None, object_geo)
258+
robust_loss+=cdl_1
259+
pt_emb += obj_emb
250260
feat = self.fusion(rgb_emb, pt_emb)
261+
251262
if self.require_adl:
252-
extra_loss += feat[1]
263+
adl = feat[1]
253264
feat = feat[0]
265+
robust_loss+=adl
254266

255267
out_rx, out_tx, out_cx = self.posepred(feat, obj)
256268

257-
return out_rx, out_tx, out_cx, rgb_emb.detach(), pt_recon.detach(), extra_loss
269+
return out_rx, out_tx, out_cx, rgb_emb.detach(), pt_recon.detach(), robust_loss
258270

259271
def get_attention_map(self, img, x, choose):
260272
# rgb color embedding
@@ -275,4 +287,9 @@ def get_attention_map(self, img, x, choose):
275287
_, _, attn1, attn2 = self.fusion.layers[0](rgb_emb, pt_emb, require_attn=True)
276288

277289
return attn1, attn2
278-
290+
291+
def get_freq_domain(self, x):
292+
pt_feat, pt_emb, _, _ = self.ptnet(x, None, None)
293+
pt_emb = self.ptnet.latent(pt_feat, pt_emb)
294+
assert self.filter_enhance is not None, "filter enhanced MLP is not applied."
295+
freq_domain = self.filter_enhance.visualize_frequency_domain(pt_emb)

run/run_train_densefusion.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

run/run_train_densefusion_GADD.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)