Skip to content

Commit f7d236e

Browse files
douglasdavisveprbl
andauthored
Store histref in a partial function for propagation through graph. (#157)
Co-authored-by: Dmitry Kalinkin <[email protected]>
1 parent bf85767 commit f7d236e

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

afile

Whitespace-only changes.

src/dask_histogram/core.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ def _blocked_multi(
410410
repacker: Callable,
411411
*flattened_inputs: tuple[Any],
412412
) -> bh.Histogram:
413-
414413
data_list, weights, samples, histref = repacker(flattened_inputs)
415414

416415
weights = weights or (None for _ in range(len(data_list)))
@@ -439,7 +438,6 @@ def _blocked_multi_df(
439438
repacker: Callable,
440439
*flattened_inputs: tuple[Any],
441440
) -> bh.Histogram:
442-
443441
data_list, weights, samples, histref = repacker(flattened_inputs)
444442

445443
weights = weights or (None for _ in range(len(data_list)))
@@ -1027,8 +1025,9 @@ def _partitioned_histogram(
10271025
if len(data) == 1 and data_is_dak:
10281026
from dask_awkward.lib.core import partitionwise_layer as dak_pwl
10291027

1030-
f = partial(_blocked_dak, weights=weights, sample=sample, histref=histref)
1031-
g = dak_pwl(f, name, data[0])
1028+
f = partial(_blocked_dak, histref=histref)
1029+
1030+
g = dak_pwl(f, name, data[0], weights, sample)
10321031

10331032
# Single object, not a dataframe
10341033
elif len(data) == 1 and not data_is_df:

tests/test_boost.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ def test_155_boost_factory():
573573
import boost_histogram as bh
574574

575575
dak = pytest.importorskip("dask_awkward")
576-
import numpy as np
577576

578577
import dask_histogram as dh
579578

@@ -584,3 +583,56 @@ def test_155_boost_factory():
584583
axes=(axis,),
585584
).compute()
586585
assert np.all(hist.values() == [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0])
586+
587+
588+
def test_155_2():
589+
import boost_histogram as bh
590+
591+
import dask_histogram as dh
592+
593+
dak = pytest.importorskip("dask_awkward")
594+
595+
arr = dak.from_lists([list(range(10))] * 3)
596+
axis = bh.axis.Regular(10, 0.0, 10.0)
597+
hist = dh.factory(
598+
arr,
599+
axes=(axis,),
600+
weights=arr,
601+
).compute()
602+
assert np.all(
603+
hist.values() == [0.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0]
604+
)
605+
606+
607+
def test_155_3_2d():
608+
import boost_histogram as bh
609+
610+
dak = pytest.importorskip("dask_awkward")
611+
612+
import dask_histogram as dh
613+
614+
arr1 = dak.from_lists([list(range(10))] * 3)
615+
arr2 = dak.from_lists([list(reversed(range(10)))] * 3)
616+
axis1 = bh.axis.Regular(10, 0.0, 10.0)
617+
axis2 = bh.axis.Regular(10, 0.0, 10.0)
618+
hist = dh.factory(
619+
arr1,
620+
arr2,
621+
axes=(axis1, axis2),
622+
weights=arr1,
623+
).compute()
624+
should_be = (
625+
[
626+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
627+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0],
628+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0],
629+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0],
630+
[0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 0.0, 0.0, 0.0, 0.0],
631+
[0.0, 0.0, 0.0, 0.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0],
632+
[0.0, 0.0, 0.0, 18.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
633+
[0.0, 0.0, 21.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
634+
[0.0, 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
635+
[27.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
636+
],
637+
)
638+
assert np.all(hist.values() == should_be)

0 commit comments

Comments
 (0)