|
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 | import pandas as pd
|
11 |
| -import pyemd |
12 | 11 | from scipy.spatial.distance import jensenshannon
|
13 |
| -from scipy.stats import entropy, kruskal, ks_2samp |
| 12 | +from scipy.stats import entropy, kruskal, ks_2samp, wasserstein_distance |
14 | 13 |
|
15 | 14 | from .. import utils
|
16 | 15 | from ..metrics import significance as pv
|
@@ -304,19 +303,30 @@ class EarthMoversDistance(CategoricalDistanceMetric):
|
304 | 303 | """
|
305 | 304 |
|
306 | 305 | def distance_pdf(self, p: pd.Series, q: pd.Series, bin_edges: Optional[np.ndarray]) -> float:
|
307 |
| - distance_matrix = 1 - np.eye(len(p)) |
308 |
| - |
309 |
| - if bin_edges is not None: |
310 |
| - # Use pair-wise euclidean distances between bin centers for scale data |
311 |
| - bin_centers = np.mean([bin_edges[:-1], bin_edges[1:]], axis=0) |
312 |
| - xx, yy = np.meshgrid(bin_centers, bin_centers) |
313 |
| - distance_matrix = np.abs(xx - yy) |
314 |
| - |
315 |
| - p = np.array(p).astype(np.float64) |
316 |
| - q = np.array(q).astype(np.float64) |
317 |
| - distance_matrix = distance_matrix.astype(np.float64) |
| 306 | + p_sum = p.sum() |
| 307 | + q_sum = q.sum() |
| 308 | + |
| 309 | + if p_sum == 0 and q_sum == 0: |
| 310 | + return 0.0 |
| 311 | + elif p_sum == 0 or q_sum == 0: |
| 312 | + return 1.0 |
| 313 | + |
| 314 | + # normalise counts for consistency with scipy.stats.wasserstein |
| 315 | + with np.errstate(divide="ignore", invalid="ignore"): |
| 316 | + p_normalised = np.nan_to_num(p / p_sum).astype(np.float64) |
| 317 | + q_normalised = np.nan_to_num(q / q_sum).astype(np.float64) |
| 318 | + |
| 319 | + if bin_edges is None: |
| 320 | + # if bins not given, histograms are assumed to be counts of nominal categories, |
| 321 | + # and therefore distances betwen bins are meaningless. Set to all distances to |
| 322 | + # unity to model this. |
| 323 | + distance = 0.5 * np.sum(np.abs(p_normalised - q_normalised)) |
| 324 | + else: |
| 325 | + # otherwise, use pair-wise euclidean distances between bin centers for scale data |
| 326 | + bin_centers = bin_edges[:-1] + np.diff(bin_edges) / 2.0 |
| 327 | + distance = wasserstein_distance(bin_centers, bin_centers, u_weights=p_normalised, v_weights=q_normalised) |
318 | 328 |
|
319 |
| - return pyemd.emd(p, q, distance_matrix) |
| 329 | + return distance |
320 | 330 |
|
321 | 331 | @property
|
322 | 332 | def id(self) -> str:
|
|
0 commit comments