Skip to content

A couple of issues in model_predictions from submission.py #105

Open
@aaprasad

Description

@aaprasad

Hi, I was using your submission example notebook to generate my submission files and I came across a couple of errors that I tried resolving manually and just wanted to make sure these were the right solutions and see if you could update your code to account for them. Both issues came from the model_predictions function in sensorium/sensorium/utility/submission.py

The first was TypeError: model.forward got an unexpected keyword argument 'data_key' this came from line 29. I fixed this by simply removing the data_key=data_key, **batch_kwargs in the model() call. I think this came from your example model having these arguments in model.forward() but just wanted to check.

The second error was RuntimeError: Given groups=1, weight of size [64, 3, 11, 11], expected input[128, 1, 144, 256] to have 3 channels, but got 1 channels instead. I fixed this one by adding a images = torch.cat([images,images,images],dim=1) on line 20 to convert it to 3 channels. I think an easier solution would be to directly open the images as rgb in the dataloader or somehow generalize the input so it works for models trained with either grayscale or rgb.

Let me know if I messed up other things by doing this or if theres a better work around. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions