Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
201 changes: 201 additions & 0 deletions examples/libtorch_kks/KKS_libtorch.i
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#
# Kim-Kim-Suzuki with Gibbs energy supplied by a torch model, solved on a 2D grid.
#


# Constants for Initial Conditions
r = 30
l = 4.2

# Initial condition function for order parameter
eta_IC = '0.5*(1-tanh(2*(sqrt(x^2+y^2)-${r})/${l}))'

# Phase-field model parameters
kappa_eta = 5
w = 1
M = 5
L = 5

# Expressions for switching function and bulk Gibbs energy
h_eta = 'eta^3*(6*eta^2-15*eta+10)'


[Domain]
dim = 2
nx = 100
ny = 100

xmin = -50
xmax = 50
ymin = -50
ymax = 50

# automatically create a matching mesh
mesh_mode = DUMMY
[]

[TensorComputes]
[Initialize]
[c_IC]
type = ParsedCompute
buffer = c
expression = '0.6 + (0.3-0.6)*${eta_IC}'
extra_symbols = 'true'
enable_jit = false
[]
[eta_IC]
type = ParsedCompute
buffer = eta
expression = '${eta_IC}'
extra_symbols = 'true'
enable_jit = false
[]
[psi_init]
type = ConstantTensor
buffer = psi
real = 1
[]

[M]
type = ConstantTensor
buffer = M
real = ${M}
[]
[L]
type = ConstantTensor
buffer = L
real = ${L}
[]
[L_kappa]
type = ReciprocalLaplacianFactor
buffer = L_kappa
factor = ${fparse ${L} * ${kappa_eta} }
[]
[h_eta_IC]
type = ParsedCompute
buffer = h_eta
expression = '${h_eta}'
inputs = eta
[]
[G_func_IC]
type = LibtorchGibbsEnergy
buffer = 'G'
phase_fractions = 'h_eta'
concentrations = 'c'
domega_detas = 'dG_dh'
chem_pots = 'mu'
libtorch_model_file = 'torch_NN_gibbs_model.pt'
[]
[smooth]
type = DeAliasingTensor
method = HOULI
buffer = smooth
[]
[]
[Solve]
[h_eta]
type = ParsedCompute
buffer = h_eta
expression = '${h_eta}'
inputs = eta
[]
[G_func]
type = LibtorchGibbsEnergy
buffer = 'G'
phase_fractions = 'h_eta'
concentrations = 'c'
domega_detas = 'dG_dh'
chem_pots = 'mu'
libtorch_model_file = 'torch_NN_gibbs_model.pt'
[]
[dG_deta]
type = ParsedCompute
buffer = 'dG_deta'
inputs = 'eta dG_dh'
expression = 'dG_dh * ${h_eta} + ${w} * eta^2 * (1-eta^2)^2'
derivatives = 'eta'
[]

[etabar]
type = ForwardFFT
buffer = etabar
input = eta
[]
[AC_bulk]
type = ReciprocalAllenCahn
L = L
buffer = AC_bulk
dF_chem_deta = dG_deta
psi = psi
[]
[NL_eta]
type = ParsedCompute
buffer = NL_eta
expression = 'AC_bulk '
inputs = 'AC_bulk'
[]
[cbar]
type = ForwardFFT
buffer = cbar
input = c
[]
[div_J]
type = ReciprocalMatDiffusion
buffer = 'div_J'
chemical_potential = mu
mobility = M
psi = psi
[]
[NL_c]
type = ParsedCompute
buffer = 'NL_c'
inputs = 'div_J smooth'
expression = 'smooth * div_J'
[]
[]
[]

[TensorSolver]
type = AdamsBashforthMoulton
buffer = 'c eta'
reciprocal_buffer = 'cbar etabar'
linear_reciprocal = '0 L_kappa'
nonlinear_reciprocal = 'NL_c NL_eta'
substeps = 1e3
predictor_order = 3
corrector_order = 1
corrector_steps = 1
[]

[Postprocessors]
[total_c]
type = TensorIntegralPostprocessor
buffer = c
[]
[]

[TensorOutputs]
[xdmf]
type = XDMFTensorOutput
buffer = 'eta c mu psi dG_deta dG_dh G'
enable_hdf5 = true
transpose = false
[]
[]

[Executioner]
type = Transient
num_steps = 100
[TimeStepper]
type = IterationAdaptiveDT
growth_factor = 1.25
dt = 0.1
[]
dtmax = 10
[]

