|
113 | 113 | },
|
114 | 114 | {
|
115 | 115 | "cell_type": "code",
|
116 |
| - "execution_count": 7, |
| 116 | + "execution_count": 9, |
117 | 117 | "metadata": {},
|
118 | 118 | "outputs": [],
|
119 | 119 | "source": [
|
|
144 | 144 | "# plt.matshow(\n",
|
145 | 145 | "# jnp.exp(log_kernel)\n",
|
146 | 146 | "# )\n",
|
147 |
| - "# plt.colorbar()" |
| 147 | + "# plt.colorbar()\n", |
| 148 | + "\n", |
| 149 | + "value_and_grad_func = jax.jit(\n", |
| 150 | + " jax.value_and_grad(\n", |
| 151 | + " lambda trace, pose: b3d.update_choices_get_score(trace, Pytree.const((\"object_pose_0\",)), pose),\n", |
| 152 | + " argnums=1\n", |
| 153 | + " )\n", |
| 154 | + ")" |
148 | 155 | ]
|
149 | 156 | },
|
150 | 157 | {
|
151 | 158 | "cell_type": "code",
|
152 |
| - "execution_count": 29, |
| 159 | + "execution_count": 269, |
153 | 160 | "metadata": {},
|
154 | 161 | "outputs": [
|
155 | 162 | {
|
156 | 163 | "name": "stdout",
|
157 | 164 | "output_type": "stream",
|
158 | 165 | "text": [
|
159 |
| - "-164.9128\n" |
| 166 | + "-161.01285\n" |
160 | 167 | ]
|
161 | 168 | }
|
162 | 169 | ],
|
|
207 | 214 | },
|
208 | 215 | {
|
209 | 216 | "cell_type": "code",
|
210 |
| - "execution_count": 20, |
| 217 | + "execution_count": 270, |
211 | 218 | "metadata": {},
|
212 |
| - "outputs": [], |
| 219 | + "outputs": [ |
| 220 | + { |
| 221 | + "name": "stdout", |
| 222 | + "output_type": "stream", |
| 223 | + "text": [ |
| 224 | + "-81.195724\n", |
| 225 | + "[[ 0.4 36.234486 ]\n", |
| 226 | + " [ 0.37310344 48.27931 ]\n", |
| 227 | + " [ 0.4 39.67586 ]\n", |
| 228 | + " [ 0.4 8.703448 ]\n", |
| 229 | + " [ 0.38655174 37.955173 ]]\n", |
| 230 | + "852.72345\n", |
| 231 | + "[[0.01 0.13793103]\n", |
| 232 | + " [0.01 0.16413793]\n", |
| 233 | + " [0.01 0.15103447]\n", |
| 234 | + " [0.01 0.16413793]\n", |
| 235 | + " [0.01 0.13793103]]\n" |
| 236 | + ] |
| 237 | + } |
| 238 | + ], |
213 | 239 | "source": [
|
214 |
| - "addr = \"object_pose_0\"\n", |
215 |
| - "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
216 |
| - " trace, key, 0.5, 1000.0, addr, 1000\n", |
| 240 | + "# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n", |
| 241 | + "outlier_probability_sweep = jnp.linspace(0.01, 0.4, 30)\n", |
| 242 | + "blur_sweep = jnp.linspace(0.1, 50.0, 30)\n", |
| 243 | + "key = b3d.split_key(key)\n", |
| 244 | + "addresses= Pytree.const(( \"outlier_probability_0\", \"blur\"))\n", |
| 245 | + "sweeps = [ outlier_probability_sweep, blur_sweep]\n", |
| 246 | + "scores = b3d.utils.grid_trace(\n", |
| 247 | + " trace,\n", |
| 248 | + " addresses,\n", |
| 249 | + " sweeps,\n", |
| 250 | + ")\n", |
| 251 | + "print(scores.max())\n", |
| 252 | + "index = jnp.unravel_index(scores.argmax(), scores.shape)\n", |
| 253 | + "\n", |
| 254 | + "sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n", |
| 255 | + " jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n", |
| 256 | + ")\n", |
| 257 | + "sampled_parameters = jnp.vstack(\n", |
| 258 | + " [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n", |
| 259 | + ").T\n", |
| 260 | + "\n", |
| 261 | + "print(sampled_parameters[:5])\n", |
| 262 | + "\n", |
| 263 | + "trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n", |
| 264 | + "viz_trace(trace,T)\n", |
| 265 | + "\n", |
| 266 | + "# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n", |
| 267 | + "color_sweep_sweep = jnp.linspace(0.01, 0.4, 30)\n", |
| 268 | + "depth_sweep = jnp.linspace(0.02, 0.4, 30)\n", |
| 269 | + "key = b3d.split_key(key)\n", |
| 270 | + "addresses= Pytree.const(( \"color_variance_0\", \"depth_variance_0\"))\n", |
| 271 | + "sweeps = [ color_sweep_sweep, depth_sweep]\n", |
| 272 | + "scores = b3d.utils.grid_trace(\n", |
| 273 | + " trace,\n", |
| 274 | + " addresses,\n", |
| 275 | + " sweeps,\n", |
| 276 | + ")\n", |
| 277 | + "print(scores.max())\n", |
| 278 | + "index = jnp.unravel_index(scores.argmax(), scores.shape)\n", |
| 279 | + "\n", |
| 280 | + "sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n", |
| 281 | + " jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n", |
217 | 282 | ")\n",
|
218 |
| - "viz_trace(trace, 0)" |
| 283 | + "sampled_parameters = jnp.vstack(\n", |
| 284 | + " [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n", |
| 285 | + ").T\n", |
| 286 | + "\n", |
| 287 | + "print(sampled_parameters[:5])\n", |
| 288 | + "\n", |
| 289 | + "trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n", |
| 290 | + "viz_trace(trace,T)" |
219 | 291 | ]
|
220 | 292 | },
|
221 | 293 | {
|
222 | 294 | "cell_type": "code",
|
223 |
| - "execution_count": 37, |
| 295 | + "execution_count": 274, |
224 | 296 | "metadata": {},
|
225 | 297 | "outputs": [],
|
226 | 298 | "source": [
|
|
254 | 326 | },
|
255 | 327 | {
|
256 | 328 | "cell_type": "code",
|
257 |
| - "execution_count": 26, |
| 329 | + "execution_count": 278, |
258 | 330 | "metadata": {},
|
259 |
| - "outputs": [], |
| 331 | + "outputs": [ |
| 332 | + { |
| 333 | + "name": "stdout", |
| 334 | + "output_type": "stream", |
| 335 | + "text": [ |
| 336 | + "427.17706\n" |
| 337 | + ] |
| 338 | + } |
| 339 | + ], |
260 | 340 | "source": [
|
| 341 | + "\n", |
261 | 342 | "addr = \"object_pose_0\"\n",
|
262 | 343 | "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
|
263 |
| - " trace, key, 0.5, 1000.0, addr, 1000\n", |
| 344 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
264 | 345 | ")\n",
|
265 | 346 | "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
|
266 |
| - " trace, key, 0.1, 1000.0, addr, 1000\n", |
| 347 | + " trace, key, 0.1, 1000.0, addr, 700\n", |
267 | 348 | ")\n",
|
| 349 | + "\n", |
268 | 350 | "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n",
|
269 |
| - " trace, key, 0.05, 1000.0, addr, 1000\n", |
| 351 | + " trace, key, 0.05, 2000.0, addr, 700\n", |
270 | 352 | ")\n",
|
| 353 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 354 | + " trace, key, 0.05, 1000.0, addr, 700\n", |
| 355 | + ")\n", |
| 356 | + "\n", |
| 357 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 358 | + " trace, key, 0.01, 2000.0, addr, 700\n", |
| 359 | + ")\n", |
| 360 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 361 | + " trace, key, 0.01, 1000.0, addr, 700\n", |
| 362 | + ")\n", |
| 363 | + "\n", |
| 364 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 365 | + " trace, key, 0.5, 2000.0, addr, 700\n", |
| 366 | + ")\n", |
| 367 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 368 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
| 369 | + ")\n", |
| 370 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 371 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
| 372 | + ")\n", |
| 373 | + "print(trace.get_score())\n", |
| 374 | + "\n", |
271 | 375 | "viz_trace(trace, T)"
|
272 | 376 | ]
|
273 | 377 | },
|
274 | 378 | {
|
275 | 379 | "cell_type": "code",
|
276 |
| - "execution_count": 27, |
| 380 | + "execution_count": 287, |
277 | 381 | "metadata": {},
|
278 |
| - "outputs": [], |
| 382 | + "outputs": [ |
| 383 | + { |
| 384 | + "name": "stderr", |
| 385 | + "output_type": "stream", |
| 386 | + "text": [ |
| 387 | + "Loss 405.0047607421875: 100%|██████████| 100/100 [00:00<00:00, 118.42it/s]\n" |
| 388 | + ] |
| 389 | + }, |
| 390 | + { |
| 391 | + "name": "stdout", |
| 392 | + "output_type": "stream", |
| 393 | + "text": [ |
| 394 | + "406.33615\n" |
| 395 | + ] |
| 396 | + } |
| 397 | + ], |
279 | 398 | "source": [
|
280 |
| - "value_and_grad_func = jax.jit(\n", |
281 |
| - " jax.value_and_grad(\n", |
282 |
| - " lambda trace, pose: b3d.update_choices_get_score(trace, Pytree.const((\"object_pose_0\",)), pose),\n", |
283 |
| - " argnums=1\n", |
284 |
| - " )\n", |
285 |
| - ")" |
| 399 | + "pbar = tqdm(range(100))\n", |
| 400 | + "for _ in pbar:\n", |
| 401 | + " loss, grad_pose = value_and_grad_func(trace, trace.get_choices()[\"object_pose_0\"])\n", |
| 402 | + " pbar.set_description(f\"Loss {loss}\")\n", |
| 403 | + " trace = b3d.update_choices(trace, Pytree.const((\"object_pose_0\",)), trace.get_choices()[\"object_pose_0\"] + grad_pose * 0.001)\n", |
| 404 | + "print(trace.get_score())\n", |
| 405 | + "viz_trace(trace, T)" |
286 | 406 | ]
|
287 | 407 | },
|
288 | 408 | {
|
289 | 409 | "cell_type": "code",
|
290 |
| - "execution_count": 40, |
| 410 | + "execution_count": null, |
| 411 | + "metadata": {}, |
| 412 | + "outputs": [], |
| 413 | + "source": [] |
| 414 | + }, |
| 415 | + { |
| 416 | + "cell_type": "code", |
| 417 | + "execution_count": 211, |
291 | 418 | "metadata": {},
|
292 | 419 | "outputs": [
|
293 | 420 | {
|
294 | 421 | "name": "stdout",
|
295 | 422 | "output_type": "stream",
|
296 | 423 | "text": [
|
297 |
| - "519.4532\n", |
298 |
| - "[[0.3462069 3.5413792 ]\n", |
299 |
| - " [0.4 3.5413792 ]\n", |
300 |
| - " [0.37310344 3.5413792 ]\n", |
301 |
| - " [0.38655174 3.5413792 ]\n", |
302 |
| - " [0.30586204 3.5413792 ]]\n" |
| 424 | + "-147.7747\n", |
| 425 | + "[[ 0.4 10.424138]\n", |
| 426 | + " [ 0.4 29.351725]\n", |
| 427 | + " [ 0.4 41.396553]\n", |
| 428 | + " [ 0.4 50. ]\n", |
| 429 | + " [ 0.4 32.793102]]\n", |
| 430 | + "29.873737\n", |
| 431 | + "[[0.03689655 0.4 ]\n", |
| 432 | + " [0.02344828 0.4 ]\n", |
| 433 | + " [0.03689655 0.4 ]\n", |
| 434 | + " [0.03689655 0.4 ]\n", |
| 435 | + " [0.03689655 0.4 ]]\n" |
303 | 436 | ]
|
304 | 437 | }
|
305 | 438 | ],
|
306 |
| - "source": [ |
307 |
| - "# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n", |
308 |
| - "outlier_probability_sweep = jnp.linspace(0.01, 0.4, 30)\n", |
309 |
| - "blur_sweep = jnp.linspace(0.1, 50.0, 30)\n", |
310 |
| - "key = b3d.split_key(key)\n", |
311 |
| - "addresses= Pytree.const(( \"outlier_probability_0\", \"blur\"))\n", |
312 |
| - "sweeps = [ outlier_probability_sweep, blur_sweep]\n", |
313 |
| - "scores = b3d.utils.grid_trace(\n", |
314 |
| - " trace,\n", |
315 |
| - " addresses,\n", |
316 |
| - " sweeps,\n", |
317 |
| - ")\n", |
318 |
| - "print(scores.max())\n", |
319 |
| - "index = jnp.unravel_index(scores.argmax(), scores.shape)\n", |
320 |
| - "\n", |
321 |
| - "sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n", |
322 |
| - " jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n", |
323 |
| - ")\n", |
324 |
| - "sampled_parameters = jnp.vstack(\n", |
325 |
| - " [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n", |
326 |
| - ").T\n", |
327 |
| - "\n", |
328 |
| - "print(sampled_parameters[:5])\n", |
329 |
| - "\n", |
330 |
| - "trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n", |
331 |
| - "viz_trace(trace,T)" |
332 |
| - ] |
| 439 | + "source": [] |
333 | 440 | },
|
334 | 441 | {
|
335 | 442 | "cell_type": "code",
|
336 |
| - "execution_count": 41, |
| 443 | + "execution_count": null, |
| 444 | + "metadata": {}, |
| 445 | + "outputs": [], |
| 446 | + "source": [] |
| 447 | + }, |
| 448 | + { |
| 449 | + "cell_type": "code", |
| 450 | + "execution_count": 235, |
337 | 451 | "metadata": {},
|
338 | 452 | "outputs": [
|
339 | 453 | {
|
340 | 454 | "name": "stdout",
|
341 | 455 | "output_type": "stream",
|
342 | 456 | "text": [
|
343 |
| - "518.2212\n", |
344 |
| - "[[0.01 0.11172414]\n", |
345 |
| - " [0.01 0.13793103]\n", |
346 |
| - " [0.01 0.12482759]\n", |
347 |
| - " [0.01 0.11172414]\n", |
348 |
| - " [0.01 0.12482759]]\n" |
| 457 | + "-31.535\n" |
349 | 458 | ]
|
350 | 459 | }
|
351 | 460 | ],
|
352 | 461 | "source": [
|
353 |
| - "# hierachical bayes inference of \"outlier_probability_0\", \"blur\"\n", |
354 |
| - "color_sweep_sweep = jnp.linspace(0.01, 0.4, 30)\n", |
355 |
| - "depth_sweep = jnp.linspace(0.02, 0.4, 30)\n", |
356 |
| - "key = b3d.split_key(key)\n", |
357 |
| - "addresses= Pytree.const(( \"color_variance_0\", \"depth_variance_0\"))\n", |
358 |
| - "sweeps = [ color_sweep_sweep, depth_sweep]\n", |
359 |
| - "scores = b3d.utils.grid_trace(\n", |
360 |
| - " trace,\n", |
361 |
| - " addresses,\n", |
362 |
| - " sweeps,\n", |
| 462 | + "\n", |
| 463 | + "addr = \"object_pose_0\"\n", |
| 464 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 465 | + " trace, key, 0.5, 1000.0, addr, 200\n", |
| 466 | + ")\n", |
| 467 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 468 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
| 469 | + ")\n", |
| 470 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 471 | + " trace, key, 0.1, 1000.0, addr, 700\n", |
363 | 472 | ")\n",
|
364 |
| - "print(scores.max())\n", |
365 |
| - "index = jnp.unravel_index(scores.argmax(), scores.shape)\n", |
366 | 473 | "\n",
|
367 |
| - "sampled_indices = jax.vmap(jnp.unravel_index, in_axes=(0, None))(\n", |
368 |
| - " jax.random.categorical(key, scores.reshape(-1), shape=(1000,)), scores.shape\n", |
| 474 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 475 | + " trace, key, 0.05, 2000.0, addr, 700\n", |
| 476 | + ")\n", |
| 477 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 478 | + " trace, key, 0.05, 1000.0, addr, 700\n", |
369 | 479 | ")\n",
|
370 |
| - "sampled_parameters = jnp.vstack(\n", |
371 |
| - " [sweep[indices] for indices, sweep in zip(sampled_indices, sweeps)]\n", |
372 |
| - ").T\n", |
373 | 480 | "\n",
|
374 |
| - "print(sampled_parameters[:5])\n", |
| 481 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 482 | + " trace, key, 0.01, 2000.0, addr, 700\n", |
| 483 | + ")\n", |
| 484 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 485 | + " trace, key, 0.01, 1000.0, addr, 700\n", |
| 486 | + ")\n", |
375 | 487 | "\n",
|
376 |
| - "trace = b3d.update_choices(trace, addresses, *sampled_parameters[0])\n", |
377 |
| - "viz_trace(trace,0)" |
| 488 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 489 | + " trace, key, 0.5, 2000.0, addr, 700\n", |
| 490 | + ")\n", |
| 491 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 492 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
| 493 | + ")\n", |
| 494 | + "trace, key = b3d.bayes3d.enumerative_proposals.gvmf_and_select_best_move(\n", |
| 495 | + " trace, key, 0.1, 2000.0, addr, 700\n", |
| 496 | + ")\n", |
| 497 | + "print(trace.get_score())\n", |
| 498 | + "\n", |
| 499 | + "viz_trace(trace, T)" |
378 | 500 | ]
|
379 | 501 | },
|
| 502 | + { |
| 503 | + "cell_type": "code", |
| 504 | + "execution_count": null, |
| 505 | + "metadata": {}, |
| 506 | + "outputs": [], |
| 507 | + "source": [] |
| 508 | + }, |
| 509 | + { |
| 510 | + "cell_type": "code", |
| 511 | + "execution_count": null, |
| 512 | + "metadata": {}, |
| 513 | + "outputs": [], |
| 514 | + "source": [] |
| 515 | + }, |
380 | 516 | {
|
381 | 517 | "cell_type": "code",
|
382 | 518 | "execution_count": 42,
|
|
0 commit comments