Skip to content

A to(device) method on the high-level interface #502

@sfriedowitz

Description

@sfriedowitz

Describe the workflow you want to enable

Hello,

I've been using this package to power a personal algo-trading project, thanks for the great work!

I would find it incredibly helpful to add a to(device) on the top-level TabPFNRegressor or TabPFNClassifier classes. I see that it is possible to move the architecture and associated components across device by a save/load from disk, but this is cumbersome and raises exceptions for non-fitted models. Adding this high-level method would be incredibly useful and aid in usage of the package.

Thanks!

Describe your proposed solution

On both TabPFNRegressor and TabPFNClassifier, add a to(device) or a to_torch_device method or something similar that handles moving all weights/torch-components of the model to device.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

Impact

None

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions