-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
topk and argtopk #10086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
topk and argtopk #10086
Conversation
dcherian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. I took a quick pass. I have one major suggestion on handling skipna and some minor suggestions for reducing scope so we can get this in earlier
|
|
||
|
|
||
| def argtopk(values, k, axis=None, skipna=None): | ||
| if is_chunked_array(values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if is_chunked_array(values): | |
| if is_duck_dask_array(values): |
| else: | ||
| return self._replace_maybe_drop_dims(result) | ||
|
|
||
| def argtopk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think generate_aggregations.py would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready
| def topk(values, k: int, axis: int): | ||
| """Extract the k largest elements from a on the given axis. | ||
| If k is negative, extract the -k smallest elements instead. | ||
| The returned elements are sorted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not sort. The user can do that if they really need to.
| """ | ||
| # topk accepts only an integer axis like argmin or argmax, | ||
| # not tuples, so we need to stack multiple dimensions. | ||
| if dim is ... or dim is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- we have a
infix_dimsutility function for this. - Can we punt the multiple dims case to later once someone asks for it? The stacking approach is bad with dask and will require some
reshape_blockwisetrickery, which isn't hard but we may as well do it in a followup (seepolyfitfor an example)
|
|
||
| # Borrowed from nanops | ||
| xp = get_array_namespace(values) | ||
| if skipna or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the way to do this is to use the fact that nans sort to the end. Then given k and count you know what to provide to partition.
whats-new.rstapi.rstI've made a start with implementations of
topkandargtopk.This'll work uncontroversially for DataArrays with and without NaNs, although I'm getting mostly stuck on what
skipna=Trueshould entail.There's a number of choices to make, however, which are probably best illustrated with this PR.
nputils.py. The question that arises here for me: I guess this would work withcupytoo, but I don't quite oversee what's the best way to integrate it? Because dask providestopkof its own, it appears a bit exceptional.partitionandargpartition. This version works differs fromnumpy.partitionin its handling of NaNs, however. Using numpy's partition has the benefit that it's consistent with dask.topkfeels mostly similar to quantile, since it shortens, but doesn't reduce the dimension entirely. argmin also supports anaxisargument next todim(though exclusive) -- is the axis argument desirable?variable.pyborrows somewhat fromquantile, since the result has an axis with dimksize instead oflen(q). Unlikequantile, dask'stopkandargtopkdo not support tuple arguments for the axis (althoughtopkaccepts it and produces an unexpected result), so part of the stacking and unraveling functionality of_unravel_argminmaxis required. I've currently duplicated the relevant lines to keep changes clearly visible.apply_ufuncsto judge whetherdask="allowed"works gracefully with the dasktopkandargtopkfunctions, my guess is that it should.quantilereturns a result with a new dimension and coordinate calledquantile, I've mimicked this andtopkandargtopkreturn a result with a newtopkorargtopkdimension respectively. I was thinking no labels are required fortopk, but since both positive k values (for largest) and negative k values (for smallest) are possible, it's probably smart to return labelsrange(0, k)andrange(-k, 0)?idxtopkwould make sense too?skipna=Falseis giving me some headaches. A (naive) implementation as in this PR is assymetric. Numpy partition (and thus the dask version too) sorts NaNs towards the end of the array, such that k > 0 will return NaNs, but k < 0 will not. For the testing, I figuredda.topk(k=-1, skipna=False)should equalda.min(skipna=False)andda.topk(k=1, skipna=False), should equalda.max(), but this isn't the case. k=1 will return a NaN value since numpy partition moves the NaN to the end; k=-1 will not. I currently gravitate towards accepting this assymetry, since e.g.np.sortwill also move NaNs to the back and it feels forced to fetch NaNs for k=-1 to match.min(skipna=False). On the other hand, Python'ssortedbehaves differently, according to IEEE 754 NaNs are not orderable ... and I reckon you'd mostly useskipna=Falsewhen you want to ensure that no NaNs are present?duck_array_ops, maybe it belong innanopsas it resembles_nan_argminmax_objectand_nan_minmax_object, but is again slightly different. But I didn't like the circular imports that it seems to require; in duck_array_ops it decides whether to use dask or numpy (via nputils), but the masking of NaNs is required for both.