@@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
362
362
363
363
364
364
def compute_state_entropy (
365
- obs : np . ndarray ,
366
- all_obs : np . ndarray ,
365
+ obs : th . Tensor ,
366
+ all_obs : th . Tensor ,
367
367
k : int ,
368
- ) -> np . ndarray :
368
+ ) -> th . Tensor :
369
369
"""Compute the state entropy given by KNN distance.
370
370
371
371
Args:
@@ -379,19 +379,15 @@ def compute_state_entropy(
379
379
assert obs .shape [1 :] == all_obs .shape [1 :]
380
380
with th .no_grad ():
381
381
non_batch_dimensions = tuple (range (2 , len (obs .shape ) + 1 ))
382
- distances_tensor = np .linalg .norm (
382
+ distances_tensor = th .linalg .vector_norm (
383
383
obs [:, None ] - all_obs [None , :],
384
- axis = non_batch_dimensions ,
384
+ dim = non_batch_dimensions ,
385
385
ord = 2 ,
386
386
)
387
387
388
388
# Note that we take the k+1'th value because the closest neighbor to
389
389
# a point is itself, which we want to skip.
390
- knn_dists = kth_value (distances_tensor , k + 1 )
390
+ assert distances_tensor .shape [- 1 ] > k
391
+ knn_dists = th .kthvalue (distances_tensor , k = k + 1 , dim = 1 ).values
391
392
state_entropy = knn_dists
392
- return np .expand_dims (state_entropy , axis = 1 )
393
-
394
-
395
- def kth_value (x : np .ndarray , k : int ):
396
- assert k > 0
397
- return np .partition (x , k - 1 , axis = - 1 )[..., k - 1 ]
393
+ return state_entropy .unsqueeze (1 )
0 commit comments