Skip to content

Commit

Permalink
Update Material and SymbolicHandler classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsferreira committed Mar 28, 2024
1 parent 046de86 commit 4f38983
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions hyper_surrogate/materials.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any
from typing import Any, Iterable

import sympy as sym

from hyper_surrogate.symbolic import SymbolicHandler


class Material(SymbolicHandler):
def __init__(self, parameters: list) -> None:
def __init__(self, parameters: Iterable[Any]) -> None:
super().__init__()
self.parameters = parameters

Expand Down
4 changes: 2 additions & 2 deletions hyper_surrogate/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def lambdify(self, symbolic_tensor: sym.Matrix, *args: Iterable[Any]) -> Any:

return sym.lambdify((self.c_symbols(), *args), symbolic_tensor, modules="numpy")

def evaluate(self, lambdified_tensor: Any, *args: Any) -> Any:
def _evaluate(self, lambdified_tensor: Any, *args: Any) -> Any:
"""
Evaluate a lambdified tensor with specific values.
Expand All @@ -196,4 +196,4 @@ def evaluate_iterator(self, lambdified_tensor: Any, numerical_c_tensors: np.ndar
Any: The evaluated tensor.
"""
for numerical_c_tensor in numerical_c_tensors:
yield self.evaluate(lambdified_tensor, numerical_c_tensor.flatten(), *args)
yield self._evaluate(lambdified_tensor, numerical_c_tensor.flatten(), *args)
27 changes: 27 additions & 0 deletions tests/test_materials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging

import numpy as np
import pytest
import sympy as sym

from hyper_surrogate.materials import Material


@pytest.fixture
def material():
return Material([])


def test_material_dummy_sef(material):
assert material.sef == sym.Symbol("sef")


def test_pk2_symbol(material):
logging.info(material.pk2_symb)
assert material.pk2_symb == material.pk2_tensor(material.sef)
assert material.pk2_symb == sym.Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]])


def test_cmat_symbol(material):
assert material.cmat_symb == material.cmat_tensor(material.pk2_symb)
assert material.cmat_symb == sym.ImmutableDenseNDimArray(np.zeros((3, 3, 3, 3)))

0 comments on commit 4f38983

Please sign in to comment.