Skip to content

Commit 5d26db9

Browse files
authored
Merge pull request #66 from jakirkham/dist_more_metrics
Support more metrics in cdist and pdist
2 parents 33cb8b5 + 80534f2 commit 5d26db9

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

dask_distance/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,15 @@ def cdist(XA, XB, metric="euclidean", **kwargs):
5353
"hamming": hamming,
5454
"jaccard": jaccard,
5555
"kulsinski": kulsinski,
56+
"mahalanobis": mahalanobis,
5657
"minkowski": minkowski,
5758
"rogerstanimoto": rogerstanimoto,
5859
"russellrao": russellrao,
5960
"sokalmichener": sokalmichener,
6061
"sokalsneath": sokalsneath,
62+
"seuclidean": seuclidean,
6163
"sqeuclidean": sqeuclidean,
64+
"wminkowski": wminkowski,
6265
"yule": yule,
6366
}
6467

@@ -93,8 +96,22 @@ def cdist(XA, XB, metric="euclidean", **kwargs):
9396

9497
metric = func_mappings[metric]
9598

96-
if metric == minkowski:
97-
kwargs["p"] = kwargs.get("p", 2)
99+
if metric == mahalanobis:
100+
if "VI" not in kwargs:
101+
kwargs["VI"] = (
102+
dask.array.linalg.inv(
103+
dask.array.cov(dask.array.vstack([XA, XB]).T)
104+
).T
105+
)
106+
elif metric == minkowski:
107+
kwargs.setdefault("p", 2)
108+
elif metric == seuclidean:
109+
if "V" not in kwargs:
110+
kwargs["V"] = (
111+
dask.array.var(dask.array.vstack([XA, XB]), axis=0, ddof=1)
112+
)
113+
elif metric == wminkowski:
114+
kwargs.setdefault("p", 2)
98115

99116
result = metric(XA, XB, **kwargs)
100117

@@ -124,6 +141,13 @@ def pdist(X, metric="euclidean", **kwargs):
124141
other tradeoffs.
125142
"""
126143

144+
if metric == "mahalanobis":
145+
if "VI" not in kwargs:
146+
kwargs["VI"] = dask.array.linalg.inv(dask.array.cov(X.T)).T
147+
elif metric == "seuclidean":
148+
if "V" not in kwargs:
149+
kwargs["V"] = dask.array.var(X, axis=0, ddof=1)
150+
127151
result = cdist(X, X, metric, **kwargs)
128152

129153
result = dask.array.triu(result, 1)

tests/test_dask_distance.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,15 @@ def test_1d_dist(funcname, kw, seed, size, chunks):
106106
("correlation", {}),
107107
("cosine", {}),
108108
("euclidean", {}),
109+
("mahalanobis", {"VI": None}),
110+
("mahalanobis", {}),
109111
("minkowski", {}),
110112
("minkowski", {"p": 3}),
113+
("seuclidean", {"V": None}),
114+
("seuclidean", {}),
111115
("sqeuclidean", {}),
116+
("wminkowski", {}),
117+
("wminkowski", {"p": 1.6}),
112118
(lambda u, v: (abs(u - v) ** 3).sum() ** (1.0 / 3.0), {}),
113119
]
114120
)
@@ -133,6 +139,19 @@ def test_2d_cdist(metric, kw, seed, u_shape, u_chunks, v_shape, v_chunks):
133139
d_u = da.from_array(a_u, chunks=u_chunks)
134140
d_v = da.from_array(a_v, chunks=v_chunks)
135141

142+
if metric == "mahalanobis":
143+
if "VI" not in kw:
144+
kw["VI"] = 2 * np.random.random(2 * u_shape[-1:]) - 1
145+
elif kw["VI"] is None:
146+
kw.pop("VI")
147+
elif metric == "seuclidean":
148+
if "V" not in kw:
149+
kw["V"] = 2 * np.random.random(u_shape[-1:]) - 1
150+
elif kw["V"] is None:
151+
kw.pop("V")
152+
elif metric == "wminkowski":
153+
kw["w"] = np.random.random(u_shape[-1:])
154+
136155
a_r = spdist.cdist(a_u, a_v, metric, **kw)
137156
d_r = dask_distance.cdist(d_u, d_v, metric, **kw)
138157

@@ -148,9 +167,15 @@ def test_2d_cdist(metric, kw, seed, u_shape, u_chunks, v_shape, v_chunks):
148167
("correlation", {}),
149168
("cosine", {}),
150169
("euclidean", {}),
170+
("mahalanobis", {"VI": None}),
171+
("mahalanobis", {}),
151172
("minkowski", {}),
152173
("minkowski", {"p": 3}),
174+
("seuclidean", {"V": None}),
175+
("seuclidean", {}),
153176
("sqeuclidean", {}),
177+
("wminkowski", {}),
178+
("wminkowski", {"p": 1.6}),
154179
(lambda u, v: (abs(u - v) ** 3).sum() ** (1.0 / 3.0), {}),
155180
]
156181
)
@@ -172,6 +197,19 @@ def test_2d_pdist(metric, kw, seed, u_shape, u_chunks):
172197
a_u = 2 * np.random.random(u_shape) - 1
173198
d_u = da.from_array(a_u, chunks=u_chunks)
174199

200+
if metric == "mahalanobis":
201+
if "VI" not in kw:
202+
kw["VI"] = 2 * np.random.random(2 * u_shape[-1:]) - 1
203+
elif kw["VI"] is None:
204+
kw.pop("VI")
205+
elif metric == "seuclidean":
206+
if "V" not in kw:
207+
kw["V"] = 2 * np.random.random(u_shape[-1:]) - 1
208+
elif kw["V"] is None:
209+
kw.pop("V")
210+
elif metric == "wminkowski":
211+
kw["w"] = np.random.random(u_shape[-1:])
212+
175213
a_r = spdist.pdist(a_u, metric, **kw)
176214
d_r = dask_distance.pdist(d_u, metric, **kw)
177215

0 commit comments

Comments
 (0)