-
Notifications
You must be signed in to change notification settings - Fork 102
Open
Description
I am trying to run the examples.simple.main.py and I encounter the following exception:
RuntimeError: tensors used as indices must be long or byte tensors
I have the following packages installed with pip3 in python3.6:
neat-python==0.92
numpy==1.15.2+mkl
gym==0.10.5
click==6.7
torch==0.4.0
torchvision==0.2.1
This error occurs within pytorch_neat.recurrent_net in the dense_from_coo function:
def dense_from_coo(shape, conns, dtype=torch.float64):
mat = torch.zeros(shape, dtype=dtype)
idxs, weights = conns
if len(idxs) == 0:
return mat
rows, cols = np.array(idxs).transpose()
mat[torch.tensor(rows), torch.tensor(cols)] = torch.tensor(
weights, dtype=dtype)
return mat
The problem is that np.array is assuming int32 for the indexes, but torch wants int64.
Simple solution:
def dense_from_coo(shape, conns, dtype=torch.float64):
mat = torch.zeros(shape, dtype=dtype)
idxs, weights = conns
if len(idxs) == 0:
return mat
rows, cols = np.array(idxs, dtype=np.int64).transpose()
mat[torch.tensor(rows), torch.tensor(cols)] = torch.tensor(
weights, dtype=dtype)
return mat
The difference may be from the differing numpy versions, but I think this change makes sense regardless for torch tensor indexing.
Metadata
Metadata
Assignees
Labels
No labels