Skip to content

dataset's [ implementation has inconsistent output shapes when $.getitem() is implemented #1307

Open
@sebffischer

Description

@sebffischer

In the example below, for the dataset with the $.getitem() implementation, the [ method returns an element without batch dimension for an index of length 1, and otherwise includes the batch dimension. I think it would be better to have this consistent and always return the batch dimension.

library(torch)

ds_batch = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getbatch = function(i) {
    self$x[i,.., drop = FALSE]
  },
  .length = function() nrow(self$x)
)()

print(ds_batch[1L]$shape)
#> [1]  1 10
print(ds_batch[1:2]$shape)
#> [1]  2 10

ds_item = dataset("batch",
  initialize = function() {
    self$x = torch_randn(100, 10)
  },
  .getitem = function(i) {
    self$x[i]
  },
  .length = function() nrow(self$x)
)()

print(ds_item[1L]$shape)
#> [1] 10
print(ds_item[1:2]$shape)
#> [1]  2 10

Created on 2025-04-17 with reprex v2.1.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions