Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Enable pmap progress bar with cpu backend / remove deprecated host_callback #1841

Merged
merged 8 commits into from
Aug 9, 2024

Conversation

andrewdipper
Copy link
Contributor

Switched from host_callback.id_tap to io_callback. This allows the progressbar to work when the jax backend is cpu. In addition host_callback.id_tap is deprecated. However io_callback doesn't allow identification of the device in the callback.

As a result the progressbars are not tied to the physical devices. The resulting effect is the progressbars behave the same but they are always shown in sorted order (by most to least complete) with no possible identification of the underlying device.

Left as draft due to the behavioral changes / needs more visual testing.

The original error with progressbar + pmap on cpu:
NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.

@andrewdipper
Copy link
Contributor Author

andrewdipper commented Aug 3, 2024

There might be some race issues here - I'm not exactly sure what the io_callback guarantees are

Edit:
Given the global random generator example in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html I don't think races should be an issue.

numpyro/util.py Outdated
@@ -201,24 +202,40 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

idx_map = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments here for how this works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added by _calc_chain_idx

tqdm_bars = {}
finished_chains = []
lock = Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment for why this is needed?

Copy link
Contributor Author

@andrewdipper andrewdipper Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. Also added a lock around closing the chains. It appears it hasn't been an issue but could cause multiple closes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the purpose of Lock in more details? I'm not familiar with its usage. What happens if we dont use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. The locking is only around idx_counter since after a chain has a chain id there isn't access to resources that are shared across threads.

numpyro/util.py Outdated
_update_tqdm, 0, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, None, -1, 0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why -1, 0 is used here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -1 was a mistake, changed to 0,0. The second 0 is consistent with the current implementation. I believe this is 0 instead on 1 so that subsequent calls using print_rate and remainder add to the total. But it does seem to be off by 1

@andrewdipper
Copy link
Contributor Author

Maybe hold on reviewing this - there's probably a better way that keeps a fixed mapping between progress bars and chains. In discussion here blackjax-devs/blackjax#712

@andrewdipper
Copy link
Contributor Author

This is updated - a resource counter assigns each chain a progress bar. This id is then saved in the fori loop carry.

The issue is the function returned by progress_bar_fori_loop expects the loop carry to have a spot to save the chain id. So if there's any use of progress_bar_factory externally this would break it. I'm not sure what'd be the best way to transition it smoothly

@andrewdipper andrewdipper requested a review from fehiepsi August 8, 2024 23:46
numpyro/util.py Outdated
iter_num = int(iter_num)
increment = int(increment)
chain = int(chain)
if iter_num == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because iter_num is not used, it is clearer to use the assertion chain == -1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed - done

numpyro/util.py Outdated
increment = int(increment)
chain = int(chain)
if iter_num == 0:
chain = _calc_chain_idx(iter_num)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is unnecessary to have a separate call to _calc_chain_idx - moving its content to here seems clearer imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

numpyro/util.py Outdated
)

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
(last_val, collection, _, _), _ = maybe_jit(loop_fn, donate_argnums=0)(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change this if we dont return chain in loop_fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

numpyro/util.py Outdated
lambda _: host_callback.id_tap(
_update_tqdm, print_rate, result=iter_num, tap_with_device=True
lambda _: io_callback(
_update_tqdm, jnp.array(0), iter_num, print_rate, chain
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like no need for iter_num

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good point

@fehiepsi
Copy link
Member

fehiepsi commented Aug 9, 2024

So if there's any use of progress_bar_factory externally this would break it

Do you know when this happens?

@andrewdipper
Copy link
Contributor Author

Do you know when this happens?

I'm not sure - I'd suspect that the vast majority of times progress_bar_factory is simply used through calls to fori_collect. If it is used externally I'd guess it'd only be in specialized cases - but I don't have an example in mind.

@fehiepsi
Copy link
Member

fehiepsi commented Aug 9, 2024

LGTM. CI seems to suggest that there is no degration in performance. Thanks for the great solution, @andrewdipper!

@fehiepsi fehiepsi merged commit fb018d7 into pyro-ppl:master Aug 9, 2024
4 checks passed
@andrewdipper andrewdipper deleted the progbar branch August 12, 2024 23:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants