diff --git a/python/fate_arch/computing/eggroll/_table.py b/python/fate_arch/computing/eggroll/_table.py index a33af09302..49650df4cc 100644 --- a/python/fate_arch/computing/eggroll/_table.py +++ b/python/fate_arch/computing/eggroll/_table.py @@ -126,32 +126,17 @@ def glom(self, **kwargs): @computing_profile def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None): + if fraction is not None and num is not None: + raise ValueError("specify only one of `fraction` or `num`, not both.") + if fraction is not None: return Table(self._rp.sample(fraction=fraction, seed=seed)) if num is not None: - total = self._rp.count() - if num > total: - raise ValueError(f"not enough data to sample, own {total} but required {num}") - - frac = num / float(total) - while True: - sampled_table = self._rp.sample(fraction=frac, seed=seed) - sampled_count = sampled_table.count() - if sampled_count < num: - frac *= 1.1 - else: - break - - if sampled_count > num: - drops = sampled_table.take(sampled_count - num) - for k, v in drops: - sampled_table.delete(k) - - return Table(sampled_table) + return self._exactly_sample(num, seed) raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}") - + @computing_profile def subtractByKey(self, other: 'Table', **kwargs): return Table(self._rp.subtract_by_key(other._rp)) @@ -169,3 +154,54 @@ def flatMap(self, func, **kwargs): flat_map = self._rp.flat_map(func) shuffled = flat_map.map(lambda k, v: (k, v)) # trigger shuffle return Table(shuffled) + + def _exactly_sample(self, num: int, seed: int): + split_size = list(self._rp.map_partitions_with_index( + lambda s, it: [(s, sum(1 for _ in it))] + ).get_all()) + LOGGER.info(f"{split_size}") + + if not split_size: + raise ValueError("no data available to sample") + + total = sum(v for _, v in split_size) + if num > total: + raise ValueError(f"not enough data to sample, own {total} but required {num}") + + sampled_size = {} + for split, size in split_size: + if size <= 0: + sampled_size[split] = 0 + else: + if num == 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + total -= size + num -= sampled_size[split] + + LOGGER.info(f"{sampled_size}") + + return self._rp.map_partitions_with_index(self._reservoir_sample_func(sampled_size, seed)) + + def _reservoir_sample_func(self, split_sample_size: dict, seed=None): + def func(split, iterator): + size = split_sample_size[split] + sample = [] + random_seed = seed + + if random_seed is None: + random_seed = random.randint(0, sys.maxsize) + random_state = random.Random(random_seed ^ split) + + for counter, obj in enumerate(iterator, start=1): + if len(sample) < size: + sample.append(obj) + else: + randint = random_state.randint(1, counter) + if randint <= size: + sample[randint - 1] = obj + + return iter(sample) + + return func diff --git a/python/fate_arch/computing/spark/_table.py b/python/fate_arch/computing/spark/_table.py index 0a25c34431..5b64f17ef6 100644 --- a/python/fate_arch/computing/spark/_table.py +++ b/python/fate_arch/computing/spark/_table.py @@ -314,7 +314,10 @@ def _exactly_sample(rdd, num: int, seed: int): # random the size of each split sampled_size = {} for split, size in split_size.items(): - sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) + if num == 0: + sampled_size[split] = 0 + else: + sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num) total = total - size num = num - sampled_size[split]