Skip to content

Commit 567e980

Browse files
author
Jan Michelfeit
committed
#625 revert change of compute_state_entropy() from tensors to numpy
1 parent 1f50696 commit 567e980

File tree

2 files changed

+8
-23
lines changed

2 files changed

+8
-23
lines changed

src/imitation/util/util.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
362362

363363

364364
def compute_state_entropy(
365-
obs: np.ndarray,
366-
all_obs: np.ndarray,
365+
obs: th.Tensor,
366+
all_obs: th.Tensor,
367367
k: int,
368-
) -> np.ndarray:
368+
) -> th.Tensor:
369369
"""Compute the state entropy given by KNN distance.
370370
371371
Args:
@@ -379,19 +379,15 @@ def compute_state_entropy(
379379
assert obs.shape[1:] == all_obs.shape[1:]
380380
with th.no_grad():
381381
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
382-
distances_tensor = np.linalg.norm(
382+
distances_tensor = th.linalg.vector_norm(
383383
obs[:, None] - all_obs[None, :],
384-
axis=non_batch_dimensions,
384+
dim=non_batch_dimensions,
385385
ord=2,
386386
)
387387

388388
# Note that we take the k+1'th value because the closest neighbor to
389389
# a point is itself, which we want to skip.
390-
knn_dists = kth_value(distances_tensor, k+1)
390+
assert distances_tensor.shape[-1] > k
391+
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
391392
state_entropy = knn_dists
392-
return np.expand_dims(state_entropy, axis=1)
393-
394-
395-
def kth_value(x: np.ndarray, k: int):
396-
assert k > 0
397-
return np.partition(x, k - 1, axis=-1)[..., k - 1]
393+
return state_entropy.unsqueeze(1)

tests/util/test_util.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from imitation.util import sacred as sacred_util
1313
from imitation.util import util
14-
from imitation.util.util import kth_value
1514

1615

1716
def test_endless_iter():
@@ -146,13 +145,3 @@ def test_compute_state_entropy_2d():
146145
np.sqrt(20**2 + 2**2),
147146
)
148147

149-
150-
def test_kth_value():
151-
arr1 = np.arange(0, 10, 1)
152-
np.random.shuffle(arr1)
153-
arr2 = np.arange(0, 100, 10)
154-
np.random.shuffle(arr2)
155-
arr = np.stack([arr1, arr2])
156-
157-
result = kth_value(arr, 3)
158-
np.testing.assert_array_equal(result, np.array([2, 20]))

0 commit comments

Comments
 (0)