|
1 | 1 | import numpy as np |
2 | 2 |
|
3 | 3 |
|
4 | | -def fill_with_last(arr: np.ndarray) -> np.ndarray: |
| 4 | +def fill_with_last(arr: np.ndarray, marker=0) -> np.ndarray: |
5 | 5 | ''' |
6 | | - Given an array ARR, copy it and replace all zeros with the last |
7 | | - nonzero value. E.g., |
| 6 | + Given an array ARR, copy it and replace all elements equal to MARKER |
| 7 | + (default: 0) with the last value that was not equal to MARKER. E.g., |
8 | 8 | [3, 0, 0, 5, 0, 8, 0, 0, 0] => [3, 3, 3, 5, 5, 8, 8, 8, 8]. |
9 | 9 | TODO: Replace with a faster implementation. (This one is based on |
10 | 10 | the code that calculates unix_ts in raw_event_generator.py.) |
11 | 11 | ''' |
12 | | - groups = np.split(arr, np.argwhere(arr).ravel()) |
| 12 | + groups = np.split(arr, np.argwhere(arr != marker).ravel()) |
13 | 13 | # SLOW: |
14 | 14 | groups = [np.full(len(group), group[0]) |
15 | 15 | for group in groups if len(group)] |
16 | 16 | return np.concatenate(groups, axis=0) |
| 17 | + |
| 18 | + |
| 19 | +def fill_with_next(arr: np.ndarray, marker=0) -> np.ndarray: |
| 20 | + ''' |
| 21 | + Given an array ARR, copy it and replace all elements equal to MARKER |
| 22 | + (default: 0) with the next value that is not equal to MARKER. E.g., |
| 23 | + [0, 3, 0, 0, 5, 0, 8] => [3, 3, 5, 5, 5, 8, 8]. |
| 24 | + TODO: Replace with a faster implementation. |
| 25 | + ''' |
| 26 | + a = arr[::-1] |
| 27 | + a = fill_with_last(a, marker=marker) |
| 28 | + return a[::-1] |
0 commit comments