-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Enable pmap progress bar with cpu backend / remove deprecated host_callback #1841
Conversation
There might be some race issues here - I'm not exactly sure what the Edit: |
numpyro/util.py
Outdated
@@ -201,24 +202,40 @@ def progress_bar_factory(num_samples, num_chains): | |||
|
|||
remainder = num_samples % print_rate | |||
|
|||
idx_map = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add some comments here for how this works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added by _calc_chain_idx
tqdm_bars = {} | ||
finished_chains = [] | ||
lock = Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a comment for why this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. Also added a lock around closing the chains. It appears it hasn't been an issue but could cause multiple closes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify the purpose of Lock in more details? I'm not familiar with its usage. What happens if we dont use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. The locking is only around idx_counter
since after a chain has a chain id there isn't access to resources that are shared across threads.
numpyro/util.py
Outdated
_update_tqdm, 0, result=iter_num, tap_with_device=True | ||
), | ||
lambda _: iter_num, | ||
lambda _: io_callback(_update_tqdm, None, -1, 0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm not sure why -1, 0 is used here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The -1 was a mistake, changed to 0,0. The second 0 is consistent with the current implementation. I believe this is 0 instead on 1 so that subsequent calls using print_rate and remainder add to the total. But it does seem to be off by 1
Maybe hold on reviewing this - there's probably a better way that keeps a fixed mapping between progress bars and chains. In discussion here blackjax-devs/blackjax#712 |
This is updated - a resource counter assigns each chain a progress bar. This id is then saved in the fori loop carry. The issue is the function returned by |
numpyro/util.py
Outdated
iter_num = int(iter_num) | ||
increment = int(increment) | ||
chain = int(chain) | ||
if iter_num == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because iter_num is not used, it is clearer to use the assertion chain == -1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed - done
numpyro/util.py
Outdated
increment = int(increment) | ||
chain = int(chain) | ||
if iter_num == 0: | ||
chain = _calc_chain_idx(iter_num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is unnecessary to have a separate call to _calc_chain_idx - moving its content to here seems clearer imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
numpyro/util.py
Outdated
) | ||
|
||
last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection) | ||
(last_val, collection, _, _), _ = maybe_jit(loop_fn, donate_argnums=0)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to change this if we dont return chain in loop_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
numpyro/util.py
Outdated
lambda _: host_callback.id_tap( | ||
_update_tqdm, print_rate, result=iter_num, tap_with_device=True | ||
lambda _: io_callback( | ||
_update_tqdm, jnp.array(0), iter_num, print_rate, chain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like no need for iter_num
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah good point
Do you know when this happens? |
I'm not sure - I'd suspect that the vast majority of times |
LGTM. CI seems to suggest that there is no degration in performance. Thanks for the great solution, @andrewdipper! |
Switched from
host_callback.id_tap
toio_callback
. This allows the progressbar to work when the jax backend iscpu
. In additionhost_callback.id_tap
is deprecated. Howeverio_callback
doesn't allow identification of the device in the callback.As a result the progressbars are not tied to the physical devices. The resulting effect is the progressbars behave the same but they are always shown in sorted order (by most to least complete) with no possible identification of the underlying device.
Left as draft due to the behavioral changes / needs more visual testing.
The original error with progressbar + pmap on cpu:
NotImplementedError: host_callback functionality isn't supported with PJRT C API. See https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html for alternatives. Please file a feature request at https://github.com/google/jax/issues if none of the alternatives are sufficient.