[Outputs]
csv = true
perf_graph = true
execute_on = 'INITIAL TIMESTEP_END'
[]
63 changes: 63 additions & 0 deletions examples/libtorch_kks/generate_torch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn


class GibbsEnergy(nn.Module):
def __init__(self, E: torch.Tensor, c0_a: torch.Tensor, c0_b: torch.Tensor):
"""
Initialize the GibbsEnergy model.

Args:
E (torch.Tensor): Energy parameter.
c0_a (torch.Tensor): Concentration parameter a.
c0_b (torch.Tensor): Concentration parameter b.
"""
super(GibbsEnergy, self).__init__()
self.register_buffer("E", nn.Parameter(E))
self.register_buffer("c0_a", nn.Parameter(c0_a))
self.register_buffer("c0_b", nn.Parameter(c0_b))

def forward(self, x) -> torch.Tensor:
"""
Forward pass of the GibbsEnergy model.

Args:
x (torch.Tensor): Input tensor with shape (batch_size, 2).

Returns:
torch.Tensor: Computed Gibbs energy.
"""
h_eta = x[:, 0]
c = x[:, 1]

c_a = c + (1 - h_eta)*(self.c0_a - self.c0_b)
c_b = c - h_eta*(self.c0_a - self.c0_b)

return 0.5 * self.E * (h_eta * torch.square(c_a - self.c0_a)
+ (1-h_eta) * torch.square(c_b - self.c0_b))


def main():
# Initialize the model with specific values
G_torch = GibbsEnergy(torch.tensor(
[2.0]), torch.tensor([0.3]), torch.tensor([0.7]))

# Set the model to evaluation mode
G_torch.eval()

# Sample input tensor
x = torch.tensor([[0.0, 0.3], [1.0, 0.7]])

try:
# Trace the model with the sample input
scripted_model = torch.jit.trace(G_torch, x)

# Save the traced model
scripted_model.save('torch_NN_gibbs_model.pt')
print("Model saved successfully.")
except Exception as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
main()
38 changes: 38 additions & 0 deletions include/tensor_computes/LibtorchGibbsEnergy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/**********************************************************************/
/* DO NOT MODIFY THIS HEADER */
/* Swift, a Fourier spectral solver for MOOSE */
/* */
/* Copyright 2024 Battelle Energy Alliance, LLC */
/* ALL RIGHTS RESERVED */
/**********************************************************************/

#pragma once

#include "TensorOperator.h"
// moose headers
#include "DataFileUtils.h"
// libtorch headers
#include <torch/script.h>

class LibtorchGibbsEnergy : public TensorOperator<>
{
public:
static InputParameters validParams();

LibtorchGibbsEnergy(const InputParameters & parameters);

virtual void computeBuffer() override;

protected:
unsigned int _n_phases;
std::vector<const torch::Tensor *> _phase_fractions;
std::vector<torch::Tensor *> _domega_detas;

unsigned int _n_components;
std::vector<const torch::Tensor *> _concentrations;
std::vector<torch::Tensor *> _chemical_potentials;

Moose::DataFileUtils::Path _file_path;
// We need to use a pointer here because forward is not const qualified
std::unique_ptr<torch::jit::script::Module> _surrogate;
};
3 changes: 3 additions & 0 deletions src/base/SwiftApp.C
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ SwiftApp::registerAll(Factory & f, ActionFactory & af, Syntax & syntax)
registerMooseObjectTask("add_tensor_predictor", TensorPredictor, false);
addTaskDependency("add_tensor_predictor", "create_tensor_solver");

// Register data file path
registerAppDataFilePath("swift");

// make sure all this gets run before `add_mortar_variable`
addTaskDependency("add_mortar_variable", "add_tensor_predictor");
}
Expand Down
4 changes: 0 additions & 4 deletions src/tensor_computes/FFTGradient.C
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ FFTGradient::FFTGradient(const InputParameters & parameters)
void
FFTGradient::computeBuffer()
{
std::cout << "_x " << _direction << " = " << _x << std::endl;
std::cout << "_y " << _direction << " = " << _y << std::endl;
std::cout << "_z " << _direction << " = " << _z << std::endl;

_u = _domain.ifft((_input_is_reciprocal ? _input : _domain.fft(_input)) *
_domain.getReciprocalAxis(_direction) * _i);
}
Loading