Open
Description
In the example below, for the dataset with the $.getitem()
implementation, the [
method returns an element without batch dimension for an index of length 1, and otherwise includes the batch dimension. I think it would be better to have this consistent and always return the batch dimension.
library(torch)
ds_batch = dataset("batch",
initialize = function() {
self$x = torch_randn(100, 10)
},
.getbatch = function(i) {
self$x[i,.., drop = FALSE]
},
.length = function() nrow(self$x)
)()
print(ds_batch[1L]$shape)
#> [1] 1 10
print(ds_batch[1:2]$shape)
#> [1] 2 10
ds_item = dataset("batch",
initialize = function() {
self$x = torch_randn(100, 10)
},
.getitem = function(i) {
self$x[i]
},
.length = function() nrow(self$x)
)()
print(ds_item[1L]$shape)
#> [1] 10
print(ds_item[1:2]$shape)
#> [1] 2 10
Created on 2025-04-17 with reprex v2.1.1
Metadata
Metadata
Assignees
Labels
No labels