Skip to content

Commit 149b978

Browse files
renderer differentiable
1 parent 87d9de9 commit 149b978

File tree

2 files changed

+221
-90
lines changed

2 files changed

+221
-90
lines changed

notebooks/bayes3d_paper/kitti.ipynb

+221-85
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
},
114114
{
115115
"cell_type": "code",
116-
"execution_count": 7,
116+
"execution_count": 9,
117117
"metadata": {},
118118
"outputs": [],
119119
"source": [
@@ -144,19 +144,26 @@
144144
"# plt.matshow(\n",
145145
"# jnp.exp(log_kernel)\n",
146146
"# )\n",
147-
"# plt.colorbar()"
147+
"# plt.colorbar()\n",
148+
"\n",
149+
"value_and_grad_func = jax.jit(\n",
150+
" jax.value_and_grad(\n",
151+
" lambda trace, pose: b3d.update_choices_get_score(trace, Pytree.const((\"object_pose_0\",)), pose),\n",
152+
" argnums=1\n",
153+
" )\n",
154+
")"
148155
]
149156
},
150157
{
151158
"cell_type": "code",
152-
"execution_count": 29,
159+
"execution_count": 269,
153160
"metadata": {},
154161
"outputs": [
155162
{
156163
"name": "stdout",
157164
"output_type": "stream",
158165
"text": [
159-
"-164.9128\n"
166+
"-161.01285\n"
160167
]
161168
}
162169
],
@@ -207,20 +214,85 @@
207214
},
208215
{
209216
"cell_type": "code",
210-
"execution_count": 20,
217+
"execution_count": 270,
211218
"metadata": {},
212-
"outputs": [],
219+
"outputs": [
220+
{
221+
"name": "stdout",
222+
"output_type": "stream",
223+
"text": [
224+
"-81.195724\n",
225+
"[[ 0.4 36.234486 ]\n",
226+
" [ 0.37310344 48.27931 ]\n",
227+
" [ 0.4 39.67586 ]\n",
228+
" [ 0.4 8.703448 ]\n",
229+
" [ 0.38655174 37.955173 ]]\n",
230+
"852.72345\n",
231+
"[[0.01 0.13793103]\n",
232+
" [0.01 0.16413793]\n",
233+
" [0.01 0.15103447]\n",
234+
" [0.01 0.16413793]\n",
235+
" [0.01 0.13793103]]\n"
236+
]
237+
}
238+
],
213239
"source": [
214-
"addr = \"object_pose_0\"\n",
215-
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
216-
" trace, key, 0.5, 1000.0, addr, 1000\n",
240+
"# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n",
241+
"outlier_probability_sweep = jnp.linspace(0.01, 0.4, 30)\n",
242+
"blur_sweep = jnp.linspace(0.1, 50.0, 30)\n",
243+
"key = b3d.split_key(key)\n",
244+
"addresses= Pytree.const(( \"outlier_probability_0\", \"blur\"))\n",
245+
"sweeps = [ outlier_probability_sweep, blur_sweep]\n",
246+
"scores = b3d.utils.grid_trace(\n",
247+
" trace,\n",
248+
" addresses,\n",
249+
" sweeps,\n",
250+
")\n",
251+
"print(scores.max())\n",
252+
"index = jnp.unravel_index(scores.argmax(), scores.shape)\n",
253+
"\n",
254+
"sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n",
255+
" jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n",
256+
")\n",
257+
"sampled_parameters = jnp.vstack(\n",
258+
" [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n",
259+
").T\n",
260+
"\n",
261+
"print(sampled_parameters[:5])\n",
262+
"\n",
263+
"trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n",
264+
"viz_trace(trace,T)\n",
265+
"\n",
266+
"# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n",
267+
"color_sweep_sweep = jnp.linspace(0.01, 0.4, 30)\n",
268+
"depth_sweep = jnp.linspace(0.02, 0.4, 30)\n",
269+
"key = b3d.split_key(key)\n",
270+
"addresses= Pytree.const(( \"color_variance_0\", \"depth_variance_0\"))\n",
271+
"sweeps = [ color_sweep_sweep, depth_sweep]\n",
272+
"scores = b3d.utils.grid_trace(\n",
273+
" trace,\n",
274+
" addresses,\n",
275+
" sweeps,\n",
276+
")\n",
277+
"print(scores.max())\n",
278+
"index = jnp.unravel_index(scores.argmax(), scores.shape)\n",
279+
"\n",
280+
"sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n",
281+
" jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n",
217282
")\n",
218-
"viz_trace(trace, 0)"
283+
"sampled_parameters = jnp.vstack(\n",
284+
" [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n",
285+
").T\n",
286+
"\n",
287+
"print(sampled_parameters[:5])\n",
288+
"\n",
289+
"trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n",
290+
"viz_trace(trace,T)"
219291
]
220292
},
221293
{
222294
"cell_type": "code",
223-
"execution_count": 37,
295+
"execution_count": 274,
224296
"metadata": {},
225297
"outputs": [],
226298
"source": [
@@ -254,129 +326,193 @@
254326
},
255327
{
256328
"cell_type": "code",
257-
"execution_count": 26,
329+
"execution_count": 278,
258330
"metadata": {},
259-
"outputs": [],
331+
"outputs": [
332+
{
333+
"name": "stdout",
334+
"output_type": "stream",
335+
"text": [
336+
"427.17706\n"
337+
]
338+
}
339+
],
260340
"source": [
341+
"\n",
261342
"addr = \"object_pose_0\"\n",
262343
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
263-
" trace, key, 0.5, 1000.0, addr, 1000\n",
344+
" trace, key, 0.1, 2000.0, addr, 700\n",
264345
")\n",
265346
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
266-
" trace, key, 0.1, 1000.0, addr, 1000\n",
347+
" trace, key, 0.1, 1000.0, addr, 700\n",
267348
")\n",
349+
"\n",
268350
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
269-
" trace, key, 0.05, 1000.0, addr, 1000\n",
351+
" trace, key, 0.05, 2000.0, addr, 700\n",
270352
")\n",
353+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
354+
" trace, key, 0.05, 1000.0, addr, 700\n",
355+
")\n",
356+
"\n",
357+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
358+
" trace, key, 0.01, 2000.0, addr, 700\n",
359+
")\n",
360+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
361+
" trace, key, 0.01, 1000.0, addr, 700\n",
362+
")\n",
363+
"\n",
364+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
365+
" trace, key, 0.5, 2000.0, addr, 700\n",
366+
")\n",
367+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
368+
" trace, key, 0.1, 2000.0, addr, 700\n",
369+
")\n",
370+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
371+
" trace, key, 0.1, 2000.0, addr, 700\n",
372+
")\n",
373+
"print(trace.get_score())\n",
374+
"\n",
271375
"viz_trace(trace, T)"
272376
]
273377
},
274378
{
275379
"cell_type": "code",
276-
"execution_count": 27,
380+
"execution_count": 287,
277381
"metadata": {},
278-
"outputs": [],
382+
"outputs": [
383+
{
384+
"name": "stderr",
385+
"output_type": "stream",
386+
"text": [
387+
"Loss 405.0047607421875: 100%|██████████| 100/100 [00:00<00:00, 118.42it/s]\n"
388+
]
389+
},
390+
{
391+
"name": "stdout",
392+
"output_type": "stream",
393+
"text": [
394+
"406.33615\n"
395+
]
396+
}
397+
],
279398
"source": [
280-
"value_and_grad_func = jax.jit(\n",
281-
" jax.value_and_grad(\n",
282-
" lambda trace, pose: b3d.update_choices_get_score(trace, Pytree.const((\"object_pose_0\",)), pose),\n",
283-
" argnums=1\n",
284-
" )\n",
285-
")"
399+
"pbar = tqdm(range(100))\n",
400+
"for _ in pbar:\n",
401+
" loss, grad_pose = value_and_grad_func(trace, trace.get_choices()[\"object_pose_0\"])\n",
402+
" pbar.set_description(f\"Loss {loss}\")\n",
403+
" trace = b3d.update_choices(trace, Pytree.const((\"object_pose_0\",)), trace.get_choices()[\"object_pose_0\"] + grad_pose * 0.001)\n",
404+
"print(trace.get_score())\n",
405+
"viz_trace(trace, T)"
286406
]
287407
},
288408
{
289409
"cell_type": "code",
290-
"execution_count": 40,
410+
"execution_count": null,
411+
"metadata": {},
412+
"outputs": [],
413+
"source": []
414+
},
415+
{
416+
"cell_type": "code",
417+
"execution_count": 211,
291418
"metadata": {},
292419
"outputs": [
293420
{
294421
"name": "stdout",
295422
"output_type": "stream",
296423
"text": [
297-
"519.4532\n",
298-
"[[0.3462069 3.5413792 ]\n",
299-
" [0.4 3.5413792 ]\n",
300-
" [0.37310344 3.5413792 ]\n",
301-
" [0.38655174 3.5413792 ]\n",
302-
" [0.30586204 3.5413792 ]]\n"
424+
"-147.7747\n",
425+
"[[ 0.4 10.424138]\n",
426+
" [ 0.4 29.351725]\n",
427+
" [ 0.4 41.396553]\n",
428+
" [ 0.4 50. ]\n",
429+
" [ 0.4 32.793102]]\n",
430+
"29.873737\n",
431+
"[[0.03689655 0.4 ]\n",
432+
" [0.02344828 0.4 ]\n",
433+
" [0.03689655 0.4 ]\n",
434+
" [0.03689655 0.4 ]\n",
435+
" [0.03689655 0.4 ]]\n"
303436
]
304437
}
305438
],
306-
"source": [
307-
"# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n",
308-
"outlier_probability_sweep = jnp.linspace(0.01, 0.4, 30)\n",
309-
"blur_sweep = jnp.linspace(0.1, 50.0, 30)\n",
310-
"key = b3d.split_key(key)\n",
311-
"addresses= Pytree.const(( \"outlier_probability_0\", \"blur\"))\n",
312-
"sweeps = [ outlier_probability_sweep, blur_sweep]\n",
313-
"scores = b3d.utils.grid_trace(\n",
314-
" trace,\n",
315-
" addresses,\n",
316-
" sweeps,\n",
317-
")\n",
318-
"print(scores.max())\n",
319-
"index = jnp.unravel_index(scores.argmax(), scores.shape)\n",
320-
"\n",
321-
"sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n",
322-
" jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n",
323-
")\n",
324-
"sampled_parameters = jnp.vstack(\n",
325-
" [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n",
326-
").T\n",
327-
"\n",
328-
"print(sampled_parameters[:5])\n",
329-
"\n",
330-
"trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n",
331-
"viz_trace(trace,T)"
332-
]
439+
"source": []
333440
},
334441
{
335442
"cell_type": "code",
336-
"execution_count": 41,
443+
"execution_count": null,
444+
"metadata": {},
445+
"outputs": [],
446+
"source": []
447+
},
448+
{
449+
"cell_type": "code",
450+
"execution_count": 235,
337451
"metadata": {},
338452
"outputs": [
339453
{
340454
"name": "stdout",
341455
"output_type": "stream",
342456
"text": [
343-
"518.2212\n",
344-
"[[0.01 0.11172414]\n",
345-
" [0.01 0.13793103]\n",
346-
" [0.01 0.12482759]\n",
347-
" [0.01 0.11172414]\n",
348-
" [0.01 0.12482759]]\n"
457+
"-31.535\n"
349458
]
350459
}
351460
],
352461
"source": [
353-
"# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n",
354-
"color_sweep_sweep = jnp.linspace(0.01, 0.4, 30)\n",
355-
"depth_sweep = jnp.linspace(0.02, 0.4, 30)\n",
356-
"key = b3d.split_key(key)\n",
357-
"addresses= Pytree.const(( \"color_variance_0\", \"depth_variance_0\"))\n",
358-
"sweeps = [ color_sweep_sweep, depth_sweep]\n",
359-
"scores = b3d.utils.grid_trace(\n",
360-
" trace,\n",
361-
" addresses,\n",
362-
" sweeps,\n",
462+
"\n",
463+
"addr = \"object_pose_0\"\n",
464+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
465+
" trace, key, 0.5, 1000.0, addr, 200\n",
466+
")\n",
467+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
468+
" trace, key, 0.1, 2000.0, addr, 700\n",
469+
")\n",
470+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
471+
" trace, key, 0.1, 1000.0, addr, 700\n",
363472
")\n",
364-
"print(scores.max())\n",
365-
"index = jnp.unravel_index(scores.argmax(), scores.shape)\n",
366473
"\n",
367-
"sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n",
368-
" jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n",
474+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
475+
" trace, key, 0.05, 2000.0, addr, 700\n",
476+
")\n",
477+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
478+
" trace, key, 0.05, 1000.0, addr, 700\n",
369479
")\n",
370-
"sampled_parameters = jnp.vstack(\n",
371-
" [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n",
372-
").T\n",
373480
"\n",
374-
"print(sampled_parameters[:5])\n",
481+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
482+
" trace, key, 0.01, 2000.0, addr, 700\n",
483+
")\n",
484+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
485+
" trace, key, 0.01, 1000.0, addr, 700\n",
486+
")\n",
375487
"\n",
376-
"trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n",
377-
"viz_trace(trace,0)"
488+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
489+
" trace, key, 0.5, 2000.0, addr, 700\n",
490+
")\n",
491+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
492+
" trace, key, 0.1, 2000.0, addr, 700\n",
493+
")\n",
494+
"trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
495+
" trace, key, 0.1, 2000.0, addr, 700\n",
496+
")\n",
497+
"print(trace.get_score())\n",
498+
"\n",
499+
"viz_trace(trace, T)"
378500
]
379501
},
502+
{
503+
"cell_type": "code",
504+
"execution_count": null,
505+
"metadata": {},
506+
"outputs": [],
507+
"source": []
508+
},
509+
{
510+
"cell_type": "code",
511+
"execution_count": null,
512+
"metadata": {},
513+
"outputs": [],
514+
"source": []
515+
},
380516
{
381517
"cell_type": "code",
382518
"execution_count": 42,

0 commit comments

Comments
 (0)