Skip to content

Commit 61ff213

Browse files
committed
Fixed issues with torch model device and removed enforcement of torch device in input files
1 parent b96b0eb commit 61ff213

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

examples/libtorch_kks/KKS_libtorch.i

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ h_eta = 'eta^3*(6*eta^2-15*eta+10)'
3030
ymin = -50
3131
ymax = 50
3232

33-
# run on a CUDA device (adjust this to `cpu` if not available)
34-
device_names = 'cuda'
35-
3633
# automatically create a matching mesh
3734
mesh_mode = DUMMY
3835
[]

src/tensor_computes/LibtorchGibbsEnergy.C

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ LibtorchGibbsEnergy::LibtorchGibbsEnergy(const InputParameters & parameters)
3838
_file_path(Moose::DataFileUtils::getPath(getParam<DataFileName>("libtorch_model_file"))),
3939
_surrogate(std::make_unique<torch::jit::script::Module>(torch::jit::load(_file_path.path)))
4040
{
41-
_surrogate->to(MooseTensor::floatTensorOptions().device());
41+
const auto opts = MooseTensor::floatTensorOptions();
42+
43+
auto ref = torch::empty({0}, opts);
44+
const auto dev = ref.device();
45+
const auto dt = ref.scalar_type();
46+
_surrogate->to(dev, dt, /*non_blocking=*/false);
4247
_surrogate->eval();
4348

4449
auto phase_fractions = getParam<std::vector<TensorInputBufferName>>("phase_fractions");

test/tests/kks/KKS_libtorch.i

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ h_eta = 'eta^3*(6*eta^2-15*eta+10)'
3030
ymin = -50
3131
ymax = 50
3232

33-
# run on a CUDA device (adjust this to `cpu` if not available)
34-
device_names = 'cuda'
35-
3633
# automatically create a matching mesh
3734
mesh_mode = DUMMY
3835
[]

0 commit comments

Comments
 (0)