diff --git a/poetry.lock b/poetry.lock index ec2c246581..eab1cc97eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5036,10 +5036,9 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" name = "pydantic-settings" version = "2.8.1" description = "Settings management using Pydantic" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"dev\"" files = [ {file = "pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c"}, {file = "pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585"}, @@ -5361,10 +5360,9 @@ six = ">=1.5" name = "python-dotenv" version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"dev\"" files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -7715,4 +7713,4 @@ vtk = ["vtk"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.14" -content-hash = "c948357280c9ae9288903051ab031ae58db5daca58c12221a51143b2558eca07" +content-hash = "6837994745f407980816a2bdd5213e019ea5044a315b8eaba3c0d738b92ddc2d" diff --git a/pyproject.toml b/pyproject.toml index 22e17c5633..9f56936843 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ documentation = "https://docs.flexcompute.com/projects/tidy3d/en/latest/" [tool.poetry.dependencies] python = ">=3.9,<3.14" +typing-extensions = { version = "*", python = "<3.11" } pyroots = ">=0.5.0" xarray = ">=2023.08" importlib-metadata = ">=6.0.0" @@ -33,7 +34,8 @@ numpy = "*" matplotlib = "*" shapely = "^2.0" pandas = "*" -pydantic = "^2.0" +pydantic = ">=2,<3" +pydantic-settings=">=2,<3" PyYAML = "*" dask = "*" toml = "*" diff --git a/tests/test_components/test_IO.py b/tests/test_components/test_IO.py index b4fe3322df..9dd149b93d 100644 --- a/tests/test_components/test_IO.py +++ b/tests/test_components/test_IO.py @@ -21,6 +21,7 @@ # Store an example of every minor release simulation to test updater in the future SIM_DIR = "tests/sims" +SIM_STATIC = SIM.to_static() @pytest.fixture @@ -66,14 +67,14 @@ def test_simulation_load_export(split_string, tmp_path): major, minor, patch = __version__.split(".") path = os.path.join(tmp_path, f"simulation_{major}_{minor}_{patch}.json") path_hdf5 = os.path.join(tmp_path, f"simulation_{major}_{minor}_{patch}.h5") - SIM.to_file(path) - SIM.to_hdf5(path_hdf5) + SIM_STATIC.to_file(path) + SIM_STATIC.to_hdf5(path_hdf5) SIM2 = td.Simulation.from_file(path) SIM_HDF5 = td.Simulation.from_hdf5(path_hdf5) assert ( - set_datasets_to_none(SIM)._json_string == SIM2._json_string + set_datasets_to_none(SIM_STATIC)._json_string == SIM2._json_string ), "original and loaded simulations are not the same" - assert SIM == SIM_HDF5, "original and loaded from hdf5 simulations are not the same" + assert SIM_STATIC == SIM_HDF5, "original and loaded from hdf5 simulations are not the same" def test_simulation_load_export_yaml(tmp_path): @@ -101,30 +102,30 @@ def test_component_load_export_yaml(tmp_path): def test_simulation_load_export_hdf5(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5") - SIM.to_file(path) + SIM_STATIC.to_file(path) SIM2 = td.Simulation.from_file(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_gz(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5.gz") - SIM.to_file(path) + SIM_STATIC.to_file(path) SIM2 = td.Simulation.from_file(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_explicit(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5") - SIM.to_hdf5(path) + SIM_STATIC.to_hdf5(path) SIM2 = td.Simulation.from_hdf5(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_gz_explicit(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5.gz") - SIM.to_hdf5_gz(path) + SIM_STATIC.to_hdf5_gz(path) SIM2 = td.Simulation.from_hdf5_gz(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_pckl(tmp_path): @@ -189,7 +190,7 @@ def test_validation_speed(tmp_path): for i in range(n): new_structure = SIM.structures[0].copy(update={"name": str(i)}) new_structures.append(new_structure) - S = SIM.copy(update=dict(structures=new_structures)) + S = SIM.copy(update=dict(structures=tuple(new_structures))) S.to_file(path) time_start = time() @@ -220,7 +221,9 @@ def test_simulation_updater(sim_file): def test_yaml(tmp_path): path = str(tmp_path / "simulation.json") SIM.to_file(path) + SIM.to_file("simulation.json") sim = td.Simulation.from_file(path) + path1 = str(tmp_path / "simulation.yaml") sim.to_yaml(path1) sim1 = td.Simulation.from_yaml(path1) diff --git a/tests/test_components/test_apodization.py b/tests/test_components/test_apodization.py index d1c3b14440..5a0ab45202 100644 --- a/tests/test_components/test_apodization.py +++ b/tests/test_components/test_apodization.py @@ -1,9 +1,9 @@ """Tests mode objects.""" import matplotlib.pyplot as plt -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError def test_apodization(): @@ -14,27 +14,27 @@ def test_apodization(): def test_end_lt_start(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=2, end=1, width=0.2) def test_no_width(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(end=2) def test_negative_times(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=-2, end=-1, width=0.2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2, width=-0.2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2, width=0) diff --git a/tests/test_components/test_base.py b/tests/test_components/test_base.py index 9058ec0897..70bacbd794 100644 --- a/tests/test_components/test_base.py +++ b/tests/test_components/test_base.py @@ -3,6 +3,8 @@ import numpy as np import pytest import tidy3d as td +from pydantic import ValidationError +from pydantic_core import PydanticSerializationError from tidy3d.components.base import Tidy3dBaseModel M = td.Medium() @@ -166,7 +168,7 @@ def test_updated_copy_path(): ) # forgot path - with pytest.raises(ValueError): + with pytest.raises(KeyError): assert sim == sim.updated_copy(permittivity=2.0) assert sim.updated_copy(size=(6, 6, 6)) == sim.updated_copy(size=(6, 6, 6), path=None) @@ -198,7 +200,7 @@ def test_attrs(tmp_path): assert obj.attrs == {"foo": "attr"} # this is still not allowed though - with pytest.raises(TypeError): + with pytest.raises(ValidationError): obj.attrs = {} # attrs can be modified @@ -215,7 +217,7 @@ def test_attrs(tmp_path): # attrs are in the json strings obj_json = obj3.json() - assert '{"foo": "bar"}' in obj_json + assert '{"foo":"bar"}' in obj_json # attrs are in the dict() obj_dict = obj3.dict() @@ -230,7 +232,7 @@ def test_attrs(tmp_path): # test attrs that can't be serialized obj.attrs["not_serializable"] = type - with pytest.raises(TypeError): + with pytest.raises(PydanticSerializationError): obj.json() diff --git a/tests/test_components/test_beam.py b/tests/test_components/test_beam.py index 457ad28292..1c9bbba210 100644 --- a/tests/test_components/test_beam.py +++ b/tests/test_components/test_beam.py @@ -1,8 +1,8 @@ """Tests for the various BeamProfile components.""" import numpy as np -import pydantic.v1 as pd import pytest +from pydantic import ValidationError from tidy3d.components.beam import ( AstigmaticGaussianBeamProfile, GaussianBeamProfile, @@ -94,7 +94,7 @@ def test_invalid_beam_size(): center = (0, 0, 0) size = (10, 10, 10) resolution = 100 - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): GaussianBeamProfile(center=center, size=size, resolution=resolution, freqs=FREQS) diff --git a/tests/test_components/test_boundaries.py b/tests/test_components/test_boundaries.py index 9ca0f7d915..d6d8c3ed68 100644 --- a/tests/test_components/test_boundaries.py +++ b/tests/test_components/test_boundaries.py @@ -1,8 +1,8 @@ """Tests boundary conditions.""" -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.boundary import ( PML, Absorber, @@ -78,11 +78,11 @@ def test_boundary_validators(): periodic = Periodic() # test `bloch_on_both_sides` - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = Boundary(plus=bloch, minus=pec) # test `periodic_with_pml` - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = Boundary(plus=periodic, minus=pml) diff --git a/tests/test_components/test_custom.py b/tests/test_components/test_custom.py index 4ba744f054..8d4d913a1d 100644 --- a/tests/test_components/test_custom.py +++ b/tests/test_components/test_custom.py @@ -1,13 +1,11 @@ """Tests custom sources and mediums.""" -from typing import Tuple - import dill as pickle import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td import xarray as xr +from pydantic import ValidationError from tidy3d.components.data.dataset import PermittivityDataset from tidy3d.components.data.utils import UnstructuredGridDataset, _get_numpy_array from tidy3d.components.medium import ( @@ -89,7 +87,7 @@ def make_spatial_data(value=0, dx=0, unstructured=False, seed=None, uniform=Fals CURRENT_SRC = make_custom_current_source() -def get_dataset(custom_source_obj) -> Tuple[str, td.FieldDataset]: +def get_dataset(custom_source_obj) -> tuple[str, td.FieldDataset]: """Get a dict containing dataset depending on type and its key.""" if isinstance(custom_source_obj, td.CustomFieldSource): return "field_dataset", custom_source_obj.field_dataset @@ -116,7 +114,7 @@ def test_validator_tangential_field(): """Test that it errors if no tangential field defined.""" field_dataset = FIELD_SRC.field_dataset field_dataset = field_dataset.copy(update=dict(Ex=None, Ez=None, Hx=None, Hz=None)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomFieldSource(size=SIZE, source_time=ST, field_dataset=field_dataset) @@ -124,7 +122,7 @@ def test_validator_non_planar(): """Test that it errors if the source geometry has a volume.""" field_dataset = FIELD_SRC.field_dataset field_dataset = field_dataset.copy(update=dict(Ex=None, Ez=None, Hx=None, Hz=None)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomFieldSource(size=(1, 1, 1), source_time=ST, field_dataset=field_dataset) @@ -134,7 +132,7 @@ def test_validator_freq_out_of_range_src(source): key, dataset = get_dataset(source) Ex_new = td.ScalarFieldDataArray(dataset.Ex.data, coords=dict(x=X, y=Y, z=Z, f=[0])) dataset_fail = dataset.copy(update=dict(Ex=Ex_new)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = source.updated_copy(size=SIZE, source_time=ST, **{key: dataset_fail}) @@ -145,7 +143,7 @@ def test_validator_freq_multiple(source): new_data = np.concatenate((dataset.Ex.data, dataset.Ex.data), axis=-1) Ex_new = td.ScalarFieldDataArray(new_data, coords=dict(x=X, y=Y, z=Z, f=[1, 2])) dataset_fail = dataset.copy(update=dict(Ex=Ex_new)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = source.copy(update={key: dataset_fail}) @@ -419,7 +417,7 @@ def test_medium_smaller_than_one_positive_sigma(unstructured): if unstructured: n_dataarray = cartesian_to_unstructured(n_dataarray.isel(f=0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = CustomMedium.from_nk(n_dataarray) # negative sigma @@ -433,7 +431,7 @@ def test_medium_smaller_than_one_positive_sigma(unstructured): n_dataarray = cartesian_to_unstructured(n_dataarray.isel(f=0), seed=1) k_dataarray = cartesian_to_unstructured(k_dataarray.isel(f=0), seed=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = CustomMedium.from_nk(n_dataarray, k_dataarray, freq=freqs[0]) @@ -470,9 +468,9 @@ def test_medium_nk(unstructured): assert np.isclose(med.eps_model(1e14), meds.eps_model(1e14), rtol=RTOL) # gain - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): med = CustomMedium.from_nk(n=n, k=-k) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): meds = CustomMedium.from_nk(n=ns, k=-ks, freq=freqs[0]) med = CustomMedium.from_nk(n=n, k=-k, allow_gain=True) meds = CustomMedium.from_nk(n=ns, k=-ks, freq=freqs[0], allow_gain=True) @@ -497,7 +495,7 @@ def test_medium_eps_model(): med.eps_model(frequency=freqs[0]) # error with multifrequency data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): med = make_custom_medium(make_scalar_data_multifreqs()) @@ -552,16 +550,8 @@ def verify_custom_medium_methods(mat, reduced_fields): # data fields in medium classes could be SpatialArrays or 2d tuples of spatial arrays # lets convert everything into 2d tuples of spatial arrays for uniform handling if isinstance(original, (td.SpatialDataArray, UnstructuredGridDataset)): - original = [ - [ - original, - ], - ] - reduced = [ - [ - reduced, - ], - ] + original = [[original]] + reduced = [[reduced]] for or_set, re_set in zip(original, reduced): assert len(or_set) == len(re_set) @@ -644,30 +634,30 @@ def test_custom_isotropic_medium(unstructured): conductivity = make_spatial_data(value=1, unstructured=unstructured, seed=seed) # some terms in permittivity are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1 + 0.1j, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=epstmp, conductivity=conductivity) # some terms in permittivity are < 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=epstmp, conductivity=conductivity) # some terms in conductivity are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sigmatmp = make_spatial_data(value=0.1j, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) # some terms in conductivity are negative sigmatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp, allow_gain=True) verify_custom_medium_methods(mat, ["permittivity", "conductivity"]) assert not mat.is_spatially_uniform # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sigmatmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) @@ -723,27 +713,27 @@ def test_custom_pole_residue(unstructured): c = 1j * make_spatial_data(value=1, unstructured=unstructured, seed=seed) # some terms in eps_inf are negative - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # some terms in eps_inf are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0.1j, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # inconsistent coords of eps_inf with a,c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # break causality - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): atmp = make_spatial_data(value=0, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=eps_inf, poles=((atmp, c),)) @@ -759,7 +749,7 @@ def test_custom_pole_residue(unstructured): # non-dispersive but gain a = 0 * c mat = CustomPoleResidue(eps_inf=eps_inf, poles=((a, c - 0.1),)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat_medium = mat.to_medium() mat = CustomPoleResidue(eps_inf=eps_inf, poles=((a, c - 0.1),), allow_gain=True) mat_medium = mat.to_medium() @@ -783,34 +773,34 @@ def test_custom_sellmeier(unstructured): c2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2))) # complex c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ctmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (b2, ctmp))) # negative c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ctmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (b2, ctmp))) # negative b btmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2))) mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2)), allow_gain=True) assert mat.pole_residue.allow_gain # inconsistent coord - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c2), (btmp, c2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=0, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomSellmeier(coeffs=((b1, c2), (btmp, c2))) @@ -842,32 +832,32 @@ def test_custom_lorentz(unstructured): delta2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex de - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): detmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (detmp, f2, delta2))) # mixed delta > f and delta < f over spatial points - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=1, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, f2, deltatmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, ftmp, delta2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, ftmp, delta2))) # break causality with negative delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, f2, deltatmp))) # gain medium with negative delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): detmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (detmp, f2, delta2))) mat = CustomLorentz( @@ -898,22 +888,22 @@ def test_custom_drude(unstructured): delta2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (f2, deltatmp))) # negative delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (f2, deltatmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (ftmp, delta2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (ftmp, delta2))) @@ -936,32 +926,32 @@ def test_custom_debye(unstructured): tau2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex eps - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # complex tau - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): tautmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (eps2, tautmp))) # negative tau - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): tautmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (eps2, tautmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # negative delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2)), allow_gain=True) @@ -1013,7 +1003,7 @@ def test_custom_anisotropic_medium(unstructured): # so that xx-component is using "nearest" freq = 2e14 dist_coeff = 0.7 - coord_test = td.Coords(x=[X[0] * dist_coeff + X[1] * (1 - dist_coeff)], y=Y[0], z=Z[0]) + coord_test = td.Coords(x=[X[0] * dist_coeff + X[1] * (1 - dist_coeff)], y=[Y[0]], z=[Z[0]]) eps_nearest = mat.eps_sigma_to_eps_complex( permittivity.interp(x=X[0], y=Y[0], z=Z[0], method="nearest"), conductivity.interp(x=X[0], y=Y[0], z=Z[0], method="nearest"), @@ -1064,11 +1054,11 @@ def test_custom_anisotropic_medium(unstructured): field_components = {f"eps_{d}{d}": make_scalar_data() for d in "xyz"} eps_dataset = PermittivityDataset(**field_components) mat_tmp = CustomMedium(eps_dataset=eps_dataset) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_tmp, yy=mat_yy, zz=mat_zz) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_xx, yy=mat_tmp, zz=mat_zz) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_xx, yy=mat_yy, zz=mat_tmp) @@ -1159,7 +1149,7 @@ def test_warn_planewave_intersection(): medium=mat, ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box]) + sim.updated_copy(structures=(box,)) def test_warn_diffraction_monitor_intersection(): @@ -1188,10 +1178,10 @@ def test_warn_diffraction_monitor_intersection(): with AssertLogLevel(None): sim = td.Simulation( size=(1, 1, 2), - structures=[box], + structures=(box,), grid_spec=td.GridSpec.auto(wavelength=1), - monitors=[monitor], - sources=[src], + monitors=(monitor,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1204,7 +1194,7 @@ def test_warn_diffraction_monitor_intersection(): medium=mat, ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box]) + sim.updated_copy(structures=(box,)) @pytest.mark.parametrize( @@ -1231,7 +1221,7 @@ def test_custom_medium_duplicate_coords(custom_class, data_key): spatial_data = td.SpatialDataArray(data, coords=coords) if custom_class == CustomMedium: - with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"): + with pytest.raises(ValidationError, match="duplicate coordinates"): _ = custom_class(permittivity=spatial_data) else: field_components = { @@ -1239,5 +1229,5 @@ def test_custom_medium_duplicate_coords(custom_class, data_key): } field_dataset = td.FieldDataset(**field_components) - with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"): + with pytest.raises(ValidationError, match="duplicate coordinates"): _ = custom_class(size=SIZE, source_time=ST, **{data_key: field_dataset}) diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index 0e9c617607..bffe3a5c67 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -1,5 +1,5 @@ import numpy as np -import pydantic.v1 as pd +import pydantic as pd import pytest import tidy3d as td from matplotlib import pyplot as plt @@ -303,7 +303,7 @@ def test_eme_simulation(): _ = sim.updated_copy(freqs=None) # no symmetry in propagation direction - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(symmetry=(0, 0, 1)) # test warning for not providing wavelength in autogrid @@ -319,8 +319,8 @@ def test_eme_simulation(): ) # test port offsets - with pytest.raises(ValidationError): - _ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(port_offsets=(sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3)) # test duplicate freqs with pytest.raises(pd.ValidationError): @@ -339,7 +339,7 @@ def test_eme_simulation(): med = td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond) struct = sim.structures[0].updated_copy(medium=med) with pytest.raises(pd.ValidationError): - _ = sim.updated_copy(structures=[struct]) + _ = sim.updated_copy(structures=(struct,)) # warn for time modulated FREQ_MODULATE = 1e12 AMP_TIME = 1.1 @@ -356,7 +356,7 @@ def test_eme_simulation(): _ = td.EMESimulation( size=sim.size, monitors=sim.monitors, - structures=[struct], + structures=(struct,), grid_spec=grid_spec, axis=sim.axis, eme_grid_spec=sim.eme_grid_spec, @@ -364,7 +364,8 @@ def test_eme_simulation(): ) # warn for nonlinear nonlinear = td.Medium( - permittivity=2, nonlinear_spec=td.NonlinearSpec(models=[td.NonlinearSusceptibility(chi3=1)]) + permittivity=2, + nonlinear_spec=td.NonlinearSpec(models=(td.NonlinearSusceptibility(chi3=1),)), ) struct = sim.structures[0].updated_copy(medium=nonlinear) with AssertLogLevel("WARNING"): @@ -389,34 +390,34 @@ def test_eme_simulation(): # test monitor setup monitor = sim.monitors[0].updated_copy(freqs=[sim.freqs[0], sim.freqs[0]]) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[0].updated_copy(freqs=[5e10]) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[0].updated_copy(num_modes=1000) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[2].updated_copy(num_modes=6) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) # test monitor at simulation bounds monitor = sim.monitors[-1].updated_copy(center=[0, 0, -sim.size[2] / 2]) with pytest.raises(pd.ValidationError): - _ = sim.updated_copy(monitors=[monitor]) + _ = sim.updated_copy(monitors=(monitor,)) # test max sim size and freqs sim_bad = sim.updated_copy(size=(1000, 1000, 1000)) with pytest.raises(SetupError): sim_bad.validate_pre_upload() - sim_bad = sim.updated_copy(size=(1000, 500, 3), monitors=[], store_port_modes=True) + sim_bad = sim.updated_copy(size=(1000, 500, 3), monitors=(), store_port_modes=True) with pytest.raises(SetupError): sim_bad.validate_pre_upload() - sim_bad = sim.updated_copy(size=(1000, 500, 3), monitors=[], store_port_modes=False) + sim_bad = sim.updated_copy(size=(1000, 500, 3), monitors=(), store_port_modes=False) with pytest.raises(SetupError): sim_bad.validate_pre_upload() - sim_bad = sim.updated_copy(size=(500, 500, 3), monitors=[]) + sim_bad = sim.updated_copy(size=(500, 500, 3), monitors=()) with AssertLogLevel("WARNING", "slow-down"): sim_bad.validate_pre_upload() @@ -434,13 +435,13 @@ def test_eme_simulation(): large_monitor = sim.monitors[2].updated_copy(size=(td.inf, td.inf, td.inf)) _ = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 1)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 5)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -448,7 +449,7 @@ def test_eme_simulation(): sim_bad.validate_pre_upload() sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 20)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -456,7 +457,7 @@ def test_eme_simulation(): sim_bad.validate_pre_upload() sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor, large_monitor.updated_copy(name="lmon2")], + monitors=(large_monitor, large_monitor.updated_copy(name="lmon2")), freqs=list(1e14 * np.linspace(1, 2, 5)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -469,21 +470,21 @@ def test_eme_simulation(): center=(0, 0, -1.5), name="modes", ) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[mode_monitor], port_offsets=(0.5, 0.5)) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(mode_monitor,), port_offsets=(0.5, 0.5)) # test eme cell interval space mode_monitor = mode_monitor.updated_copy( size=(td.inf, td.inf, td.inf), eme_cell_interval_space=8 ) - sim2 = sim.updated_copy(monitors=[mode_monitor]) + sim2 = sim.updated_copy(monitors=(mode_monitor,)) assert sim2._monitor_num_eme_cells(monitor=mode_monitor) == 2 # test monitor num modes - sim_tmp = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_modes=1)]) + sim_tmp = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_modes=1),)) assert sim_tmp._monitor_num_modes_cell(monitor=sim_tmp.monitors[0], cell_index=0) == 1 # test monitor num freqs - sim_tmp = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(freqs=[sim.freqs[0]])]) + sim_tmp = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(freqs=[sim.freqs[0]]),)) assert sim_tmp._monitor_num_freqs(monitor=sim_tmp.monitors[0]) == 1 # test sweep @@ -500,9 +501,9 @@ def test_eme_simulation(): scale_factors=np.stack((np.linspace(1, 2, 7), np.linspace(1, 2, 7))) ) ) - with pytest.raises(SetupError): - _ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[])) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=())) + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( sweep_spec=td.EMELengthSweep( scale_factors=np.stack( @@ -514,19 +515,19 @@ def test_eme_simulation(): ) ) # second shape of length sweep must equal number of cells - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=np.array([[1, 2], [3, 4]]))) _ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(1, 5)))) # test sweep size limit - with pytest.raises(SetupError): - _ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[])) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=())) sim_bad = sim.updated_copy( sweep_spec=td.EMELengthSweep(scale_factors=list(np.linspace(1, 2, 200))) ) with pytest.raises(SetupError): sim_bad.validate_pre_upload() # can't exceed max num modes - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(150, 200)))) # don't warn in these two cases @@ -564,38 +565,38 @@ def test_eme_simulation(): assert sim._sweep_modes assert sim._num_sweep == 2 assert sim._monitor_num_sweep(sim.monitors[0]) == 1 - sim = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_sweep=None)]) + sim = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_sweep=None),)) assert sim._monitor_num_sweep(sim.monitors[0]) == 2 - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_sweep=4)]) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_sweep=4),)) + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1e-10, 2])) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( eme_grid_spec=td.EMEExplicitGrid( - boundaries=[-sim.size[2] / 2 + 0.001], - mode_specs=[td.EMEModeSpec(), td.EMEModeSpec()], + boundaries=(-sim.size[2] / 2 + 0.001,), + mode_specs=(td.EMEModeSpec(), td.EMEModeSpec()), ) ) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( eme_grid_spec=td.EMEExplicitGrid( - boundaries=[sim.size[2] / 2 - 0.001], - mode_specs=[td.EMEModeSpec(), td.EMEModeSpec()], + boundaries=(sim.size[2] / 2 - 0.001,), + mode_specs=(td.EMEModeSpec(), td.EMEModeSpec()), ) ) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( - monitors=[ + monitors=( td.ModeSolverMonitor( - center=[0, 0, sim.size[2] / 2 - 0.001], - size=[td.inf, td.inf, 0], + center=(0, 0, sim.size[2] / 2 - 0.001), + size=(td.inf, td.inf, 0), name="modes", freqs=sim.freqs, mode_spec=td.ModeSpec(), - ) - ] + ), + ) ) @@ -1193,7 +1194,7 @@ def test_eme_sim_data(): # test field in basis with freq sweep field_monitor_data = _get_eme_field_data(num_sweep=10) data[2] = field_monitor_data - sim_data = sim_data.updated_copy(data=data) + sim_data = sim_data.updated_copy(data=tuple(data)) field_in_basis = sim_data.field_in_basis(field=sim_data["field"], port_index=0) assert len(field_in_basis.Ex.sweep_index) == 10 assert "mode_index" in field_in_basis.Ex.coords @@ -1237,7 +1238,7 @@ def test_eme_periodicity(): # directly give it num_reps # can't have field monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(num_reps=2, path="eme_grid_spec/subgrids/1") # EMEPeriodicitySweep validation @@ -1245,24 +1246,24 @@ def test_eme_periodicity(): _ = td.EMEPeriodicitySweep(num_reps=[{"a": n} for n in range(150000, 150003)]) sweep_spec = td.EMEPeriodicitySweep(num_reps=[{"a": n} for n in range(1, 4)]) # still can't have field monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=sweep_spec) # remove the field monitor, now it passes desired_cell_index_pairs = set([(i, i + 1) for i in range(6)] + [(5, 1)]) with AssertLogLevel(None): sim = sim.updated_copy( - monitors=[m for m in sim.monitors if not isinstance(m, td.EMEFieldMonitor)] + monitors=tuple(m for m in sim.monitors if not isinstance(m, td.EMEFieldMonitor)) ) sim2 = sim.updated_copy(num_reps=2, path="eme_grid_spec/subgrids/1") assert set(sim2._cell_index_pairs) == desired_cell_index_pairs # sweep can't have coeff monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=sweep_spec) # remove coeff monitor too, now it passes with AssertLogLevel(None): sim = sim.updated_copy( - monitors=[m for m in sim.monitors if not isinstance(m, td.EMECoefficientMonitor)] + monitors=tuple(m for m in sim.monitors if not isinstance(m, td.EMECoefficientMonitor)) ) sim2 = sim.updated_copy(sweep_spec=sweep_spec) assert set(sim2._cell_index_pairs) == desired_cell_index_pairs @@ -1281,10 +1282,10 @@ def test_eme_grid_from_structures(): names=[None, "wg", None], num_reps=[1, 2, 1], ) - sim = sim.updated_copy(eme_grid_spec=eme_grid_spec, monitors=[]) + sim = sim.updated_copy(eme_grid_spec=eme_grid_spec, monitors=()) with pytest.raises(ValidationError): _ = td.EMECompositeGrid.from_structure_groups( - structure_groups=[], + structure_groups=(), axis=2, mode_specs=[], names=[None, "wg", None], @@ -1292,7 +1293,7 @@ def test_eme_grid_from_structures(): ) with pytest.raises(ValidationError): _ = td.EMECompositeGrid.from_structure_groups( - structure_groups=[[], [td.Box(center=(0, 0, 0), size=(1, 1, 1))], []], + structure_groups=[([], [td.Box(center=(0, 0, 0), size=(1, 1, 1))], [])], axis=2, mode_specs=[td.EMEModeSpec(num_modes=1)] * 2, names=[None, "wg", None], @@ -1357,6 +1358,6 @@ def test_eme_sim_2d(): axis=2, freqs=[freq0], eme_grid_spec=eme_grid_spec, - monitors=[monitor], + monitors=(monitor,), port_offsets=(0.5, 0), ) diff --git a/tests/test_components/test_field_projection.py b/tests/test_components/test_field_projection.py index 94d1e0d552..33921738ed 100644 --- a/tests/test_components/test_field_projection.py +++ b/tests/test_components/test_field_projection.py @@ -1,9 +1,9 @@ """Test near field to far field transformations.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.field_projection import FieldProjector from tidy3d.exceptions import DataError @@ -282,7 +282,7 @@ def test_proj_clientside(): sim = td.Simulation( size=sim_size, grid_spec=td.GridSpec.auto(wavelength=td.C_0 / f0), - monitors=[monitor], + monitors=(monitor,), run_time=1e-12, ) @@ -617,7 +617,7 @@ def test_2d_sim_with_proj_monitors_near(): # Modify only proj_distance and far_field_approx proj_monitors_near = [ - monitor.__class__( + type(monitor)( proj_distance=R_FAR / 50, # Adjust projection distance far_field_approx=False, # Disable far-field approximation **{ @@ -630,7 +630,7 @@ def test_2d_sim_with_proj_monitors_near(): ] with pytest.raises( - pydantic.ValidationError, + ValidationError, match="Exact far-field projection for 2D simulations is not yet available", ): _ = td.Simulation( diff --git a/tests/test_components/test_geometry.py b/tests/test_components/test_geometry.py index 160885de2e..278ffbd3b8 100644 --- a/tests/test_components/test_geometry.py +++ b/tests/test_components/test_geometry.py @@ -7,7 +7,7 @@ import gdstk import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import shapely import tidy3d as td @@ -36,10 +36,10 @@ CYLINDER = td.Cylinder(axis=2, length=1, radius=1) GROUP = td.GeometryGroup( - geometries=[ + geometries=( td.Box(center=(-0.25, 0, 0), size=(0.5, 1, 1)), td.Box(center=(0.25, 0, 0), size=(0.5, 1, 1)), - ] + ) ) UNION = td.ClipOperation( operation="union", @@ -221,16 +221,16 @@ def test_intersections_plane_inf(): def test_center_not_inf_validate(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(center=(td.inf, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(center=(-td.inf, 0, 0)) def test_radius_not_inf_validate(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sphere(radius=td.inf) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=td.inf, center=(0, 0, 0), axis=1, length=1) @@ -247,7 +247,7 @@ def test_slanted_cylinder_infinite_length_validate(): sidewall_angle=0.1, reference_plane="middle", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder( radius=1, center=(0, 0, 0), @@ -256,7 +256,7 @@ def test_slanted_cylinder_infinite_length_validate(): sidewall_angle=0.1, reference_plane="top", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder( radius=1, center=(0, 0, 0), @@ -303,7 +303,7 @@ def test_polyslab_inf_bounds(lower_bound, upper_bound): def test_polyslab_bounds(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.PolySlab(vertices=((0, 0), (1, 0), (1, 1)), slab_bounds=(0.5, -0.5), axis=2) @@ -339,15 +339,15 @@ def test_polyslab_inf_to_finite_bounds(axis): def test_validate_polyslab_vertices_valid(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): POLYSLAB.copy(update=dict(vertices=(1, 2, 3))) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): crossing_verts = ((0, 0), (1, 1), (0, 1), (1, 0)) POLYSLAB.copy(update=dict(vertices=crossing_verts)) def test_sidewall_failed_validation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): POLYSLAB.copy(update=dict(sidewall_angle=1000)) @@ -389,7 +389,7 @@ def test_gdspy_cell(): def make_geo_group(): """Make a generic Geometry Group.""" - boxes = [td.Box(size=(1, 1, 1), center=(i, 0, 0)) for i in range(-5, 5)] + boxes = tuple(td.Box(size=(1, 1, 1), center=(i, 0, 0)) for i in range(-5, 5)) return td.GeometryGroup(geometries=boxes) @@ -417,8 +417,8 @@ def test_geo_group_methods(): def test_geo_group_empty(): """dont allow empty geometry list.""" - with pytest.raises(pydantic.ValidationError): - _ = td.GeometryGroup(geometries=[]) + with pytest.raises(pd.ValidationError): + _ = td.GeometryGroup(geometries=()) def test_geo_group_volume(): @@ -593,22 +593,22 @@ def test_flattening(): flat = list( flatten_groups( td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 1, 1)), td.Box(size=(0, 1, 0)), td.ClipOperation( operation="union", geometry_a=td.Box(size=(0, 0, 1)), geometry_b=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(2, 2, 2)), td.GeometryGroup( - geometries=[td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))] + geometries=(td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))) ), - ] + ) ), ), - ] + ) ) ) ) @@ -618,22 +618,22 @@ def test_flattening(): flat = list( flatten_groups( td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 1, 1)), td.Box(size=(0, 1, 0)), td.ClipOperation( operation="intersection", geometry_a=td.Box(size=(0, 0, 1)), geometry_b=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(2, 2, 2)), td.GeometryGroup( - geometries=[td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))] + geometries=(td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))) ), - ] + ) ), ), - ] + ) ) ) ) @@ -676,15 +676,15 @@ def test_geometry_traversal(): assert len(geometries) == 1 geo_tree = td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 0, 0)), td.ClipOperation( operation="intersection", geometry_a=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(5, 0, 0)), td.Box(size=(6, 0, 0)), - ] + ) ), geometry_b=td.ClipOperation( operation="difference", @@ -693,13 +693,13 @@ def test_geometry_traversal(): ), ), td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(3, 0, 0)), td.Box(size=(4, 0, 0)), - ] + ) ), td.Box(size=(2, 0, 0)), - ] + ) ) geometries = list(traverse_geometries(geo_tree)) assert len(geometries) == 13 @@ -717,34 +717,34 @@ def test_geometry(): # _ = PolySlab(vertices=vertices_np, slab_bounds=(-1, 1), axis=1) # make sure wrong axis arguments error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=-1, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PolySlab(radius=1, center=(0, 0, 0), axis=-1, slab_bounds=(-0.5, 0.5)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=3, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PolySlab(radius=1, center=(0, 0, 0), axis=3, slab_bounds=(-0.5, 0.5)) # make sure negative values error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sphere(radius=-1, center=(0, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=-1, center=(0, 0, 0), axis=3, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=3, length=-1) def test_geometry_sizes(): # negative in size kwargs errors for size in (-1, 1, 1), (1, -1, 1), (1, 1, -1): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(size=size, center=(0, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation(size=size, run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0)) # negative grid sizes error? - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation(size=(1, 1, 1), grid_spec=td.GridSpec.uniform(dl=-1.0), run_time=1e-12) @@ -885,7 +885,7 @@ def test_polyslab_intersection_inf_bounds(): assert poly.intersections_plane(x=0)[0] == shapely.box(-1, 0.0, 1, LARGE_NUMBER) # 2) [-inf, 0] - poly = poly.updated_copy(slab_bounds=[-td.inf, 0]) + poly = poly.updated_copy(slab_bounds=(-td.inf, 0)) assert len(poly.intersections_plane(x=0)) == 1 assert poly.intersections_plane(x=0)[0] == shapely.box(-1, -LARGE_NUMBER, 1, 0) @@ -1049,7 +1049,7 @@ def test_custom_surface_geometry(tmp_path): def test_geo_group_sim(): geo_grp = td.TriangleMesh.from_stl("tests/data/two_boxes_separate.stl") geos_orig = list(geo_grp.geometries) - geo_grp_full = geo_grp.updated_copy(geometries=geos_orig + [td.Box(size=(1, 1, 1))]) + geo_grp_full = geo_grp.updated_copy(geometries=tuple(geos_orig + [td.Box(size=(1, 1, 1))])) sim = td.Simulation( size=(10, 10, 10), @@ -1066,7 +1066,7 @@ def test_geo_group_sim(): def test_finite_geometry_transformation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(size=(td.inf, 0, 1)).scaled(1, 1, 1) diff --git a/tests/test_components/test_grid_spec.py b/tests/test_components/test_grid_spec.py index 73096b289c..5313bda74a 100644 --- a/tests/test_components/test_grid_spec.py +++ b/tests/test_components/test_grid_spec.py @@ -1,9 +1,9 @@ """Tests GridSpec.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.exceptions import SetupError @@ -314,7 +314,7 @@ def test_zerosize_dimensions(): assert np.allclose(sim.grid.boundaries.y, [-dl / 2, dl / 2]) - with pytest.raises(SetupError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(5, 0, 10), boundary_spec=td.BoundarySpec.pec( @@ -330,7 +330,7 @@ def test_zerosize_dimensions(): run_time=1e-12, ) - with pytest.raises(SetupError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(5, 3, 10), boundary_spec=td.BoundarySpec.pec( @@ -529,7 +529,7 @@ def test_domain_mismatch(): def test_uniform_grid_dl_validation(dl, expect_exception): """Test the validator that checks 'dl' is between 1e-7 and 3e8 µm.""" if expect_exception: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), grid_spec=td.GridSpec.uniform(dl=dl), diff --git a/tests/test_components/test_heat.py b/tests/test_components/test_heat.py index 24eb3dfdcc..c99685e68f 100644 --- a/tests/test_components/test_heat.py +++ b/tests/test_components/test_heat.py @@ -1,8 +1,8 @@ import numpy as np -import pydantic.v1 as pd import pytest import tidy3d as td from matplotlib import pyplot as plt +from pydantic import ValidationError from tidy3d import ( ConvectionBC, DistanceUnstructuredGrid, @@ -50,10 +50,10 @@ def make_heat_mediums(): def test_heat_medium(): _, solid_medium = make_heat_mediums() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = solid_medium.heat_spec.updated_copy(capacity=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = solid_medium.heat_spec.updated_copy(conductivity=-1) @@ -92,13 +92,13 @@ def make_heat_bcs(): def test_heat_bcs(): bc_temp, bc_flux, bc_conv = make_heat_bcs() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = TemperatureBC(temperature=-10) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = ConvectionBC(ambient_temperature=-400, transfer_coeff=0.2) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = ConvectionBC(ambient_temperature=400, transfer_coeff=-0.2) @@ -120,10 +120,10 @@ def make_heat_mnts(): def test_heat_mnt(): temp_mnt, _, _, _, _, _ = make_heat_mnts() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = temp_mnt.updated_copy(name=None) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = temp_mnt.updated_copy(size=(-1, 2, 3)) @@ -233,20 +233,20 @@ def make_distance_grid_spec(): def test_grid_spec(): grid_spec = make_uniform_grid_spec() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(min_edges_per_circumference=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(min_edges_per_side=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(relative_min_dl=-1e-4) grid_spec = make_distance_grid_spec() _ = grid_spec.updated_copy(relative_min_dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(dl_interface=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(distance_interface=2, distance_bulk=1) @@ -256,8 +256,8 @@ def make_heat_source(): def test_heat_source(): source = make_heat_source() - with pytest.raises(pd.ValidationError): - _ = source.updated_copy(structures=[]) + with pytest.raises(ValidationError): + _ = source.updated_copy(structures=()) def make_heat_sim(): @@ -319,23 +319,25 @@ def test_heat_sim(): condition=bc_temp, placement=StructureSimulationBoundary(structure="no_mesh") ), ]: - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(boundary_spec=[pl]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(boundary_spec=(pl,)) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(sources=[UniformHeatSource(structures=["noname"])], rate=-10) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(sources=(UniformHeatSource(structures=["noname"])), rate=-10) # run 2D case - _ = heat_sim.updated_copy(center=(0.7, 0, 0), size=(0, 2, 2), monitors=heat_sim.monitors[:5]) + _ = heat_sim.updated_copy( + center=(0.7, 0, 0), size=(0, 2, 2), monitors=tuple(heat_sim.monitors[:5]) + ) # test unsupported 1D heat domains - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy(center=(1, 1, 1), size=(1, 0, 0)) temp_mnt = heat_sim.monitors[0] - with pytest.raises(pd.ValidationError): - heat_sim.updated_copy(monitors=[temp_mnt, temp_mnt]) + with pytest.raises(ValidationError): + heat_sim.updated_copy(monitors=(temp_mnt, temp_mnt)) _ = heat_sim.plot(x=0) plt.close() @@ -348,7 +350,7 @@ def test_heat_sim(): plt.close() # no negative symmetry - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy(symmetry=(-1, 0, 1)) # no SolidSpec in the entire simulation @@ -357,16 +359,16 @@ def test_heat_sim(): ) solid_med = heat_sim.structures[1].medium - _ = heat_sim.updated_copy(structures=[], medium=solid_med, sources=[], boundary_spec=[bc_spec]) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=[], sources=[], boundary_spec=[bc_spec], monitors=[]) + _ = heat_sim.updated_copy(structures=(), medium=solid_med, sources=(), boundary_spec=(bc_spec,)) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(), sources=(), boundary_spec=(bc_spec,), monitors=()) _ = heat_sim.updated_copy( - structures=[heat_sim.structures[0]], medium=solid_med, boundary_spec=[bc_spec], sources=[] + structures=(heat_sim.structures[0],), medium=solid_med, boundary_spec=(bc_spec,), sources=() ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy( - structures=[heat_sim.structures[0]], boundary_spec=[bc_spec], sources=[], monitors=[] + structures=(heat_sim.structures[0],), boundary_spec=(bc_spec,), sources=(), monitors=() ) # 1D and 2D structures @@ -378,18 +380,18 @@ def test_heat_sim(): geometry=td.Box(size=(1, 0, 1)), medium=heat_sim.medium, ) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=list(heat_sim.structures) + [struct_1d]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(*heat_sim.structures, struct_1d)) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=list(heat_sim.structures) + [struct_2d]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(*heat_sim.structures, struct_2d)) # no data expected inside a monitor for mnt_size in [(0.2, 0.2, 0.2), (0, 1, 1), (0, 2, 0), (0, 0, 0)]: temp_mnt = td.TemperatureMonitor(center=(0, 0, 0), size=mnt_size, name="test") - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(monitors=[temp_mnt]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(monitors=(temp_mnt,)) @pytest.mark.parametrize("shift_amount, log_level", ((1, None), (2, "WARNING"))) @@ -490,15 +492,15 @@ def test_sim_data(): with pytest.raises(KeyError): _ = heat_sim_data.plot_field("test3", x=0) - with pytest.raises(pd.ValidationError): - _ = heat_sim_data.updated_copy(data=[heat_sim_data.data[0]] * 2) + with pytest.raises(ValidationError): + _ = heat_sim_data.updated_copy(data=(heat_sim_data.data[0],) * 2) temp_mnt = TemperatureMonitor(size=(1, 2, 3), name="test") temp_mnt = temp_mnt.updated_copy(name="test2") - sim = heat_sim_data.simulation.updated_copy(monitors=[temp_mnt]) + sim = heat_sim_data.simulation.updated_copy(monitors=(temp_mnt,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim_data.updated_copy(simulation=sim) diff --git a/tests/test_components/test_heat_charge.py b/tests/test_components/test_heat_charge.py index 4921cd3e29..54a1ca60e6 100644 --- a/tests/test_components/test_heat_charge.py +++ b/tests/test_components/test_heat_charge.py @@ -1,10 +1,10 @@ """Test suite for heat-charge simulation objects and data using pytest fixtures.""" import numpy as np -import pydantic.v1 as pd import pytest import tidy3d as td from matplotlib import pyplot as plt +from pydantic import ValidationError from tidy3d.components.tcad.types import ( AugerRecombination, CaugheyThomasMobility, @@ -653,22 +653,22 @@ def test_heat_charge_medium_validation(mediums): solid_medium = mediums["solid_medium"] # Test invalid capacity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.heat_spec.updated_copy(capacity=-1) # Test invalid conductivity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.heat_spec.updated_copy(conductivity=-1) # Test invalid charge conductivity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.charge.updated_copy(conductivity=-1) def test_constant_mobility(): constant_mobility = td.ConstantMobilityModel(mu=1500) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = constant_mobility.updated_copy(mu=-1) @@ -692,23 +692,23 @@ def test_heat_charge_bcs_validation(boundary_conditions): bc_temp, bc_flux, bc_conv, bc_volt, bc_current = boundary_conditions # Invalid TemperatureBC - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.TemperatureBC(temperature=-10) # Invalid ConvectionBC: negative ambient temperature - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.ConvectionBC(ambient_temperature=-400, transfer_coeff=0.2) # Invalid ConvectionBC: negative transfer coefficient - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.ConvectionBC(ambient_temperature=400, transfer_coeff=-0.2) # Invalid VoltageBC: infinite voltage - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VoltageBC(source=td.DCVoltageSource(voltage=[td.inf])) # Invalid CurrentBC: infinite current density - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.CurrentBC(source=td.DCCurrentSource(current=td.inf)) @@ -717,11 +717,11 @@ def test_heat_charge_monitors_validation(monitors): temp_mnt = monitors[0] # Invalid monitor name - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt.updated_copy(name=None) # Invalid monitor size (negative dimension) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt.updated_copy(size=(-1, 2, 3)) @@ -737,9 +737,9 @@ def test_monitor_crosses_medium(mediums, structures, heat_simulation, conduction center=(0, 0, 0), size=(td.inf, td.inf, td.inf), name="voltage" ) # A voltage monitor in a heat simulation should throw error if no ChargeConductorMedium is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): heat_simulation.updated_copy( - medium=solid_no_elect, structures=[solid_struct_no_elect], monitors=[volt_monitor] + medium=solid_no_elect, structures=(solid_struct_no_elect,), monitors=(volt_monitor,) ) # Temperature monitor @@ -747,9 +747,9 @@ def test_monitor_crosses_medium(mediums, structures, heat_simulation, conduction center=(0, 0, 0), size=(td.inf, td.inf, td.inf), name="temperature" ) # A temperature monitor should throw error in a conduction simulation if no SolidSpec is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): conduction_simulation.updated_copy( - medium=solid_no_heat, structures=[solid_struct_no_heat], monitors=[temp_monitor] + medium=solid_no_heat, structures=(solid_struct_no_heat,), monitors=(temp_monitor,) ) @@ -765,18 +765,18 @@ def test_grid_spec_validation(grid_specs): """Tests whether unstructured grids can be created and different validators for them.""" # Test UniformUnstructuredGrid uniform_grid = grid_specs["uniform"] - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(min_edges_per_circumference=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(min_edges_per_side=-1) # Test DistanceUnstructuredGrid distance_grid = grid_specs["distance"] - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): distance_grid.updated_copy(dl_interface=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): distance_grid.updated_copy(distance_interface=2, distance_bulk=1) @@ -861,16 +861,16 @@ def test_sim_data_plotting(simulation_data): heat_sim_data.plot_field("test3", x=0) # Test updating simulation data with duplicate data - with pytest.raises(pd.ValidationError): - heat_sim_data.updated_copy(data=[heat_sim_data.data[0]] * 2) + with pytest.raises(ValidationError): + heat_sim_data.updated_copy(data=(heat_sim_data.data[0],) * 2) # Test updating simulation data with invalid simulation temp_mnt = td.TemperatureMonitor(size=(1, 2, 3), name="test") temp_mnt = temp_mnt.updated_copy(name="test2") - sim = heat_sim_data.simulation.updated_copy(monitors=[temp_mnt]) + sim = heat_sim_data.simulation.updated_copy(monitors=(temp_mnt,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): heat_sim_data.updated_copy(simulation=sim) @@ -1023,12 +1023,12 @@ def test_charge_simulation( ) # At least one ChargeSimulationMonitor should be added - with pytest.raises(pd.ValidationError): - sim.updated_copy(monitors=[]) + with pytest.raises(ValidationError): + sim.updated_copy(monitors=()) # At least 2 VoltageBCs should be defined - with pytest.raises(pd.ValidationError): - sim.updated_copy(boundary_spec=[bc_n]) + with pytest.raises(ValidationError): + sim.updated_copy(boundary_spec=(bc_n,)) # Define ChargeSimulation with no Semiconductor materials medium = td.MultiPhysicsMedium( @@ -1037,15 +1037,15 @@ def test_charge_simulation( ) new_structures = [struct.updated_copy(medium=medium) for struct in sim.structures] - with pytest.raises(pd.ValidationError): - sim.updated_copy(structures=new_structures) + with pytest.raises(ValidationError): + sim.updated_copy(structures=tuple(new_structures)) # test a voltage array is provided when a capacitance monitor is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): new_bc_n = bc_n.updated_copy( condition=td.VoltageBC(source=td.DCVoltageSource(voltage=1)) ) - _ = sim.updated_copy(boundary_spec=[bc_p, new_bc_n]) + _ = sim.updated_copy(boundary_spec=(bc_p, new_bc_n)) def test_doping_distributions(self): """Test doping distributions.""" @@ -1216,13 +1216,13 @@ def test_2D_doping_box(): _ = td.ConstantDoping(size=(1, 1, np.inf), concentration=1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.ConstantDoping(size=(0, 1, 1), concentration=1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.ConstantDoping(size=(1, 0, 1), concentration=1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.ConstantDoping(size=(1, 1, 0), concentration=1) _ = td.ConstantDoping.from_bounds(rmin=(-td.inf, -1, -1), rmax=(td.inf, 1, 1), concentration=1) @@ -1244,7 +1244,7 @@ def test_simulation_initialization_invalid_parameters( ): """Test simulation initialization with invalid parameters.""" # Invalid simulation size - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.HeatChargeSimulation( medium=mediums["fluid_medium"], structures=[structures["fluid_structure"]], @@ -1257,7 +1257,7 @@ def test_simulation_initialization_invalid_parameters( ) # Invalid monitor type - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.HeatChargeSimulation( medium=mediums["fluid_medium"], structures=[structures["fluid_structure"]], @@ -1324,9 +1324,7 @@ def test_dynamic_simulation_updates(heat_simulation): # Add a new monitor new_monitor = td.TemperatureMonitor(size=(1, 1, 1), name="new_temp_mnt") - updated_sim = heat_simulation.updated_copy( - monitors=tuple(list(heat_simulation.monitors) + [new_monitor]) - ) + updated_sim = heat_simulation.updated_copy(monitors=(*heat_simulation.monitors, new_monitor)) assert len(updated_sim.monitors) == len(heat_simulation.monitors) + 1 assert updated_sim.monitors[-1].name == "new_temp_mnt" diff --git a/tests/test_components/test_layerrefinement.py b/tests/test_components/test_layerrefinement.py index b0c6487e12..b46c29f7e8 100644 --- a/tests/test_components/test_layerrefinement.py +++ b/tests/test_components/test_layerrefinement.py @@ -1,9 +1,9 @@ """Tests 2d corner finder.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.grid.corner_finder import CornerFinderSpec from tidy3d.components.grid.grid_spec import GridRefinement, LayerRefinementSpec @@ -91,7 +91,7 @@ def test_layerrefinement(): """Test LayerRefinementSpec is working as expected.""" # size along axis must be inf - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec(axis=0, size=(td.inf, 0, 0)) # classmethod @@ -116,19 +116,19 @@ def test_layerrefinement(): assert layer._is_inplane_bounded assert layer.axis == 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): structures = [ td.Structure(geometry=td.Box(size=(td.inf, td.inf, td.inf)), medium=td.Medium()) ] layer = LayerRefinementSpec.from_structures(structures) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(0, td.inf)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(td.inf, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(-td.inf, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(1, -1)) @@ -500,9 +500,11 @@ def test_gap_meshing(): ) reentry_gap = td.Structure( - geometry=td.Box(size=(0.3, 0.4, 0.3)) - .rotated(axis=1, angle=np.pi / 4) - .translated(x=0, y=0, z=0.5), + geometry=td.PolySlab( + slab_bounds=(-0.2, 0.2), + axis=1, + vertices=[(-0.3, 0.52), (-0.05, 0.3), (0.2, 0.52)], + ), medium=td.Medium(), ) diff --git a/tests/test_components/test_lumped_element.py b/tests/test_components/test_lumped_element.py index f91ece5b0c..7d7f210b35 100644 --- a/tests/test_components/test_lumped_element.py +++ b/tests/test_components/test_lumped_element.py @@ -1,9 +1,9 @@ """Tests lumped elements.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.lumped_element import NetworkConversions @@ -33,7 +33,7 @@ def test_lumped_resistor(): assert monitor.name == resistor.monitor_name # error if voltage axis is not in plane with the resistor - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -43,7 +43,7 @@ def test_lumped_resistor(): ) # error if not planar - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -51,7 +51,7 @@ def test_lumped_resistor(): voltage_axis=2, name="R", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -86,7 +86,7 @@ def test_lumped_resistor_snapping(): # snapped version resistor_snapped = resistor.updated_copy(enable_snapping_points=True) - sim_snapped = sim.updated_copy(lumped_elements=[resistor_snapped]) + sim_snapped = sim.updated_copy(lumped_elements=(resistor_snapped,)) # whether lumped element is snapped along normal axis assert not any(np.isclose(sim.grid.boundaries.z, 0.1)) assert any(np.isclose(sim_snapped.grid.boundaries.z, 0.1)) @@ -124,7 +124,7 @@ def test_coaxial_lumped_resistor_snapping(): # snapped version resistor_snapped = resistor.updated_copy(enable_snapping_points=True) - sim_snapped = sim.updated_copy(lumped_elements=[resistor_snapped]) + sim_snapped = sim.updated_copy(lumped_elements=(resistor_snapped,)) # whether lumped element is snapped along normal axis assert not any(np.isclose(sim.grid.boundaries.z, 0.1)) assert any(np.isclose(sim_snapped.grid.boundaries.z, 0.1)) @@ -153,7 +153,7 @@ def test_coaxial_lumped_resistor(): _ = resistor.to_snapping_points() # error if inner diameter is larger - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CoaxialLumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -163,7 +163,7 @@ def test_coaxial_lumped_resistor(): name="R", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CoaxialLumpedResistor( resistance=50.0, center=[0, 0, np.inf], @@ -177,11 +177,11 @@ def test_coaxial_lumped_resistor(): def test_validators_RLC_network(): """Test that ``RLCNetwork`` is validated correctly.""" # Must have a defined value for R,L,or C - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.RLCNetwork() # Must have a valid topology - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.RLCNetwork( capacitance=0.2e-12, network_topology="left", @@ -190,13 +190,13 @@ def test_validators_RLC_network(): def test_validators_admittance_network(): """Test that ``AdmittanceNetwork`` is validated correctly.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork() a = (0, -1, 2) b = (1, 1, 2) # non negative a and b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork( a=a, b=b, @@ -205,7 +205,7 @@ def test_validators_admittance_network(): a = (0, complex(1, 2), 2) b = (1, 1, 2) # real a and b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork( a=a, b=b, @@ -265,17 +265,12 @@ def test_RLC_and_lumped_network_agreement(Rval, Lval, Cval, topology): if configuration_includes_parallel_inductor: return - network = td.AdmittanceNetwork( - a=a, - b=b, - ) + network = td.AdmittanceNetwork(a=a, b=b) (a, b) = network._as_admittance_function med_network = network._to_medium(sf) # Check conversion to geometry and to structure - linear_element = linear_element.updated_copy( - network=network, - ) + linear_element = linear_element.updated_copy(network=network) _ = linear_element.to_geometry() assert np.allclose(med_RLC.eps_model(freqs), med_network.eps_model(freqs), rtol=rtol) diff --git a/tests/test_components/test_medium.py b/tests/test_components/test_medium.py index 59a2dd2d34..cb40b48e10 100644 --- a/tests/test_components/test_medium.py +++ b/tests/test_components/test_medium.py @@ -1,13 +1,11 @@ """Tests mediums.""" -from typing import Dict - import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import tidy3d as td -from tidy3d.exceptions import SetupError, ValidationError +from tidy3d.exceptions import ValidationError from ..utils import AssertLogLevel @@ -54,9 +52,9 @@ def test_from_n_less_than_1(): def test_medium(): # mediums error with unacceptable values - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(permittivity=0.0) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-1.0) @@ -143,30 +141,30 @@ def test_PEC(): def test_lossy_metal(): # frequency_range shouldn't be None - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1) - # frequency_range shouldn't contain non-postive values - with pytest.raises(pydantic.ValidationError): + # frequency_range shouldn't contain non-positive values + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(0, 10)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(-10, 10)) # frequency_range should be finite - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(10, np.inf)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(-np.inf, 10)) # allow_gain cannot be true - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(allow_gain=True, conductivity=1, frequency_range=(10, 20)) # conductivity cannot be negative - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=-1, frequency_range=(10, 20)) # conductivity cannot be 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=0, frequency_range=(10, 20)) # default fitting @@ -212,13 +210,13 @@ def test_medium_dispersion(): m_DR = td.Drude(eps_inf=1.0, coeffs=[(1, 3), (2, 4)]) m_DB = td.Debye(eps_inf=1.0, coeffs=[(1, 3), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sellmeier(coeffs=[(2, 0), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Drude(eps_inf=1.0, coeffs=[(1, 0), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Debye(eps_inf=1.0, coeffs=[(1, 0), (2, 4)]) freqs = np.linspace(0.01, 1, 1001) @@ -287,7 +285,7 @@ def test_sellmeier_from_dispersion(): assert np.allclose(-dn_df * td.C_0 / wvl**2, dn_dwvl) -def eps_compare(medium: td.Medium, expected: Dict, tol: float = 1e-5): +def eps_compare(medium: td.Medium, expected: dict, tol: float = 1e-5): for freq, val in expected.items(): assert np.abs(medium.eps_model(freq) - val) < tol @@ -418,27 +416,27 @@ def test_n_cfl(): def test_gain_medium(): """Test passive and gain medium validations.""" # non-dispersive - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-0.1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-1.0, allow_gain=False) _ = td.Medium(conductivity=-1.0, allow_gain=True) # pole residue, causality - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PoleResidue(eps_inf=0.16, poles=[(1 + 1j, 2 + 2j)]) # Sellmeier - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sellmeier(coeffs=((-1, 1),)) mS = td.Sellmeier(coeffs=((-1, 1),), allow_gain=True) # Lorentz # causality, negative gamma - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Lorentz(eps_inf=0.04, coeffs=[(1, 2, -3)]) # gain, negative Delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Lorentz(eps_inf=0.04, coeffs=[(-1, 2, 3)]) mL = td.Lorentz(eps_inf=0.04, coeffs=[(-1, 2, 3)], allow_gain=True) assert mL.pole_residue.allow_gain @@ -447,7 +445,7 @@ def test_gain_medium(): _ = td.Lorentz(eps_inf=0.04, coeffs=[(1, -2, 3)]) # Drude, only causality constraint - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Drude(eps_inf=0.04, coeffs=[(1, -2)]) # anisotropic medium, warn allow_gain is ignored @@ -492,7 +490,7 @@ def test_medium2d(): _ = medium.plot(freqs=[2e14, 3e14], ax=AX) plt.close() - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium2D(ss=td.PECMedium(), tt=td.Medium()) @@ -535,24 +533,24 @@ def test_fully_anisotropic_media(): _ = td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond) # check that tensors are provided - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[3, 4, 2]) # check that permittivity >= 1 and conductivity >= 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[[3, 0, 0], [0, 0.5, 0], [0, 0, 1]]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(conductivity=[[-3, 0, 0], [0, 0.5, 0], [0, 0, 1]]) td.FullyAnisotropicMedium(conductivity=[[-3, 0, 0], [0, 0.5, 0], [0, 0, 1]], allow_gain=True) # check that permittivity needs to be symmetric - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[[3, 0.1, 0], [0.2, 2, 0], [0, 0, 1]]) # check that differently oriented permittivity and conductivity are not accepted - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond2) # check creation from diagonal medium @@ -636,7 +634,7 @@ def test_nonlinear_medium(): med = td.Medium(nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5)) # don't use deprecated numiters - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec(models=[td.NonlinearSusceptibility(chi3=1, numiters=2)]) ) @@ -645,15 +643,15 @@ def test_nonlinear_medium(): med = td.PoleResidue(poles=[(-1, 1)], nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5)) # unsupported material types - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.AnisotropicMedium( xx=med, yy=med, zz=med, nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5) ) # numiters too large - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium(nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5, numiters=200)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec( num_iters=200, models=[td.NonlinearSusceptibility(chi3=1.5)] @@ -661,7 +659,7 @@ def test_nonlinear_medium(): ) # duplicate models - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec( models=[ @@ -672,12 +670,12 @@ def test_nonlinear_medium(): ) # active materials - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=-1, n0=1, freq0=1)]) ) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec( models=[td.KerrNonlinearity(n2=-1j, n0=1, use_complex_fields=True)] @@ -709,7 +707,7 @@ def test_nonlinear_medium(): # subsection with nonlinear materials needs to hardcode source info sim2 = sim.updated_copy(center=(-4, -4, -4), path="sources/0") sim2 = sim2.updated_copy( - models=[td.TwoPhotonAbsorption(beta=1)], path="structures/0/medium/nonlinear_spec" + models=(td.TwoPhotonAbsorption(beta=1),), path="structures/0/medium/nonlinear_spec" ) sim2 = sim2.subsection(region=td.Box(center=(0, 0, 0), size=(1, 1, 0))) assert sim2.structures[0].medium.nonlinear_spec.models[0].n0 == n0 @@ -718,35 +716,35 @@ def test_nonlinear_medium(): # can't detect n0 with different source freqs source_time2 = source_time.updated_copy(freq0=2 * freq0) source2 = source.updated_copy(source_time=source_time2) - with pytest.raises(SetupError): - sim.updated_copy(sources=[source, source2]) - with pytest.raises(SetupError): - sim.updated_copy(sources=[]) + with pytest.raises(pd.ValidationError): + sim.updated_copy(sources=(source, source2)) + with pytest.raises(pd.ValidationError): + sim.updated_copy(sources=()) # but if we provided it, it's ok nonlinear_spec = td.NonlinearSpec(models=[td.KerrNonlinearity(n2=1, n0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) - sim = sim.updated_copy(structures=[structure]) + sim = sim.updated_copy(structures=(structure,)) assert 1 == nonlinear_spec.models[0]._get_n0(n0=1, medium=medium, freqs=[1, 2]) nonlinear_spec = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=1, n0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) - sim = sim.updated_copy(structures=[structure]) - with pytest.raises(SetupError): - sim = sim.updated_copy(structures=[structure], sources=[source, source2]) - with pytest.raises(SetupError): - sim = sim.updated_copy(structures=[structure], sources=[]) + sim = sim.updated_copy(structures=(structure,)) + with pytest.raises(pd.ValidationError): + sim = sim.updated_copy(structures=(structure,), sources=(source, source2)) + with pytest.raises(pd.ValidationError): + sim = sim.updated_copy(structures=(structure,), sources=()) nonlinear_spec = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=1, n0=1, freq0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) - sim = sim.updated_copy(structures=[structure]) + sim = sim.updated_copy(structures=(structure,)) assert 1 == nonlinear_spec.models[0]._get_freq0(freq0=1, freqs=[1, 2]) # active materials with automatic detection of n0 nonlinear_spec_active = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=-1)]) medium_active = medium.updated_copy(nonlinear_spec=nonlinear_spec_active) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): structure = structure.updated_copy(medium=medium_active) - sim.updated_copy(structures=[structure]) + sim.updated_copy(structures=(structure,)) # nonlinear or time-modulation on medium2d # time-modulated @@ -760,17 +758,17 @@ def test_nonlinear_medium(): MODULATION_SPEC = td.ModulationSpec() modulation_spec = MODULATION_SPEC.updated_copy(permittivity=ST) modulated = td.Medium(permittivity=2, modulation_spec=modulation_spec) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): td.Medium2D(ss=medium, tt=medium) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): td.Medium2D(ss=modulated, tt=modulated) # some parameters must be real now, unless we use old implementation _ = td.TwoPhotonAbsorption(beta=1j, use_complex_fields=True) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.TwoPhotonAbsorption(beta=1j) _ = td.KerrNonlinearity(n2=1j, use_complex_fields=True) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.KerrNonlinearity(n2=1j) # consistent complex fields @@ -780,7 +778,7 @@ def test_nonlinear_medium(): td.KerrNonlinearity(n2=1, use_complex_fields=True), ] ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.NonlinearSpec( models=[ td.TwoPhotonAbsorption(beta=1, use_complex_fields=True), @@ -808,7 +806,7 @@ def test_nonlinear_medium(): interval=1, size=(0, 0, 0), name="aux_field_time", fields=aux_fields ) sim = sim.updated_copy(medium=med, path="structures/0") - sim = sim.updated_copy(monitors=[monitor]) + sim = sim.updated_copy(monitors=(monitor,)) with AssertLogLevel("WARNING", contains_str="stores field"): med = td.Medium( @@ -849,7 +847,7 @@ def create_mediums(n_dataset): with AssertLogLevel(None): create_mediums(n_dataset=n_dataset) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): # repeat some entries so data cannot be interpolated X2 = [X[0]] + list(X) n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) diff --git a/tests/test_components/test_meshgenerate.py b/tests/test_components/test_meshgenerate.py index 6b10e11a9a..7f3b8ed137 100644 --- a/tests/test_components/test_meshgenerate.py +++ b/tests/test_components/test_meshgenerate.py @@ -502,7 +502,7 @@ def test_mesh_direct_override(): assert np.isclose(sizes[len(sizes) // 2], 0.05) # default override has no effect when coarser than enclosing structure - override_coarse = override_fine.copy(update={"dl": [0.2] * 3}) + override_coarse = override_fine.copy(update={"dl": (0.2,) * 3}) sim = td.Simulation( size=(3, 3, 3), grid_spec=td.GridSpec.auto( @@ -675,13 +675,13 @@ def test_small_structure_size(): # Warning not raised if structure is higher index box2 = box.updated_copy(medium=td.Medium(permittivity=300)) with AssertLogLevel(None): - sim.updated_copy(structures=[box2]) + sim.updated_copy(structures=(box2,)) # Warning not raised if structure is covered by an override structure override = td.MeshOverrideStructure(geometry=box.geometry, dl=(box_size, td.inf, td.inf)) with AssertLogLevel(None): sim3 = sim.updated_copy( - grid_spec=sim.grid_spec.updated_copy(override_structures=[override]) + grid_spec=sim.grid_spec.updated_copy(override_structures=(override,)) ) # Also check that the structure boundaries are in the grid ind_mid_cell = int(sim3.grid.num_cells[0] // 2) @@ -693,7 +693,7 @@ def test_small_structure_size(): geometry=td.Box(center=(box_size, 0, 0), size=(box_size, td.inf, td.inf)), medium=medium ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box3, box]) + sim.updated_copy(structures=(box3, box)) def test_shapely_strtree_warnings(): diff --git a/tests/test_components/test_mode.py b/tests/test_components/test_mode.py index df08f470a3..af17998243 100644 --- a/tests/test_components/test_mode.py +++ b/tests/test_components/test_mode.py @@ -1,7 +1,7 @@ """Tests mode objects.""" import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import tidy3d as td from matplotlib import pyplot as plt @@ -32,29 +32,29 @@ def test_modes(): for opt in options: _ = td.ModeSpec(num_modes=3, track_freq=opt) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(num_modes=3, track_freq="middle") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(num_modes=3, track_freq=4) def test_bend_axis_not_given(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(bend_radius=1.0, bend_axis=None) def test_zero_radius(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(bend_radius=0.0, bend_axis=1) def test_glancing_incidence(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(angle_theta=np.pi / 2) def test_group_index_step_validation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(group_index_step=1.0) ms = td.ModeSpec(group_index_step=True) @@ -71,7 +71,7 @@ def test_angle_rotation_with_phi(): td.ModeSpec(angle_phi=np.pi, angle_rotation=True) # Case where angle_phi is not a multiple of np.pi and angle_rotation is True - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeSpec(angle_phi=np.pi / 3, angle_rotation=True) @@ -110,14 +110,14 @@ def test_mode_sim(): assert sim.plane == sim.geometry # must be planar or have plane - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=None) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(size=(3, 3, 3))) _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(size=(3, 3, 0))) # plane must intersect sim geometry - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(center=(5, 5, 5), size=(1, 1, 0))) # test warning for not providing wavelength in autogrid @@ -161,7 +161,7 @@ def test_mode_sim(): ) assert td.ModeSimulation.from_simulation(sim) == sim - assert td.ModeSimulation.from_mode_solver(sim._mode_solver) == sim.updated_copy(monitors=[]) + assert td.ModeSimulation.from_mode_solver(sim._mode_solver) == sim.updated_copy(monitors=()) _ = td.ModeSimulation.from_simulation( simulation=fdtd_sim, plane=td.Box(size=(4, 4, 0)), diff --git a/tests/test_components/test_monitor.py b/tests/test_components/test_monitor.py index 825949aef9..2aacb51b51 100644 --- a/tests/test_components/test_monitor.py +++ b/tests/test_components/test_monitor.py @@ -1,7 +1,7 @@ """Tests monitors.""" import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import tidy3d as td from tidy3d.exceptions import SetupError, ValidationError @@ -10,7 +10,7 @@ def test_stop_start(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FluxTimeMonitor(size=(1, 1, 0), name="f", start=2, stop=1) @@ -57,13 +57,13 @@ def test_downsampled(): def test_excluded_surfaces_flat(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FluxMonitor(size=(1, 1, 0), name="f", freqs=[1e12], exclude_surfaces=("x-",)) def test_fld_mnt_freqs_none(): """Test that validation errors if freqs=[None].""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FieldMonitor(center=(0, 0, 0), size=(0, 0, 0), freqs=[None], name="test") @@ -147,7 +147,7 @@ def test_fieldproj_surfaces(): def test_fieldproj_surfaces_in_simulaiton(): # test error if all projection surfaces are outside the simulation domain M = td.FieldProjectionAngleMonitor(size=(3, 3, 3), theta=[1], phi=[0], name="f", freqs=[2e12]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -159,13 +159,13 @@ def test_fieldproj_surfaces_in_simulaiton(): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, - monitors=[M], + monitors=(M,), grid_spec=td.GridSpec.uniform(0.1), ) # error when the surfaces that are in are excluded - M = M.updated_copy(exclude_surfaces=["x-", "x+"]) - with pytest.raises(pydantic.ValidationError): + M = M.updated_copy(exclude_surfaces=("x-", "x+")) + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -176,11 +176,11 @@ def test_fieldproj_surfaces_in_simulaiton(): def test_fieldproj_kspace_range(): # make sure ux, uy are in [-1, 1] for k-space projection monitors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionKSpaceMonitor( size=(2, 0, 2), ux=[0.1, 2], uy=[0], name="f", freqs=[2e12], proj_axis=1 ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionKSpaceMonitor( size=(2, 0, 2), ux=[0.1, 0.2], uy=[1.1], name="f", freqs=[2e12], proj_axis=1 ) @@ -209,12 +209,12 @@ def test_fieldproj_window(): points = np.linspace(0, 10, 100) _ = M.window_function(points, window_size, window_minus, window_plus, 2) # do not allow a window size larger than 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionAngleMonitor( size=(2, 0, 2), theta=[1, 2], phi=[0], name="f", freqs=[2e12], window_size=(0.2, 1.1) ) # do not allow non-zero windows for volume monitors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionAngleMonitor( size=(2, 1, 2), theta=[1, 2], phi=[0], name="f", freqs=[2e12], window_size=(0.2, 0) ) @@ -239,7 +239,7 @@ def test_storage_sizes(proj_mnt): def test_monitor_freqs_empty(): # errors when no frequencies supplied - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldMonitor( size=(td.inf, td.inf, td.inf), freqs=[], @@ -323,7 +323,7 @@ def test_diffraction_validators(): y=td.Boundary.periodic(), z=td.Boundary.pml(), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -334,7 +334,7 @@ def test_diffraction_validators(): ) # ensure error if monitor isn't infinite in two directions - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.DiffractionMonitor(size=[td.inf, 4, 0], freqs=[1e12], name="de") @@ -391,11 +391,11 @@ def test_monitor(): def test_monitor_plane(): # make sure flux, mode and diffraction monitors fail with non planar geometries for size in ((0, 0, 0), (1, 0, 0), (1, 1, 1)): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeMonitor(size=size, freqs=FREQS, modes=[]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeSolverMonitor(size=size, freqs=FREQS, modes=[]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.DiffractionMonitor(size=size, freqs=FREQS, name="de") diff --git a/tests/test_components/test_parameter_perturbation.py b/tests/test_components/test_parameter_perturbation.py index 91f43e93dc..74fcd83ee6 100644 --- a/tests/test_components/test_parameter_perturbation.py +++ b/tests/test_components/test_parameter_perturbation.py @@ -2,9 +2,9 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from ..utils import AssertLogLevel, cartesian_to_unstructured @@ -38,14 +38,14 @@ def test_heat_perturbation(): # test complex type detection assert not perturb.is_complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearHeatPerturbation( coeff=0.01, temperature_ref=-300, temperature_range=(200, 400), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearHeatPerturbation( coeff=0.01, temperature_ref=300, @@ -134,7 +134,7 @@ def test_heat_perturbation(): assert test_value_out == perturb_data.data[2] # test not allowed interpolation method - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.CustomHeatPerturbation( perturbation_values=perturb_data, interp_method="quadratic", @@ -159,7 +159,7 @@ def test_charge_perturbation(): # test complex type detection assert not perturb.is_complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearChargePerturbation( electron_coeff=1e-21, electron_ref=0, @@ -169,7 +169,7 @@ def test_charge_perturbation(): hole_range=(0, 0.5e20), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearChargePerturbation( electron_coeff=1e-21, electron_ref=0, @@ -341,7 +341,7 @@ def test_sample(perturb): assert test_value_out == perturb_data[-1, -1].item() # test not allowed interpolation method - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.CustomChargePerturbation( perturbation_values=perturb_data, interp_method="quadratic", @@ -668,9 +668,7 @@ def test_delta_model(): coeffs_4 = np.array([7.4e-22, 1.245, 5.43e-20, 1.153, 7.25e-21, 0.991, 9.99e-18, 0.839]) averaged_vals = (coeffs_3_5 + coeffs_4) / 2 - interpolated_results = [ - value.item() for _, value in delta_model._coeffs_at_ref_freq.data_vars.items() - ] + interpolated_results = [v.item() for v in delta_model._coeffs_at_ref_freq.data.flat] error = np.abs(np.mean(averaged_vals - np.array(interpolated_results))) assert error < 1e-16 diff --git a/tests/test_components/test_perturbation_medium.py b/tests/test_components/test_perturbation_medium.py index eff3063b06..441002cac8 100644 --- a/tests/test_components/test_perturbation_medium.py +++ b/tests/test_components/test_perturbation_medium.py @@ -1,9 +1,9 @@ """Tests mediums.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from ..utils import AssertLogLevel, cartesian_to_unstructured @@ -103,7 +103,7 @@ def test_perturbation_medium(unstructured): assert cmed.allow_gain == pmed.allow_gain # permittivity < 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # conductivity validators @@ -133,7 +133,7 @@ def test_perturbation_medium(unstructured): for pmed in [pmed_direct, pmed_perm, pmed_index]: cmed = pmed.perturbed_copy(0.9 * temperature) # positive conductivity assert not cmed.subpixel - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # negative conductivity # negative conductivity but allow gain @@ -141,11 +141,11 @@ def test_perturbation_medium(unstructured): _ = pmed.perturbed_copy(1.1 * temperature) # complex perturbation - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): pmed = td.PerturbationMedium(permittivity=3, permittivity_perturbation=pp_complex) # overdefinition - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PerturbationMedium( permittivity=1.21, permittivity_perturbation=pp_real, @@ -258,18 +258,18 @@ def test_perturbation_medium(unstructured): assert cmed.allow_gain == pmed.allow_gain # eps_inf < 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # mismatch between base parameter and perturbations - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): pmed = td.PerturbationPoleResidue( poles=[(1j, 3), (2j, 4)], poles_perturbation=[(None, pp_real)], ) # overdefinition - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PerturbationPoleResidue( eps_inf=1.21, poles=[(1j, 3), (2j, 4)], diff --git a/tests/test_components/test_scene.py b/tests/test_components/test_scene.py index 6f74bd6df3..c5b5a692e8 100644 --- a/tests/test_components/test_scene.py +++ b/tests/test_components/test_scene.py @@ -2,9 +2,9 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.scene import MAX_GEOMETRY_COUNT, MAX_NUM_MEDIUMS from ..utils import SIM_FULL, cartesian_to_unstructured @@ -44,7 +44,7 @@ def test_scene_init(): def test_validate_components_none(): - assert SCENE._validate_num_mediums(val=None) is None + assert type(SCENE)._validate_num_mediums(val=None) is None def test_plot_eps(): @@ -113,7 +113,7 @@ def test_structure_alpha(): new_structs = [ td.Structure(geometry=s.geometry, medium=SCENE_FULL.medium) for s in SCENE_FULL.structures ] - S2 = SCENE_FULL.copy(update=dict(structures=new_structs)) + S2 = SCENE_FULL.copy(update=dict(structures=tuple(new_structs))) _ = S2.plot_structures_eps(x=0, alpha=0.5) plt.close() @@ -155,7 +155,7 @@ def test_num_mediums(): structures=structures, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 2)) ) @@ -191,7 +191,7 @@ def _test_names_default(): def test_names_unique(): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Scene( structures=[ td.Structure( @@ -305,5 +305,5 @@ def test_max_geometry_validation(): medium=td.Medium(permittivity=2.0), ), ] - with pytest.raises(pd.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): + with pytest.raises(ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): _ = td.Scene(structures=not_fine) diff --git a/tests/test_components/test_sidewall.py b/tests/test_components/test_sidewall.py index f6f735eded..2c4a85a7c1 100644 --- a/tests/test_components/test_sidewall.py +++ b/tests/test_components/test_sidewall.py @@ -1,9 +1,9 @@ """test slanted polyslab can be correctly setup and visualized.""" import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from shapely import Point, Polygon from tidy3d.constants import fp_eps @@ -129,17 +129,17 @@ def test_valid_polygon(): # area = 0 vertices = ((0, 0), (1, 0), (2, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # only two points vertices = ((0, 0), (1, 0), (1, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # intersecting edges vertices = ((0, 0), (1, 0), (1, 1), (0, 1), (0.5, -1)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) @@ -158,13 +158,13 @@ def test_crossing_square_poly(): dilation = -1.1 angle = 0 for ref_plane in ["bottom", "middle", "top"]: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane=ref_plane) # angle too large, self-intersecting dilation = 0 angle = np.pi / 3 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="top") # middle plane @@ -173,13 +173,13 @@ def test_crossing_square_poly(): # angle too large for middle reference plane angle = np.arctan(2.001) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") # combines both dilation = -0.1 angle = np.pi / 4 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) @@ -193,26 +193,26 @@ def test_crossing_concave_poly(): vertices = ((-0.5, 1), (-0.5, -1), (1, -1), (0, -0.1), (0, 0.1), (1, 1)) dilation = 0.5 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # polygon splitting dilation = -0.3 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # polygon fully eroded dilation = -0.5 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # # or, effectively dilation = 0 angle = -np.pi / 4 for bounds in [(0, 0.3), (0, 0.5)]: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) _ = setup_polyslab(vertices, dilation, -angle, bounds, reference_plane="top") @@ -221,7 +221,7 @@ def test_crossing_concave_poly(): bounds = (0, 0.44) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") _ = setup_polyslab(vertices, dilation, -angle, bounds, reference_plane="middle") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # vertices degenerate bounds = (0, 0.45) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") diff --git a/tests/test_components/test_simulation.py b/tests/test_components/test_simulation.py index c827e648fd..d2887dcbfc 100644 --- a/tests/test_components/test_simulation.py +++ b/tests/test_components/test_simulation.py @@ -5,10 +5,10 @@ import gdstk import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td from matplotlib.testing.compare import compare_images +from pydantic import ValidationError from tidy3d.components import simulation from tidy3d.components.scene import MAX_GEOMETRY_COUNT, MAX_NUM_MEDIUMS from tidy3d.components.simulation import MAX_NUM_SOURCES @@ -282,7 +282,7 @@ def test_sim_size(): s._validate_size() # check too many time steps - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.Simulation( size=(1, 1, 1), run_time=1e-7, @@ -395,7 +395,7 @@ def test_validate_monitor_simulation_frequency_range(): def test_validate_bloch_with_symmetry(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -422,7 +422,7 @@ def test_validate_normalize_index(): ) # negative normalize index - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -431,7 +431,7 @@ def test_validate_normalize_index(): ) # normalize index out of bounds - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -445,7 +445,7 @@ def test_validate_normalize_index(): ) # normalize by zero-amplitude source - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -513,7 +513,7 @@ def test_validate_plane_wave_boundaries(): ) # angled incidence plane wave with PMLs / absorbers should error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -560,7 +560,7 @@ def test_validate_zero_dim_boundaries(): pol_angle=0.0, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 0), run_time=1e-12, @@ -598,7 +598,7 @@ def test_validate_symmetry_boundaries(): z=td.Boundary.pml(), ), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), symmetry=(0, 1, 0), @@ -688,7 +688,7 @@ def test_max_geometry_validation(): medium=td.Medium(permittivity=2.0), ), ] - with pytest.raises(pydantic.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): + with pytest.raises(ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): _ = td.Simulation(size=(1, 1, 1), run_time=1, grid_spec=gs, structures=not_fine) @@ -1198,7 +1198,7 @@ def test_sim_plane_wave_error(): ) # with non-transparent box, raise - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, @@ -1208,7 +1208,7 @@ def test_sim_plane_wave_error(): ) # raise with anisotropic medium - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg_diag, @@ -1216,7 +1216,7 @@ def test_sim_plane_wave_error(): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg_full, @@ -1282,7 +1282,7 @@ def test_sim_monitor_homogeneous(): ) # with non-transparent box, raise - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, @@ -1524,7 +1524,7 @@ def test_diffraction_medium(): pol_angle=-1.0, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2, 2, 2), structures=[box_cond], @@ -1534,7 +1534,7 @@ def test_diffraction_medium(): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2, 2, 2), structures=[box_disp], @@ -1687,7 +1687,7 @@ def test_num_mediums(monkeypatch): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 2)) ) @@ -1708,7 +1708,7 @@ def test_num_sources(): _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=[src] * MAX_NUM_SOURCES) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=[src] * (MAX_NUM_SOURCES + 1)) @@ -1771,7 +1771,7 @@ def _test_names_default(): def test_names_unique(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, @@ -1790,7 +1790,7 @@ def test_names_unique(): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, @@ -1813,7 +1813,7 @@ def test_names_unique(): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, @@ -1830,7 +1830,7 @@ def test_mode_object_syms(): g = td.GaussianPulse(freq0=1e12, fwidth=0.1e12) # wrong mode source - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(1.0, -1.0, 0.5), size=(2.0, 2.0, 2.0), @@ -1842,7 +1842,7 @@ def test_mode_object_syms(): ) # wrong mode monitor - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(1.0, -1.0, 0.5), size=(2.0, 2.0, 2.0), @@ -1896,7 +1896,7 @@ def test_tfsf_symmetry(): injection_axis=2, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), @@ -1996,7 +1996,7 @@ def test_tfsf_boundaries(): ) # cannot cross any boundary in the direction of injection - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 0.5), grid_spec=td.GridSpec.auto(wavelength=1.0), @@ -2005,7 +2005,7 @@ def test_tfsf_boundaries(): ) # cannot cross any non-periodic boundary in the transverse direction - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(0.5, 0, 0), # also check the case when the boundary is crossed only on one side size=(0.5, 0.5, 2.0), @@ -2516,7 +2516,7 @@ def test_sim_volumetric_structures(tmp_path): assert np.isclose(sim.volumetric_structures[1].medium.xx.permittivity, 2, rtol=RTOL) # test simulation.medium can't be Medium2D - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(10, 10, 10), structures=[], @@ -2532,16 +2532,16 @@ def test_sim_volumetric_structures(tmp_path): ) # test 2d medium is added to 2d geometry - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)), medium=box.medium) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Cylinder(radius=1, length=1), medium=box.medium) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure( geometry=td.PolySlab(vertices=[(0, 0), (1, 0), (1, 1)], slab_bounds=(-1, 1)), medium=box.medium, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Sphere(radius=1), medium=box.medium) # test warning for 2d geometry in simulation without Medium2D @@ -3141,7 +3141,7 @@ def test_advanced_material_intersection(): struct1 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0.5)), medium=pair[0]) struct2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, -0.5)), medium=pair[1]) # this pair cannot intersect - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = sim.updated_copy(structures=[struct1, struct2]) for pair in incompatible_pairs: @@ -3168,7 +3168,7 @@ def test_num_lumped_elements(): lumped_elements=[resistor] * MAX_NUM_MEDIUMS, run_time=1e-12, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(5, 5, 5), grid_spec=grid_spec, @@ -3190,7 +3190,7 @@ def test_validate_lumped_elements(): lumped_elements=[resistor], ) # error for 1D/2D simulation with lumped elements - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 0, 3), run_time=1e-12, @@ -3198,7 +3198,7 @@ def test_validate_lumped_elements(): lumped_elements=[resistor], ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 0, 0), run_time=1e-12, @@ -3331,7 +3331,7 @@ def test_validate_sources_monitors_in_bounds(): ) # check that a source at y- simulation domain edge errors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -3339,7 +3339,7 @@ def test_validate_sources_monitors_in_bounds(): sources=[mode_source], ) # check that a monitor at y+ simulation domain edge errors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -3468,7 +3468,7 @@ def test_fixed_angle_sim(): assert sim._is_fixed_angle - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy( boundary_spec=td.BoundarySpec( x=td.Boundary.pml(), @@ -3477,25 +3477,25 @@ def test_fixed_angle_sim(): ) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(KeyError): _ = sim.updated_copy(med=td.Medium(conductivity=0.001)) anisotropic_med = td.FullyAnisotropicMedium(permittivity=[[2, 0, 0], [0, 1, 0], [0, 0, 3]]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(structures=[sphere.updated_copy(medium=anisotropic_med)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(sources=[source, source]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy( structures=[sphere.updated_copy(medium=td.Medium(conductivity=-0.1, allow_gain=True))] ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(monitors=[td.FieldTimeMonitor(size=[td.inf, td.inf, 0], name="time")]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(monitors=[td.FluxTimeMonitor(size=[td.inf, td.inf, 0], name="time")]) nonlinear_med = td.Medium( @@ -3507,7 +3507,7 @@ def test_fixed_angle_sim(): num_iters=20, ), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(structures=[sphere.updated_copy(medium=nonlinear_med)]) time_modulated_med = td.Medium( @@ -3518,7 +3518,7 @@ def test_fixed_angle_sim(): ) ), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy(structures=[sphere.updated_copy(medium=time_modulated_med)]) diff --git a/tests/test_components/test_types.py b/tests/test_components/test_types.py index f992bbec8d..7b166b0ff5 100644 --- a/tests/test_components/test_types.py +++ b/tests/test_components/test_types.py @@ -71,7 +71,7 @@ class MyClass(Tidy3dBaseModel): e: constrained_array(ndim=3, shape=(1, 2, 3)) # must have certain shape f: ArrayLike = None - fields = MyClass.__fields__ + fields = MyClass.model_fields def correct_field_display(field_name, display_name): """Make sure the field has the expected name.""" diff --git a/tests/test_data/test_data_arrays.py b/tests/test_data/test_data_arrays.py index f61123fb0d..81b4f56769 100644 --- a/tests/test_data/test_data_arrays.py +++ b/tests/test_data/test_data_arrays.py @@ -1,7 +1,5 @@ """Tests tidy3d/components/data/data_array.py""" -from typing import List, Tuple - import numpy as np import pytest import tidy3d as td @@ -124,7 +122,7 @@ def get_xyz( monitor: td.components.monitor.MonitorType, grid_key: str, symmetry: bool -) -> Tuple[List[float], List[float], List[float]]: +) -> tuple[list[float], list[float], list[float]]: sim = SIM_SYM if symmetry else SIM grid = sim.discretize_monitor(monitor) if monitor.colocate: diff --git a/tests/test_data/test_datasets.py b/tests/test_data/test_datasets.py index e6797ac9b7..2ef4e58c4a 100644 --- a/tests/test_data/test_datasets.py +++ b/tests/test_data/test_datasets.py @@ -1,9 +1,9 @@ """Tests tidy3d/components/data/dataset.py""" import numpy as np -import pydantic.v1 as pd import pytest from matplotlib import pyplot as plt +from pydantic import ValidationError from ..utils import AssertLogLevel, cartesian_to_unstructured @@ -55,7 +55,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): assert tri_grid.name == ds_name # wrong points dimensionality - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): tri_grid_points_bad = td.PointDataArray( np.random.random((4, 3)), coords=dict(index=np.arange(4), axis=np.arange(3)), @@ -110,7 +110,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 2, 3]], coords=dict(cell_index=np.arange(1), vertex_index=np.arange(4)), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=2, normal_pos=-3, @@ -123,7 +123,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 5], [1, 2, 3]], coords=dict(cell_index=np.arange(2), vertex_index=np.arange(3)), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=2, normal_pos=-3, @@ -137,7 +137,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.rand(3, *[len(coord) for coord in extra_dims.values()]), coords=dict(index=np.arange(3), **extra_dims), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=0, normal_pos=0, @@ -282,7 +282,6 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): # writing/reading tri_grid.to_file(tmp_path / "tri_grid_test.hdf5") - tri_grid_loaded = dataset_type.from_file(tmp_path / "tri_grid_test.hdf5") assert tri_grid == tri_grid_loaded @@ -375,7 +374,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.random((8, 2)), coords=dict(index=np.arange(8), axis=np.arange(2)), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points_bad, cells=tet_grid_cells, @@ -421,7 +420,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 3], [0, 2, 3], [0, 2, 6], [0, 4, 6], [0, 4, 5], [0, 1, 5]], coords=dict(cell_index=np.arange(6), vertex_index=np.arange(3)), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, @@ -432,7 +431,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 3, 17], [0, 2, 3, 7], [0, 2, 6, 7], [0, 4, 6, 7], [0, 4, 5, 7], [0, 1, 5, 7]], coords=dict(cell_index=np.arange(6), vertex_index=np.arange(4)), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, @@ -444,7 +443,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.rand(5, *[len(coord) for coord in extra_dims.values()]), coords=dict(index=np.arange(5), **extra_dims), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, diff --git a/tests/test_data/test_monitor_data.py b/tests/test_data/test_monitor_data.py index 5ea9d6a6cc..c890b9a058 100644 --- a/tests/test_data/test_monitor_data.py +++ b/tests/test_data/test_monitor_data.py @@ -2,10 +2,10 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td import xarray as xr +from pydantic import ValidationError from tidy3d.components.data.data_array import ( FreqDataArray, FreqModeDataArray, @@ -344,10 +344,10 @@ def test_mode_solver_data(): _ = data.updated_copy(eps_spec=["tensorial_real"] * num_freqs) _ = data.updated_copy(eps_spec=["tensorial_complex"] * num_freqs) # wrong keyword - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = data.updated_copy(eps_spec=["tensorial"] * num_freqs) # wrong number - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = data.updated_copy(eps_spec=["diagonal"] * (num_freqs + 1)) # check monitor direction changes upon time reversal data_reversed = data.time_reversed_copy @@ -625,7 +625,7 @@ def test_field_data_symmetry_present(): _ = td.FieldTimeData(monitor=monitor, **fields) # fails if symmetry specified but missing symmetry center - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.FieldTimeData( monitor=monitor, symmetry=(1, -1, 0), @@ -634,7 +634,7 @@ def test_field_data_symmetry_present(): ) # fails if symmetry specified but missing etended grid - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.FieldTimeData( monitor=monitor, symmetry=(1, -1, 1), symmetry_center=(0, 0, 0), **fields ) @@ -865,7 +865,7 @@ def test_no_nans(): eps_dataset_nan = td.PermittivityDataset( **{key: eps_nan for key in ["eps_xx", "eps_yy", "eps_zz"]} ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.CustomMedium(eps_dataset=eps_dataset_nan) diff --git a/tests/test_data/test_sim_data.py b/tests/test_data/test_sim_data.py index d2569fb20a..3b5e9b25a7 100644 --- a/tests/test_data/test_sim_data.py +++ b/tests/test_data/test_sim_data.py @@ -2,9 +2,9 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.data.data_array import ScalarFieldTimeDataArray from tidy3d.components.data.monitor_data import FieldTimeData from tidy3d.components.data.sim_data import SimulationData @@ -239,7 +239,7 @@ def test_to_json(tmp_path): sim_data.to_file(fname=FNAME) # saving to json does not store data, so trying to load from file will trigger custom error. - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = SimulationData.from_file(fname=FNAME) @@ -428,9 +428,9 @@ def test_plot_field_title(): def test_missing_monitor(): sim_data = make_sim_data() - new_monitors = list(sim_data.simulation.monitors)[:-1] + new_monitors = tuple(sim_data.simulation.monitors)[:-1] new_sim = sim_data.simulation.copy(update=dict(monitors=new_monitors)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim_data.copy(update=dict(simulation=new_sim)) diff --git a/tests/test_material_library/test_material_library.py b/tests/test_material_library/test_material_library.py index 53a11dd94b..2172b756d5 100644 --- a/tests/test_material_library/test_material_library.py +++ b/tests/test_material_library/test_material_library.py @@ -61,9 +61,7 @@ def test_medium_repr(): repr_noname_medium = test_media[0].__repr__() str_noname_medium_dict = str(noname_medium_in_dict) - assert ( - "type='Medium' permittivity=2.25 conductivity=0.0" in str_noname_medium - ), "Expected medium information in string" + assert "name=None," in str_noname_medium, "Expected medium information in string" assert ( "Medium(attrs={}, name=None, frequency_range=None" in repr_noname_medium ), "Expcted medium information in repr" diff --git a/tests/test_package/test_config.py b/tests/test_package/test_config.py index 63b31e86a7..46c40e162a 100644 --- a/tests/test_package/test_config.py +++ b/tests/test_package/test_config.py @@ -1,8 +1,8 @@ """test the grid operations""" -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.log import DEFAULT_LEVEL, _level_value @@ -19,7 +19,7 @@ def test_logging_level(): def test_log_level_not_found(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.config.logging_level = "NOT_A_LEVEL" diff --git a/tests/test_package/test_log.py b/tests/test_package/test_log.py index bc72d5b316..92071c72ae 100644 --- a/tests/test_package/test_log.py +++ b/tests/test_package/test_log.py @@ -3,9 +3,9 @@ import json import numpy as np -import pydantic.v1 as pd import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.exceptions import Tidy3dError from tidy3d.log import DEFAULT_LEVEL, _get_level_int, set_logging_level @@ -55,7 +55,7 @@ def test_logging_upper(): def test_logging_unrecognized(): """If unrecognized option, raise validation error.""" - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.config.logging_level = "blah" @@ -215,7 +215,7 @@ def test_logging_warning_capture(): sim.validate_pre_upload() warning_list = td.log.captured_warnings() print(json.dumps(warning_list, indent=4)) - assert len(warning_list) == 30 + assert len(warning_list) == 31 td.log.set_capture(False) # check that capture doesn't change validation errors @@ -233,7 +233,7 @@ def test_logging_warning_capture(): try: sim = td.Simulation.parse_obj(sim_dict) sim.validate_pre_upload() - except pd.ValidationError as e: + except ValidationError as e: error_without = e.errors() except Exception as e: error_without = str(e) @@ -242,16 +242,13 @@ def test_logging_warning_capture(): try: sim = td.Simulation.parse_obj(sim_dict) sim.validate_pre_upload() - except pd.ValidationError as e: + except ValidationError as e: error_with = e.errors() except Exception as e: error_with = str(e) td.log.set_capture(False) - print(error_without) - print(error_with) - - assert error_without == error_with + assert str(error_without) == str(error_with) def test_log_suppression(): diff --git a/tests/test_package/test_material_library.py b/tests/test_package/test_material_library.py index 649182de87..8657610ffb 100644 --- a/tests/test_package/test_material_library.py +++ b/tests/test_package/test_material_library.py @@ -1,7 +1,7 @@ import numpy as np -import pydantic.v1 as pydantic import pytest import tidy3d as td +from pydantic import ValidationError from tidy3d.components.material.multi_physics import MultiPhysicsMedium from tidy3d.material_library.material_library import ( MaterialItem, @@ -50,7 +50,7 @@ def test_MaterialItem(): material = MaterialItem(name="material", variants=dict(v1=variant1, v2=variant2), default="v1") assert material["v1"] == material.medium - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): material = MaterialItem( name="material", variants=dict(v1=variant1, v2=variant2), default="v3" ) diff --git a/tests/test_web/test_tidy3d_material_library.py b/tests/test_web/test_tidy3d_material_library.py index b897d92962..2b00fe6830 100644 --- a/tests/test_web/test_tidy3d_material_library.py +++ b/tests/test_web/test_tidy3d_material_library.py @@ -1,7 +1,7 @@ import pytest import responses import tidy3d as td -from tidy3d.web.api.material_libray import MaterialLibray +from tidy3d.web.api.material_library import MaterialLibrary from tidy3d.web.core.environment import Env Env.dev.active() @@ -24,6 +24,6 @@ def test_lib(set_api_key): json={"data": [{"id": "3eb06d16-208b-487b-864b-e9b1d3e010a7", "name": "medium1"}]}, status=200, ) - libs = MaterialLibray.list() + libs = MaterialLibrary.list() lib = libs[0] assert lib.name == "medium1" diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 1611d046a7..56096180ae 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -298,7 +298,7 @@ def mock_webapi( @responses.activate def test_source_validation(monkeypatch, mock_upload, mock_get_info, mock_metadata): - sim = make_sim().copy(update={"sources": []}) + sim = make_sim().copy(update={"sources": ()}) assert upload(sim, TASK_NAME, PROJECT_NAME, source_required=False) with pytest.raises(SetupError): diff --git a/tests/test_web/test_webapi_mode.py b/tests/test_web/test_webapi_mode.py index 432a3fd986..fd0d57dedf 100644 --- a/tests/test_web/test_webapi_mode.py +++ b/tests/test_web/test_webapi_mode.py @@ -55,7 +55,7 @@ def make_mode_sim(): simulation=simulation, plane=td.Box(center=(0, 0, 0), size=(1, 1, 0)), mode_spec=mode_spec, - freqs=[2e14], + freqs=(2e14,), direction="-", ) return ms @@ -298,6 +298,7 @@ def get_str(*args, **kwargs): fname_tmp = str(tmp_path / "web_test_tmp.json") download_json(TASK_ID, fname_tmp) + assert ModeSolver.from_file(fname_tmp) == sim diff --git a/tidy3d/compat.py b/tidy3d/compat.py index 6060407fe8..cce29d3845 100644 --- a/tidy3d/compat.py +++ b/tidy3d/compat.py @@ -5,4 +5,9 @@ except ImportError: from xarray.core import alignment -__all__ = ["alignment"] +try: + from typing import Self # Python >= 3.11 +except ImportError: # Python <3.11 + from typing_extensions import Self + +__all__ = ["alignment", "Self"] diff --git a/tidy3d/components/apodization.py b/tidy3d/components/apodization.py index 9dd8614094..4166fe8c19 100644 --- a/tidy3d/components/apodization.py +++ b/tidy3d/components/apodization.py @@ -1,11 +1,14 @@ """Defines specification for apodization.""" +from typing import Optional + import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, model_validator +from ..compat import Self from ..constants import SECOND from ..exceptions import SetupError -from .base import Tidy3dBaseModel, skip_if_fields_missing +from .base import Tidy3dBaseModel from .types import ArrayFloat1D, Ax from .viz import add_ax_if_none @@ -24,45 +27,40 @@ class ApodizationSpec(Tidy3dBaseModel): """ - start: pd.NonNegativeFloat = pd.Field( + start: Optional[NonNegativeFloat] = Field( None, title="Start Interval", description="Defines the time at which the start apodization ends.", units=SECOND, ) - end: pd.NonNegativeFloat = pd.Field( + end: Optional[NonNegativeFloat] = Field( None, title="End Interval", description="Defines the time at which the end apodization begins.", units=SECOND, ) - width: pd.PositiveFloat = pd.Field( + width: Optional[PositiveFloat] = Field( None, title="Apodization Width", description="Characteristic decay length of the apodization function, i.e., the width of the ramping up of the scaling function from 0 to 1.", units=SECOND, ) - @pd.validator("end", always=True, allow_reuse=True) - @skip_if_fields_missing(["start"]) - def end_greater_than_start(cls, val, values): + @model_validator(mode="after") + def end_greater_than_start(self) -> Self: """Ensure end is greater than or equal to start.""" - start = values.get("start") - if val is not None and start is not None and val < start: + if self.end is not None and self.start is not None and self.end < self.start: raise SetupError("End apodization begins before start apodization ends.") - return val + return self - @pd.validator("width", always=True, allow_reuse=True) - @skip_if_fields_missing(["start", "end"]) - def width_provided(cls, val, values): + @model_validator(mode="after") + def width_provided(self) -> Self: """Check that width is provided if either start or end apodization is requested.""" - start = values.get("start") - end = values.get("end") - if (start is not None or end is not None) and val is None: + if (self.start is not None or self.end is not None) and self.width is None: raise SetupError("Apodization width must be set.") - return val + return self @add_ax_if_none def plot(self, times: ArrayFloat1D, ax: Ax = None) -> Ax: diff --git a/tidy3d/components/autograd/__init__.py b/tidy3d/components/autograd/__init__.py index e80b43fd42..364b481ace 100644 --- a/tidy3d/components/autograd/__init__.py +++ b/tidy3d/components/autograd/__init__.py @@ -2,12 +2,11 @@ from .functions import interpn from .types import ( AutogradFieldMap, - AutogradTraced, + TracedArrayFloat2D, TracedCoordinate, TracedFloat, TracedSize, TracedSize1D, - TracedVertices, ) from .utils import get_static, is_tidy_box, split_list @@ -17,8 +16,7 @@ "TracedSize1D", "TracedSize", "TracedCoordinate", - "TracedVertices", - "AutogradTraced", + "TracedArrayFloat2D", "AutogradFieldMap", "get_static", "interpn", diff --git a/tidy3d/components/autograd/boxes.py b/tidy3d/components/autograd/boxes.py index 437005d9d5..0840ee490b 100644 --- a/tidy3d/components/autograd/boxes.py +++ b/tidy3d/components/autograd/boxes.py @@ -2,7 +2,7 @@ # NOTE: we do not subclass ArrayBox since that would break autograd's internal checks import importlib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable import autograd.numpy as anp from autograd.extend import VJPNode, defjvp, register_notrace @@ -33,9 +33,9 @@ def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox: def __array_function__( self: Any, func: Callable, - types: List[Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], + types: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], ) -> Any: """ Handle the dispatch of NumPy functions to autograd's numpy implementation. @@ -46,11 +46,11 @@ def __array_function__( The instance of the class. func : Callable The NumPy function being called. - types : List[Any] + types : list[Any] The types of the arguments that implement __array_function__. - args : Tuple[Any, ...] + args : tuple[Any, ...] The positional arguments to the function. - kwargs : Dict[str, Any] + kwargs : dict[str, Any] The keyword arguments to the function. Returns @@ -102,7 +102,7 @@ def __array_ufunc__( ufunc: Callable, method: str, *inputs: Any, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Any: """ Handle the dispatch of NumPy ufuncs to autograd's numpy implementation. @@ -117,7 +117,7 @@ def __array_ufunc__( The method of the ufunc being called. inputs : Any The input arguments to the ufunc. - kwargs : Dict[str, Any] + kwargs : dict[str, Any] The keyword arguments to the ufunc. Returns @@ -152,7 +152,6 @@ def item(self): TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp TidyArrayBox.__array_ufunc__ = __array_ufunc__ TidyArrayBox.__array_function__ = __array_function__ -TidyArrayBox.__repr__ = str TidyArrayBox.real = property(anp.real) TidyArrayBox.imag = property(anp.imag) TidyArrayBox.conj = anp.conj diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index e13b2abcf0..a9ad5beab0 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -1,14 +1,16 @@ # utilities for autograd derivative passing -from __future__ import annotations + +from typing import Optional import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field +from ...compat import Self from ...constants import LARGE_NUMBER from ..base import Tidy3dBaseModel from ..data.data_array import ScalarFieldDataArray, SpatialDataArray -from ..types import ArrayLike, Bound, tidycomplex +from ..types import ArrayLike, Bound, Complex from .types import PathType from .utils import get_static @@ -43,32 +45,27 @@ class DerivativeSurfaceMesh(Tidy3dBaseModel): """ - centers: ArrayLike = pd.Field( - ..., + centers: ArrayLike = Field( title="Centers", description="(N, 3) array storing the centers of each surface element.", ) - areas: ArrayLike = pd.Field( - ..., + areas: ArrayLike = Field( title="Area Elements", description="(N,) array storing the first perpendicular vectors of each surface element.", ) - normals: ArrayLike = pd.Field( - ..., + normals: ArrayLike = Field( title="Normals", description="(N, 3) array storing the normal vectors of each surface element.", ) - perps1: ArrayLike = pd.Field( - ..., + perps1: ArrayLike = Field( title="Perpendiculars 1", description="(N, 3) array storing the first perpendicular vectors of each surface element.", ) - perps2: ArrayLike = pd.Field( - ..., + perps2: ArrayLike = Field( title="Perpendiculars 1", description="(N, 3) array storing the first perpendicular vectors of each surface element.", ) @@ -77,14 +74,12 @@ class DerivativeSurfaceMesh(Tidy3dBaseModel): class DerivativeInfo(Tidy3dBaseModel): """Stores derivative information passed to the ``.compute_derivatives`` methods.""" - paths: list[PathType] = pd.Field( - ..., + paths: list[PathType] = Field( title="Paths to Traced Fields", description="List of paths to the traced fields that need derivatives calculated.", ) - E_der_map: FieldData = pd.Field( - ..., + E_der_map: FieldData = Field( title="Electric Field Gradient Map", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' "multiplication of the forward and adjoint electric fields. The tangential components " @@ -92,91 +87,81 @@ class DerivativeInfo(Tidy3dBaseModel): "All components are used when computing volume-based gradients.", ) - D_der_map: FieldData = pd.Field( - ..., + D_der_map: FieldData = Field( title="Displacement Field Gradient Map", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` store the ' "multiplication of the forward and adjoint displacement fields. The normal component " "of this dataset is used when computing adjoint gradients for shifting boundaries.", ) - E_fwd: FieldData = pd.Field( - ..., + E_fwd: FieldData = Field( title="Forward Electric Fields", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` represent the ' "forward electric fields used for computing gradients for a given structure.", ) - E_adj: FieldData = pd.Field( - ..., + E_adj: FieldData = Field( title="Adjoint Electric Fields", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` represent the ' "adjoint electric fields used for computing gradients for a given structure.", ) - D_fwd: FieldData = pd.Field( - ..., + D_fwd: FieldData = Field( title="Forward Displacement Fields", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` represent the ' "forward displacement fields used for computing gradients for a given structure.", ) - D_adj: FieldData = pd.Field( - ..., + D_adj: FieldData = Field( title="Adjoint Displacement Fields", description='Dataset where the field components ``("Ex", "Ey", "Ez")`` represent the ' "adjoint displacement fields used for computing gradients for a given structure.", ) - eps_data: PermittivityData = pd.Field( - ..., + eps_data: PermittivityData = Field( title="Permittivity Dataset", description="Dataset of relative permittivity values along all three dimensions. " "Used for automatically computing permittivity inside or outside of a simple geometry.", ) - eps_in: tidycomplex = pd.Field( + eps_in: Complex = Field( title="Permittivity Inside", description="Permittivity inside of the ``Structure``. " "Typically computed from ``Structure.medium.eps_model``." "Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", ) - eps_out: tidycomplex = pd.Field( - ..., + eps_out: Complex = Field( title="Permittivity Outside", description="Permittivity outside of the ``Structure``. " "Typically computed from ``Simulation.medium.eps_model``." "Used when it can not be computed from ``eps_data`` or when ``eps_approx==True``.", ) - eps_background: tidycomplex = pd.Field( + eps_background: Complex = Field( None, title="Permittivity in Background", description="Permittivity outside of the ``Structure`` as manually specified by. " "``Structure.background_medium``. ", ) - bounds: Bound = pd.Field( - ..., + bounds: Bound = Field( title="Geometry Bounds", description="Bounds corresponding to the structure, used in ``Medium`` calculations.", ) - bounds_intersect: Bound = pd.Field( - ..., + bounds_intersect: Bound = Field( title="Geometry and Simulation Intersections Bounds", description="Bounds corresponding to the minimum intersection between the " "structure and the simulation it is contained in.", ) - frequency: float = pd.Field( - ..., + frequency: float = Field( title="Frequency of adjoint simulation", description="Frequency at which the adjoint gradient is computed.", ) - eps_no_structure: SpatialDataArray = pd.Field( + eps_no_structure: Optional[SpatialDataArray] = Field( None, title="Permittivity Without Structure", description="The permittivity of the original simulation without the structure that is " @@ -184,7 +169,7 @@ class DerivativeInfo(Tidy3dBaseModel): "structure for shape optimization.", ) - eps_inf_structure: SpatialDataArray = pd.Field( + eps_inf_structure: Optional[SpatialDataArray] = Field( None, title="Permittivity With Infinite Structure", description="The permittivity of the original simulation where the structure being " @@ -192,7 +177,7 @@ class DerivativeInfo(Tidy3dBaseModel): "inside of the structure for shape optimization.", ) - eps_approx: bool = pd.Field( + eps_approx: bool = Field( False, title="Use Permittivity Approximation", description="If ``True``, approximates outside permittivity using ``Simulation.medium``" @@ -201,7 +186,7 @@ class DerivativeInfo(Tidy3dBaseModel): "evaluate the inside and outside relative permittivity for each geometry.", ) - def updated_paths(self, paths: list[PathType]) -> DerivativeInfo: + def updated_paths(self, paths: list[PathType]) -> Self: """Update this ``DerivativeInfo`` with new set of paths.""" return self.updated_copy(paths=paths) diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py index e8be92bb74..ad767c9e88 100644 --- a/tidy3d/components/autograd/types.py +++ b/tidy3d/components/autograd/types.py @@ -1,17 +1,26 @@ # type information for autograd -# utilities for working with autograd - import copy -import typing +from typing import Annotated, Literal, Optional, TypeAlias, Union -import pydantic.v1 as pd +import autograd.numpy as anp from autograd.builtins import dict as dict_ag from autograd.extend import Box, defvjp, primitive +from autograd.tracer import getval +from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter from tidy3d.components.type_util import _add_schema -from ..types import ArrayFloat2D, ArrayLike, Complex, Size1D +from ..types import ( + ArrayFloat2D, + Complex, + Coordinate, + PolesAndResidues, + Size, + Size1D, + _auto_serializer, +) +from .utils import contains_box # add schema to the Box _add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") @@ -25,32 +34,66 @@ Box.__copy__ = lambda v: _copy(v) Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) +Box.__str__ = lambda self: f"{self._value} <{type(self).__name__}>" +Box.__repr__ = Box.__str__ + + +def traced_alias(base_alias, *, name: Optional[str] = None) -> TypeAlias: + base_adapter = TypeAdapter(base_alias, config=dict(arbitrary_types_allowed=True)) + + def _validate_box_or_container(v): + # case 1: v itself is a tracer + # in this case we just validate but leave the tracer untouched + if isinstance(v, Box): + base_adapter.validate_python(getval(v)) + return v + + # case 2: v is a plain container that contains at least one tracer + # in this case we try to coerce into ArrayBox for efficiency + if contains_box(v): + dense = anp.array(v) + base_adapter.validate_python(getval(dense)) + return dense + + raise ValueError("expected autograd tracer") + + return Annotated[ + Union[ + base_alias, + Annotated[ + Box, + BeforeValidator(_validate_box_or_container), + PlainSerializer(lambda a, _: _auto_serializer(getval(a), _), when_used="json"), + ], + Annotated[object, BeforeValidator(_validate_box_or_container)], + ], + {} if name is None else {"title": name}, + ] + # Types for floats, or collections of floats that can also be autograd tracers -TracedFloat = typing.Union[float, Box] -TracedPositiveFloat = typing.Union[pd.PositiveFloat, Box] -TracedSize1D = typing.Union[Size1D, Box] -TracedSize = typing.Union[tuple[TracedSize1D, TracedSize1D, TracedSize1D], Box] -TracedCoordinate = typing.Union[tuple[TracedFloat, TracedFloat, TracedFloat], Box] -TracedVertices = typing.Union[ArrayFloat2D, Box] +TracedFloat = traced_alias(float) +TracedPositiveFloat = traced_alias(PositiveFloat) +TracedSize1D = traced_alias(Size1D) +TracedSize = traced_alias(Size) +TracedCoordinate = traced_alias(Coordinate) +TracedArrayFloat2D = traced_alias(ArrayFloat2D) # poles -TracedComplex = typing.Union[Complex, Box] -TracedPoleAndResidue = typing.Tuple[TracedComplex, TracedComplex] +TracedComplex = traced_alias(Complex) +TracedPolesAndResidues = traced_alias(PolesAndResidues) # The data type that we pass in and out of the web.run() @autograd.primitive -AutogradTraced = typing.Union[Box, ArrayLike] -PathType = tuple[typing.Union[int, str], ...] -AutogradFieldMap = dict_ag[PathType, AutogradTraced] +PathType = tuple[Union[int, str], ...] +AutogradFieldMap = dict_ag[PathType, Box] -InterpolationType = typing.Literal["nearest", "linear"] +InterpolationType = Literal["nearest", "linear"] __all__ = [ "TracedFloat", "TracedSize1D", "TracedSize", "TracedCoordinate", - "TracedVertices", - "AutogradTraced", + "TracedArrayFloat2D", "AutogradFieldMap", ] diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py index 0a1fbfd43d..4aa03bbc8f 100644 --- a/tidy3d/components/autograd/utils.py +++ b/tidy3d/components/autograd/utils.py @@ -1,26 +1,39 @@ # utilities for working with autograd -import typing +from collections.abc import Mapping, Sequence +from typing import Any +from autograd.extend import Box from autograd.tracer import getval -def get_static(x: typing.Any) -> typing.Any: +def get_static(x: Any) -> Any: """Get the 'static' (untraced) version of some value.""" return getval(x) -def split_list(x: list[typing.Any], index: int) -> (list[typing.Any], list[typing.Any]): +def split_list(x: list[Any], index: int) -> (list[Any], list[Any]): """Split a list at a given index.""" x = list(x) return x[:index], x[index:] -def is_tidy_box(x: typing.Any) -> bool: +def is_tidy_box(x: Any) -> bool: """Check if a value is a tidy box.""" return getattr(x, "_tidy", False) +def contains_box(obj: Any) -> bool: + """True if any element inside obj is an autograd Box.""" + if isinstance(obj, Box): + return True + if isinstance(obj, Mapping): + return any(contains_box(v) for v in obj.values()) + if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): + return any(contains_box(i) for i in obj) + return False + + __all__ = [ "get_static", "split_list", diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index 9615c33beb..e1c0b94c22 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -1,35 +1,53 @@ """global configuration / base class for pydantic models used to make simulation.""" -from __future__ import annotations - import hashlib import io import json import math import os -import pathlib import tempfile -from functools import wraps +import typing as _t +from collections import defaultdict +from functools import total_ordering, wraps from math import ceil -from typing import Any, Callable, Dict, List, Tuple, Union +from pathlib import Path +from typing import ( + Any, + Callable, + Literal, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, +) import h5py import numpy as np -import pydantic.v1 as pydantic import rich import xarray as xr import yaml from autograd.builtins import dict as dict_ag from autograd.tracer import isbox -from pydantic.v1.fields import ModelField - +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + field_validator, + model_validator, +) +from pydantic_core import PydanticCustomError + +from ..compat import Self from ..exceptions import FileError from ..log import log -from .autograd.types import AutogradFieldMap, Box +from .autograd.types import AutogradFieldMap from .autograd.utils import get_static -from .data.data_array import DATA_ARRAY_MAP, DataArray +from .data.data_array import DATA_ARRAY_MAP from .file_util import compress_file_to_gzip, extract_gzip_file -from .types import TYPE_TAG_STR, ComplexNumber, Literal +from .types import TYPE_TAG_STR INDENT_JSON_FILE = 4 # default indentation of json string in json files INDENT = None # default indentation of json string used internally @@ -67,17 +85,10 @@ def cached_property(cached_property_getter): return property(cache(cached_property_getter)) -def ndarray_encoder(val): - """How a ``np.ndarray`` gets handled before saving to json.""" - if np.any(np.iscomplex(val)): - return dict(real=val.real.tolist(), imag=val.imag.tolist()) - return val.real.tolist() - - def _get_valid_extension(fname: str) -> str: """Return the file extension from fname, validated to accepted ones.""" valid_extensions = [".json", ".yaml", ".hdf5", ".h5", ".hdf5.gz"] - extensions = [s.lower() for s in pathlib.Path(fname).suffixes[-2:]] + extensions = [s.lower() for s in Path(fname).suffixes[-2:]] if len(extensions) == 0: raise FileError(f"File '{fname}' missing extension.") single_extension = extensions[-1] @@ -92,35 +103,20 @@ def _get_valid_extension(fname: str) -> str: ) -def skip_if_fields_missing(fields: List[str], root=False): - """Decorate ``validator`` to check that other fields have passed validation.""" +def _fmt_ann_literal(ann) -> str: + """Spell the annotation exactly as written.""" + if ann is None: + return "Any" + if isinstance(ann, _t._GenericAlias): + return str(ann).replace("typing.", "") + return ann.__name__ if hasattr(ann, "__name__") else str(ann) - def actual_decorator(validator): - @wraps(validator) - def _validator(cls, *args, **kwargs): - """New validator function.""" - values = kwargs.get("values") - if values is None: - values = args[0] if root else args[1] - for field in fields: - if field not in values: - log.warning( - f"Could not execute validator '{validator.__name__}' because field " - f"'{field}' failed validation." - ) - if root: - return values - else: - return kwargs.get("val") if "val" in kwargs.keys() else args[0] - return validator(cls, *args, **kwargs) +T = TypeVar("T", bound="Tidy3dBaseModel") - return _validator - return actual_decorator - - -class Tidy3dBaseModel(pydantic.BaseModel): +@total_ordering +class Tidy3dBaseModel(BaseModel): """Base pydantic model that all Tidy3d components inherit from. Defines configuration for handling data structures as well as methods for importing, exporting, and hashing tidy3d objects. @@ -128,6 +124,73 @@ class Tidy3dBaseModel(pydantic.BaseModel): `Pydantic Models `_ """ + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_default=True, + validate_assignment=True, + populate_by_name=True, + ser_json_inf_nan="strings", + extra="forbid", + frozen=True, + ) + + attrs: dict = Field( + default_factory=dict, + title="Attributes", + description="Dictionary storing arbitrary metadata for a Tidy3D object. " + "This dictionary can be freely used by the user for storing data without affecting the " + "operation of Tidy3D as it is not used internally. " + "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " + "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " + "Also note that `Tidy3D`` will raise a ``TypeError`` if ``attrs`` contain objects " + "that can not be serialized. One can check if ``attrs`` are serializable " + "by calling ``obj.json()``.", + ) + + _cached_properties: dict = PrivateAttr(default_factory=dict) + + @field_validator("name", check_fields=False) + def _validate_name_no_special_characters(name): + if name is None: + return name + for character in FORBID_SPECIAL_CHARACTERS: + if character in name: + raise ValueError( + f"Special character '{character}' not allowed in component name {name}." + ) + return name + + def __init__(self, **kwargs): + """Init method, includes post-init validators.""" + log.begin_capture() + super().__init__(**kwargs) + log.end_capture(self) + + def __init_subclass__(cls: Type[T], **kwargs): + """Injects a constant discriminator field before Pydantic builds the model. + + Adds + type: Literal[""] = "" + to every concrete subclass so it can participate in a + `Field(discriminator="type")` union without manual boilerplate. + + Must run *before* `super().__init_subclass__()`; that call lets Pydantic + see the injected field during its normal schema/validator generation. + See also: https://peps.python.org/pep-0487/ + """ + tag = cls.__name__ + cls.__annotations__[TYPE_TAG_STR] = Literal[tag] + setattr(cls, TYPE_TAG_STR, tag) + + super().__init_subclass__(**kwargs) + + @classmethod + def __pydantic_init_subclass__(cls: Type[T], **kwargs): + super().__pydantic_init_subclass__(**kwargs) + + # add docstring once pydantic is done constructing the class + cls.__doc__ = cls.generate_docstring() + def __hash__(self) -> int: """Hash method.""" try: @@ -141,101 +204,63 @@ def _hash_self(self) -> str: self.to_hdf5(bf) return hashlib.sha256(bf.getvalue()).hexdigest() - def __init__(self, **kwargs): - """Init method, includes post-init validators.""" - log.begin_capture() - super().__init__(**kwargs) - self._post_init_validators() - log.end_capture(self) + @model_validator(mode="wrap") + def _call_post_init_validators(cls, data: Any, handler): + obj = handler(data) + for fn in obj._post_init_validators: + try: + fn() + except Exception as exc: + raise PydanticCustomError( + "post_init_validator", + 'post-init validator "{validator}" failed: {msg}', + {"validator": fn.__name__, "msg": str(exc)}, + ) from exc + return obj + + @property + def _post_init_validators(self) -> tuple[Callable[[Self], None], ...]: + """List of functions to run for post-init validation""" + return () + + def copy( + self, *, deep: bool = True, validate: bool = True, update: Mapping[str, Any] | None = None + ) -> Self: + """Return a copy of the model. - def _post_init_validators(self) -> None: - """Call validators taking ``self`` that get run after init, implement in subclasses.""" - - def __init_subclass__(cls) -> None: - """Things that are done to each of the models.""" - - cls.add_type_field() - cls.generate_docstring() - - class Config: - """Sets config for all :class:`Tidy3dBaseModel` objects. - - Configuration Options - --------------------- - allow_population_by_field_name : bool = True - Allow properties to stand in for fields(?). - arbitrary_types_allowed : bool = True - Allow types like numpy arrays. - extra : str = 'forbid' - Forbid extra kwargs not specified in model. - json_encoders : Dict[type, Callable] - Defines how to encode type in json file. - validate_all : bool = True - Validate default values just to be safe. - validate_assignment : bool - Re-validate after re-assignment of field in model. + Parameters + ---------- + deep : bool = True + Whether to make a deep copy first (same as v1). + validate : bool = True + If ``True``, run full Pydantic validation on the copied data. + update : Mapping[str, Any] | None = None + Optional mapping of fields to overwrite (passed straight + through to ``model_copy(update=...)``). """ + if update and self.model_config.get("extra") == "forbid": + invalid = set(update) - set(self.model_fields) + if invalid: + raise KeyError(f"'{self.type}' received invalid fields on copy: {invalid}") - arbitrary_types_allowed = True - validate_all = True - extra = "forbid" - validate_assignment = True - allow_population_by_field_name = True - json_encoders = { - np.ndarray: ndarray_encoder, - complex: lambda x: ComplexNumber(real=x.real, imag=x.imag), - xr.DataArray: DataArray._json_encoder, - Box: lambda x: x._value, - } - frozen = True - allow_mutation = False - copy_on_model_validation = "none" - - _cached_properties = pydantic.PrivateAttr({}) - - @pydantic.root_validator(skip_on_failure=True) - def _special_characters_not_in_name(cls, values): - name = values.get("name") - if name: - for character in FORBID_SPECIAL_CHARACTERS: - if character in name: - raise ValueError( - f"Special character '{character}' not allowed in component name {name}." - ) - return values - - attrs: dict = pydantic.Field( - {}, - title="Attributes", - description="Dictionary storing arbitrary metadata for a Tidy3D object. " - "This dictionary can be freely used by the user for storing data without affecting the " - "operation of Tidy3D as it is not used internally. " - "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " - "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " - "Also note that `Tidy3D`` will raise a ``TypeError`` if ``attrs`` contain objects " - "that can not be serialized. One can check if ``attrs`` are serializable " - "by calling ``obj.json()``.", - ) + new_model = self.model_copy(deep=deep, update=update) - def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Tidy3dBaseModel: - """Copy a Tidy3dBaseModel. With ``deep=True`` and ``validate=True`` as default.""" - kwargs.update(deep=deep) - new_copy = pydantic.BaseModel.copy(self, **kwargs) if validate: - return self.validate(new_copy.dict()) - # cached property is cleared automatically when validation is on, but it - # needs to be manually cleared when validation is off - new_copy._cached_properties = {} - return new_copy + return self.__class__.model_validate(new_model.model_dump()) + else: + # make sure cache is always cleared + new_model._cached_properties = {} + + return new_model def updated_copy( - self, path: str = None, deep: bool = True, validate: bool = True, **kwargs - ) -> Tidy3dBaseModel: + self, path: str | None = None, *, deep: bool = True, validate: bool = True, **kwargs: Any + ) -> Self: """Make copy of a component instance with ``**kwargs`` indicating updated field values. Note ---- - If ``path`` supplied, applies the updated copy with the update performed on the sub- + If ``path`` is supplied, applies the updated copy with the update performed on the sub- component corresponding to the path. For indexing into a tuple or list, use the integer value. @@ -243,54 +268,44 @@ def updated_copy( ------- >>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP """ - if not path: - return self._updated_copy(**kwargs, deep=deep, validate=validate) + return self.copy(deep=deep, validate=validate, update=kwargs) - path_components = path.split("/") - - field_name = path_components[0] + path_parts = path.split("/") + field_name, *rest = path_parts try: sub_component = getattr(self, field_name) - except AttributeError as e: + except AttributeError as exc: raise AttributeError( - f"Could not field field '{field_name}' in the sub-component `path`. " - f"Found fields of '{tuple(self.__fields__.keys())}'. " - "Please double check the `path` passed to `.updated_copy()`." - ) from e + f"Could not find field '{field_name}' in path '{path}'. " + f"Available top-level fields: {tuple(self.model_fields)}." + ) from exc if isinstance(sub_component, (list, tuple)): - integer_index_path = path_components[1] - try: - index = int(integer_index_path) - except ValueError: + index = int(rest[0]) + except (IndexError, ValueError): raise ValueError( - f"Could not grab integer index from path '{path}'. " - f"Please correct the sub path containing '{integer_index_path}' to be an " - f"integer index into '{field_name}' (containing {len(sub_component)} elements)." + f"Expected integer index into '{field_name}' " f"in path '{path}'." ) - sub_component_list = list(sub_component) - sub_component = sub_component_list[index] - sub_path = "/".join(path_components[2:]) - - sub_component_list[index] = sub_component.updated_copy( - path=sub_path, deep=deep, validate=validate, **kwargs + sub_component_list[index] = sub_component_list[index].updated_copy( + path="/".join(rest[1:]), + deep=deep, + validate=validate, + **kwargs, ) - new_component = tuple(sub_component_list) + new_value = type(sub_component)(sub_component_list) else: - sub_path = "/".join(path_components[1:]) - new_component = sub_component.updated_copy( - path=sub_path, deep=deep, validate=validate, **kwargs + new_value = sub_component.updated_copy( + path="/".join(rest), + deep=deep, + validate=validate, + **kwargs, ) - return self._updated_copy(deep=deep, validate=validate, **{field_name: new_component}) - - def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Tidy3dBaseModel: - """Make copy of a component instance with ``**kwargs`` indicating updated field values.""" - return self.copy(update=kwargs, deep=deep, validate=validate) + return self.copy(deep=deep, validate=validate, update={field_name: new_value}) def help(self, methods: bool = False) -> None: """Prints message describing the fields and methods of a :class:`Tidy3dBaseModel`. @@ -307,7 +322,7 @@ def help(self, methods: bool = False) -> None: rich.inspect(self, methods=methods) @classmethod - def from_file(cls, fname: str, group_path: str = None, **parse_obj_kwargs) -> Tidy3dBaseModel: + def from_file(cls: Type[T], fname: str, group_path: str = None, **parse_obj_kwargs) -> T: """Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file. Parameters @@ -330,10 +345,10 @@ def from_file(cls, fname: str, group_path: str = None, **parse_obj_kwargs) -> Ti >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP """ model_dict = cls.dict_from_file(fname=fname, group_path=group_path) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) @classmethod - def dict_from_file(cls, fname: str, group_path: str = None) -> dict: + def dict_from_file(cls: Type[T], fname: str, group_path: str = None) -> dict: """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. Parameters @@ -394,7 +409,7 @@ def to_file(self, fname: str) -> None: return converter(fname=fname) @classmethod - def from_json(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel: + def from_json(cls: Type[T], fname: str, **parse_obj_kwargs) -> T: """Load a :class:`Tidy3dBaseModel` from .json file. Parameters @@ -414,10 +429,10 @@ def from_json(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel: >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP """ model_dict = cls.dict_from_json(fname=fname) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) @classmethod - def dict_from_json(cls, fname: str) -> dict: + def dict_from_json(cls: Type[T], fname: str) -> dict: """Load dictionary of the model from a .json file. Parameters @@ -450,13 +465,13 @@ def to_json(self, fname: str) -> None: ------- >>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP """ - json_string = self._json(indent=INDENT_JSON_FILE) + json_string = self.model_dump_json(indent=INDENT_JSON_FILE) self._warn_if_contains_data(json_string) with open(fname, "w", encoding="utf-8") as file_handle: file_handle.write(json_string) @classmethod - def from_yaml(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel: + def from_yaml(cls: Type[T], fname: str, **parse_obj_kwargs) -> T: """Loads :class:`Tidy3dBaseModel` from .yaml file. Parameters @@ -476,10 +491,10 @@ def from_yaml(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel: >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP """ model_dict = cls.dict_from_yaml(fname=fname) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) @classmethod - def dict_from_yaml(cls, fname: str) -> dict: + def dict_from_yaml(cls: Type[T], fname: str) -> dict: """Load dictionary of the model from a .yaml file. Parameters @@ -525,8 +540,8 @@ def _warn_if_contains_data(json_str: str) -> None: log.warning( "Data contents found in the model to be written to file. " "Note that this data will not be included in '.json' or '.yaml' formats. " - "As a result, it will not be possible to load the file back to the original model." - "Instead, use `.hdf5` extension in filename passed to 'to_file()'." + "As a result, it will not be possible to load the file back to the original model. " + "Instead, use '.hdf5' extension in filename passed to 'to_file()'." ) @staticmethod @@ -554,12 +569,12 @@ def get_tuple_index(key_name: str) -> int: return int(str(key_name)) @classmethod - def tuple_to_dict(cls, tuple_values: tuple) -> dict: + def tuple_to_dict(cls: Type[T], tuple_values: tuple) -> dict: """How we generate a dictionary mapping new keys to tuple values for hdf5.""" return {cls.get_tuple_group_name(index=i): val for i, val in enumerate(tuple_values)} @classmethod - def get_sub_model(cls, group_path: str, model_dict: dict | list) -> dict: + def get_sub_model(cls: Type[T], group_path: str, model_dict: dict | list) -> dict: """Get the sub model for a given group path.""" for key in group_path.split("/"): @@ -579,7 +594,7 @@ def _json_string_key(index: int) -> str: return JSON_TAG @classmethod - def _json_string_from_hdf5(cls, fname: str) -> str: + def _json_string_from_hdf5(cls: Type[T], fname: str) -> str: """Load the model json string from an hdf5 file.""" with h5py.File(fname, "r") as f_handle: num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) @@ -590,7 +605,10 @@ def _json_string_from_hdf5(cls, fname: str) -> str: @classmethod def dict_from_hdf5( - cls, fname: str, group_path: str = "", custom_decoders: List[Callable] = None + cls: Type[T], + fname: str, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, ) -> dict: """Loads a dictionary containing the model contents from a .hdf5 file. @@ -665,12 +683,12 @@ def load_data_from_file(model_dict: dict, group_path: str = "") -> None: @classmethod def from_hdf5( - cls, + cls: Type[T], fname: str, group_path: str = "", - custom_decoders: List[Callable] = None, + custom_decoders: list[Callable] = None, **parse_obj_kwargs, - ) -> Tidy3dBaseModel: + ) -> T: """Loads :class:`Tidy3dBaseModel` instance to .hdf5 file. Parameters @@ -696,9 +714,9 @@ def from_hdf5( model_dict = cls.dict_from_hdf5( fname=fname, group_path=group_path, custom_decoders=custom_decoders ) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) - def to_hdf5(self, fname: str, custom_encoders: List[Callable] = None) -> None: + def to_hdf5(self, fname: str, custom_encoders: list[Callable] = None) -> None: """Exports :class:`Tidy3dBaseModel` instance to .hdf5 file. Parameters @@ -745,11 +763,11 @@ def add_data_to_file(data_dict: dict, group_path: str = "") -> None: elif isinstance(value, dict): add_data_to_file(data_dict=value, group_path=subpath) - add_data_to_file(data_dict=self.dict()) + add_data_to_file(data_dict=self.model_dump()) @classmethod def dict_from_hdf5_gz( - cls, fname: str, group_path: str = "", custom_decoders: List[Callable] = None + cls: Type[T], fname: str, group_path: str = "", custom_decoders: list[Callable] = None ) -> dict: """Loads a dictionary containing the model contents from a .hdf5.gz file. @@ -787,12 +805,12 @@ def dict_from_hdf5_gz( @classmethod def from_hdf5_gz( - cls, + cls: Type[T], fname: str, group_path: str = "", - custom_decoders: List[Callable] = None, + custom_decoders: list[Callable] = None, **parse_obj_kwargs, - ) -> Tidy3dBaseModel: + ) -> T: """Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file. Parameters @@ -818,9 +836,9 @@ def from_hdf5_gz( model_dict = cls.dict_from_hdf5_gz( fname=fname, group_path=group_path, custom_decoders=custom_decoders ) - return cls.parse_obj(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) - def to_hdf5_gz(self, fname: str, custom_encoders: List[Callable] = None) -> None: + def to_hdf5_gz(self, fname: str, custom_encoders: list[Callable] = None) -> None: """Exports :class:`Tidy3dBaseModel` instance to .hdf5.gz file. Parameters @@ -844,71 +862,57 @@ def to_hdf5_gz(self, fname: str, custom_encoders: List[Callable] = None) -> None finally: os.unlink(decompressed) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: """define < for getting unique indices based on hash.""" return hash(self) < hash(other) - def __gt__(self, other): - """define > for getting unique indices based on hash.""" - return hash(self) > hash(other) - - def __le__(self, other): - """define <= for getting unique indices based on hash.""" - return hash(self) <= hash(other) - - def __ge__(self, other): - """define >= for getting unique indices based on hash.""" - return hash(self) >= hash(other) + def __eq__(self, other: object) -> bool: + """Two models are equal when origins match and every public or extra field matches.""" + if not isinstance(other, BaseModel): + return NotImplemented - def __eq__(self, other): - """Define == for two Tidy3dBaseModels.""" - if other is None: + self_origin = ( + getattr(self, "__pydantic_generic_metadata__", {}).get("origin") or self.__class__ + ) + other_origin = ( + getattr(other, "__pydantic_generic_metadata__", {}).get("origin") or other.__class__ + ) + if self_origin is not other_origin: return False - def check_equal(dict1: dict, dict2: dict) -> bool: - """Check if two dictionaries are equal, with special handlings.""" - - # if different keys, automatically fail - if not dict1.keys() == dict2.keys(): - return False - - # loop through elements in each dict - for key in dict1.keys(): - val1 = dict1[key] - val2 = dict2[key] + if getattr(self, "__pydantic_extra__", None) != getattr(other, "__pydantic_extra__", None): + return False - val1 = get_static(val1) - val2 = get_static(val2) + def _fields_equal(a, b) -> bool: + a = get_static(a) + b = get_static(b) - # if one of val1 or val2 is None (exclusive OR) - if (val1 is None) != (val2 is None): + if a is b: + return True + if type(a) is not type(b): + if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): return False + if isinstance(a, np.ndarray): + return np.array_equal(a, b) + if isinstance(a, (xr.DataArray, xr.Dataset)): + return a.equals(b) + if isinstance(a, Mapping): + if a.keys() != b.keys(): + return False + return all(_fields_equal(a[k], b[k]) for k in a) + if isinstance(a, Sequence) and not isinstance(a, (str, bytes)): + if len(a) != len(b): + return False + return all(_fields_equal(x, y) for i, (x, y) in enumerate(zip(a, b))) + if isinstance(a, float) and isinstance(b, float) and np.isnan(a) and np.isnan(b): + return True + return a == b - # convert tuple to dict to use this recursive function - if isinstance(val1, tuple) or isinstance(val2, tuple): - val1 = dict(zip(range(len(val1)), val1)) - val2 = dict(zip(range(len(val2)), val2)) - - # if dictionaries, recurse - if isinstance(val1, dict) or isinstance(val2, dict): - are_equal = check_equal(val1, val2) - if not are_equal: - return False - - # if numpy arrays, use numpy to do equality check - elif isinstance(val1, np.ndarray) or isinstance(val2, np.ndarray): - if not np.array_equal(val1, val2): - return False - - # everything else - else: - # note: this logic is because != is handled differently in DataArrays apparently - if not val1 == val2: - return False - - return True + for name in self.model_fields: + if not _fields_equal(getattr(self, name), getattr(other, name)): + return False - return check_equal(self.dict(), other.dict()) + return True @cached_property def _json_string(self) -> str: @@ -919,34 +923,7 @@ def _json_string(self) -> str: str Json-formatted string holding :class:`Tidy3dBaseModel` data. """ - return self._json() - - def _json(self, indent=INDENT, exclude_unset=False, **kwargs) -> str: - """Overwrites the model ``json`` representation with some extra customized handling. - - Parameters - ----------- - **kwargs : kwargs passed to `self.json()` - - Returns - ------- - str - Json-formatted string holding :class:`Tidy3dBaseModel` data. - """ - - def make_json_compatible(json_string: str) -> str: - """Makes the string compatible with json standards, notably for infinity.""" - - tmp_string = "<>" - json_string = json_string.replace("-Infinity", tmp_string) - json_string = json_string.replace('""-Infinity""', tmp_string) - json_string = json_string.replace("Infinity", '"Infinity"') - json_string = json_string.replace('""Infinity""', '"Infinity"') - return json_string.replace(tmp_string, '"-Infinity"') - - json_string = self.json(indent=indent, exclude_unset=exclude_unset, **kwargs) - json_string = make_json_compatible(json_string) - return json_string + return self.model_dump_json(indent=INDENT, exclude_unset=False) def strip_traced_fields( self, starting_path: tuple[str] = (), include_untraced_data_arrays: bool = False @@ -956,7 +933,7 @@ def strip_traced_fields( Parameters ---------- starting_path : tuple[str, ...] = () - If provided, starts recursing in self.dict() from this path of field names + If provided, starts recursing in self.model_dump() from this path of field names include_untraced_data_arrays : bool = False Whether to include ``DataArray`` objects without tracers. We need to include these when returning data, but are unnecessary for structures. @@ -978,8 +955,12 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None: field_mapping[path] = x # for data arrays, need to be more careful as their tracers are stored in .data - elif isinstance(x, xr.DataArray) and (isbox(x.data) or include_untraced_data_arrays): - field_mapping[path] = x.data + elif isinstance(x, xr.DataArray): + data = x.data + if isbox(data) or any(isbox(el) for el in np.asarray(data).ravel()): + field_mapping[path] = x.data + elif include_untraced_data_arrays: + field_mapping[path] = x.data # for sequences, add (i,) to the path and handle each value individually elif isinstance(x, (list, tuple)): @@ -992,22 +973,20 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None: handle_value(val, path=path + (key,)) # recursively parse the dictionary of this object - self_dict = self.dict() + self_dict = self.model_dump(round_trip=True) # if an include_only string was provided, only look at that subset of the dict - if starting_path: - for key in starting_path: - self_dict = self_dict[key] + for key in starting_path: + self_dict = self_dict[key] handle_value(self_dict, path=starting_path) # convert the resulting field_mapping to an autograd-traced dictionary return dict_ag(field_mapping) - def insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Tidy3dBaseModel: + def insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self: """Recursively insert a map of paths to autograd-traced fields into a copy of this obj.""" - - self_dict = self.dict() + self_dict = self.model_dump(round_trip=True) def insert_value(x, path: tuple[str, ...], sub_dict: dict): """Insert a value into the path into a dictionary.""" @@ -1031,9 +1010,9 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict): for path, value in field_mapping.items(): insert_value(value, path=path, sub_dict=self_dict) - return self.parse_obj(self_dict) + return self.__class__.model_validate(self_dict) - def to_static(self) -> Tidy3dBaseModel: + def to_static(self) -> Self: """Version of object with all autograd-traced fields removed.""" # get dictionary of all traced fields @@ -1049,124 +1028,115 @@ def to_static(self) -> Tidy3dBaseModel: # insert the static values into a copy of self return self.insert_traced_fields(field_mapping_static) - @classmethod - def add_type_field(cls) -> None: - """Automatically place "type" field with model name in the model field dictionary.""" - - value = cls.__name__ - annotation = Literal[value] - - tag_field = ModelField.infer( - name=TYPE_TAG_STR, - value=value, - annotation=annotation, - class_validators=None, - config=cls.__config__, - ) - cls.__fields__[TYPE_TAG_STR] = tag_field - @classmethod def generate_docstring(cls) -> str: - """Generates a docstring for a Tidy3D mode and saves it to the __doc__ of the class.""" + """Generates a docstring for a Tidy3D model.""" - # store the docstring in here doc = "" - # if the model already has a docstring, get the first lines and save the rest + # keep any pre-existing class description original_docstrings = [] if cls.__doc__: original_docstrings = cls.__doc__.split("\n\n") - class_description = original_docstrings.pop(0) - doc += class_description + doc += original_docstrings.pop(0) original_docstrings = "\n\n".join(original_docstrings) - # create the list of parameters (arguments) for the model + # parameters doc += "\n\n Parameters\n ----------\n" - for field_name, field in cls.__fields__.items(): - # ignore the type tag + for field_name, field in cls.model_fields.items(): # v2 if field_name == TYPE_TAG_STR: continue - # get data type - data_type = field._type_display() + # type + ann = getattr(field, "annotation", None) + data_type = _fmt_ann_literal(ann) + + # default / default_factory + default_val = ( + f"{field.default_factory.__name__}()" + if field.default_factory is not None + else field.get_default(call_default_factory=False) + ) - # get default values - default_val = field.get_default() - if "=" in str(default_val): - # handle cases where default values are pydantic models - default_val = f"{default_val.__class__.__name__}({default_val})" - default_val = (", ").join(default_val.split(" ")) + if isinstance(default_val, BaseModel) or ( + "=" in str(default_val) if default_val is not None else False + ): + default_val = ", ".join( + str(f"{default_val.__class__.__name__}({default_val})").split(" ") + ) - # make first line: name : type = default - default_str = "" if field.required else f" = {default_val}" + default_str = "" if field.is_required() else f" = {default_val}" doc += f" {field_name} : {data_type}{default_str}\n" - # get field metadata - field_info = field.field_info - doc += " " - - # add units (if present) - units = field_info.extra.get("units") + parts = [] + + # units + units = None + extra = getattr(field, "json_schema_extra", None) + if isinstance(extra, dict): + units = extra.get("units") + if units is None and hasattr(field, "metadata"): + for meta in field.metadata: + if isinstance(meta, dict) and "units" in meta: + units = meta["units"] + break if units is not None: - if isinstance(units, (tuple, list)): - unitstr = "(" - for unit in units: - unitstr += str(unit) - unitstr += ", " - unitstr = unitstr[:-2] - unitstr += ")" - else: - unitstr = units - doc += f"[units = {unitstr}]. " + unitstr = ( + f"({', '.join(str(u) for u in units)})" + if isinstance(units, (list, tuple)) + else str(units) + ) + parts.append(f"[units = {unitstr}].") - # add description - description_str = field_info.description - if description_str is not None: - doc += f"{description_str}\n" + # description + desc = getattr(field, "description", None) + if desc: + parts.append(desc) - # add in remaining things in the docs - if original_docstrings: - doc += "\n" - doc += original_docstrings + if parts: + doc += " " + " ".join(parts) + "\n" + if original_docstrings: + doc += "\n" + original_docstrings doc += "\n" - cls.__doc__ = doc - - def get_submodels_by_hash(self) -> Dict[int, List[Union[str, Tuple[str, int]]]]: - """Return a dictionary of this object's sub-models indexed by their hash values.""" - fields = {} - for key in self.__fields__: - field = getattr(self, key) - - if isinstance(field, Tidy3dBaseModel): - hash_ = hash(field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append(key) - - # Do we need to consider np.ndarray here? - elif isinstance(field, (list, tuple, np.ndarray)): - for index, sub_field in enumerate(field): - if isinstance(sub_field, Tidy3dBaseModel): - hash_ = hash(sub_field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append((key, index)) - - elif isinstance(field, dict): - for index, sub_field in field.items(): - if isinstance(sub_field, Tidy3dBaseModel): - hash_ = hash(sub_field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append((key, index)) - - return fields + + return doc + + def get_submodels_by_hash(self) -> dict[int, list[Union[str, tuple[str, int]]]]: + """ + Return a mapping ``{hash(submodel): [field_path, ...]}`` for every + nested ``Tidy3dBaseModel`` inside this model. + """ + out = defaultdict(list) + + for name in self.model_fields: + value = getattr(self, name) + + if isinstance(value, Tidy3dBaseModel): + out[hash(value)].append(name) + continue + + if isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, np.ndarray): + for idx, item in enumerate(value.flat): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, dict): + for k, item in value.items(): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, k)) + + return dict(out) @staticmethod def _scientific_notation( min_val: float, max_val: float, min_digits: int = 4 - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """ Convert numbers to scientific notation, displaying only digits up to the point of difference, with a minimum number of significant digits specified by `min_digits`. @@ -1199,3 +1169,30 @@ def to_sci(value: float, exponent: int, precision: int) -> str: sci_max = to_sci(max_val, common_exponent, precision) return sci_min, sci_max + + def __rich_repr__(self): + """How to pretty-print instances of ``Tidy3dBaseModel``.""" + for name in self.model_fields: + value = getattr(self, name) + + # don't print the type field we add to the models + if name == "type": + continue + + # skip `attrs` if it's an empty dictionary + if name == "attrs" and isinstance(value, dict) and not value: + continue + + yield name, value + + def __str__(self) -> str: + """Return a pretty-printed string representation of the model.""" + from io import StringIO + + from rich.console import Console + + sio = StringIO() + console = Console(file=sio) + console.print(self) + output = sio.getvalue() + return output.rstrip("\n") diff --git a/tidy3d/components/base_sim/data/monitor_data.py b/tidy3d/components/base_sim/data/monitor_data.py index 86d2f0717b..6d8ed6f52d 100644 --- a/tidy3d/components/base_sim/data/monitor_data.py +++ b/tidy3d/components/base_sim/data/monitor_data.py @@ -4,7 +4,7 @@ from abc import ABC -import pydantic.v1 as pd +from pydantic import Field from ...data.dataset import Dataset from ..monitor import AbstractMonitor @@ -15,8 +15,7 @@ class AbstractMonitorData(Dataset, ABC): :class:`AbstractMonitor`. """ - monitor: AbstractMonitor = pd.Field( - ..., + monitor: AbstractMonitor = Field( title="Monitor", description="Monitor associated with the data.", ) diff --git a/tidy3d/components/base_sim/data/sim_data.py b/tidy3d/components/base_sim/data/sim_data.py index 88644a26c6..34c9eb8f9a 100644 --- a/tidy3d/components/base_sim/data/sim_data.py +++ b/tidy3d/components/base_sim/data/sim_data.py @@ -1,16 +1,14 @@ """Abstract base for simulation data structures.""" -from __future__ import annotations - from abc import ABC -from typing import Dict, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, field_validator, model_validator from ....exceptions import DataError, Tidy3dKeyError, ValidationError -from ...base import Tidy3dBaseModel, skip_if_fields_missing +from ...base import Tidy3dBaseModel from ...data.utils import UnstructuredGridDatasetType from ...types import FieldVal from ..simulation import AbstractSimulation @@ -22,20 +20,18 @@ class AbstractSimulationData(Tidy3dBaseModel, ABC): a :class:`AbstractSimulation`. """ - simulation: AbstractSimulation = pd.Field( - ..., + simulation: AbstractSimulation = Field( title="Simulation", description="Original :class:`AbstractSimulation` associated with the data.", ) - data: Tuple[AbstractMonitorData, ...] = pd.Field( - ..., + data: tuple[AbstractMonitorData, ...] = Field( title="Monitor Data", description="List of :class:`AbstractMonitorData` instances " "associated with the monitors of the original :class:`AbstractSimulation`.", ) - log: str = pd.Field( + log: Optional[str] = Field( None, title="Solver Log", description="A string containing the log information from the simulation run.", @@ -47,19 +43,18 @@ def __getitem__(self, monitor_name: str) -> AbstractMonitorData: return monitor_data.symmetry_expanded_copy @property - def monitor_data(self) -> Dict[str, AbstractMonitorData]: + def monitor_data(self) -> dict[str, AbstractMonitorData]: """Dictionary mapping monitor name to its associated :class:`AbstractMonitorData`.""" return {monitor_data.monitor.name: monitor_data for monitor_data in self.data} - @pd.validator("data", always=True) - @skip_if_fields_missing(["simulation"]) - def data_monitors_match_sim(cls, val, values): + @model_validator(mode="after") + def data_monitors_match_sim(self): """Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in ``.simulation``. """ - sim = values.get("simulation") + sim = self.simulation - for mnt_data in val: + for mnt_data in self.data: try: monitor_name = mnt_data.monitor.name sim.get_monitor_by_name(monitor_name) @@ -68,11 +63,10 @@ def data_monitors_match_sim(cls, val, values): f"Data with monitor name '{monitor_name}' supplied " f"but not found in the original '{sim.type}'." ) from exc - return val + return self - @pd.validator("data", always=True) - @skip_if_fields_missing(["simulation"]) - def validate_no_ambiguity(cls, val, values): + @field_validator("data") + def validate_no_ambiguity(val): """Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different monitors in ``.simulation``. """ diff --git a/tidy3d/components/base_sim/monitor.py b/tidy3d/components/base_sim/monitor.py index e5355e5679..f154511e60 100644 --- a/tidy3d/components/base_sim/monitor.py +++ b/tidy3d/components/base_sim/monitor.py @@ -1,14 +1,13 @@ """Abstract bases for classes that define how data is recorded from simulation.""" from abc import ABC, abstractmethod -from typing import Tuple import numpy as np -import pydantic.v1 as pd +from pydantic import Field from ..base import cached_property from ..geometry.base import Box -from ..types import ArrayFloat1D, Axis, Numpy +from ..types import ArrayFloat1D, Axis from ..validators import _warn_unsupported_traced_argument from ..viz import PlotParams, plot_params_monitor @@ -16,8 +15,7 @@ class AbstractMonitor(Box, ABC): """Abstract base class for steady-state monitors.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for monitor.", min_length=1, @@ -59,20 +57,20 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: Number of bytes to be stored in monitor. """ - def downsample(self, arr: Numpy, axis: Axis) -> Numpy: + def downsample(self, arr: np.ndarray, axis: Axis) -> np.ndarray: """Downsample a 1D array making sure to keep the first and last entries, based on the spatial interval defined for the ``axis``. Parameters ---------- - arr : Numpy + arr : np.ndarray A 1D array of arbitrary type. axis : Axis Axis for which to select the interval_space defined for the monitor. Returns ------- - Numpy + np.ndarray Downsampled array. """ @@ -88,7 +86,7 @@ def downsample(self, arr: Numpy, axis: Axis) -> Numpy: inds = np.append(inds, size - 1) return arr[inds] - def downsampled_num_cells(self, num_cells: Tuple[int, int, int]) -> Tuple[int, int, int]: + def downsampled_num_cells(self, num_cells: tuple[int, int, int]) -> tuple[int, int, int]: """Given a tuple of the number of cells spanned by the monitor along each dimension, return the number of cells one would have after downsampling based on ``interval_space``. """ diff --git a/tidy3d/components/base_sim/simulation.py b/tidy3d/components/base_sim/simulation.py index 0d3dfa228e..3af7f84c85 100644 --- a/tidy3d/components/base_sim/simulation.py +++ b/tidy3d/components/base_sim/simulation.py @@ -3,15 +3,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Literal, Optional import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, model_validator from ...exceptions import Tidy3dKeyError from ...log import log from ...version import __version__ -from ..base import cached_property, skip_if_fields_missing +from ..base import cached_property from ..geometry.base import Box from ..medium import Medium, MediumType3D from ..scene import Scene @@ -22,20 +22,15 @@ assert_objects_in_sim_bounds, assert_unique_names, ) -from ..viz import ( - PlotParams, - add_ax_if_none, - equal_aspect, - plot_params_symmetry, -) +from ..viz import PlotParams, add_ax_if_none, equal_aspect, plot_params_symmetry from .monitor import AbstractMonitor class AbstractSimulation(Box, ABC): """Base class for simulation classes of different solvers.""" - medium: MediumType3D = pd.Field( - Medium(), + medium: MediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, @@ -44,7 +39,7 @@ class AbstractSimulation(Box, ABC): Background medium of simulation, defaults to vacuum if not specified. """ - structures: Tuple[Structure, ...] = pd.Field( + structures: tuple[Structure, ...] = Field( (), title="Structures", description="Tuple of structures present in simulation. " @@ -73,7 +68,7 @@ class AbstractSimulation(Box, ABC): ) """ - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pd.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -81,37 +76,37 @@ class AbstractSimulation(Box, ABC): "at the simulation center of each axis, respectively. ", ) - sources: Tuple[None, ...] = pd.Field( + sources: tuple[None, ...] = Field( (), title="Sources", description="Sources in the simulation.", ) - boundary_spec: None = pd.Field( + boundary_spec: Literal[None] = Field( None, title="Boundaries", description="Specification of boundary conditions.", ) - monitors: Tuple[None, ...] = pd.Field( + monitors: tuple[None, ...] = Field( (), title="Monitors", description="Monitors in the simulation. ", ) - grid_spec: None = pd.Field( + grid_spec: Literal[None] = Field( None, title="Grid Specification", description="Specifications for the simulation grid.", ) - version: str = pd.Field( + version: str = Field( __version__, title="Version", description="String specifying the front end version number.", ) - plot_length_units: Optional[LengthUnit] = pd.Field( + plot_length_units: Optional[LengthUnit] = Field( "μm", title="Plot Units", description="When set to a supported ``LengthUnit``, " @@ -121,17 +116,17 @@ class AbstractSimulation(Box, ABC): """ Validating setup """ - @pd.root_validator(pre=True) - def _update_simulation(cls, values): + @model_validator(mode="before") + @classmethod + def _update_simulation(cls, data): """Update the simulation if it is an earlier version.""" - # dummy upgrade of version number # this should be overriden by each simulation class if needed - current_version = values.get("version") + current_version = data.get("version") if current_version != __version__ and current_version is not None: log.warning(f"updating {cls.__name__} from {current_version} to {__version__}") - values["version"] = __version__ - return values + data["version"] = __version__ + return data # make sure all names are unique _unique_monitor_names = assert_unique_names("monitors") @@ -144,20 +139,19 @@ def _update_simulation(cls, values): _warn_traced_center = _warn_unsupported_traced_argument("center") _warn_traced_size = _warn_unsupported_traced_argument("size") - @pd.validator("structures", always=True) - @skip_if_fields_missing(["size", "center"]) - def _structures_not_at_edges(cls, val, values): + @model_validator(mode="after") + def _structures_not_at_edges(self): """Warn if any structures lie at the simulation boundaries.""" - if val is None: - return val + if self.structures is None: + return self - sim_box = Box(size=values.get("size"), center=values.get("center")) + sim_box = Box(size=self.size, center=self.center) sim_bound_min, sim_bound_max = sim_box.bounds sim_bounds = list(sim_bound_min) + list(sim_bound_max) with log as consolidated_logger: - for istruct, structure in enumerate(val): + for istruct, structure in enumerate(self.structures): struct_bound_min, struct_bound_max = structure.geometry.bounds struct_bounds = list(struct_bound_min) + list(struct_bound_max) @@ -172,13 +166,14 @@ def _structures_not_at_edges(cls, val, values): ) continue - return val + return self """ Post-init validators """ - def _post_init_validators(self) -> None: + @property + def _post_init_validators(self): """Call validators taking z`self` that get run after init.""" - _ = self.scene + return (lambda: self.scene,) def validate_pre_upload(self) -> None: """Validate the fully initialized simulation is ok for upload to our servers.""" @@ -233,8 +228,8 @@ def plot( ax: Ax = None, source_alpha: float = None, monitor_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, fill_structures: bool = True, **patch_kwargs, ) -> Ax: @@ -254,9 +249,9 @@ def plot( Opacity of the monitors. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill_structures : bool = True Whether to fill structures with color or just draw outlines. @@ -295,8 +290,8 @@ def plot_sources( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, alpha: float = None, ax: Ax = None, ) -> Ax: @@ -310,9 +305,9 @@ def plot_sources( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -343,8 +338,8 @@ def plot_monitors( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, alpha: float = None, ax: Ax = None, ) -> Ax: @@ -358,9 +353,9 @@ def plot_monitors( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -391,8 +386,8 @@ def plot_symmetries( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ax: Ax = None, ) -> Ax: """Plot each of simulation's symmetries on a plane defined by one nonzero x,y,z coordinate. @@ -405,9 +400,9 @@ def plot_symmetries( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. @@ -501,8 +496,8 @@ def plot_structures( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, fill: bool = True, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. @@ -517,9 +512,9 @@ def plot_structures( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill : bool = True Whether to fill structures with color or just draw outlines. @@ -549,8 +544,8 @@ def plot_structures_eps( cbar: bool = True, reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. The permittivity is plotted in grayscale based on its value at the specified frequency. @@ -576,9 +571,9 @@ def plot_structures_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -615,8 +610,8 @@ def plot_structures_heat_conductivity( cbar: bool = True, reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. The permittivity is plotted in grayscale based on its value at the specified frequency. @@ -642,9 +637,9 @@ def plot_structures_heat_conductivity( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns diff --git a/tidy3d/components/base_sim/source.py b/tidy3d/components/base_sim/source.py index cdaff53d66..eac7c55f4e 100644 --- a/tidy3d/components/base_sim/source.py +++ b/tidy3d/components/base_sim/source.py @@ -1,10 +1,9 @@ """Abstract base for classes that define simulation sources.""" -from __future__ import annotations - from abc import ABC, abstractmethod +from typing import Optional -import pydantic.v1 as pydantic +from pydantic import Field from ..base import Tidy3dBaseModel from ..validators import validate_name_str @@ -14,7 +13,11 @@ class AbstractSource(Tidy3dBaseModel, ABC): """Abstract base class for all sources.""" - name: str = pydantic.Field(None, title="Name", description="Optional name for the source.") + name: Optional[str] = Field( + None, + title="Name", + description="Optional name for the source.", + ) @abstractmethod def plot_params(self) -> PlotParams: diff --git a/tidy3d/components/bc_placement.py b/tidy3d/components/bc_placement.py index c59c11962e..3244c7b769 100644 --- a/tidy3d/components/bc_placement.py +++ b/tidy3d/components/bc_placement.py @@ -1,15 +1,13 @@ """Defines placements for boundary conditions.""" -from __future__ import annotations - from abc import ABC -from typing import Tuple, Union +from typing import Union -import pydantic.v1 as pd +from pydantic import Field, field_validator from ..exceptions import SetupError from .base import Tidy3dBaseModel -from .types import BoxSurface +from .types import BoxSurface, discriminated_union class AbstractBCPlacement(ABC, Tidy3dBaseModel): @@ -24,7 +22,7 @@ class StructureBoundary(AbstractBCPlacement): >>> bc_placement = StructureBoundary(structure="box") """ - structure: str = pd.Field( + structure: str = Field( title="Structure Name", description="Name of the structure.", ) @@ -38,13 +36,13 @@ class StructureStructureInterface(AbstractBCPlacement): >>> bc_placement = StructureStructureInterface(structures=["box", "sphere"]) """ - structures: Tuple[str, str] = pd.Field( + structures: tuple[str, str] = Field( title="Structures", description="Names of two structures.", ) - @pd.validator("structures", always=True) - def unique_names(cls, val): + @field_validator("structures") + def unique_names(val): """Error if the same structure is provided twice""" if val[0] == val[1]: raise SetupError( @@ -61,13 +59,13 @@ class MediumMediumInterface(AbstractBCPlacement): >>> bc_placement = MediumMediumInterface(mediums=["dieletric", "metal"]) """ - mediums: Tuple[str, str] = pd.Field( + mediums: tuple[str, str] = Field( title="Mediums", description="Names of two mediums.", ) - @pd.validator("mediums", always=True) - def unique_names(cls, val): + @field_validator("mediums") + def unique_names(val): """Error if the same structure is provided twice""" if val[0] == val[1]: raise SetupError("The same medium is provided twice in 'MediumMediumInterface'.") @@ -82,7 +80,7 @@ class SimulationBoundary(AbstractBCPlacement): >>> bc_placement = SimulationBoundary(surfaces=["x-", "x+"]) """ - surfaces: Tuple[BoxSurface, ...] = pd.Field( + surfaces: tuple[BoxSurface, ...] = Field( ("x-", "x+", "y-", "y+", "z-", "z+"), title="Surfaces", description="Surfaces of simulation domain where to apply boundary conditions.", @@ -97,22 +95,24 @@ class StructureSimulationBoundary(AbstractBCPlacement): >>> bc_placement = StructureSimulationBoundary(structure="box", surfaces=["y-", "y+"]) """ - structure: str = pd.Field( + structure: str = Field( title="Structure Name", description="Name of the structure.", ) - surfaces: Tuple[BoxSurface, ...] = pd.Field( + surfaces: tuple[BoxSurface, ...] = Field( ("x-", "x+", "y-", "y+", "z-", "z+"), title="Surfaces", description="Surfaces of simulation domain where to apply boundary conditions.", ) -BCPlacementType = Union[ - StructureBoundary, - StructureStructureInterface, - MediumMediumInterface, - SimulationBoundary, - StructureSimulationBoundary, -] +BCPlacementType = discriminated_union( + Union[ + StructureBoundary, + StructureStructureInterface, + MediumMediumInterface, + SimulationBoundary, + StructureSimulationBoundary, + ] +) diff --git a/tidy3d/components/beam.py b/tidy3d/components/beam.py index 4a6f632466..6f877d8837 100644 --- a/tidy3d/components/beam.py +++ b/tidy3d/components/beam.py @@ -2,10 +2,10 @@ astigmatic Gaussian beam.""" from abc import abstractmethod -from typing import Optional, Tuple, Union +from typing import Optional, Union import autograd.numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from ..constants import C_0, ETA_0, HERTZ, MICROMETER, RADIAN from .base import cached_property @@ -16,7 +16,7 @@ from .medium import Medium, MediumType from .monitor import FieldMonitor from .source.field import FixedAngleSpec, FixedInPlaneKSpec -from .types import TYPE_TAG_STR, Direction, FreqArray, Literal, Numpy +from .types import TYPE_TAG_STR, Direction, FreqArray, Literal from .validators import assert_plane DEFAULT_RESOLUTION = 200 @@ -25,7 +25,7 @@ class BeamProfile(Box): """Base class for handling analytic beams.""" - resolution: float = pd.Field( + resolution: float = Field( DEFAULT_RESOLUTION, title="Sampling resolution", description="Sampling resolution in the tangential directions of the beam (defines a " @@ -33,27 +33,26 @@ class BeamProfile(Box): units=MICROMETER, ) - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="List of frequencies at which the beam is sampled.", units=HERTZ, ) - background_medium: MediumType = pd.Field( - Medium(), + background_medium: MediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium in which the beam is embedded.", ) - angle_theta: float = pd.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the normal axis.", units=RADIAN, ) - angle_phi: float = pd.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -61,7 +60,7 @@ class BeamProfile(Box): units=RADIAN, ) - pol_angle: float = pd.Field( + pol_angle: float = Field( 0.0, title="Polarization Angle", description="Specifies the angle between the electric field polarization of the " @@ -75,7 +74,7 @@ class BeamProfile(Box): units=RADIAN, ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Direction", description="Specifies propagation in the positive or negative direction of the normal " @@ -129,7 +128,7 @@ def field_data(self) -> FieldData: return data_raw.updated_copy(**fields_norm) - def _field_data_on_grid(self, grid: Grid, background_n: Numpy, colocate=True) -> dict: + def _field_data_on_grid(self, grid: Grid, background_n: np.ndarray, colocate=True) -> dict: """Compute the field data for each field component on a grid for the beam. A dictionary of the scalar field data arrays is returned, not yet packaged as ``FieldData``. """ @@ -163,15 +162,15 @@ def _field_data_on_grid(self, grid: Grid, background_n: Numpy, colocate=True) -> return scalar_fields @abstractmethod - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray: """Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is computed on an unstructured array ``points`` of shape ``(3, ...)``.""" pass def analytic_beam_z_normal( - self, points: Numpy, background_n: float, field: Literal["E", "H"] - ) -> Numpy: + self, points: np.ndarray, background_n: float, field: Literal["E", "H"] + ) -> np.ndarray: """Analytic beam with all the beam parameters but assuming ``z`` as the normal axis.""" # Add a frequency dimension to points @@ -211,12 +210,12 @@ def analytic_beam_z_normal( def analytic_beam( self, - x: Numpy, - y: Numpy, - z: Numpy, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, background_n: float, field: Literal["E", "H"], - ) -> Numpy: + ) -> np.ndarray: """Sample the analytic beam fields on a cartesian grid of points in x, y, z.""" # Make a meshgrid @@ -240,13 +239,15 @@ def analytic_beam( # Reshape to (3, Nx, Ny, Nz, num_freqs) return np.reshape(field_vals, (3, Nx, Ny, Nz, len(self.freqs))) - def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: + def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.ndarray: """Rotate points to new coordinates where z is the propagation axis.""" points_prop_z = self.rotate_points(points, [0, 0, 1], -self.angle_phi) points_prop_z = self.rotate_points(points_prop_z, [0, 1, 0], -self.angle_theta) return points_prop_z - def _inverse_rotate_field_vals_z(self, field_vals: Numpy, background_n: Numpy) -> Numpy: + def _inverse_rotate_field_vals_z( + self, field_vals: np.ndarray, background_n: np.ndarray + ) -> np.ndarray: """Rotate field values from coordinates where z is the propagation axis to angled coordinates.""" field_vals = self.rotate_points(field_vals, [0, 1, 0], self.angle_theta) @@ -261,14 +262,14 @@ class PlaneWaveBeamProfile(BeamProfile): See also :class:`.PlaneWave`. """ - angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = pd.Field( + angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = Field( FixedAngleSpec(), title="Angular Dependence Specification", description="Specification of plane wave propagation direction dependence on wavelength.", discriminator=TYPE_TAG_STR, ) - as_fixed_angle_source: bool = pd.Field( + as_fixed_angle_source: bool = Field( False, title="Fixed Angle Flag", description="Fixed angle flag. Only used internally when computing source beams for " @@ -276,7 +277,7 @@ class PlaneWaveBeamProfile(BeamProfile): "switch between waves with fixed angle and fixed in-plane k.", ) - angle_theta_frequency: Optional[float] = pd.Field( + angle_theta_frequency: Optional[float] = Field( None, title="Frequency at Which Angle Theta is Defined", description="Frequency for which ``angle_theta`` is set. This only has an effect for " @@ -296,7 +297,7 @@ def in_plane_k(self, background_n: float): k_in_plane = k0.real * np.sin(self.angle_theta) return [k_in_plane * np.cos(self.angle_phi), k_in_plane * np.sin(self.angle_phi)] - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray: """Scalar field for plane wave. Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is @@ -311,14 +312,14 @@ def scalar_field(self, points: Numpy, background_n: float) -> Numpy: kz *= np.cos(self.angle_theta) return np.exp(1j * points[2] * kz) - def _angle_theta_actual(self, background_n: Numpy) -> Numpy: + def _angle_theta_actual(self, background_n: np.ndarray) -> np.ndarray: """Compute the frequency-dependent actual propagation angle theta.""" k0 = 2 * np.pi * np.array(self.freqs) / C_0 * background_n kx, ky = self.in_plane_k(background_n) k_perp = np.sqrt(kx**2 + ky**2) return np.real(np.arcsin(k_perp / k0)) - def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: + def _rotate_points_z(self, points: np.ndarray, background_n: np.ndarray) -> np.ndarray: """Rotate points to new coordinates where z is the propagation axis.""" if self.as_fixed_angle_source: # For fixed-angle, we do not rotate the points @@ -332,7 +333,9 @@ def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: return points return super()._rotate_points_z(points, background_n) - def _inverse_rotate_field_vals_z(self, field_vals: Numpy, background_n: Numpy) -> Numpy: + def _inverse_rotate_field_vals_z( + self, field_vals: np.ndarray, background_n: np.ndarray + ) -> np.ndarray: """Rotate field values from coordinates where z is the propagation axis to angled coordinates. Special handling is needed if fixed in-plane k wave.""" if isinstance(self.angular_spec, FixedInPlaneKSpec): @@ -354,14 +357,14 @@ class GaussianBeamProfile(BeamProfile): See also :class:`.GaussianBeam`. """ - waist_radius: pd.PositiveFloat = pd.Field( + waist_radius: PositiveFloat = Field( 1.0, title="Waist Radius", description="Radius of the beam at the waist.", units=MICROMETER, ) - waist_distance: float = pd.Field( + waist_distance: float = Field( 0.0, title="Waist Distance", description="Distance from the beam waist along the propagation direction. " @@ -373,14 +376,16 @@ class GaussianBeamProfile(BeamProfile): units=MICROMETER, ) - def beam_params(self, z: Numpy, k0: Numpy) -> Tuple[Numpy, Numpy, Numpy]: + def beam_params( + self, z: np.ndarray, k0: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Compute the parameters needed to evaluate a Gaussian beam at z. Parameters ---------- - z : Numpy + z : np.ndarray Axial distance from the beam center. - k0 : Numpy + k0 : np.ndarray Wave vector magnitude. """ @@ -395,7 +400,7 @@ def beam_params(self, z: Numpy, k0: Numpy) -> Tuple[Numpy, Numpy, Numpy]: psi_g = np.arctan((z + z_0) / z_r) - np.arctan(z_0 / z_r) return w_z, inv_r_z, psi_g - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray: """Scalar field for Gaussian beam. Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is @@ -420,14 +425,14 @@ class AstigmaticGaussianBeamProfile(BeamProfile): See also :class:`.AstigmaticGaussianBeam`. """ - waist_sizes: Tuple[pd.PositiveFloat, pd.PositiveFloat] = pd.Field( + waist_sizes: tuple[PositiveFloat, PositiveFloat] = Field( (1.0, 1.0), title="Waist sizes", description="Size of the beam at the waist in the local x and y directions.", units=MICROMETER, ) - waist_distances: Tuple[float, float] = pd.Field( + waist_distances: tuple[float, float] = Field( (0.0, 0.0), title="Waist distances", description="Distance to the beam waist along the propagation direction " @@ -439,14 +444,16 @@ class AstigmaticGaussianBeamProfile(BeamProfile): units=MICROMETER, ) - def beam_params(self, z: Numpy, k0: Numpy) -> Tuple[Numpy, Numpy, Numpy, Numpy]: + def beam_params( + self, z: np.ndarray, k0: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute the parameters needed to evaluate an astigmatic Gaussian beam at z. Parameters ---------- - z : Numpy + z : np.ndarray Axial distance from the beam center. - k0 : Numpy + k0 : np.ndarray Wave vector magnitude. """ @@ -466,7 +473,7 @@ def beam_params(self, z: Numpy, k0: Numpy) -> Tuple[Numpy, Numpy, Numpy, Numpy]: return w_0, w_z, inv_r_z, psi_g - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: np.ndarray, background_n: float) -> np.ndarray: """ Scalar field for astigmatic Gaussian beam. Scalar field corresponding to the analytic beam in coordinate system such that the diff --git a/tidy3d/components/boundary.py b/tidy3d/components/boundary.py index 2e0acc62fa..f1fd5111e3 100644 --- a/tidy3d/components/boundary.py +++ b/tidy3d/components/boundary.py @@ -1,20 +1,19 @@ """Defines electromagnetic boundary conditions""" -from __future__ import annotations - from abc import ABC -from typing import List, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, NonNegativeInt, model_validator +from ..compat import Self from ..constants import EPSILON_0, MU_0, PML_SIGMA from ..exceptions import DataError, SetupError from ..log import log from .base import Tidy3dBaseModel, cached_property from .medium import Medium from .source.field import TFSF, GaussianBeam, ModeSource, PlaneWave -from .types import TYPE_TAG_STR, Axis, Complex +from .types import Axis, Complex, discriminated_union MIN_NUM_PML_LAYERS = 6 @@ -22,7 +21,11 @@ class BoundaryEdge(ABC, Tidy3dBaseModel): """Electromagnetic boundary condition at a domain edge.""" - name: str = pd.Field(None, title="Name", description="Optional unique name for boundary.") + name: Optional[str] = Field( + None, + title="Name", + description="Optional unique name for boundary.", + ) # PBC keyword @@ -69,8 +72,7 @@ class BlochBoundary(BoundaryEdge): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - bloch_vec: float = pd.Field( - ..., + bloch_vec: float = Field( title="Normalized Bloch vector component", description="Normalized component of the Bloch vector in units of " "2 * pi / (size along dimension) in the background medium, " @@ -85,7 +87,7 @@ def bloch_phase(self) -> Complex: @classmethod def from_source( cls, source: BlochSourceType, domain_size: float, axis: Axis, medium: Medium = None - ) -> BlochBoundary: + ) -> Self: """Set the Bloch vector component based on a given angled source and its center frequency. Note that if a broadband angled source is used, only the frequency components near the center frequency will exhibit angled incidence at the expect angle. In this case, a @@ -173,20 +175,20 @@ class AbsorberParams(Tidy3dBaseModel): >>> params = AbsorberParams(sigma_order=3, sigma_min=0.0, sigma_max=1.5) """ - sigma_order: pd.NonNegativeInt = pd.Field( + sigma_order: NonNegativeInt = Field( 3, title="Sigma Order", description="Order of the polynomial describing the absorber profile (~dist^sigma_order).", ) - sigma_min: pd.NonNegativeFloat = pd.Field( + sigma_min: NonNegativeFloat = Field( 0.0, title="Sigma Minimum", description="Minimum value of the absorber conductivity.", units=PML_SIGMA, ) - sigma_max: pd.NonNegativeFloat = pd.Field( + sigma_max: NonNegativeFloat = Field( 1.5, title="Sigma Maximum", description="Maximum value of the absorber conductivity.", @@ -202,29 +204,29 @@ class PMLParams(AbsorberParams): >>> params = PMLParams(sigma_order=3, sigma_min=0.0, sigma_max=1.5, kappa_min=0.0) """ - kappa_order: pd.NonNegativeInt = pd.Field( + kappa_order: NonNegativeInt = Field( 3, title="Kappa Order", description="Order of the polynomial describing the PML kappa profile " "(kappa~dist^kappa_order).", ) - kappa_min: pd.NonNegativeFloat = pd.Field(0.0, title="Kappa Minimum", description="") + kappa_min: NonNegativeFloat = Field(0.0, title="Kappa Minimum") - kappa_max: pd.NonNegativeFloat = pd.Field(1.5, title="Kappa Maximum", description="") + kappa_max: NonNegativeFloat = Field(1.5, title="Kappa Maximum") - alpha_order: pd.NonNegativeInt = pd.Field( + alpha_order: NonNegativeInt = Field( 3, title="Alpha Order", description="Order of the polynomial describing the PML alpha profile " "(alpha~dist^alpha_order).", ) - alpha_min: pd.NonNegativeFloat = pd.Field( + alpha_min: NonNegativeFloat = Field( 0.0, title="Alpha Minimum", description="Minimum value of the PML alpha.", units=PML_SIGMA ) - alpha_max: pd.NonNegativeFloat = pd.Field( + alpha_max: NonNegativeFloat = Field( 1.5, title="Alpha Maximum", description="Maximum value of the PML alpha.", units=PML_SIGMA ) @@ -262,14 +264,12 @@ class PMLParams(AbsorberParams): class AbsorberSpec(BoundaryEdge): """Specifies the generic absorber properties along a single dimension.""" - num_layers: int = pd.Field( - ..., + num_layers: float = Field( title="Number of Layers", description="Number of layers of standard PML.", ge=MIN_NUM_PML_LAYERS, ) - parameters: AbsorberParams = pd.Field( - ..., + parameters: AbsorberParams = Field( title="Absorber Parameters", description="Parameters to fine tune the absorber profile and properties.", ) @@ -379,14 +379,14 @@ class PML(AbsorberSpec): """ - num_layers: int = pd.Field( + num_layers: int = Field( 12, title="Number of Layers", description="Number of layers of standard PML.", ge=MIN_NUM_PML_LAYERS, ) - parameters: PMLParams = pd.Field( + parameters: PMLParams = Field( DefaultPMLParameters, title="PML Parameters", description="Parameters of the complex frequency-shifted absorption poles.", @@ -417,14 +417,14 @@ class StablePML(AbsorberSpec): * `Introduction to perfectly matched layer (PML) tutorial `__ """ - num_layers: int = pd.Field( + num_layers: int = Field( 40, title="Number of Layers", description="Number of layers of 'stable' PML.", ge=MIN_NUM_PML_LAYERS, ) - parameters: PMLParams = pd.Field( + parameters: PMLParams = Field( DefaultStablePMLParameters, title="Stable PML Parameters", description="'Stable' parameters of the complex frequency-shifted absorption poles.", @@ -470,14 +470,14 @@ class Absorber(AbsorberSpec): * `How to troubleshoot a diverged FDTD simulation <../../notebooks/DivergedFDTDSimulation.html>`_ """ - num_layers: int = pd.Field( + num_layers: int = Field( 40, title="Number of Layers", description="Number of layers of absorber to add to + and - boundaries.", ge=MIN_NUM_PML_LAYERS, ) - parameters: AbsorberParams = pd.Field( + parameters: AbsorberParams = Field( DefaultAbsorberParameters, title="Absorber Parameters", description="Adiabatic absorber parameters.", @@ -492,9 +492,9 @@ class Absorber(AbsorberSpec): # types of boundaries that can be used in Simulation -BoundaryEdgeType = Union[ - Periodic, PECBoundary, PMCBoundary, PML, StablePML, Absorber, BlochBoundary -] +BoundaryEdgeType = discriminated_union( + Union[Periodic, PECBoundary, PMCBoundary, PML, StablePML, Absorber, BlochBoundary] +) class Boundary(Tidy3dBaseModel): @@ -525,67 +525,65 @@ class Boundary(Tidy3dBaseModel): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - plus: BoundaryEdgeType = pd.Field( - PML(), + plus: BoundaryEdgeType = Field( + default_factory=PML, title="Plus BC", description="Boundary condition on the plus side along a dimension.", - discriminator=TYPE_TAG_STR, ) - minus: BoundaryEdgeType = pd.Field( - PML(), + minus: BoundaryEdgeType = Field( + default_factory=PML, title="Minus BC", description="Boundary condition on the minus side along a dimension.", - discriminator=TYPE_TAG_STR, ) - @pd.root_validator(skip_on_failure=True) - def bloch_on_both_sides(cls, values): + @model_validator(mode="after") + def bloch_on_both_sides(self): """Error if a Bloch boundary is applied on only one side.""" - plus = values.get("plus") - minus = values.get("minus") - num_bloch = isinstance(plus, BlochBoundary) + isinstance(minus, BlochBoundary) + num_bloch = isinstance(self.plus, BlochBoundary) + isinstance(self.minus, BlochBoundary) if num_bloch == 1: raise SetupError( "Bloch boundaries must be applied either on both sides or on neither side." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def periodic_with_pml(cls, values): + @model_validator(mode="after") + def periodic_with_pml(self): """Error if PBC is specified with a PML.""" - plus = values.get("plus") - minus = values.get("minus") - num_pbc = isinstance(plus, Periodic) + isinstance(minus, Periodic) - num_pml = isinstance(plus, (PML, StablePML, Absorber)) + isinstance( - minus, (PML, StablePML, Absorber) + num_pbc = isinstance(self.plus, Periodic) + isinstance(self.minus, Periodic) + num_pml = isinstance(self.plus, (PML, StablePML, Absorber)) + isinstance( + self.minus, (PML, StablePML, Absorber) ) if num_pbc == 1 and num_pml == 1: - raise SetupError("Cannot have both PML and PBC along the same dimension.") - return values - - @pd.root_validator(skip_on_failure=True) - def periodic_with_pec_pmc(cls, values): - """If a PBC is specified along with PEC or PMC on the other side, manually set the PBC - to PEC or PMC so that no special treatment of halos is required.""" - plus = values.get("plus") - minus = values.get("minus") + raise SetupError("Cannot have both 'PML' and 'Periodic' along the same dimension.") + return self + @model_validator(mode="after") + def periodic_with_pec_pmc(self): + """ + If a PBC is specified along with PEC or PMC on the other side, manually set the PBC + to PEC or PMC so that no special treatment of halos is required. + """ + plus, minus = self.plus, self.minus switched = False + if isinstance(minus, (PECBoundary, PMCBoundary)) and isinstance(plus, Periodic): plus = minus switched = True elif isinstance(plus, (PECBoundary, PMCBoundary)) and isinstance(minus, Periodic): minus = plus switched = True + if switched: - values.update({"plus": plus, "minus": minus}) + object.__setattr__(self, "plus", plus) + object.__setattr__(self, "minus", minus) log.warning( "A periodic boundary condition was specified on the opposite side of a perfect " "electric or magnetic conductor boundary. This periodic boundary condition will " "be replaced by the perfect electric or magnetic conductor across from it." ) - return values + + return self @classmethod def periodic(cls): @@ -676,7 +674,7 @@ def pmc(cls): return cls(plus=plus, minus=minus) @classmethod - def pml(cls, num_layers: pd.NonNegativeInt = 12, parameters: PMLParams = DefaultPMLParameters): + def pml(cls, num_layers: NonNegativeInt = 12, parameters: PMLParams = DefaultPMLParameters): """PML boundary specification on both sides along a dimension. Parameters @@ -696,7 +694,7 @@ def pml(cls, num_layers: pd.NonNegativeInt = 12, parameters: PMLParams = Default @classmethod def stable_pml( - cls, num_layers: pd.NonNegativeInt = 40, parameters: PMLParams = DefaultStablePMLParameters + cls, num_layers: NonNegativeInt = 40, parameters: PMLParams = DefaultStablePMLParameters ): """Stable PML boundary specification on both sides along a dimension. @@ -717,7 +715,7 @@ def stable_pml( @classmethod def absorber( - cls, num_layers: pd.NonNegativeInt = 40, parameters: PMLParams = DefaultAbsorberParameters + cls, num_layers: NonNegativeInt = 40, parameters: PMLParams = DefaultAbsorberParameters ): """Adiabatic absorber boundary specification on both sides along a dimension. @@ -772,24 +770,24 @@ class BoundarySpec(Tidy3dBaseModel): * `Using FDTD to Compute a Transmission Spectrum `__ """ - x: Boundary = pd.Field( - Boundary(), + x: Boundary = Field( + default_factory=Boundary, title="Boundary condition along x.", description="Boundary condition on the plus and minus sides along the x axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " "so explicitly setting the boundaries is recommended.", ) - y: Boundary = pd.Field( - Boundary(), + y: Boundary = Field( + default_factory=Boundary, title="Boundary condition along y.", description="Boundary condition on the plus and minus sides along the y axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " "so explicitly setting the boundaries is recommended.", ) - z: Boundary = pd.Field( - Boundary(), + z: Boundary = Field( + default_factory=Boundary, title="Boundary condition along z.", description="Boundary condition on the plus and minus sides along the z axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " @@ -906,7 +904,7 @@ def all_sides(cls, boundary: BoundaryEdge): ) @cached_property - def to_list(self) -> List[Tuple[BoundaryEdgeType, BoundaryEdgeType]]: + def to_list(self) -> list[tuple[BoundaryEdgeType, BoundaryEdgeType]]: """Returns edge-wise boundary conditions along each dimension for internal use.""" return [ (self.x.minus, self.x.plus), @@ -915,7 +913,7 @@ def to_list(self) -> List[Tuple[BoundaryEdgeType, BoundaryEdgeType]]: ] @cached_property - def flipped_bloch_vecs(self) -> BoundarySpec: + def flipped_bloch_vecs(self) -> Self: """Return a copy of the instance where all Bloch vectors are multiplied by -1.""" bound_dims = dict(x=self.x.copy(), y=self.y.copy(), z=self.z.copy()) for dim_key, bound_dim in bound_dims.items(): diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 9cabd11796..de48d32ff9 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -1,15 +1,16 @@ """Storing tidy3d data at it's most fundamental level as xr.DataArray objects""" -from __future__ import annotations - from abc import ABC -from typing import Any, Dict, List, Mapping, Union +from typing import Any, Mapping, Optional, Union import autograd.numpy as anp import h5py import numpy as np import xarray as xr from autograd.tracer import isbox +from pydantic.annotated_handlers import GetCoreSchemaHandler +from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue +from pydantic_core import core_schema from xarray.core import missing from xarray.core.indexes import PandasIndex from xarray.core.indexing import _outer_to_numpy_indexer @@ -68,7 +69,7 @@ class DataArray(xr.DataArray): # stores an ordered tuple of strings corresponding to the data dimensions _dims = () # stores a dictionary of attributes corresponding to the data values - _data_attrs: Dict[str, str] = {} + _data_attrs: dict[str, str] = {} def __init__(self, data, *args, **kwargs): # if data is a vanilla autograd box, convert to our box @@ -84,40 +85,116 @@ def __init__(self, data, *args, **kwargs): super().__init__(data, *args, **kwargs) @classmethod - def __get_validators__(cls): - """Validators that get run when :class:`.DataArray` objects are added to pydantic models.""" - yield cls.check_unloaded_data - yield cls.validate_dims - yield cls.assign_data_attrs - yield cls.assign_coord_attrs + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Core schema definition for validation & serialization.""" + + def _initial_parser(value: Any) -> Self: + if isinstance(value, cls): + return value + + if isinstance(value, str) and value == cls.__name__: + raise DataError( + f"Trying to load '{cls.__name__}' from string placeholder '{value}' " + "but the actual data is missing. DataArrays are not typically stored " + "in JSON. Load from HDF5 or ensure the DataArray object is provided." + ) + + try: + instance = cls(value) + if not isinstance(instance, cls): + raise TypeError( + f"Constructor for {cls.__name__} returned unexpected type {type(instance)}" + ) + return instance + except Exception as e: + raise ValueError( + f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. " + f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}" + ) from e + + validation_schema = core_schema.no_info_plain_validator_function(_initial_parser) + validation_schema = core_schema.no_info_after_validator_function( + cls._validate_dims, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_data_attrs, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_coord_attrs, validation_schema + ) + + def _serialize_to_name(instance: Self) -> str: + return type(instance).__name__ + + # serialization behavior: + # - for JSON ('json' mode), use the _serialize_to_name function. + # - for Python ('python' mode), use Pydantic's default for the object type + serialization_schema = core_schema.plain_serializer_function_ser_schema( + _serialize_to_name, + return_schema=core_schema.str_schema(), + when_used="json", + ) + + return core_schema.json_or_python_schema( + python_schema=validation_schema, + json_schema=validation_schema, # Use same validation rules for JSON input + serialization=serialization_schema, + ) @classmethod - def check_unloaded_data(cls, val): - """If the data comes in as the raw data array string, raise a custom warning.""" - if isinstance(val, str) and val in DATA_ARRAY_MAP: - raise DataError( - f"Trying to load {cls.__name__} but the data is not present. " - "Note that data will not be saved to .json file. " - "use .hdf5 format instead if data present." - ) - return cls(val) + def __get_pydantic_json_schema__( + cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + """JSON schema definition (defines how it LOOKS in a schema, not the data).""" + json_schema = handler(core_schema_obj) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema.update( + { + "type": "string", + "title": cls.__name__, + "description": ( + f"Placeholder for a '{cls.__name__}' object. Actual data is typically " + "serialized separately (e.g., via HDF5) and not embedded in JSON." + ), + } + ) + return json_schema @classmethod - def validate_dims(cls, val): - """Make sure the dims are the same as _dims, then put them in the correct order.""" + def _validate_dims(cls, val: Self) -> Self: + """Make sure the dims are the same as ``_dims``, then put them in the correct order.""" if set(val.dims) != set(cls._dims): - raise ValueError(f"wrong dims, expected '{cls._dims}', got '{val.dims}'") - return val.transpose(*cls._dims) + raise ValueError( + f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'" + ) + if val.dims != cls._dims: + val = val.transpose(*cls._dims) + return val @classmethod - def assign_data_attrs(cls, val): + def _assign_data_attrs(cls, val: Self) -> Self: """Assign the correct data attributes to the :class:`.DataArray`.""" + for attr_name, attr_val in cls._data_attrs.items(): + val.attrs[attr_name] = attr_val + return val - for attr_name, attr in cls._data_attrs.items(): - val.attrs[attr_name] = attr + @classmethod + def _assign_coord_attrs(cls, val: Self) -> Self: + """Assign the correct coordinate attributes to the :class:`.DataArray`.""" + target_dims = set(val.dims) & set(cls._dims) & set(val.coords) + for dim in target_dims: + template = DIM_ATTRS.get(dim) + if not template: + continue + + coord_attrs = val.coords[dim].attrs + missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v} + coord_attrs.update(missing) return val - def _interp_validator(self, field_name: str = None) -> None: + def _interp_validator(self, field_name: Optional[str] = None) -> None: """Ensure the data can be interpolated or selected by checking for duplicate coordinates. NOTE @@ -126,7 +203,7 @@ def _interp_validator(self, field_name: str = None) -> None: called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. """ if field_name is None: - field_name = "DataArray" + field_name = self.__class__.__name__ for dim, coord in self.coords.items(): if coord.to_index().duplicated().any(): @@ -136,39 +213,6 @@ def _interp_validator(self, field_name: str = None) -> None: f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." ) - @classmethod - def assign_coord_attrs(cls, val): - """Assign the correct coordinate attributes to the :class:`.DataArray`.""" - - for dim in cls._dims: - dim_attrs = DIM_ATTRS.get(dim) - if dim_attrs is not None: - for attr_name, attr in dim_attrs.items(): - val.coords[dim].attrs[attr_name] = attr - return val - - @classmethod - def __modify_schema__(cls, field_schema): - """Sets the schema of DataArray object.""" - - schema = dict( - title="DataArray", - type="xr.DataArray", - properties=dict( - _dims=dict( - title="_dims", - type="Tuple[str, ...]", - ), - ), - required=["_dims"], - ) - field_schema.update(schema) - - @classmethod - def _json_encoder(cls, val): - """What function to call when writing a DataArray to json.""" - return type(val).__name__ - def __eq__(self, other) -> bool: """Whether two data array objects are equal.""" @@ -210,20 +254,15 @@ def is_uniform(self): return np.allclose(raw_data, raw_data[0]) def to_hdf5(self, fname: Union[str, h5py.File], group_path: str) -> None: - """Save an xr.DataArray to the hdf5 file or file handle with a given path to the group.""" - - # file name passed + """Save an ``xr.DataArray`` to the hdf5 file or file handle with a given path to the group.""" if isinstance(fname, str): with h5py.File(fname, "w") as f_handle: self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) - - # file handle passed else: self.to_hdf5_handle(f_handle=fname, group_path=group_path) def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: - """Save an xr.DataArray to the hdf5 file handle with a given path to the group.""" - + """Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group.""" sub_group = f_handle.create_group(group_path) sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) for key, val in self.coords.items(): @@ -234,7 +273,7 @@ def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: @classmethod def from_hdf5(cls, fname: str, group_path: str) -> Self: - """Load an DataArray from an hdf5 file with a given path to the group.""" + """Load a DataArray from an hdf5 file with a given path to the group.""" with h5py.File(fname, "r") as f: sub_group = f[group_path] values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) @@ -246,7 +285,7 @@ def from_hdf5(cls, fname: str, group_path: str) -> Self: @classmethod def from_file(cls, fname: str, group_path: str) -> Self: - """Load an DataArray from an hdf5 file with a given path to the group.""" + """Load a DataArray from an hdf5 file with a given path to the group.""" if ".hdf5" not in fname: raise FileError( f"'DataArray' objects must be written to '.hdf5' format. Given filename of {fname}." @@ -260,7 +299,7 @@ def __hash__(self) -> int: token_str = dask.base.tokenize(self) return hash(token_str) - def multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> Self: + def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: """Multiply self by value at indices.""" if isbox(self.data) or isbox(value): return self._ag_multiply_at(value, coord_name, indices) @@ -269,7 +308,7 @@ def multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> Se self_mult[{coord_name: indices}] *= value return self_mult - def _ag_multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> Self: + def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: """Autograd multiply_at override when tracing.""" key = {coord_name: indices} _, index_tuple, _ = self.variable._broadcast_indexes(key) @@ -526,7 +565,7 @@ class TimeDataArray(DataArray): """ __slots__ = () - _dims = "t" + _dims = ("t",) class MixedModeDataArray(DataArray): @@ -553,7 +592,7 @@ class AbstractSpatialDataArray(DataArray, ABC): _data_attrs = {"long_name": "field value"} @property - def _spatially_sorted(self) -> SpatialDataArray: + def _spatially_sorted(self) -> Self: """Check whether sorted and sort if not.""" needs_sorting = [] for axis in "xyz": @@ -566,7 +605,7 @@ def _spatially_sorted(self) -> SpatialDataArray: return self - def sel_inside(self, bounds: Bound) -> SpatialDataArray: + def sel_inside(self, bounds: Bound) -> Self: """Return a new SpatialDataArray that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. Note that the returned data is sorted with respect to spatial coordinates. @@ -669,7 +708,7 @@ class SpatialDataArray(AbstractSpatialDataArray): __slots__ = () - def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> SpatialDataArray: + def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> Self: """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to left. Note that the returned data is sorted with respect to spatial coordinates. @@ -997,7 +1036,7 @@ class HeatDataArray(DataArray): """ __slots__ = () - _dims = "T" + _dims = ("T",) class EMEScalarModeFieldDataArray(AbstractSpatialDataArray): @@ -1253,6 +1292,11 @@ class SpatialVoltageDataArray(AbstractSpatialDataArray): _dims = ("x", "y", "z", "voltage") +class PerturbationCoefficientDataArray(DataArray): + __slots__ = () + _dims = ("wvl", "coeff") + + DATA_ARRAY_TYPES = [ SpatialDataArray, ScalarFieldDataArray, @@ -1286,6 +1330,8 @@ class SpatialVoltageDataArray(AbstractSpatialDataArray): CellDataArray, IndexedDataArray, IndexedVoltageDataArray, + SpatialVoltageDataArray, + PerturbationCoefficientDataArray, ] DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index de64f657f4..8e6da724b4 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -3,11 +3,11 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Union, get_args +from typing import Any, Callable, Optional, Union, get_args import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field from ...constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling from ...exceptions import DataError @@ -44,7 +44,7 @@ class AbstractFieldDataset(Dataset, ABC): @property @abstractmethod - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" def apply_phase(self, phase: float) -> AbstractFieldDataset: @@ -60,15 +60,15 @@ def apply_phase(self, phase: float) -> AbstractFieldDataset: @property @abstractmethod - def grid_locations(self) -> Dict[str, str]: + def grid_locations(self) -> dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" @property @abstractmethod - def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" - def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArray]) -> Any: + def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: """How to package the dictionary of fields computed via self.colocate().""" return xr.Dataset(centered_fields) @@ -150,39 +150,39 @@ def colocate(self, x=None, y=None, z=None) -> xr.Dataset: class ElectromagneticFieldDataset(AbstractFieldDataset, ABC): """Stores a collection of E and H fields with x, y, z components.""" - Ex: Optional[EMScalarFieldType] = pd.Field( + Ex: Optional[EMScalarFieldType] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[EMScalarFieldType] = pd.Field( + Ey: Optional[EMScalarFieldType] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[EMScalarFieldType] = pd.Field( + Ez: Optional[EMScalarFieldType] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[EMScalarFieldType] = pd.Field( + Hx: Optional[EMScalarFieldType] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[EMScalarFieldType] = pd.Field( + Hy: Optional[EMScalarFieldType] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[EMScalarFieldType] = pd.Field( + Hz: Optional[EMScalarFieldType] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" fields = { "Ex": self.Ex, @@ -195,12 +195,12 @@ def field_components(self) -> Dict[str, DataArray]: return {field_name: field for field_name, field in fields.items() if field is not None} @property - def grid_locations(self) -> Dict[str, str]: + def grid_locations(self) -> dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" return dict(Ex="Ex", Ey="Ey", Ez="Ez", Hx="Hx", Hy="Hy", Hz="Hz") @property - def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" return dict( @@ -227,32 +227,32 @@ class FieldDataset(ElectromagneticFieldDataset): >>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldDataArray] = pd.Field( + Ex: Optional[ScalarFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldDataArray] = pd.Field( + Ey: Optional[ScalarFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldDataArray] = pd.Field( + Ez: Optional[ScalarFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldDataArray] = pd.Field( + Hx: Optional[ScalarFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldDataArray] = pd.Field( + Hy: Optional[ScalarFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldDataArray] = pd.Field( + Hz: Optional[ScalarFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -359,32 +359,32 @@ class FieldTimeDataset(ElectromagneticFieldDataset): >>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ex: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ey: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ez: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hx: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hy: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hz: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -402,19 +402,19 @@ def apply_phase(self, phase: float) -> AbstractFieldDataset: class AuxFieldDataset(AbstractFieldDataset, ABC): """Stores a collection of aux fields with x, y, z components.""" - Nfx: Optional[EMScalarFieldType] = pd.Field( + Nfx: Optional[EMScalarFieldType] = Field( None, title="Nfx", description="Spatial distribution of the free carrier density for " "polarization in the x-direction.", ) - Nfy: Optional[EMScalarFieldType] = pd.Field( + Nfy: Optional[EMScalarFieldType] = Field( None, title="Nfy", description="Spatial distribution of the free carrier density for " "polarization in the y-direction.", ) - Nfz: Optional[EMScalarFieldType] = pd.Field( + Nfz: Optional[EMScalarFieldType] = Field( None, title="Nfz", description="Spatial distribution of the free carrier density for " @@ -422,7 +422,7 @@ class AuxFieldDataset(AbstractFieldDataset, ABC): ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" fields = { "Nfx": self.Nfx, @@ -432,12 +432,12 @@ def field_components(self) -> Dict[str, DataArray]: return {field_name: field for field_name, field in fields.items() if field is not None} @property - def grid_locations(self) -> Dict[str, str]: + def grid_locations(self) -> dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" return dict(Nfx="Ex", Nfy="Ey", Nfz="Ez") @property - def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" return dict( @@ -461,19 +461,19 @@ class AuxFieldTimeDataset(AuxFieldDataset): >>> data = AuxFieldTimeDataset(Nfx=scalar_field) """ - Nfx: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfx: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfx", description="Spatial distribution of the free carrier density for polarization " "in the x-direction.", ) - Nfy: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfy: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfy", description="Spatial distribution of the free carrier density for polarization " "in the y-direction.", ) - Nfz: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfz: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfz", description="Spatial distribution of the free carrier density for polarization " @@ -507,51 +507,50 @@ class ModeSolverDataset(ElectromagneticFieldDataset): ... ) """ - Ex: Optional[ScalarModeFieldDataArray] = pd.Field( + Ex: Optional[ScalarModeFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: Optional[ScalarModeFieldDataArray] = pd.Field( + Ey: Optional[ScalarModeFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: Optional[ScalarModeFieldDataArray] = pd.Field( + Ez: Optional[ScalarModeFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: Optional[ScalarModeFieldDataArray] = pd.Field( + Hx: Optional[ScalarModeFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: Optional[ScalarModeFieldDataArray] = pd.Field( + Hy: Optional[ScalarModeFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: Optional[ScalarModeFieldDataArray] = pd.Field( + Hz: Optional[ScalarModeFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) - n_complex: ModeIndexDataArray = pd.Field( - ..., + n_complex: ModeIndexDataArray = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - n_group_raw: Optional[GroupIndexDataArray] = pd.Field( + n_group_raw: Optional[GroupIndexDataArray] = Field( None, alias="n_group", # This is for backwards compatibility only when loading old data title="Group Index", description="Index associated with group velocity of the mode.", ) - dispersion_raw: Optional[ModeDispersionDataArray] = pd.Field( + dispersion_raw: Optional[ModeDispersionDataArray] = Field( None, title="Dispersion", description="Dispersion parameter for the mode.", @@ -559,7 +558,7 @@ class ModeSolverDataset(ElectromagneticFieldDataset): ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" fields = { "Ex": self.Ex, @@ -632,32 +631,29 @@ class PermittivityDataset(AbstractFieldDataset): """ @property - def field_components(self) -> Dict[str, ScalarFieldDataArray]: + def field_components(self) -> dict[str, ScalarFieldDataArray]: """Maps the field components to their associated data.""" return dict(eps_xx=self.eps_xx, eps_yy=self.eps_yy, eps_zz=self.eps_zz) @property - def grid_locations(self) -> Dict[str, str]: + def grid_locations(self) -> dict[str, str]: """Maps field components to the string key of their grid locations on the yee lattice.""" return dict(eps_xx="Ex", eps_yy="Ey", eps_zz="Ez") @property - def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: """Maps field components to their (positive) symmetry eigenvalues.""" return dict(eps_xx=None, eps_yy=None, eps_zz=None) - eps_xx: ScalarFieldDataArray = pd.Field( - ..., + eps_xx: ScalarFieldDataArray = Field( title="Epsilon xx", description="Spatial distribution of the xx-component of the relative permittivity.", ) - eps_yy: ScalarFieldDataArray = pd.Field( - ..., + eps_yy: ScalarFieldDataArray = Field( title="Epsilon yy", description="Spatial distribution of the yy-component of the relative permittivity.", ) - eps_zz: ScalarFieldDataArray = pd.Field( - ..., + eps_zz: ScalarFieldDataArray = Field( title="Epsilon zz", description="Spatial distribution of the zz-component of the relative permittivity.", ) @@ -666,8 +662,7 @@ def symmetry_eigenvalues(self) -> Dict[str, Callable[[Axis], float]]: class TriangleMeshDataset(Dataset): """Dataset for storing triangular surface data.""" - surface_mesh: TriangleMeshDataArray = pd.Field( - ..., + surface_mesh: TriangleMeshDataArray = Field( title="Surface mesh data", description="Dataset containing the surface triangles and corresponding face indices " "for a surface mesh.", @@ -677,6 +672,7 @@ class TriangleMeshDataset(Dataset): class TimeDataset(Dataset): """Dataset for storing a function of time.""" - values: TimeDataArray = pd.Field( - ..., title="Values", description="Values as a function of time." + values: TimeDataArray = Field( + title="Values", + description="Values as a function of time.", ) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 3a20e43d1d..76a9e94d72 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -6,18 +6,18 @@ import warnings from abc import ABC from math import isclose -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args +from typing import Any, Callable, Literal, Optional, Union, get_args import autograd.numpy as np -import pydantic.v1 as pd import xarray as xr from pandas import DataFrame +from pydantic import Field, model_validator from xarray.core.types import Self from ...constants import C_0, EPSILON_0, ETA_0, MICROMETER, UnitScaling from ...exceptions import DataError, SetupError, Tidy3dNotImplementedError, ValidationError from ...log import log -from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing +from ..base import TYPE_TAG_STR, cached_property from ..base_sim.data.monitor_data import AbstractMonitorData from ..grid.grid import Coords, Grid from ..medium import Medium, MediumType @@ -39,27 +39,15 @@ PermittivityMonitor, ) from ..source.base import Source -from ..source.current import ( - CustomCurrentSource, - PointDipole, -) -from ..source.field import ( - CustomFieldSource, - ModeSource, - PlaneWave, -) -from ..source.time import ( - GaussianPulse, - SourceTimeType, -) +from ..source.current import CustomCurrentSource, PointDipole +from ..source.field import CustomFieldSource, ModeSource, PlaneWave +from ..source.time import GaussianPulse, SourceTimeType from ..types import ( ArrayFloat1D, ArrayFloat2D, Coordinate, EMField, EpsSpecType, - Literal, - Numpy, PolarizationBasis, Size, Symmetry, @@ -111,8 +99,7 @@ class MonitorData(AbstractMonitorData, ABC): Abstract base class of objects that store data pertaining to a single :class:`.monitor`. """ - monitor: MonitorType = pd.Field( - ..., + monitor: MonitorType = Field( title="Monitor", description="Monitor associated with the data.", discriminator=TYPE_TAG_STR, @@ -152,7 +139,7 @@ def amplitude_fn(freq: list[float]) -> complex: return self.normalize(amplitude_fn) - def _updated(self, update: Dict) -> MonitorData: + def _updated(self, update: dict) -> MonitorData: """Similar to ``updated_copy``, but does not actually copy components, for speed. Note @@ -204,19 +191,19 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): FieldMonitor, FieldTimeMonitor, AuxFieldTimeMonitor, PermittivityMonitor, ModeMonitor ] - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pd.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetry", description="Symmetry eigenvalues of the original simulation in x, y, and z.", ) - symmetry_center: Coordinate = pd.Field( + symmetry_center: Optional[Coordinate] = Field( None, title="Symmetry Center", description="Center of the symmetry planes of the original simulation in x, y, and z. " "Required only if any of the ``symmetry`` field are non-zero.", ) - grid_expanded: Grid = pd.Field( + grid_expanded: Optional[Grid] = Field( None, title="Expanded Grid", description=":class:`.Grid` discretization of the associated monitor in the simulation " @@ -224,17 +211,19 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): "well as in order to use some functionalities like getting Poynting vector and flux.", ) - @pd.validator("grid_expanded", always=True) - def warn_missing_grid_expanded(cls, val, values): + @model_validator(mode="after") + def warn_missing_grid_expanded(self): """If ``grid_expanded`` not provided and fields data is present, warn that some methods will break.""" field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] - if val is None and any(values.get(comp) is not None for comp in field_comps): + if self.grid_expanded is None and any( + getattr(self, comp) is not None for comp in field_comps + ): log.warning( "Monitor data requires 'grid_expanded' to be defined to compute values like " "flux, Poynting and dot product with other data." ) - return val + return self _require_sym_center = required_if_symmetry_present("symmetry_center") _require_grid_expanded = required_if_symmetry_present("grid_expanded") @@ -277,7 +266,7 @@ def symmetry_expanded_copy(self) -> AbstractFieldData: return self.copy(update=self._symmetry_update_dict) @property - def _symmetry_update_dict(self) -> Dict: + def _symmetry_update_dict(self) -> dict: """Dictionary of data fields to create data with expanded symmetry.""" update_dict = {} @@ -407,7 +396,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A TimeDataArray, FreqModeDataArray, EMEFreqModeDataArray, - ] = pd.Field( + ] = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -421,7 +410,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A TimeDataArray, FreqModeDataArray, EMEFreqModeDataArray, - ] = pd.Field( + ] = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -446,7 +435,7 @@ def _grid_correction_dict(self): } @property - def _tangential_dims(self) -> List[str]: + def _tangential_dims(self) -> list[str]: """For a 2D monitor data, return the names of the tangential dimensions. Raise if cannot confirm that the associated monitor is 2D.""" if len(self.monitor.zero_dims) != 1: @@ -491,7 +480,7 @@ def colocation_centers(self) -> Coords: return Coords(**colocate_centers) @property - def _plane_grid_boundaries(self) -> Tuple[Coords1D, Coords1D]: + def _plane_grid_boundaries(self) -> tuple[Coords1D, Coords1D]: """For a 2D monitor data, return the boundaries of the in-plane grid to be used to compute differential area and to colocate fields if needed.""" if np.any(np.array(self.monitor.interval_space) > 1): @@ -504,7 +493,7 @@ def _plane_grid_boundaries(self) -> Tuple[Coords1D, Coords1D]: return (bounds_dict[dim1], bounds_dict[dim2]) @property - def _plane_grid_centers(self) -> Tuple[Coords1D, Coords1D]: + def _plane_grid_centers(self) -> tuple[Coords1D, Coords1D]: """For 2D monitor data, return the centers of the in-plane grid""" return [(bs[1:] + bs[:-1]) / 2 for bs in self._plane_grid_boundaries] @@ -548,7 +537,7 @@ def _diff_area(self) -> DataArray: return DataArray(np.outer(sizes_dim0, sizes_dim1), dims=self._tangential_dims) - def _tangential_corrected(self, fields: Dict[str, DataArray]) -> Dict[str, DataArray]: + def _tangential_corrected(self, fields: dict[str, DataArray]) -> dict[str, DataArray]: """For a 2D monitor data, extract the tangential components from fields and orient them such that the third component would be the normal axis. This just means that the H field gets an extra minus sign if the normal axis is ``"y"``. Raise if any of the tangential @@ -591,7 +580,7 @@ def _tangential_corrected(self, fields: Dict[str, DataArray]) -> Dict[str, DataA return tan_fields @property - def _tangential_fields(self) -> Dict[str, DataArray]: + def _tangential_fields(self) -> dict[str, DataArray]: """For a 2D monitor data, get the tangential E and H fields in the 2D plane grid. Fields are oriented such that the third component would be the normal axis. This just means that the H field gets an extra minus sign if the normal axis is ``"y"``. @@ -603,7 +592,7 @@ def _tangential_fields(self) -> Dict[str, DataArray]: return self._tangential_corrected(self.symmetry_expanded.field_components) @property - def _colocated_fields(self) -> Dict[str, DataArray]: + def _colocated_fields(self) -> dict[str, DataArray]: """For a 2D monitor data, get all E and H fields colocated to the cell boundaries in the 2D plane grid, with symmetries expanded. """ @@ -623,7 +612,7 @@ def _colocated_fields(self) -> Dict[str, DataArray]: return colocated_fields @property - def _colocated_tangential_fields(self) -> Dict[str, DataArray]: + def _colocated_tangential_fields(self) -> dict[str, DataArray]: """For a 2D monitor data, get the tangential E and H fields colocated to the cell boundaries in the 2D plane grid. Fields are oriented such that the third component would be the normal axis. This just means that the H field gets an extra minus sign if the normal axis is @@ -786,7 +775,7 @@ def dot( return ModeAmpsDataArray(0.25 * integrand.sum(dim=d_area.dims)) - def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> Dict[str, DataArray]: + def _interpolated_tangential_fields(self, coords: ArrayFloat2D) -> dict[str, DataArray]: """For 2D monitors, interpolate this fields to given coords in the tangential plane. Parameters @@ -942,11 +931,11 @@ def fn(fields_1, fields_2): @staticmethod def _outer_fn_summation( - fields_1: Dict[str, xr.DataArray], - fields_2: Dict[str, xr.DataArray], + fields_1: dict[str, xr.DataArray], + fields_2: dict[str, xr.DataArray], outer_dim_1: str, outer_dim_2: str, - sum_dims: List[str], + sum_dims: list[str], fn: Callable, ) -> DataArray: """ @@ -1083,7 +1072,7 @@ def to_zbf( z_y: float = 0, rec_efficiency: float = 0, sys_efficiency: float = 0, - ) -> Tuple[ScalarFieldDataArray, ScalarFieldDataArray]: + ) -> tuple[ScalarFieldDataArray, ScalarFieldDataArray]: """For a 2D monitor, export the fields to a Zemax Beam File (``.zbf``). The mode area is used to approximate the beam waist, which is only valid @@ -1125,7 +1114,7 @@ def to_zbf( Returns ------- - Tuple[:class:`.ScalarFieldDataArray`,:class:`.ScalarFieldDataArray`] + tuple[:class:`.ScalarFieldDataArray`,:class:`.ScalarFieldDataArray`] The two E field components being exported to ``.zbf``. """ log.warning( @@ -1279,8 +1268,9 @@ class FieldData(FieldDataset, ElectromagneticFieldData): * `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_ """ - monitor: FieldMonitor = pd.Field( - ..., title="Monitor", description="Frequency-domain field monitor associated with the data." + monitor: FieldMonitor = Field( + title="Monitor", + description="Frequency-domain field monitor associated with the data.", ) _contains_monitor_fields = enforce_monitor_fields_present() @@ -1303,9 +1293,9 @@ def to_source( ---------- source_time: :class:`.SourceTime` Specification of the source time-dependence. - center: Tuple[float, float, float] + center: tuple[float, float, float] Source center in x, y and z. - size: Tuple[float, float, float] + size: tuple[float, float, float] Source size in x, y, and z. If not provided, the size of the monitor associated to the data is used. **kwargs @@ -1335,7 +1325,7 @@ def to_source( def make_adjoint_sources( self, dataset_names: list[str], fwidth: float - ) -> List[CustomCurrentSource]: + ) -> list[CustomCurrentSource]: """Converts a :class:`.FieldData` to a list of adjoint current or point sources.""" sources = [] @@ -1424,8 +1414,9 @@ class FieldTimeData(FieldTimeDataset, ElectromagneticFieldData): >>> data = FieldTimeData(monitor=monitor, Ex=scalar_field, Hz=scalar_field, grid_expanded=grid) """ - monitor: FieldTimeMonitor = pd.Field( - ..., title="Monitor", description="Time-domain field monitor associated with the data." + monitor: FieldTimeMonitor = Field( + title="Monitor", + description="Time-domain field monitor associated with the data.", ) _contains_monitor_fields = enforce_monitor_fields_present() @@ -1498,8 +1489,7 @@ class AuxFieldTimeData(AuxFieldTimeDataset, AbstractFieldData): >>> data = AuxFieldTimeData(monitor=monitor, Nfx=scalar_field, grid_expanded=grid) """ - monitor: AuxFieldTimeMonitor = pd.Field( - ..., + monitor: AuxFieldTimeMonitor = Field( title="Monitor", description="Time-domain auxiliary field monitor associated with the data.", ) @@ -1532,8 +1522,9 @@ class PermittivityData(PermittivityDataset, AbstractFieldData): ... ) """ - monitor: PermittivityMonitor = pd.Field( - ..., title="Monitor", description="Permittivity monitor associated with the data." + monitor: PermittivityMonitor = Field( + title="Monitor", + description="Permittivity monitor associated with the data.", ) @@ -1576,32 +1567,33 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData): >>> data = ModeData(monitor=monitor, amps=amp_data, n_complex=index_data) """ - monitor: ModeMonitor = pd.Field( - ..., title="Monitor", description="Mode monitor associated with the data." + monitor: ModeMonitor = Field( + title="Monitor", + description="Mode monitor associated with the data.", ) - amps: ModeAmpsDataArray = pd.Field( - ..., title="Amplitudes", description="Complex-valued amplitudes associated with the mode." + amps: ModeAmpsDataArray = Field( + title="Amplitudes", + description="Complex-valued amplitudes associated with the mode.", ) - eps_spec: List[EpsSpecType] = pd.Field( + eps_spec: Optional[list[EpsSpecType]] = Field( None, title="Permettivity Specification", description="Characterization of the permittivity profile on the plane where modes are " "computed. Possible values are 'diagonal', 'tensorial_real', 'tensorial_complex'.", ) - @pd.validator("eps_spec", always=True) - @skip_if_fields_missing(["monitor"]) - def eps_spec_match_mode_spec(cls, val, values): + @model_validator(mode="after") + def eps_spec_match_mode_spec(self): """Raise validation error if frequencies in eps_spec does not match frequency list""" - if val: - mode_data_freqs = values["monitor"].freqs - if len(val) != len(mode_data_freqs): + if self.eps_spec: + mode_data_freqs = self.monitor.freqs + if len(self.eps_spec) != len(mode_data_freqs): raise ValidationError( "eps_spec must be provided at the same frequencies as mode solver data." ) - return val + return self def normalize(self, source_spectrum_fn) -> ModeData: """Return copy of self after normalization is applied using source spectrum function.""" @@ -1726,7 +1718,7 @@ def _find_ordering_one_freq( self, data_to_sort: ModeData, overlap_thresh: float, - ) -> Tuple[Numpy, Numpy]: + ) -> tuple[np.ndarray, np.ndarray]: """Find new ordering of modes in data_to_sort based on their similarity to own modes.""" num_modes = self.n_complex.sizes["mode_index"] @@ -1763,7 +1755,7 @@ def _find_ordering_one_freq( return pairs, complex_amps @staticmethod - def _find_closest_pairs(arr: Numpy) -> Tuple[Numpy, Numpy]: + def _find_closest_pairs(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Given a complex overlap matrix pair row and column entries.""" n, k = np.shape(arr) @@ -1784,8 +1776,8 @@ def _find_closest_pairs(arr: Numpy) -> Tuple[Numpy, Numpy]: def _reorder_modes( self, - sorting: Numpy, - phase: Numpy, + sorting: np.ndarray, + phase: np.ndarray, track_freq: TrackFreq, ) -> ModeData: """Rearrange modes for the i-th frequency according to sorting[i, :] and apply phase @@ -2185,12 +2177,15 @@ class ModeSolverData(ModeData): ... ) """ - monitor: ModeSolverMonitor = pd.Field( - ..., title="Monitor", description="Mode solver monitor associated with the data." + monitor: ModeSolverMonitor = Field( + title="Monitor", + description="Mode solver monitor associated with the data.", ) - amps: ModeAmpsDataArray = pd.Field( - None, title="Amplitudes", description="Unused for ModeSolverData." + amps: Optional[ModeAmpsDataArray] = Field( + None, + title="Amplitudes", + description="Unused for ModeSolverData.", ) def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData: @@ -2258,17 +2253,19 @@ class FluxData(MonitorData): * `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_ """ - monitor: FluxMonitor = pd.Field( - ..., title="Monitor", description="Frequency-domain flux monitor associated with the data." + monitor: FluxMonitor = Field( + title="Monitor", + description="Frequency-domain flux monitor associated with the data.", ) - flux: FluxDataArray = pd.Field( - ..., title="Flux", description="Flux values in the frequency-domain." + flux: FluxDataArray = Field( + title="Flux", + description="Flux values in the frequency-domain.", ) def make_adjoint_sources( self, dataset_names: list[str], fwidth: float - ) -> List[Union[CustomCurrentSource, PointDipole]]: + ) -> list[Union[CustomCurrentSource, PointDipole]]: """Converts a :class:`.FieldData` to a list of adjoint current or point sources.""" # avoids error in edge case where there are extraneous flux monitors not used in objective @@ -2312,12 +2309,14 @@ class FluxTimeData(MonitorData): >>> data = FluxTimeData(monitor=monitor, flux=flux_data) """ - monitor: FluxTimeMonitor = pd.Field( - ..., title="Monitor", description="Time-domain flux monitor associated with the data." + monitor: FluxTimeMonitor = Field( + title="Monitor", + description="Time-domain flux monitor associated with the data.", ) - flux: FluxTimeDataArray = pd.Field( - ..., title="Flux", description="Flux values in the time-domain." + flux: FluxTimeDataArray = Field( + title="Flux", + description="Flux values in the time-domain.", ) @@ -2340,59 +2339,52 @@ class FluxTimeData(MonitorData): class AbstractFieldProjectionData(MonitorData): """Collection of projected fields in spherical coordinates in the frequency domain.""" - monitor: ProjMonitorType = pd.Field( - ..., + monitor: ProjMonitorType = Field( title="Projection monitor", description="Field projection monitor.", discriminator=TYPE_TAG_STR, ) - Er: ProjFieldType = pd.Field( - ..., + Er: ProjFieldType = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: ProjFieldType = pd.Field( - ..., + Etheta: ProjFieldType = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: ProjFieldType = pd.Field( - ..., + Ephi: ProjFieldType = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: ProjFieldType = pd.Field( - ..., + Hr: ProjFieldType = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: ProjFieldType = pd.Field( - ..., + Htheta: ProjFieldType = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: ProjFieldType = pd.Field( - ..., + Hphi: ProjFieldType = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) - medium: MediumType = pd.Field( - Medium(), + medium: MediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium through which to project fields.", discriminator=TYPE_TAG_STR, ) - is_2d_simulation: bool = pd.Field( + is_2d_simulation: bool = Field( False, title="2D Simulation", description="Indicates whether the monitor data is for a 2D simulation.", ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" return dict( Er=self.Er, @@ -2409,12 +2401,12 @@ def f(self) -> np.ndarray: return np.array(self.Etheta.coords["f"]) @property - def coords(self) -> Dict[str, np.ndarray]: + def coords(self) -> dict[str, np.ndarray]: """Coordinates of the fields contained.""" return self.Etheta.coords @property - def coords_spherical(self) -> Dict[str, np.ndarray]: + def coords_spherical(self) -> dict[str, np.ndarray]: """Coordinates grid for the fields in the spherical system.""" if "theta" in self.coords.keys(): r, theta, phi = np.meshgrid( @@ -2442,7 +2434,7 @@ def coords_spherical(self) -> Dict[str, np.ndarray]: return {"r": r, "theta": theta, "phi": phi} @property - def dims(self) -> Tuple[str, ...]: + def dims(self) -> tuple[str, ...]: """Dimensions of the radiation vectors contained.""" return self.Etheta.dims @@ -2450,7 +2442,7 @@ def make_data_array(self, data: np.ndarray) -> DataArray: """Make an DataArray with data and same coords and dims as fields of self.""" return DataArray(data=data, coords=self.coords, dims=self.dims) - def make_dataset(self, keys: Tuple[str, ...], vals: Tuple[np.ndarray, ...]) -> xr.Dataset: + def make_dataset(self, keys: tuple[str, ...], vals: tuple[np.ndarray, ...]) -> xr.Dataset: """Make an xr.Dataset with keys and data with same coords and dims as fields.""" data_arrays = tuple(map(self.make_data_array, vals)) return xr.Dataset(dict(zip(keys, data_arrays))) @@ -2484,7 +2476,7 @@ def wavenumber(medium: MediumType, frequency: float) -> complex: return (2 * np.pi * frequency / C_0) * (index_n + 1j * index_k) @property - def nk(self) -> Tuple[float, float]: + def nk(self) -> tuple[float, float]: """Returns the real and imaginary parts of the background medium's refractive index.""" return self.medium.nk_model(frequency=self.f) @@ -2608,7 +2600,7 @@ def radar_cross_section(self) -> DataArray: def make_adjoint_sources( self, dataset_names: list[str], fwidth: float - ) -> List[Union[CustomCurrentSource, PointDipole]]: + ) -> list[Union[CustomCurrentSource, PointDipole]]: """Error if server-side field projection is used for autograd""" raise NotImplementedError( @@ -2644,45 +2636,37 @@ class FieldProjectionAngleData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionAngleMonitor = pd.Field( - ..., + monitor: FieldProjectionAngleMonitor = Field( title="Projection monitor", description="Field projection monitor with an angle-based projection grid.", ) - projection_surfaces: Tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionAngleDataArray = pd.Field( - ..., + Er: FieldProjectionAngleDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionAngleDataArray = pd.Field( - ..., + Etheta: FieldProjectionAngleDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionAngleDataArray = pd.Field( - ..., + Ephi: FieldProjectionAngleDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionAngleDataArray = pd.Field( - ..., + Hr: FieldProjectionAngleDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionAngleDataArray = pd.Field( - ..., + Htheta: FieldProjectionAngleDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionAngleDataArray = pd.Field( - ..., + Hphi: FieldProjectionAngleDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -2854,45 +2838,37 @@ class FieldProjectionCartesianData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionCartesianMonitor = pd.Field( - ..., + monitor: FieldProjectionCartesianMonitor = Field( title="Projection monitor", description="Field projection monitor with a Cartesian projection grid.", ) - projection_surfaces: Tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionCartesianDataArray = pd.Field( - ..., + Er: FieldProjectionCartesianDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionCartesianDataArray = pd.Field( - ..., + Etheta: FieldProjectionCartesianDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionCartesianDataArray = pd.Field( - ..., + Ephi: FieldProjectionCartesianDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionCartesianDataArray = pd.Field( - ..., + Hr: FieldProjectionCartesianDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionCartesianDataArray = pd.Field( - ..., + Htheta: FieldProjectionCartesianDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionCartesianDataArray = pd.Field( - ..., + Hphi: FieldProjectionCartesianDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3007,45 +2983,37 @@ class FieldProjectionKSpaceData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionKSpaceMonitor = pd.Field( - ..., + monitor: FieldProjectionKSpaceMonitor = Field( title="Projection monitor", description="Field projection monitor with a projection grid defined in k-space.", ) - projection_surfaces: Tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Er: FieldProjectionKSpaceDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Etheta: FieldProjectionKSpaceDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Ephi: FieldProjectionKSpaceDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Hr: FieldProjectionKSpaceDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Htheta: FieldProjectionKSpaceDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Hphi: FieldProjectionKSpaceDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3136,57 +3104,50 @@ class DiffractionData(AbstractFieldProjectionData): ... ) """ - monitor: DiffractionMonitor = pd.Field( - ..., title="Monitor", description="Diffraction monitor associated with the data." + monitor: DiffractionMonitor = Field( + title="Monitor", + description="Diffraction monitor associated with the data.", ) - Er: DiffractionDataArray = pd.Field( - ..., + Er: DiffractionDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: DiffractionDataArray = pd.Field( - ..., + Etheta: DiffractionDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: DiffractionDataArray = pd.Field( - ..., + Ephi: DiffractionDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: DiffractionDataArray = pd.Field( - ..., + Hr: DiffractionDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: DiffractionDataArray = pd.Field( - ..., + Htheta: DiffractionDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: DiffractionDataArray = pd.Field( - ..., + Hphi: DiffractionDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) - sim_size: Tuple[float, float] = pd.Field( - ..., + sim_size: tuple[float, float] = Field( title="Domain size", description="Size of the near field in the local x and y directions.", units=MICROMETER, ) - bloch_vecs: Union[Tuple[float, float], Tuple[ArrayFloat1D, ArrayFloat1D]] = pd.Field( - ..., + bloch_vecs: Union[tuple[float, float], tuple[ArrayFloat1D, ArrayFloat1D]] = Field( title="Bloch vectors", description="Bloch vectors along the local x and y directions in units of " "``2 * pi / (simulation size along the respective dimension)``.", ) @staticmethod - def shifted_orders(orders: Tuple[int, ...], bloch_vec: Union[float, np.ndarray]) -> np.ndarray: + def shifted_orders(orders: tuple[int, ...], bloch_vec: Union[float, np.ndarray]) -> np.ndarray: """Diffraction orders shifted by the Bloch vector.""" return bloch_vec + np.atleast_2d(orders).T @@ -3207,8 +3168,8 @@ def reciprocal_coords( @staticmethod def compute_angles( - reciprocal_vectors: Tuple[np.ndarray, np.ndarray], - ) -> Tuple[np.ndarray, np.ndarray]: + reciprocal_vectors: tuple[np.ndarray, np.ndarray], + ) -> tuple[np.ndarray, np.ndarray]: """Compute the polar and azimuth angles associated with the given reciprocal vectors.""" # some wave number pairs are outside the light cone, leading to warnings from numpy.arcsin with warnings.catch_warnings(): @@ -3220,7 +3181,7 @@ def compute_angles( return (thetas, phis) @property - def coords_spherical(self) -> Dict[str, np.ndarray]: + def coords_spherical(self) -> dict[str, np.ndarray]: """Coordinates grid for the fields in the spherical system.""" theta, phi = self.angles return {"r": None, "theta": theta, "phi": phi} @@ -3236,7 +3197,7 @@ def orders_y(self) -> np.ndarray: return np.atleast_1d(np.array(self.Etheta.coords["orders_y"])) @property - def reciprocal_vectors(self) -> Tuple[np.ndarray, np.ndarray]: + def reciprocal_vectors(self) -> tuple[np.ndarray, np.ndarray]: """Get the normalized "ux" and "uy" reciprocal vectors.""" return (self.ux, self.uy) @@ -3267,7 +3228,7 @@ def uy(self) -> np.ndarray: ) @property - def angles(self) -> Tuple[DataArray]: + def angles(self) -> tuple[DataArray]: """The (theta, phi) angles corresponding to each allowed pair of diffraction orders storeds as data arrays. Disallowed angles are set to ``np.nan``. """ @@ -3350,7 +3311,7 @@ def fields_cartesian(self) -> xr.Dataset: keys = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] return self._make_dataset(fields, keys) - def _make_dataset(self, fields: Tuple[np.ndarray, ...], keys: Tuple[str, ...]) -> xr.Dataset: + def _make_dataset(self, fields: tuple[np.ndarray, ...], keys: tuple[str, ...]) -> xr.Dataset: """Make an xr.Dataset for fields with given field names.""" data_arrays = [] for field in fields: @@ -3474,14 +3435,12 @@ class DirectivityData(FieldProjectionAngleData): ... Hr=scalar_field, Htheta=scalar_field, Hphi=scalar_field, projection_surfaces=monitor.projection_surfaces) # doctest: +SKIP """ - monitor: DirectivityMonitor = pd.Field( - ..., + monitor: DirectivityMonitor = Field( title="Monitor", description="Monitor describing the angle-based projection grid on which to measure directivity data.", ) - flux: FluxDataArray = pd.Field( - ..., + flux: FluxDataArray = Field( title="Flux", description="Flux values that are either computed from fields recorded on the " "projection surfaces or by integrating the projected fields over a spherical surface.", diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index af76beae4d..86081bcc68 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -6,12 +6,12 @@ import pathlib from abc import ABC from collections import defaultdict -from typing import Callable, Tuple, Union +from typing import Callable, Union import h5py import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field from ...constants import C_0, inf from ...exceptions import DataError, FileError, Tidy3dKeyError @@ -25,7 +25,7 @@ from ..source.time import GaussianPulse from ..source.utils import SourceType from ..structure import Structure -from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type +from ..types import Ax, Axis, ColormapType, FieldVal, PlotScale, discriminated_union from ..viz import add_ax_if_none, equal_aspect from .data_array import FreqDataArray from .monitor_data import ( @@ -35,10 +35,12 @@ MonitorDataTypes, ) -DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes} +DATA_TYPE_MAP = {data.model_fields["monitor"].annotation: data for data in MonitorDataTypes} # maps monitor type (string) to the class of the corresponding data -DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes} +DATA_TYPE_NAME_MAP = { + val.model_fields["monitor"].annotation.__name__: val for val in MonitorDataTypes +} # residuals below this are considered good fits for broadband adjoint source creation RESIDUAL_CUTOFF_ADJOINT = 1e-6 @@ -47,21 +49,18 @@ class AdjointSourceInfo(Tidy3dBaseModel): """Stores information about the adjoint sources to pass to autograd pipeline.""" - sources: Tuple[annotate_type(SourceType), ...] = pd.Field( - ..., + sources: tuple[discriminated_union(SourceType), ...] = Field( title="Adjoint Sources", description="Set of processed sources to include in the adjoint simulation.", ) - post_norm: Union[float, FreqDataArray] = pd.Field( - ..., + post_norm: Union[float, FreqDataArray] = Field( title="Post Normalization Values", description="Factor to multiply the adjoint fields by after running " "given the adjoint source pipeline used.", ) - normalize_sim: bool = pd.Field( - ..., + normalize_sim: bool = Field( title="Normalize Adjoint Simulation", description="Whether the adjoint simulation needs to be normalized " "given the adjoint source pipeline used.", @@ -906,20 +905,18 @@ class SimulationData(AbstractYeeGridSimulationData): """ - simulation: Simulation = pd.Field( - ..., + simulation: Simulation = Field( title="Simulation", description="Original :class:`.Simulation` associated with the data.", ) - data: Tuple[annotate_type(MonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(MonitorDataType), ...] = Field( title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.Simulation`.", ) - diverged: bool = pd.Field( + diverged: bool = Field( False, title="Diverged", description="A boolean flag denoting whether the simulation run diverged.", @@ -988,7 +985,7 @@ def source_spectrum_fn(freqs): return new_spectrum_fn(freqs) / old_spectrum_fn(freqs) # Make a new monitor_data dictionary with renormalized data - data_normalized = [mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data] + data_normalized = tuple(mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data) simulation = self.simulation.copy(update=dict(normalize_index=normalize_index)) @@ -1008,7 +1005,7 @@ def split_adjoint_data(self: SimulationData, num_mnts_original: int) -> tuple[li return data_original, data_adjoint - def split_original_fwd(self, num_mnts_original: int) -> Tuple[SimulationData, SimulationData]: + def split_original_fwd(self, num_mnts_original: int) -> tuple[SimulationData, SimulationData]: """Split this simulation data into original and fwd data from number of original mnts.""" # split the data and monitors into the original ones & adjoint gradient ones (for 'fwd') diff --git a/tidy3d/components/data/unstructured/base.py b/tidy3d/components/data/unstructured/base.py index cd27e32707..ae0d31664c 100644 --- a/tidy3d/components/data/unstructured/base.py +++ b/tidy3d/components/data/unstructured/base.py @@ -4,13 +4,14 @@ import numbers from abc import ABC, abstractmethod -from typing import Literal, Tuple, Union +from typing import Literal, Union import numpy as np -import pydantic.v1 as pd +from pandas import RangeIndex +from pydantic import Field, PositiveInt, field_validator, model_validator from xarray import DataArray as XrDataArray -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( DATA_ARRAY_MAP, CellDataArray, @@ -34,20 +35,17 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): """Abstract base for datasets that store unstructured grid data.""" - points: PointDataArray = pd.Field( - ..., + points: PointDataArray = Field( title="Grid Points", description="Coordinates of points composing the unstructured grid.", ) - values: IndexedDataArrayTypes = pd.Field( - ..., + values: IndexedDataArrayTypes = Field( title="Point Values", description="Values stored at the grid points.", ) - cells: CellDataArray = pd.Field( - ..., + cells: CellDataArray = Field( title="Grid Cells", description="Cells composing the unstructured grid specified as connections between grid " "points.", @@ -57,17 +55,18 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC) @classmethod @abstractmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" @classmethod @abstractmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" """ Validators """ - @pd.validator("points", always=True) + @field_validator("points") + @classmethod def points_right_dims(cls, val): """Check that point coordinates have the right dimensionality.""" # currently support only the standard axis ordering, that is 01(2) @@ -80,8 +79,8 @@ def points_right_dims(cls, val): ) return val - @pd.validator("points", always=True) - def points_right_indexing(cls, val): + @field_validator("points") + def points_right_indexing(val): """Check that points are indexed corrrectly.""" indices_expected = np.arange(len(val.data)) indices_given = val.index.data @@ -93,15 +92,15 @@ def points_right_indexing(cls, val): ) return val - @pd.validator("values", always=True) - def first_values_dim_is_index(cls, val): + @field_validator("values") + def first_values_dim_is_index(val): """Check that the number of data values matches the number of grid points.""" if val.dims[0] != "index": raise ValidationError("First dimension of array 'values' must be 'index'.") return val - @pd.validator("values", always=True) - def values_right_indexing(cls, val): + @field_validator("values") + def values_right_indexing(val): """Check that data values are indexed correctly.""" # currently support only simple ordered indexing of points, that is, 0, 1, 2, ... indices_expected = np.arange(len(val.index.data)) @@ -114,24 +113,21 @@ def values_right_indexing(cls, val): ) return val - @pd.validator("values", always=True) - @skip_if_fields_missing(["points"]) - def number_of_values_matches_points(cls, val, values): + @model_validator(mode="after") + def number_of_values_matches_points(self): """Check that the number of data values matches the number of grid points.""" - num_values = len(val.index) - - points = values.get("points") - num_points = len(points) + num_values = len(self.values.index) + num_points = len(self.points) if num_points != num_values: raise ValidationError( f"The number of data values ({num_values}) does not match the number of grid " f"points ({num_points})." ) - return val + return self - @pd.validator("cells", always=True) - def match_cells_to_vtk_type(cls, val): + @field_validator("cells") + def match_cells_to_vtk_type(val): """Check that cell connections does not have duplicate points.""" if vtk is None: return val @@ -139,7 +135,8 @@ def match_cells_to_vtk_type(cls, val): # using val.astype(np.int32/64) directly causes issues when dataarray are later checked == return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords) - @pd.validator("cells", always=True) + @field_validator("cells") + @classmethod def cells_right_type(cls, val): """Check that cell are of the right type.""" # only supporting the standard ordering of cell vertices 012(3) @@ -152,18 +149,19 @@ def cells_right_type(cls, val): ) return val - @pd.validator("cells", always=True) - @skip_if_fields_missing(["points"]) - def check_cell_vertex_range(cls, val, values): + @model_validator(mode="after") + def check_cell_vertex_range(self): """Check that cell connections use only defined points.""" + val = getattr(self, "cells", None) + if val is None: + return self all_point_indices_used = val.data.ravel() # skip validation if zero size data if len(all_point_indices_used) > 0: min_index_used = np.min(all_point_indices_used) max_index_used = np.max(all_point_indices_used) - points = values.get("points") - num_points = len(points) + num_points = len(self.points) if max_index_used > num_points - 1 or min_index_used < 0: raise ValidationError( @@ -171,9 +169,9 @@ def check_cell_vertex_range(cls, val, values): f"[{min_index_used}, {max_index_used}]. The valid range of point indices is " f"[0, {num_points-1}]." ) - return val + return self - @pd.validator("cells", always=True) + @field_validator("cells") def warn_degenerate_cells(cls, val): """Check that cell connections does not have duplicate points.""" degenerate_cells = cls._find_degenerate_cells(val) @@ -187,13 +185,14 @@ def warn_degenerate_cells(cls, val): ) return val - @pd.root_validator(pre=True, allow_reuse=True) - def _warn_if_none(cls, values): + @model_validator(mode="before") + @classmethod + def _warn_if_none(cls, data: dict) -> dict: """Warn if any of data arrays are not loaded.""" no_data_fields = [] for field_name in ["points", "cells", "values"]: - field = values.get(field_name) + field = data.get(field_name) if isinstance(field, str) and field in DATA_ARRAY_MAP.keys(): no_data_fields.append(field_name) if len(no_data_fields) > 0: @@ -201,20 +200,37 @@ def _warn_if_none(cls, values): log.warning( f"Loading {', '.join(formatted_names)} without data. Constructing an empty dataset." ) - values["points"] = PointDataArray( + data["points"] = PointDataArray( np.zeros((0, cls._point_dims())), dims=["index", "axis"] ) - values["cells"] = CellDataArray( + data["cells"] = CellDataArray( np.zeros((0, cls._cell_num_vertices())), dims=["cell_index", "vertex_index"] ) - values["values"] = IndexedDataArray(np.zeros(0), dims=["index"]) - return values - - @pd.root_validator(skip_on_failure=True, allow_reuse=True) - def _warn_unused_points(cls, values): + data["values"] = IndexedDataArray(np.zeros(0), dims=["index"]) + return data + + @model_validator(mode="before") + def _add_default_coords(cls, data: dict) -> dict: + def _add_default_coords(da): + """Add 0..N-1 coordinates to any dimension that does not already have one. + Note: We use a pandas `RangeIndex` here for constant memory. + """ + missing = {d: RangeIndex(da.sizes[d]) for d in da.dims if d not in da.coords} + return da.assign_coords(missing) if missing else da + + if "points" in data: + data["points"] = _add_default_coords(data["points"]) + if "cells" in data: + data["cells"] = _add_default_coords(data["cells"]) + if "values" in data: + data["values"] = _add_default_coords(data["values"]) + return data + + @model_validator(mode="after") + def _warn_unused_points(self): """Warn if some points are unused.""" - point_indices = set(np.arange(len(values["points"].data))) - used_indices = set(values["cells"].values.ravel()) + point_indices = set(np.arange(len(self.points.data))) + used_indices = set(self.cells.values.ravel()) if not point_indices.issubset(used_indices): log.warning( @@ -222,7 +238,7 @@ def _warn_unused_points(cls, values): "Consider calling 'clean()' to remove them." ) - return values + return self """ Convenience properties """ @@ -571,7 +587,7 @@ def to_vtu(self, fname: str): def _get_values_from_vtk( cls, vtk_obj, - num_points: pd.PositiveInt, + num_points: PositiveInt, field: str = None, values_type=IndexedDataArray, expect_complex=None, @@ -720,7 +736,7 @@ def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -1380,18 +1396,18 @@ def _interp_py_general( def _interp_py_chunk( self, - xyz_grid: Tuple[ArrayLike[float], ...], + xyz_grid: tuple[ArrayLike[float], ...], cell_inds: ArrayLike[int], cell_ind_min: ArrayLike[int], cell_ind_max: ArrayLike[int], sdf_tol: float, - ) -> Tuple[Tuple[ArrayLike, ...], ArrayLike]: + ) -> tuple[tuple[ArrayLike, ...], ArrayLike]: """For each cell listed in ``cell_inds`` perform interpolation at a rectilinear subarray of xyz_grid given by a (3D) index span (cell_ind_min, cell_ind_max). Parameters ---------- - xyz_grid : Tuple[ArrayLike[float], ...] + xyz_grid : tuple[ArrayLike[float], ...] x, y, and z coordiantes defining rectilinear grid. cell_inds : ArrayLike[int] Indices of cells to perfrom interpolation from. @@ -1404,7 +1420,7 @@ def _interp_py_chunk( Returns ------- - Tuple[Tuple[ArrayLike, ...], ArrayLike] + tuple[tuple[ArrayLike, ...], ArrayLike] x, y, and z indices of interpolated values and values themselves. """ @@ -1712,7 +1728,7 @@ def sel_inside(self, bounds: Bound) -> UnstructuredGridDataset: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -1779,7 +1795,7 @@ def does_cover(self, bounds: Bound) -> bool: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns diff --git a/tidy3d/components/data/unstructured/tetrahedral.py b/tidy3d/components/data/unstructured/tetrahedral.py index 6c7d8ffc16..74c14de630 100644 --- a/tidy3d/components/data/unstructured/tetrahedral.py +++ b/tidy3d/components/data/unstructured/tetrahedral.py @@ -5,15 +5,11 @@ from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import PositiveInt from xarray import DataArray as XrDataArray from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import ( - CellDataArray, - IndexedDataArray, - PointDataArray, -) +from tidy3d.components.data.data_array import CellDataArray, IndexedDataArray, PointDataArray from tidy3d.components.types import ArrayLike, Axis, Bound, Coordinate from tidy3d.exceptions import DataError from tidy3d.packaging import requires_vtk, vtk @@ -58,17 +54,17 @@ class TetrahedralGridDataset(UnstructuredGridDataset): """ Fundametal parameters to set up based on grid dimensionality """ @classmethod - def _traingular_dataset_type(cls) -> type: + def _triangular_dataset_type(cls) -> type: """Corresponding class for triangular grid datasets. We need to know this when creating a triangular slice from a tetrahedral grid.""" return TriangularGridDataset @classmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" return 3 @classmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" return 4 @@ -159,7 +155,7 @@ def plane_slice(self, axis: Axis, pos: float) -> TriangularGridDataset: slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) - return self._traingular_dataset_type()._from_vtk_obj( + return self._triangular_dataset_type()._from_vtk_obj( slice_vtk, remove_degenerate_cells=True, remove_unused_points=True, diff --git a/tidy3d/components/data/unstructured/triangular.py b/tidy3d/components/data/unstructured/triangular.py index 21c2416da9..f8ddbf96b6 100644 --- a/tidy3d/components/data/unstructured/triangular.py +++ b/tidy3d/components/data/unstructured/triangular.py @@ -2,10 +2,11 @@ from __future__ import annotations -from typing import Dict, Literal, Union +from typing import Literal, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveInt +from xarray import DataArray as XrDataArray try: from matplotlib import pyplot as plt @@ -13,8 +14,6 @@ except ImportError: pass -from xarray import DataArray as XrDataArray - from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( CellDataArray, @@ -72,14 +71,12 @@ class TriangularGridDataset(UnstructuredGridDataset): ... ) """ - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Grid Axis", description="Orientation of the grid.", ) - normal_pos: float = pd.Field( - ..., + normal_pos: float = Field( title="Position", description="Coordinate of the grid along the normal direction.", ) @@ -87,12 +84,12 @@ class TriangularGridDataset(UnstructuredGridDataset): """ Fundamental parameters to set up based on grid dimensionality """ @classmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" return 2 @classmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" return 3 @@ -577,8 +574,8 @@ def plot( vmin: float = None, vmax: float = None, shading: Literal["gourand", "flat"] = "gouraud", - cbar_kwargs: Dict = None, - pcolor_kwargs: Dict = None, + cbar_kwargs: dict = None, + pcolor_kwargs: dict = None, ) -> Ax: """Plot the data field and/or the unstructured grid. diff --git a/tidy3d/components/data/utils.py b/tidy3d/components/data/utils.py index 9e64765fd3..d0bf492f71 100644 --- a/tidy3d/components/data/utils.py +++ b/tidy3d/components/data/utils.py @@ -7,7 +7,7 @@ import numpy as np import xarray as xr -from ..types import ArrayLike, annotate_type +from ..types import ArrayLike, discriminated_union from .data_array import DataArray, SpatialDataArray from .unstructured.base import UnstructuredGridDataset from .unstructured.tetrahedral import TetrahedralGridDataset @@ -16,7 +16,10 @@ UnstructuredGridDatasetType = Union[TriangularGridDataset, TetrahedralGridDataset] CustomSpatialDataType = Union[SpatialDataArray, UnstructuredGridDatasetType] -CustomSpatialDataTypeAnnotated = Union[SpatialDataArray, annotate_type(UnstructuredGridDatasetType)] +CustomSpatialDataTypeAnnotated = Union[ + discriminated_union(UnstructuredGridDatasetType), + SpatialDataArray, +] def _get_numpy_array(data_array: Union[ArrayLike, DataArray, UnstructuredGridDataset]) -> ArrayLike: diff --git a/tidy3d/components/data/validators.py b/tidy3d/components/data/validators.py index 98d49ef56b..cc34764a02 100644 --- a/tidy3d/components/data/validators.py +++ b/tidy3d/components/data/validators.py @@ -1,7 +1,7 @@ # special validators for Datasets import numpy as np -import pydantic.v1 as pd +from pydantic import field_validator from ...exceptions import ValidationError from .data_array import DataArray @@ -9,11 +9,11 @@ # this can't go in validators.py because that file imports dataset.py -def validate_no_nans(field_name: str): +def validate_no_nans(*field_names: str): """Raise validation error if nans found in Dataset, or other data-containing item.""" - @pd.validator(field_name, always=True, allow_reuse=True) - def no_nans(cls, val): + @field_validator(*field_names) + def no_nans(val, info): """Raise validation error if nans found in Dataset, or other data-containing item.""" if val is None: @@ -44,15 +44,15 @@ def has_nans(values) -> bool: else: if has_nans(value): # the identifier is used to make the message more clear by appending some more info - field_name_display = field_name + field_name_display = info.field_name if identifier: field_name_display += identifier raise ValidationError( - f"Found NaN values in '{field_name_display}'. " + f"Found 'NaN' values in '{field_name_display}'. " "If they were not intended, please double check your construction. " - "If intended, to replace these data points with a value 'x'," - " call 'values = np.nan_to_num(values, nan=x)'." + "If intended, to replace these data points with a value 'x', " + "call 'values = np.nan_to_num(values, nan=x)'." ) error_if_has_nans(val) @@ -61,11 +61,11 @@ def has_nans(values) -> bool: return no_nans -def validate_can_interpolate(field_name: str): - """Make sure the data in 'field_name' can be interpolated.""" +def validate_can_interpolate(*field_names: str): + """Make sure the data in ``field_name`` can be interpolated.""" - @pd.validator(field_name, always=True, allow_reuse=True) - def check_fields_interpolate(cls, val: AbstractFieldDataset) -> AbstractFieldDataset: + @field_validator(*field_names) + def check_fields_interpolate(val: AbstractFieldDataset) -> AbstractFieldDataset: if isinstance(val, AbstractFieldDataset): for name, data in val.field_components.items(): if isinstance(data, ScalarFieldDataArray): diff --git a/tidy3d/components/data/zbf.py b/tidy3d/components/data/zbf.py index 947953d00c..7e4e4dbddf 100644 --- a/tidy3d/components/data/zbf.py +++ b/tidy3d/components/data/zbf.py @@ -5,7 +5,7 @@ from struct import unpack import numpy as np -import pydantic.v1 as pd +from pydantic import Field from ..base import Tidy3dBaseModel @@ -15,58 +15,58 @@ class ZBFData(Tidy3dBaseModel): Contains data read in from a ``.zbf`` file """ - version: int = pd.Field(title="Version", description="File format version number.") - nx: int = pd.Field(title="Samples in X", description="Number of samples in the x direction.") - ny: int = pd.Field(title="Samples in Y", description="Number of samples in the y direction.") - ispol: bool = pd.Field( + version: int = Field(title="Version", description="File format version number.") + nx: int = Field(title="Samples in X", description="Number of samples in the x direction.") + ny: int = Field(title="Samples in Y", description="Number of samples in the y direction.") + ispol: bool = Field( title="Is Polarized", description="``True`` if the beam is polarized, ``False`` otherwise.", ) - unit: str = pd.Field( + unit: str = Field( title="Spatial Units", description="Spatial units, either 'mm', 'cm', 'in', or 'm'." ) - dx: float = pd.Field(title="Grid Spacing, X", description="Grid spacing in x.") - dy: float = pd.Field(title="Grid Spacing, Y", description="Grid spacing in y.") - zposition_x: float = pd.Field( + dx: float = Field(title="Grid Spacing, X", description="Grid spacing in x.") + dy: float = Field(title="Grid Spacing, Y", description="Grid spacing in y.") + zposition_x: float = Field( title="Z Position, X Direction", description="The pilot beam z position with respect to the pilot beam waist, x direction.", ) - zposition_y: float = pd.Field( + zposition_y: float = Field( title="Z Position, Y Direction", description="The pilot beam z position with respect to the pilot beam waist, y direction.", ) - rayleigh_x: float = pd.Field( + rayleigh_x: float = Field( title="Rayleigh Distance, X Direction", description="The pilot beam Rayleigh distance in the x direction.", ) - rayleigh_y: float = pd.Field( + rayleigh_y: float = Field( title="Rayleigh Distance, Y Direction", description="The pilot beam Rayleigh distance in the y direction.", ) - waist_x: float = pd.Field( + waist_x: float = Field( title="Beam Waist, X", description="The pilot beam waist in the x direction." ) - waist_y: float = pd.Field( + waist_y: float = Field( title="Beam Waist, Y", description="The pilot beam waist in the y direction." ) - wavelength: float = pd.Field(..., title="Wavelength", description="The wavelength of the beam.") - background_refractive_index: float = pd.Field( + wavelength: float = Field(title="Wavelength", description="The wavelength of the beam.") + background_refractive_index: float = Field( title="Background Refractive Index", description="The index of refraction in the current medium.", ) - receiver_eff: float = pd.Field( + receiver_eff: float = Field( title="Receiver Efficiency", description="The receiver efficiency. Zero if fiber coupling is not computed.", ) - system_eff: float = pd.Field( + system_eff: float = Field( title="System Efficiency", description="The system efficiency. Zero if fiber coupling is not computed.", ) - Ex: np.ndarray = pd.Field( + Ex: np.ndarray = Field( title="Electric Field, X Component", description="Complex-valued electric field, x component.", ) - Ey: np.ndarray = pd.Field( + Ey: np.ndarray = Field( title="Electric Field, Y Component", description="Complex-valued electric field, y component.", ) diff --git a/tidy3d/components/dispersion_fitter.py b/tidy3d/components/dispersion_fitter.py index 07fecf866e..b7f5974c39 100644 --- a/tidy3d/components/dispersion_fitter.py +++ b/tidy3d/components/dispersion_fitter.py @@ -2,17 +2,24 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Optional import numpy as np import scipy -from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from rich.progress import Progress from ..constants import fp_eps from ..exceptions import ValidationError from ..log import get_logging_console, log -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .types import ArrayComplex1D, ArrayComplex2D, ArrayFloat1D, ArrayFloat2D # numerical tolerance for pole relocation for fast fitter @@ -111,7 +118,7 @@ def _extrema_loss_freq_finder(areal, aimag, creal, cimag): class AdvancedFastFitterParam(Tidy3dBaseModel): """Advanced fast fitter parameters.""" - loss_bounds: Tuple[float, float] = Field( + loss_bounds: tuple[float, float] = Field( (0, np.inf), title="Loss bounds", description="Bounds (lower, upper) on Im[resp]. Default corresponds to only passivity. " @@ -121,7 +128,7 @@ class AdvancedFastFitterParam(Tidy3dBaseModel): "A finite upper bound may be helpful when fitting lossless materials. " "In this case, consider also increasing the weight for fitting the imaginary part.", ) - weights: Tuple[NonNegativeFloat, NonNegativeFloat] = Field( + weights: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Weights", description="Weights (real, imag) in objective function for fitting. The weights " @@ -187,8 +194,8 @@ class AdvancedFastFitterParam(Tidy3dBaseModel): "There will be a warning if this value is too small.", ) - @validator("loss_bounds", always=True) - def _max_loss_geq_min_loss(cls, val): + @field_validator("loss_bounds") + def _max_loss_geq_min_loss(val): """Must have max_loss >= min_loss.""" if val[0] > val[1]: raise ValidationError( @@ -196,8 +203,8 @@ def _max_loss_geq_min_loss(cls, val): ) return val - @validator("weights", always=True) - def _weights_average_to_one(cls, val): + @field_validator("weights") + def _weights_average_to_one(val): """Weights must average to one.""" if val is None: return None @@ -210,25 +217,39 @@ class FastFitterData(AdvancedFastFitterParam): """Data class for internal use while running fitter.""" omega: ArrayComplex1D = Field( - ..., title="Angular frequencies in eV", description="Angular frequencies in eV" + title="Angular frequencies in eV", + description="Angular frequencies in eV", + ) + eps: ArrayComplex1D = Field( + title="Permittivity", + description="Permittivity to fit", ) - eps: ArrayComplex1D = Field(..., title="Permittivity", description="Permittivity to fit") - optimize_eps_inf: bool = Field( - None, title="Optimize eps_inf", description="Whether to optimize ``eps_inf``." + optimize_eps_inf: Optional[bool] = Field( + None, + title="Optimize eps_inf", + description="Whether to optimize ``eps_inf``.", ) - num_poles: PositiveInt = Field(None, title="Number of poles", description="Number of poles") - eps_inf: float = Field( + num_poles: Optional[PositiveInt] = Field( + None, + title="Number of poles", + description="Number of poles", + ) + eps_inf: Optional[float] = Field( None, title="eps_inf", description="Value of ``eps_inf``.", ) poles: Optional[ArrayComplex1D] = Field( - None, title="Pole frequencies in eV", description="Pole frequencies in eV" + None, + title="Pole frequencies in eV", + description="Pole frequencies in eV", ) residues: Optional[ArrayComplex1D] = Field( - None, title="Residues in eV", description="Residues in eV" + None, + title="Residues in eV", + description="Residues in eV", ) passivity_optimized: Optional[bool] = Field( @@ -253,37 +274,31 @@ class FastFitterData(AdvancedFastFitterParam): ) scale_factor: PositiveFloat = Field( - ..., title="Scale Factor", description="Factor by which frequencies have been rescaled prior to fitting. " "The ``pole_residue`` model returned will be rescaled by the inverse of this factor " "in order to restore it to the original units.", ) - @validator("eps_inf", always=True) - @skip_if_fields_missing(["optimize_eps_inf"]) - def _eps_inf_geq_one(cls, val, values): + @model_validator(mode="after") + def _eps_inf_geq_one(self): """Must have eps_inf >= 1 unless it is being optimized. In the latter case, it will be made >= 1 later.""" - if values["optimize_eps_inf"] is False and val < 1: + if self.optimize_eps_inf is False and self.eps_inf < 1: raise ValidationError("The value of 'eps_inf' must be at least 1.") - return val + return self - @validator("poles", always=True) - @skip_if_fields_missing(["logspacing", "smooth", "num_poles", "omega", "num_poles"]) - def _generate_initial_poles(cls, val, values): + @model_validator(mode="after") + def _generate_initial_poles(self): """Generate initial poles.""" + val = self.poles if val is not None: - return val - if ( - values.get("logspacing") is None - or values.get("smooth") is None - or values.get("num_poles") is None - ): - return None - omega = values["omega"] - num_poles = values["num_poles"] - if values["logspacing"]: + return self + if self.logspacing is None or self.smooth is None or self.num_poles is None: + return self + omega = self.omega + num_poles = self.num_poles + if self.logspacing: pole_range = np.logspace( np.log10(min(omega) / SCALE_FACTOR), np.log10(max(omega) * SCALE_FACTOR), num_poles ) @@ -291,22 +306,22 @@ def _generate_initial_poles(cls, val, values): pole_range = np.linspace( min(omega) / SCALE_FACTOR, max(omega) * SCALE_FACTOR, num_poles ) - if values["smooth"]: + if self.smooth: poles = -pole_range else: poles = -pole_range / 100 + 1j * pole_range - return poles + object.__setattr__(self, "poles", poles) + return self - @validator("residues", always=True) - @skip_if_fields_missing(["poles"]) - def _generate_initial_residues(cls, val, values): + @model_validator(mode="after") + def _generate_initial_residues(self): """Generate initial residues.""" - if val is not None: - return val - poles = values.get("poles") - if poles is None: - return None - return np.zeros(len(poles)) + if self.residues is not None: + return self + if self.poles is None: + return self + object.__setattr__(self, "residues", np.zeros(len(self.poles))) + return self @classmethod def initialize( @@ -350,7 +365,7 @@ def complex_poles(self) -> ArrayFloat1D: return self.poles[np.iscomplex(self.poles)] @classmethod - def get_default_weights(cls, eps: ArrayComplex1D) -> Tuple[float, float]: + def get_default_weights(cls, eps: ArrayComplex1D) -> tuple[float, float]: """Default weights based on real and imaginary part of eps.""" rms = np.array([np.sqrt(np.mean(x**2)) for x in (np.real(eps), np.imag(eps))]) rms = np.maximum(RMS_MIN, rms) @@ -360,7 +375,7 @@ def get_default_weights(cls, eps: ArrayComplex1D) -> Tuple[float, float]: return tuple(weights) @cached_property - def pole_residue(self) -> Tuple[float, ArrayComplex1D, ArrayComplex1D]: + def pole_residue(self) -> tuple[float, ArrayComplex1D, ArrayComplex1D]: """Parameters for pole-residue model in original units.""" if self.eps_inf is None or self.poles is None: return 1, [], [] @@ -647,7 +662,7 @@ def iterate_fit(self) -> FastFitterData: return model - def iterate_passivity(self, passivity_omega: ArrayFloat1D) -> Tuple[FastFitterData, int]: + def iterate_passivity(self, passivity_omega: ArrayFloat1D) -> tuple[FastFitterData, int]: """Iterate passivity enforcement algorithm.""" size = len(self.real_poles) + 2 * len(self.complex_poles) @@ -724,7 +739,7 @@ def enforce_passivity( def _fit_fixed_parameters( - num_poles_range: Tuple[PositiveInt, PositiveInt], model: FastFitterData + num_poles_range: tuple[PositiveInt, PositiveInt], model: FastFitterData ) -> FastFitterData: def fit_non_passive(model: FastFitterData) -> FastFitterData: best_model = model @@ -759,7 +774,7 @@ def fit( tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS, advanced_param: AdvancedFastFitterParam = None, scale_factor: PositiveFloat = 1, -) -> Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float]: +) -> tuple[tuple[float, ArrayComplex1D, ArrayComplex1D], float]: """Fit data using a fast fitting algorithm. Note @@ -818,7 +833,7 @@ def fit( Returns ------- - Tuple[Tuple[float, ArrayComplex1D, ArrayComplex1D], float] + tuple[tuple[float, ArrayComplex1D, ArrayComplex1D], float] Best fitting result: (dispersive medium parameters, weighted RMS error). The dispersive medium parameters have the form (resp_inf, poles, residues) and are in the original unscaled units. @@ -963,7 +978,4 @@ def make_configs(): best_model.unweighted_rms_error, ) - return ( - best_model.pole_residue, - best_model.rms_error, - ) + return best_model.pole_residue, best_model.rms_error diff --git a/tidy3d/components/eme/data/dataset.py b/tidy3d/components/eme/data/dataset.py index f0dcd772c8..eb9bb37fba 100644 --- a/tidy3d/components/eme/data/dataset.py +++ b/tidy3d/components/eme/data/dataset.py @@ -1,8 +1,8 @@ """EME dataset""" -from __future__ import annotations +from typing import Optional -import pydantic.v1 as pd +from pydantic import Field from ...data.data_array import ( EMECoefficientDataArray, @@ -17,23 +17,19 @@ class EMESMatrixDataset(Dataset): """Dataset storing S matrix.""" - S11: EMESMatrixDataArray = pd.Field( - ..., + S11: EMESMatrixDataArray = Field( title="S11 matrix", description="S matrix relating output modes at port 1 to input modes at port 1.", ) - S12: EMESMatrixDataArray = pd.Field( - ..., + S12: EMESMatrixDataArray = Field( title="S12 matrix", description="S matrix relating output modes at port 1 to input modes at port 2.", ) - S21: EMESMatrixDataArray = pd.Field( - ..., + S21: EMESMatrixDataArray = Field( title="S21 matrix", description="S matrix relating output modes at port 2 to input modes at port 1.", ) - S22: EMESMatrixDataArray = pd.Field( - ..., + S22: EMESMatrixDataArray = Field( title="S22 matrix", description="S matrix relating output modes at port 2 to input modes at port 2.", ) @@ -44,13 +40,11 @@ class EMECoefficientDataset(Dataset): These are defined at the cell centers. """ - A: EMECoefficientDataArray = pd.Field( - ..., + A: EMECoefficientDataArray = Field( title="A coefficient", description="Coefficient for forward mode in this cell.", ) - B: EMECoefficientDataArray = pd.Field( - ..., + B: EMECoefficientDataArray = Field( title="B coefficient", description="Coefficient for backward mode in this cell.", ) @@ -59,32 +53,32 @@ class EMECoefficientDataset(Dataset): class EMEFieldDataset(ElectromagneticFieldDataset): """Dataset storing scalar components of E and H fields as a function of freq, mode_index, and port_index.""" - Ex: EMEScalarFieldDataArray = pd.Field( + Ex: Optional[EMEScalarFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: EMEScalarFieldDataArray = pd.Field( + Ey: Optional[EMEScalarFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: EMEScalarFieldDataArray = pd.Field( + Ez: Optional[EMEScalarFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: EMEScalarFieldDataArray = pd.Field( + Hx: Optional[EMEScalarFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: EMEScalarFieldDataArray = pd.Field( + Hy: Optional[EMEScalarFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: EMEScalarFieldDataArray = pd.Field( + Hz: Optional[EMEScalarFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", @@ -94,39 +88,32 @@ class EMEFieldDataset(ElectromagneticFieldDataset): class EMEModeSolverDataset(ElectromagneticFieldDataset): """Dataset storing EME modes as a function of freq, mode_index, and cell_index.""" - n_complex: EMEModeIndexDataArray = pd.Field( - ..., + n_complex: EMEModeIndexDataArray = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - Ex: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ex: EMEScalarModeFieldDataArray = Field( title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ey: EMEScalarModeFieldDataArray = Field( title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ez: EMEScalarModeFieldDataArray = Field( title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hx: EMEScalarModeFieldDataArray = Field( title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hy: EMEScalarModeFieldDataArray = Field( title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hz: EMEScalarModeFieldDataArray = Field( title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) diff --git a/tidy3d/components/eme/data/monitor_data.py b/tidy3d/components/eme/data/monitor_data.py index 7014562e11..75242158ce 100644 --- a/tidy3d/components/eme/data/monitor_data.py +++ b/tidy3d/components/eme/data/monitor_data.py @@ -1,10 +1,8 @@ """EME monitor data""" -from __future__ import annotations - from typing import Union -import pydantic.v1 as pd +from pydantic import Field from ...base_sim.data.monitor_data import AbstractMonitorData from ...data.monitor_data import ElectromagneticFieldData, ModeSolverData, PermittivityData @@ -15,8 +13,7 @@ class EMEModeSolverData(ElectromagneticFieldData, EMEModeSolverDataset): """Data associated with an EME mode solver monitor.""" - monitor: EMEModeSolverMonitor = pd.Field( - ..., + monitor: EMEModeSolverMonitor = Field( title="EME Mode Solver Monitor", description="EME mode solver monitor associated with this data.", ) @@ -25,16 +22,16 @@ class EMEModeSolverData(ElectromagneticFieldData, EMEModeSolverDataset): class EMEFieldData(ElectromagneticFieldData, EMEFieldDataset): """Data associated with an EME field monitor.""" - monitor: EMEFieldMonitor = pd.Field( - ..., title="EME Field Monitor", description="EME field monitor associated with this data." + monitor: EMEFieldMonitor = Field( + title="EME Field Monitor", + description="EME field monitor associated with this data.", ) class EMECoefficientData(AbstractMonitorData, EMECoefficientDataset): """Data associated with an EME coefficient monitor.""" - monitor: EMECoefficientMonitor = pd.Field( - ..., + monitor: EMECoefficientMonitor = Field( title="EME Coefficient Monitor", description="EME coefficient monitor associated with this data.", ) diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index ef15a45c55..1734759660 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -1,18 +1,16 @@ """EME simulation data""" -from __future__ import annotations - -from typing import List, Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field from ....exceptions import SetupError from ...base import cached_property from ...data.data_array import EMEScalarFieldDataArray, EMESMatrixDataArray from ...data.monitor_data import FieldData, ModeData, ModeSolverData from ...data.sim_data import AbstractYeeGridSimulationData -from ...types import annotate_type +from ...types import discriminated_union from ..simulation import EMESimulation from .dataset import EMESMatrixDataset from .monitor_data import EMEFieldData, EMEModeSolverData, EMEMonitorDataType @@ -21,22 +19,24 @@ class EMESimulationData(AbstractYeeGridSimulationData): """Data associated with an EME simulation.""" - simulation: EMESimulation = pd.Field( - ..., title="EME simulation", description="EME simulation associated with this data." + simulation: EMESimulation = Field( + title="EME simulation", + description="EME simulation associated with this data.", ) - data: Tuple[annotate_type(EMEMonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(EMEMonitorDataType), ...] = Field( title="Monitor Data", description="List of EME monitor data " "associated with the monitors of the original :class:`.EMESimulation`.", ) - smatrix: Optional[EMESMatrixDataset] = pd.Field( - None, title="S Matrix", description="Scattering matrix of the EME simulation." + smatrix: Optional[EMESMatrixDataset] = Field( + None, + title="S Matrix", + description="Scattering matrix of the EME simulation.", ) - port_modes: Optional[EMEModeSolverData] = pd.Field( + port_modes: Optional[EMEModeSolverData] = Field( None, title="Port Modes", description="Modes associated with the two ports of the EME device. " @@ -78,7 +78,7 @@ def _extract_mode_solver_data( return ModeSolverData(**update_dict, monitor=monitor, grid_expanded=grid_expanded) @cached_property - def port_modes_tuple(self) -> Tuple[ModeSolverData, ModeSolverData]: + def port_modes_tuple(self) -> tuple[ModeSolverData, ModeSolverData]: """Port modes as a tuple ``(port_modes_1, port_modes_2)``.""" if self.port_modes is None: raise SetupError( @@ -101,7 +101,7 @@ def port_modes_tuple(self) -> Tuple[ModeSolverData, ModeSolverData]: return port_modes_1, port_modes_2 @cached_property - def port_modes_list_sweep(self) -> List[Tuple[ModeSolverData, ModeSolverData]]: + def port_modes_list_sweep(self) -> list[tuple[ModeSolverData, ModeSolverData]]: """Port modes as a list of tuples ``(port_modes_1, port_modes_2)``. There is one entry for every sweep index if the port modes vary with sweep index.""" if self.port_modes is None: diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index 18c0403b3e..d782e6f819 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -3,14 +3,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + field_validator, + model_validator, +) from ...constants import RADIAN, fp_eps, inf from ...exceptions import SetupError, ValidationError -from ..base import Tidy3dBaseModel, skip_if_fields_missing +from ..base import Tidy3dBaseModel from ..geometry.base import Box from ..grid.grid import Coords1D from ..mode_spec import ModeSpec @@ -26,7 +33,7 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - track_freq: Union[TrackFreq, None] = pd.Field( + track_freq: Optional[Union[TrackFreq]] = Field( None, title="Mode Tracking Frequency", description="Parameter that turns on/off mode tracking based on their similarity. " @@ -35,7 +42,7 @@ class EMEModeSpec(ModeSpec): "If ``None`` no mode tracking is performed, which is the default for best performance.", ) - angle_theta: Literal[0.0] = pd.Field( + angle_theta: Literal[0.0] = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis. Not currently " @@ -44,7 +51,7 @@ class EMEModeSpec(ModeSpec): units=RADIAN, ) - angle_phi: Literal[0.0] = pd.Field( + angle_phi: Literal[0.0] = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -98,7 +105,7 @@ class EMEGridSpec(Tidy3dBaseModel, ABC): in the simulation. """ - num_reps: pd.PositiveInt = pd.Field( + num_reps: PositiveInt = Field( 1, title="Number of Repetitions", description="Number of periodic repetitions of this EME grid. Useful for " @@ -107,12 +114,14 @@ class EMEGridSpec(Tidy3dBaseModel, ABC): "the EME solver to reuse the modes and cell interface scattering matrices.", ) - name: Optional[str] = pd.Field( - None, title="Name", description="Name of this 'EMEGridSpec'. Used in 'EMEPeriodicitySweep'." + name: Optional[str] = Field( + None, + title="Name", + description="Name of this 'EMEGridSpec'. Used in 'EMEPeriodicitySweep'.", ) - @pd.validator("num_reps", always=True) - def _validate_num_reps(cls, val): + @field_validator("num_reps") + def _validate_num_reps(val): """Check num_reps is not too large.""" if val > MAX_NUM_REPS: raise SetupError( @@ -163,7 +172,7 @@ def num_virtual_cells(self) -> int: """Number of virtual cells in this EME grid spec.""" return len(self.virtual_cell_indices) - def _updated_copy_num_reps(self, num_reps: Dict[str, pd.PositiveInt]) -> EMEGridSpec: + def _updated_copy_num_reps(self, num_reps: dict[str, PositiveInt]) -> EMEGridSpec: """Update ``num_reps`` of named subgrids.""" if self.name is not None: new_num_reps = num_reps.get(self.name) @@ -172,7 +181,7 @@ def _updated_copy_num_reps(self, num_reps: Dict[str, pd.PositiveInt]) -> EMEGrid return self @property - def _cell_index_pairs(self) -> List[pd.NonNegativeInt]: + def _cell_index_pairs(self) -> list[NonNegativeInt]: """Pairs of adjacent cell indices.""" cell_indices = self.virtual_cell_indices pairs = [] @@ -192,12 +201,14 @@ class EMEUniformGrid(EMEGridSpec): >>> eme_grid = EMEUniformGrid(num_cells=10, mode_spec=mode_spec) """ - num_cells: pd.PositiveInt = pd.Field( - ..., title="Number of cells", description="Number of cells in the uniform EME grid." + num_cells: PositiveInt = Field( + title="Number of cells", + description="Number of cells in the uniform EME grid.", ) - mode_spec: EMEModeSpec = pd.Field( - ..., title="Mode Specification", description="Mode specification for the uniform EME grid." + mode_spec: EMEModeSpec = Field( + title="Mode Specification", + description="Mode specification for the uniform EME grid.", ) def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: @@ -246,14 +257,12 @@ class EMEExplicitGrid(EMEGridSpec): ... ) """ - mode_specs: List[EMEModeSpec] = pd.Field( - ..., + mode_specs: list[EMEModeSpec] = Field( title="Mode Specifications", description="Mode specifications for each cell " "in the explicit EME grid.", ) - boundaries: ArrayFloat1D = pd.Field( - ..., + boundaries: ArrayFloat1D = Field( title="Boundaries", description="List of coordinates of internal cell boundaries along the propagation axis. " "Must contain one fewer item than 'mode_specs', and must be strictly increasing. " @@ -262,11 +271,11 @@ class EMEExplicitGrid(EMEGridSpec): "and the simulation boundary.", ) - @pd.validator("boundaries", always=True) - @skip_if_fields_missing(["mode_specs"]) - def _validate_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_boundaries(self): """Check that boundaries is increasing and contains one fewer element than mode_specs.""" - mode_specs = values["mode_specs"] + val = self.boundaries + mode_specs = self.mode_specs boundaries = val if len(mode_specs) - 1 != len(boundaries): raise ValidationError( @@ -278,7 +287,7 @@ def _validate_boundaries(cls, val, values): if rmax < rmin: raise ValidationError("The 'boundaries' must be increasing.") rmin = rmax - return val + return self def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: """Generate EME grid from the EME grid spec. @@ -321,7 +330,7 @@ def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: @classmethod def from_structures( - cls, structures: List[Structure], axis: Axis, mode_spec: EMEModeSpec, **kwargs + cls, structures: list[Structure], axis: Axis, mode_spec: EMEModeSpec, **kwargs ) -> EMEExplicitGrid: """Create an explicit EME grid with boundaries aligned with structure bounding boxes. Every cell in the resulting grid @@ -329,7 +338,7 @@ def from_structures( Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] A list of structures to define the :class:`.EMEExplicitGrid`. The EME grid boundaries will be placed at the lower and upper bounds of the bounding boxes of all the structures in the list. @@ -398,12 +407,12 @@ class EMECompositeGrid(EMEGridSpec): ... ) """ - subgrids: List[EMESubgridType] = pd.Field( - ..., title="Subgrids", description="Subgrids in the composite grid." + subgrids: list[EMESubgridType] = Field( + title="Subgrids", + description="Subgrids in the composite grid.", ) - subgrid_boundaries: ArrayFloat1D = pd.Field( - ..., + subgrid_boundaries: ArrayFloat1D = Field( title="Subgrid Boundaries", description="List of coordinates of internal subgrid boundaries along the propagation axis. " "Must contain one fewer item than 'subgrids', and must be strictly increasing. " @@ -412,10 +421,11 @@ class EMECompositeGrid(EMEGridSpec): "and the simulation boundary.", ) - @pd.validator("subgrid_boundaries", always=True) - def _validate_subgrid_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_subgrid_boundaries(self): """Check that subgrid boundaries is increasing and contains one fewer element than subgrids.""" - subgrids = values["subgrids"] + val = self.subgrid_boundaries + subgrids = self.subgrids subgrid_boundaries = val if len(subgrids) - 1 != len(subgrid_boundaries): raise ValidationError( @@ -427,11 +437,11 @@ def _validate_subgrid_boundaries(cls, val, values): if rmax < rmin: raise ValidationError("The 'subgrid_boundaries' must be increasing.") rmin = rmax - return val + return self def subgrid_bounds( self, center: Coordinate, size: Size, axis: Axis - ) -> List[Tuple[float, float]]: + ) -> list[tuple[float, float]]: """Subgrid bounds: a list of pairs (rmin, rmax) of the bounds of the subgrids along the propagation axis. @@ -446,7 +456,7 @@ def subgrid_bounds( Returns ------- - List[Tuple[float, float]] + list[tuple[float, float]] A list of pairs (rmin, rmax) of the bounds of the subgrids along the propagation axis. """ @@ -524,7 +534,7 @@ def virtual_cell_indices(self) -> int: inds += [ind + start_ind for ind in subgrid.virtual_cell_indices] return list(inds) * self.num_reps - def _updated_copy_num_reps(self, num_reps: Dict[str, pd.PositiveInt]) -> EMEGridSpec: + def _updated_copy_num_reps(self, num_reps: dict[str, PositiveInt]) -> EMEGridSpec: """Update ``num_reps`` of named subgrids.""" new_self = super()._updated_copy_num_reps(num_reps=num_reps) new_subgrids = [ @@ -535,18 +545,18 @@ def _updated_copy_num_reps(self, num_reps: Dict[str, pd.PositiveInt]) -> EMEGrid @classmethod def from_structure_groups( cls, - structure_groups: List[List[Structure]], + structure_groups: list[list[Structure]], axis: Axis, - mode_specs: List[EMEModeSpec], - names: List[str] = None, - num_reps: List[pd.PositiveInt] = None, + mode_specs: list[EMEModeSpec], + names: list[str] = None, + num_reps: list[PositiveInt] = None, ) -> EMECompositeGrid: """Create a composite EME grid with boundaries aligned with structure bounding boxes. Parameters ---------- - structure_groups : List[List[:class:`.Structure`]] + structure_groups : list[list[:class:`.Structure`]] A list of structure groups to define the :class:`.EMECompositeGrid`. Each structure group will be used to generate an :class:`.EMEExplicitGrid` with boundaries aligned with the bounding boxes of the structures @@ -559,13 +569,13 @@ def from_structure_groups( Two adjacent structure groups cannot be empty. axis : :class:`.Axis` Propagation axis for the EME simulation. - mode_specs : List[:class:`.EMEModeSpec`] + mode_specs : list[:class:`.EMEModeSpec`] Mode specifications for each subgrid. Must be the same length as ``structure_groups``. - names : List[str] = None + names : list[str] = None Names for each subgrid. Must be the same length as ``structure_groups``. If ``None``, the subgrids do not recieve names. - num_reps : List[pd.PositiveInt] = None + num_reps : list[PositiveInt] = None Number of repetitions for each subgrid. Must be the same length as ``structure_groups``. If ``None``, the subgrids are not repeated. @@ -672,20 +682,23 @@ class EMEGrid(Box): in the simulation. """ - axis: Axis = pd.Field( - ..., title="Propagation axis", description="Propagation axis for the EME simulation." + axis: Axis = Field( + title="Propagation axis", + description="Propagation axis for the EME simulation.", ) - mode_specs: List[EMEModeSpec] = pd.Field( - ..., title="Mode Specifications", description="Mode specifications for the EME cells." + mode_specs: list[EMEModeSpec] = Field( + title="Mode Specifications", + description="Mode specifications for the EME cells.", ) - boundaries: Coords1D = pd.Field( - ..., title="Cell boundaries", description="Boundary coordinates of the EME cells." + boundaries: Coords1D = Field( + title="Cell boundaries", + description="Boundary coordinates of the EME cells.", ) - @pd.validator("mode_specs", always=True) - def _validate_size(cls, val): + @field_validator("mode_specs") + def _validate_size(val): """Check grid size and num modes.""" num_eme_cells = len(val) if num_eme_cells > MAX_NUM_EME_CELLS: @@ -702,16 +715,15 @@ def _validate_size(cls, val): ) return val - @pd.validator("boundaries", always=True, pre=False) - @skip_if_fields_missing(["mode_specs", "axis", "center", "size"]) - def _validate_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_boundaries(self): """Check that boundaries is increasing, in simulation domain, and contains one more element than 'mode_specs'.""" - mode_specs = values["mode_specs"] - boundaries = val - axis = values["axis"] - center = values["center"][axis] - size = values["size"][axis] + boundaries = self.boundaries + mode_specs = self.mode_specs + axis = self.axis + center = self.center[axis] + size = self.size[axis] sim_rmin = center - size / 2 sim_rmax = center + size / 2 if len(mode_specs) + 1 != len(boundaries): @@ -730,7 +742,7 @@ def _validate_boundaries(cls, val, values): rmin = rmax if rmax - sim_rmax > fp_eps: raise ValidationError("The last item in 'boundaries' is outside the simulation domain.") - return val + return self @property def centers(self) -> Coords1D: @@ -744,7 +756,7 @@ def centers(self) -> Coords1D: return centers @property - def lengths(self) -> List[pd.NonNegativeFloat]: + def lengths(self) -> list[NonNegativeFloat]: """Lengths of the EME cells along the propagation axis.""" rmin = self.boundaries[0] lengths = [] @@ -755,12 +767,12 @@ def lengths(self) -> List[pd.NonNegativeFloat]: return lengths @property - def num_cells(self) -> pd.NonNegativeInteger: + def num_cells(self) -> NonNegativeInt: """The number of cells in the EME grid.""" return len(self.centers) @property - def mode_planes(self) -> List[Box]: + def mode_planes(self) -> list[Box]: """Planes for mode solving, aligned with cell centers.""" size = [inf, inf, inf] center = list(self.center) @@ -773,7 +785,7 @@ def mode_planes(self) -> List[Box]: return mode_planes @property - def boundary_planes(self) -> List[Box]: + def boundary_planes(self) -> list[Box]: """Planes aligned with cell boundaries.""" size = list(self.size) center = list(self.center) @@ -786,7 +798,7 @@ def boundary_planes(self) -> List[Box]: return boundary_planes @property - def cells(self) -> List[Box]: + def cells(self) -> list[Box]: """EME cells in the grid. Each cell is a :class:`.Box`.""" size = list(self.size) center = list(self.center) @@ -798,7 +810,7 @@ def cells(self) -> List[Box]: cells.append(Box(center=center, size=size)) return cells - def cell_indices_in_box(self, box: Box) -> List[pd.NonNegativeInteger]: + def cell_indices_in_box(self, box: Box) -> list[NonNegativeInt]: """Indices of cells that overlap with 'box'. Used to determine which data is recorded by a monitor. @@ -809,7 +821,7 @@ def cell_indices_in_box(self, box: Box) -> List[pd.NonNegativeInteger]: Returns ------- - List[pd.NonNegativeInteger] + list[NonNegativeInteger] The indices of the cells that intersect the provided box. """ indices = [] diff --git a/tidy3d/components/eme/monitor.py b/tidy3d/components/eme/monitor.py index aad5e5eb7f..3ffbea3b00 100644 --- a/tidy3d/components/eme/monitor.py +++ b/tidy3d/components/eme/monitor.py @@ -1,11 +1,9 @@ """EME monitors""" -from __future__ import annotations - from abc import ABC, abstractmethod -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt, PositiveInt from ..base_sim.monitor import AbstractMonitor from ..monitor import AbstractFieldMonitor, ModeSolverMonitor, PermittivityMonitor @@ -17,7 +15,7 @@ class EMEMonitor(AbstractMonitor, ABC): """Abstract EME monitor.""" - freqs: Optional[FreqArray] = pd.Field( + freqs: Optional[FreqArray] = Field( None, title="Monitor Frequencies", description="Frequencies at which the monitor will record. " @@ -25,7 +23,7 @@ class EMEMonitor(AbstractMonitor, ABC): "A value of 'None' will record at all simulation 'freqs'.", ) - num_modes: Optional[pd.NonNegativeInt] = pd.Field( + num_modes: Optional[NonNegativeInt] = Field( None, title="Number of Modes", description="Maximum number of modes for the monitor to record. " @@ -33,7 +31,7 @@ class EMEMonitor(AbstractMonitor, ABC): "A value of 'None' will record all modes.", ) - num_sweep: Optional[pd.NonNegativeInt] = pd.Field( + num_sweep: Optional[NonNegativeInt] = Field( 1, title="Number of Sweep Indices", description="Number of sweep indices for the monitor to record. " @@ -42,7 +40,7 @@ class EMEMonitor(AbstractMonitor, ABC): "will be omitted. A value of 'None' will record all sweep indices.", ) - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -51,7 +49,7 @@ class EMEMonitor(AbstractMonitor, ABC): "Not all monitors support values different from 1.", ) - eme_cell_interval_space: Literal[1] = pd.Field( + eme_cell_interval_space: Literal[1] = Field( 1, title="EME Cell Interval", description="Number of eme cells between monitor recordings. If equal to 1, " @@ -60,7 +58,7 @@ class EMEMonitor(AbstractMonitor, ABC): "Not all monitors support values different from 1.", ) - colocate: Literal[True] = pd.Field( + colocate: Literal[True] = Field( True, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -124,7 +122,7 @@ class EMEModeSolverMonitor(EMEMonitor): ... ) """ - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Note: not yet supported. Number of grid step intervals between monitor recordings. If equal to 1, " @@ -133,7 +131,7 @@ class EMEModeSolverMonitor(EMEMonitor): "in the propagation direction is not used. Note: this is not yet supported.", ) - eme_cell_interval_space: pd.PositiveInt = pd.Field( + eme_cell_interval_space: PositiveInt = Field( 1, title="EME Cell Interval", description="Number of eme cells between monitor recordings. If equal to 1, " @@ -142,20 +140,20 @@ class EMEModeSolverMonitor(EMEMonitor): "Not all monitors support values different from 1.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default (False) is used internally in EME propagation.", ) - normalize: bool = pd.Field( + normalize: bool = Field( True, title="Normalize Modes", description="Whether to normalize the EME modes to unity flux.", ) - keep_invalid_modes: bool = pd.Field( + keep_invalid_modes: bool = Field( False, title="Keep Invalid Modes", description="Whether to store modes containing nan values and modes which are " @@ -198,7 +196,7 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): ... ) """ - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Note: not yet supported. Number of grid step intervals between monitor recordings. If equal to 1, " @@ -206,7 +204,7 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): "first and last point of the monitor grid are always included.", ) - eme_cell_interval_space: Literal[1] = pd.Field( + eme_cell_interval_space: Literal[1] = Field( 1, title="EME Cell Interval", description="Number of eme cells between monitor recordings. If equal to 1, " @@ -216,14 +214,14 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): "EME field monitor.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default (False) is used internally in EME propagation.", ) - num_modes: Optional[pd.NonNegativeInt] = pd.Field( + num_modes: Optional[NonNegativeInt] = Field( None, title="Number of Modes", description="Maximum number of modes for the monitor to record. " @@ -262,7 +260,7 @@ class EMECoefficientMonitor(EMEMonitor): ... ) """ - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -272,7 +270,7 @@ class EMECoefficientMonitor(EMEMonitor): "for 'EMECoefficientMonitor'.", ) - eme_cell_interval_space: pd.PositiveInt = pd.Field( + eme_cell_interval_space: PositiveInt = Field( 1, title="EME Cell Interval", description="Number of eme cells between monitor recordings. If equal to 1, " diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index a624ad9024..2e08fb6f35 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -2,14 +2,21 @@ from __future__ import annotations -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union + +import numpy as np +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + field_validator, +) try: import matplotlib as mpl except ImportError: pass -import numpy as np -import pydantic.v1 as pd from ...constants import C_0 from ...exceptions import SetupError, ValidationError @@ -23,7 +30,7 @@ from ..monitor import AbstractModeMonitor, ModeSolverMonitor, Monitor, MonitorType from ..scene import Scene from ..simulation import AbstractYeeGridSimulation, Simulation, validate_boundaries_for_zero_dims -from ..types import Ax, Axis, FreqArray, Symmetry, annotate_type +from ..types import Ax, Axis, FreqArray, Symmetry, discriminated_union from ..validators import ( MIN_FREQUENCY, validate_freqs_min, @@ -153,8 +160,7 @@ class EMESimulation(AbstractYeeGridSimulation): * `EME Solver Demonstration <../../notebooks/docs/features/eme.rst>`_ """ - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Frequencies for the EME simulation. " "The field is propagated independently at each provided frequency. " @@ -163,14 +169,12 @@ class EMESimulation(AbstractYeeGridSimulation): "instead of providing all desired frequencies here.", ) - axis: Axis = pd.Field( - ..., + axis: Axis = Field( title="Propagation Axis", description="Propagation axis (0, 1, or 2) for the EME simulation.", ) - eme_grid_spec: EMEGridSpecType = pd.Field( - ..., + eme_grid_spec: EMEGridSpecType = Field( title="EME Grid Specification", description="Specification for the EME propagation grid. " "The simulation is divided into cells in the propagation direction; " @@ -181,15 +185,15 @@ class EMESimulation(AbstractYeeGridSimulation): "tangential directions, as well as the grid used for field monitors.", ) - monitors: Tuple[annotate_type(EMEMonitorType), ...] = pd.Field( + monitors: tuple[discriminated_union(EMEMonitorType), ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " "Note: monitor names are used to access data after simulation is run.", ) - boundary_spec: BoundarySpec = pd.Field( - BoundarySpec.all_sides(PECBoundary()), + boundary_spec: BoundarySpec = Field( + default_factory=lambda: BoundarySpec.all_sides(PECBoundary()), title="Boundaries", description="Specification of boundary conditions along each dimension. " "By default, PEC boundary conditions are applied on all sides. " @@ -199,7 +203,7 @@ class EMESimulation(AbstractYeeGridSimulation): "apply PML layers in the mode solver.", ) - sources: Tuple[None, ...] = pd.Field( + sources: tuple[None, ...] = Field( (), title="Sources", description="Sources in the simulation. NOTE: sources are not currently supported " @@ -208,8 +212,8 @@ class EMESimulation(AbstractYeeGridSimulation): "use 'smatrix_in_basis' to use another set of modes or input field.", ) - grid_spec: GridSpec = pd.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions. " "This is distinct from 'eme_grid_spec', which defines the 1D EME grid in the " @@ -217,28 +221,28 @@ class EMESimulation(AbstractYeeGridSimulation): validate_default=True, ) - store_port_modes: bool = pd.Field( + store_port_modes: bool = Field( True, title="Store Port Modes", description="Whether to store the modes associated with the two ports. " "Required to find scattering matrix in basis besides the computational basis.", ) - normalize: bool = pd.Field( + normalize: bool = Field( True, title="Normalize Scattering Matrix", description="Whether to normalize the port modes to unity flux, " "thereby normalizing the scattering matrix and expansion coefficients.", ) - port_offsets: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + port_offsets: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, 0), title="Port Offsets", description="Offsets for the two ports, relative to the simulation bounds " "along the propagation axis.", ) - sweep_spec: Optional[EMESweepSpecType] = pd.Field( + sweep_spec: Optional[EMESweepSpecType] = Field( None, title="EME Sweep Specification", description="Specification for a parameter sweep to be performed during the EME " @@ -246,7 +250,7 @@ class EMESimulation(AbstractYeeGridSimulation): "in 'sim_data.smatrix'. Other simulation monitor data is not included in the sweep.", ) - constraint: Optional[Literal["passive", "unitary"]] = pd.Field( + constraint: Optional[Literal["passive", "unitary"]] = Field( "passive", title="EME Constraint", description="Constraint for EME propagation, imposed at cell interfaces. " @@ -260,21 +264,21 @@ class EMESimulation(AbstractYeeGridSimulation): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pd.validator("grid_spec", always=True) - def _validate_auto_grid_wavelength(cls, val, values): + @field_validator("grid_spec") + def _validate_auto_grid_wavelength(val): """Handle the case where grid_spec is auto and wavelength is not provided.""" # this is handled instead post-init to ensure freqs is defined return val - @pd.validator("freqs", always=True) - def _validate_freqs(cls, val): + @field_validator("freqs") + def _validate_freqs(val): """Freqs cannot contain duplicates.""" if len(set(val)) != len(val): raise SetupError(f"'EMESimulation' 'freqs={val}' cannot contain duplicate frequencies.") return val - @pd.validator("structures", always=True) - def _validate_structures(cls, val): + @field_validator("structures") + def _validate_structures(val): """Validate and warn for certain medium types.""" for ind, structure in enumerate(val): medium = structure.medium @@ -303,8 +307,8 @@ def plot_eme_ports( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **kwargs, ) -> Ax: """Plot the EME ports.""" @@ -347,8 +351,8 @@ def plot_eme_subgrid_boundaries( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **kwargs, ) -> Ax: """Plot the EME subgrid boundaries. @@ -399,8 +403,8 @@ def plot_eme_grid( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **kwargs, ) -> Ax: """Plot the EME grid.""" @@ -442,8 +446,8 @@ def plot( ax: Ax = None, source_alpha: float = None, monitor_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **patch_kwargs, ) -> Ax: """Plot each of simulation's components on a plane defined by one nonzero x,y,z coordinate. @@ -462,9 +466,9 @@ def plot( Opacity of the monitors. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -536,7 +540,7 @@ def from_scene(cls, scene: Scene, **kwargs) -> EMESimulation: ) @property - def mode_solver_monitors(self) -> List[ModeSolverMonitor]: + def mode_solver_monitors(self) -> list[ModeSolverMonitor]: """A list of mode solver monitors at the cell centers. Each monitor has a mode spec. The cells and mode specs are specified by 'eme_grid_spec'.""" @@ -570,17 +574,20 @@ def port_modes_monitor(self) -> EMEModeSolverMonitor: normalize=self.normalize, ) + @property def _post_init_validators(self) -> None: - """Call validators taking `self` that get run after init.""" - self._validate_port_offsets() - _ = self.grid - _ = self.eme_grid - _ = self.mode_solver_monitors - _ = self._cell_index_pairs - self._validate_too_close_to_edges() - self._validate_sweep_spec() - self._validate_symmetry() - self._validate_monitor_setup() + """Return validators taking `self` that get run after init.""" + return ( + self._validate_port_offsets, + lambda: self.grid, + lambda: self.eme_grid, + lambda: self.mode_solver_monitors, + lambda: self._cell_index_pairs, + self._validate_too_close_to_edges, + self._validate_sweep_spec, + self._validate_symmetry, + self._validate_monitor_setup, + ) def validate_pre_upload(self) -> None: """Validate the fully initialized EME simulation is ok for upload to our servers.""" @@ -856,7 +863,7 @@ def _validate_monitor_size(self) -> None: def _validate_modes_size(self) -> None: """Warn if mode sources or monitors have a large number of points.""" - def warn_mode_size(monitor: AbstractModeMonitor, msg_header: str, custom_loc: List): + def warn_mode_size(monitor: AbstractModeMonitor, msg_header: str, custom_loc: list): """Warn if a mode component has a large number of points.""" num_cells = np.prod(self.discretize_monitor(monitor).num_cells) if num_cells > WARN_MODE_NUM_CELLS: @@ -880,14 +887,14 @@ def warn_mode_size(monitor: AbstractModeMonitor, msg_header: str, custom_loc: Li warn_mode_size(monitor=monitor, msg_header=msg_header, custom_loc=custom_loc) @property - def _monitors_full(self) -> Tuple[EMEMonitorType, ...]: + def _monitors_full(self) -> tuple[EMEMonitorType, ...]: """All monitors, including port modes monitor.""" if self.store_port_modes: return list(self.monitors) + [self.port_modes_monitor] return list(self.monitors) @cached_property - def monitors_data_size(self) -> Dict[str, float]: + def monitors_data_size(self) -> dict[str, float]: """Dictionary mapping monitor names to their estimated storage size in bytes.""" data_size = {} for monitor in self._monitors_full: @@ -914,7 +921,7 @@ def monitors_data_size(self) -> Dict[str, float]: return data_size @property - def _num_sweep(self) -> pd.PositiveInt: + def _num_sweep(self) -> PositiveInt: """Number of sweep indices.""" if self.sweep_spec is None: return 1 @@ -926,7 +933,7 @@ def _sweep_modes(self) -> bool: return self.sweep_spec is not None and isinstance(self.sweep_spec, EMEFreqSweep) @property - def _num_sweep_modes(self) -> pd.PositiveInt: + def _num_sweep_modes(self) -> PositiveInt: """Number of sweep indices for modes.""" if self._sweep_modes: return self._num_sweep @@ -940,7 +947,7 @@ def _sweep_interfaces(self) -> bool: ) @property - def _num_sweep_interfaces(self) -> pd.PositiveInt: + def _num_sweep_interfaces(self) -> PositiveInt: """Number of sweep indices for interfaces.""" if self._sweep_interfaces: return self._num_sweep @@ -954,13 +961,13 @@ def _sweep_cells(self) -> bool: ) @property - def _num_sweep_cells(self) -> pd.PositiveInt: + def _num_sweep_cells(self) -> PositiveInt: """Number of sweep indices for cells.""" if self._sweep_cells: return self._num_sweep return 1 - def _monitor_num_sweep(self, monitor: EMEMonitor) -> pd.PositiveInt: + def _monitor_num_sweep(self, monitor: EMEMonitor) -> PositiveInt: """Number of sweep indices for a certain monitor.""" if self.sweep_spec is None: return 1 @@ -971,7 +978,7 @@ def _monitor_num_sweep(self, monitor: EMEMonitor) -> pd.PositiveInt: return self.sweep_spec.num_sweep return min(self.sweep_spec.num_sweep, monitor.num_sweep) - def _monitor_eme_cell_indices(self, monitor: EMEMonitor) -> List[pd.NonNegativeInt]: + def _monitor_eme_cell_indices(self, monitor: EMEMonitor) -> list[NonNegativeInt]: """EME cell indices inside monitor. Takes into account 'eme_cell_interval_space'.""" cell_indices_full = self.eme_grid.cell_indices_in_box(box=monitor.geometry) if len(cell_indices_full) == 0: @@ -986,7 +993,7 @@ def _monitor_num_eme_cells(self, monitor: EMEMonitor) -> int: """Total number of EME cells included in monitor based on simulation grid.""" return len(self._monitor_eme_cell_indices(monitor=monitor)) - def _monitor_freqs(self, monitor: Monitor) -> List[pd.NonNegativeFloat]: + def _monitor_freqs(self, monitor: Monitor) -> list[NonNegativeFloat]: """Monitor frequencies.""" if monitor.freqs is None: return list(self.freqs) @@ -1079,7 +1086,9 @@ def _to_fdtd_sim(self) -> Simulation: grid_spec = grid_spec.updated_copy(wavelength=min_wvl) # copy over all FDTD monitors too - monitors = [monitor for monitor in self.monitors if not isinstance(monitor, EMEMonitor)] + monitors = tuple( + monitor for monitor in self.monitors if not isinstance(monitor, EMEMonitor) + ) kwargs = {key: getattr(self, key) for key in EME_SIM_YEE_SIM_SHARED_ATTRS} return Simulation( @@ -1094,8 +1103,8 @@ def subsection( region: Box, grid_spec: Union[GridSpec, Literal["identical"]] = None, eme_grid_spec: Union[EMEGridSpec, Literal["identical"]] = None, - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = None, - monitors: Tuple[MonitorType, ...] = None, + symmetry: tuple[Symmetry, Symmetry, Symmetry] = None, + monitors: tuple[MonitorType, ...] = None, remove_outside_structures: bool = True, remove_outside_custom_mediums: bool = False, **kwargs, @@ -1117,11 +1126,11 @@ def subsection( simulation. If ``identical``, then the original grid is transferred directly as a :class:`.EMEExplicitGrid`. Noe that in the latter case the region of the new simulation is expanded to contain full EME cells. - symmetry : Tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None + symmetry : tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None New simulation symmetry. If ``None``, then it is inherited from the original simulation. Note that in this case the size and placement of new simulation domain must be commensurate with the original symmetry. - monitors : Tuple[MonitorType, ...] = None + monitors : tuple[MonitorType, ...] = None New list of monitors. If ``None``, then the monitors intersecting the new simulation domain are inherited from the original simulation. remove_outside_structures : bool = True @@ -1170,7 +1179,7 @@ def subsection( return new_sim @property - def _cell_index_pairs(self) -> List[pd.NonNegativeInt]: + def _cell_index_pairs(self) -> list[NonNegativeInt]: """All the pairs of adjacent EME cells needed, taken over all sweep indices.""" pairs = set() if isinstance(self.sweep_spec, EMEPeriodicitySweep): diff --git a/tidy3d/components/eme/sweep.py b/tidy3d/components/eme/sweep.py index d4130e00f3..c5f9f75f97 100644 --- a/tidy3d/components/eme/sweep.py +++ b/tidy3d/components/eme/sweep.py @@ -3,9 +3,9 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Union -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, field_validator from ...exceptions import SetupError from ..base import Tidy3dBaseModel @@ -18,15 +18,14 @@ class EMESweepSpec(Tidy3dBaseModel, ABC): @property @abstractmethod - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" class EMELengthSweep(EMESweepSpec): """Spec for sweeping EME cell lengths.""" - scale_factors: ArrayLike = pd.Field( - ..., + scale_factors: ArrayLike = Field( title="Length Scale Factor", description="Length scale factors to be used in the EME propagation step. " "The EME propagation step is repeated after scaling every cell length by this amount. " @@ -36,7 +35,7 @@ class EMELengthSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.scale_factors) @@ -45,8 +44,7 @@ class EMEModeSweep(EMESweepSpec): """Spec for sweeping number of modes in EME propagation step. Used for convergence testing.""" - num_modes: ArrayInt1D = pd.Field( - ..., + num_modes: ArrayInt1D = Field( title="Number of Modes", description="Max number of modes to use in the EME propagation step. " "The EME propagation step is repeated after dropping modes with mode_index " @@ -56,7 +54,7 @@ class EMEModeSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.num_modes) @@ -67,8 +65,7 @@ class EMEFreqSweep(EMESweepSpec): perturbative mode solver relative to the simulation EME modes. This can be a faster way to solve at a larger number of frequencies.""" - freq_scale_factors: ArrayFloat1D = pd.Field( - ..., + freq_scale_factors: ArrayFloat1D = Field( title="Frequency Scale Factors", description="Scale factors " "applied to every frequency in 'EMESimulation.freqs'. After applying the scale factors, " @@ -78,7 +75,7 @@ class EMEFreqSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.freq_scale_factors) @@ -99,16 +96,15 @@ class EMEPeriodicitySweep(EMESweepSpec): >>> sweep_spec = EMEPeriodicitySweep(num_reps=[{"unit_cell": n} for n in n_list]) """ - num_reps: List[Dict[str, pd.PositiveInt]] = pd.Field( - ..., + num_reps: list[dict[str, PositiveInt]] = Field( title="Number of Repetitions", description="Number of periodic repetitions of named subgrids in this EME grid. " "At each sweep index, contains a dict mapping the name of a subgrid to the " "number of repetitions of that subgrid at that sweep index.", ) - @pd.validator("num_reps", always=True) - def _validate_num_reps(cls, val): + @field_validator("num_reps") + def _validate_num_reps(val): """Check num_reps is not too large.""" for num_reps_dict in val: for value in num_reps_dict.values(): @@ -120,7 +116,7 @@ def _validate_num_reps(cls, val): return val @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.num_reps) diff --git a/tidy3d/components/field_projection.py b/tidy3d/components/field_projection.py index b66e14484a..6ac963008a 100644 --- a/tidy3d/components/field_projection.py +++ b/tidy3d/components/field_projection.py @@ -1,20 +1,18 @@ """Near field to far field transformation plugin""" -from __future__ import annotations - -from typing import Iterable, List, Tuple, Union +from typing import Iterable, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic import xarray as xr +from pydantic import Field, model_validator from rich.progress import track from ..constants import C_0, EPSILON_0, ETA_0, MICROMETER, MU_0 from ..exceptions import SetupError from ..log import get_logging_console from .autograd.functions import add_at, trapz -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .data.data_array import ( FieldProjectionAngleDataArray, FieldProjectionCartesianDataArray, @@ -44,7 +42,7 @@ # Numpy float array and related array types -ArrayLikeN2F = Union[float, Tuple[float, ...], ArrayComplex4D] +ArrayLikeN2F = Union[float, tuple[float, ...], ArrayComplex4D] class FieldProjector(Tidy3dBaseModel): @@ -64,20 +62,18 @@ class FieldProjector(Tidy3dBaseModel): * `Performing near field to far field projections <../../notebooks/FieldProjections.html>`_ """ - sim_data: SimulationData = pydantic.Field( - ..., + sim_data: SimulationData = Field( title="Simulation data", description="Container for simulation data containing the near field monitors.", ) - surfaces: Tuple[FieldProjectionSurface, ...] = pydantic.Field( - ..., + surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Surface monitor with direction", - description="Tuple of each :class:`.FieldProjectionSurface` to use as source of " + description="tuple of each :class:`.FieldProjectionSurface` to use as source of " "near field.", ) - pts_per_wavelength: Union[int, type(None)] = pydantic.Field( + pts_per_wavelength: Optional[int] = Field( PTS_PER_WVL, title="Points per wavelength", description="Number of points per wavelength in the background medium with which " @@ -85,7 +81,7 @@ class FieldProjector(Tidy3dBaseModel): "will not resampled, but will still be colocated.", ) - origin: Coordinate = pydantic.Field( + origin: Optional[Coordinate] = Field( None, title="Local origin", description="Local origin used for defining observation points. If ``None``, uses the " @@ -93,21 +89,19 @@ class FieldProjector(Tidy3dBaseModel): units=MICROMETER, ) + @model_validator(mode="after") + def _check_origin_set(self): + """Sets ``.origin`` as the average of centers of all surface monitors if not provided.""" + if self.origin is None: + centers = np.array([surface.monitor.center for surface in self.surfaces]) + object.__setattr__(self, "origin", tuple(np.mean(centers, axis=0))) + return self + @cached_property def is_2d_simulation(self) -> bool: non_zero_dims = sum(1 for size in self.sim_data.simulation.size if size != 0) return non_zero_dims == 2 - @pydantic.validator("origin", always=True) - @skip_if_fields_missing(["surfaces"]) - def set_origin(cls, val, values): - """Sets .origin as the average of centers of all surface monitors if not provided.""" - if val is None: - surfaces = values.get("surfaces") - val = np.array([surface.monitor.center for surface in surfaces]) - return tuple(np.mean(val, axis=0)) - return val - @cached_property def medium(self) -> MediumType: """Medium into which fields are to be projected.""" @@ -116,7 +110,7 @@ def medium(self) -> MediumType: return sim.monitor_medium(monitor) @cached_property - def frequencies(self) -> List[float]: + def frequencies(self) -> list[float]: """Return the list of frequencies associated with the field monitors.""" return self.surfaces[0].monitor.freqs @@ -124,8 +118,8 @@ def frequencies(self) -> List[float]: def from_near_field_monitors( cls, sim_data: SimulationData, - near_monitors: List[FieldMonitor], - normal_dirs: List[Direction], + near_monitors: list[FieldMonitor], + normal_dirs: list[Direction], pts_per_wavelength: int = PTS_PER_WVL, origin: Coordinate = None, ): @@ -135,10 +129,10 @@ def from_near_field_monitors( ---------- sim_data : :class:`.SimulationData` Container for simulation data containing the near field monitors. - near_monitors : List[:class:`.FieldMonitor`] - Tuple of :class:`.FieldMonitor` objects on which near fields will be sampled. - normal_dirs : List[:class:`.Direction`] - Tuple containing the :class:`.Direction` of the normal to each surface monitor + near_monitors : list[:class:`.FieldMonitor`] + tuple of :class:`.FieldMonitor` objects on which near fields will be sampled. + normal_dirs : list[:class:`.Direction`] + tuple containing the :class:`.Direction` of the normal to each surface monitor w.r.t. to the positive x, y or z unit vectors. Must have the same length as monitors. pts_per_wavelength : int = 10 Number of points per wavelength with which to discretize the @@ -268,7 +262,7 @@ def _fields_to_currents(field_data: FieldData, surface: FieldProjectionSurface) surface_currents[H2] = field_data.field_components[E1] * signs[0] surface_currents[H1] = field_data.field_components[E2] * signs[1] - new_monitor = surface.monitor.copy(update=dict(fields=[E1, E2, H1, H2])) + new_monitor = surface.monitor.copy(update=dict(fields=(E1, E2, H1, H2))) return FieldData( monitor=new_monitor, @@ -400,9 +394,9 @@ def _far_fields_for_surface( frequency : float Frequency to select from each :class:`.FieldMonitor` to use for projection. Must be a frequency stored in each :class:`FieldMonitor`. - theta : Union[float, Tuple[float, ...], np.ndarray] + theta : Union[float, tuple[float, ...], np.ndarray] Polar angles (rad) downward from x=y=0 line relative to the local origin. - phi : Union[float, Tuple[float, ...], np.ndarray] + phi : Union[float, tuple[float, ...], np.ndarray] Azimuthal (rad) angles from y=z=0 line relative to the local origin. surface: :class:`FieldProjectionSurface` :class:`FieldProjectionSurface` object to use as source of near field. @@ -905,7 +899,7 @@ def _fields_for_surface_exact( d2G_dr2 = dG_dr * (ikr - 1.0) / r + G / (r**2) # operations between unit vectors and currents - def r_x_current(current: Tuple[np.ndarray, ...]) -> Tuple[np.ndarray, ...]: + def r_x_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, ...]: """Cross product between the r unit vector and the current.""" return [ sin_theta * sin_phi * current[2] - cos_theta * current[1], @@ -913,7 +907,7 @@ def r_x_current(current: Tuple[np.ndarray, ...]) -> Tuple[np.ndarray, ...]: sin_theta * cos_phi * current[1] - sin_theta * sin_phi * current[0], ] - def r_dot_current(current: Tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current(current: tuple[np.ndarray, ...]) -> np.ndarray: """Dot product between the r unit vector and the current.""" return ( sin_theta * cos_phi * current[0] @@ -921,7 +915,7 @@ def r_dot_current(current: Tuple[np.ndarray, ...]) -> np.ndarray: + cos_theta * current[2] ) - def r_dot_current_dtheta(current: Tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current_dtheta(current: tuple[np.ndarray, ...]) -> np.ndarray: """Theta derivative of the dot product between the r unit vector and the current.""" return ( cos_theta * cos_phi * current[0] @@ -929,12 +923,12 @@ def r_dot_current_dtheta(current: Tuple[np.ndarray, ...]) -> np.ndarray: - sin_theta * current[2] ) - def r_dot_current_dphi_div_sin_theta(current: Tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current_dphi_div_sin_theta(current: tuple[np.ndarray, ...]) -> np.ndarray: """Phi derivative of the dot product between the r unit vector and the current, analytically divided by sin theta.""" return -sin_phi * current[0] + cos_phi * current[1] - def grad_Gr_r_dot_current(current: Tuple[np.ndarray, ...]) -> Tuple[np.ndarray, ...]: + def grad_Gr_r_dot_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, ...]: """Gradient of the product of the gradient of the Green's function and the dot product between the r unit vector and the current.""" temp = [ @@ -945,7 +939,7 @@ def grad_Gr_r_dot_current(current: Tuple[np.ndarray, ...]) -> Tuple[np.ndarray, # convert to Cartesian coordinates return surface.monitor.sph_2_car_field(temp[0], temp[1], temp[2], theta_obs, phi_obs) - def potential_terms(current: Tuple[np.ndarray, ...], const: complex): + def potential_terms(current: tuple[np.ndarray, ...], const: complex): """Assemble vector potential and its derivatives.""" r_x_c = r_x_current(current) pot = [const * item * G for item in current] diff --git a/tidy3d/components/frequencies.py b/tidy3d/components/frequencies.py index 6a14fb5b6b..833b5547b8 100644 --- a/tidy3d/components/frequencies.py +++ b/tidy3d/components/frequencies.py @@ -1,7 +1,7 @@ """Frequency utilities.""" import numpy as np -import pydantic as pd +from pydantic import Field from ..constants import C_0 from .base import Tidy3dBaseModel @@ -17,7 +17,7 @@ class FrequencyUtils(Tidy3dBaseModel): """Class for general frequency/wavelength utilities.""" - use_wavelength: bool = pd.Field( + use_wavelength: bool = Field( False, title="Use wavelength", description="Indicate whether to use wavelengths instead of frequencies for the return " diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 9fddfddc72..e57ad62c29 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -5,12 +5,19 @@ import functools import pathlib from abc import ABC, abstractmethod -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, Optional, Union import autograd.numpy as np -import pydantic.v1 as pydantic import shapely import xarray as xr +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) try: from matplotlib import patches @@ -45,7 +52,7 @@ PlanePosition, Shapely, Size, - annotate_type, + discriminated_union, ) from ..viz import ( ARROW_LENGTH, @@ -151,7 +158,7 @@ def make_shapely_point(minx: float, miny: float) -> shapely.Point: def _inds_inside_bounds( self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] - ) -> Tuple[slice, slice, slice]: + ) -> tuple[slice, slice, slice]: """Return slices into the sorted input arrays that are inside the geometry bounds. Parameters @@ -165,7 +172,7 @@ def _inds_inside_bounds( Returns ------- - Tuple[slice, slice, slice] + tuple[slice, slice, slice] Slices into each of the three arrays that are inside the geometry bounds. """ bounds = self.bounds @@ -213,7 +220,7 @@ def inside_meshgrid( @abstractmethod def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -227,7 +234,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -235,7 +242,7 @@ def intersections_tilted_plane( def intersections_plane( self, x: float = None, y: float = None, z: float = None - ) -> List[Shapely]: + ) -> list[Shapely]: """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. Parameters @@ -249,7 +256,7 @@ def intersections_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -263,13 +270,13 @@ def intersections_plane( to_2D = to_2D[list(indices) + [last, 3]] return self.intersections_tilted_plane(normal, origin, to_2D) - def intersections_2dbox(self, plane: Box) -> List[Shapely]: + def intersections_2dbox(self, plane: Box) -> list[Shapely]: """Returns list of shapely geometries representing the intersections of the geometry with a 2D box. Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. """ @@ -280,7 +287,7 @@ def intersections_2dbox(self, plane: Box) -> List[Shapely]: return plane.intersections_with(self) def intersects( - self, other, strict_inequality: Tuple[bool, bool, bool] = [False, False, False] + self, other, strict_inequality: tuple[bool, bool, bool] = [False, False, False] ) -> bool: """Returns ``True`` if two :class:`Geometry` have intersecting `.bounds`. @@ -288,7 +295,7 @@ def intersects( ---------- other : :class:`Geometry` Geometry to check intersection with. - strict_inequality : Tuple[bool, bool, bool] = [False, False, False] + strict_inequality : tuple[bool, bool, bool] = [False, False, False] For each dimension, defines whether to include equality in the boundaries comparison. If ``False``, equality is included, and two geometries that only intersect at their boundaries will evaluate as ``True``. If ``True``, such geometries will evaluate as @@ -318,7 +325,7 @@ def intersects( return True def contains( - self, other: Geometry, strict_inequality: Tuple[bool, bool, bool] = [False, False, False] + self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] ) -> bool: """Returns ``True`` if the `.bounds` of ``other`` are contained within the `.bounds` of ``self``. @@ -327,7 +334,7 @@ def contains( ---------- other : :class:`Geometry` Geometry to check containment with. - strict_inequality : Tuple[bool, bool, bool] = [False, False, False] + strict_inequality : tuple[bool, bool, bool] = [False, False, False] For each dimension, defines whether to include equality in the boundaries comparison. If ``False``, equality will be considered as contained. If ``True``, ``other``'s bounds must be strictly within the bounds of ``self``. @@ -401,7 +408,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float float] + tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ @@ -435,7 +442,7 @@ def bounding_box(self): return Box.from_bounds(*self.bounds) @cached_property - def zero_dims(self) -> List[Axis]: + def zero_dims(self) -> list[Axis]: """A list of axes along which the :class:`Geometry` is zero-sized based on its bounds.""" zero_dims = [] for dim in range(3): @@ -443,7 +450,7 @@ def zero_dims(self) -> List[Axis]: zero_dims.append(dim) return zero_dims - def _pop_bounds(self, axis: Axis) -> Tuple[Coordinate2D, Tuple[Coordinate2D, Coordinate2D]]: + def _pop_bounds(self, axis: Axis) -> tuple[Coordinate2D, tuple[Coordinate2D, Coordinate2D]]: """Returns min and max bounds in plane normal to and tangential to ``axis``. Parameters @@ -453,7 +460,7 @@ def _pop_bounds(self, axis: Axis) -> Tuple[Coordinate2D, Tuple[Coordinate2D, Coo Returns ------- - Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]] + tuple[float, float], tuple[tuple[float, float], tuple[float, float]] Bounds along axis and a tuple of bounds in the ordered planar coordinates. Packed as ``(zmin, zmax), ((xmin, ymin), (xmax, ymax))``. """ @@ -483,7 +490,7 @@ def _normal_2dmaterial(self) -> Axis: """Get the normal to the given geometry, checking that it is a 2D geometry.""" raise ValidationError("'Medium2D' is not compatible with this geometry class.") - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Geometry: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geometry: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" raise NotImplementedError( @@ -594,7 +601,7 @@ def _do_not_intersect(bounds_a, bounds_b, shape_a, shape_b): return False @staticmethod - def _get_plot_labels(axis: Axis) -> Tuple[str, str]: + def _get_plot_labels(axis: Axis) -> tuple[str, str]: """Returns planar coordinate x and y axis labels for cross section plots. Parameters @@ -612,7 +619,7 @@ def _get_plot_labels(axis: Axis) -> Tuple[str, str]: def _get_plot_limits( self, axis: Axis, buffer: float = PLOT_BUFFER - ) -> Tuple[Coordinate2D, Coordinate2D]: + ) -> tuple[Coordinate2D, Coordinate2D]: """Gets planar coordinate limits for cross section plots. Parameters @@ -624,7 +631,7 @@ def _get_plot_limits( Returns ------- - Tuple[float, float], Tuple[float, float] + tuple[float, float], tuple[float, float] The x and y plot limits, packed as ``(xmin, xmax), (ymin, ymax)``. """ _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) @@ -726,19 +733,19 @@ def evaluate_inf_shape(shape: Shapely) -> Shapely: return shape @staticmethod - def pop_axis(coord: Tuple[Any, Any, Any], axis: int) -> Tuple[Any, Tuple[Any, Any]]: + def pop_axis(coord: tuple[Any, Any, Any], axis: int) -> tuple[Any, tuple[Any, Any]]: """Separates coordinate at ``axis`` index from coordinates on the plane tangent to ``axis``. Parameters ---------- - coord : Tuple[Any, Any, Any] + coord : tuple[Any, Any, Any] Tuple of three values in original coordinate system. axis : int Integer index into 'xyz' (0,1,2). Returns ------- - Any, Tuple[Any, Any] + Any, tuple[Any, Any] The input coordinates are separated into the one along the axis provided and the two on the planar coordinates, like ``axis_coord, (planar_coord1, planar_coord2)``. @@ -748,21 +755,21 @@ def pop_axis(coord: Tuple[Any, Any, Any], axis: int) -> Tuple[Any, Tuple[Any, An return axis_val, tuple(plane_vals) @staticmethod - def unpop_axis(ax_coord: Any, plane_coords: Tuple[Any, Any], axis: int) -> Tuple[Any, Any, Any]: + def unpop_axis(ax_coord: Any, plane_coords: tuple[Any, Any], axis: int) -> tuple[Any, Any, Any]: """Combine coordinate along axis with coordinates on the plane tangent to the axis. Parameters ---------- ax_coord : Any Value along axis direction. - plane_coords : Tuple[Any, Any] + plane_coords : tuple[Any, Any] Values along ordered planar directions. axis : int Integer index into 'xyz' (0,1,2). Returns ------- - Tuple[Any, Any, Any] + tuple[Any, Any, Any] The three values in the xyz coordinate system. """ coords = list(plane_coords) @@ -770,7 +777,7 @@ def unpop_axis(ax_coord: Any, plane_coords: Tuple[Any, Any], axis: int) -> Tuple return tuple(coords) @staticmethod - def parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]: + def parse_xyz_kwargs(**xyz) -> tuple[Axis, float]: """Turns x,y,z kwargs into index of the normal axis and position along that axis. Parameters @@ -795,7 +802,7 @@ def parse_xyz_kwargs(**xyz) -> Tuple[Axis, float]: return axis, position @staticmethod - def parse_two_xyz_kwargs(**xyz) -> List[Tuple[Axis, float]]: + def parse_two_xyz_kwargs(**xyz) -> list[tuple[Axis, float]]: """Turns x,y,z kwargs into indices of axes and the position along each axis. Parameters @@ -879,7 +886,7 @@ def volume(self, bounds: Bound = None): Parameters ---------- - bounds : Tuple[Tuple[float, float, float], Tuple[float, float, float]] = None + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -902,7 +909,7 @@ def surface_area(self, bounds: Bound = None): Parameters ---------- - bounds : Tuple[Tuple[float, float, float], Tuple[float, float, float]] = None + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -965,7 +972,7 @@ def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> Geometry: ---------- angle : float Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] + axis : Union[int, tuple[float, float, float]] Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. Returns @@ -980,7 +987,7 @@ def reflected(self, normal: Coordinate) -> Geometry: Parameters ---------- - normal : Tuple[float, float, float] + normal : tuple[float, float, float] The 3D normal vector of the plane of reflection. The plane is assumed to pass through the origin (0,0,0). @@ -994,7 +1001,7 @@ def reflected(self, normal: Coordinate) -> Geometry: """ Field and coordinate transformations """ @staticmethod - def car_2_sph(x: float, y: float, z: float) -> Tuple[float, float, float]: + def car_2_sph(x: float, y: float, z: float) -> tuple[float, float, float]: """Convert Cartesian to spherical coordinates. Parameters @@ -1008,7 +1015,7 @@ def car_2_sph(x: float, y: float, z: float) -> Tuple[float, float, float]: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] r, theta, and phi coordinates relative to ``local_origin``. """ r = np.sqrt(x**2 + y**2 + z**2) @@ -1017,7 +1024,7 @@ def car_2_sph(x: float, y: float, z: float) -> Tuple[float, float, float]: return r, theta, phi @staticmethod - def sph_2_car(r: float, theta: float, phi: float) -> Tuple[float, float, float]: + def sph_2_car(r: float, theta: float, phi: float) -> tuple[float, float, float]: """Convert spherical to Cartesian coordinates. Parameters @@ -1031,7 +1038,7 @@ def sph_2_car(r: float, theta: float, phi: float) -> Tuple[float, float, float]: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] x, y, and z coordinates relative to ``local_origin``. """ r_sin_theta = r * np.sin(theta) @@ -1043,7 +1050,7 @@ def sph_2_car(r: float, theta: float, phi: float) -> Tuple[float, float, float]: @staticmethod def sph_2_car_field( f_r: float, f_theta: float, f_phi: float, theta: float, phi: float - ) -> Tuple[complex, complex, complex]: + ) -> tuple[complex, complex, complex]: """Convert vector field components in spherical coordinates to cartesian. Parameters @@ -1061,7 +1068,7 @@ def sph_2_car_field( Returns ------- - Tuple[float, float, float] + tuple[float, float, float] x, y, and z components of the vector field in cartesian coordinates. """ sin_theta = np.sin(theta) @@ -1076,7 +1083,7 @@ def sph_2_car_field( @staticmethod def car_2_sph_field( f_x: float, f_y: float, f_z: float, theta: float, phi: float - ) -> Tuple[complex, complex, complex]: + ) -> tuple[complex, complex, complex]: """Convert vector field components in cartesian coordinates to spherical. Parameters @@ -1094,7 +1101,7 @@ def car_2_sph_field( Returns ------- - Tuple[float, float, float] + tuple[float, float, float] radial (s), elevation (theta), and azimuthal (phi) components of the vector field in spherical coordinates. """ @@ -1108,7 +1115,7 @@ def car_2_sph_field( return f_r, f_theta, f_phi @staticmethod - def kspace_2_sph(ux: float, uy: float, axis: Axis) -> Tuple[float, float]: + def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: """Convert normalized k-space coordinates to angles. Parameters @@ -1122,7 +1129,7 @@ def kspace_2_sph(ux: float, uy: float, axis: Axis) -> Tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] theta and phi coordinates relative to ``local_origin``. """ phi_local = np.arctan2(uy, ux) @@ -1147,8 +1154,8 @@ def kspace_2_sph(ux: float, uy: float, axis: Axis) -> Tuple[float, float]: @staticmethod @verify_packages_import(["gdstk", "gdspy"], required="any") def load_gds_vertices_gdstk( - gds_cell, gds_layer: int, gds_dtype: int = None, gds_scale: pydantic.PositiveFloat = 1.0 - ) -> List[ArrayFloat2D]: + gds_cell, gds_layer: int, gds_dtype: int = None, gds_scale: PositiveFloat = 1.0 + ) -> list[ArrayFloat2D]: """Load polygon vertices from a ``gdstk.Cell``. Parameters @@ -1166,7 +1173,7 @@ def load_gds_vertices_gdstk( Returns ------- - List[ArrayFloat2D] + list[ArrayFloat2D] List of polygon vertices """ @@ -1196,8 +1203,8 @@ def load_gds_vertices_gdstk( @staticmethod @verify_packages_import(["gdstk", "gdspy"], required="any") def load_gds_vertices_gdspy( - gds_cell, gds_layer: int, gds_dtype: int = None, gds_scale: pydantic.PositiveFloat = 1.0 - ) -> List[ArrayFloat2D]: + gds_cell, gds_layer: int, gds_dtype: int = None, gds_scale: PositiveFloat = 1.0 + ) -> list[ArrayFloat2D]: """Load polygon vertices from a ``gdspy.Cell``. Parameters @@ -1215,7 +1222,7 @@ def load_gds_vertices_gdspy( Returns ------- - List[ArrayFloat2D] + list[ArrayFloat2D] List of polygon vertices """ @@ -1241,10 +1248,10 @@ def load_gds_vertices_gdspy( def from_gds( gds_cell, axis: Axis, - slab_bounds: Tuple[float, float], + slab_bounds: tuple[float, float], gds_layer: int, gds_dtype: int = None, - gds_scale: pydantic.PositiveFloat = 1.0, + gds_scale: PositiveFloat = 1.0, dilation: float = 0.0, sidewall_angle: float = 0, reference_plane: PlanePosition = "middle", @@ -1257,7 +1264,7 @@ def from_gds( ``gdstk.Cell`` or ``gdspy.Cell`` containing 2D geometric data. axis : int Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] + slab_bounds: tuple[float, float] Minimal and maximal positions of the extruded slab along ``axis``. gds_layer : int Layer index in the ``gds_cell``. @@ -1320,7 +1327,7 @@ def from_gds( shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane ) ) - except pydantic.ValidationError as error: + except ValidationError as error: consolidated_logger.warning(str(error)) except Tidy3dError as error: consolidated_logger.warning(str(error)) @@ -1330,7 +1337,7 @@ def from_gds( def from_shapely( shape: Shapely, axis: Axis, - slab_bounds: Tuple[float, float], + slab_bounds: tuple[float, float], dilation: float = 0.0, sidewall_angle: float = 0, reference_plane: PlanePosition = "middle", @@ -1344,7 +1351,7 @@ def from_shapely( of any of those. axis : int Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] + slab_bounds: tuple[float, float] Minimal and maximal positions of the extruded slab along ``axis``. dilation : float Dilation of the polygon in the base by shifting each edge along its normal outwards @@ -1371,9 +1378,9 @@ def to_gdstk( x: float = None, y: float = None, z: float = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, - ) -> List: + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> list: """Convert a Geometry object's planar slice to a .gds type polygon. Parameters @@ -1420,9 +1427,9 @@ def to_gdspy( x: float = None, y: float = None, z: float = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, - ) -> List: + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> list: """Convert a Geometry object's planar slice to a .gds type polygon. Parameters @@ -1470,8 +1477,8 @@ def to_gds( x: float = None, y: float = None, z: float = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, ) -> None: """Append a Geometry object's planar slice to a .gds cell. @@ -1529,8 +1536,8 @@ def to_gds_file( x: float = None, y: float = None, z: float = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, gds_cell_name: str = "MAIN", ) -> None: """Export a Geometry object's planar slice to a .gds file. @@ -1580,7 +1587,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM """Compute the adjoint derivatives for this object.""" raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") - def _as_union(self) -> List[Geometry]: + def _as_union(self) -> list[Geometry]: """Return a list of geometries that, united, make up the given geometry.""" if isinstance(self, GeometryGroup): return self.geometries @@ -1660,15 +1667,23 @@ def __invert__(self): class Centered(Geometry, ABC): """Geometry with a well defined center.""" - center: TracedCoordinate = pydantic.Field( - (0.0, 0.0, 0.0), + center: Optional[TracedCoordinate] = Field( + None, title="Center", description="Center of object in x, y, and z.", units=MICROMETER, ) - @pydantic.validator("center", always=True) - def _center_not_inf(cls, val): + @field_validator("center", mode="before") + @classmethod + def _center_default(cls, val): + """Make sure center is not infinitiy.""" + if val is None: + val = (0.0, 0.0, 0.0) + return val + + @field_validator("center") + def _center_not_inf(val): """Make sure center is not infinitiy.""" if any(np.isinf(v) for v in val): raise ValidationError("center can not contain td.inf terms.") @@ -1680,7 +1695,7 @@ class SimplePlaneIntersection(Geometry, ABC): def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Checks special cases before relying on the complete computation. @@ -1695,7 +1710,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -1723,7 +1738,7 @@ def transform(p_array): @abstractmethod def _do_intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -1737,7 +1752,7 @@ def _do_intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -1747,11 +1762,13 @@ def _do_intersections_tilted_plane( class Planar(SimplePlaneIntersection, Geometry, ABC): """Geometry with one ``axis`` that is slab-like with thickness ``height``.""" - axis: Axis = pydantic.Field( - 2, title="Axis", description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z)." + axis: Axis = Field( + 2, + title="Axis", + description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z).", ) - sidewall_angle: float = pydantic.Field( + sidewall_angle: float = Field( 0.0, title="Sidewall angle", description="Angle of the sidewall. " @@ -1763,7 +1780,7 @@ class Planar(SimplePlaneIntersection, Geometry, ABC): units=RADIAN, ) - reference_plane: PlanePosition = pydantic.Field( + reference_plane: PlanePosition = Field( "middle", title="Reference plane for cross section", description="The position of the plane where the supplied cross section are " @@ -1774,15 +1791,14 @@ class Planar(SimplePlaneIntersection, Geometry, ABC): "``top`` refers to the positive side of the y-axis.", ) - @pydantic.validator("sidewall_angle", always=True) - def validate_angle(cls, value: float) -> float: + @field_validator("sidewall_angle") + def validate_angle(val) -> float: lower_bound = -np.pi / 2 upper_bound = np.pi / 2 - if (value <= lower_bound) or (value >= upper_bound): + if (val <= lower_bound) or (val >= upper_bound): # u03C0 is unicode for pi - raise ValidationError(f"Sidewall angle ({value}) must be between -π/2 and π/2 rad.") - - return value + raise ValidationError(f"Sidewall angle ({val}) must be between -π/2 and π/2 rad.") + return val @property @abstractmethod @@ -1815,7 +1831,7 @@ def intersections_plane(self, x: float = None, y: float = None, z: float = None) Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -1838,7 +1854,7 @@ def _intersections_normal(self, z: float) -> list: Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -1857,7 +1873,7 @@ def _intersections_side(self, position: float, axis: Axis) -> list: Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -1880,7 +1896,7 @@ def _order_axis(self, axis: int) -> int: axis_index.insert(self.axis, 2) return axis_index[axis] - def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> Tuple[Any, Any]: + def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> tuple[Any, Any]: """Orders a value in the plane and value along axis in correct (x,y) order for plotting. Note: sometimes if axis=1 and we compute cross section values orthogonal to axis, they can either be x or y in the plots. @@ -1917,15 +1933,17 @@ def _tanq(self) -> float: class Circular(Geometry): """Geometry with circular characteristics (specified by a radius).""" - radius: pydantic.NonNegativeFloat = pydantic.Field( - ..., title="Radius", description="Radius of geometry.", units=MICROMETER + radius: NonNegativeFloat = Field( + title="Radius", + description="Radius of geometry.", + units=MICROMETER, ) - @pydantic.validator("radius", always=True) - def _radius_not_inf(cls, val): + @field_validator("radius") + def _radius_not_inf(val): """Make sure center is not infinitiy.""" if np.isinf(val): - raise ValidationError("radius can not be td.inf.") + raise ValidationError("radius can not be 'td.inf'.") return val def _intersect_dist(self, position, z0) -> float: @@ -1961,8 +1979,7 @@ class Box(SimplePlaneIntersection, Centered): >>> b = Box(center=(1,2,3), size=(2,2,2)) """ - size: TracedSize = pydantic.Field( - ..., + size: TracedSize = Field( title="Size", description="Size in x, y, and z directions.", units=MICROMETER, @@ -1974,9 +1991,9 @@ def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs): Parameters ---------- - rmin : Tuple[float, float, float] + rmin : tuple[float, float, float] (x, y, z) coordinate of the minimum values. - rmax : Tuple[float, float, float] + rmax : tuple[float, float, float] (x, y, z) coordinate of the maximum values. Example @@ -2009,9 +2026,9 @@ def surfaces(cls, size: Size, center: Coordinate, **kwargs): Parameters ---------- - size : Tuple[float, float, float] + size : tuple[float, float, float] Size of object in x, y, and z directions. - center : Tuple[float, float, float] + center : tuple[float, float, float] Center of object in x, y, and z. Example @@ -2075,10 +2092,10 @@ def del_items(items, indices): surfaces = [] for _cent, _size, _name, _normal_dir in zip(centers, sizes, names, normal_dirs): - if "normal_dir" in cls.__dict__["__fields__"]: + if "normal_dir" in cls.model_fields: kwargs["normal_dir"] = _normal_dir - if "name" in cls.__dict__["__fields__"]: + if "name" in cls.model_fields: kwargs["name"] = _name surface = cls(center=_cent, size=_size, **kwargs) @@ -2099,9 +2116,9 @@ def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs): Parameters ---------- - size : Tuple[float, float, float] + size : tuple[float, float, float] Size of object in x, y, and z directions. - center : Tuple[float, float, float] + center : tuple[float, float, float] Center of object in x, y, and z. Example @@ -2112,14 +2129,14 @@ def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs): """ exclude_surfaces = kwargs.pop("exclude_surfaces", None) surfaces = cls.surfaces(size=size, center=center, **kwargs) - if "name" in cls.__dict__["__fields__"] and exclude_surfaces: + if "name" in cls.model_fields and exclude_surfaces: surfaces = [surf for surf in surfaces if surf.name[-2:] not in exclude_surfaces] return surfaces @verify_packages_import(["trimesh"]) def _do_intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -2133,7 +2150,7 @@ def _do_intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -2181,7 +2198,7 @@ def intersections_plane(self, x: float = None, y: float = None, z: float = None) Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -2241,7 +2258,7 @@ def intersections_with(self, other): Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect this 2D box. For more details refer to `Shapely's Documentation `_. @@ -2277,7 +2294,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float float] + tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ size = self.size @@ -2298,7 +2315,7 @@ def geometry(self): return Box(center=self.center, size=self.size) @cached_property - def zero_dims(self) -> List[Axis]: + def zero_dims(self) -> list[Axis]: """A list of axes along which the :class:`Box` is zero-sized.""" return [dim for dim, size in enumerate(self.size) if size == 0] @@ -2311,7 +2328,7 @@ def _normal_2dmaterial(self) -> Axis: ) return self.size.index(0) - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Box: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Box: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" new_center = list(self.center) @@ -2322,7 +2339,7 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Box: def _plot_arrow( self, - direction: Tuple[float, float, float], + direction: tuple[float, float, float], x: float = None, y: float = None, z: float = None, @@ -2338,7 +2355,7 @@ def _plot_arrow( Parameters ---------- - direction: Tuple[float, float, float] + direction: tuple[float, float, float] Normalized vector describing the arrow direction. x : float = None Position of plotting plane in x direction. @@ -2667,24 +2684,25 @@ def integrate_face(arr: xr.DataArray) -> complex: class Transformed(Geometry): """Class representing a transformed geometry.""" - geometry: annotate_type(GeometryType) = pydantic.Field( - ..., title="Geometry", description="Base geometry to be transformed." + geometry: discriminated_union(GeometryType) = Field( + title="Geometry", + description="Base geometry to be transformed.", ) - transform: MatrixReal4x4 = pydantic.Field( - np.eye(4).tolist(), + transform: MatrixReal4x4 = Field( + default_factory=lambda: np.eye(4).tolist(), title="Transform", description="Transform matrix applied to the base geometry.", ) - @pydantic.validator("transform") - def _transform_is_invertible(cls, val): + @field_validator("transform") + def _transform_is_invertible(val): # If the transform is not invertible, this will raise an error _ = np.linalg.inv(val) return val - @pydantic.validator("geometry") - def _geometry_is_finite(cls, val): + @field_validator("geometry") + def _geometry_is_finite(val): if not np.isfinite(val.bounds).all(): raise ValidationError( "Transformations are only supported on geometries with finite dimensions. " @@ -2693,13 +2711,13 @@ def _geometry_is_finite(cls, val): ) return val - @pydantic.root_validator(skip_on_failure=True) - def _apply_transforms(cls, values): - while isinstance(values["geometry"], Transformed): - inner = values["geometry"] - values["geometry"] = inner.geometry - values["transform"] = np.dot(values["transform"], inner.transform) - return values + @model_validator(mode="after") + def _apply_transforms(self): + while isinstance(self.geometry, Transformed): + inner = self.geometry + object.__setattr__(self, "geometry", inner.geometry) + object.__setattr__(self, "transform", np.dot(self.transform, inner.transform)) + return self @cached_property def inverse(self) -> MatrixReal4x4: @@ -2738,7 +2756,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float, float] + tuple[float, float, float], tuple[float, float, float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ # NOTE (Lucas): The bounds are overestimated because we don't want to calculate @@ -2748,7 +2766,7 @@ def bounds(self) -> Bound: def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -2762,7 +2780,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -2881,7 +2899,7 @@ def rotation(angle: float, axis: Union[Axis, Coordinate]) -> MatrixReal4x4: ---------- angle : float Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] + axis : Union[int, tuple[float, float, float]] Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. Returns @@ -2899,7 +2917,7 @@ def reflection(normal: Coordinate) -> MatrixReal4x4: Parameters ---------- - normal : Tuple[float, float, float] + normal : tuple[float, float, float] Normal of the plane of reflection. Returns @@ -2945,7 +2963,7 @@ def _normal_2dmaterial(self) -> Axis: return normal - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Transformed: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Transformed: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" min_bound = np.array([0, 0, 0, 1.0]) @@ -2962,27 +2980,24 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Transf class ClipOperation(Geometry): """Class representing the result of a set operation between geometries.""" - operation: ClipOperationType = pydantic.Field( - ..., + operation: ClipOperationType = Field( title="Operation Type", description="Operation to be performed between geometries.", ) - geometry_a: annotate_type(GeometryType) = pydantic.Field( - ..., + geometry_a: discriminated_union(GeometryType) = Field( title="Geometry A", description="First operand for the set operation. It can be any geometry type, including " ":class:`GeometryGroup`.", ) - geometry_b: annotate_type(GeometryType) = pydantic.Field( - ..., + geometry_b: discriminated_union(GeometryType) = Field( title="Geometry B", description="Second operand for the set operation. It can also be any geometry type.", ) - @pydantic.validator("geometry_a", "geometry_b", always=True) - def _geometries_untraced(cls, val): + @field_validator("geometry_a", "geometry_b") + def _geometries_untraced(val): """Make sure that ``ClipOperation`` geometries do not contain tracers.""" traced = val.strip_traced_fields() if traced: @@ -2993,7 +3008,7 @@ def _geometries_untraced(cls, val): return val @staticmethod - def to_polygon_list(base_geometry: Shapely) -> List[Shapely]: + def to_polygon_list(base_geometry: Shapely) -> list[Shapely]: """Return a list of valid polygons from a shapely geometry, discarding points, lines, and empty polygons. @@ -3004,7 +3019,7 @@ def to_polygon_list(base_geometry: Shapely) -> List[Shapely]: Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] Valid polygons retrieved from ``base geometry``. """ if base_geometry.geom_type == "GeometryCollection": @@ -3039,7 +3054,7 @@ def _bit_operation(self) -> Callable[[Any, Any], Any]: def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -3053,7 +3068,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -3066,7 +3081,7 @@ def intersections_tilted_plane( def intersections_plane( self, x: float = None, y: float = None, z: float = None - ) -> List[Shapely]: + ) -> list[Shapely]: """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. Parameters @@ -3080,7 +3095,7 @@ def intersections_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentaton `_. @@ -3097,7 +3112,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float float] + tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ # Overestimates @@ -3203,7 +3218,7 @@ def _normal_2dmaterial(self) -> Axis: ) return normal_a - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> ClipOperation: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOperation: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" new_geom_a = self.geometry_a._update_from_bounds(bounds=bounds, axis=axis) @@ -3214,16 +3229,15 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> ClipOp class GeometryGroup(Geometry): """A collection of Geometry objects that can be called as a single geometry object.""" - geometries: Tuple[annotate_type(GeometryType), ...] = pydantic.Field( - ..., + geometries: tuple[discriminated_union(GeometryType), ...] = Field( title="Geometries", description="Tuple of geometries in a single grouping. " "Can provide significant performance enhancement in ``Structure`` when all geometries are " "assigned the same medium.", ) - @pydantic.validator("geometries", always=True) - def _geometries_not_empty(cls, val): + @field_validator("geometries") + def _geometries_not_empty(val): """make sure geometries are not empty.""" if not len(val) > 0: raise ValidationError("GeometryGroup.geometries must not be empty.") @@ -3235,7 +3249,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float, float] + tuple[float, float, float], tuple[float, float, float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ @@ -3247,7 +3261,7 @@ def bounds(self) -> Bound: def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -3261,7 +3275,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -3274,7 +3288,7 @@ def intersections_tilted_plane( def intersections_plane( self, x: float = None, y: float = None, z: float = None - ) -> List[Shapely]: + ) -> list[Shapely]: """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. Parameters @@ -3288,7 +3302,7 @@ def intersections_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -3394,12 +3408,12 @@ def _normal_2dmaterial(self) -> Axis: ) return normal - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> GeometryGroup: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> GeometryGroup: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" - new_geometries = [ + new_geometries = tuple( geometry._update_from_bounds(bounds=bounds, axis=axis) for geometry in self.geometries - ] + ) return self.updated_copy(geometries=new_geometries) def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: diff --git a/tidy3d/components/geometry/mesh.py b/tidy3d/components/geometry/mesh.py index d2e258ee66..96e9cf6c7a 100644 --- a/tidy3d/components/geometry/mesh.py +++ b/tidy3d/components/geometry/mesh.py @@ -3,10 +3,10 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, field_validator, model_validator from ...constants import inf from ...exceptions import DataError, ValidationError @@ -33,21 +33,21 @@ class TriangleMesh(base.Geometry, ABC): >>> stl_geom = TriangleMesh.from_vertices_faces(vertices, faces) """ - mesh_dataset: Optional[TriangleMeshDataset] = pydantic.Field( - ..., + mesh_dataset: Optional[TriangleMeshDataset] = Field( + None, title="Surface mesh data", description="Surface mesh data.", ) _no_nans_mesh = validate_no_nans("mesh_dataset") - @pydantic.root_validator(pre=True) + @model_validator(mode="before") @verify_packages_import(["trimesh"]) - def _validate_trimesh_library(cls, values): + def _validate_trimesh_library(data): """Check if the trimesh package is imported as a validator.""" - return values + return data - @pydantic.validator("mesh_dataset", pre=True, always=True) + @field_validator("mesh_dataset", mode="before") def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: """Warn if the Dataset fails to load.""" if isinstance(val, dict): @@ -56,8 +56,7 @@ def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: return None return val - @pydantic.validator("mesh_dataset", always=True) - @verify_packages_import(["trimesh"]) + @field_validator("mesh_dataset") def _check_mesh(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: """Check that the mesh is valid.""" if val is None: @@ -150,7 +149,7 @@ def from_stl( cls, filename: str, scale: float = 1.0, - origin: Tuple[float, float, float] = (0, 0, 0), + origin: tuple[float, float, float] = (0, 0, 0), solid_index: int = None, **kwargs, ) -> Union[TriangleMesh, base.GeometryGroup]: @@ -168,7 +167,7 @@ def from_stl( The length scale for the loaded geometry (um). For example, a scale of 10.0 means that a vertex (1, 0, 0) will be placed at x = 10 um. - origin : Tuple[float, float, float] = (0, 0, 0) + origin : tuple[float, float, float] = (0, 0, 0) The origin of the loaded geometry, in units of ``scale``. Translates from (0, 0, 0) to this point after applying the scaling. solid_index : int = None @@ -339,7 +338,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float float] + tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ if self.mesh_dataset is None: @@ -348,7 +347,7 @@ def bounds(self) -> Bound: def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -362,7 +361,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -375,7 +374,7 @@ def intersections_tilted_plane( def intersections_plane( self, x: float = None, y: float = None, z: float = None - ) -> List[Shapely]: + ) -> list[Shapely]: """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. Parameters @@ -389,7 +388,7 @@ def intersections_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentaton `_. diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index 51b66ab8a1..32a21f24c7 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -4,21 +4,21 @@ import math from copy import copy -from typing import List, Tuple, Union +from typing import Union import autograd.numpy as np -import pydantic.v1 as pydantic import shapely from autograd.tracer import getval, isbox +from pydantic import Field, PositiveFloat, field_validator, model_validator from ...constants import LARGE_NUMBER, MICROMETER, fp_eps from ...exceptions import SetupError, ValidationError from ...log import log from ...packaging import verify_packages_import -from ..autograd import AutogradFieldMap, TracedVertices, get_static +from ..autograd import AutogradFieldMap, TracedArrayFloat2D, get_static from ..autograd.derivative_utils import DerivativeInfo, DerivativeSurfaceMesh from ..autograd.types import TracedFloat -from ..base import cached_property, skip_if_fields_missing +from ..base import cached_property from ..transformation import ReflectionFromPlane, RotationAroundAxis from ..types import ( ArrayFloat1D, @@ -60,14 +60,13 @@ class PolySlab(base.Planar): >>> p = PolySlab(vertices=vertices, axis=2, slab_bounds=(-1, 1)) """ - slab_bounds: Tuple[TracedFloat, TracedFloat] = pydantic.Field( - ..., + slab_bounds: tuple[TracedFloat, TracedFloat] = Field( title="Slab Bounds", description="Minimum and maximum positions of the slab along axis dimension.", units=MICROMETER, ) - dilation: float = pydantic.Field( + dilation: float = Field( 0.0, title="Dilation", description="Dilation of the supplied polygon by shifting each edge along its " @@ -75,8 +74,7 @@ class PolySlab(base.Planar): units=MICROMETER, ) - vertices: TracedVertices = pydantic.Field( - ..., + vertices: TracedArrayFloat2D = Field( title="Vertices", description="List of (d1, d2) defining the 2 dimensional positions of the polygon " "face vertices at the ``reference_plane``. " @@ -91,8 +89,8 @@ def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: vertices = get_static(vertices) return shapely.Polygon(vertices) - @pydantic.validator("slab_bounds", always=True) - def slab_bounds_order(cls, val): + @field_validator("slab_bounds") + def slab_bounds_order(val): """Maximum position of the slab should be no smaller than its minimal position.""" if val[1] < val[0]: raise SetupError( @@ -102,18 +100,16 @@ def slab_bounds_order(cls, val): ) return val - @pydantic.validator("vertices", always=True) + @field_validator("vertices") + @classmethod def correct_shape(cls, val): - """Makes sure vertices size is correct. - Make sure no intersecting edges. - """ + """Makes sure vertices size is correct. Make sure no intersecting edges.""" # overall shape of vertices if val.shape[1] != 2: raise SetupError( "PolySlab.vertices must be a 2 dimensional array shaped (N, 2). " f"Given array with shape of {val.shape}." ) - # make sure no polygon splitting, islands, 0 area poly_heal = shapely.make_valid(cls.make_shapely_polygon(val)) if poly_heal.area < _MIN_POLYGON_AREA: @@ -128,9 +124,8 @@ def correct_shape(cls, val): ) return val - @pydantic.validator("vertices", always=True) - @skip_if_fields_missing(["dilation"]) - def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values): + @model_validator(mode="after") + def no_complex_self_intersecting_polygon_at_reference_plane(self): """At the reference plane, check if the polygon is self-intersecting. There are two types of self-intersection that can occur during dilation: @@ -140,12 +135,13 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values): For 1), we issue an error since it is yet to be supported; For 2), we heal the polygon, and warn that the polygon has been cleaned up. """ + val = self.vertices # no need to validate anything here - if math.isclose(values["dilation"], 0): - return val + if math.isclose(self.dilation, 0): + return self val_np = PolySlab._proper_vertices(val) - dist = values["dilation"] + dist = self.dilation # 0) fully eroded if dist < 0 and dist < -PolySlab._maximal_erosion(val_np): @@ -153,14 +149,14 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values): # no edge events if not PolySlab._edge_events_detection(val_np, dist, ignore_at_dist=False): - return val + return self poly_offset = PolySlab._shift_vertices(val_np, dist)[0] if PolySlab._area(poly_offset) < fp_eps**2: raise SetupError("Erosion value is too large. The polygon is fully eroded.") # edge events - poly_offset = shapely.make_valid(cls.make_shapely_polygon(poly_offset)) + poly_offset = shapely.make_valid(self.make_shapely_polygon(poly_offset)) # 1) polygon split or create holes/islands if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: raise SetupError( @@ -176,11 +172,10 @@ def no_complex_self_intersecting_polygon_at_reference_plane(cls, val, values): "self-intersecting polygon. " "The vertices have been modified to make a valid polygon." ) - return val + return self - @pydantic.validator("vertices", always=True) - @skip_if_fields_missing(["sidewall_angle", "dilation", "slab_bounds", "reference_plane"]) - def no_self_intersecting_polygon_during_extrusion(cls, val, values): + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self): """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that any normal cross section of the PolySlab cannot be self-intersecting. This part checks if any self-interction will occur during extrusion with non-zero sidewall angle. @@ -194,28 +189,29 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values): To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation of polygons/holes, and changes in vertices number. """ + val = self.vertices # no need to validate anything here - if math.isclose(values["sidewall_angle"], 0): - return val + if math.isclose(self.sidewall_angle, 0): + return self # apply dilation poly_ref = PolySlab._proper_vertices(val) - if not math.isclose(values["dilation"], 0): - poly_ref = PolySlab._shift_vertices(poly_ref, values["dilation"])[0] + if not math.isclose(self.dilation, 0): + poly_ref = PolySlab._shift_vertices(poly_ref, self.dilation)[0] poly_ref = PolySlab._heal_polygon(poly_ref) - slab_min, slab_max = values["slab_bounds"] + slab_min, slab_max = self.slab_bounds slab_bounds = [getval(slab_min), getval(slab_max)] # Fist, check vertex-vertex crossing at any point during extrusion length = slab_bounds[1] - slab_bounds[0] - dist = [-length * np.tan(values["sidewall_angle"])] + dist = [-length * np.tan(self.sidewall_angle)] # reverse the dilation value if it's defined on the top - if values["reference_plane"] == "top": + if self.reference_plane == "top": dist = [-dist[0]] # for middle, both direction needs to be examined - elif values["reference_plane"] == "middle": + elif self.reference_plane == "middle": dist = [dist[0] / 2, -dist[0] / 2] # capture vertex crossing events @@ -245,21 +241,21 @@ def no_self_intersecting_polygon_during_extrusion(cls, val, values): "A general treatment to self-intersecting polygon will be available " "in future releases." ) - return val + return self @classmethod def from_gds( cls, gds_cell, axis: Axis, - slab_bounds: Tuple[float, float], + slab_bounds: tuple[float, float], gds_layer: int, gds_dtype: int = None, - gds_scale: pydantic.PositiveFloat = 1.0, + gds_scale: PositiveFloat = 1.0, dilation: float = 0.0, sidewall_angle: float = 0, reference_plane: PlanePosition = "middle", - ) -> List[PolySlab]: + ) -> list[PolySlab]: """Import :class:`PolySlab` from a ``gdstk.Cell`` or a ``gdspy.Cell``. Parameters @@ -268,7 +264,7 @@ def from_gds( ``gdstk.Cell`` or ``gdspy.Cell`` containing 2D geometric data. axis : int Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: Tuple[float, float] + slab_bounds: tuple[float, float] Minimum and maximum positions of the slab along ``axis``. gds_layer : int Layer index in the ``gds_cell``. @@ -294,7 +290,7 @@ def from_gds( Returns ------- - List[:class:`PolySlab`] + list[:class:`PolySlab`] List of :class:`PolySlab` objects sharing ``axis`` and slab bound properties. """ @@ -317,8 +313,8 @@ def _load_gds_vertices( gds_cell, gds_layer: int, gds_dtype: int = None, - gds_scale: pydantic.PositiveFloat = 1.0, - ) -> List[ArrayFloat2D]: + gds_scale: PositiveFloat = 1.0, + ) -> list[ArrayFloat2D]: """Import :class:`PolySlab` from a ``gdstk.Cell`` or a ``gdspy.Cell``. Parameters @@ -337,7 +333,7 @@ def _load_gds_vertices( Returns ------- - List[ArrayFloat2D] + list[ArrayFloat2D] List of :class:`.ArrayFloat2D` """ @@ -466,7 +462,7 @@ def _normal_2dmaterial(self) -> Axis: raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.") return self.axis - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> PolySlab: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> PolySlab: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" if axis != self.axis: @@ -474,7 +470,7 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> PolySl f"'_update_from_bounds' may only be applied along axis '{self.axis}', " f"but was given axis '{axis}'." ) - return self.updated_copy(slab_bounds=bounds) + return self.updated_copy(slab_bounds=tuple(bounds)) @cached_property def is_ccw(self) -> bool: @@ -578,7 +574,7 @@ def _move_axis_reverse(arr): @verify_packages_import(["trimesh"]) def _do_intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -592,7 +588,7 @@ def _do_intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -640,7 +636,7 @@ def _intersections_normal(self, z: float): Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -682,7 +678,7 @@ def _intersections_side(self, position, axis) -> list: Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -800,7 +796,7 @@ def _find_intersecting_height(self, position: float, axis: int) -> np.ndarray: def _find_intersecting_ys_angle_vertical( self, vertices: np.ndarray, position: float, axis: int, exclude_on_vertices: bool = False - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Finds pairs of forward and backwards vertices where polygon intersects position at axis, Find intersection point (in y) assuming straight line,and intersecting angle between plane and edges. (For unslanted polyslab). @@ -880,7 +876,7 @@ def _find_intersecting_ys_angle_vertical( def _find_intersecting_ys_angle_slant( self, vertices: np.ndarray, position: float, axis: int - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Finds pairs of forward and backwards vertices where polygon intersects position at axis, Find intersection point (in y) assuming straight line,and intersecting angle between plane and edges. (For slanted polyslab) @@ -992,7 +988,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float float] + tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ @@ -1252,7 +1248,7 @@ def normalize(v): @staticmethod def _shift_vertices( vertices: np.ndarray, dist - ) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]: + ) -> tuple[np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray]]: """Shifts the vertices of a polygon outward uniformly by distances `dists`. @@ -1265,7 +1261,7 @@ def _shift_vertices( Returns ------- - Tuple[np.ndarray, np.narray,Tuple[np.ndarray,np.ndarray]] + tuple[np.ndarray, np.narray,tuple[np.ndarray,np.ndarray]] New polygon vertices; and the shift of vertices in direction parallel to the edges. Shift along x and y direction. @@ -1313,7 +1309,7 @@ def normalize(v): return np.swapaxes(vs_orig + shift_total, -2, -1), parallel_shift, (shift_x, shift_y) @staticmethod - def _edge_length_and_reduction_rate(vertices: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def _edge_length_and_reduction_rate(vertices: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Edge length of reduction rate of each edge with unit offset length. Parameters @@ -1323,7 +1319,7 @@ def _edge_length_and_reduction_rate(vertices: np.ndarray) -> Tuple[np.ndarray, n Returns ------- - Tuple[np.ndarray, np.narray] + tuple[np.ndarray, np.narray] edge length, and reduction rate """ @@ -1445,7 +1441,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM def compute_derivative_slab_face( self, derivative_info: DerivativeInfo, min_max_index: int - ) -> TracedVertices: + ) -> TracedArrayFloat2D: """Derivative with respect to slab_bounds.""" rmin, rmax = derivative_info.bounds @@ -1519,7 +1515,7 @@ def get_grad(min_max_index: int) -> float: def compute_derivative_slab_face_single_pt( self, derivative_info: DerivativeInfo, min_max_index: int - ) -> TracedVertices: + ) -> TracedArrayFloat2D: """Derivative with respect to slab faces (single point approximation).""" self_static = self.to_static() @@ -1547,7 +1543,7 @@ def compute_derivative_slab_face_single_pt( return vjp - def compute_derivative_vertices(self, derivative_info: DerivativeInfo) -> TracedVertices: + def compute_derivative_vertices(self, derivative_info: DerivativeInfo) -> TracedArrayFloat2D: # derivative w.r.t each edge vertices = np.array(self.vertices) @@ -1642,7 +1638,7 @@ def unpop_axis_vect(self, ax_coords: np.ndarray, plane_coords: np.ndarray) -> np arr_xyz = np.stack(arr_xyz, axis=-1) return arr_xyz - def pop_axis_vect(self, coord: np.ndarray) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + def pop_axis_vect(self, coord: np.ndarray) -> tuple[np.ndarray, tuple[np.ndarray, np.ndarray]]: """Combine coordinate along axis with coordinates on the plane tangent to the axis. coord.shape == [N, 3] @@ -1714,7 +1710,7 @@ def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> PolySlab: ---------- angle : float Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] + axis : Union[int, tuple[float, float, float]] Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. Returns @@ -1741,7 +1737,7 @@ def reflected(self, normal: Coordinate) -> PolySlab: Parameters ---------- - normal : Tuple[float, float, float] + normal : tuple[float, float, float] The 3D normal vector of the plane of reflection. The plane is assumed to pass through the origin (0,0,0). @@ -1769,24 +1765,24 @@ class ComplexPolySlabBase(PolySlab): occur during extrusion. This class should not be used directly. Use instead :class:`plugins.polyslab.ComplexPolySlab`.""" - @pydantic.validator("vertices", always=True) - def no_self_intersecting_polygon_during_extrusion(cls, val, values): + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self): """Turn off the validation for this class.""" - return val + return self @classmethod def from_gds( cls, gds_cell, axis: Axis, - slab_bounds: Tuple[float, float], + slab_bounds: tuple[float, float], gds_layer: int, gds_dtype: int = None, - gds_scale: pydantic.PositiveFloat = 1.0, + gds_scale: PositiveFloat = 1.0, dilation: float = 0.0, sidewall_angle: float = 0, reference_plane: PlanePosition = "middle", - ) -> List[PolySlab]: + ) -> list[PolySlab]: """Import :class:`.PolySlab` from a ``gdstk.Cell``. Parameters @@ -1795,7 +1791,7 @@ def from_gds( ``gdstk.Cell`` containing 2D geometric data. axis : int Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: Tuple[float, float] + slab_bounds: tuple[float, float] Minimum and maximum positions of the slab along ``axis``. gds_layer : int Layer index in the ``gds_cell``. @@ -1821,7 +1817,7 @@ def from_gds( Returns ------- - List[:class:`.PolySlab`] + list[:class:`.PolySlab`] List of :class:`.PolySlab` objects sharing ``axis`` and slab bound properties. """ @@ -1855,14 +1851,14 @@ def geometry_group(self) -> base.GeometryGroup: return base.GeometryGroup(geometries=self.sub_polyslabs) @property - def sub_polyslabs(self) -> List[PolySlab]: + def sub_polyslabs(self) -> list[PolySlab]: """Divide a complex polyslab into a list of simple polyslabs. Only neighboring vertex-vertex crossing events are treated in this version. Returns ------- - List[PolySlab] + list[PolySlab] A list of simple polyslabs. """ sub_polyslab_list = [] @@ -1940,7 +1936,7 @@ def sub_polyslabs(self) -> List[PolySlab]: return sub_polyslab_list @property - def _dilation_length(self) -> List[float]: + def _dilation_length(self) -> list[float]: """dilation length from reference plane to the top/bottom of the polyslab.""" # for "bottom", only needs to compute the offset length to the top @@ -1966,7 +1962,7 @@ def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -1980,7 +1976,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index 3b5adc69a1..289ca93e2d 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -3,20 +3,19 @@ from __future__ import annotations from math import isclose -from typing import List import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic import shapely +from pydantic import Field, model_validator from ...constants import C_0, LARGE_NUMBER, MICROMETER from ...exceptions import SetupError, ValidationError from ...packaging import verify_packages_import from ..autograd import AutogradFieldMap, TracedSize1D from ..autograd.derivative_utils import DerivativeInfo -from ..base import cached_property, skip_if_fields_missing -from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely, Tuple +from ..base import cached_property +from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely from . import base from .polyslab import PolySlab @@ -71,7 +70,7 @@ def inside( def intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -85,7 +84,7 @@ def intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -124,7 +123,7 @@ def intersections_plane(self, x: float = None, y: float = None, z: float = None) Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -195,35 +194,32 @@ class Cylinder(base.Centered, base.Circular, base.Planar): """ # Provide more explanations on where radius is defined - radius: TracedSize1D = pydantic.Field( - ..., + radius: TracedSize1D = Field( title="Radius", description="Radius of geometry at the ``reference_plane``.", units=MICROMETER, ) - length: TracedSize1D = pydantic.Field( - ..., + length: TracedSize1D = Field( title="Length", description="Defines thickness of cylinder along axis dimension.", units=MICROMETER, ) - @pydantic.validator("length", always=True) - @skip_if_fields_missing(["sidewall_angle", "reference_plane"]) - def _only_middle_for_infinite_length_slanted_cylinder(cls, val, values): + @model_validator(mode="after") + def _only_middle_for_infinite_length_slanted_cylinder(self): """For a slanted cylinder of infinite length, ``reference_plane`` can only be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. """ - if isclose(values["sidewall_angle"], 0) or not np.isinf(val): - return val - if values["reference_plane"] != "middle": + if isclose(self.sidewall_angle, 0) or not np.isinf(self.length): + return self + if self.reference_plane != "middle": raise SetupError( "For a slanted cylinder here is of infinite length, " "defining the reference_plane other than 'middle' " "leads to undefined cylinder behaviors near 'center'." ) - return val + return self def to_polyslab( self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB, **kwargs @@ -354,7 +350,7 @@ def _normal_2dmaterial(self) -> Axis: raise ValidationError("'Medium2D' requires the 'Cylinder' length to be zero.") return self.axis - def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Cylinder: + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Cylinder: """Returns an updated geometry which has been transformed to fit within ``bounds`` along the ``axis`` direction.""" if axis != self.axis: @@ -365,12 +361,12 @@ def _update_from_bounds(self, bounds: Tuple[float, float], axis: Axis) -> Cylind new_center = list(self.center) new_center[axis] = (bounds[0] + bounds[1]) / 2 new_length = bounds[1] - bounds[0] - return self.updated_copy(center=new_center, length=new_length) + return self.updated_copy(center=tuple(new_center), length=new_length) @verify_packages_import(["trimesh"]) def _do_intersections_tilted_plane( self, normal: Coordinate, origin: Coordinate, to_2D: MatrixReal4x4 - ) -> List[Shapely]: + ) -> list[Shapely]: """Return a list of shapely geometries at the plane specified by normal and origin. Parameters @@ -384,7 +380,7 @@ def _do_intersections_tilted_plane( Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -474,7 +470,7 @@ def _intersections_normal(self, z: float): Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -506,7 +502,7 @@ def _intersections_side(self, position, axis): Returns ------- - List[shapely.geometry.base.BaseGeometry] + list[shapely.geometry.base.BaseGeometry] List of 2D shapes that intersect plane. For more details refer to `Shapely's Documentation `_. @@ -728,7 +724,7 @@ def _radius_z(self, z: float): return radius_middle - (z - self.center_axis) * self._tanq - def _local_to_global_side_cross_section(self, coords: List[float], axis: int) -> List[float]: + def _local_to_global_side_cross_section(self, coords: list[float], axis: int) -> list[float]: """Map a point (x,y) from local to global coordinate system in the side cross section. @@ -741,7 +737,7 @@ def _local_to_global_side_cross_section(self, coords: List[float], axis: int) -> ---------- axis : int Integer index into 'xyz' (0, 1, 2). - coords : List[float, float] + coords : list[float, float] The value in the planar coordinate. Returns diff --git a/tidy3d/components/geometry/triangulation.py b/tidy3d/components/geometry/triangulation.py index a6a6e00d2c..f750461fb9 100644 --- a/tidy3d/components/geometry/triangulation.py +++ b/tidy3d/components/geometry/triangulation.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import List, Tuple import numpy as np import shapely @@ -33,12 +32,12 @@ class Vertex: is_ear: bool -def update_convexity(vertices: List[Vertex], i: int) -> int: +def update_convexity(vertices: list[Vertex], i: int) -> int: """Update the convexity of a vertex in a polygon. Parameters ---------- - vertices : List[Vertex] + vertices : list[Vertex] Vertices of the polygon. i : int Index of the vertex to be updated. @@ -69,7 +68,7 @@ def update_convexity(vertices: List[Vertex], i: int) -> int: def is_inside( - vertex: ArrayFloat1D, triangle: Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + vertex: ArrayFloat1D, triangle: tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] ) -> bool: """Check if a vertex is inside a triangle. @@ -77,7 +76,7 @@ def is_inside( ---------- vertex : ArrayFloat1D Vertex coordinates. - triangle : Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + triangle : tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] Vertices of the triangle in CCW order. Returns @@ -90,12 +89,12 @@ def is_inside( ) -def update_ear_flag(vertices: List[Vertex], i: int) -> None: +def update_ear_flag(vertices: list[Vertex], i: int) -> None: """Update the ear flag of a vertex in a polygon. Parameters ---------- - vertices : List[Vertex] + vertices : list[Vertex] Vertices of the polygon. i : int Index of the vertex to be updated. @@ -112,7 +111,7 @@ def update_ear_flag(vertices: List[Vertex], i: int) -> None: # TODO: This is an inefficient algorithm that runs in O(n^2). We should use something # better, and probably as a compiled extension. -def triangulate(vertices: ArrayFloat2D) -> List[Tuple[int, int, int]]: +def triangulate(vertices: ArrayFloat2D) -> list[tuple[int, int, int]]: """Triangulate a simple polygon. Parameters @@ -122,7 +121,7 @@ def triangulate(vertices: ArrayFloat2D) -> List[Tuple[int, int, int]]: Returns ------- - List[Tuple[int, int, int]] + list[tuple[int, int, int]] List of indices of the vertices of the triangles. """ is_ccw = shapely.LinearRing(vertices).is_ccw diff --git a/tidy3d/components/geometry/utils.py b/tidy3d/components/geometry/utils.py index 9eaa7542e3..6a349ad02c 100644 --- a/tidy3d/components/geometry/utils.py +++ b/tidy3d/components/geometry/utils.py @@ -4,10 +4,10 @@ from enum import Enum from math import isclose -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np -import pydantic as pydantic +from pydantic import Field from ...constants import fp_eps from ...exceptions import SetupError, Tidy3dError @@ -31,16 +31,16 @@ def merging_geometries_on_plane( - geometries: List[GeometryType], + geometries: list[GeometryType], plane: Box, - property_list: List[Any], -) -> List[Tuple[Any, Shapely]]: + property_list: list[Any], +) -> list[tuple[Any, Shapely]]: """Compute list of shapes on plane. Overlaps are removed or merged depending on provided property_list. Parameters ---------- - geometries : List[GeometryType] + geometries : list[GeometryType] List of structures to filter on the plane. plane : Box Plane specification. @@ -49,7 +49,7 @@ def merging_geometries_on_plane( Returns ------- - List[Tuple[Any, shapely]] + list[tuple[Any, shapely]] List of shapes and their property value on the plane after merging. """ @@ -191,7 +191,7 @@ def traverse_geometries(geometry: GeometryType) -> GeometryType: def from_shapely( shape: Shapely, axis: Axis, - slab_bounds: Tuple[float, float], + slab_bounds: tuple[float, float], dilation: float = 0.0, sidewall_angle: float = 0, reference_plane: PlanePosition = "middle", @@ -205,7 +205,7 @@ def from_shapely( of any of those. axis : int Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] + slab_bounds: tuple[float, float] Minimal and maximal positions of the extruded slab along ``axis``. dilation : float Dilation of the polygon in the base by shifting each edge along its normal outwards @@ -278,7 +278,7 @@ def vertices_from_shapely(shape: Shapely) -> ArrayFloat2D: Returns ------- - List[Tuple[ArrayFloat2D]] + list[tuple[ArrayFloat2D]] List of tuples ``(exterior, *interiors)``. """ if shape.geom_type == "LinearRing": @@ -353,14 +353,12 @@ class SnapBehavior(Enum): class SnappingSpec(Tidy3dBaseModel): """Specifies how to apply grid snapping along each dimension.""" - location: tuple[SnapLocation, SnapLocation, SnapLocation] = pydantic.Field( - ..., + location: tuple[SnapLocation, SnapLocation, SnapLocation] = Field( title="Location", description="Describes which positions in the grid will be considered for snapping.", ) - behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = pydantic.Field( - ..., + behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = Field( title="Behavior", description="Describes how snapping positions will be chosen.", ) diff --git a/tidy3d/components/geometry/utils_2d.py b/tidy3d/components/geometry/utils_2d.py index 5d7b0860f9..df318e2b41 100644 --- a/tidy3d/components/geometry/utils_2d.py +++ b/tidy3d/components/geometry/utils_2d.py @@ -1,7 +1,6 @@ """Utilities for 2D geometry manipulation.""" from math import isclose -from typing import List, Tuple import numpy as np import shapely @@ -46,7 +45,7 @@ def snap_coordinate_to_grid(grid: Grid, center: float, axis: Axis) -> float: return new_center -def get_bounds(geom: Geometry, axis: Axis) -> Tuple[float, float]: +def get_bounds(geom: Geometry, axis: Axis) -> tuple[float, float]: """Get the bounds of a geometry in the axis direction.""" return (geom.bounds[0][axis], geom.bounds[1][axis]) @@ -62,7 +61,7 @@ def get_thickened_geom(geom: Geometry, axis: Axis): def get_neighbors( geom: Geometry, axis: Axis, - structures: List[Structure], + structures: list[Structure], ): """Find the neighboring structures and return the tested positions above and below.""" center = get_bounds(geom, axis)[0] @@ -100,8 +99,8 @@ def get_neighbors( def subdivide( - geom: Geometry, structures: List[Structure] -) -> List[Tuple[Geometry, Structure, Structure]]: + geom: Geometry, structures: list[Structure] +) -> list[tuple[Geometry, Structure, Structure]]: """Subdivide geometry associated with a :class:`.Medium2D` into partitions that each have a homogeneous substrate / superstrate. Partitions are computed using ``shapely`` boolean operations on polygons. @@ -110,12 +109,12 @@ def subdivide( ---------- geom : Geometry A 2D geometry associated with the :class:`.Medium2D`. - structures : List[Structure] + structures : list[Structure] List of structures that are checked for intersection with ``geom``. Returns ------- - List[Tuple[Geometry, Structure, Structure]] + list[tuple[Geometry, Structure, Structure]] List of the created partitions. Each element of the list represents a partition of the 2D geometry, which includes the newly created structures below and above. diff --git a/tidy3d/components/grid/corner_finder.py b/tidy3d/components/grid/corner_finder.py index aa3c5347bf..2a3e85222a 100644 --- a/tidy3d/components/grid/corner_finder.py +++ b/tidy3d/components/grid/corner_finder.py @@ -1,9 +1,9 @@ """Find corners of structures on a 2D plane.""" -from typing import Any, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt from ...constants import inf from ..base import Tidy3dBaseModel, cached_property @@ -19,7 +19,7 @@ class CornerFinderSpec(Tidy3dBaseModel): """Specification for corner detection on a 2D plane.""" - medium: Literal["metal", "dielectric", "all"] = pd.Field( + medium: Literal["metal", "dielectric", "all"] = Field( "metal", title="Material Type For Corner Identification", description="Find corners of structures made of ``medium``, " @@ -27,7 +27,7 @@ class CornerFinderSpec(Tidy3dBaseModel): "for non-metallic materials, and ``all`` for all materials.", ) - angle_threshold: float = pd.Field( + angle_threshold: float = Field( CORNER_ANGLE_THRESOLD, title="Angle Threshold In Corner Identification", description="A vertex is qualified as a corner if the angle spanned by its two edges " @@ -37,28 +37,28 @@ class CornerFinderSpec(Tidy3dBaseModel): lt=np.pi, ) - distance_threshold: Optional[pd.PositiveFloat] = pd.Field( + distance_threshold: Optional[PositiveFloat] = Field( None, title="Distance Threshold In Corner Identification", description="If not ``None`` and the distance of the vertex to its neighboring vertices " "is below the threshold value based on Douglas-Peucker algorithm, the vertex is disqualified as a corner.", ) - concave_resolution: Optional[pd.PositiveInt] = pd.Field( + concave_resolution: Optional[PositiveInt] = Field( None, title="Concave Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on concave featues." "If set to ``None``, then the corresponding `dl_min` reduction is not applied.", ) - convex_resolution: Optional[pd.PositiveInt] = pd.Field( + convex_resolution: Optional[PositiveInt] = Field( None, title="Convex Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on convex featues." "If set to ``None``, then the corresponding `dl_min` reduction is not applied.", ) - mixed_resolution: Optional[pd.PositiveInt] = pd.Field( + mixed_resolution: Optional[PositiveInt] = Field( None, title="Mixed Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on mixed featues." @@ -80,10 +80,10 @@ def _merged_pec_on_plane( cls, normal_axis: Axis, coord: float, - structure_list: List[Structure], - center: Tuple[float, float] = [0, 0, 0], - size: Tuple[float, float, float] = [inf, inf, inf], - ) -> List[Tuple[Any, Shapely]]: + structure_list: list[Structure], + center: tuple[float, float] = [0, 0, 0], + size: tuple[float, float, float] = [inf, inf, inf], + ) -> list[tuple[Any, Shapely]]: """On a 2D plane specified by axis = `normal_axis` and coordinate `coord`, merge geometries made of PEC. Parameters @@ -92,16 +92,16 @@ def _merged_pec_on_plane( Axis normal to the 2D plane. coord : float Position of plane along the normal axis. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in simulation. - center : Tuple[float, float] = [0, 0, 0] + center : tuple[float, float] = [0, 0, 0] Center of the 2D plane (coordinate along ``axis`` is ignored) - size : Tuple[float, float, float] = [inf, inf, inf] + size : tuple[float, float, float] = [inf, inf, inf] Size of the 2D plane (size along ``axis`` is ignored) Returns ------- - List[Tuple[Any, Shapely]] + list[tuple[Any, Shapely]] List of shapes and their property value on the plane after merging. """ @@ -129,9 +129,9 @@ def _corners_and_convexity( self, normal_axis: Axis, coord: float, - structure_list: List[Structure], + structure_list: list[Structure], ravel: bool, - ) -> Tuple[ArrayFloat2D, ArrayFloat1D]: + ) -> tuple[ArrayFloat2D, ArrayFloat1D]: """On a 2D plane specified by axis = `normal_axis` and coordinate `coord`, find out corners of merged geometries made of PEC. @@ -142,14 +142,14 @@ def _corners_and_convexity( Axis normal to the 2D plane. coord : float Position of plane along the normal axis. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in simulation. ravel : bool Whether to put the resulting corners in a single list or per polygon. Returns ------- - Tuple[ArrayFloat2D, ArrayFloat1D] + tuple[ArrayFloat2D, ArrayFloat1D] Corner coordinates and their convexity. """ @@ -192,7 +192,7 @@ def corners( self, normal_axis: Axis, coord: float, - structure_list: List[Structure], + structure_list: list[Structure], ) -> ArrayFloat2D: """On a 2D plane specified by axis = `normal_axis` and coordinate `coord`, find out corners of merged geometries made of `medium`. @@ -204,7 +204,7 @@ def corners( Axis normal to the 2D plane. coord : float Position of plane along the normal axis. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in simulation. Returns @@ -220,7 +220,7 @@ def corners( def _filter_collinear_vertices( self, vertices: ArrayFloat2D - ) -> Tuple[ArrayFloat2D, ArrayFloat1D]: + ) -> tuple[ArrayFloat2D, ArrayFloat1D]: """Filter collinear vertices of a polygon, and return corners locations and their convexity. Parameters diff --git a/tidy3d/components/grid/grid.py b/tidy3d/components/grid/grid.py index 62353b659e..fd215e908b 100644 --- a/tidy3d/components/grid/grid.py +++ b/tidy3d/components/grid/grid.py @@ -2,17 +2,17 @@ from __future__ import annotations -from typing import Dict, List, Tuple, Union +from typing import Literal, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field from ...exceptions import SetupError from ..base import Tidy3dBaseModel, cached_property from ..data.data_array import DataArray, ScalarFieldDataArray, SpatialDataArray from ..data.utils import UnstructuredGridDataset, UnstructuredGridDatasetType from ..geometry.base import Box -from ..types import ArrayFloat1D, Axis, Coordinate, InterpMethod, Literal +from ..types import ArrayFloat1D, Axis, Coordinate, InterpMethod # data type of one dimensional coordinate array. Coords1D = ArrayFloat1D @@ -29,16 +29,19 @@ class Coords(Tidy3dBaseModel): >>> coords = Coords(x=x, y=y, z=z) """ - x: Coords1D = pd.Field( - ..., title="X Coordinates", description="1-dimensional array of x coordinates." + x: Coords1D = Field( + title="X Coordinates", + description="1-dimensional array of x coordinates.", ) - y: Coords1D = pd.Field( - ..., title="Y Coordinates", description="1-dimensional array of y coordinates." + y: Coords1D = Field( + title="Y Coordinates", + description="1-dimensional array of y coordinates.", ) - z: Coords1D = pd.Field( - ..., title="Z Coordinates", description="1-dimensional array of z coordinates." + z: Coords1D = Field( + title="Z Coordinates", + description="1-dimensional array of z coordinates.", ) @property @@ -280,20 +283,17 @@ class FieldGrid(Tidy3dBaseModel): >>> field_grid = FieldGrid(x=coords, y=coords, z=coords) """ - x: Coords = pd.Field( - ..., + x: Coords = Field( title="X Positions", description="x,y,z coordinates of the locations of the x-component of a vector field.", ) - y: Coords = pd.Field( - ..., + y: Coords = Field( title="Y Positions", description="x,y,z coordinates of the locations of the y-component of a vector field.", ) - z: Coords = pd.Field( - ..., + z: Coords = Field( title="Z Positions", description="x,y,z coordinates of the locations of the z-component of a vector field.", ) @@ -313,14 +313,12 @@ class YeeGrid(Tidy3dBaseModel): >>> Ex_coords = yee_grid.E.x """ - E: FieldGrid = pd.Field( - ..., + E: FieldGrid = Field( title="Electric Field Grid", description="Coordinates of the locations of all three components of the electric field.", ) - H: FieldGrid = pd.Field( - ..., + H: FieldGrid = Field( title="Electric Field Grid", description="Coordinates of the locations of all three components of the magnetic field.", ) @@ -353,8 +351,7 @@ class Grid(Tidy3dBaseModel): >>> yee_grid = grid.yee """ - boundaries: Coords = pd.Field( - ..., + boundaries: Coords = Field( title="Boundary Coordinates", description="x,y,z coordinates of the boundaries between cells, defining the FDTD grid.", ) @@ -410,7 +407,7 @@ def sizes(self) -> Coords: return Coords(**{key: np.diff(val) for key, val in self.boundaries.to_dict.items()}) @property - def num_cells(self) -> Tuple[int, int, int]: + def num_cells(self) -> tuple[int, int, int]: """Return sizes of the cells in the :class:`Grid`. Returns @@ -452,7 +449,7 @@ def max_size(self) -> float: return float(max(max(sizes) for sizes in self.sizes.to_list)) @property - def info(self) -> Dict: + def info(self) -> dict: """Dictionary collecting various properties of the grids.""" num_cells = self.num_cells total_cells = int(np.prod(num_cells)) @@ -567,7 +564,7 @@ def _yee_h(self, axis: Axis): return Coords(**yee_coords) - def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int]]: + def discretize_inds(self, box: Box, extend: bool = False) -> list[tuple[int, int]]: """Start and stopping indexes for the cells that intersect with a :class:`Box`. Parameters @@ -581,7 +578,7 @@ def discretize_inds(self, box: Box, extend: bool = False) -> List[Tuple[int, int Returns ------- - List[Tuple[int, int]] + list[tuple[int, int]] The (start, stop) indexes of the cells that intersect with ``box`` in each of the three dimensions. """ diff --git a/tidy3d/components/grid/grid_spec.py b/tidy3d/components/grid/grid_spec.py index 17f41e088c..fe3333121d 100644 --- a/tidy3d/components/grid/grid_spec.py +++ b/tidy3d/components/grid/grid_spec.py @@ -3,15 +3,23 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from ...constants import C_0, MICROMETER, dp_eps, inf from ...exceptions import SetupError from ...log import log -from ..base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from ..base import Tidy3dBaseModel, cached_property from ..geometry.base import Box, ClipOperation from ..lumped_element import LumpedElementType from ..source.utils import SourceType @@ -23,7 +31,7 @@ Coordinate, CoordinateOptional, Symmetry, - annotate_type, + discriminated_union, ) from .corner_finder import CornerFinderSpec from .grid import Coords, Coords1D, Grid @@ -46,12 +54,12 @@ class GridSpec1d(Tidy3dBaseModel, ABC): def make_coords( self, axis: Axis, - structures: List[StructureType], - symmetry: Tuple[Symmetry, Symmetry, Symmetry], + structures: list[StructureType], + symmetry: tuple[Symmetry, Symmetry, Symmetry], periodic: bool, - wavelength: pd.PositiveFloat, - num_pml_layers: Tuple[pd.NonNegativeInt, pd.NonNegativeInt], - snapping_points: Tuple[CoordinateOptional, ...], + wavelength: PositiveFloat, + num_pml_layers: tuple[NonNegativeInt, NonNegativeInt], + snapping_points: tuple[CoordinateOptional, ...], ) -> Coords1D: """Generate 1D coords to be used as grid boundaries, based on simulation parameters. Symmetry, and PML layers will be treated here. @@ -60,9 +68,9 @@ def make_coords( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. periodic : bool @@ -70,9 +78,9 @@ def make_coords( Only relevant for autogrids. wavelength : float Free-space wavelength. - num_pml_layers : Tuple[int, int] + num_pml_layers : tuple[int, int] number of layers in the absorber + and - direction along one dimension. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. Returns @@ -112,7 +120,7 @@ def make_coords( def _make_coords_initial( self, axis: Axis, - structures: List[StructureType], + structures: list[StructureType], **kwargs, ) -> Coords1D: """Generate 1D coords to be used as grid boundaries, based on simulation parameters. @@ -122,7 +130,7 @@ def _make_coords_initial( Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. **kwargs Other arguments @@ -134,13 +142,13 @@ def _make_coords_initial( """ @staticmethod - def _add_pml_to_bounds(num_layers: Tuple[int, int], bounds: Coords1D) -> Coords1D: + def _add_pml_to_bounds(num_layers: tuple[int, int], bounds: Coords1D) -> Coords1D: """Append absorber layers to the beginning and end of the simulation bounds along one dimension. Parameters ---------- - num_layers : Tuple[int, int] + num_layers : tuple[int, int] number of layers in the absorber + and - direction along one dimension. bound_coords : np.ndarray coordinates specifying boundaries between cells along one dimension. @@ -173,7 +181,7 @@ def _postprocess_unaligned_grid( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. machine_error_relaxation : bool When operations such as translation are applied to the 1d grids, fix the bounds @@ -233,7 +241,7 @@ def _postprocess_unaligned_grid( @abstractmethod def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Estimated minimal grid size along the axis. The actual minimal grid size from mesher might be smaller. @@ -242,9 +250,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -275,15 +283,14 @@ class UniformGrid(GridSpec1d): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for uniform grid generation.", units=MICROMETER, ) - @pd.validator("dl", always=True) - def _validate_dl(cls, val): + @field_validator("dl") + def _validate_dl(val): """ Ensure 'dl' is not too small. """ @@ -298,7 +305,7 @@ def _validate_dl(cls, val): def _make_coords_initial( self, axis: Axis, - structures: List[StructureType], + structures: list[StructureType], **kwargs, ) -> Coords1D: """Uniform 1D coords to be used as grid boundaries. @@ -307,7 +314,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -330,7 +337,7 @@ def _make_coords_initial( return center - size / 2 + np.arange(num_cells + 1) * dl_snapped def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Minimal grid size, which equals grid size here. @@ -338,9 +345,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -360,8 +367,7 @@ class CustomGridBoundaries(GridSpec1d): >>> grid_1d = CustomGridBoundaries(coords=[-0.2, 0.0, 0.2, 0.4, 0.5, 0.6, 0.7]) """ - coords: Coords1D = pd.Field( - ..., + coords: Coords1D = Field( title="Grid Boundary Coordinates", description="An array of grid boundary coordinates.", units=MICROMETER, @@ -370,7 +376,7 @@ class CustomGridBoundaries(GridSpec1d): def _make_coords_initial( self, axis: Axis, - structures: List[StructureType], + structures: list[StructureType], **kwargs, ) -> Coords1D: """Customized 1D coords to be used as grid boundaries. @@ -379,7 +385,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -396,7 +402,7 @@ def _make_coords_initial( ) def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Minimal grid size from grid specification. @@ -404,9 +410,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -426,8 +432,7 @@ class CustomGrid(GridSpec1d): >>> grid_1d = CustomGrid(dl=[0.2, 0.2, 0.1, 0.1, 0.1, 0.2, 0.2]) """ - dl: Tuple[pd.PositiveFloat, ...] = pd.Field( - ..., + dl: tuple[PositiveFloat, ...] = Field( title="Customized grid sizes.", description="An array of custom nonuniform grid sizes. The resulting grid is centered on " "the simulation center such that it spans the region " @@ -437,7 +442,7 @@ class CustomGrid(GridSpec1d): units=MICROMETER, ) - custom_offset: float = pd.Field( + custom_offset: Optional[float] = Field( None, title="Customized grid offset.", description="The starting coordinate of the grid which defines the simulation center. " @@ -449,7 +454,7 @@ class CustomGrid(GridSpec1d): def _make_coords_initial( self, axis: Axis, - structures: List[StructureType], + structures: list[StructureType], **kwargs, ) -> Coords1D: """Customized 1D coords to be used as grid boundaries. @@ -458,7 +463,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -488,7 +493,7 @@ def _make_coords_initial( ) def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Minimal grid size from grid specification. @@ -496,9 +501,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -512,7 +517,7 @@ def estimated_min_dl( class AbstractAutoGrid(GridSpec1d): """Specification for non-uniform or quasi-uniform grid along a given dimension.""" - max_scale: float = pd.Field( + max_scale: float = Field( 1.4, title="Maximum Grid Size Scaling", description="Sets the maximum ratio between any two consecutive grid steps.", @@ -520,13 +525,13 @@ class AbstractAutoGrid(GridSpec1d): lt=2.0, ) - mesher: MesherType = pd.Field( - GradedMesher(), + mesher: MesherType = Field( + default_factory=GradedMesher, title="Grid Construction Tool", description="The type of mesher to use to generate the grid automatically.", ) - dl_min: pd.NonNegativeFloat = pd.Field( + dl_min: Optional[NonNegativeFloat] = Field( None, title="Lower Bound of Grid Size", description="Lower bound of the grid size along this dimension regardless of " @@ -538,11 +543,11 @@ class AbstractAutoGrid(GridSpec1d): ) @abstractmethod - def _preprocessed_structures(self, structures: List[StructureType]) -> List[StructureType]: + def _preprocessed_structures(self, structures: list[StructureType]) -> list[StructureType]: """Preprocess structure list before passing to ``mesher``.""" @abstractmethod - def _dl_collapsed_axis(self, wavelength: float, sim_size: Tuple[float, 3]) -> float: + def _dl_collapsed_axis(self, wavelength: float, sim_size: tuple[float, 3]) -> float: """The grid step size if just a single grid along an axis in the simulation domain.""" @property @@ -556,7 +561,7 @@ def _min_steps_per_wvl(self) -> float: """Minimal steps per wavelength applied internally.""" @abstractmethod - def _dl_max(self, sim_size: Tuple[float, 3]) -> float: + def _dl_max(self, sim_size: tuple[float, 3]) -> float: """Upper bound of grid size applied internally.""" @property @@ -564,18 +569,18 @@ def _undefined_dl_min(self) -> bool: """Whether `dl_min` has been specified or not.""" return self.dl_min is None or self.dl_min == 0 - def _filtered_dl(self, dl: float, sim_size: Tuple[float, 3]) -> float: + def _filtered_dl(self, dl: float, sim_size: tuple[float, 3]) -> float: """Grid step size after applying minimal and maximal filtering.""" return max(min(dl, self._dl_max(sim_size)), self._dl_min) def _make_coords_initial( self, axis: Axis, - structures: List[StructureType], + structures: list[StructureType], wavelength: float, symmetry: Symmetry, is_periodic: bool, - snapping_points: Tuple[CoordinateOptional, ...], + snapping_points: tuple[CoordinateOptional, ...], ) -> Coords1D: """Customized 1D coords to be used as grid boundaries. @@ -583,16 +588,16 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation. wavelength : float Free-space wavelength. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. is_periodic : bool Apply periodic boundary condition or not. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. Returns @@ -695,15 +700,14 @@ class QuasiUniformGrid(AbstractAutoGrid): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for quasi-uniform grid generation. Grid size at some locations can be " "slightly smaller.", units=MICROMETER, ) - def _preprocessed_structures(self, structures: List[StructureType]) -> List[StructureType]: + def _preprocessed_structures(self, structures: list[StructureType]) -> list[StructureType]: """Processing structure list before passing to ``mesher``. Adjust all structures to drop their material properties so that they all have step size ``dl``. """ @@ -731,16 +735,16 @@ def _min_steps_per_wvl(self) -> float: # irrelevant in this class, just supply an arbitrary number return 1 - def _dl_max(self, sim_size: Tuple[float, 3]) -> float: + def _dl_max(self, sim_size: tuple[float, 3]) -> float: """Upper bound of grid size.""" return self.dl - def _dl_collapsed_axis(self, wavelength: float, sim_size: Tuple[float, 3]) -> float: + def _dl_collapsed_axis(self, wavelength: float, sim_size: tuple[float, 3]) -> float: """The grid step size if just a single grid along an axis.""" return self._filtered_dl(self.dl, sim_size) def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Estimated minimal grid size, which equals grid size here. @@ -748,9 +752,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -786,14 +790,14 @@ class AutoGrid(AbstractAutoGrid): * `Numerical dispersion in FDTD `_ """ - min_steps_per_wvl: float = pd.Field( + min_steps_per_wvl: float = Field( 10.0, title="Minimal Number of Steps Per Wavelength", description="Minimal number of steps per wavelength in each medium.", ge=6.0, ) - min_steps_per_sim_size: float = pd.Field( + min_steps_per_sim_size: float = Field( 10.0, title="Minimal Number of Steps Per Simulation Domain Size", description="Minimal number of steps per longest edge length of simulation domain " @@ -801,7 +805,7 @@ class AutoGrid(AbstractAutoGrid): ge=1.0, ) - def _dl_max(self, sim_size: Tuple[float, 3]) -> float: + def _dl_max(self, sim_size: tuple[float, 3]) -> float: """Upper bound of grid size, constrained by `min_steps_per_sim_size`.""" return max(sim_size) / self.min_steps_per_sim_size @@ -818,20 +822,20 @@ def _min_steps_per_wvl(self) -> float: """Minimal steps per wavelength.""" return self.min_steps_per_wvl - def _preprocessed_structures(self, structures: List[StructureType]) -> List[StructureType]: + def _preprocessed_structures(self, structures: list[StructureType]) -> list[StructureType]: """Processing structure list before passing to ``mesher``.""" return structures - def _dl_collapsed_axis(self, wavelength: float, sim_size: Tuple[float, 3]) -> float: + def _dl_collapsed_axis(self, wavelength: float, sim_size: tuple[float, 3]) -> float: """The grid step size if just a single grid along an axis.""" return self._vacuum_dl(wavelength, sim_size) - def _vacuum_dl(self, wavelength: float, sim_size: Tuple[float, 3]) -> float: + def _vacuum_dl(self, wavelength: float, sim_size: tuple[float, 3]) -> float: """Grid step size when computed in vacuum region.""" return self._filtered_dl(wavelength / self.min_steps_per_wvl, sim_size) def estimated_min_dl( - self, wavelength: float, structure_list: List[Structure], sim_size: Tuple[float, 3] + self, wavelength: float, structure_list: list[Structure], sim_size: tuple[float, 3] ) -> float: """Estimated minimal grid size along the axis. The actual minimal grid size from mesher might be smaller. @@ -840,9 +844,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -879,27 +883,27 @@ class GridRefinement(Tidy3dBaseModel): """ - refinement_factor: Optional[pd.PositiveFloat] = pd.Field( + refinement_factor: Optional[PositiveFloat] = Field( None, title="Mesh Refinement Factor", description="Refine grid step size in vacuum by this factor.", ) - dl: Optional[pd.PositiveFloat] = pd.Field( + dl: Optional[PositiveFloat] = Field( None, title="Grid Size", description="Grid step size in the refined region.", units=MICROMETER, ) - num_cells: pd.PositiveInt = pd.Field( + num_cells: PositiveInt = Field( 3, title="Number of Refined Grid Cells", description="Number of grid cells in the refinement region.", ) @property - def _refinement_factor(self) -> pd.PositiveFloat: + def _refinement_factor(self) -> PositiveFloat: """Refinement factor applied internally.""" if self.refinement_factor is None and self.dl is None: return DEFAULT_REFINEMENT_FACTOR @@ -985,20 +989,19 @@ class LayerRefinementSpec(Box): """ - axis: Axis = pd.Field( - ..., + axis: Axis = Field( title="Axis", description="Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z).", ) - min_steps_along_axis: Optional[pd.PositiveFloat] = pd.Field( + min_steps_along_axis: Optional[PositiveFloat] = Field( None, title="Minimal Number Of Steps Along Axis", description="If not ``None`` and the thickness of the layer is nonzero, set minimal " "number of steps discretizing the layer thickness.", ) - bounds_refinement: Optional[GridRefinement] = pd.Field( + bounds_refinement: Optional[GridRefinement] = Field( None, title="Mesh Refinement Factor Around Layer Bounds", description="If not ``None``, refine mesh around minimum and maximum positions " @@ -1006,35 +1009,35 @@ class LayerRefinementSpec(Box): "refinement here is only applied if it sets a smaller grid size.", ) - bounds_snapping: Optional[Literal["bounds", "lower", "upper", "center"]] = pd.Field( + bounds_snapping: Optional[Literal["bounds", "lower", "upper", "center"]] = Field( "lower", title="Placing Grid Snapping Point Along Axis", description="If not ``None``, enforcing grid boundaries to pass through ``lower``, " "``center``, or ``upper`` position of the layer; or both ``lower`` and ``upper`` with ``bounds``.", ) - corner_finder: Optional[CornerFinderSpec] = pd.Field( - CornerFinderSpec(), + corner_finder: Optional[CornerFinderSpec] = Field( + default_factory=CornerFinderSpec, title="Inplane Corner Detection Specification", description="Specification for inplane corner detection. Inplane mesh refinement " "is based on the coordinates of those corners.", ) - corner_snapping: bool = pd.Field( + corner_snapping: bool = Field( True, title="Placing Grid Snapping Point At Corners", description="If ``True`` and ``corner_finder`` is not ``None``, enforcing inplane " "grid boundaries to pass through corners of geometries specified by ``corner_finder``.", ) - corner_refinement: Optional[GridRefinement] = pd.Field( - GridRefinement(), + corner_refinement: Optional[GridRefinement] = Field( + default_factory=GridRefinement, title="Inplane Mesh Refinement Factor Around Corners", description="If not ``None`` and ``corner_finder`` is not ``None``, refine mesh around " "corners of geometries specified by ``corner_finder``. ", ) - refinement_inside_sim_only: bool = pd.Field( + refinement_inside_sim_only: bool = Field( True, title="Apply Refinement Only To Features Inside Simulation Domain", description="If ``True``, only apply mesh refinement to features such as corners inside " @@ -1043,33 +1046,32 @@ class LayerRefinementSpec(Box): "and the projection of the simulation domain overlaps.", ) - gap_meshing_iters: pd.NonNegativeInt = pd.Field( + gap_meshing_iters: NonNegativeInt = Field( 1, title="Gap Meshing Iterations", description="Number of recursive iterations for resolving thin gaps. " "The underlying algorithm detects gaps contained in a single cell and places a snapping plane at the gaps's centers.", ) - dl_min_from_gap_width: bool = pd.Field( + dl_min_from_gap_width: bool = Field( True, title="Set ``dl_min`` from Estimated Gap Width", description="Take into account autodetected minimal PEC gap width when determining ``dl_min``. " "This only applies if ``dl_min`` in ``AutoGrid`` specification is not set.", ) - @pd.validator("axis", always=True) - @skip_if_fields_missing(["size"]) - def _finite_size_along_axis(cls, val, values): + @model_validator(mode="after") + def _finite_size_along_axis(self): """size must be finite along axis.""" - if np.isinf(values["size"][val]): + if np.isinf(self.size[self.axis]): raise SetupError("'size' must take finite values along 'axis' dimension.") - return val + return self @classmethod def from_layer_bounds( cls, axis: Axis, - bounds: Tuple[float, float], + bounds: tuple[float, float], min_steps_along_axis: np.PositiveFloat = None, bounds_refinement: GridRefinement = None, bounds_snapping: Literal["bounds", "lower", "upper", "center"] = "lower", @@ -1077,7 +1079,7 @@ def from_layer_bounds( corner_snapping: bool = True, corner_refinement: GridRefinement = GridRefinement(), refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, ): """Constructs a :class:`LayerRefiementSpec` that is unbounded in inplane dimensions from bounds along @@ -1087,7 +1089,7 @@ def from_layer_bounds( ---------- axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). - bounds : Tuple[float, float] + bounds : tuple[float, float] Minimum and maximum positions of the layer along axis dimension. min_steps_along_axis : np.PositiveFloat = None Minimal number of steps along axis. @@ -1146,16 +1148,16 @@ def from_bounds( corner_snapping: bool = True, corner_refinement: GridRefinement = GridRefinement(), refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, ): """Constructs a :class:`LayerRefiementSpec` from minimum and maximum coordinate bounds. Parameters ---------- - rmin : Tuple[float, float, float] + rmin : tuple[float, float, float] (x, y, z) coordinate of the minimum values. - rmax : Tuple[float, float, float] + rmax : tuple[float, float, float] (x, y, z) coordinate of the maximum values. axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). If ``None``, apply the dimension @@ -1207,7 +1209,7 @@ def from_bounds( @classmethod def from_structures( cls, - structures: List[Structure], + structures: list[Structure], axis: Axis = None, min_steps_along_axis: np.PositiveFloat = None, bounds_refinement: GridRefinement = None, @@ -1216,14 +1218,14 @@ def from_structures( corner_snapping: bool = True, corner_refinement: GridRefinement = GridRefinement(), refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, ): """Constructs a :class:`LayerRefiementSpec` from the bounding box of a list of structures. Parameters ---------- - structures : List[Structure] + structures : list[Structure] A list of structures whose overall bounding box is used to define mesh refinement axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). If ``None``, apply the dimension @@ -1306,7 +1308,7 @@ def _unpop_axis(self, ax_coord: float, plane_coord: Any) -> CoordinateOptional: """ return self.unpop_axis(ax_coord, [plane_coord, plane_coord], self.axis) - def suggested_dl_min(self, grid_size_in_vacuum: float, structures: List[Structure]) -> float: + def suggested_dl_min(self, grid_size_in_vacuum: float, structures: list[Structure]) -> float: """Suggested lower bound of grid step size for this layer. Parameters @@ -1345,7 +1347,7 @@ def suggested_dl_min(self, grid_size_in_vacuum: float, structures: List[Structur return dl_min - def generate_snapping_points(self, structure_list: List[Structure]) -> List[CoordinateOptional]: + def generate_snapping_points(self, structure_list: list[Structure]) -> list[CoordinateOptional]: """generate snapping points for mesh refinement.""" snapping_points = self._snapping_points_along_axis if self.corner_snapping: @@ -1353,8 +1355,8 @@ def generate_snapping_points(self, structure_list: List[Structure]) -> List[Coor return snapping_points def generate_override_structures( - self, grid_size_in_vacuum: float, structure_list: List[Structure] - ) -> List[MeshOverrideStructure]: + self, grid_size_in_vacuum: float, structure_list: list[Structure] + ) -> list[MeshOverrideStructure]: """Generate mesh override structures for mesh refinement.""" return self._override_structures_along_axis( grid_size_in_vacuum @@ -1380,8 +1382,8 @@ def _inplane_inside(self, point: ArrayFloat2D) -> bool: return self.inside(point_3d[0], point_3d[1], point_3d[2]) def _corners_and_convexity_2d( - self, structure_list: List[Structure], ravel: bool - ) -> List[CoordinateOptional]: + self, structure_list: list[Structure], ravel: bool + ) -> list[CoordinateOptional]: """Raw inplane corners and their convexity.""" if self.corner_finder is None: return [], [] @@ -1410,7 +1412,7 @@ def _corners_and_convexity_2d( return inplane_points, convexity - def _dl_min_from_smallest_feature(self, structure_list: List[Structure]): + def _dl_min_from_smallest_feature(self, structure_list: list[Structure]): """Calculate `dl_min` suggestion based on smallest feature size.""" inplane_points, convexity = self._corners_and_convexity_2d( @@ -1448,7 +1450,7 @@ def _dl_min_from_smallest_feature(self, structure_list: List[Structure]): return dl_min - def _corners(self, structure_list: List[Structure]) -> List[CoordinateOptional]: + def _corners(self, structure_list: list[Structure]) -> list[CoordinateOptional]: """Inplane corners in 3D coordinate.""" inplane_points, _ = self._corners_and_convexity_2d( structure_list=structure_list, ravel=True @@ -1461,7 +1463,7 @@ def _corners(self, structure_list: List[Structure]) -> List[CoordinateOptional]: ] @property - def _snapping_points_along_axis(self) -> List[CoordinateOptional]: + def _snapping_points_along_axis(self) -> list[CoordinateOptional]: """Snapping points for layer bounds.""" if self.bounds_snapping is None: @@ -1486,8 +1488,8 @@ def _snapping_points_along_axis(self) -> List[CoordinateOptional]: ] def _override_structures_inplane( - self, structure_list: List[Structure], grid_size_in_vacuum: float - ) -> List[MeshOverrideStructure]: + self, structure_list: list[Structure], grid_size_in_vacuum: float + ) -> list[MeshOverrideStructure]: """Inplane mesh override structures for refining mesh around corners.""" if self.corner_refinement is None: return [] @@ -1501,7 +1503,7 @@ def _override_structures_inplane( def _override_structures_along_axis( self, grid_size_in_vacuum: float - ) -> List[MeshOverrideStructure]: + ) -> list[MeshOverrideStructure]: """Mesh override structures for refining mesh along layer axis dimension.""" override_structures = [] @@ -1556,7 +1558,7 @@ def _override_structures_along_axis( def _find_vertical_intersections( self, grid_x_coords, grid_y_coords, poly_vertices, boundary - ) -> Tuple[List[Tuple[int, int]], List[float]]: + ) -> tuple[list[tuple[int, int]], list[float]]: """Detect intersection points of single polygon and vertical grid lines.""" # indices of cells that contain intersection with grid lines (left edge of a cell) @@ -1709,7 +1711,7 @@ def _find_vertical_intersections( def _process_poly( self, grid_x_coords, grid_y_coords, poly_vertices, boundaries - ) -> Tuple[List[Tuple[int, int]], List[float], List[Tuple[int, int]], List[float]]: + ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: """Detect intersection points of single polygon and grid lines.""" # find cells that contain intersections of vertical grid lines @@ -1732,7 +1734,7 @@ def _process_poly( def _process_slice( self, x, y, merged_geos, boundaries - ) -> Tuple[List[Tuple[int, int]], List[float], List[Tuple[int, int]], List[float]]: + ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: """Detect intersection points of geometries boundaries and grid lines.""" # cells that contain intersections of vertical grid lines @@ -1815,7 +1817,7 @@ def _process_slice( def _generate_horizontal_snapping_lines( self, grid_y_coords, intersected_cells_ij, relative_vert_disp - ) -> Tuple[List[CoordinateOptional], float]: + ) -> tuple[list[CoordinateOptional], float]: """Convert a list of intersections of vertical grid lines, given as coordinates of cells and relative vertical displacement inside each cell, into locations of snapping lines that resolve thin gaps and strips. @@ -1892,8 +1894,8 @@ def _generate_horizontal_snapping_lines( return snapping_lines_y, min_gap_width def _resolve_gaps( - self, structures: List[Structure], grid: Grid, boundaries: Tuple, center, size - ) -> Tuple[List[CoordinateOptional], float]: + self, structures: list[Structure], grid: Grid, boundaries: tuple, center, size + ) -> tuple[list[CoordinateOptional], float]: """Detect underresolved gaps and place snapping lines in them. Also return the detected minimal gap width.""" # get x and y coordinates of grid lines @@ -2025,28 +2027,28 @@ class GridSpec(Tidy3dBaseModel): * `Numerical dispersion in FDTD `_ """ - grid_x: GridType = pd.Field( - AutoGrid(), + grid_x: GridType = Field( + default_factory=AutoGrid, title="Grid specification along x-axis", description="Grid specification along x-axis", discriminator=TYPE_TAG_STR, ) - grid_y: GridType = pd.Field( - AutoGrid(), + grid_y: GridType = Field( + default_factory=AutoGrid, title="Grid specification along y-axis", description="Grid specification along y-axis", discriminator=TYPE_TAG_STR, ) - grid_z: GridType = pd.Field( - AutoGrid(), + grid_z: GridType = Field( + default_factory=AutoGrid, title="Grid specification along z-axis", description="Grid specification along z-axis", discriminator=TYPE_TAG_STR, ) - wavelength: float = pd.Field( + wavelength: Optional[PositiveFloat] = Field( None, title="Free-space wavelength", description="Free-space wavelength for automatic nonuniform grid. It can be 'None' " @@ -2057,7 +2059,7 @@ class GridSpec(Tidy3dBaseModel): units=MICROMETER, ) - override_structures: Tuple[annotate_type(StructureType), ...] = pd.Field( + override_structures: tuple[discriminated_union(StructureType), ...] = Field( (), title="Grid specification override structures", description="A set of structures that is added on top of the simulation structures in " @@ -2067,7 +2069,7 @@ class GridSpec(Tidy3dBaseModel): "uses :class:`.AutoGrid` or :class:`.QuasiUniformGrid`.", ) - snapping_points: Tuple[CoordinateOptional, ...] = pd.Field( + snapping_points: tuple[CoordinateOptional, ...] = Field( (), title="Grid specification snapping_points", description="A set of points that enforce grid boundaries to pass through them. " @@ -2078,7 +2080,7 @@ class GridSpec(Tidy3dBaseModel): "uses :class:`.AutoGrid` or :class:`.QuasiUniformGrid`.", ) - layer_refinement_specs: Tuple[LayerRefinementSpec, ...] = pd.Field( + layer_refinement_specs: tuple[LayerRefinementSpec, ...] = Field( (), title="Mesh Refinement In Layered Structures", description="Automatic mesh refinement according to layer specifications. The material " @@ -2108,7 +2110,7 @@ def custom_grid_used(self) -> bool: return np.any([isinstance(mesh, (CustomGrid, CustomGridBoundaries)) for mesh in grid_list]) @staticmethod - def wavelength_from_sources(sources: List[SourceType]) -> pd.PositiveFloat: + def wavelength_from_sources(sources: list[SourceType]) -> PositiveFloat: """Define a wavelength based on supplied sources. Called if auto mesh is used and ``self.wavelength is None``.""" @@ -2136,7 +2138,7 @@ def layer_refinement_used(self) -> bool: return len(self.layer_refinement_specs) > 0 @property - def snapping_points_used(self) -> List[bool, bool, bool]: + def snapping_points_used(self) -> list[bool, bool, bool]: """Along each axis, ``True`` if any snapping point is used. However, it is still ``False`` if all snapping points take value ``None`` along the axis. """ @@ -2155,7 +2157,7 @@ def snapping_points_used(self) -> List[bool, bool, bool]: return snapping_used @property - def override_structures_used(self) -> List[bool, bool, bool]: + def override_structures_used(self) -> list[bool, bool, bool]: """Along each axis, ``True`` if any override structure is used. However, it is still ``False`` if only :class:`.MeshOverrideStructure` is supplied, and their ``dl[axis]`` all take the ``None`` value. @@ -2177,21 +2179,21 @@ def override_structures_used(self) -> List[bool, bool, bool]: return override_used def internal_snapping_points( - self, structures: List[Structure], lumped_elements: List[LumpedElementType] - ) -> List[CoordinateOptional]: + self, structures: list[Structure], lumped_elements: list[LumpedElementType] + ) -> list[CoordinateOptional]: """Internal snapping points. So far, internal snapping points are generated by `layer_refinement_specs` and lumped element. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of physical structures. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping points coordinates. """ @@ -2211,25 +2213,25 @@ def internal_snapping_points( def all_snapping_points( self, - structures: List[Structure], - lumped_elements: List[LumpedElementType], - internal_snapping_points: List[CoordinateOptional] = None, - ) -> List[CoordinateOptional]: + structures: list[Structure], + lumped_elements: list[LumpedElementType], + internal_snapping_points: list[CoordinateOptional] = None, + ) -> list[CoordinateOptional]: """Internal and external snapping points. External snapping points take higher priority. So far, internal snapping points are generated by `layer_refinement_specs`. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of physical structures. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping points coordinates. """ @@ -2240,34 +2242,34 @@ def all_snapping_points( return internal_snapping_points + list(self.snapping_points) @property - def external_override_structures(self) -> List[StructureType]: + def external_override_structures(self) -> list[StructureType]: """External supplied override structure list.""" return [s.to_static() for s in self.override_structures] def internal_override_structures( self, - structures: List[Structure], - wavelength: pd.PositiveFloat, - sim_size: Tuple[float, 3], - lumped_elements: List[LumpedElementType], - ) -> List[StructureType]: + structures: list[Structure], + wavelength: PositiveFloat, + sim_size: tuple[float, 3], + lumped_elements: list[LumpedElementType], + ) -> list[StructureType]: """Internal mesh override structures. So far, internal override structures are generated by `layer_refinement_specs` and lumped element. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for minimal step size in vaccum. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. Returns ------- - List[StructureType] + list[StructureType] List of override structures. """ @@ -2289,31 +2291,31 @@ def internal_override_structures( def all_override_structures( self, - structures: List[Structure], - wavelength: pd.PositiveFloat, - sim_size: Tuple[float, 3], - lumped_elements: List[LumpedElementType], - internal_override_structures: List[MeshOverrideStructure] = None, - ) -> List[StructureType]: + structures: list[Structure], + wavelength: PositiveFloat, + sim_size: tuple[float, 3], + lumped_elements: list[LumpedElementType], + internal_override_structures: list[MeshOverrideStructure] = None, + ) -> list[StructureType]: """Internal and external mesh override structures. External override structures take higher priority. So far, internal override structures all come from `layer_refinement_specs`. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for minimal step size in vaccum. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. Returns ------- - List[StructureType] + list[StructureType] List of override structures. """ @@ -2325,7 +2327,7 @@ def all_override_structures( return internal_override_structures + self.external_override_structures - def _min_vacuum_dl_in_autogrid(self, wavelength: float, sim_size: Tuple[float, 3]) -> float: + def _min_vacuum_dl_in_autogrid(self, wavelength: float, sim_size: tuple[float, 3]) -> float: """Compute grid step size in vacuum for Autogrd. If AutoGrid is applied along more than 1 dimension, return the minimal. """ @@ -2338,9 +2340,9 @@ def _min_vacuum_dl_in_autogrid(self, wavelength: float, sim_size: Tuple[float, 3 def _dl_min( self, wavelength: float, - structure_list: List[StructureType], - sim_size: Tuple[float, 3], - lumped_elements: List[LumpedElementType], + structure_list: list[StructureType], + sim_size: tuple[float, 3], + lumped_elements: list[LumpedElementType], ) -> float: """Lower bound of grid size to be applied to dimensions where AutoGrid with unset `dl_min` (0 or None) is applied. @@ -2375,7 +2377,7 @@ def _dl_min( min_dl = min(min_dl, min(override_structure.dl)) return min_dl * MIN_STEP_BOUND_SCALE - def get_wavelength(self, sources: List[SourceType]) -> float: + def get_wavelength(self, sources: list[SourceType]) -> float: """Get wavelength for automatic mesh generation if needed.""" wavelength = self.wavelength if wavelength is None and self.auto_grid_used: @@ -2385,15 +2387,15 @@ def get_wavelength(self, sources: List[SourceType]) -> float: def make_grid( self, - structures: List[Structure], - symmetry: Tuple[Symmetry, Symmetry, Symmetry], - periodic: Tuple[bool, bool, bool], - sources: List[SourceType], - num_pml_layers: List[Tuple[pd.NonNegativeInt, pd.NonNegativeInt]], - lumped_elements: List[LumpedElementType] = (), - internal_override_structures: List[MeshOverrideStructure] = None, - internal_snapping_points: List[CoordinateOptional] = None, - boundary_types: Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [ + structures: list[Structure], + symmetry: tuple[Symmetry, Symmetry, Symmetry], + periodic: tuple[bool, bool, bool], + sources: list[SourceType], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], + lumped_elements: list[LumpedElementType] = (), + internal_override_structures: list[MeshOverrideStructure] = None, + internal_snapping_points: list[CoordinateOptional] = None, + boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [ [None, None], [None, None], [None, None], @@ -2403,26 +2405,26 @@ def make_grid( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] - If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] - If `None`, recomputes internal snapping points. - boundary_types : Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [[None, None], [None, None], [None, None]] + internal_override_structures : list[MeshOverrideStructure] + If ``None``, recomputes internal override structures. + internal_snapping_points : list[CoordinateOptional] + If ``None``, recomputes internal snapping points. + boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [[None, None], [None, None], [None, None]] Type of boundary conditions along each dimension: "pec/pmc", "periodic", or None for any other. This is relevant only for gap meshing. @@ -2447,51 +2449,51 @@ def make_grid( def _make_grid_and_snapping_lines( self, - structures: List[Structure], - symmetry: Tuple[Symmetry, Symmetry, Symmetry], - periodic: Tuple[bool, bool, bool], - sources: List[SourceType], - num_pml_layers: List[Tuple[pd.NonNegativeInt, pd.NonNegativeInt]], - lumped_elements: List[LumpedElementType] = (), - internal_override_structures: List[MeshOverrideStructure] = None, - internal_snapping_points: List[CoordinateOptional] = None, - boundary_types: Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [ + structures: list[Structure], + symmetry: tuple[Symmetry, Symmetry, Symmetry], + periodic: tuple[bool, bool, bool], + sources: list[SourceType], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], + lumped_elements: list[LumpedElementType] = (), + internal_override_structures: list[MeshOverrideStructure] = None, + internal_snapping_points: list[CoordinateOptional] = None, + boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [ [None, None], [None, None], [None, None], ], - ) -> Tuple[Grid, List[CoordinateOptional]]: + ) -> tuple[Grid, list[CoordinateOptional]]: """Make the entire simulation grid based on some simulation parameters. Also return snappiung point resulted from iterative gap meshing. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. - boundary_types : Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [[None, None], [None, None], [None, None]] + boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [[None, None], [None, None], [None, None]] Type of boundary conditions along each dimension: "pec/pmc", "periodic", or None for any other. This is relevant only for gap meshing. Returns ------- - Tuple[Grid, List[CoordinateOptional]]: + tuple[Grid, list[CoordinateOptional]]: Entire simulation grid and snapping points generated during iterative gap meshing. """ @@ -2566,40 +2568,40 @@ def _make_grid_and_snapping_lines( def _make_grid_one_iteration( self, - structures: List[Structure], - symmetry: Tuple[Symmetry, Symmetry, Symmetry], - periodic: Tuple[bool, bool, bool], - sources: List[SourceType], - num_pml_layers: List[Tuple[pd.NonNegativeInt, pd.NonNegativeInt]], - lumped_elements: List[LumpedElementType] = (), - internal_override_structures: List[MeshOverrideStructure] = None, - internal_snapping_points: List[CoordinateOptional] = None, - dl_min_from_gaps: pd.PositiveFloat = inf, + structures: list[Structure], + symmetry: tuple[Symmetry, Symmetry, Symmetry], + periodic: tuple[bool, bool, bool], + sources: list[SourceType], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], + lumped_elements: list[LumpedElementType] = (), + internal_override_structures: list[MeshOverrideStructure] = None, + internal_snapping_points: list[CoordinateOptional] = None, + dl_min_from_gaps: PositiveFloat = inf, ) -> Grid: """Make the entire simulation grid based on some simulation parameters. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. - dl_min_from_gaps : pd.PositiveFloat + dl_min_from_gaps : PositiveFloat Minimal grid size computed based on autodetected gaps. @@ -2713,39 +2715,39 @@ def from_grid(cls, grid: Grid) -> GridSpec: @classmethod def auto( cls, - wavelength: pd.PositiveFloat = None, - min_steps_per_wvl: pd.PositiveFloat = 10.0, - max_scale: pd.PositiveFloat = 1.4, - override_structures: List[StructureType] = (), - snapping_points: Tuple[CoordinateOptional, ...] = (), - layer_refinement_specs: List[LayerRefinementSpec] = (), - dl_min: pd.NonNegativeFloat = 0.0, - min_steps_per_sim_size: pd.PositiveFloat = 10.0, + wavelength: PositiveFloat = None, + min_steps_per_wvl: PositiveFloat = 10.0, + max_scale: PositiveFloat = 1.4, + override_structures: list[StructureType] = (), + snapping_points: tuple[CoordinateOptional, ...] = (), + layer_refinement_specs: list[LayerRefinementSpec] = (), + dl_min: NonNegativeFloat = 0.0, + min_steps_per_sim_size: PositiveFloat = 10.0, mesher: MesherType = GradedMesher(), ) -> GridSpec: """Use the same :class:`AutoGrid` along each of the three directions. Parameters ---------- - wavelength : pd.PositiveFloat, optional + wavelength : PositiveFloat, optional Free-space wavelength for automatic nonuniform grid. It can be 'None' if there is at least one source in the simulation, in which case it is defined by the source central frequency. - min_steps_per_wvl : pd.PositiveFloat, optional + min_steps_per_wvl : PositiveFloat, optional Minimal number of steps per wavelength in each medium. - max_scale : pd.PositiveFloat, optional + max_scale : PositiveFloat, optional Sets the maximum ratio between any two consecutive grid steps. - override_structures : List[StructureType] + override_structures : list[StructureType] A list of structures that is added on top of the simulation structures in the process of generating the grid. This can be used to refine the grid or make it coarser depending than the expected need for higher/lower resolution regions. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. - layer_refinement_specs: List[LayerRefinementSpec] + layer_refinement_specs: list[LayerRefinementSpec] Mesh refinement according to layer specifications. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - min_steps_per_sim_size : pd.PositiveFloat, optional + min_steps_per_sim_size : PositiveFloat, optional Minimal number of steps per longest edge length of simulation domain. mesher : MesherType = GradedMesher() The type of mesher to use to generate the grid automatically. @@ -2795,9 +2797,9 @@ def uniform(cls, dl: float) -> GridSpec: def quasiuniform( cls, dl: float, - max_scale: pd.PositiveFloat = 1.4, - override_structures: List[StructureType] = (), - snapping_points: Tuple[CoordinateOptional, ...] = (), + max_scale: PositiveFloat = 1.4, + override_structures: list[StructureType] = (), + snapping_points: tuple[CoordinateOptional, ...] = (), mesher: MesherType = GradedMesher(), ) -> GridSpec: """Use the same :class:`QuasiUniformGrid` along each of the three directions. @@ -2806,13 +2808,13 @@ def quasiuniform( ---------- dl : float Grid size for quasi-uniform grid generation. - max_scale : pd.PositiveFloat, optional + max_scale : PositiveFloat, optional Sets the maximum ratio between any two consecutive grid steps. - override_structures : List[StructureType] + override_structures : list[StructureType] A list of structures that is added on top of the simulation structures in the process of generating the grid. This can be used to snap grid points to the bounding box boundary. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. mesher : MesherType = GradedMesher() The type of mesher to use to generate the grid automatically. diff --git a/tidy3d/components/grid/mesher.py b/tidy3d/components/grid/mesher.py index 039e975d0b..fe58b36563 100644 --- a/tidy3d/components/grid/mesher.py +++ b/tidy3d/components/grid/mesher.py @@ -4,10 +4,10 @@ from abc import ABC, abstractmethod from itertools import compress from math import isclose -from typing import Dict, List, Tuple, Union +from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import NonNegativeFloat, NonNegativeInt, PositiveFloat from pyroots import Brentq from shapely.errors import ShapelyDeprecationWarning from shapely.geometry import box as shapely_box @@ -35,23 +35,23 @@ class Mesher(Tidy3dBaseModel, ABC): def parse_structures( self, axis: Axis, - structures: List[StructureType], - wavelength: pd.PositiveFloat, - min_steps_per_wvl: pd.NonNegativeInt, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, - ) -> Tuple[ArrayFloat1D, ArrayFloat1D]: + structures: list[StructureType], + wavelength: PositiveFloat, + min_steps_per_wvl: NonNegativeInt, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, + ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Calculate the positions of all bounding box interfaces along a given axis.""" @abstractmethod def insert_snapping_points( self, - dl_min: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, axis: Axis, interval_coords: ArrayFloat1D, max_dl_list: ArrayFloat1D, - snapping_points: List[CoordinateOptional], - ) -> Tuple[ArrayFloat1D, ArrayFloat1D]: + snapping_points: list[CoordinateOptional], + ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Insert snapping_points to the intervals.""" @abstractmethod @@ -61,7 +61,7 @@ def make_grid_multiple_intervals( len_interval_list: ArrayFloat1D, max_scale: float, is_periodic: bool, - ) -> List[ArrayFloat1D]: + ) -> list[ArrayFloat1D]: """Create grid steps in multiple connecting intervals.""" @staticmethod @@ -96,17 +96,17 @@ class GradedMesher(Mesher): def insert_snapping_points( self, - dl_min: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, axis: Axis, interval_coords: ArrayFloat1D, max_dl_list: ArrayFloat1D, - snapping_points: List[CoordinateOptional], - ) -> Tuple[ArrayFloat1D, ArrayFloat1D]: + snapping_points: list[CoordinateOptional], + ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Insert snapping_points to the intervals. Parameters ---------- - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. axis : Axis Axis index along which to operate. @@ -114,7 +114,7 @@ def insert_snapping_points( Coordinate of interval boundaries. max_dl_list : ArrayFloat1D Maximal allowed step size of each interval generated from `parse_structures`. - snapping_points : List[CoordinateOptional] + snapping_points : list[CoordinateOptional] A set of points that enforce grid boundaries to pass through them. Returns @@ -180,12 +180,12 @@ def insert_snapping_points( def parse_structures( self, axis: Axis, - structures: List[StructureType], - wavelength: pd.PositiveFloat, - min_steps_per_wvl: pd.NonNegativeInt, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, - ) -> Tuple[ArrayFloat1D, ArrayFloat1D]: + structures: list[StructureType], + wavelength: PositiveFloat, + min_steps_per_wvl: NonNegativeInt, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, + ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Calculate the positions of all bounding box interfaces along a given axis. In this implementation, in most cases the complexity should be O(len(structures)**2), although the worst-case complexity may approach O(len(structures)**3). @@ -195,15 +195,15 @@ def parse_structures( ---------- axis : Axis Axis index along which to operate. - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for the step size and for dispersive media epsilon. - min_steps_per_wvl : pd.NonNegativeInt + min_steps_per_wvl : NonNegativeInt Minimum requested steps per wavelength. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - dl_max: pd.NonNegativeFloat + dl_max: NonNegativeFloat Upper bound of grid size. Returns @@ -382,14 +382,14 @@ def parse_structures( def insert_bbox( self, - intervals: Dict[str, List], + intervals: dict[str, list], str_ind: int, str_bbox: ArrayFloat1D, - bbox_contained_2d: List[ArrayFloat1D], + bbox_contained_2d: list[ArrayFloat1D], min_step: float, structure_steps: ArrayFloat1D, unshadowed: bool, - ) -> Dict[str, List]: + ) -> dict[str, list]: """Figure out where to place the bounding box coordinates of current structure. For both the left and the right bounds of the structure along the meshing direction, we check if they are not too close to an already existing coordinate, if the @@ -403,14 +403,14 @@ def insert_bbox( Parameters ---------- - intervals : Dict[str, List] + intervals : dict[str, List] Dictionary containing the coordinates of the interval boundaries, and a list of lists of structures contained in each interval. str_ind : int Index of the current structure. str_bbox : ArrayFloat1D Bounding box of the current structure. - bbox_contained_2d : List[ArrayFloat1D] + bbox_contained_2d : list[ArrayFloat1D] List of 3D bounding boxes that contain the current structure in 2D. min_step : float Absolute minimum interval size to impose. @@ -508,8 +508,8 @@ def insert_bbox( @staticmethod def reorder_structures( - structures: List[StructureType], - ) -> Tuple[int, List[StructureType]]: + structures: list[StructureType], + ) -> tuple[int, list[StructureType]]: """Reorder structure list to order as follows: 1). simulation structure `str[0]` remains as the first structure; 2). MeshOverrideStructures with ``shadow=False``; @@ -518,12 +518,12 @@ def reorder_structures( Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. Returns ------- - Tuple[int, List[StructureType]] + tuple[int, list[StructureType]] The number of unenforced structures, reordered structure list """ @@ -565,22 +565,22 @@ def reorder_structures( @staticmethod def filter_structures_effective_dl( - structures: List[StructureType], axis: Axis - ) -> List[StructureType]: + structures: list[StructureType], axis: Axis + ) -> list[StructureType]: """For :class:`.MeshOverrideStructure`, we allow ``dl`` along some axis to be ``None`` so that no override occurs along this axis.Here those structures with ``dl[axis]=None`` is filtered. Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. axis : Axis Axis index to place last. Returns ------- - List[StructureType] + list[StructureType] A list of filtered structures whose ``dl`` along this axis is not ``None``. """ @@ -623,11 +623,11 @@ def structure_step( @staticmethod def structure_steps( - structures: List[StructureType], + structures: list[StructureType], wavelength: float, min_steps_per_wvl: float, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, axis: Axis, ) -> ArrayFloat1D: """Get the minimum mesh required in each structure. Special media are set to index of 1, @@ -636,15 +636,15 @@ def structure_steps( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. wavelength : float Wavelength to use for the step size and for dispersive media epsilon. min_steps_per_wvl : float Minimum requested steps per wavelength. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - dl_max: pd.NonNegativeFloat + dl_max: NonNegativeFloat Upper bound of grid size. axis : Axis Axis index along which to operate. @@ -662,19 +662,19 @@ def structure_steps( return np.where(min_steps < dl_min, dl_min, min_steps) @staticmethod - def rotate_structure_bounds(structures: List[StructureType], axis: Axis) -> List[ArrayFloat1D]: + def rotate_structure_bounds(structures: list[StructureType], axis: Axis) -> list[ArrayFloat1D]: """Get structure bounding boxes with a given ``axis`` rotated to z. Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. axis : Axis Axis index to place last. Returns ------- - List[ArrayFloat1D] + list[ArrayFloat1D] A list of the bounding boxes of shape ``(2, 3)`` for each structure, with the bounds along ``axis`` being ``(:, 2)``. """ @@ -689,7 +689,7 @@ def rotate_structure_bounds(structures: List[StructureType], axis: Axis) -> List return struct_bbox @staticmethod - def bounds_2d_tree(struct_bbox: List[ArrayFloat1D]): + def bounds_2d_tree(struct_bbox: list[ArrayFloat1D]): """Make a shapely Rtree for the 2D bounding boxes of all structures in the plane perpendicular to the meshing axis.""" @@ -704,7 +704,7 @@ def bounds_2d_tree(struct_bbox: List[ArrayFloat1D]): return stree @staticmethod - def contained_2d(bbox0: ArrayFloat1D, query_bbox: List[ArrayFloat1D]) -> List[ArrayFloat1D]: + def contained_2d(bbox0: ArrayFloat1D, query_bbox: list[ArrayFloat1D]) -> list[ArrayFloat1D]: """Return a list of all bounding boxes among ``query_bbox`` that contain ``bbox0`` in 2D.""" return [ bbox @@ -720,7 +720,7 @@ def contained_2d(bbox0: ArrayFloat1D, query_bbox: List[ArrayFloat1D]) -> List[Ar ] @staticmethod - def contains_3d(bbox0: ArrayFloat1D, query_bbox: List[ArrayFloat1D]) -> List[int]: + def contains_3d(bbox0: ArrayFloat1D, query_bbox: list[ArrayFloat1D]) -> list[int]: """Return a list of all indexes of bounding boxes in the ``query_bbox`` list that ``bbox0`` fully contains.""" return [ @@ -739,7 +739,7 @@ def contains_3d(bbox0: ArrayFloat1D, query_bbox: List[ArrayFloat1D]) -> List[int ] @staticmethod - def is_close(coord: float, interval_coords: List[float], coord_ind: int, atol: float) -> bool: + def is_close(coord: float, interval_coords: list[float], coord_ind: int, atol: float) -> bool: """Check if a given ``coord`` is within ``atol`` of an interval coordinate at a given interval index. If the index is out of bounds, return ``False``.""" return ( @@ -749,7 +749,7 @@ def is_close(coord: float, interval_coords: List[float], coord_ind: int, atol: f ) @staticmethod - def is_contained(normal_pos: float, contained_2d: List[ArrayFloat1D]) -> bool: + def is_contained(normal_pos: float, contained_2d: list[ArrayFloat1D]) -> bool: """Check if a given ``normal_pos`` along the meshing direction is contained inside any of the bounding boxes that are in the ``contained_2d`` list. """ @@ -759,8 +759,8 @@ def is_contained(normal_pos: float, contained_2d: List[ArrayFloat1D]) -> bool: @staticmethod def filter_min_step( - interval_coords: List[float], max_steps: List[float] - ) -> Tuple[List[float], List[float]]: + interval_coords: list[float], max_steps: list[float] + ) -> tuple[list[float], list[float]]: """Filter intervals that are smaller than the absolute smallest of the ``max_steps``.""" # Re-compute minimum step in case some high-index structures were completely covered @@ -784,7 +784,7 @@ def make_grid_multiple_intervals( len_interval_list: ArrayFloat1D, max_scale: float, is_periodic: bool, - ) -> List[ArrayFloat1D]: + ) -> list[ArrayFloat1D]: """Create grid steps in multiple connecting intervals of length specified by ``len_interval_list``. The maximal allowed step size in each interval is given by ``max_dl_list``. The maximum ratio between neighboring steps is bounded by ``max_scale``. @@ -802,7 +802,7 @@ def make_grid_multiple_intervals( Returns ------- - List[ArrayFloat1D] + list[ArrayFloat1D] A list of of step sizes in each interval. """ @@ -880,7 +880,7 @@ def grid_multiple_interval_analy_refinement( len_interval_list: ArrayFloat1D, max_scale: float, is_periodic: bool, - ) -> Tuple[ArrayFloat1D, ArrayFloat1D]: + ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Analytical refinement for multiple intervals. "analytical" meaning we allow non-integar step sizes, so that we don't consider snapping here. @@ -897,7 +897,7 @@ def grid_multiple_interval_analy_refinement( Returns ------- - Tuple[ArrayFloat1D, ArrayFloat1D] + tuple[ArrayFloat1D, ArrayFloat1D] left and right step sizes of each interval. """ diff --git a/tidy3d/components/lumped_element.py b/tidy3d/components/lumped_element.py index 8d565f6ebe..d3a3bc8f48 100644 --- a/tidy3d/components/lumped_element.py +++ b/tidy3d/components/lumped_element.py @@ -1,32 +1,25 @@ """Defines lumped elements that should be included in the simulation.""" -from __future__ import annotations - from abc import ABC, abstractmethod from math import isclose -from typing import Annotated, Literal, Optional, Union +from typing import Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from tidy3d.log import log -from ..components.grid.grid import Grid -from ..components.medium import ( - PEC2D, - Debye, - Drude, - Lorentz, - Medium, - Medium2D, - PoleResidue, -) -from ..components.monitor import FieldMonitor -from ..components.structure import MeshOverrideStructure, Structure -from ..components.validators import assert_line_or_plane, assert_plane, validate_name_str +from ..compat import Self from ..constants import EPSILON_0, FARAD, HENRY, MICROMETER, OHM, fp_eps from ..exceptions import ValidationError -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .geometry.base import Box, ClipOperation, Geometry, GeometryGroup from .geometry.primitives import Cylinder from .geometry.utils import ( @@ -37,21 +30,34 @@ snap_point_to_grid, ) from .geometry.utils_2d import increment_float +from .grid.grid import Grid +from .medium import ( + PEC2D, + Debye, + Drude, + Lorentz, + Medium, + Medium2D, + PoleResidue, +) from .microwave.formulas.circuit_parameters import ( capacitance_colinear_cylindrical_wire_segments, capacitance_rectangular_sheets, inductance_straight_rectangular_wire, total_inductance_colinear_rectangular_wire_segments, ) +from .monitor import FieldMonitor +from .structure import MeshOverrideStructure, Structure from .types import ( - TYPE_TAG_STR, Axis, Axis2D, Coordinate, CoordinateOptional, FreqArray, LumpDistType, + discriminated_union, ) +from .validators import assert_line_or_plane, assert_plane, validate_name_str from .viz import PlotParams, plot_params_lumped_element DEFAULT_LUMPED_ELEMENT_NUM_CELLS = 1 @@ -61,14 +67,13 @@ class LumpedElement(Tidy3dBaseModel, ABC): """Base class describing the interface all lumped elements obey.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the lumped element.", min_length=1, ) - num_grid_cells: Optional[pd.PositiveInt] = pd.Field( + num_grid_cells: Optional[PositiveInt] = Field( DEFAULT_LUMPED_ELEMENT_NUM_CELLS, title="Lumped element grid cells", description="Number of mesh grid cells associated with the lumped element along each direction. " @@ -76,7 +81,7 @@ class LumpedElement(Tidy3dBaseModel, ABC): "A value of ``None`` will turn off mesh refinement suggestions.", ) - enable_snapping_points: bool = pd.Field( + enable_snapping_points: bool = Field( True, title="Snap Grid To Lumped Element", description="When enabled, snapping points are automatically generated to snap grids to key " @@ -112,13 +117,13 @@ def to_structures(self, grid: Grid = None) -> list[Structure]: which are ready to be added to the :class:`.Simulation`""" return [self.to_structure(grid)] - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class RectangularLumpedElement(LumpedElement, Box): @@ -126,14 +131,13 @@ class RectangularLumpedElement(LumpedElement, Box): is appended to the list of structures in the simulation as a :class:`.Medium2D` with the appropriate material properties given their size, voltage axis, and the network they represent.""" - voltage_axis: Axis = pd.Field( - ..., + voltage_axis: Axis = Field( title="Voltage Drop Axis", description="Specifies the axis along which the component is oriented and along which the " "associated voltage drop will occur. Must be in the plane of the element.", ) - snap_perimeter_to_grid: bool = pd.Field( + snap_perimeter_to_grid: bool = Field( True, title="Snap Perimeter to Grid", description="When enabled, the perimeter of the lumped element is snapped to the simulation grid, " @@ -184,7 +188,7 @@ def _snapping_spec(self) -> SnappingSpec: snap_behavior = [SnapBehavior.Closest] * 3 snap_location[self.lateral_axis] = SnapLocation.Center snap_behavior[self.lateral_axis] = SnapBehavior.Expand - return SnappingSpec(location=snap_location, behavior=snap_behavior) + return SnappingSpec(location=tuple(snap_location), behavior=tuple(snap_behavior)) def to_mesh_overrides(self) -> list[MeshOverrideStructure]: """Creates a suggested :class:`.MeshOverrideStructure` list for mesh refinement both on the @@ -288,18 +292,18 @@ def to_monitor(self, freqs: FreqArray) -> FieldMonitor: def monitor_name(self): return f"{self.name}_monitor" - @pd.validator("voltage_axis", always=True) - @skip_if_fields_missing(["name", "size"]) - def _voltage_axis_in_plane(cls, val, values): + @model_validator(mode="after") + def _voltage_axis_in_plane(self) -> Self: """Ensure voltage drop axis is in the plane of the lumped element.""" - name = values.get("name") - size = values.get("size") + val = self.voltage_axis + name = self.name + size = self.size if size.count(0.0) == 1 and size.index(0.0) == val: # if not planar, then a separate validator should be triggered, not this one raise ValidationError( f"'voltage_axis' must be in the plane of lumped element '{name}'." ) - return val + return self class LumpedResistor(RectangularLumpedElement): @@ -307,8 +311,7 @@ class LumpedResistor(RectangularLumpedElement): of structures in the simulation as :class:`Medium2D` with the appropriate conductivity given their size and voltage axis.""" - resistance: pd.PositiveFloat = pd.Field( - ..., + resistance: PositiveFloat = Field( title="Resistance", description="Resistance value in ohms.", unit=OHM, @@ -343,36 +346,32 @@ class CoaxialLumpedResistor(LumpedElement): structures in the simulation as :class:`Medium2D` with the appropriate conductivity given their size and geometry.""" - resistance: pd.PositiveFloat = pd.Field( - ..., + resistance: PositiveFloat = Field( title="Resistance", description="Resistance value in ohms.", unit=OHM, ) - center: Coordinate = pd.Field( + center: Coordinate = Field( (0.0, 0.0, 0.0), title="Center", description="Center of object in x, y, and z.", units=MICROMETER, ) - outer_diameter: pd.PositiveFloat = pd.Field( - ..., + outer_diameter: PositiveFloat = Field( title="Outer Diameter", description="Diameter of the outer concentric circle.", units=MICROMETER, ) - inner_diameter: pd.PositiveFloat = pd.Field( - ..., + inner_diameter: PositiveFloat = Field( title="Inner Diameter", description="Diameter of the inner concentric circle.", units=MICROMETER, ) - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Normal Axis", description="Specifies the normal axis, which defines " "the orientation of the circles making up the coaxial lumped element.", @@ -411,23 +410,23 @@ def to_mesh_overrides(self) -> list[MeshOverrideStructure]: ) ] - @pd.validator("center", always=True) - def _center_not_inf(cls, val): + @field_validator("center") + def _center_not_inf(val): """Make sure center is not infinitiy.""" if any(np.isinf(v) for v in val): raise ValidationError("'center' can not contain 'td.inf' terms.") return val - @pd.validator("inner_diameter", always=True) - @skip_if_fields_missing(["outer_diameter"]) - def _ensure_inner_diameter_is_smaller(cls, val, values): + @model_validator(mode="after") + def _ensure_inner_diameter_is_smaller(self) -> Self: """Ensures that the inner diameter is smaller than the outer diameter, so that the final shape is an annulus.""" - outer_diameter = values.get("outer_diameter") + val = self.inner_diameter + outer_diameter = self.outer_diameter if val >= outer_diameter: raise ValidationError( f"The 'inner_diameter' {val} of a coaxial lumped element must be less than its 'outer_diameter' {outer_diameter}." ) - return val + return self @cached_property def _sheet_conductance(self): @@ -553,13 +552,13 @@ def complex_permittivity(a: tuple[float, ...], b: tuple[float, ...], freqs: np.n sigma = NetworkConversions.complex_conductivity(a, b, freqs) return 1j * sigma / (2 * np.pi * freqs * EPSILON_0) - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class RLCNetwork(Tidy3dBaseModel): @@ -586,35 +585,35 @@ class RLCNetwork(Tidy3dBaseModel): """ - resistance: Optional[pd.PositiveFloat] = pd.Field( + resistance: Optional[PositiveFloat] = Field( None, title="Resistance", description="Resistance value in ohms.", unit=OHM, ) - capacitance: Optional[pd.PositiveFloat] = pd.Field( + capacitance: Optional[PositiveFloat] = Field( None, title="Capacitance", description="Capacitance value in farads.", unit=FARAD, ) - inductance: Optional[pd.PositiveFloat] = pd.Field( + inductance: Optional[PositiveFloat] = Field( None, title="Inductance", description="Inductance value in henrys.", unit=HENRY, ) - network_topology: Literal["series", "parallel"] = pd.Field( + network_topology: Literal["series", "parallel"] = Field( "series", title="Network Topology", description="Describes whether network elements are connected in ``series`` or ``parallel``.", ) @cached_property - def _number_network_elements(self) -> pd.PositiveInt: + def _number_network_elements(self) -> PositiveInt: num_elements = 0 if self.resistance: num_elements += 1 @@ -801,24 +800,24 @@ def combine_equivalent_medium_in_parallel(first: PoleResidue, second: PoleResidu result_medium = combine_equivalent_medium_in_parallel(med, result_medium) return result_medium - @pd.validator("inductance", always=True) - @skip_if_fields_missing(["resistance", "capacitance"]) - def _validate_single_element(cls, val, values): + @model_validator(mode="after") + def _validate_single_element(self): """At least one element should be defined.""" - resistance = values.get("resistance") - capacitance = values.get("capacitance") + val = self.inductance + resistance = self.resistance + capacitance = self.capacitance all_items_are_none = all(item is None for item in [resistance, capacitance, val]) if all_items_are_none: raise ValueError("At least one element must be defined in the 'RLCNetwork'.") - return val + return self - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class AdmittanceNetwork(Tidy3dBaseModel): @@ -869,15 +868,13 @@ class AdmittanceNetwork(Tidy3dBaseModel): """ - a: tuple[pd.NonNegativeFloat, ...] = pd.Field( - ..., + a: tuple[NonNegativeFloat, ...] = Field( title="Numerator Coefficients", description="A ``tuple`` of floats describing the coefficients of the numerator polynomial. " "The length of the ``tuple`` is equal to the order of the network.", ) - b: tuple[pd.NonNegativeFloat, ...] = pd.Field( - ..., + b: tuple[NonNegativeFloat, ...] = Field( title="Denominator Coefficients", description="A ``tuple`` of floats describing the coefficients of the denomiator polynomial. " "The length of the ``tuple`` is equal to the order of the network.", @@ -897,13 +894,16 @@ def _as_admittance_function(self) -> tuple[tuple[float, ...], tuple[float, ...]] """ return (self.a, self.b) - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data + + +NetworkType = discriminated_union(Union[RLCNetwork, AdmittanceNetwork]) class LinearLumpedElement(RectangularLumpedElement): @@ -944,15 +944,13 @@ class LinearLumpedElement(RectangularLumpedElement): * `Using lumped elements in Tidy3D simulations <../../notebooks/LinearLumpedElements.html>`_ """ - network: Union[RLCNetwork, AdmittanceNetwork] = pd.Field( - ..., + network: NetworkType = Field( title="Network", description="The linear element produces an equivalent medium that emulates the " "voltage-current relationship described by the ``network`` field.", - discriminator=TYPE_TAG_STR, ) - dist_type: LumpDistType = pd.Field( + dist_type: LumpDistType = Field( "on", title="Distribute Type", description="Switches between the different methods for distributing the lumped element over " @@ -1005,7 +1003,7 @@ def _create_box_for_network(self, grid: Grid) -> Box: if size[self.voltage_axis] == 0: behavior = list(snap_spec.behavior) behavior[self.voltage_axis] = SnapBehavior.Expand - snap_spec = snap_spec.updated_copy(behavior=behavior) + snap_spec = snap_spec.updated_copy(behavior=tuple(behavior)) return snap_box_to_grid(grid, cell_box, snap_spec=snap_spec) @@ -1198,11 +1196,10 @@ def impedance(self, freqs: np.ndarray) -> np.ndarray: # lumped elements allowed in Simulation.lumped_elements -LumpedElementType = Annotated[ +LumpedElementType = discriminated_union( Union[ LumpedResistor, CoaxialLumpedResistor, LinearLumpedElement, - ], - pd.Field(discriminator=TYPE_TAG_STR), -] + ] +) diff --git a/tidy3d/components/material/multi_physics.py b/tidy3d/components/material/multi_physics.py index ea60652536..723bf39f68 100644 --- a/tidy3d/components/material/multi_physics.py +++ b/tidy3d/components/material/multi_physics.py @@ -1,6 +1,6 @@ from typing import Optional -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.material.solver_types import ( @@ -78,24 +78,30 @@ class MultiPhysicsMedium(Tidy3dBaseModel): ... ) """ - name: Optional[str] = pd.Field(None, title="Name", description="Medium name") + name: Optional[str] = Field(None, title="Name", description="Medium name") - optical: Optional[OpticalMediumType] = pd.Field( - None, title="Optical properties", description="Specifies optical properties." + optical: Optional[OpticalMediumType] = Field( + None, + title="Optical properties", + description="Specifies optical properties.", ) - # electrical: Optional[ElectricalMediumType] = pd.Field( + # electrical: Optional[ElectricalMediumType] = Field( # None, # title="Electrical properties", # description="Specifies electrical properties for RF simulations. This is currently not in use.", # ) - heat: Optional[HeatMediumType] = pd.Field( - None, title="Heat properties", description="Specifies properties for Heat simulations." + heat: Optional[HeatMediumType] = Field( + None, + title="Heat properties", + description="Specifies properties for Heat simulations.", ) - charge: Optional[ChargeMediumType] = pd.Field( - None, title="Charge properties", description="Specifies properties for Charge simulations." + charge: Optional[ChargeMediumType] = Field( + None, + title="Charge properties", + description="Specifies properties for Charge simulations.", ) def __getattr__(self, name: str): diff --git a/tidy3d/components/material/tcad/charge.py b/tidy3d/components/material/tcad/charge.py index 406a862d9d..5dc963f457 100644 --- a/tidy3d/components/material/tcad/charge.py +++ b/tidy3d/components/material/tcad/charge.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Tuple +from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.components.medium import AbstractMedium @@ -14,20 +14,19 @@ MobilityModelType, RecombinationModelType, ) -from tidy3d.components.types import Union -from tidy3d.constants import ( - CONDUCTIVITY, - ELECTRON_VOLT, - PERMITTIVITY, -) +from tidy3d.constants import CONDUCTIVITY, ELECTRON_VOLT, PERMITTIVITY class AbstractChargeMedium(AbstractMedium): """Abstract class for Charge specifications Currently, permittivity is treated as a constant.""" - permittivity: float = pd.Field( - 1.0, ge=1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY + permittivity: float = Field( + 1.0, + ge=1.0, + title="Permittivity", + description="Relative permittivity.", + units=PERMITTIVITY, ) @property @@ -75,8 +74,7 @@ class ChargeConductorMedium(AbstractChargeMedium): A relative permittivity will be assumed 1 if no value is specified. """ - conductivity: pd.PositiveFloat = pd.Field( - ..., + conductivity: PositiveFloat = Field( title="Electric conductivity", description=f"Electric conductivity of material in units of {CONDUCTIVITY}.", units=CONDUCTIVITY, @@ -256,59 +254,54 @@ class SemiconductorMedium(AbstractChargeMedium): """ - N_c: pd.PositiveFloat = pd.Field( - ..., + N_c: PositiveFloat = Field( title="Effective density of electron states", description=r"$N_c$ Effective density of states in the conduction band.", units="cm^(-3)", ) - N_v: pd.PositiveFloat = pd.Field( - ..., + N_v: PositiveFloat = Field( title="Effective density of hole states", description=r"$N_v$ Effective density of states in the valence band.", units="cm^(-3)", ) - E_g: pd.PositiveFloat = pd.Field( - ..., + E_g: PositiveFloat = Field( title="Band-gap energy", description="Band-gap energy", units=ELECTRON_VOLT, ) - mobility_n: MobilityModelType = pd.Field( - ..., + mobility_n: MobilityModelType = Field( title="Mobility model for electrons", description="Mobility model for electrons", ) - mobility_p: MobilityModelType = pd.Field( - ..., + mobility_p: MobilityModelType = Field( title="Mobility model for holes", description="Mobility model for holes", ) - R: Tuple[RecombinationModelType, ...] = pd.Field( - [], + R: tuple[RecombinationModelType, ...] = Field( + (), title="Generation-Recombination models", description="Array containing the R models to be applied to the material.", ) - delta_E_g: BandGapNarrowingModelType = pd.Field( + delta_E_g: Optional[BandGapNarrowingModelType] = Field( None, title=r"$\Delta E_g$ Bandgap narrowing model.", description="Bandgap narrowing model.", ) - N_a: Union[pd.NonNegativeFloat, SpatialDataArray, Tuple[DopingBoxType, ...]] = pd.Field( + N_a: Union[NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...]] = Field( 0, title="Doping: Acceptor concentration", description="Units of 1/cm^3", units="1/cm^3", ) - N_d: Union[pd.NonNegativeFloat, SpatialDataArray, Tuple[DopingBoxType, ...]] = pd.Field( + N_d: Union[NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...]] = Field( 0, title="Doping: Donor concentration", description="Units of 1/cm^3", diff --git a/tidy3d/components/material/tcad/heat.py b/tidy3d/components/material/tcad/heat.py index fc36ed29c3..6d376d30b3 100644 --- a/tidy3d/components/material/tcad/heat.py +++ b/tidy3d/components/material/tcad/heat.py @@ -3,22 +3,19 @@ from __future__ import annotations from abc import ABC -from typing import Union +from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.constants import ( - SPECIFIC_HEAT_CAPACITY, - THERMAL_CONDUCTIVITY, -) +from tidy3d.constants import SPECIFIC_HEAT_CAPACITY, THERMAL_CONDUCTIVITY # Liquid class class AbstractHeatMedium(ABC, Tidy3dBaseModel): """Abstract heat material specification.""" - name: str = pd.Field(None, title="Name", description="Optional unique name for medium.") + name: Optional[str] = Field(None, title="Name", description="Optional unique name for medium.") @property def heat(self): @@ -68,13 +65,13 @@ class SolidMedium(AbstractHeatMedium): ... ) """ - capacity: pd.PositiveFloat = pd.Field( + capacity: PositiveFloat = Field( title="Heat capacity", description=f"Volumetric heat capacity in unit of {SPECIFIC_HEAT_CAPACITY}.", units=SPECIFIC_HEAT_CAPACITY, ) - conductivity: pd.PositiveFloat = pd.Field( + conductivity: PositiveFloat = Field( title="Thermal conductivity", description=f"Thermal conductivity of material in units of {THERMAL_CONDUCTIVITY}.", units=THERMAL_CONDUCTIVITY, diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 1383ce93bf..8abb71b21c 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -6,19 +6,21 @@ import warnings from abc import ABC, abstractmethod from math import isclose -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Union import autograd as ag import autograd.numpy as np - -# TODO: it's hard to figure out which functions need this, for now all get it import numpy as npo -import pydantic.v1 as pd import xarray as xr from numpy.typing import NDArray -from scipy import signal - -from tidy3d.components.material.tcad.heat import ThermalSpecType +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from ..constants import ( C_0, @@ -40,13 +42,15 @@ from ..exceptions import SetupError, ValidationError from ..log import log from .autograd.derivative_utils import DerivativeInfo, integrate_within_bounds -from .autograd.types import AutogradFieldMap, TracedFloat, TracedPoleAndResidue, TracedPositiveFloat -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from .data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray -from .data.dataset import ( - ElectromagneticFieldDataset, - PermittivityDataset, +from .autograd.types import ( + AutogradFieldMap, + TracedFloat, + TracedPolesAndResidues, + TracedPositiveFloat, ) +from .base import Tidy3dBaseModel, cached_property +from .data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray +from .data.dataset import ElectromagneticFieldDataset, PermittivityDataset from .data.unstructured.base import UnstructuredGridDataset from .data.utils import ( CustomSpatialDataType, @@ -66,6 +70,7 @@ ) from .geometry.base import Geometry from .grid.grid import Coords, Grid +from .material.tcad.heat import ThermalSpecType from .parameter_perturbation import ( IndexPerturbation, ParameterPerturbation, @@ -86,7 +91,7 @@ InterpMethod, Literal, PermittivityComponent, - PoleAndResidue, + PolesAndResidues, TensorReal, ) from .validators import _warn_potential_error, validate_name_str, validate_parameter_perturbation @@ -172,17 +177,17 @@ def _validate_medium(self, medium: AbstractMedium): """Any additional validation that depends on the medium""" pass - def _validate_medium_freqs(self, medium: AbstractMedium, freqs: List[pd.PositiveFloat]) -> None: + def _validate_medium_freqs(self, medium: AbstractMedium, freqs: list[PositiveFloat]) -> None: """Any additional validation that depends on the central frequencies of the sources.""" pass def _hardcode_medium_freqs( - self, medium: AbstractMedium, freqs: List[pd.PositiveFloat] + self, medium: AbstractMedium, freqs: list[PositiveFloat] ) -> NonlinearSpec: """Update the nonlinear model to hardcode information on medium and freqs.""" return self - def _get_freq0(self, freq0, freqs: List[pd.PositiveFloat]) -> float: + def _get_freq0(self, freq0, freqs: list[PositiveFloat]) -> float: """Get a single value for freq0.""" # freq0 is not specified; need to calculate it @@ -219,7 +224,7 @@ def _get_n0( self, n0: complex, medium: AbstractMedium, - freqs: List[pd.PositiveFloat], + freqs: list[PositiveFloat], ) -> complex: """Get a single value for n0.""" if freqs is None: @@ -265,7 +270,7 @@ def complex_fields(self) -> bool: return False @property - def aux_fields(self) -> List[str]: + def aux_fields(self) -> list[str]: """List of available aux fields in this model.""" return [] @@ -305,14 +310,14 @@ class NonlinearSusceptibility(NonlinearModel): >>> nonlinear_susceptibility = NonlinearSusceptibility(chi3=1) """ - chi3: float = pd.Field( + chi3: float = Field( 0, title="Chi3", description="Chi3 nonlinear susceptibility.", units=f"{MICROMETER}^2 / {VOLT}^2", ) - numiters: pd.PositiveInt = pd.Field( + numiters: Optional[PositiveInt] = Field( None, title="Number of iterations", description="Deprecated. The old usage 'nonlinear_spec=model' with 'model.numiters' " @@ -321,8 +326,8 @@ class NonlinearSusceptibility(NonlinearModel): "usage, this parameter is ignored, and 'NonlinearSpec.num_iters' is used instead.", ) - @pd.validator("numiters", always=True) - def _validate_numiters(cls, val): + @field_validator("numiters") + def _validate_numiters(val): """Check that numiters is not too large.""" if val is None: return val @@ -385,7 +390,7 @@ class TwoPhotonAbsorption(NonlinearModel): >>> tpa_model = TwoPhotonAbsorption(beta=1) """ - use_complex_fields: bool = pd.Field( + use_complex_fields: bool = Field( False, title="Use complex fields", description="Whether to use the old deprecated complex-fields implementation. " @@ -394,51 +399,51 @@ class TwoPhotonAbsorption(NonlinearModel): "with Tidy3D version < 2.8 and may be removed in a future release.", ) - beta: Union[float, Complex] = pd.Field( + beta: Union[float, Complex] = Field( 0, title="TPA coefficient", description="Coefficient for two-photon absorption (TPA).", units=f"{MICROMETER} / {WATT}", ) - tau: pd.NonNegativeFloat = pd.Field( + tau: NonNegativeFloat = Field( 0, title="Carrier lifetime", description="Lifetime for the free carriers created by two-photon absorption (TPA).", units=f"{SECOND}", ) - sigma: pd.NonNegativeFloat = pd.Field( + sigma: NonNegativeFloat = Field( 0, title="FCA cross section", description="Total cross section for free-carrier absorption (FCA). " "Contains contributions from electrons and from holes.", units=f"{MICROMETER}^2", ) - e_e: pd.NonNegativeFloat = pd.Field( + e_e: NonNegativeFloat = Field( 1, title="Electron exponent", description="Exponent for the free electron refractive index shift in the free-carrier plasma dispersion (FCPD).", ) - e_h: pd.NonNegativeFloat = pd.Field( + e_h: NonNegativeFloat = Field( 1, title="Hole exponent", description="Exponent for the free hole refractive index shift in the free-carrier plasma dispersion (FCPD).", ) - c_e: float = pd.Field( + c_e: float = Field( 0, title="Electron coefficient", description="Coefficient for the free electron refractive index shift in the free-carrier plasma dispersion (FCPD).", units=f"{MICROMETER}^(3 e_e)", ) - c_h: float = pd.Field( + c_h: float = Field( 0, title="Hole coefficient", description="Coefficient for the free hole refractive index shift in the free-carrier plasma dispersion (FCPD).", units=f"{MICROMETER}^(3 e_h)", ) - n0: Optional[Complex] = pd.Field( + n0: Optional[Complex] = Field( None, title="Complex linear refractive index", description="Complex linear refractive index of the medium, computed for instance using " @@ -446,7 +451,7 @@ class TwoPhotonAbsorption(NonlinearModel): "frequencies of the simulation sources (as long as these are all equal).", ) - freq0: Optional[pd.PositiveFloat] = pd.Field( + freq0: Optional[PositiveFloat] = Field( None, title="Central frequency", description="Central frequency, used to calculate the energy of the free-carriers " @@ -454,21 +459,22 @@ class TwoPhotonAbsorption(NonlinearModel): "from the simulation sources (as long as these are all equal).", ) - @pd.validator("beta", always=True) - def _validate_beta_real(cls, val, values): + @model_validator(mode="after") + def _validate_beta_real(self): """Check that beta is real and give a useful error if it is not.""" - use_complex_fields = values.get("use_complex_fields") + val = self.beta + use_complex_fields = self.use_complex_fields if use_complex_fields: - return val + return self if not np.isreal(val): raise SetupError( "Complex values of 'beta' in 'TwoPhotonAbsorption' are not " "supported; the implementation uses the " "physical real-valued fields." ) - return val + return self - def _validate_medium_freqs(self, medium: AbstractMedium, freqs: List[pd.PositiveFloat]) -> None: + def _validate_medium_freqs(self, medium: AbstractMedium, freqs: list[PositiveFloat]) -> None: """Any validation that depends on knowing the central frequencies of the sources. This includes passivity checking, if necessary.""" n0 = self._get_n0(self.n0, medium, freqs) @@ -487,7 +493,7 @@ def _validate_medium_freqs(self, medium: AbstractMedium, freqs: List[pd.Positive ) def _hardcode_medium_freqs( - self, medium: AbstractMedium, freqs: List[pd.PositiveFloat] + self, medium: AbstractMedium, freqs: list[PositiveFloat] ) -> TwoPhotonAbsorption: """Update the nonlinear model to hardcode information on medium and freqs.""" n0 = self._get_n0(n0=self.n0, medium=medium, freqs=freqs) @@ -506,7 +512,7 @@ def complex_fields(self) -> bool: return self.use_complex_fields @property - def aux_fields(self) -> List[str]: + def aux_fields(self) -> list[str]: """List of available aux fields in this model.""" if self.tau == 0: return [] @@ -560,7 +566,7 @@ class KerrNonlinearity(NonlinearModel): >>> kerr_model = KerrNonlinearity(n2=1) """ - use_complex_fields: bool = pd.Field( + use_complex_fields: bool = Field( False, title="Use complex fields", description="Whether to use the old deprecated complex-fields implementation. " @@ -569,14 +575,14 @@ class KerrNonlinearity(NonlinearModel): "with Tidy3D version < 2.8 and may be removed in a future release.", ) - n2: Complex = pd.Field( + n2: Complex = Field( 0, title="Nonlinear refractive index", description="Nonlinear refractive index in the Kerr nonlinearity.", units=f"{MICROMETER}^2 / {WATT}", ) - n0: Optional[Complex] = pd.Field( + n0: Optional[Complex] = Field( None, title="Complex linear refractive index", description="Complex linear refractive index of the medium, computed for instance using " @@ -584,12 +590,13 @@ class KerrNonlinearity(NonlinearModel): "frequencies of the simulation sources (as long as these are all equal).", ) - @pd.validator("n2", always=True) - def _validate_n2_real(cls, val, values): + @model_validator(mode="after") + def _validate_n2_real(self): """Check that n2 is real and give a useful error if it is not.""" - use_complex_fields = values.get("use_complex_fields") + val = self.n2 + use_complex_fields = self.use_complex_fields if use_complex_fields: - return val + return self if not np.isreal(val): raise SetupError( "Complex values of 'n2' in 'KerrNonlinearity' are not " @@ -600,9 +607,9 @@ def _validate_n2_real(cls, val, values): "more physical dispersive loss of the form " "'chi_{TPA} = i (c_0 n_0 beta / omega) I'." ) - return val + return self - def _validate_medium_freqs(self, medium: AbstractMedium, freqs: List[pd.PositiveFloat]) -> None: + def _validate_medium_freqs(self, medium: AbstractMedium, freqs: list[PositiveFloat]) -> None: """Any validation that depends on knowing the central frequencies of the sources. This includes passivity checking, if necessary.""" n0 = self._get_n0(self.n0, medium, freqs) @@ -627,7 +634,7 @@ def _validate_medium(self, medium: AbstractMedium): self._validate_medium_freqs(medium, []) def _hardcode_medium_freqs( - self, medium: AbstractMedium, freqs: List[pd.PositiveFloat] + self, medium: AbstractMedium, freqs: list[PositiveFloat] ) -> KerrNonlinearity: """Update the nonlinear model to hardcode information on medium and freqs.""" n0 = self._get_n0(n0=self.n0, medium=medium, freqs=freqs) @@ -657,7 +664,7 @@ class NonlinearSpec(ABC, Tidy3dBaseModel): >>> medium = Medium(permittivity=2, nonlinear_spec=nonlinear_spec) """ - models: Tuple[NonlinearModelType, ...] = pd.Field( + models: tuple[NonlinearModelType, ...] = Field( (), title="Nonlinear models", description="The nonlinear models present in this nonlinear spec. " @@ -665,14 +672,14 @@ class NonlinearSpec(ABC, Tidy3dBaseModel): "Multiple nonlinear models of the same type are not allowed.", ) - num_iters: pd.PositiveInt = pd.Field( + num_iters: PositiveInt = Field( NONLINEAR_DEFAULT_NUM_ITERS, title="Number of iterations", description="Number of iterations for solving nonlinear constitutive relation.", ) - @pd.validator("models", always=True) - def _no_duplicate_models(cls, val): + @field_validator("models") + def _no_duplicate_models(val): """Ensure each type of model appears at most once.""" if val is None: return val @@ -686,8 +693,8 @@ def _no_duplicate_models(cls, val): ) return val - @pd.validator("models", always=True) - def _consistent_old_complex_fields(cls, val): + @field_validator("models") + def _consistent_old_complex_fields(val): """Ensure that old complex fields implementation is used consistently.""" if val is None: return val @@ -709,8 +716,8 @@ def _consistent_old_complex_fields(cls, val): ) return val - @pd.validator("num_iters", always=True) - def _validate_num_iters(cls, val, values): + @field_validator("num_iters") + def _validate_num_iters(val): """Check that num_iters is not too large.""" if val > NONLINEAR_MAX_NUM_ITERS: raise ValidationError( @@ -720,17 +727,17 @@ def _validate_num_iters(cls, val, values): return val def _hardcode_medium_freqs( - self, medium: AbstractMedium, freqs: List[pd.PositiveFloat] + self, medium: AbstractMedium, freqs: list[PositiveFloat] ) -> NonlinearSpec: """Update the nonlinear spec to hardcode information on medium and freqs.""" new_models = [] for model in self.models: new_model = model._hardcode_medium_freqs(medium=medium, freqs=freqs) new_models.append(new_model) - return self.updated_copy(models=new_models) + return self.updated_copy(models=tuple(new_models)) @property - def aux_fields(self) -> List[str]: + def aux_fields(self) -> list[str]: """List of available aux fields in all present models.""" fields = [] for model in self.models: @@ -741,16 +748,16 @@ def aux_fields(self) -> List[str]: class AbstractMedium(ABC, Tidy3dBaseModel): """A medium within which electromagnetic waves propagate.""" - name: str = pd.Field(None, title="Name", description="Optional unique name for medium.") + name: Optional[str] = Field(None, title="Name", description="Optional unique name for medium.") - frequency_range: FreqBound = pd.Field( + frequency_range: Optional[FreqBound] = Field( None, title="Frequency Range", description="Optional range of validity for the medium.", units=(HERTZ, HERTZ), ) - allow_gain: bool = pd.Field( + allow_gain: bool = Field( False, title="Allow gain medium", description="Allow the medium to be active. Caution: " @@ -760,51 +767,39 @@ class AbstractMedium(ABC, Tidy3dBaseModel): "useful in some cases.", ) - nonlinear_spec: Union[NonlinearSpec, NonlinearSusceptibility] = pd.Field( + nonlinear_spec: Optional[Union[NonlinearSpec, NonlinearSusceptibility]] = Field( None, title="Nonlinear Spec", description="Nonlinear spec applied on top of the base medium properties.", ) - modulation_spec: ModulationSpec = pd.Field( + modulation_spec: Optional[ModulationSpec] = Field( None, title="Modulation Spec", description="Modulation spec applied on top of the base medium properties.", ) - viz_spec: Optional[VisualizationSpec] = pd.Field( + viz_spec: Optional[VisualizationSpec] = Field( None, title="Visualization Specification", description="Plotting specification for visualizing medium.", ) - @cached_property - def _nonlinear_models(self) -> List: - """The nonlinear models in the nonlinear_spec.""" - if self.nonlinear_spec is None: - return [] - if isinstance(self.nonlinear_spec, NonlinearModel): - return [self.nonlinear_spec] - if self.nonlinear_spec.models is None: - return [] - return list(self.nonlinear_spec.models) - - @cached_property - def _nonlinear_num_iters(self) -> pd.PositiveInt: - """The num_iters of the nonlinear_spec.""" - if self.nonlinear_spec is None: - return 0 - if isinstance(self.nonlinear_spec, NonlinearModel): - if self.nonlinear_spec.numiters is None: - return 1 # old default value for backwards compatibility - return self.nonlinear_spec.numiters - return self.nonlinear_spec.num_iters - - def _post_init_validators(self) -> None: - """Call validators taking ``self`` that get run after init.""" - self._validate_nonlinear_spec() - self._validate_modulation_spec_post_init() + heat_spec: Optional[ThermalSpecType] = Field( + None, + title="Heat Specification", + description="DEPRECATED: Use `td.MultiPhysicsMedium`. Specification of the medium heat properties. They are " + "used for solving the heat equation via the ``HeatSimulation`` interface. Such simulations can be" + "used for investigating the influence of heat propagation on the properties of optical systems. " + "Once the temperature distribution in the system is found using ``HeatSimulation`` object, " + "``Simulation.perturbed_mediums_copy()`` can be used to convert mediums with perturbation " + "models defined into spatially dependent custom mediums. " + "Otherwise, the ``heat_spec`` does not directly affect the running of an optical " + "``Simulation``.", + discriminator=TYPE_TAG_STR, + ) + @model_validator(mode="after") def _validate_nonlinear_spec(self): """Check compatibility with nonlinear_spec.""" if self.__class__.__name__ == "AnisotropicMedium" and any( @@ -822,7 +817,7 @@ def _validate_nonlinear_spec(self): ) if self.nonlinear_spec is None: - return + return self if isinstance(self.nonlinear_spec, NonlinearModel): log.warning( "The API for 'nonlinear_spec' has changed. " @@ -842,6 +837,22 @@ def _validate_nonlinear_spec(self): "'NonlinearSusceptibility.numiters' is deprecated. " "Please use 'NonlinearSpec.num_iters' instead." ) + return self + + @model_validator(mode="after") + def _check_either_modulation_or_nonlinear_spec(self): + """Check compatibility with modulation_spec.""" + val = self.modulation_spec + nonlinear_spec = self.nonlinear_spec + if val is not None and nonlinear_spec is not None: + raise ValidationError( + f"For medium class {self.type}, 'modulation_spec' of class {type(val)} and " + f"'nonlinear_spec' of class {type(nonlinear_spec)} are " + "not simultaneously supported." + ) + return self + + _name_validator = validate_name_str() def _validate_modulation_spec_post_init(self): """Check compatibility with nonlinear_spec.""" @@ -852,19 +863,9 @@ def _validate_modulation_spec_post_init(self): "Time modulation is not currently supported for the components " "of a 2D medium." ) - heat_spec: Optional[ThermalSpecType] = pd.Field( - None, - title="Heat Specification", - description="DEPRECATED: Use `td.MultiPhysicsMedium`. Specification of the medium heat properties. They are " - "used for solving the heat equation via the ``HeatSimulation`` interface. Such simulations can be" - "used for investigating the influence of heat propagation on the properties of optical systems. " - "Once the temperature distribution in the system is found using ``HeatSimulation`` object, " - "``Simulation.perturbed_mediums_copy()`` can be used to convert mediums with perturbation " - "models defined into spatially dependent custom mediums. " - "Otherwise, the ``heat_spec`` does not directly affect the running of an optical " - "``Simulation``.", - discriminator=TYPE_TAG_STR, - ) + @property + def _post_init_validators(self): + return (lambda: self._validate_modulation_spec_post_init(),) @property def charge(self): @@ -882,20 +883,27 @@ def heat(self): def optical(self): return None - @pd.validator("modulation_spec", always=True) - @skip_if_fields_missing(["nonlinear_spec"]) - def _validate_modulation_spec(cls, val, values): - """Check compatibility with modulation_spec.""" - nonlinear_spec = values.get("nonlinear_spec") - if val is not None and nonlinear_spec is not None: - raise ValidationError( - f"For medium class {cls}, 'modulation_spec' of class {type(val)} and " - f"'nonlinear_spec' of class {type(nonlinear_spec)} are " - "not simultaneously supported." - ) - return val + @cached_property + def _nonlinear_models(self) -> list: + """The nonlinear models in the nonlinear_spec.""" + if self.nonlinear_spec is None: + return [] + if isinstance(self.nonlinear_spec, NonlinearModel): + return [self.nonlinear_spec] + if self.nonlinear_spec.models is None: + return [] + return list(self.nonlinear_spec.models) - _name_validator = validate_name_str() + @cached_property + def _nonlinear_num_iters(self) -> PositiveInt: + """The num_iters of the nonlinear_spec.""" + if self.nonlinear_spec is None: + return 0 + if isinstance(self.nonlinear_spec, NonlinearModel): + if self.nonlinear_spec.numiters is None: + return 1 # old default value for backwards compatibility + return self.nonlinear_spec.numiters + return self.nonlinear_spec.num_iters @cached_property def is_spatially_uniform(self) -> bool: @@ -923,7 +931,7 @@ def is_fully_anisotropic(self) -> bool: return isinstance(self, FullyAnisotropicMedium) @cached_property - def _incompatible_material_types(self) -> List[str]: + def _incompatible_material_types(self) -> list[str]: """A list of material properties present which may lead to incompatibilities.""" properties = [ self.is_time_modulated, @@ -974,7 +982,7 @@ def eps_model(self, frequency: float) -> complex: Complex-valued relative permittivity evaluated at ``frequency``. """ - def nk_model(self, frequency: float) -> Tuple[float, float]: + def nk_model(self, frequency: float) -> tuple[float, float]: """Real and imaginary parts of the refactive index as a function of frequency. Parameters @@ -984,13 +992,13 @@ def nk_model(self, frequency: float) -> Tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] Real part (n) and imaginary part (k) of refractive index of medium. """ eps_complex = self.eps_model(frequency=frequency) return self.eps_complex_to_nk(eps_complex) - def loss_tangent_model(self, frequency: float) -> Tuple[float, float]: + def loss_tangent_model(self, frequency: float) -> tuple[float, float]: """Permittivity and loss tangent as a function of frequency. Parameters @@ -1000,14 +1008,14 @@ def loss_tangent_model(self, frequency: float) -> Tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] Real part of permittivity and loss tangent. """ eps_complex = self.eps_model(frequency=frequency) return self.eps_complex_to_eps_loss_tangent(eps_complex) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor as a function of frequency. Parameters @@ -1017,7 +1025,7 @@ def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: Returns ------- - Tuple[complex, complex, complex] + tuple[complex, complex, complex] The diagonal elements of the relative permittivity tensor evaluated at ``frequency``. """ @@ -1025,7 +1033,7 @@ def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: eps = self.eps_model(frequency) return (eps, eps, eps) - def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor for numerical considerations such as meshing and runtime estimation. @@ -1036,7 +1044,7 @@ def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, co Returns ------- - Tuple[complex, complex, complex] + tuple[complex, complex, complex] The diagonal elements of relative permittivity tensor relevant for numerical considerations evaluated at ``frequency``. """ @@ -1163,7 +1171,7 @@ def nk_to_eps_complex(n: float, k: float = 0.0) -> complex: return eps_real + 1j * eps_imag @staticmethod - def eps_complex_to_nk(eps_c: complex) -> Tuple[float, float]: + def eps_complex_to_nk(eps_c: complex) -> tuple[float, float]: """Convert complex permittivity to n, k values. Parameters @@ -1173,7 +1181,7 @@ def eps_complex_to_nk(eps_c: complex) -> Tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] Real and imaginary parts of refractive index (n & k). """ eps_c = np.array(eps_c) @@ -1181,7 +1189,7 @@ def eps_complex_to_nk(eps_c: complex) -> Tuple[float, float]: return np.real(ref_index), np.imag(ref_index) @staticmethod - def nk_to_eps_sigma(n: float, k: float, freq: float) -> Tuple[float, float]: + def nk_to_eps_sigma(n: float, k: float, freq: float) -> tuple[float, float]: """Convert ``n``, ``k`` at frequency ``freq`` to permittivity and conductivity values. Parameters @@ -1195,7 +1203,7 @@ def nk_to_eps_sigma(n: float, k: float, freq: float) -> Tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] Real part of relative permittivity & electric conductivity. """ eps_complex = AbstractMedium.nk_to_eps_complex(n, k) @@ -1230,7 +1238,7 @@ def eps_sigma_to_eps_complex(eps_real: float, sigma: float, freq: float) -> comp return eps_real + 1j * sigma / omega / EPSILON_0 @staticmethod - def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> Tuple[float, float]: + def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> tuple[float, float]: """Convert complex permittivity at frequency ``freq`` to permittivity and conductivity values. @@ -1243,7 +1251,7 @@ def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> Tuple[float, Returns ------- - Tuple[float, float] + tuple[float, float] Real part of relative permittivity & electric conductivity. """ eps_real, eps_imag = eps_complex.real, eps_complex.imag @@ -1252,7 +1260,7 @@ def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> Tuple[float, return eps_real, sigma @staticmethod - def eps_complex_to_eps_loss_tangent(eps_complex: complex) -> Tuple[float, float]: + def eps_complex_to_eps_loss_tangent(eps_complex: complex) -> tuple[float, float]: """Convert complex permittivity to permittivity and loss tangent. Parameters @@ -1262,7 +1270,7 @@ def eps_complex_to_eps_loss_tangent(eps_complex: complex) -> Tuple[float, float] Returns ------- - Tuple[float, float] + tuple[float, float] Real part of relative permittivity & loss tangent """ eps_real, eps_imag = eps_complex.real, eps_complex.imag @@ -1362,7 +1370,7 @@ def sel_inside(self, bounds: Bound) -> AbstractMedium: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -1429,7 +1437,7 @@ def __repr__(self): class AbstractCustomMedium(AbstractMedium, ABC): """A spatially varying medium.""" - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "nearest", title="Interpolation method", description="Interpolation method to obtain permittivity values " @@ -1439,7 +1447,7 @@ class AbstractCustomMedium(AbstractMedium, ABC): "the extrapolated value will take the minimal (maximal) of the supplied data.", ) - subpixel: bool = pd.Field( + subpixel: bool = Field( False, title="Subpixel averaging", description="If ``True``, apply the subpixel averaging method specified by " @@ -1460,7 +1468,7 @@ def _interp_method(self, comp: Axis) -> InterpMethod: @abstractmethod def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -1470,7 +1478,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -1494,7 +1502,7 @@ def eps_diagonal_on_grid( self, frequency: float, coords: Coords, - ) -> Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: + ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: """Spatial profile of main diagonal of the complex-valued permittivity at ``frequency`` interpolated at the supplied coordinates. @@ -1507,7 +1515,7 @@ def eps_diagonal_on_grid( Returns ------- - Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] + tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] The complex-valued permittivity tensor at ``frequency`` interpolated at the supplied coordinate. """ @@ -1564,7 +1572,7 @@ def eps_model(self, frequency: float) -> complex: ) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor at ``frequency``. Spatially, we take max{||eps||}, so that autoMesh generation works appropriately. @@ -1585,7 +1593,7 @@ def _get_real_vals(self, x: np.ndarray) -> np.ndarray: def _eps_bounds( self, frequency: float = None, eps_component: Optional[PermittivityComponent] = None - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Returns permittivity bounds for setting the color bounds when plotting. Parameters @@ -1600,7 +1608,7 @@ def _eps_bounds( Returns ------- - Tuple[float, float] + tuple[float, float] The min and max values of the permittivity for the selected component and evaluated at ``frequency``. """ eps_dataarray = self.eps_dataarray_freq(frequency) @@ -1614,7 +1622,7 @@ def _validate_isreal_dataarray(dataarray: CustomSpatialDataType) -> bool: @staticmethod def _validate_isreal_dataarray_tuple( - dataarray_tuple: Tuple[CustomSpatialDataType, ...], + dataarray_tuple: tuple[CustomSpatialDataType, ...], ) -> bool: """Validate that the dataarray is real""" return np.all([AbstractCustomMedium._validate_isreal_dataarray(f) for f in dataarray_tuple]) @@ -1631,7 +1639,7 @@ def sel_inside(self, bounds: Bound) -> AbstractCustomMedium: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -1736,8 +1744,8 @@ class PECMedium(AbstractMedium): """ - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): + @field_validator("modulation_spec") + def _validate_modulation_spec(cls, val, info): """Check compatibility with modulation_spec.""" if val is not None: raise ValidationError( @@ -1802,11 +1810,11 @@ class Medium(AbstractMedium): """ - permittivity: TracedFloat = pd.Field( + permittivity: TracedFloat = Field( 1.0, ge=1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY ) - conductivity: TracedFloat = pd.Field( + conductivity: TracedFloat = Field( 0.0, title="Conductivity", description="Electric conductivity. Defined such that the imaginary part of the complex " @@ -1814,43 +1822,43 @@ class Medium(AbstractMedium): units=CONDUCTIVITY, ) - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if ``allow_gain`` is False.""" - if not values.get("allow_gain") and val < 0: + val = self.conductivity + if not self.allow_gain and val < 0: raise ValidationError( "For passive medium, 'conductivity' must be non-negative. " "To simulate a gain medium, please set 'allow_gain=True'. " "Caution: simulations with a gain medium are unstable, and are likely to diverge." ) - return val + return self - @pd.validator("permittivity", always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _permittivity_modulation_validation(cls, val, values): + @model_validator(mode="after") + def _permittivity_modulation_validation(self): """Assert modulated permittivity cannot be <= 0.""" - modulation = values.get("modulation_spec") + val = self.permittivity + modulation = self.modulation_spec if modulation is None or modulation.permittivity is None: - return val + return self min_eps_inf = np.min(_get_numpy_array(val)) if min_eps_inf - modulation.permittivity.max_modulation <= 0: raise ValidationError( "The minimum permittivity value with modulation applied was found to be negative." ) - return val + return self - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["modulation_spec", "allow_gain"]) - def _passivity_modulation_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_modulation_validation(self): """Assert passive medium if ``allow_gain`` is False.""" - modulation = values.get("modulation_spec") + val = self.conductivity + modulation = self.modulation_spec if modulation is None or modulation.conductivity is None: - return val + return self min_sigma = np.min(_get_numpy_array(val)) - if not values.get("allow_gain") and min_sigma - modulation.conductivity.max_modulation < 0: + if not self.allow_gain and min_sigma - modulation.conductivity.max_modulation < 0: raise ValidationError( "For passive medium, 'conductivity' must be non-negative at any time." "With conductivity modulation, this medium can sometimes be active. " @@ -1858,7 +1866,7 @@ def _passivity_modulation_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self @cached_property def n_cfl(self): @@ -1990,14 +1998,13 @@ class CustomIsotropicMedium(AbstractCustomMedium, Medium): >>> eps = dielectric.eps_model(200e12) """ - permittivity: CustomSpatialDataTypeAnnotated = pd.Field( - ..., + permittivity: CustomSpatialDataTypeAnnotated = Field( title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY, ) - conductivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( + conductivity: Optional[CustomSpatialDataTypeAnnotated] = Field( None, title="Conductivity", description="Electric conductivity. Defined such that the imaginary part of the complex " @@ -2005,11 +2012,10 @@ class CustomIsotropicMedium(AbstractCustomMedium, Medium): units=CONDUCTIVITY, ) - _no_nans_eps = validate_no_nans("permittivity") - _no_nans_sigma = validate_no_nans("conductivity") + _no_nans = validate_no_nans("permittivity", "conductivity") - @pd.validator("permittivity", always=True) - def _eps_inf_greater_no_less_than_one(cls, val): + @field_validator("permittivity") + def _eps_inf_greater_no_less_than_one(val): """Assert any eps_inf must be >=1""" if not CustomIsotropicMedium._validate_isreal_dataarray(val): @@ -2020,34 +2026,34 @@ def _eps_inf_greater_no_less_than_one(cls, val): return val - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def _conductivity_real_and_correct_shape(cls, val, values): + @model_validator(mode="after") + def _conductivity_real_and_correct_shape(self): """Assert conductivity is real and of right shape.""" + val = self.conductivity if val is None: - return val + return self if not CustomIsotropicMedium._validate_isreal_dataarray(val): raise SetupError("'conductivity' must be real.") - if not _check_same_coordinates(values["permittivity"], val): + if not _check_same_coordinates(self.permittivity, val): raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") - return val + return self - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if ``allow_gain`` is False.""" + val = self.conductivity if val is None: - return val - if not values.get("allow_gain") and np.any(_get_numpy_array(val) < 0): + return self + if not self.allow_gain and np.any(_get_numpy_array(val) < 0): raise ValidationError( "For passive medium, 'conductivity' must be non-negative. " "To simulate a gain medium, please set 'allow_gain=True'. " "Caution: simulations with a gain medium are unstable, and are likely to diverge." ) - return val + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -2077,7 +2083,7 @@ def is_isotropic(self): def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -2087,7 +2093,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -2119,7 +2125,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -2162,7 +2168,7 @@ class CustomMedium(AbstractCustomMedium): >>> eps = dielectric.eps_model(200e12) """ - eps_dataset: Optional[PermittivityDataset] = pd.Field( + eps_dataset: Optional[PermittivityDataset] = Field( None, title="Permittivity Dataset", description="[To be deprecated] User-supplied dataset containing complex-valued " @@ -2170,14 +2176,14 @@ class CustomMedium(AbstractCustomMedium): "will be interpolated based on ``interp_method``.", ) - permittivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( + permittivity: Optional[CustomSpatialDataTypeAnnotated] = Field( None, title="Permittivity", description="Spatial profile of relative permittivity.", units=PERMITTIVITY, ) - conductivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( + conductivity: Optional[CustomSpatialDataTypeAnnotated] = Field( None, title="Conductivity", description="Spatial profile Electric conductivity. Defined such " @@ -2186,45 +2192,42 @@ class CustomMedium(AbstractCustomMedium): units=CONDUCTIVITY, ) - _no_nans_eps_dataset = validate_no_nans("eps_dataset") - _no_nans_permittivity = validate_no_nans("permittivity") - _no_nans_sigma = validate_no_nans("conductivity") + _no_nans = validate_no_nans("eps_dataset", "permittivity", "conductivity") - @pd.root_validator(pre=True) - def _warn_if_none(cls, values): + @model_validator(mode="before") + def _warn_if_none(cls, data: dict) -> dict: """Warn if the data array fails to load, and return a vacuum medium.""" - eps_dataset = values.get("eps_dataset") - permittivity = values.get("permittivity") - conductivity = values.get("conductivity") fail_load = False - if cls._not_loaded(permittivity): + if cls._not_loaded(data.get("permittivity")): log.warning( "Loading 'permittivity' without data; constructing a vacuum medium instead." ) fail_load = True - if cls._not_loaded(conductivity): + if cls._not_loaded(data.get("conductivity")): log.warning( "Loading 'conductivity' without data; constructing a vacuum medium instead." ) fail_load = True - if isinstance(eps_dataset, dict): - if any((v in DATA_ARRAY_MAP for _, v in eps_dataset.items() if isinstance(v, str))): + eps_ds = data.get("eps_dataset") + if isinstance(eps_ds, dict): + if any(isinstance(v, str) and v in DATA_ARRAY_MAP for v in eps_ds.values()): log.warning( "Loading 'eps_dataset' without data; constructing a vacuum medium instead." ) fail_load = True if fail_load: - eps_real = SpatialDataArray(np.ones((1, 1, 1)), coords=dict(x=[0], y=[0], z=[0])) - return dict(permittivity=eps_real) - return values + data["permittivity"] = SpatialDataArray( + np.ones((1, 1, 1)), coords=dict(x=[0], y=[0], z=[0]) + ) + return data - @pd.root_validator(pre=True) - def _deprecation_dataset(cls, values): + @model_validator(mode="after") + def _deprecation_dataset(self): """Raise deprecation warning if dataset supplied and convert to dataset.""" - eps_dataset = values.get("eps_dataset") - permittivity = values.get("permittivity") - conductivity = values.get("conductivity") + eps_dataset = self.eps_dataset + permittivity = self.permittivity + conductivity = self.conductivity # Incomplete custom medium definition. if eps_dataset is None and permittivity is None and conductivity is None: @@ -2240,7 +2243,7 @@ def _deprecation_dataset(cls, values): ) if eps_dataset is None: - return values + return self # TODO: sometime before 3.0, uncomment these lines to warn users to start using new API # if isinstance(eps_dataset, dict): @@ -2265,10 +2268,10 @@ def _deprecation_dataset(cls, values): # "We recommend you change your scripts to be compatible with the new API." # ) - return values + return self - @pd.validator("eps_dataset", always=True) - def _eps_dataset_single_frequency(cls, val): + @field_validator("eps_dataset") + def _eps_dataset_single_frequency(val): """Assert only one frequency supplied.""" if val is None: return val @@ -2282,13 +2285,13 @@ def _eps_dataset_single_frequency(cls, val): ) return val - @pd.validator("eps_dataset", always=True) - @skip_if_fields_missing(["modulation_spec", "allow_gain"]) - def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, values): + @model_validator(mode="after") + def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(self): """Assert any eps_inf must be >=1""" + val = self.eps_dataset if val is None: - return val - modulation = values.get("modulation_spec") + return self + modulation = self.modulation_spec for comp in ["eps_xx", "eps_yy", "eps_zz"]: eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma( @@ -2307,7 +2310,7 @@ def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, value "was found to be negative." ) - if not values.get("allow_gain") and np.any(_get_numpy_array(sigma) < 0): + if not self.allow_gain and np.any(_get_numpy_array(sigma) < 0): raise ValidationError( "For passive medium, imaginary part of permittivity must be non-negative. " "To simulate a gain medium, please set 'allow_gain=True'. " @@ -2316,7 +2319,7 @@ def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, value ) if ( - not values.get("allow_gain") + not self.allow_gain and modulation is not None and modulation.conductivity is not None and np.any(_get_numpy_array(sigma) - modulation.conductivity.max_modulation <= 0) @@ -2329,14 +2332,14 @@ def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, value "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self - @pd.validator("permittivity", always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _eps_inf_greater_no_less_than_one(cls, val, values): + @model_validator(mode="after") + def _eps_inf_greater_no_less_than_one(self): """Assert any eps_inf must be >=1""" + val = self.permittivity if val is None: - return val + return self if not CustomMedium._validate_isreal_dataarray(val): raise SetupError("'permittivity' must be real.") @@ -2344,29 +2347,29 @@ def _eps_inf_greater_no_less_than_one(cls, val, values): if np.any(_get_numpy_array(val) < 1): raise SetupError("'permittivity' must be no less than one.") - modulation = values.get("modulation_spec") + modulation = self.modulation_spec if modulation is None or modulation.permittivity is None: - return val + return self if np.any(_get_numpy_array(val) - modulation.permittivity.max_modulation <= 0): raise ValidationError( "The minimum permittivity value with modulation applied was found to be negative." ) - return val + return self - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity", "allow_gain"]) - def _conductivity_non_negative_correct_shape(cls, val, values): + @model_validator(mode="after") + def _conductivity_non_negative_correct_shape(self): """Assert conductivity>=0""" + val = self.conductivity if val is None: - return val + return self if not CustomMedium._validate_isreal_dataarray(val): raise SetupError("'conductivity' must be real.") - if not values.get("allow_gain") and np.any(_get_numpy_array(val) < 0): + if not self.allow_gain and np.any(_get_numpy_array(val) < 0): raise ValidationError( "For passive medium, 'conductivity' must be non-negative. " "To simulate a gain medium, please set 'allow_gain=True'. " @@ -2374,24 +2377,24 @@ def _conductivity_non_negative_correct_shape(cls, val, values): "and are likely to diverge." ) - if not _check_same_coordinates(values["permittivity"], val): + if not _check_same_coordinates(self.permittivity, val): raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") - return val + return self - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["eps_dataset", "modulation_spec", "allow_gain"]) - def _passivity_modulation_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_modulation_validation(self): """Assert passive medium at any time during modulation if ``allow_gain`` is False.""" + val = self.conductivity # validated already when the data is supplied through `eps_dataset` - if values.get("eps_dataset"): - return val + if self.eps_dataset: + return self # permittivity defined with ``permittivity`` and ``conductivity`` - modulation = values.get("modulation_spec") - if values.get("allow_gain") or modulation is None or modulation.conductivity is None: - return val + modulation = self.modulation_spec + if self.allow_gain or modulation is None or modulation.conductivity is None: + return self if val is None or np.any( _get_numpy_array(val) - modulation.conductivity.max_modulation < 0 ): @@ -2402,14 +2405,14 @@ def _passivity_modulation_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self - @pd.validator("permittivity", "conductivity", always=True) - def _check_permittivity_conductivity_interpolate(cls, val, values, field): + @field_validator("permittivity", "conductivity") + def _check_permittivity_conductivity_interpolate(val, info): """Check that the custom medium 'SpatialDataArrays' can be interpolated.""" if isinstance(val, SpatialDataArray): - val._interp_validator(field.name) + val._interp_validator(info.field_name) return val @@ -2505,7 +2508,7 @@ def n_cfl(self): def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. () Parameters @@ -2515,7 +2518,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -2540,7 +2543,7 @@ def eps_diagonal_on_grid( self, frequency: float, coords: Coords, - ) -> Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: + ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: """Spatial profile of main diagonal of the complex-valued permittivity at ``frequency`` interpolated at the supplied coordinates. @@ -2553,14 +2556,14 @@ def eps_diagonal_on_grid( Returns ------- - Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] + tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] The complex-valued permittivity tensor at ``frequency`` interpolated at the supplied coordinate. """ return self._medium.eps_diagonal_on_grid(frequency, coords) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor at ``frequency``. Spatially, we take max{|eps|}, so that autoMesh generation works appropriately. @@ -2731,7 +2734,7 @@ def from_nk( sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) return cls(permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs) - def grids(self, bounds: Bound) -> Dict[str, Grid]: + def grids(self, bounds: Bound) -> dict[str, Grid]: """Make a :class:`.Grid` corresponding to the data in each ``eps_ii`` component. The min and max coordinates along each dimension are bounded by ``bounds``.""" @@ -2742,7 +2745,7 @@ def grids(self, bounds: Bound) -> Dict[str, Grid]: def make_grid(scalar_field: Union[ScalarFieldDataArray, SpatialDataArray]) -> Grid: """Make a grid for a single dataset.""" - def make_bound_coords(coords: np.ndarray, pt_min: float, pt_max: float) -> List[float]: + def make_bound_coords(coords: np.ndarray, pt_min: float, pt_max: float) -> list[float]: """Convert user supplied coords into boundary coords to use in :class:`.Grid`.""" # get coordinates of the bondaries halfway between user-supplied data @@ -2787,7 +2790,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -2979,20 +2982,20 @@ class DispersiveMedium(AbstractMedium, ABC): def _permittivity_modulation_validation(): """Assert modulated permittivity cannot be <= 0 at any time.""" - @pd.validator("eps_inf", allow_reuse=True, always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _validate_permittivity_modulation(cls, val, values): + @model_validator(mode="after") + def _validate_permittivity_modulation(self): """Assert modulated permittivity cannot be <= 0.""" - modulation = values.get("modulation_spec") + val = self.eps_inf + modulation = self.modulation_spec if modulation is None or modulation.permittivity is None: - return val + return self min_eps_inf = np.min(_get_numpy_array(val)) if min_eps_inf - modulation.permittivity.max_modulation <= 0: raise ValidationError( "The minimum permittivity value with modulation applied was found to be negative." ) - return val + return self return _validate_permittivity_modulation @@ -3000,28 +3003,28 @@ def _validate_permittivity_modulation(cls, val, values): def _conductivity_modulation_validation(): """Assert passive medium at any time if not ``allow_gain``.""" - @pd.validator("modulation_spec", allow_reuse=True, always=True) - @skip_if_fields_missing(["allow_gain"]) - def _validate_conductivity_modulation(cls, val, values): + @model_validator(mode="after") + def _validate_conductivity_modulation(self): """With conductivity modulation, the medium can exhibit gain during the cycle. So `allow_gain` must be True when the conductivity is modulated. """ + val = self.modulation_spec if val is None or val.conductivity is None: - return val + return self - if not values.get("allow_gain"): + if not self.allow_gain: raise ValidationError( "For passive medium, 'conductivity' must be non-negative at any time. " "With conductivity modulation, this medium can sometimes be active. " "Please set 'allow_gain=True'. " "Caution: simulations with a gain medium are unstable, and are likely to diverge." ) - return val + return self return _validate_conductivity_modulation @abstractmethod - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model.""" @cached_property @@ -3045,14 +3048,14 @@ def n_cfl(self): return n @staticmethod - def tuple_to_complex(value: Tuple[float, float]) -> complex: + def tuple_to_complex(value: tuple[float, float]) -> complex: """Convert a tuple of real and imaginary parts to complex number.""" val_r, val_i = value return val_r + 1j * val_i @staticmethod - def complex_to_tuple(value: complex) -> Tuple[float, float]: + def complex_to_tuple(value: complex) -> tuple[float, float]: """Convert a complex number to a tuple of real and imaginary parts.""" return (value.real, value.imag) @@ -3097,34 +3100,33 @@ def _warn_if_data_none(nested_tuple_field: str): and return a vacuum with eps_inf = 1. """ - @pd.root_validator(pre=True, allow_reuse=True) - def _warn_if_none(cls, values): - """Warn if any of `eps_inf` and nested_tuple_field are not load.""" - eps_inf = values.get("eps_inf") - coeffs = values.get(nested_tuple_field) - fail_load = False + @model_validator(mode="before") + def _warn_if_none(cls, data: dict): + is_not_loaded = AbstractCustomMedium._not_loaded + + eps_inf = data.get("eps_inf") + coeffs = data.get(nested_tuple_field, ()) + + eps_bad = is_not_loaded(eps_inf) + coeff_bad = any(is_not_loaded(c) for coeff in coeffs for c in coeff) - if AbstractCustomMedium._not_loaded(eps_inf): + if not (eps_bad or coeff_bad): + return data + + if eps_bad: log.warning("Loading 'eps_inf' without data; constructing a vacuum medium instead.") - fail_load = True - for coeff in coeffs: - if fail_load: - break - for coeff_i in coeff: - if AbstractCustomMedium._not_loaded(coeff_i): - log.warning( - f"Loading '{nested_tuple_field}' without data; " - "constructing a vacuum medium instead." - ) - fail_load = True - break + if coeff_bad: + log.warning( + f"Loading '{nested_tuple_field}' without data; constructing a vacuum medium instead." + ) - if fail_load and eps_inf is None: - return {nested_tuple_field: ()} - if fail_load: - eps_inf = SpatialDataArray(np.ones((1, 1, 1)), coords=dict(x=[0], y=[0], z=[0])) - return {"eps_inf": eps_inf, nested_tuple_field: ()} - return values + data[nested_tuple_field] = () + if eps_inf is not None: + data["eps_inf"] = SpatialDataArray( + np.ones((1, 1, 1)), coords=dict(x=[0], y=[0], z=[0]) + ) + + return data return _warn_if_none @@ -3161,22 +3163,22 @@ class PoleResidue(DispersiveMedium): * `Modeling dispersive material in FDTD `_ """ - eps_inf: TracedPositiveFloat = pd.Field( + eps_inf: TracedPositiveFloat = Field( 1.0, title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - poles: Tuple[TracedPoleAndResidue, ...] = pd.Field( + poles: TracedPolesAndResidues = Field( (), title="Poles", description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", units=(RADPERSEC, RADPERSEC), ) - @pd.validator("poles", always=True) - def _causality_validation(cls, val): + @field_validator("poles") + def _causality_validation(val): """Assert causal medium.""" for a, _ in val: if np.any(np.real(_get_numpy_array(a)) > 0): @@ -3187,9 +3189,7 @@ def _causality_validation(cls, val): _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() @staticmethod - def _eps_model( - eps_inf: pd.PositiveFloat, poles: Tuple[PoleAndResidue, ...], frequency: float - ) -> complex: + def _eps_model(eps_inf: PositiveFloat, poles: PolesAndResidues, frequency: float) -> complex: """Complex-valued permittivity as a function of frequency.""" omega = 2 * np.pi * frequency @@ -3206,7 +3206,7 @@ def eps_model(self, frequency: float) -> complex: """Complex-valued permittivity as a function of frequency.""" return self._eps_model(eps_inf=self.eps_inf, poles=self.poles, frequency=frequency) - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model.""" return dict( @@ -3268,8 +3268,8 @@ def to_medium(self) -> Medium: @staticmethod def lo_to_eps_model( - poles: Tuple[Tuple[float, float, float, float], ...], - eps_inf: pd.PositiveFloat, + poles: tuple[tuple[float, float, float, float], ...], + eps_inf: PositiveFloat, frequency: float, ) -> complex: """Complex permittivity as a function of frequency for a given set of LO-TO coefficients. @@ -3278,10 +3278,10 @@ def lo_to_eps_model( Parameters ---------- - poles : Tuple[Tuple[float, float, float, float], ...] + poles : tuple[tuple[float, float, float, float], ...] The LO-TO poles, given as list of tuples of the form (omega_LO, gamma_LO, omega_TO, gamma_TO). - eps_inf: pd.PositiveFloat + eps_inf: PositiveFloat The relative permittivity at infinite frequency. frequency: float Frequency at which to evaluate the permittivity. @@ -3300,7 +3300,7 @@ def lo_to_eps_model( @classmethod def from_lo_to( - cls, poles: Tuple[Tuple[float, float, float, float], ...], eps_inf: pd.PositiveFloat = 1 + cls, poles: tuple[tuple[float, float, float, float], ...], eps_inf: PositiveFloat = 1 ) -> PoleResidue: """Construct a pole residue model from the LO-TO form (longitudinal and transverse optical modes). @@ -3312,10 +3312,10 @@ def from_lo_to( Parameters ---------- - poles : Tuple[Tuple[float, float, float, float], ...] + poles : tuple[tuple[float, float, float, float], ...] The LO-TO poles, given as list of tuples of the form (omega_LO, gamma_LO, omega_TO, gamma_TO). - eps_inf: pd.PositiveFloat + eps_inf: PositiveFloat The relative permittivity at infinite frequency. Returns @@ -3372,12 +3372,12 @@ def from_lo_to( return PoleResidue(eps_inf=eps_inf, poles=list(zip(a_coeffs, c_coeffs))) @staticmethod - def imag_ep_extrema(poles: Tuple[PoleAndResidue, ...]) -> ArrayFloat1D: + def imag_ep_extrema(poles: PolesAndResidues) -> ArrayFloat1D: """Extrema of Im[eps] in the same unit as poles. Parameters ---------- - poles: Tuple[PoleAndResidue, ...] + poles: PolesAndResidues Tuple of complex-valued (``a_i, c_i``) poles for the model. """ @@ -3469,7 +3469,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM @classmethod def _real_partial_fraction_decomposition( - cls, a: np.ndarray, b: np.ndarray, tol: pd.PositiveFloat = 1e-2 + cls, a: np.ndarray, b: np.ndarray, tol: PositiveFloat = 1e-2 ) -> tuple[list[tuple[Complex, Complex]], np.ndarray]: """Computes the complex conjugate pole residue pairs given a rational expression with real coefficients. @@ -3481,7 +3481,7 @@ def _real_partial_fraction_decomposition( Coefficients of the numerator polynomial in increasing monomial order. b : np.ndarray Coefficients of the denominator polynomial in increasing monomial order. - tol : pd.PositiveFloat + tol : PositiveFloat Tolerance for pole finding. Two poles are considered equal, if their spacing is less than ``tol``. @@ -3492,6 +3492,7 @@ def _real_partial_fraction_decomposition( ``tuple`` is an array of coefficients representing any direct polynomial term. """ + from scipy import signal if a.ndim != 1 or np.any(np.iscomplex(a)): raise ValidationError( @@ -3537,7 +3538,7 @@ def _real_partial_fraction_decomposition( r_filtered.append(res) p_filtered.append(pole) - poles_residues = list(zip(p_filtered, r_filtered)) + poles_residues = tuple(zip(p_filtered, r_filtered)) k_increasing_order = np.flip(k) return (poles_residues, k_increasing_order) @@ -3546,8 +3547,8 @@ def from_admittance_coeffs( cls, a: np.ndarray, b: np.ndarray, - eps_inf: pd.PositiveFloat = 1, - pole_tol: pd.PositiveFloat = 1e-2, + eps_inf: PositiveFloat = 1, + pole_tol: PositiveFloat = 1e-2, ) -> PoleResidue: """Construct a :class:`.PoleResidue` model from an admittance function defining the relationship between the electric field and the polarization current density in the @@ -3559,9 +3560,9 @@ def from_admittance_coeffs( Coefficients of the numerator polynomial in increasing monomial order. b : np.ndarray Coefficients of the denominator polynomial in increasing monomial order. - eps_inf: pd.PositiveFloat + eps_inf: PositiveFloat The relative permittivity at infinite frequency. - pole_tol: pd.PositiveFloat + pole_tol: PositiveFloat Tolerance for the pole finding algorithm in Hertz. Two poles are considered equal, if their spacing is closer than ``pole_tol`. Returns @@ -3714,15 +3715,14 @@ class CustomPoleResidue(CustomDispersiveMedium, PoleResidue): * `Modeling dispersive material in FDTD `_ """ - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., + eps_inf: CustomSpatialDataTypeAnnotated = Field( title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - poles: Tuple[Tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( + poles: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( (), title="Poles", description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", @@ -3730,12 +3730,11 @@ class CustomPoleResidue(CustomDispersiveMedium, PoleResidue): ) ) - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_poles = validate_no_nans("poles") + _no_nans = validate_no_nans("eps_inf", "poles") _warn_if_none = CustomDispersiveMedium._warn_if_data_none("poles") - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): + @field_validator("eps_inf") + def _eps_inf_positive(val): """eps_inf must be positive""" if not CustomDispersiveMedium._validate_isreal_dataarray(val): raise SetupError("'eps_inf' must be real.") @@ -3743,19 +3742,19 @@ def _eps_inf_positive(cls, val): raise SetupError("'eps_inf' must be positive.") return val - @pd.validator("poles", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _poles_correct_shape(cls, val, values): + @model_validator(mode="after") + def _poles_correct_shape(self): """poles must have the same shape.""" + val = self.poles for coeffs in val: for coeff in coeffs: - if not _check_same_coordinates(coeff, values["eps_inf"]): + if not _check_same_coordinates(coeff, self.eps_inf): raise SetupError( "All pole coefficients 'a' and 'c' must have the same coordinates; " "The coordinates must also be consistent with 'eps_inf'." ) - return val + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -3771,7 +3770,7 @@ def is_spatially_uniform(self) -> bool: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -3781,7 +3780,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -3803,7 +3802,7 @@ def eps_dataarray_freq( eps = PoleResidue.eps_model(self, frequency) return (eps, eps, eps) - def poles_on_grid(self, coords: Coords) -> Tuple[Tuple[ArrayComplex3D, ArrayComplex3D], ...]: + def poles_on_grid(self, coords: Coords) -> tuple[tuple[ArrayComplex3D, ArrayComplex3D], ...]: """Spatial profile of poles interpolated at the supplied coordinates. Parameters @@ -3813,7 +3812,7 @@ def poles_on_grid(self, coords: Coords) -> Tuple[Tuple[ArrayComplex3D, ArrayComp Returns ------- - Tuple[Tuple[ArrayComplex3D, ArrayComplex3D], ...] + tuple[tuple[ArrayComplex3D, ArrayComplex3D], ...] The poles interpolated at the supplied coordinate. """ @@ -3876,7 +3875,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -3897,7 +3896,7 @@ def _sel_custom_data_inside(self, bounds: Bound): poles_reduced.append((pole.sel_inside(bounds), residue.sel_inside(bounds))) - return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced) + return self.updated_copy(eps_inf=eps_inf_reduced, poles=tuple(poles_reduced)) def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: """Compute adjoint derivatives for each of the ``fields`` given the multiplied E and D.""" @@ -4007,18 +4006,18 @@ class Sellmeier(DispersiveMedium): * `Modeling dispersive material in FDTD `_ """ - coeffs: Tuple[Tuple[float, pd.PositiveFloat], ...] = pd.Field( + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( title="Coefficients", description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", units=(None, MICROMETER + "^2"), ) - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val + val = self.coeffs + if self.allow_gain: + return self for B, _ in val: if B < 0: raise ValidationError( @@ -4027,10 +4026,10 @@ def _passivity_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self - @pd.validator("modulation_spec", always=True) - def _validate_permittivity_modulation(cls, val): + @field_validator("modulation_spec") + def _validate_permittivity_modulation(val): """Assert modulated permittivity cannot be <= 0.""" if val is None or val.permittivity is None: @@ -4062,7 +4061,7 @@ def eps_model(self, frequency: float) -> complex: n = self._n_model(frequency) return AbstractMedium.nk_to_eps_complex(n) - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model""" poles = [] for B, C in self.coeffs: @@ -4146,9 +4145,8 @@ class CustomSellmeier(CustomDispersiveMedium, Sellmeier): * `Modeling dispersive material in FDTD `_ """ - coeffs: Tuple[Tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( title="Coefficients", description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", units=(None, MICROMETER + "^2"), @@ -4156,11 +4154,10 @@ class CustomSellmeier(CustomDispersiveMedium, Sellmeier): ) _no_nans = validate_no_nans("coeffs") - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - @pd.validator("coeffs", always=True) - def _correct_shape_and_sign(cls, val): + @field_validator("coeffs") + def _correct_shape_and_sign(val): """every term in coeffs must have the same shape, and B>=0 and C>0.""" if len(val) == 0: return val @@ -4175,12 +4172,12 @@ def _correct_shape_and_sign(cls, val): raise SetupError("'C' must be positive.") return val - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val + val = self.coeffs + if self.allow_gain: + return self for B, _ in val: if np.any(_get_numpy_array(B) < 0): raise ValidationError( @@ -4189,7 +4186,7 @@ def _passivity_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -4200,7 +4197,7 @@ def is_spatially_uniform(self) -> bool: return False return True - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model.""" poles_dict = Sellmeier._pole_residue_dict(self) if len(self.coeffs) > 0: @@ -4209,7 +4206,7 @@ def _pole_residue_dict(self) -> Dict: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -4219,7 +4216,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -4303,7 +4300,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -4325,7 +4322,7 @@ def _sel_custom_data_inside(self, bounds: Bound): coeffs_reduced.append((b_coeff.sel_inside(bounds), c_coeff.sel_inside(bounds))) - return self.updated_copy(coeffs=coeffs_reduced) + return self.updated_copy(coeffs=tuple(coeffs_reduced)) class Lorentz(DispersiveMedium): @@ -4356,21 +4353,20 @@ class Lorentz(DispersiveMedium): * `Modeling dispersive material in FDTD `_ """ - eps_inf: pd.PositiveFloat = pd.Field( + eps_inf: PositiveFloat = Field( 1.0, title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[Tuple[float, float, pd.NonNegativeFloat], ...] = pd.Field( - ..., + coeffs: tuple[tuple[float, float, NonNegativeFloat], ...] = Field( title="Coefficients", description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", units=(PERMITTIVITY, HERTZ, HERTZ), ) - @pd.validator("coeffs", always=True) + @field_validator("coeffs") def _coeffs_unequal_f_delta(cls, val): """f**2 and delta**2 cannot be exactly the same.""" for _, f, delta in val: @@ -4378,12 +4374,12 @@ def _coeffs_unequal_f_delta(cls, val): raise SetupError("'f' and 'delta' cannot take equal values.") return val - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if ``allow_gain`` is False.""" - if values.get("allow_gain"): - return val + val = self.coeffs + if self.allow_gain: + return self for del_ep, _, _ in val: if del_ep < 0: raise ValidationError( @@ -4392,7 +4388,7 @@ def _passivity_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() @@ -4406,7 +4402,7 @@ def eps_model(self, frequency: float) -> complex: eps = eps + (de * f**2) / (f**2 - 2j * frequency * delta - frequency**2) return eps - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model.""" poles = [] @@ -4537,34 +4533,30 @@ class CustomLorentz(CustomDispersiveMedium, Lorentz): * `Modeling dispersive material in FDTD `_ """ - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., + eps_inf: CustomSpatialDataTypeAnnotated = Field( title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[ - Tuple[ + coeffs: tuple[ + tuple[ CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated, ], ..., - ] = pd.Field( - ..., + ] = Field( title="Coefficients", description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", units=(PERMITTIVITY, HERTZ, HERTZ), ) - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - + _no_nans = validate_no_nans("eps_inf", "coeffs") _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): + @field_validator("eps_inf") + def _eps_inf_positive(val): """eps_inf must be positive""" if not CustomDispersiveMedium._validate_isreal_dataarray(val): raise SetupError("'eps_inf' must be real.") @@ -4572,7 +4564,7 @@ def _eps_inf_positive(cls, val): raise SetupError("'eps_inf' must be positive.") return val - @pd.validator("coeffs", always=True) + @field_validator("coeffs") def _coeffs_unequal_f_delta(cls, val): """f and delta cannot be exactly the same. Not needed for now because we have a more strict @@ -4580,15 +4572,15 @@ def _coeffs_unequal_f_delta(cls, val): """ return val - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape(cls, val, values): + @model_validator(mode="after") + def _coeffs_correct_shape(self): """coeffs must have consistent shape.""" + val = self.coeffs for de, f, delta in val: if ( - not _check_same_coordinates(de, values["eps_inf"]) - or not _check_same_coordinates(f, values["eps_inf"]) - or not _check_same_coordinates(delta, values["eps_inf"]) + not _check_same_coordinates(de, self.eps_inf) + or not _check_same_coordinates(f, self.eps_inf) + or not _check_same_coordinates(delta, self.eps_inf) ): raise SetupError( "All terms in 'coeffs' must have the same coordinates; " @@ -4596,10 +4588,10 @@ def _coeffs_correct_shape(cls, val, values): ) if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, f, delta)): raise SetupError("All terms in 'coeffs' must be real.") - return val + return self - @pd.validator("coeffs", always=True) - def _coeffs_delta_all_smaller_or_larger_than_fi(cls, val): + @field_validator("coeffs") + def _coeffs_delta_all_smaller_or_larger_than_fi(val): """We restrict either all f**2>delta**2 or all f**2 bool: @@ -4641,7 +4633,7 @@ def is_spatially_uniform(self) -> bool: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -4651,7 +4643,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -4680,7 +4672,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -4712,7 +4704,7 @@ def _sel_custom_data_inside(self, bounds: Bound): (de.sel_inside(bounds), f.sel_inside(bounds), delta.sel_inside(bounds)) ) - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) class Drude(DispersiveMedium): @@ -4746,15 +4738,14 @@ class Drude(DispersiveMedium): * `Modeling dispersive material in FDTD `_ """ - eps_inf: pd.PositiveFloat = pd.Field( + eps_inf: PositiveFloat = Field( 1.0, title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[Tuple[float, pd.PositiveFloat], ...] = pd.Field( - ..., + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( title="Coefficients", description="List of (:math:`f_i, \\delta_i`) values for model.", units=(HERTZ, HERTZ), @@ -4772,7 +4763,7 @@ def eps_model(self, frequency: float) -> complex: eps = eps - (f**2) / (frequency**2 + 1j * frequency * delta) return eps - def _pole_residue_dict(self) -> Dict: + def _pole_residue_dict(self) -> dict: """Dict representation of Medium as a pole-residue model.""" poles = [] @@ -4839,29 +4830,25 @@ class CustomDrude(CustomDispersiveMedium, Drude): * `Modeling dispersive material in FDTD `_ """ - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., + eps_inf: CustomSpatialDataTypeAnnotated = Field( title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[Tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( title="Coefficients", description="List of (:math:`f_i, \\delta_i`) values for model.", units=(HERTZ, HERTZ), ) ) - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - + _no_nans = validate_no_nans("eps_inf", "coeffs") _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): + @field_validator("eps_inf") + def _eps_inf_positive(val): """eps_inf must be positive""" if not CustomDispersiveMedium._validate_isreal_dataarray(val): raise SetupError("'eps_inf' must be real.") @@ -4869,13 +4856,13 @@ def _eps_inf_positive(cls, val): raise SetupError("'eps_inf' must be positive.") return val - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape_and_sign(cls, val, values): + @model_validator(mode="after") + def _coeffs_correct_shape_and_sign(self): """coeffs must have consistent shape and sign.""" + val = self.coeffs for f, delta in val: - if not _check_same_coordinates(f, values["eps_inf"]) or not _check_same_coordinates( - delta, values["eps_inf"] + if not _check_same_coordinates(f, self.eps_inf) or not _check_same_coordinates( + delta, self.eps_inf ): raise SetupError( "All terms in 'coeffs' must have the same coordinates; " @@ -4885,7 +4872,7 @@ def _coeffs_correct_shape_and_sign(cls, val, values): raise SetupError("All terms in 'coeffs' must be real.") if np.any(_get_numpy_array(delta) <= 0): raise SetupError("For stable medium, 'delta' must be positive.") - return val + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -4900,7 +4887,7 @@ def is_spatially_uniform(self) -> bool: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -4910,7 +4897,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -4939,7 +4926,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -4964,7 +4951,7 @@ def _sel_custom_data_inside(self, bounds: Bound): coeffs_reduced.append((f.sel_inside(bounds), delta.sel_inside(bounds))) - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) class Debye(DispersiveMedium): @@ -4998,26 +4985,25 @@ class Debye(DispersiveMedium): * `Modeling dispersive material in FDTD `_ """ - eps_inf: pd.PositiveFloat = pd.Field( + eps_inf: PositiveFloat = Field( 1.0, title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[Tuple[float, pd.PositiveFloat], ...] = pd.Field( - ..., + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( title="Coefficients", description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", units=(PERMITTIVITY, SECOND), ) - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val + val = self.coeffs + if self.allow_gain: + return self for del_ep, _ in val: if del_ep < 0: raise ValidationError( @@ -5026,7 +5012,7 @@ def _passivity_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() @@ -5096,29 +5082,25 @@ class CustomDebye(CustomDispersiveMedium, Debye): * `Modeling dispersive material in FDTD `_ """ - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., + eps_inf: CustomSpatialDataTypeAnnotated = Field( title="Epsilon at Infinity", description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", units=PERMITTIVITY, ) - coeffs: Tuple[Tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( title="Coefficients", description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", units=(PERMITTIVITY, SECOND), ) ) - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - + _no_nans = validate_no_nans("eps_inf", "coeffs") _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): + @field_validator("eps_inf") + def _eps_inf_positive(val): """eps_inf must be positive""" if not CustomDispersiveMedium._validate_isreal_dataarray(val): raise SetupError("'eps_inf' must be real.") @@ -5126,13 +5108,13 @@ def _eps_inf_positive(cls, val): raise SetupError("'eps_inf' must be positive.") return val - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape(cls, val, values): + @model_validator(mode="after") + def _coeffs_correct_shape(self): """coeffs must have consistent shape.""" + val = self.coeffs for de, tau in val: - if not _check_same_coordinates(de, values["eps_inf"]) or not _check_same_coordinates( - tau, values["eps_inf"] + if not _check_same_coordinates(de, self.eps_inf) or not _check_same_coordinates( + tau, self.eps_inf ): raise SetupError( "All terms in 'coeffs' must have the same coordinates; " @@ -5140,13 +5122,13 @@ def _coeffs_correct_shape(cls, val, values): ) if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, tau)): raise SetupError("All terms in 'coeffs' must be real.") - return val + return self - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if ``allow_gain`` is False.""" - allow_gain = values.get("allow_gain") + val = self.coeffs + allow_gain = self.allow_gain for del_ep, tau in val: if np.any(_get_numpy_array(tau) <= 0): raise SetupError("For stable medium, 'tau_i' must be positive.") @@ -5157,7 +5139,7 @@ def _passivity_validation(cls, val, values): "Caution: simulations with a gain medium are unstable, " "and are likely to diverge." ) - return val + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -5172,7 +5154,7 @@ def is_spatially_uniform(self) -> bool: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -5182,7 +5164,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -5211,7 +5193,7 @@ def _sel_custom_data_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -5236,7 +5218,7 @@ def _sel_custom_data_inside(self, bounds: Bound): coeffs_reduced.append((de.sel_inside(bounds), tau.sel_inside(bounds))) - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) class SurfaceImpedanceFitterParam(Tidy3dBaseModel): @@ -5244,26 +5226,26 @@ class SurfaceImpedanceFitterParam(Tidy3dBaseModel): Internally, the quantity to be fitted is surface impedance divided by ``-1j * \\omega``. """ - max_num_poles: pd.PositiveInt = pd.Field( + max_num_poles: PositiveInt = Field( LOSSY_METAL_DEFAULT_MAX_POLES, title="Maximal Number Of Poles", description="Maximal number of poles in complex-conjugate pole residue model for " "fitting surface impedance.", ) - tolerance_rms: pd.NonNegativeFloat = pd.Field( + tolerance_rms: NonNegativeFloat = Field( LOSSY_METAL_DEFAULT_TOLERANCE_RMS, title="Tolerance In Fitting", description="Tolerance in fitting.", ) - frequency_sampling_points: pd.PositiveInt = pd.Field( + frequency_sampling_points: PositiveInt = Field( LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY, title="Number Of Sampling Frequencies", description="Number of sampling frequencies used in fitting.", ) - log_sampling: bool = pd.Field( + log_sampling: bool = Field( True, title="Frequencies Sampling In Log Scale", description="Whether to sample frequencies logarithmically (``True``), " @@ -5327,14 +5309,13 @@ class HammerstadSurfaceRoughness(AbstractSurfaceRoughness): and its Effect on Transmission Line Characteristics", Signal Integrity Journal, 2018. """ - rq: pd.PositiveFloat = pd.Field( - ..., + rq: PositiveFloat = Field( title="RMS Peak-to-Valley Height", description="RMS peak-to-valley height (Rq) of the surface roughness.", units=MICROMETER, ) - roughness_factor: float = pd.Field( + roughness_factor: float = Field( 2.0, title="Roughness Factor", description="Expected maximal increase in conductor losses due to roughness effect. " @@ -5397,14 +5378,13 @@ class HuraySurfaceRoughness(AbstractSurfaceRoughness): J. Eric Bracken, "A Causal Huray Model for Surface Roughness", DesignCon, 2012. """ - relative_area: pd.PositiveFloat = pd.Field( + relative_area: PositiveFloat = Field( 1, title="Relative Area", description="Relative area of the matte base compared to a flat surface", ) - coeffs: Tuple[Tuple[pd.PositiveFloat, pd.PositiveFloat], ...] = pd.Field( - ..., + coeffs: tuple[tuple[PositiveFloat, PositiveFloat], ...] = Field( title="Coefficients for surface ratio and sphere radius", description="List of (:math:`f_i, r_i`) values for model, where :math:`f_i` is " "the ratio of total sphere surface area to the flat surface area, and :math:`r_i` " @@ -5487,7 +5467,7 @@ class LossyMetalMedium(Medium): """ - allow_gain: Literal[False] = pd.Field( + allow_gain: Literal[False] = Field( False, title="Allow gain medium", description="Allow the medium to be active. Caution: " @@ -5497,11 +5477,11 @@ class LossyMetalMedium(Medium): "useful in some cases.", ) - permittivity: Literal[1] = pd.Field( + permittivity: Literal[1.0] = Field( 1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY ) - roughness: SurfaceRoughnessType = pd.Field( + roughness: Optional[SurfaceRoughnessType] = Field( None, title="Surface Roughness Model", description="Surface roughness model that applies a frequency-dependent scaling " @@ -5509,22 +5489,21 @@ class LossyMetalMedium(Medium): discriminator=TYPE_TAG_STR, ) - frequency_range: FreqBound = pd.Field( - ..., + frequency_range: FreqBound = Field( title="Frequency Range", description="Frequency range of validity for the medium.", units=(HERTZ, HERTZ), ) - fit_param: SurfaceImpedanceFitterParam = pd.Field( - SurfaceImpedanceFitterParam(), + fit_param: SurfaceImpedanceFitterParam = Field( + default_factory=SurfaceImpedanceFitterParam, title="Fitting Parameters For Surface Impedance", description="Parameters for fitting surface impedance divided by (-1j * omega) over " "the frequency range using pole-residue pair model.", ) - @pd.validator("frequency_range") - def _validate_frequency_range(cls, val): + @field_validator("frequency_range") + def _validate_frequency_range(val): """Validate that frequency range is finite and non-zero.""" for freq in val: if not np.isfinite(freq): @@ -5533,7 +5512,7 @@ def _validate_frequency_range(cls, val): raise ValidationError("Values in 'frequency_range' must be positive.") return val - @pd.validator("conductivity", always=True) + @field_validator("conductivity") def _positive_conductivity(cls, val): """Assert conductivity>0.""" if val <= 0: @@ -5541,7 +5520,7 @@ def _positive_conductivity(cls, val): return val @cached_property - def _fitting_result(self) -> Tuple[PoleResidue, float]: + def _fitting_result(self) -> tuple[PoleResidue, float]: """Fitted scaled surface impedance and residue.""" omega_data = self.Hz_to_angular_freq(self.sampling_frequencies) @@ -5615,7 +5594,7 @@ def sampling_frequencies(self) -> ArrayFloat1D: self.fit_param.frequency_sampling_points, ) - def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor for numerical considerations such as meshing and runtime estimation. @@ -5626,7 +5605,7 @@ def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, co Returns ------- - Tuple[complex, complex, complex] + tuple[complex, complex, complex] The diagonal elements of relative permittivity tensor relevant for numerical considerations evaluated at ``frequency``. """ @@ -5711,34 +5690,32 @@ class AnisotropicMedium(AbstractMedium): * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ """ - xx: IsotropicUniformMediumType = pd.Field( - ..., + xx: IsotropicUniformMediumType = Field( title="XX Component", description="Medium describing the xx-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - yy: IsotropicUniformMediumType = pd.Field( - ..., + yy: IsotropicUniformMediumType = Field( title="YY Component", description="Medium describing the yy-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - zz: IsotropicUniformMediumType = pd.Field( - ..., + zz: IsotropicUniformMediumType = Field( title="ZZ Component", description="Medium describing the zz-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - allow_gain: bool = pd.Field( + allow_gain: Optional[bool] = Field( None, title="Allow gain medium", description="This field is ignored. Please set ``allow_gain`` in each component", ) - @pd.validator("modulation_spec", always=True) + @field_validator("modulation_spec") + @classmethod def _validate_modulation_spec(cls, val): """Check compatibility with modulation_spec.""" if val is not None: @@ -5749,17 +5726,17 @@ def _validate_modulation_spec(cls, val): ) return val - @pd.root_validator(pre=True) - def _ignored_fields(cls, values): + @model_validator(mode="after") + def _ignored_fields(self): """The field is ignored.""" - if values.get("xx") is not None and values.get("allow_gain") is not None: + if self.xx is not None and self.allow_gain is not None: log.warning( "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." ) - return values + return self @cached_property - def components(self) -> Dict[str, Medium]: + def components(self) -> dict[str, Medium]: """Dictionary of diagonal medium components.""" return dict(xx=self.xx, yy=self.yy, zz=self.zz) @@ -5785,7 +5762,7 @@ def eps_model(self, frequency: float) -> complex: return np.mean(self.eps_diagonal(frequency), axis=0) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" eps_xx = self.xx.eps_model(frequency) @@ -5869,7 +5846,7 @@ def plot(self, freqs: float, ax: Ax = None) -> Ax: return ax @property - def elements(self) -> Dict[str, IsotropicUniformMediumType]: + def elements(self) -> dict[str, IsotropicUniformMediumType]: """The diagonal elements of the medium as a dictionary.""" return dict(xx=self.xx, yy=self.yy, zz=self.zz) @@ -5889,7 +5866,7 @@ def sel_inside(self, bounds: Bound): Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -5950,14 +5927,14 @@ class FullyAnisotropicMedium(AbstractMedium): * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ """ - permittivity: TensorReal = pd.Field( + permittivity: TensorReal = Field( [[1, 0, 0], [0, 1, 0], [0, 0, 1]], title="Permittivity", description="Relative permittivity tensor.", units=PERMITTIVITY, ) - conductivity: TensorReal = pd.Field( + conductivity: TensorReal = Field( [[0, 0, 0], [0, 0, 0], [0, 0, 0]], title="Conductivity", description="Electric conductivity tensor. Defined such that the imaginary part " @@ -5965,7 +5942,8 @@ class FullyAnisotropicMedium(AbstractMedium): units=CONDUCTIVITY, ) - @pd.validator("modulation_spec", always=True) + @field_validator("modulation_spec") + @classmethod def _validate_modulation_spec(cls, val): """Check compatibility with modulation_spec.""" if val is not None: @@ -5975,8 +5953,8 @@ def _validate_modulation_spec(cls, val): ) return val - @pd.validator("permittivity", always=True) - def permittivity_spd_and_ge_one(cls, val): + @field_validator("permittivity") + def permittivity_spd_and_ge_one(val): """Check that provided permittivity tensor is symmetric positive definite with eigenvalues >= 1. """ @@ -5989,14 +5967,14 @@ def permittivity_spd_and_ge_one(cls, val): return val - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def conductivity_commutes(cls, val, values): + @model_validator(mode="after") + def conductivity_commutes(self): """Check that the symmetric part of conductivity tensor commutes with permittivity tensor (that is, simultaneously diagonalizable). """ - perm = values.get("permittivity") + val = self.conductivity + perm = self.permittivity cond_sym = 0.5 * (val + val.T) comm_diff = np.abs(np.matmul(perm, cond_sym) - np.matmul(cond_sym, perm)) @@ -6005,14 +5983,14 @@ def conductivity_commutes(cls, val, values): "Main directions of conductivity and permittivity tensor do not coincide." ) - return val + return self - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): + @model_validator(mode="after") + def _passivity_validation(self): """Assert passive medium if ``allow_gain`` is False.""" - if values.get("allow_gain"): - return val + val = self.conductivity + if self.allow_gain: + return self cond_sym = 0.5 * (val + val.T) if np.any(np.linalg.eigvals(cond_sym) < -fp_eps): @@ -6022,7 +6000,7 @@ def _passivity_validation(cls, val, values): "To simulate a gain medium, please set 'allow_gain=True'. " "Caution: simulations with a gain medium are unstable, and are likely to diverge." ) - return val + return self @classmethod def from_diagonal(cls, xx: Medium, yy: Medium, zz: Medium, rotation: RotationType): @@ -6086,7 +6064,7 @@ def _to_diagonal(self) -> AnisotropicMedium: @cached_property def eps_sigma_diag( self, - ) -> Tuple[Tuple[float, float, float], Tuple[float, float, float], TensorReal]: + ) -> tuple[tuple[float, float, float], tuple[float, float, float], TensorReal]: """Main components of permittivity and conductivity tensors and their directions.""" perm_diag, vecs = np.linalg.eig(self.permittivity) @@ -6106,7 +6084,7 @@ def eps_model(self, frequency: float) -> complex: return np.mean(eps_diag) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" perm_diag, cond_diag, _ = self.eps_sigma_diag @@ -6233,28 +6211,25 @@ class CustomAnisotropicMedium(AbstractCustomMedium, AnisotropicMedium): * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ """ - xx: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., + xx: Union[IsotropicCustomMediumType, CustomMedium] = Field( title="XX Component", description="Medium describing the xx-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - yy: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., + yy: Union[IsotropicCustomMediumType, CustomMedium] = Field( title="YY Component", description="Medium describing the yy-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - zz: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., + zz: Union[IsotropicCustomMediumType, CustomMedium] = Field( title="ZZ Component", description="Medium describing the zz-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - interp_method: Optional[InterpMethod] = pd.Field( + interp_method: Optional[InterpMethod] = Field( None, title="Interpolation method", description="When the value is 'None', each component will follow its own " @@ -6262,52 +6237,38 @@ class CustomAnisotropicMedium(AbstractCustomMedium, AnisotropicMedium): "method specified by this field will override the one in each component.", ) - allow_gain: bool = pd.Field( + allow_gain: Optional[bool] = Field( None, title="Allow gain medium", description="This field is ignored. Please set ``allow_gain`` in each component", ) - subpixel: bool = pd.Field( + subpixel: Optional[bool] = Field( None, title="Subpixel averaging", description="This field is ignored. Please set ``subpixel`` in each component", ) - @pd.validator("xx", always=True) - def _isotropic_xx(cls, val): - """If it's `CustomMedium`, make sure it's isotropic.""" - if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The xx-component medium type is not isotropic.") - return val - - @pd.validator("yy", always=True) - def _isotropic_yy(cls, val): - """If it's `CustomMedium`, make sure it's isotropic.""" - if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The yy-component medium type is not isotropic.") - return val - - @pd.validator("zz", always=True) - def _isotropic_zz(cls, val): + @field_validator("xx", "yy", "zz") + def _isotropic_xx(val, info): """If it's `CustomMedium`, make sure it's isotropic.""" if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The zz-component medium type is not isotropic.") + raise SetupError(f"The {info.field_name}-component medium type is not isotropic.") return val - @pd.root_validator(pre=True) - def _ignored_fields(cls, values): + @model_validator(mode="after") + def _ignored_fields(self): """The field is ignored.""" - if values.get("xx") is not None: - if values.get("allow_gain") is not None: + if self.xx is not None: + if self.allow_gain is not None: log.warning( "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." ) - if values.get("subpixel") is not None: + if self.subpixel is not None: log.warning( "The field 'subpixel' is ignored. Please set 'subpixel' in each component." ) - return values + return self @cached_property def is_spatially_uniform(self) -> bool: @@ -6340,7 +6301,7 @@ def _interp_method(self, comp: Axis) -> InterpMethod: def eps_dataarray_freq( self, frequency: float - ) -> Tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: """Permittivity array at ``frequency``. Parameters @@ -6350,7 +6311,7 @@ def eps_dataarray_freq( Returns ------- - Tuple[ + tuple[ Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, @@ -6376,7 +6337,7 @@ def eps_dataarray_freq( def _eps_bounds( self, frequency: float = None, eps_component: Optional[PermittivityComponent] = None - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Returns permittivity bounds for setting the color bounds when plotting. Parameters @@ -6391,7 +6352,7 @@ def _eps_bounds( Returns ------- - Tuple[float, float] + tuple[float, float] The min and max values of the permittivity for the selected component and evaluated at ``frequency``. """ comps = ["xx", "yy", "zz"] @@ -6438,22 +6399,19 @@ class CustomAnisotropicMediumInternal(CustomAnisotropicMedium): >>> anisotropic_dielectric = CustomAnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) """ - xx: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., + xx: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( title="XX Component", description="Medium describing the xx-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - yy: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., + yy: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( title="YY Component", description="Medium describing the yy-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, ) - zz: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., + zz: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( title="ZZ Component", description="Medium describing the zz-component of the diagonal permittivity tensor.", discriminator=TYPE_TAG_STR, @@ -6466,7 +6424,7 @@ class CustomAnisotropicMediumInternal(CustomAnisotropicMedium): class AbstractPerturbationMedium(ABC, Tidy3dBaseModel): """Abstract class for medium perturbation.""" - subpixel: bool = pd.Field( + subpixel: bool = Field( True, title="Subpixel averaging", description="This value will be transferred to the resulting custom medium. That is, " @@ -6476,7 +6434,7 @@ class AbstractPerturbationMedium(ABC, Tidy3dBaseModel): "have an effect.", ) - perturbation_spec: Optional[Union[PermittivityPerturbation, IndexPerturbation]] = pd.Field( + perturbation_spec: Optional[Union[PermittivityPerturbation, IndexPerturbation]] = Field( None, title="Perturbation Spec", description="Specification of medium perturbation as one of predefined types.", @@ -6589,14 +6547,14 @@ class PerturbationMedium(Medium, AbstractPerturbationMedium): ... ) """ - permittivity_perturbation: Optional[ParameterPerturbation] = pd.Field( + permittivity_perturbation: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="List of heat and/or charge perturbations to permittivity.", units=PERMITTIVITY, ) - conductivity_perturbation: Optional[ParameterPerturbation] = pd.Field( + conductivity_perturbation: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="List of heat and/or charge perturbations to permittivity.", @@ -6619,15 +6577,15 @@ class PerturbationMedium(Medium, AbstractPerturbationMedium): allowed_complex=False, ) - @pd.root_validator(pre=True) - def _check_overdefining(cls, values): + @model_validator(mode="after") + def _check_overdefining(self): """Check that perturbation model is provided either directly or through ``perturbation_spec``, but not both. """ - perm_p = values.get("permittivity_perturbation") is not None - cond_p = values.get("conductivity_perturbation") is not None - p_spec = values.get("perturbation_spec") is not None + perm_p = self.permittivity_perturbation is not None + cond_p = self.conductivity_perturbation is not None + p_spec = self.perturbation_spec is not None if p_spec and (perm_p or cond_p): raise SetupError( @@ -6636,17 +6594,17 @@ def _check_overdefining(cls, values): "but not in both ways simultaneously." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_perturbation_spec_ranges(cls, values): + @model_validator(mode="after") + def _check_perturbation_spec_ranges(self): """Check perturbation ranges if defined as ``perturbation_spec``.""" - p_spec = values["perturbation_spec"] + p_spec = self.perturbation_spec if p_spec is None: - return values + return self - perm = values["permittivity"] - cond = values["conductivity"] + perm = self.permittivity + cond = self.conductivity if isinstance(p_spec, IndexPerturbation): eps_complex = Medium._eps_model( @@ -6674,7 +6632,7 @@ def _check_perturbation_spec_ranges(cls, values): allowed_real_range=(0.0, None), allowed_imag_range=None, ) - return values + return self def perturbed_copy( self, @@ -6805,7 +6763,7 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): ... ) """ - eps_inf_perturbation: Optional[ParameterPerturbation] = pd.Field( + eps_inf_perturbation: Optional[ParameterPerturbation] = Field( None, title="Perturbation of Epsilon at Infinity", description="Perturbations to relative permittivity at infinite frequency " @@ -6814,8 +6772,8 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): ) poles_perturbation: Optional[ - Tuple[Tuple[Optional[ParameterPerturbation], Optional[ParameterPerturbation]], ...] - ] = pd.Field( + tuple[tuple[Optional[ParameterPerturbation], Optional[ParameterPerturbation]], ...] + ] = Field( None, title="Perturbations of Poles", description="Perturbations to poles of the model.", @@ -6837,15 +6795,15 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): allowed_imag_range=[None, None], ) - @pd.root_validator(pre=True) - def _check_overdefining(cls, values): + @model_validator(mode="after") + def _check_overdefining(self): """Check that perturbation model is provided either directly or through ``perturbation_spec``, but not both. """ - eps_i_p = values.get("eps_inf_perturbation") is not None - poles_p = values.get("poles_perturbation") is not None - p_spec = values.get("perturbation_spec") is not None + eps_i_p = self.eps_inf_perturbation is not None + poles_p = self.poles_perturbation is not None + p_spec = self.perturbation_spec is not None if p_spec and (eps_i_p or poles_p): raise SetupError( @@ -6854,17 +6812,17 @@ def _check_overdefining(cls, values): "but not in both ways simultaneously." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_perturbation_spec_ranges(cls, values): + @model_validator(mode="after") + def _check_perturbation_spec_ranges(self): """Check perturbation ranges if defined as ``perturbation_spec``.""" - p_spec = values["perturbation_spec"] + p_spec = self.perturbation_spec if p_spec is None: - return values + return self - eps_inf = values["eps_inf"] - poles = values["poles"] + eps_inf = self.eps_inf + poles = self.poles if isinstance(p_spec, IndexPerturbation): eps_complex = PoleResidue._eps_model( @@ -6885,7 +6843,7 @@ def _check_perturbation_spec_ranges(cls, values): allowed_imag_range=None, ) - return values + return self def perturbed_copy( self, @@ -7029,8 +6987,7 @@ class Medium2D(AbstractMedium): """ - ss: IsotropicUniformMediumType = pd.Field( - ..., + ss: IsotropicUniformMediumType = Field( title="SS Component", description="Medium describing the ss-component of the diagonal permittivity tensor. " "The ss-component refers to the in-plane dimension of the medium that is the first " @@ -7040,8 +6997,7 @@ class Medium2D(AbstractMedium): discriminator=TYPE_TAG_STR, ) - tt: IsotropicUniformMediumType = pd.Field( - ..., + tt: IsotropicUniformMediumType = Field( title="TT Component", description="Medium describing the tt-component of the diagonal permittivity tensor. " "The tt-component refers to the in-plane dimension of the medium that is the second " @@ -7051,7 +7007,8 @@ class Medium2D(AbstractMedium): discriminator=TYPE_TAG_STR, ) - @pd.validator("modulation_spec", always=True) + @field_validator("modulation_spec") + @classmethod def _validate_modulation_spec(cls, val): """Check compatibility with modulation_spec.""" if val is not None: @@ -7061,20 +7018,20 @@ def _validate_modulation_spec(cls, val): ) return val - @skip_if_fields_missing(["ss"]) - @pd.validator("tt", always=True) - def _validate_inplane_pec(cls, val, values): + @model_validator(mode="after") + def _validate_inplane_pec(self): """ss/tt components must be both PEC or non-PEC.""" - if isinstance(val, PECMedium) != isinstance(values["ss"], PECMedium): + val = self.tt + if isinstance(val, PECMedium) != isinstance(self.ss, PECMedium): raise ValidationError( "Materials describing ss- and tt-components must be " "either both 'PECMedium', or non-'PECMedium'." ) - return val + return self @classmethod def _weighted_avg( - cls, meds: List[IsotropicUniformMediumType], weights: List[float] + cls, meds: list[IsotropicUniformMediumType], weights: list[float] ) -> Union[PoleResidue, PECMedium]: """Average ``meds`` with weights ``weights``.""" eps_inf = 1 @@ -7097,8 +7054,8 @@ def _weighted_avg( def volumetric_equivalent( self, axis: Axis, - adjacent_media: Tuple[MediumType3D, MediumType3D], - adjacent_dls: Tuple[float, float], + adjacent_media: tuple[MediumType3D, MediumType3D], + adjacent_dls: tuple[float, float], ) -> AnisotropicMedium: """Produces a 3D volumetric equivalent medium. The new medium has thickness equal to the average of the ``dls`` in the ``axis`` direction. @@ -7114,11 +7071,11 @@ def volumetric_equivalent( axis : Axis Index (0, 1, or 2 for x, y, or z respectively) of the normal direction to the 2D material. - adjacent_media : Tuple[MediumType3D, MediumType3D] + adjacent_media : tuple[MediumType3D, MediumType3D] The neighboring media on either side of the 2D material. The first element is directly on the - side of the 2D material in the supplied axis, and the second element is directly on the + side. - adjacent_dls : Tuple[float, float] + adjacent_dls : tuple[float, float] Each dl represents twice the thickness of the desired volumetric model on the respective side of the 2D material. @@ -7302,7 +7259,7 @@ def eps_model(self, frequency: float) -> complex: return np.mean(self.eps_diagonal(frequency=frequency), axis=0) @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> Tuple[complex, complex]: + def eps_diagonal(self, frequency: float) -> tuple[complex, complex]: """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" log.warning( "The permittivity of a 'Medium2D' is unphysical. " @@ -7314,7 +7271,7 @@ def eps_diagonal(self, frequency: float) -> Tuple[complex, complex]: eps_tt = self.tt.eps_model(frequency) return (eps_ss, eps_tt) - def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, complex]: + def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor for numerical considerations such as meshing and runtime estimation. @@ -7325,7 +7282,7 @@ def eps_diagonal_numerical(self, frequency: float) -> Tuple[complex, complex, co Returns ------- - Tuple[complex, complex, complex] + tuple[complex, complex, complex] The diagonal elements of relative permittivity tensor relevant for numerical considerations evaluated at ``frequency``. """ @@ -7390,7 +7347,7 @@ def sigma_model(self, freq: float) -> complex: return np.mean([self.ss.sigma_model(freq), self.tt.sigma_model(freq)], axis=0) @property - def elements(self) -> Dict[str, IsotropicUniformMediumType]: + def elements(self) -> dict[str, IsotropicUniformMediumType]: """The diagonal elements of the 2D medium as a dictionary.""" return dict(ss=self.ss, tt=self.tt) diff --git a/tidy3d/components/microwave/data/monitor_data.py b/tidy3d/components/microwave/data/monitor_data.py index aa4048b862..d4bb9787ba 100644 --- a/tidy3d/components/microwave/data/monitor_data.py +++ b/tidy3d/components/microwave/data/monitor_data.py @@ -4,8 +4,8 @@ from __future__ import annotations -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, model_validator from tidy3d.components.data.data_array import FieldProjectionAngleDataArray, FreqDataArray from tidy3d.components.data.monitor_data import DirectivityData @@ -62,14 +62,12 @@ class AntennaMetricsData(DirectivityData): John Wiley & Sons, Chapter 2.9 (2016). """ - power_incident: FreqDataArray = pd.Field( - ..., + power_incident: FreqDataArray = Field( title="Power incident", description="Array of values representing the incident power to an antenna.", ) - power_reflected: FreqDataArray = pd.Field( - ..., + power_reflected: FreqDataArray = Field( title="Power reflected", description="Array of values representing power reflected due to an impedance mismatch with the antenna.", ) @@ -196,10 +194,10 @@ def realized_gain(self) -> FieldProjectionAngleDataArray: partial_G = self.partial_realized_gain() return partial_G.Gtheta + partial_G.Gphi - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data diff --git a/tidy3d/components/mode/data/sim_data.py b/tidy3d/components/mode/data/sim_data.py index a01bf78375..fa3e8d65bb 100644 --- a/tidy3d/components/mode/data/sim_data.py +++ b/tidy3d/components/mode/data/sim_data.py @@ -2,17 +2,14 @@ from __future__ import annotations -from typing import Literal, Tuple +from typing import Literal -import pydantic.v1 as pd +from pydantic import Field from ...base import cached_property from ...data.monitor_data import ModeSolverData, PermittivityData from ...data.sim_data import AbstractYeeGridSimulationData -from ...types import ( - Ax, - PlotScale, -) +from ...types import Ax, PlotScale from ..simulation import ModeSimulation ModeSimulationMonitorDataType = PermittivityData @@ -21,17 +18,17 @@ class ModeSimulationData(AbstractYeeGridSimulationData): """Data associated with a mode solver simulation.""" - simulation: ModeSimulation = pd.Field( - ..., title="Mode simulation", description="Mode simulation associated with this data." + simulation: ModeSimulation = Field( + title="Mode simulation", + description="Mode simulation associated with this data.", ) - modes_raw: ModeSolverData = pd.Field( - ..., + modes_raw: ModeSolverData = Field( title="Raw Modes", description=":class:`.ModeSolverData` containing the field and effective index on unexpanded grid.", ) - data: Tuple[ModeSimulationMonitorDataType, ...] = pd.Field( + data: tuple[ModeSimulationMonitorDataType, ...] = Field( (), title="Monitor Data", description="List of monitor data " diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 41b32c91cd..9a879afb26 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -6,18 +6,25 @@ from functools import wraps from math import isclose -from typing import Dict, List, Tuple, Union +from typing import Union import numpy as np -import pydantic.v1 as pydantic import xarray as xr from matplotlib.collections import PatchCollection from matplotlib.patches import Rectangle +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + field_validator, + model_validator, +) from ...constants import C_0 from ...exceptions import SetupError, ValidationError from ...log import log -from ..base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from ..base import Tidy3dBaseModel, cached_property from ..boundary import PML, Absorber, Boundary, BoundarySpec, PECBoundary, StablePML from ..data.data_array import ( FreqModeDataArray, @@ -41,7 +48,6 @@ from ..structure import Structure from ..subpixel_spec import SurfaceImpedance from ..types import ( - TYPE_TAG_STR, ArrayComplex3D, ArrayComplex4D, ArrayFloat1D, @@ -56,12 +62,9 @@ Literal, PlotScale, Symmetry, + discriminated_union, ) -from ..validators import ( - validate_freqs_min, - validate_freqs_not_empty, - validate_mode_plane_radius, -) +from ..validators import validate_freqs_min, validate_freqs_not_empty, validate_mode_plane_radius from ..viz import make_ax, plot_params_pml # Importing the local solver may not work if e.g. scipy is not installed @@ -76,7 +79,7 @@ log.warning(IMPORT_ERROR_MSG) LOCAL_SOLVER_IMPORTED = False -FIELD = Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] +FIELD = tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] MODE_MONITOR_NAME = "<<>>" # Warning for field intensity at edges over total field intensity larger than this value @@ -85,9 +88,9 @@ # Maximum allowed size of the field data produced by the mode solver MAX_MODES_DATA_SIZE_GB = 20 -MODE_SIMULATION_TYPE = Union[Simulation, EMESimulation] -MODE_SIMULATION_DATA_TYPE = Union[SimulationData, EMESimulationData] -MODE_PLANE_TYPE = Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor] +MODE_SIMULATION_TYPE = discriminated_union(Union[Simulation, EMESimulation]) +MODE_SIMULATION_DATA_TYPE = discriminated_union(Union[SimulationData, EMESimulationData]) +MODE_PLANE_TYPE = discriminated_union(Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor]) # When using ``angle_rotation`` without a bend, use a very large effective radius EFFECTIVE_RADIUS_FACTOR = 10_000 @@ -131,45 +134,42 @@ class ModeSolver(Tidy3dBaseModel): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - simulation: MODE_SIMULATION_TYPE = pydantic.Field( - ..., + simulation: MODE_SIMULATION_TYPE = Field( title="Simulation", description="Simulation or EMESimulation defining all structures and mediums.", discriminator="type", ) - plane: MODE_PLANE_TYPE = pydantic.Field( - ..., + plane: MODE_PLANE_TYPE = Field( title="Plane", description="Cross-sectional plane in which the mode will be computed.", - discriminator=TYPE_TAG_STR, ) - mode_spec: ModeSpec = pydantic.Field( - ..., + mode_spec: ModeSpec = Field( title="Mode specification", description="Container with specifications about the modes to be solved for.", ) - freqs: FreqArray = pydantic.Field( - ..., title="Frequencies", description="A list of frequencies at which to solve." + freqs: FreqArray = Field( + title="Frequencies", + description="A list of frequencies at which to solve.", ) - direction: Direction = pydantic.Field( + direction: Direction = Field( "+", title="Propagation direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default is ``True``.", ) - fields: Tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -177,8 +177,8 @@ class ModeSolver(Tidy3dBaseModel): "like ``mode_area`` require all E-field components.", ) - @pydantic.validator("simulation", pre=True, always=True) - def _convert_to_simulation(cls, val): + @field_validator("simulation") + def _convert_to_simulation(val): """Convert to regular Simulation if e.g. JaxSimulation given.""" if hasattr(val, "to_simulation"): val = val.to_simulation()[0] @@ -188,8 +188,8 @@ def _convert_to_simulation(cls, val): ) return val - @pydantic.validator("plane", always=True) - def is_plane(cls, val): + @field_validator("plane") + def is_plane(val): """Raise validation error if not planar.""" if val.size.count(0.0) != 1: raise ValidationError(f"ModeSolver plane must be planar, given size={val}") @@ -198,33 +198,31 @@ def is_plane(cls, val): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pydantic.validator("plane", always=True) - @skip_if_fields_missing(["simulation"]) - def plane_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def plane_in_sim_bounds(self): """Check that the plane is at least partially inside the simulation bounds.""" - sim_center = values.get("simulation").center - sim_size = values.get("simulation").size - sim_box = Box(size=sim_size, center=sim_center) - - if not sim_box.intersects(val): + sim_box = Box(size=self.simulation.size, center=self.simulation.center) + if not sim_box.intersects(self.plane): raise SetupError("'ModeSolver.plane' must intersect 'ModeSolver.simulation'.") - return val + return self - def _post_init_validators(self) -> None: - validate_mode_plane_radius( - mode_spec=self.mode_spec, plane=self.plane, msg_prefix="Mode solver" + @property + def _post_init_validators(self) -> tuple: + return ( + lambda: validate_mode_plane_radius( + mode_spec=self.mode_spec, plane=self.plane, msg_prefix="Mode solver" + ), + lambda: self._warn_thick_pml( + simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec + ), ) - self._warn_thick_pml(simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec) @classmethod def _warn_thick_pml( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec, warn_str: str = "'ModeSolver'" ): """Warn if the pml covers a significant portion of the mode plane.""" - coord_0, coord_1 = cls._plane_grid( - simulation=simulation, - plane=plane, - ) + coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane) num_cells = [len(coord_0), len(coord_1)] effective_num_pml = cls._effective_num_pml( simulation=simulation, plane=plane, mode_spec=mode_spec @@ -258,7 +256,7 @@ def normal_axis_2d(self) -> Axis2D: return idx_plane.index(self.normal_axis) @staticmethod - def _solver_symmetry(simulation: Simulation, plane: Box) -> Tuple[Symmetry, Symmetry]: + def _solver_symmetry(simulation: Simulation, plane: Box) -> tuple[Symmetry, Symmetry]: """Get symmetry for solver for propagation along self.normal axis.""" normal_axis = plane.size.index(0.0) mode_symmetry = list(simulation.symmetry) @@ -266,10 +264,10 @@ def _solver_symmetry(simulation: Simulation, plane: Box) -> Tuple[Symmetry, Symm if simulation.center[dim] != plane.center[dim]: mode_symmetry[dim] = 0 _, solver_sym = plane.pop_axis(mode_symmetry, axis=normal_axis) - return solver_sym + return tuple(solver_sym) @cached_property - def solver_symmetry(self) -> Tuple[Symmetry, Symmetry]: + def solver_symmetry(self) -> tuple[Symmetry, Symmetry]: """Get symmetry for solver for propagation along self.normal axis.""" return self._solver_symmetry(simulation=self.simulation, plane=self.plane) @@ -337,7 +335,7 @@ def _solver_grid(self) -> Grid: ) @cached_property - def _num_cells_freqs_modes(self) -> Tuple[int, int, int]: + def _num_cells_freqs_modes(self) -> tuple[int, int, int]: """Get the number of spatial points, number of freqs, and number of modes requested.""" num_cells = np.prod(self._solver_grid.num_cells) num_modes = self.mode_spec.num_modes @@ -555,7 +553,7 @@ def rotated_structures_copy(self): return self.updated_copy(simulation=rotated_simulation, mode_spec=rotated_mode_spec) - def _rotate_structures(self) -> List[Structure]: + def _rotate_structures(self) -> list[Structure]: """Rotate the structures intersecting with modal plane by angle theta if bend_correction is enabeled for bend simulations.""" @@ -600,7 +598,7 @@ def _rotate_structures(self) -> List[Structure]: return rotated_structures @cached_property - def rotated_bend_center(self) -> List: + def rotated_bend_center(self) -> list: """Calculate the center at the rotated bend such that the modal plane is normal to the azimuthal direction of the bend.""" rotated_bend_center = list(self.plane.center) @@ -612,7 +610,7 @@ def rotated_bend_center(self) -> List: # # Leaving for future reference if needed # def _ref_data_straight( # self, mode_solver_data: ModeSolverData - # ) -> Dict[Union[ScalarModeFieldDataArray, ModeIndexDataArray]]: + # ) -> dict[Union[ScalarModeFieldDataArray, ModeIndexDataArray]]: # """Convert reference data to be centered at the monitor center.""" # # Reference solution stored @@ -636,7 +634,7 @@ def rotated_bend_center(self) -> List: def _car_2_cyn( self, mode_solver_data: ModeSolverData - ) -> Dict[Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray]]: + ) -> dict[Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray]]: """Convert cartesian fields to cylindrical fields centered at the rotated bend center.""" @@ -728,7 +726,7 @@ def _car_2_cyn( # # Leaving for future reference if needed # def _mode_rotation_straight( # self, - # solver_ref_data: Dict[Union[ModeSolverData]], + # solver_ref_data: dict[Union[ModeSolverData]], # solver: ModeSolver, # ) -> ModeSolverData: # """Rotate the mode solver solution from the reference plane @@ -836,7 +834,7 @@ def _car_2_cyn( def _mode_rotation( self, - solver_ref_data_cylindrical: Dict[ + solver_ref_data_cylindrical: dict[ Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray] ], solver: ModeSolver, @@ -991,7 +989,7 @@ def _bend_radius(self): return EFFECTIVE_RADIUS_FACTOR * largest_dim @cached_property - def bend_center(self) -> List: + def bend_center(self) -> list: """Computes the bend center based on plane center, angle_theta and angle_phi.""" _, id_bend_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d) @@ -1176,7 +1174,7 @@ def _data_on_yee_grid_relative(self, basis: ModeSolverData) -> ModeSolverData: return mode_solver_data - def _get_colocation_coordinates(self) -> Dict[str, ArrayFloat1D]: + def _get_colocation_coordinates(self) -> dict[str, ArrayFloat1D]: """Get colocation coordinates in the solver plane. Returns: @@ -1278,7 +1276,7 @@ def sim_data(self) -> MODE_SIMULATION_DATA_TYPE: :class:`.SimulationData` object containing the effective index and mode fields. """ monitor_data = self.data - new_monitors = list(self.simulation.monitors) + [monitor_data.monitor] + new_monitors = (*self.simulation.monitors, monitor_data.monitor) new_simulation = self.simulation.copy(update=dict(monitors=new_monitors)) if isinstance(new_simulation, Simulation): return SimulationData(simulation=new_simulation, data=(monitor_data,)) @@ -1364,9 +1362,9 @@ def _solver_eps(self, freq: float) -> ArrayComplex4D: def _solve_all_freqs( self, - coords: Tuple[ArrayFloat1D, ArrayFloat1D], - symmetry: Tuple[Symmetry, Symmetry], - ) -> Tuple[List[float], List[Dict[str, ArrayComplex4D]], List[EpsSpecType]]: + coords: tuple[ArrayFloat1D, ArrayFloat1D], + symmetry: tuple[Symmetry, Symmetry], + ) -> tuple[list[float], list[dict[str, ArrayComplex4D]], list[EpsSpecType]]: """Call the mode solver at all requested frequencies.""" fields = [] @@ -1383,10 +1381,10 @@ def _solve_all_freqs( def _solve_all_freqs_relative( self, - coords: Tuple[ArrayFloat1D, ArrayFloat1D], - symmetry: Tuple[Symmetry, Symmetry], - basis_fields: List[Dict[str, ArrayComplex4D]], - ) -> Tuple[List[float], List[Dict[str, ArrayComplex4D]], List[EpsSpecType]]: + coords: tuple[ArrayFloat1D, ArrayFloat1D], + symmetry: tuple[Symmetry, Symmetry], + basis_fields: list[dict[str, ArrayComplex4D]], + ) -> tuple[list[float], list[dict[str, ArrayComplex4D]], list[EpsSpecType]]: """Call the mode solver at all requested frequencies.""" fields = [] @@ -1426,9 +1424,9 @@ def _postprocess_solver_fields(solver_fields, normal_axis, plane, mode_spec, coo def _solve_single_freq( self, freq: float, - coords: Tuple[ArrayFloat1D, ArrayFloat1D], - symmetry: Tuple[Symmetry, Symmetry], - ) -> Tuple[float, Dict[str, ArrayComplex4D], EpsSpecType]: + coords: tuple[ArrayFloat1D, ArrayFloat1D], + symmetry: tuple[Symmetry, Symmetry], + ) -> tuple[float, dict[str, ArrayComplex4D], EpsSpecType]: """Call the mode solver at a single frequency. The fields are rotated from propagation coordinates back to global coordinates. @@ -1479,10 +1477,10 @@ def _postprocess_solver_fields_inverse(self, fields): def _solve_single_freq_relative( self, freq: float, - coords: Tuple[ArrayFloat1D, ArrayFloat1D], - symmetry: Tuple[Symmetry, Symmetry], - basis_fields: Dict[str, ArrayComplex4D], - ) -> Tuple[float, Dict[str, ArrayComplex4D], EpsSpecType]: + coords: tuple[ArrayFloat1D, ArrayFloat1D], + symmetry: tuple[Symmetry, Symmetry], + basis_fields: dict[str, ArrayComplex4D], + ) -> tuple[float, dict[str, ArrayComplex4D], EpsSpecType]: """Call the mode solver at a single frequency. Modes are computed as linear combinations of ``basis_fields``. """ @@ -1518,7 +1516,7 @@ def _rotate_field_coords(field: FIELD, normal_axis: Axis, plane: MODE_PLANE_TYPE @staticmethod def _weighted_coord_max( array: ArrayFloat2D, u: ArrayFloat1D, v: ArrayFloat1D - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """2D argmax for an array weighted in both directions.""" if not np.all(np.isfinite(array)): # make sure the array is valid return 0, 0 @@ -1536,7 +1534,7 @@ def _weighted_coord_max( return i, j @staticmethod - def _inverted_gauge(e_field: FIELD, diff_coords: Tuple[ArrayFloat1D, ArrayFloat1D]) -> bool: + def _inverted_gauge(e_field: FIELD, diff_coords: tuple[ArrayFloat1D, ArrayFloat1D]) -> bool: """Check if the lower xy region of the mode has a negative sign.""" dx, dy = diff_coords e_x, e_y = e_field[:2, :, :, 0] @@ -1566,11 +1564,11 @@ def _inverted_gauge(e_field: FIELD, diff_coords: Tuple[ArrayFloat1D, ArrayFloat1 @staticmethod def _process_fields( mode_fields: ArrayComplex4D, - mode_index: pydantic.NonNegativeInt, + mode_index: NonNegativeInt, normal_axis: Axis, plane: MODE_PLANE_TYPE, - diff_coords: Tuple[ArrayFloat1D, ArrayFloat1D], - ) -> Tuple[FIELD, FIELD]: + diff_coords: tuple[ArrayFloat1D, ArrayFloat1D], + ) -> tuple[FIELD, FIELD]: """Transform solver fields to simulation axes and set gauge.""" # Separate E and H fields (in solver coordinates) @@ -1697,7 +1695,7 @@ def _is_tensorial(self) -> bool: return abs(self.mode_spec.angle_theta) > 0 or self._has_fully_anisotropic_media @cached_property - def _intersecting_media(self) -> List: + def _intersecting_media(self) -> list: """List of media (including simulation background) intersecting the mode plane.""" total_structures = [self.simulation.scene.background_structure] total_structures += list(self.simulation.structures) @@ -1756,8 +1754,8 @@ def to_source( self, source_time: SourceTime, direction: Direction = None, - mode_index: pydantic.NonNegativeInt = 0, - num_freqs: pydantic.PositiveInt = 1, + mode_index: NonNegativeInt = 0, + num_freqs: PositiveInt = 1, **kwargs, ) -> ModeSource: """Creates :class:`.ModeSource` from a :class:`ModeSolver` instance plus additional @@ -1794,13 +1792,13 @@ def to_source( **kwargs, ) - def to_monitor(self, freqs: List[float] = None, name: str = None) -> ModeMonitor: + def to_monitor(self, freqs: list[float] = None, name: str = None) -> ModeMonitor: """Creates :class:`ModeMonitor` from a :class:`ModeSolver` instance plus additional specifications. Parameters ---------- - freqs : List[float] + freqs : list[float] Frequencies to include in Monitor (Hz). If not specified, passes ``self.freqs``. name : str @@ -1865,7 +1863,7 @@ def sim_with_source( self, source_time: SourceTime, direction: Direction = None, - mode_index: pydantic.NonNegativeInt = 0, + mode_index: NonNegativeInt = 0, ) -> Simulation: """Creates :class:`Simulation` from a :class:`ModeSolver`. Creates a copy of the ModeSolver's original simulation with a ModeSource added corresponding to @@ -1898,7 +1896,7 @@ def sim_with_source( @require_fdtd_simulation def sim_with_monitor( self, - freqs: List[float] = None, + freqs: list[float] = None, name: str = None, ) -> Simulation: """Creates :class:`.Simulation` from a :class:`ModeSolver`. Creates a copy of @@ -1907,7 +1905,7 @@ def sim_with_monitor( Parameters ---------- - freqs : List[float] = None + freqs : list[float] = None Frequencies to include in Monitor (Hz). If not specified, uses the frequencies from the mode solver. name : str @@ -2203,7 +2201,7 @@ def plot_grid( ) @classmethod - def _plane_grid(cls, simulation: Simulation, plane: Box) -> Tuple[Coords, Coords]: + def _plane_grid(cls, simulation: Simulation, plane: Box) -> tuple[Coords, Coords]: """Plane grid for mode solver.""" # Get the mode plane normal axis, center, and limits. _, _, _, t_axes = cls._center_and_lims(simulation=simulation, plane=plane) @@ -2219,7 +2217,7 @@ def _plane_grid(cls, simulation: Simulation, plane: Box) -> Tuple[Coords, Coords @classmethod def _effective_num_pml( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec - ) -> Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """Number of cells of the mode solver pml.""" coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane) @@ -2233,9 +2231,9 @@ def _effective_num_pml( @classmethod def _pml_thickness( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec - ) -> Tuple[ - Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat], - Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat], + ) -> tuple[ + tuple[NonNegativeFloat, NonNegativeFloat], + tuple[NonNegativeFloat, NonNegativeFloat], ]: """Thickness of the mode solver pml in the form ((plus0, minus0), (plus1, minus1)) @@ -2273,7 +2271,7 @@ def _pml_thickness( @classmethod def _mode_plane_size( cls, simulation: Simulation, plane: Box - ) -> Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """The size of the mode plane intersected with the simulation.""" _, h_lim, v_lim, _ = cls._center_and_lims(simulation=simulation, plane=plane) return h_lim[1] - h_lim[0], v_lim[1] - v_lim[0] @@ -2281,7 +2279,7 @@ def _mode_plane_size( @classmethod def _mode_plane_size_no_pml( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec - ) -> Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """The size of the remaining portion of the mode plane, after the pml has been removed.""" size = cls._mode_plane_size(simulation=simulation, plane=plane) @@ -2364,7 +2362,7 @@ def _plot_pml( return ax @staticmethod - def _center_and_lims(simulation: Simulation, plane: Box) -> Tuple[List, List, List, List]: + def _center_and_lims(simulation: Simulation, plane: Box) -> tuple[list, list, list, list]: """Get the mode plane center and limits.""" normal_axis = plane.size.index(0.0) @@ -2456,8 +2454,8 @@ def reduced_simulation_copy(self): # extract sub-simulation removing everything irrelevant new_sim = self.simulation.subsection( region=new_sim_box, - monitors=[], - sources=[], + monitors=tuple(), + sources=tuple(), grid_spec="identical", boundary_spec=new_bspec, remove_outside_custom_mediums=True, @@ -2468,10 +2466,10 @@ def reduced_simulation_copy(self): ) # Let's only validate mode solver where geometry validation is skipped: geometry replaced by its bounding # box - structures = [ + structures = tuple( strc.updated_copy(geometry=strc.geometry.bounding_box, deep=False) for strc in new_sim.structures - ] + ) # skip validation as it's validated already in subsection aux_new_sim = new_sim.updated_copy(structures=structures, deep=False, validate=False) # validate mode solver here where geometry is replaced by its bounding box diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index 290bae48c5..880b285a9f 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -2,10 +2,10 @@ from __future__ import annotations -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, field_validator, model_validator from ...constants import C_0 from ...exceptions import SetupError, ValidationError @@ -19,13 +19,7 @@ from ..monitor import ModeMonitor, ModeSolverMonitor, PermittivityMonitor from ..simulation import AbstractYeeGridSimulation, Simulation, validate_boundaries_for_zero_dims from ..source.field import ModeSource -from ..types import ( - TYPE_TAG_STR, - Ax, - Direction, - EMField, - FreqArray, -) +from ..types import Ax, Direction, EMField, FreqArray, discriminated_union from ..validators import validate_mode_plane_radius from .mode_solver import ModeSolver @@ -35,7 +29,7 @@ # should be very small -- otherwise, generating tmesh will fail or take a long time RUN_TIME = 1e-30 -MODE_PLANE_TYPE = Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor] +MODE_PLANE_TYPE = discriminated_union(Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor]) # attributes shared between ModeSimulation class and ModeSolver class @@ -114,31 +108,31 @@ class ModeSimulation(AbstractYeeGridSimulation): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - mode_spec: ModeSpec = pd.Field( - ..., + mode_spec: ModeSpec = Field( title="Mode specification", description="Container with specifications about the modes to be solved for.", ) - freqs: FreqArray = pd.Field( - ..., title="Frequencies", description="A list of frequencies at which to solve." + freqs: FreqArray = Field( + title="Frequencies", + description="A list of frequencies at which to solve.", ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Propagation direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default is ``True``.", ) - fields: Tuple[EMField, ...] = pd.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -146,8 +140,8 @@ class ModeSimulation(AbstractYeeGridSimulation): "like ``mode_area`` require all E-field components.", ) - boundary_spec: BoundarySpec = pd.Field( - BoundarySpec(), + boundary_spec: BoundarySpec = Field( + default_factory=BoundarySpec, title="Boundaries", description="Specification of boundary conditions along each dimension. If ``None``, " "PML boundary conditions are applied on all sides. This behavior is for " @@ -156,27 +150,27 @@ class ModeSimulation(AbstractYeeGridSimulation): "apply PML layers in the mode solver.", ) - monitors: Tuple[ModeSimulationMonitorType, ...] = pd.Field( + monitors: tuple[ModeSimulationMonitorType, ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " "Note: monitor names are used to access data after simulation is run.", ) - sources: Tuple[()] = pd.Field( + sources: tuple[()] = Field( (), title="Sources", description="Sources in the simulation. Note: sources are not supported in mode " "simulations.", ) - grid_spec: GridSpec = pd.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) - plane: MODE_PLANE_TYPE = pd.Field( + plane: Optional[MODE_PLANE_TYPE] = Field( None, title="Plane", description="Cross-sectional plane in which the mode will be computed. " @@ -184,62 +178,62 @@ class ModeSimulation(AbstractYeeGridSimulation): "the provided ``plane`` and the simulation geometry. " "If ``None``, the simulation must be 2D, and the plane will be the entire " "simulation geometry.", - discriminator=TYPE_TAG_STR, ) - @pd.validator("plane", always=True) - def is_plane(cls, val, values): + @field_validator("grid_spec") + def _validate_auto_grid_wavelength(val): + # abstract override, logic is handled in post-init to ensure freqs is defined + return val + + @field_validator("plane") + def _validate_planar(val): + if val.size.count(0.0) != 1: + raise ValidationError(f"'ModeSimulation.plane' must be planar, given 'size={val.size}'") + return val + + @model_validator(mode="before") + def is_plane(data): """Raise validation error if not planar.""" - if val is None: - sim_center = values.get("center") - sim_size = values.get("size") - val = Box(size=sim_size, center=sim_center) + if data.get("plane") is None: + val = Box(size=data.get("size"), center=data.get("center")) if val.size.count(0.0) != 1: raise ValidationError( "If the 'ModeSimulation' geometry is not planar, " "then 'plane' must be specified." ) - return val - if val.size.count(0.0) != 1: - raise ValidationError(f"'ModeSimulation.plane' must be planar, given 'size={val}'") - return val + data["plane"] = val + return data - @pd.validator("plane", always=True) - def plane_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def plane_in_sim_bounds(self): """Check that the plane is at least partially inside the simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") - sim_box = Box(size=sim_size, center=sim_center) - - if not sim_box.intersects(val): + sim_box = Box(size=self.size, center=self.center) + if not sim_box.intersects(self.plane): raise SetupError("'ModeSimulation.plane' must intersect 'ModeSimulation.geometry.") - return val + return self - @pd.validator("boundary_spec", always=True) - def boundaries_for_zero_dims(cls, val, values): + @model_validator(mode="after") + def boundaries_for_zero_dims(self): """Replace with periodic boundary along zero-size dimensions.""" + val = self.boundary_spec boundaries = [val.x, val.y, val.z] - size = values.get("size") - - for dim, size_dim in enumerate(size): + for dim, size_dim in enumerate(self.size): if size_dim == 0: boundaries[dim] = Boundary.periodic() - - return BoundarySpec(x=boundaries[0], y=boundaries[1], z=boundaries[2]) - - def _post_init_validators(self) -> None: - """Call validators taking `self` that get run after init.""" - validate_mode_plane_radius( - mode_spec=self.mode_spec, plane=self.plane, msg_prefix="'ModeSimulation'" + boundary_spec = BoundarySpec(x=boundaries[0], y=boundaries[1], z=boundaries[2]) + object.__setattr__(self, "boundary_spec", boundary_spec) + return self + + @property + def _post_init_validators(self): + """Return validators taking `self` that get run after init.""" + return ( + lambda: validate_mode_plane_radius( + mode_spec=self.mode_spec, plane=self.plane, msg_prefix="'ModeSimulation'" + ), + lambda: self._mode_solver, + lambda: self.grid, ) - _ = self._mode_solver - _ = self.grid - - @pd.validator("grid_spec", always=True) - def _validate_auto_grid_wavelength(cls, val, values): - """Handle the case where grid_spec is auto and wavelength is not provided.""" - # this is handled instead post-init to ensure freqs is defined - return val @cached_property def _mode_solver(self) -> ModeSolver: @@ -286,14 +280,14 @@ def _as_fdtd_sim(self) -> Simulation: **kwargs, run_time=RUN_TIME, grid_spec=grid_spec, - monitors=[], + monitors=(), ) @classmethod def from_simulation( cls, simulation: AbstractYeeGridSimulation, - wavelength: Optional[pd.PositiveFloat] = None, + wavelength: Optional[PositiveFloat] = None, **kwargs, ) -> ModeSimulation: """Creates :class:`.ModeSimulation` from a :class:`.AbstractYeeGridSimulation`. @@ -302,7 +296,7 @@ def from_simulation( ---------- simulation: :class:`.AbstractYeeGridSimulation` Starting simulation defining structures, grid, etc. - wavelength: Optional[pd.PositiveFloat] + wavelength: Optional[PositiveFloat] Wavelength used for automatic grid generation. Required if auto grid is used in ``grid_spec``. **kwargs @@ -350,7 +344,7 @@ def reduced_simulation_copy(self) -> ModeSimulation: @classmethod def from_mode_solver( - cls, mode_solver: ModeSolver, wavelength: Optional[pd.PositiveFloat] = None + cls, mode_solver: ModeSolver, wavelength: Optional[PositiveFloat] = None ) -> ModeSimulation: """Creates :class:`.ModeSimulation` from a :class:`.ModeSolver`. @@ -358,7 +352,7 @@ def from_mode_solver( ---------- simulation: :class:`.AbstractYeeGridSimulation` Starting simulation defining structures, grid, etc. - wavelength: Optional[pd.PositiveFloat] + wavelength: Optional[PositiveFloat] Wavelength used for automatic grid generation. Required if auto grid is used in ``grid_spec``. diff --git a/tidy3d/components/mode/solver.py b/tidy3d/components/mode/solver.py index 2cd12ef4de..330c4ec5c8 100644 --- a/tidy3d/components/mode/solver.py +++ b/tidy3d/components/mode/solver.py @@ -1,7 +1,5 @@ """Mode solver for propagating EM modes.""" -from typing import Tuple - import numpy as np import scipy.linalg as linalg import scipy.sparse as sp @@ -9,7 +7,7 @@ from ...constants import C_0, ETA_0, fp_eps, pec_val from ..base import Tidy3dBaseModel -from ..types import EpsSpecType, ModeSolverType, Numpy +from ..types import EpsSpecType, ModeSolverType from .derivatives import create_d_matrices as d_mats from .derivatives import create_s_matrices as s_mats from .transforms import angled_transform, radial_transform @@ -50,7 +48,7 @@ def compute_modes( direction="+", solver_basis_fields=None, plane_center: tuple[float, float] = None, - ) -> Tuple[Numpy, Numpy, EpsSpecType]: + ) -> tuple[np.ndarray, np.ndarray, EpsSpecType]: """ Solve for the modes of a waveguide cross-section. @@ -60,7 +58,7 @@ def compute_modes( Either a single 2D array defining the relative permittivity in the cross-section, or nine 2D arrays defining the permittivity at the Ex, Ey, and Ez locations of the Yee grid in the order xx, xy, xz, yx, yy, yz, zx, zy, zz. - coords : List[Numpy] + coords : List[np.ndarray] Two 1D arrays with each with size one larger than the corresponding axis of ``eps_cross``. Defines a (potentially non-uniform) Cartesian grid on which the modes are computed. @@ -92,7 +90,7 @@ def compute_modes( Returns ------- - Tuple[Numpy, Numpy, str] + tuple[np.ndarray, np.ndarray, str] The first array gives the E and H fields for all modes, the second one gives the complex effective index. The last variable describes permittivity characterization on the mode solver's plane ("diagonal", "tensorial_real", or "tensorial_complex"). @@ -954,19 +952,19 @@ def set_initial_vec(cls, Nx, Ny, is_tensorial=False): return vec_init.flatten("F") @classmethod - def eigs_to_effective_index(cls, eig_list: Numpy, mode_solver_type: ModeSolverType): + def eigs_to_effective_index(cls, eig_list: np.ndarray, mode_solver_type: ModeSolverType): """Convert obtained eigenvalues to n_eff and k_eff. Parameters ---------- - eig_list : Numpy + eig_list : np.ndarray Array of eigenvalues mode_solver_type : ModeSolverType The type of mode solver problems Returns ------- - Tuple[Numpy, Numpy] + tuple[np.ndarray, np.ndarray] n_eff and k_eff """ if eig_list.size == 0: @@ -991,7 +989,7 @@ def format_medium_data(mat_data): the property at the E(H)x, E(H)y, and E(H)z locations of the Yee grid in the order xx, xy, xz, yx, yy, yz, zx, zy, zz. """ - if isinstance(mat_data, Numpy): + if isinstance(mat_data, np.ndarray): return (mat_data[i, :, :] for i in range(9)) if len(mat_data) == 9: return (np.copy(e) for e in mat_data) @@ -1044,6 +1042,6 @@ def mode_plane_contain_good_conductor(material_response) -> bool: return np.any(np.abs(material_response) > GOOD_CONDUCTOR_THRESHOLD * np.abs(pec_val)) -def compute_modes(*args, **kwargs) -> Tuple[Numpy, Numpy, str]: +def compute_modes(*args, **kwargs) -> tuple[np.ndarray, np.ndarray, str]: """A wrapper around ``EigSolver.compute_modes``, which is used in ``ModeSolver``.""" return EigSolver.compute_modes(*args, **kwargs) diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 4f63465bce..1169ec0bd7 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -1,16 +1,24 @@ """Defines specification for mode solver.""" from math import isclose -from typing import Tuple, Union +from typing import Literal, Optional, Union import numpy as np -import pydantic.v1 as pd - +from pydantic import ( + Field, + NonNegativeInt, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) + +from ..compat import Self from ..constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps from ..exceptions import SetupError, ValidationError from ..log import log -from .base import Tidy3dBaseModel, skip_if_fields_missing -from .types import Axis2D, Literal, TrackFreq +from .base import Tidy3dBaseModel +from .types import Axis2D, TrackFreq GROUP_INDEX_STEP = 0.005 @@ -57,21 +65,25 @@ class ModeSpec(Tidy3dBaseModel): """ - num_modes: pd.PositiveInt = pd.Field( - 1, title="Number of modes", description="Number of modes returned by mode solver." + num_modes: PositiveInt = Field( + 1, + title="Number of modes", + description="Number of modes returned by mode solver.", ) - target_neff: pd.PositiveFloat = pd.Field( - None, title="Target effective index", description="Guess for effective index of the mode." + target_neff: Optional[PositiveFloat] = Field( + None, + title="Target effective index", + description="Guess for effective index of the mode.", ) - num_pml: Tuple[pd.NonNegativeInt, pd.NonNegativeInt] = pd.Field( + num_pml: tuple[NonNegativeInt, NonNegativeInt] = Field( (0, 0), title="Number of PML layers", description="Number of standard pml layers to add in the two tangential axes.", ) - filter_pol: Literal["te", "tm"] = pd.Field( + filter_pol: Optional[Literal["te", "tm"]] = Field( None, title="Polarization filtering", description="The solver always computes the ``num_modes`` modes closest to the given " @@ -87,14 +99,14 @@ class ModeSpec(Tidy3dBaseModel): "``tm``-fraction uses the E field component parallel to the second plane axis.", ) - angle_theta: float = pd.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis.", units=RADIAN, ) - angle_phi: float = pd.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -102,7 +114,7 @@ class ModeSpec(Tidy3dBaseModel): units=RADIAN, ) - precision: Literal["auto", "single", "double"] = pd.Field( + precision: Literal["auto", "single", "double"] = Field( "auto", title="single, double, or automatic precision in mode solver", description="The solver will be faster and using less memory under " @@ -111,7 +123,7 @@ class ModeSpec(Tidy3dBaseModel): "conductor, single precision otherwise.", ) - bend_radius: float = pd.Field( + bend_radius: Optional[float] = Field( None, title="Bend radius", description="A curvature radius for simulation of waveguide bends. Can be negative, in " @@ -120,7 +132,7 @@ class ModeSpec(Tidy3dBaseModel): units=MICROMETER, ) - bend_axis: Axis2D = pd.Field( + bend_axis: Optional[Axis2D] = Field( None, title="Bend axis", description="Index into the two tangential axes defining the normal to the " @@ -129,20 +141,20 @@ class ModeSpec(Tidy3dBaseModel): "yz plane, the ``bend_axis`` is always 1 (the global z axis).", ) - angle_rotation: bool = pd.Field( + angle_rotation: bool = Field( False, - title="Use fields rotation when angle_theta is not zero", - description="Defines how modes are computed when angle_theta is not zero. " - "If 'False', a coordinate transformation is applied through the permittivity and permeability tensors." - "If 'True', the structures in the simulation are first rotated to compute a mode solution at " + title="Use fields rotation when ``angle_theta`` is not zero", + description="Defines how modes are computed when ``angle_theta`` is not zero. " + "If ``False``, a coordinate transformation is applied through the permittivity and permeability tensors." + "If ``True``, the structures in the simulation are first rotated to compute a mode solution at " "a reference plane normal to the structure's azimuthal direction. Then, the fields are rotated " - "to align with the mode plane, using the 'n_eff' calculated at the reference plane. The second option can " + "to align with the mode plane, using the ``n_eff`` calculated at the reference plane. The second option can " "produce more accurate results, but more care must be taken, for example, in ensuring that the " "original mode plane intersects the correct geometries in the simulation with rotated structures. " - "Note: currently only supported when 'angle_phi' is a multiple of 'np.pi'.", + "Note: currently only supported when ``angle_phi`` is a multiple of ``np.pi``.", ) - track_freq: Union[TrackFreq, None] = pd.Field( + track_freq: Union[TrackFreq, None] = Field( "central", title="Mode Tracking Frequency", description="Parameter that turns on/off mode tracking based on their similarity. " @@ -151,7 +163,7 @@ class ModeSpec(Tidy3dBaseModel): "If ``None`` no mode tracking is performed.", ) - group_index_step: Union[pd.PositiveFloat, bool] = pd.Field( + group_index_step: Union[PositiveFloat, bool] = Field( False, title="Frequency step for group index computation", description="Control the computation of the group index alongside the effective index. If " @@ -160,75 +172,74 @@ class ModeSpec(Tidy3dBaseModel): f"default of {GROUP_INDEX_STEP} is used.", ) - @pd.validator("bend_axis", always=True) - @skip_if_fields_missing(["bend_radius"]) - def bend_axis_given(cls, val, values): - """Check that ``bend_axis`` is provided if ``bend_radius`` is not ``None``""" - if val is None and values.get("bend_radius") is not None: - raise SetupError("'bend_axis' must also be defined if 'bend_radius' is defined.") + @field_validator("group_index_step", mode="before") + def _validate_group_index_step_default(val): + """If ``True``, replace with default fractional step.""" + if val is True: + return GROUP_INDEX_STEP return val - @pd.validator("bend_radius", always=True) - def bend_radius_not_zero(cls, val, values): - """Check that ``bend_raidus`` magnitude is not close to zero.`""" - if val is not None and isclose(val, 0): - raise SetupError("The magnitude of 'bend_radius' must be larger than 0.") + @field_validator("group_index_step") + def _validate_group_index_step_size(val): + """Ensure group-index step is < 1.""" + if val is not False and val >= 1: + raise ValidationError( + "Parameter 'group_index_step' must be a fractional value less than 1." + ) return val - @pd.validator("angle_theta", allow_reuse=True, always=True) - def glancing_incidence(cls, val): - """Warn if close to glancing incidence.""" - if np.abs(np.pi / 2 - val) < GLANCING_CUTOFF: + @field_validator("bend_radius") + def _validate_bend_radius_not_zero(v): + """`bend_radius` magnitude must be non-zero.""" + if v is not None and isclose(v, 0): + raise SetupError("The magnitude of 'bend_radius' must be larger than 0.") + return v + + @field_validator("angle_theta") + def _validate_angle_theta_glancing(val): + """Disallow incidence too close to glancing.""" + if abs(np.pi / 2 - val) < GLANCING_CUTOFF: raise SetupError( "Mode propagation axis too close to glancing angle for accurate injection. " "For best results, switch the injection axis." ) return val - # Must be executed before type validation by pydantic, otherwise True is converted to 1.0 - @pd.validator("group_index_step", pre=True) - def assign_default_on_true(cls, val): - """Assign the default fractional frequency step value if not provided.""" - if val is True: - return GROUP_INDEX_STEP - return val + @model_validator(mode="after") + def _check_bend_axis_given(self) -> Self: + """``bend_axis`` must be provided when ``bend_radius`` is set.""" + if self.bend_radius is not None and self.bend_axis is None: + raise SetupError("'bend_axis' must also be defined if 'bend_radius' is defined.") + return self - @pd.validator("group_index_step") - def check_group_step_size(cls, val): - """Ensure a reasonable group index step is used.""" - if val >= 1: + @model_validator(mode="after") + def _check_angle_rotation_with_phi(self) -> Self: + """``angle_rotation`` requires ``angle_phi`` % (π/2) == 0.""" + if self.angle_rotation and not isclose(self.angle_phi % (np.pi / 2), 0): raise ValidationError( - "Parameter 'group_index_step' is a fractional value. It must be less than 1." + "'angle_phi' must be a multiple of 'π/2' when 'angle_rotation' is enabled." + ) + return self + + @model_validator(mode="after") + def _check_precision(self) -> Self: + """Verify critical ``ModeSpec`` settings for group index calculation.""" + if self.group_index_step is False: + return self + + if self.track_freq is None: + log.warning( + "Group index calculation without mode tracking can lead to incorrect results " + "around mode crossings. Consider setting 'track_freq' to 'central'." ) - return val - @pd.root_validator(skip_on_failure=True) - def check_precision(cls, values): - """Verify critical ModeSpec settings for group index calculation.""" - if values["group_index_step"] > 0: - if values["track_freq"] is None: - log.warning( - "Group index calculation without mode tracking can lead to incorrect results " - "around mode crossings. Consider setting 'track_freq' to 'central'." - ) - - # multiply by 5 to be safe - if values["group_index_step"] < 5 * fp_eps and values["precision"] != "double": - log.warning( - "Group index step is too small! " - "The results might be fully corrupted by numerical errors. " - "For more accurate results, please consider using 'double' precision, " - "or increasing the value of 'group_index_step'." - ) - - return values - - @pd.validator("angle_rotation") - def angle_rotation_with_phi(cls, val, values): - """Currently ``angle_rotation`` is only supported with ``angle_phi % (np.pi / 2) == 0``.""" - if val and not isclose(values["angle_phi"] % (np.pi / 2), 0): - raise ValidationError( - "Parameter 'angle_phi' must be a multiple of 'np.pi / 2' when 'angle_rotation' is " - "enabled." + # multiply by 5 to be safe + if self.group_index_step < 5 * fp_eps and self.precision != "double": + log.warning( + "Group index step is too small! " + "The results might be fully corrupted by numerical errors. " + "For more accurate results, please consider using 'double' precision, " + "or increasing the value of 'group_index_step'." ) - return val + + return self diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index e16fc13c70..61a3ea7fb2 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -1,16 +1,22 @@ """Objects that define how data is recorded from simulation.""" from abc import ABC, abstractmethod -from typing import Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pydantic +from pydantic import ( + Field, + NonNegativeFloat, + PositiveInt, + field_validator, + model_validator, +) from ..constants import HERTZ, MICROMETER, RADIAN, SECOND, inf from ..exceptions import SetupError, ValidationError from ..log import log from .apodization import ApodizationSpec -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .base_sim.monitor import AbstractMonitor from .medium import MediumType from .mode_spec import ModeSpec @@ -48,7 +54,7 @@ class Monitor(AbstractMonitor): """Abstract base class for monitors.""" - interval_space: Tuple[Literal[1], Literal[1], Literal[1]] = pydantic.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -57,7 +63,7 @@ class Monitor(AbstractMonitor): "Not all monitors support values different from 1.", ) - colocate: Literal[True] = pydantic.Field( + colocate: Literal[True] = Field( True, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -83,15 +89,14 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FreqMonitor(Monitor, ABC): """:class:`Monitor` that records data in the frequency-domain.""" - freqs: FreqArray = pydantic.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Array or list of frequencies stored by the field monitor.", units=HERTZ, ) - apodization: ApodizationSpec = pydantic.Field( - ApodizationSpec(), + apodization: ApodizationSpec = Field( + default_factory=ApodizationSpec, title="Apodization Specification", description="Sets parameters of (optional) apodization. Apodization applies a windowing " "function to the Fourier transform of the time-domain fields into frequency-domain ones, " @@ -103,13 +108,13 @@ class FreqMonitor(Monitor, ABC): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pydantic.validator("freqs", always=True) - def _warn_num_freqs(cls, val, values): + @field_validator("freqs") + def _warn_num_freqs(val, info): """Warn if number of frequencies is too large.""" if len(val) > WARN_NUM_FREQS: log.warning( f"A large number ({len(val)}) of frequencies detected in monitor " - f"'{values['name']}'. This can lead to solver slow-down and increased cost. " + f"'{info.field_name}'. This can lead to solver slow-down and increased cost. " "Consider decreasing the number of frequencies in the monitor. This may become a " "hard limit in future Tidy3D versions.", custom_loc=["freqs"], @@ -122,7 +127,7 @@ def frequency_range(self) -> FreqBound: Returns ------- - Tuple[float, float] + tuple[float, float] Minimum and maximum frequencies of the frequency array. """ return (min(self.freqs), max(self.freqs)) @@ -131,14 +136,14 @@ def frequency_range(self) -> FreqBound: class TimeMonitor(Monitor, ABC): """:class:`Monitor` that records data in the time-domain.""" - start: pydantic.NonNegativeFloat = pydantic.Field( + start: NonNegativeFloat = Field( 0.0, title="Start Time", description="Time at which to start monitor recording.", units=SECOND, ) - stop: pydantic.NonNegativeFloat = pydantic.Field( + stop: Optional[NonNegativeFloat] = Field( None, title="Stop Time", description="Time at which to stop monitor recording. " @@ -146,7 +151,7 @@ class TimeMonitor(Monitor, ABC): units=SECOND, ) - interval: pydantic.PositiveInt = pydantic.Field( + interval: Optional[PositiveInt] = Field( None, title="Time Interval", description="Sampling rate of the monitor: number of time steps between each measurement. " @@ -155,14 +160,14 @@ class TimeMonitor(Monitor, ABC): "This can be useful for reducing data storage as needed by the application.", ) - @pydantic.validator("interval", always=True) - @skip_if_fields_missing(["start", "stop"]) - def _warn_interval_default(cls, val, values): + @model_validator(mode="after") + def _warn_interval_default(self): """If all defaults used for time sampler, warn and set ``interval=1`` internally.""" + val = self.interval if val is None: - start = values.get("start") - stop = values.get("stop") + start = self.start + stop = self.stop if start == 0.0 and stop is None: log.warning( "The monitor 'interval' field was left as its default value, " @@ -178,20 +183,20 @@ def _warn_interval_default(cls, val, values): ) # set 'interval = 1' for backwards compatibility - val = 1 + object.__setattr__(self, "interval", 1) - return val + return self - @pydantic.validator("stop", always=True, allow_reuse=True) - @skip_if_fields_missing(["start"]) - def stop_greater_than_start(cls, val, values): + @model_validator(mode="after") + def stop_greater_than_start(self): """Ensure sure stop is greater than or equal to start.""" - start = values.get("start") - if val and val < start: + stop = self.stop + start = self.start + if stop and stop < start: raise SetupError("Monitor start time is greater than stop time.") - return val + return self - def time_inds(self, tmesh: ArrayFloat1D) -> Tuple[int, int]: + def time_inds(self, tmesh: ArrayFloat1D) -> tuple[int, int]: """Compute the starting and stopping index of the monitor in a given discrete time mesh.""" tmesh = np.array(tmesh) @@ -230,23 +235,21 @@ def num_steps(self, tmesh: ArrayFloat1D) -> int: class AbstractFieldMonitor(Monitor, ABC): """:class:`Monitor` that records electromagnetic field data as a function of x,y,z.""" - fields: Tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor.", ) - interval_space: Tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " @@ -278,24 +281,22 @@ class AbstractAuxFieldMonitor(Monitor, ABC): :class:`.TwoPhotonAbsorption` uses `Nfx`, `Nfy`, and `Nfz` for the free-carrier density.""" - fields: Tuple[AuxField, ...] = pydantic.Field( + fields: tuple[AuxField, ...] = Field( (), title="Aux Field Components", description="Collection of auxiliary field components to store in the monitor. " "Auxiliary fields which are not present in the simulation will be zero.", ) - interval_space: Tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " @@ -330,19 +331,19 @@ def normal_axis(self) -> Axis: class AbstractModeMonitor(PlanarMonitor, FreqMonitor): """:class:`Monitor` that records mode-related data.""" - mode_spec: ModeSpec = pydantic.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", ) - store_fields_direction: Direction = pydantic.Field( + store_fields_direction: Optional[Direction] = Field( None, title="Store Fields", description="Propagation direction for the mode field profiles stored from mode solving.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " @@ -385,7 +386,7 @@ def plot( return ax @cached_property - def _dir_arrow(self) -> Tuple[float, float, float]: + def _dir_arrow(self) -> tuple[float, float, float]: """Source direction normal vector in cartesian coordinates.""" dx = np.cos(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta) dy = np.sin(self.mode_spec.angle_phi) * np.sin(self.mode_spec.angle_theta) @@ -401,13 +402,13 @@ def _bend_axis(self) -> Axis: direction = self.unpop_axis(0, in_plane, axis=self.normal_axis) return direction.index(1) - @pydantic.validator("mode_spec", always=True) - def _warn_num_modes(cls, val, values): + @field_validator("mode_spec") + def _warn_num_modes(val, info): """Warn if number of modes is too large.""" if val.num_modes > WARN_NUM_MODES: log.warning( f"A large number ({val.num_modes}) of modes requested in monitor " - f"'{values['name']}'. This can lead to solver slow-down and increased cost. " + f"'{info.field_name}'. This can lead to solver slow-down and increased cost. " "Consider decreasing the number of modes and using 'ModeSpec.target_neff' " "to target the modes of interest. This may become a hard limit in future " "Tidy3D versions.", @@ -558,25 +559,23 @@ class PermittivityMonitor(FreqMonitor): ... name='eps_monitor') """ - colocate: Literal[False] = pydantic.Field( + colocate: Literal[False] = Field( False, title="Colocate Fields", description="Colocation turned off, since colocated permittivity values do not have a " "physical meaning - they do not correspond to the subpixel-averaged ones.", ) - interval_space: Tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - apodization: ApodizationSpec = pydantic.Field( - ApodizationSpec(), + apodization: ApodizationSpec = Field( + default_factory=ApodizationSpec, title="Apodization Specification", description="This field is ignored in this monitor.", ) @@ -591,7 +590,7 @@ class SurfaceIntegrationMonitor(Monitor, ABC): """Abstract class for monitors that perform surface integrals during the solver run, as in flux and near to far transformations.""" - normal_dir: Direction = pydantic.Field( + normal_dir: Optional[Direction] = Field( None, title="Normal Vector Orientation", description="Direction of the surface monitor's normal vector w.r.t. " @@ -599,7 +598,7 @@ class SurfaceIntegrationMonitor(Monitor, ABC): "Applies to surface monitors only, and defaults to ``'+'`` if not provided.", ) - exclude_surfaces: Tuple[BoxSurface, ...] = pydantic.Field( + exclude_surfaces: Optional[tuple[BoxSurface, ...]] = Field( None, title="Excluded Surfaces", description="Surfaces to exclude in the integration, if a volume monitor.", @@ -612,38 +611,35 @@ def integration_surfaces(self): return self.surfaces_with_exclusion(**self.dict()) return [self] - @pydantic.root_validator(skip_on_failure=True) - def normal_dir_exists_for_surface(cls, values): + @model_validator(mode="after") + def normal_dir_exists_for_surface(self): """If the monitor is a surface, set default ``normal_dir`` if not provided. If the monitor is a box, warn that ``normal_dir`` is relevant only for surfaces.""" - normal_dir = values.get("normal_dir") - name = values.get("name") - size = values.get("size") - if size.count(0.0) != 1: - if normal_dir is not None: + if self.size.count(0.0) != 1: + if self.normal_dir is not None: log.warning( "The ``normal_dir`` field is relevant only for surface monitors " - f"and will be ignored for monitor {name}, which is a box." + f"and will be ignored for monitor {self.name}, which is a box." ) else: - if normal_dir is None: - values["normal_dir"] = "+" - return values + if self.normal_dir is None: + object.__setattr__(self, "normal_dir", "+") + return self - @pydantic.root_validator(skip_on_failure=True) - def check_excluded_surfaces(cls, values): + @model_validator(mode="after") + def check_excluded_surfaces(self): """Error if ``exclude_surfaces`` is provided for a surface monitor.""" - exclude_surfaces = values.get("exclude_surfaces") + exclude_surfaces = self.exclude_surfaces if exclude_surfaces is None: - return values - name = values.get("name") - size = values.get("size") + return self + name = self.name + size = self.size if size.count(0.0) > 0: raise SetupError( f"Can't specify ``exclude_surfaces`` for surface monitor {name}; " "valid for box monitors only." ) - return values + return self def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" @@ -789,14 +785,14 @@ class ModeSolverMonitor(AbstractModeMonitor): ... name='mode_monitor') """ - direction: Direction = pydantic.Field( + direction: Direction = Field( "+", title="Propagation Direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - fields: Tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -804,19 +800,19 @@ class ModeSolverMonitor(AbstractModeMonitor): "like ``mode_area`` require all E-field components.", ) - @pydantic.root_validator(skip_on_failure=True) - def set_store_fields(cls, values): + @model_validator(mode="after") + def set_store_fields(self): """Ensure 'store_fields_direction' is compatible with 'direction'.""" - store_fields_direction = values["store_fields_direction"] - direction = values["direction"] + store_fields_direction = self.store_fields_direction + direction = self.direction if store_fields_direction is None: - values["store_fields_direction"] = direction + object.__setattr__(self, "store_fields_direction", direction) elif store_fields_direction != direction: raise ValidationError( f"The values of 'direction' ({direction}) and 'store_fields_direction' " f"({store_fields_direction}) must be equal." ) - return values + return self def storage_size(self, num_cells: int, tmesh: int) -> int: """Size of monitor storage given the number of points after discretization.""" @@ -840,14 +836,12 @@ class FieldProjectionSurface(Tidy3dBaseModel): * `Performing near field to far field projections <../../notebooks/FieldProjections.html>`_ """ - monitor: FieldMonitor = pydantic.Field( - ..., + monitor: FieldMonitor = Field( title="Field Monitor", description=":class:`.FieldMonitor` on which near fields will be sampled and integrated.", ) - normal_dir: Direction = pydantic.Field( - ..., + normal_dir: Direction = Field( title="Normal Vector Orientation", description=":class:`.Direction` of the surface monitor's normal vector w.r.t.\ the positive x, y or z unit vectors. Must be one of '+' or '-'.", @@ -859,8 +853,8 @@ def axis(self) -> Axis: # assume that the monitor's axis is in the direction where the monitor is thinnest return self.monitor.size.index(0.0) - @pydantic.validator("monitor", always=True) - def is_plane(cls, val): + @field_validator("monitor") + def is_plane(val): """Ensures that the monitor is a plane, i.e., its ``size`` attribute has exactly 1 zero""" size = val.size if size.count(0.0) != 1: @@ -873,7 +867,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): and projects them to a given set of observation points. """ - custom_origin: Coordinate = pydantic.Field( + custom_origin: Optional[Coordinate] = Field( None, title="Local Origin", description="Local origin used for defining observation points. If ``None``, uses the " @@ -881,7 +875,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): units=MICROMETER, ) - far_field_approx: bool = pydantic.Field( + far_field_approx: bool = Field( True, title="Far Field Approximation", description="Whether to enable the far field approximation when projecting fields. " @@ -891,22 +885,20 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "in the far field of the device.", ) - interval_space: Tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals at which near fields are recorded for " - "projection to the far field, along each direction. If equal to 1, there will be no " - "downsampling. If greater than 1, the step will be applied, but the first and last " - "point of the monitor grid are always included. Using values greater than 1 can " - "help speed up server-side far field projections with minimal accuracy loss, " - "especially in cases where it is necessary for the grid resolution to be high for " - "the FDTD simulation, but such a high resolution is unnecessary for the purpose of " - "projecting the recorded near fields to the far field.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals at which near fields are recorded for " + "projection to the far field, along each direction. If equal to 1, there will be no " + "downsampling. If greater than 1, the step will be applied, but the first and last " + "point of the monitor grid are always included. Using values greater than 1 can " + "help speed up server-side far field projections with minimal accuracy loss, " + "especially in cases where it is necessary for the grid resolution to be high for " + "the FDTD simulation, but such a high resolution is unnecessary for the purpose of " + "projecting the recorded near fields to the far field.", ) - window_size: Tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat] = pydantic.Field( + window_size: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, 0), title="Spatial filtering window size", description="Size of the transition region of the windowing function used to ensure that " @@ -923,7 +915,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "and otherwise must remain (0, 0).", ) - medium: MediumType = pydantic.Field( + medium: Optional[MediumType] = Field( None, title="Projection medium", description="Medium through which to project fields. Generally, the fields should be " @@ -933,12 +925,12 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "non-default ``medium``.", ) - @pydantic.validator("window_size", always=True) - @skip_if_fields_missing(["size", "name"]) - def window_size_for_surface(cls, val, values): + @model_validator(mode="after") + def window_size_for_surface(self): """Ensures that windowing is applied for surface monitors only.""" - size = values.get("size") - name = values.get("name") + val = self.window_size + size = self.size + name = self.name if size.count(0.0) != 1: if val != (0, 0): @@ -946,22 +938,20 @@ def window_size_for_surface(cls, val, values): f"A non-zero 'window_size' cannot be used for projection monitor '{name}'. " "Windowing can be applied only for surface projection monitors." ) - return val + return self - @pydantic.validator("window_size", always=True) - @skip_if_fields_missing(["name"]) - def window_size_leq_one(cls, val, values): + @field_validator("window_size") + def window_size_leq_one(val, info): """Ensures that each component of the window size is less than or equal to 1.""" - name = values.get("name") if val[0] > 1 or val[1] > 1: raise ValidationError( - f"Each component of 'window_size' for monitor '{name}' " + f"Each component of 'window_size' for monitor '{info.field_name}' " "must be less than or equal to 1." ) return val @property - def projection_surfaces(self) -> Tuple[FieldProjectionSurface, ...]: + def projection_surfaces(self) -> tuple[FieldProjectionSurface, ...]: """Surfaces of the monitor where near fields will be recorded for subsequent projection.""" surfaces = self.integration_surfaces return [ @@ -985,7 +975,7 @@ def local_origin(self) -> Coordinate: return self.center return self.custom_origin - def window_parameters(self, custom_bounds: Bound = None) -> Tuple[Size, Coordinate, Coordinate]: + def window_parameters(self, custom_bounds: Bound = None) -> tuple[Size, Coordinate, Coordinate]: """Return the physical size of the window transition region based on the monitor's size and optional custom bounds (useful in case the monitor has infinite dimensions). The window size is returned in 3D. Also returns the coordinate where the transition region beings on @@ -1147,23 +1137,21 @@ class FieldProjectionAngleMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_: For far field projections in the context of perdiodic boundary conditions. """ - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Radial distance of the projection points from ``local_origin``.", units=MICROMETER, ) - theta: ObsGridArray = pydantic.Field( - ..., + theta: ObsGridArray = Field( title="Polar Angles", description="Polar angles with respect to the global z axis, relative to the location of " "``local_origin``, at which to project fields.", units=RADIAN, ) - phi: ObsGridArray = pydantic.Field( - ..., + phi: ObsGridArray = Field( title="Azimuth Angles", description="Azimuth angles with respect to the global z axis, relative to the location of " "``local_origin``, at which to project fields.", @@ -1206,13 +1194,13 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: self.freqs ) * 6 + BYTES_REAL * len(self.freqs) - @pydantic.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): @@ -1319,13 +1307,12 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - proj_axis: Axis = pydantic.Field( - ..., + proj_axis: Axis = Field( title="Projection Plane Axis", description="Axis along which the observation plane is oriented.", ) - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Signed distance of the projection plane along ``proj_axis``. " @@ -1333,8 +1320,7 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): units=MICROMETER, ) - x: ObsGridArray = pydantic.Field( - ..., + x: ObsGridArray = Field( title="Local x Observation Coordinates", description="Local x observation coordinates w.r.t. ``local_origin`` and ``proj_axis``. " "When ``proj_axis`` is 0, this corresponds to the global y axis. " @@ -1343,8 +1329,7 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): units=MICROMETER, ) - y: ObsGridArray = pydantic.Field( - ..., + y: ObsGridArray = Field( title="Local y Observation Coordinates", description="Local y observation coordinates w.r.t. ``local_origin`` and ``proj_axis``. " "When ``proj_axis`` is 0, this corresponds to the global z axis. " @@ -1428,21 +1413,19 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - proj_axis: Axis = pydantic.Field( - ..., + proj_axis: Axis = Field( title="Projection Plane Axis", description="Axis along which the observation plane is oriented.", ) - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Radial distance of the projection points from ``local_origin``.", units=MICROMETER, ) - ux: ObsGridArray = pydantic.Field( - ..., + ux: ObsGridArray = Field( title="Normalized kx", description="Local x component of wave vectors on the observation plane, " "relative to ``local_origin`` and oriented with respect to ``proj_axis``, " @@ -1450,8 +1433,7 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): "associated with the background medium. Must be in the range [-1, 1].", ) - uy: ObsGridArray = pydantic.Field( - ..., + uy: ObsGridArray = Field( title="Normalized ky", description="Local y component of wave vectors on the observation plane, " "relative to ``local_origin`` and oriented with respect to ``proj_axis``, " @@ -1459,17 +1441,17 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): "associated with the background medium. Must be in the range [-1, 1].", ) - @pydantic.root_validator() - def reciprocal_vector_range(cls, values): + @model_validator(mode="after") + def reciprocal_vector_range(self): """Ensure that ux, uy are in [-1, 1].""" - maxabs_ux = max(list(values.get("ux")), key=abs) - maxabs_uy = max(list(values.get("uy")), key=abs) - name = values.get("name") + maxabs_ux = max(list(self.ux), key=abs) + maxabs_uy = max(list(self.uy), key=abs) + name = self.name if maxabs_ux > 1: raise SetupError(f"Entries of 'ux' must lie in the range [-1, 1] for monitor {name}.") if maxabs_uy > 1: raise SetupError(f"Entries of 'uy' must lie in the range [-1, 1] for monitor {name}.") - return values + return self def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of monitor storage given the number of points after discretization.""" @@ -1499,7 +1481,7 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - normal_dir: Direction = pydantic.Field( + normal_dir: Direction = Field( "+", title="Normal Vector Orientation", description="Direction of the surface monitor's normal vector w.r.t. " @@ -1507,7 +1489,7 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): "Defaults to ``'+'`` if not provided.", ) - colocate: Literal[False] = pydantic.Field( + colocate: Literal[False] = Field( False, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -1515,8 +1497,8 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): "monitors depending on their specific function.", ) - @pydantic.validator("size", always=True) - def diffraction_monitor_size(cls, val): + @field_validator("size") + def diffraction_monitor_size(val): """Ensure that the monitor is infinite in the transverse direction.""" if val.count(inf) != 2: raise SetupError( diff --git a/tidy3d/components/parameter_perturbation.py b/tidy3d/components/parameter_perturbation.py index 20b4e5d8b4..0d6c8028dd 100644 --- a/tidy3d/components/parameter_perturbation.py +++ b/tidy3d/components/parameter_perturbation.py @@ -1,27 +1,25 @@ """Defines perturbations to properties of the medium / materials""" -from __future__ import annotations - import functools from abc import ABC, abstractmethod -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union -try: - import matplotlib.pyplot as plt -except ImportError: - pass import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, NonNegativeFloat, model_validator -from ..components.data.validators import validate_no_nans -from ..components.types import TYPE_TAG_STR, ArrayLike, Ax, Complex, FieldVal, InterpMethod -from ..components.viz import add_ax_if_none +from ..compat import Self from ..constants import C_0, CMCUBE, EPSILON_0, HERTZ, KELVIN, PERCMCUBE, inf from ..exceptions import DataError from ..log import log from .base import Tidy3dBaseModel, cached_property -from .data.data_array import ChargeDataArray, HeatDataArray, IndexedDataArray, SpatialDataArray +from .data.data_array import ( + ChargeDataArray, + HeatDataArray, + IndexedDataArray, + PerturbationCoefficientDataArray, + SpatialDataArray, +) from .data.unstructured.base import UnstructuredGridDataset from .data.utils import ( CustomSpatialDataType, @@ -29,6 +27,17 @@ _get_numpy_array, _zeros_like, ) +from .data.validators import validate_no_nans +from .types import ( + ArrayComplex, + ArrayFloat, + Ax, + Complex, + FieldVal, + InterpMethod, + discriminated_union, +) +from .viz import add_ax_if_none """ Generic perturbation classes """ @@ -38,7 +47,7 @@ class AbstractPerturbation(ABC, Tidy3dBaseModel): @cached_property @abstractmethod - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Complex]]: """Perturbation range.""" @cached_property @@ -47,7 +56,7 @@ def is_complex(self) -> bool: """Whether perturbation is complex valued.""" @staticmethod - def _linear_range(interval: Tuple[float, float], ref: float, coeff: Union[float, Complex]): + def _linear_range(interval: tuple[float, float], ref: float, coeff: Union[float, Complex]): """Find value range for a linear perturbation.""" if coeff in (0, 0j): # to avoid 0*inf return np.array([0, 0]) @@ -55,8 +64,8 @@ def _linear_range(interval: Tuple[float, float], ref: float, coeff: Union[float, @staticmethod def _get_val( - field: Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], val: FieldVal - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + field: Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], val: FieldVal + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Get specified value from a field.""" if val == "real": @@ -85,19 +94,19 @@ def _get_val( def ensure_temp_in_range( sample: Callable[ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ], ) -> Callable[ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ]: """Decorate ``sample`` to log warning if temperature supplied is out of bounds.""" @functools.wraps(sample) def _sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """New sample function.""" if np.iscomplexobj(temperature): @@ -118,7 +127,7 @@ def _sample( class HeatPerturbation(AbstractPerturbation): """Abstract class for heat perturbation.""" - temperature_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + temperature_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Temperature range", description="Temperature range in which perturbation model is valid.", @@ -127,14 +136,14 @@ class HeatPerturbation(AbstractPerturbation): @abstractmethod def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -144,8 +153,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -156,7 +165,7 @@ def sample( @add_ax_if_none def plot( self, - temperature: ArrayLike[float], + temperature: ArrayFloat, val: FieldVal = "real", ax: Ax = None, ) -> Ax: @@ -164,7 +173,7 @@ def plot( Parameters ---------- - temperature : ArrayLike[float] + temperature : ArrayFloat Array of temperature sample points. val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real' Which part of the field to plot. @@ -221,35 +230,33 @@ class LinearHeatPerturbation(HeatPerturbation): ... ) """ - temperature_ref: pd.NonNegativeFloat = pd.Field( - ..., + temperature_ref: NonNegativeFloat = Field( title="Reference temperature", description="Temperature at which perturbation is zero.", units=KELVIN, ) - coeff: Union[float, Complex] = pd.Field( - ..., + coeff: Union[float, Complex] = Field( title="Thermo-optic Coefficient", description="Sensitivity (derivative) of perturbation with respect to temperature.", units=f"1/{KELVIN}", ) @cached_property - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Complex]]: """Range of possible perturbation values in the provided ``temperature_range``.""" return self._linear_range(self.temperature_range, self.temperature_ref, self.coeff) @ensure_temp_in_range def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at temperature points. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -259,8 +266,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -309,13 +316,12 @@ class CustomHeatPerturbation(HeatPerturbation): ... ) """ - perturbation_values: HeatDataArray = pd.Field( - ..., + perturbation_values: HeatDataArray = Field( title="Perturbation Values", description="Sampled perturbation values.", ) - temperature_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + temperature_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Temperature range", description="Temperature range in which perturbation model is valid. For " @@ -324,7 +330,7 @@ class CustomHeatPerturbation(HeatPerturbation): units=KELVIN, ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "linear", title="Interpolation method", description="Interpolation method to obtain perturbation values between sample points.", @@ -333,15 +339,15 @@ class CustomHeatPerturbation(HeatPerturbation): _no_nans = validate_no_nans("perturbation_values") @cached_property - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Complex]]: """Range of possible parameter perturbation values.""" return np.min(self.perturbation_values).item(), np.max(self.perturbation_values).item() - @pd.root_validator(skip_on_failure=True) - def compute_temperature_range(cls, values): + @model_validator(mode="after") + def compute_temperature_range(self) -> Self: """Compute and set temperature range based on provided ``perturbation_values``.""" - perturbation_values = values["perturbation_values"] + perturbation_values = self.perturbation_values # .item() to convert to a scalar temperature_range = ( @@ -349,30 +355,27 @@ def compute_temperature_range(cls, values): np.max(perturbation_values.coords["T"]).item(), ) - if ( - values["temperature_range"] is not None - and values["temperature_range"] != temperature_range - ): + if self.temperature_range is not None and self.temperature_range != temperature_range: log.warning( "Temperature range for 'CustomHeatPerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'temperature_range' will be " "overwritten." ) - values.update({"temperature_range": temperature_range}) + object.__setattr__(self, "temperature_range", temperature_range) - return values + return self @ensure_temp_in_range def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at provided temperature points. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -382,8 +385,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -415,7 +418,7 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.perturbation_values) -HeatPerturbationType = Union[LinearHeatPerturbation, CustomHeatPerturbation] +HeatPerturbationType = discriminated_union(Union[LinearHeatPerturbation, CustomHeatPerturbation]) """ Elementary charge perturbation classes """ @@ -424,26 +427,26 @@ def is_complex(self) -> bool: def ensure_charge_in_range( sample: Callable[ [ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], ], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ], ) -> Callable[ [ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], ], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ]: """Decorate ``sample`` to log warning if charge supplied is out of bounds.""" @functools.wraps(sample) def _sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """New sample function.""" # disable complex input @@ -480,13 +483,13 @@ def _sample( class ChargePerturbation(AbstractPerturbation): """Abstract class for charge perturbation.""" - electron_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + electron_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Electron Density Range", description="Range of electrons densities in which perturbation model is valid.", ) - hole_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + hole_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Hole Density Range", description="Range of holes densities in which perturbation model is valid.", @@ -495,22 +498,22 @@ class ChargePerturbation(AbstractPerturbation): @abstractmethod def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -525,8 +528,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -537,8 +540,8 @@ def sample( @add_ax_if_none def plot( self, - electron_density: ArrayLike[float], - hole_density: ArrayLike[float], + electron_density: ArrayFloat, + hole_density: ArrayFloat, val: FieldVal = "real", ax: Ax = None, ) -> Ax: @@ -546,9 +549,9 @@ def plot( Parameters ---------- - electron_density : Union[ArrayLike[float], CustomSpatialDataType] + electron_density : Union[ArrayFloat, CustomSpatialDataType] Array of electron density sample points. - hole_density : Union[ArrayLike[float], CustomSpatialDataType] + hole_density : Union[ArrayFloat, CustomSpatialDataType] Array of hole density sample points. val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real' Which part of the field to plot. @@ -560,6 +563,7 @@ def plot( matplotlib.axes._subplots.Axes The supplied or created matplotlib axes. """ + import matplotlib.pyplot as plt values = self.sample(electron_density, hole_density) values = self._get_val(values, val) @@ -628,37 +632,33 @@ class LinearChargePerturbation(ChargePerturbation): ... ) """ - electron_ref: pd.NonNegativeFloat = pd.Field( - ..., + electron_ref: NonNegativeFloat = Field( title="Reference Electron Density", description="Electron density value at which there is no perturbation due to electrons's " "presence.", units=PERCMCUBE, ) - hole_ref: pd.NonNegativeFloat = pd.Field( - ..., + hole_ref: NonNegativeFloat = Field( title="Reference Hole Density", description="Hole density value at which there is no perturbation due to holes' presence.", units=PERCMCUBE, ) - electron_coeff: float = pd.Field( - ..., + electron_coeff: float = Field( title="Sensitivity to Electron Density", description="Sensitivity (derivative) of perturbation with respect to electron density.", units=CMCUBE, ) - hole_coeff: float = pd.Field( - ..., + hole_coeff: float = Field( title="Sensitivity to Hole Density", description="Sensitivity (derivative) of perturbation with respect to hole density.", units=CMCUBE, ) @cached_property - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Complex]]: """Range of possible perturbation values within provided ``electron_range`` and ``hole_range``. """ @@ -673,22 +673,22 @@ def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Comple @ensure_charge_in_range def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at electron and hole density points. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -704,8 +704,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -788,13 +788,12 @@ class CustomChargePerturbation(ChargePerturbation): ... ) """ - perturbation_values: ChargeDataArray = pd.Field( - ..., + perturbation_values: ChargeDataArray = Field( title="Petrubation Values", description="2D array (vs electron and hole densities) of sampled perturbation values.", ) - electron_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + electron_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Electron Density Range", description="Range of electrons densities in which perturbation model is valid. For " @@ -802,7 +801,7 @@ class CustomChargePerturbation(ChargePerturbation): "provided ``perturbation_values``", ) - hole_range: Tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + hole_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Hole Density Range", description="Range of holes densities in which perturbation model is valid. For " @@ -810,7 +809,7 @@ class CustomChargePerturbation(ChargePerturbation): "provided ``perturbation_values``", ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "linear", title="Interpolation method", description="Interpolation method to obtain perturbation values between sample points.", @@ -819,17 +818,17 @@ class CustomChargePerturbation(ChargePerturbation): _no_nans = validate_no_nans("perturbation_values") @cached_property - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[complex, complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[complex, complex]]: """Range of possible parameter perturbation values.""" return np.min(self.perturbation_values).item(), np.max(self.perturbation_values).item() - @pd.root_validator(skip_on_failure=True) - def compute_eh_ranges(cls, values): + @model_validator(mode="after") + def compute_eh_ranges(self): """Compute and set electron and hole density ranges based on provided ``perturbation_values``. """ - perturbation_values = values["perturbation_values"] + perturbation_values = self.perturbation_values electron_range = ( np.min(perturbation_values.coords["n"]).item(), @@ -841,43 +840,44 @@ def compute_eh_ranges(cls, values): np.max(perturbation_values.coords["p"]).item(), ) - if values["electron_range"] is not None and electron_range != values["electron_range"]: + if self.electron_range is not None and electron_range != self.electron_range: log.warning( "Electron density range for 'CustomChargePerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'electron_range' will be " "overwritten." ) - if values["hole_range"] is not None and hole_range != values["hole_range"]: + if self.hole_range is not None and hole_range != self.hole_range: log.warning( "Hole density range for 'CustomChargePerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'hole_range' will be " "overwritten." ) - values.update({"electron_range": electron_range, "hole_range": hole_range}) + object.__setattr__(self, "electron_range", electron_range) + object.__setattr__(self, "hole_range", hole_range) - return values + return self @ensure_charge_in_range def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at electron and hole density points. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -893,8 +893,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -963,9 +963,10 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.perturbation_values) -ChargePerturbationType = Union[LinearChargePerturbation, CustomChargePerturbation] - -PerturbationType = Union[HeatPerturbationType, ChargePerturbationType] +ChargePerturbationType = discriminated_union( + Union[LinearChargePerturbation, CustomChargePerturbation] +) +PerturbationType = discriminated_union(Union[HeatPerturbationType, ChargePerturbationType]) class ParameterPerturbation(Tidy3dBaseModel): @@ -991,26 +992,24 @@ class ParameterPerturbation(Tidy3dBaseModel): >>> param_perturb = ParameterPerturbation(heat=heat_perturb, charge=charge_perturb) """ - heat: HeatPerturbationType = pd.Field( + heat: Optional[HeatPerturbationType] = Field( None, title="Heat Perturbation", description="Heat perturbation to apply.", - discriminator=TYPE_TAG_STR, ) - charge: ChargePerturbationType = pd.Field( + charge: Optional[ChargePerturbationType] = Field( None, title="Charge Perturbation", description="Charge perturbation to apply.", - discriminator=TYPE_TAG_STR, ) - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - heat = values.get("heat") - charge = values.get("charge") + heat = self.heat + charge = self.charge if heat is None and charge is None: raise DataError( @@ -1018,10 +1017,10 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self @cached_property - def perturbation_list(self) -> List[PerturbationType]: + def perturbation_list(self) -> list[PerturbationType]: """Provided perturbations as a list.""" perturb_list = [] for p in [self.heat, self.charge]: @@ -1030,7 +1029,7 @@ def perturbation_list(self) -> List[PerturbationType]: return perturb_list @cached_property - def perturbation_range(self) -> Union[Tuple[float, float], Tuple[Complex, Complex]]: + def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Complex]]: """Range of possible parameter perturbation values due to both heat and charge effects.""" prange = np.zeros(2) @@ -1151,24 +1150,24 @@ class PermittivityPerturbation(Tidy3dBaseModel): >>> permittivity_pb = PermittivityPerturbation(delta_eps=delta_eps, delta_sigma=delta_sigma) """ - delta_eps: Optional[ParameterPerturbation] = pd.Field( + delta_eps: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="Perturbation model for permittivity.", ) - delta_sigma: Optional[ParameterPerturbation] = pd.Field( + delta_sigma: Optional[ParameterPerturbation] = Field( None, title="Conductivity Perturbation", description="Perturbation model for conductivity.", ) - @pd.root_validator(skip_on_failure=True) - def _check_not_complex(cls, values): + @model_validator(mode="after") + def _check_not_complex(self) -> Self: """Check that perturbation values are not complex.""" - delta_eps = values.get("delta_eps") - delta_sigma = values.get("delta_sigma") + delta_eps = self.delta_eps + delta_sigma = self.delta_sigma delta_eps_complex = False if delta_eps is None else delta_eps.is_complex delta_sigma_complex = False if delta_sigma is None else delta_sigma.is_complex @@ -1179,14 +1178,14 @@ def _check_not_complex(cls, values): "complex-valued." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - delta_eps = values.get("delta_eps") - delta_sigma = values.get("delta_sigma") + delta_eps = self.delta_eps + delta_sigma = self.delta_sigma if delta_eps is None and delta_sigma is None: raise DataError( @@ -1194,7 +1193,7 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self def _delta_eps_delta_sigma_ranges(self): """Perturbation range of permittivity.""" @@ -1252,277 +1251,261 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): ------- """ - perturb_coeffs = xr.Dataset( - { - "a": ( - "wvl", + perturb_coeffs: PerturbationCoefficientDataArray = Field( + default_factory=lambda: PerturbationCoefficientDataArray( + np.column_stack( [ - 3.48e-22, - 8.88e-21, - 3.22e-20, - 1.67e-20, - 6.29e-21, - 3.10e-21, - 7.45e-22, - 2.16e-22, - 9.28e-23, - 4.58e-23, - 3.26e-23, - 2.70e-23, - 2.25e-23, - 1.36e-23, - 1.85e-23, - 3.05e-23, - 4.08e-23, - 4.14e-23, - 3.81e-23, - 4.23e-23, - 5.81e-23, - 8.20e-23, - 1.13e-22, - 1.22e-22, - 1.09e-22, - 1.20e-22, - 1.62e-22, + [ + 3.48e-22, + 8.88e-21, + 3.22e-20, + 1.67e-20, + 6.29e-21, + 3.10e-21, + 7.45e-22, + 2.16e-22, + 9.28e-23, + 4.58e-23, + 3.26e-23, + 2.70e-23, + 2.25e-23, + 1.36e-23, + 1.85e-23, + 3.05e-23, + 4.08e-23, + 4.14e-23, + 3.81e-23, + 4.23e-23, + 5.81e-23, + 8.20e-23, + 1.13e-22, + 1.22e-22, + 1.09e-22, + 1.20e-22, + 1.62e-22, + ], + [ + 1.229, + 1.167, + 1.149, + 1.169, + 1.193, + 1.210, + 1.245, + 1.277, + 1.299, + 1.319, + 1.330, + 1.338, + 1.345, + 1.359, + 1.354, + 1.345, + 1.340, + 1.341, + 1.344, + 1.344, + 1.338, + 1.331, + 1.325, + 1.324, + 1.328, + 1.327, + 1.321, + ], + [ + 1.02e-19, + 5.84e-20, + 6.21e-20, + 8.08e-20, + 3.40e-20, + 6.05e-20, + 5.43e-20, + 5.58e-20, + 6.65e-20, + 8.53e-20, + 1.53e-19, + 1.22e-19, + 1.29e-19, + 9.99e-20, + 1.32e-19, + 1.57e-18, + 1.45e-18, + 1.70e-18, + 1.25e-18, + 8.14e-19, + 1.55e-18, + 4.81e-18, + 4.72e-18, + 2.09e-18, + 1.16e-18, + 2.01e-18, + 7.52e-18, + ], + [ + 1.089, + 1.109, + 1.119, + 1.123, + 1.151, + 1.145, + 1.153, + 1.158, + 1.160, + 1.159, + 1.149, + 1.158, + 1.160, + 1.170, + 1.167, + 1.111, + 1.115, + 1.115, + 1.125, + 1.137, + 1.124, + 1.100, + 1.102, + 1.124, + 1.140, + 1.130, + 1.101, + ], + [ + 2.98e-22, + 5.40e-22, + 1.91e-21, + 5.70e-21, + 6.57e-21, + 6.95e-21, + 7.25e-21, + 1.19e-20, + 2.46e-20, + 3.64e-20, + 4.96e-20, + 5.91e-20, + 5.52e-20, + 3.19e-20, + 3.56e-20, + 8.65e-20, + 2.09e-19, + 2.07e-19, + 3.01e-19, + 5.07e-19, + 1.51e-19, + 2.19e-19, + 3.04e-19, + 4.44e-19, + 6.96e-19, + 1.05e-18, + 1.45e-18, + ], + [ + 1.016, + 1.011, + 0.992, + 0.976, + 0.981, + 0.986, + 0.991, + 0.985, + 0.973, + 0.968, + 0.965, + 0.964, + 0.969, + 0.984, + 0.984, + 0.966, + 0.948, + 0.951, + 0.944, + 0.934, + 0.965, + 0.958, + 0.953, + 0.945, + 0.936, + 0.928, + 0.922, + ], + [ + 1.25e-18, + 1.53e-18, + 2.28e-18, + 5.19e-18, + 3.62e-18, + 9.28e-18, + 9.99e-18, + 1.29e-17, + 2.03e-17, + 3.31e-17, + 6.92e-17, + 8.23e-17, + 1.15e-16, + 4.81e-16, + 7.44e-16, + 7.11e-16, + 5.29e-16, + 9.72e-16, + 1.22e-15, + 1.16e-15, + 3.16e-15, + 1.51e-14, + 2.71e-14, + 2.65e-14, + 2.94e-14, + 6.85e-14, + 2.60e-13, + ], + [ + 0.835, + 0.838, + 0.841, + 0.832, + 0.849, + 0.834, + 0.839, + 0.838, + 0.833, + 0.826, + 0.812, + 0.812, + 0.807, + 0.776, + 0.769, + 0.774, + 0.783, + 0.772, + 0.769, + 0.772, + 0.750, + 0.716, + 0.704, + 0.706, + 0.705, + 0.686, + 0.656, + ], ], ), - "b": ( - "wvl", - [ - 1.229, - 1.167, - 1.149, - 1.169, - 1.193, - 1.210, - 1.245, - 1.277, - 1.299, - 1.319, - 1.330, - 1.338, - 1.345, - 1.359, - 1.354, - 1.345, - 1.340, - 1.341, - 1.344, - 1.344, - 1.338, - 1.331, - 1.325, - 1.324, - 1.328, - 1.327, - 1.321, - ], - ), - "c": ( - "wvl", - [ - 1.02e-19, - 5.84e-20, - 6.21e-20, - 8.08e-20, - 3.40e-20, - 6.05e-20, - 5.43e-20, - 5.58e-20, - 6.65e-20, - 8.53e-20, - 1.53e-19, - 1.22e-19, - 1.29e-19, - 9.99e-20, - 1.32e-19, - 1.57e-18, - 1.45e-18, - 1.70e-18, - 1.25e-18, - 8.14e-19, - 1.55e-18, - 4.81e-18, - 4.72e-18, - 2.09e-18, - 1.16e-18, - 2.01e-18, - 7.52e-18, - ], - ), - "d": ( - "wvl", - [ - 1.089, - 1.109, - 1.119, - 1.123, - 1.151, - 1.145, - 1.153, - 1.158, - 1.160, - 1.159, - 1.149, - 1.158, - 1.160, - 1.170, - 1.167, - 1.111, - 1.115, - 1.115, - 1.125, - 1.137, - 1.124, - 1.100, - 1.102, - 1.124, - 1.140, - 1.130, - 1.101, - ], - ), - "p": ( - "wvl", - [ - 2.98e-22, - 5.40e-22, - 1.91e-21, - 5.70e-21, - 6.57e-21, - 6.95e-21, - 7.25e-21, - 1.19e-20, - 2.46e-20, - 3.64e-20, - 4.96e-20, - 5.91e-20, - 5.52e-20, - 3.19e-20, - 3.56e-20, - 8.65e-20, - 2.09e-19, - 2.07e-19, - 3.01e-19, - 5.07e-19, - 1.51e-19, - 2.19e-19, - 3.04e-19, - 4.44e-19, - 6.96e-19, - 1.05e-18, - 1.45e-18, - ], - ), - "q": ( - "wvl", - [ - 1.016, - 1.011, - 0.992, - 0.976, - 0.981, - 0.986, - 0.991, - 0.985, - 0.973, - 0.968, - 0.965, - 0.964, - 0.969, - 0.984, - 0.984, - 0.966, - 0.948, - 0.951, - 0.944, - 0.934, - 0.965, - 0.958, - 0.953, - 0.945, - 0.936, - 0.928, - 0.922, - ], - ), - "r": ( - "wvl", - [ - 1.25e-18, - 1.53e-18, - 2.28e-18, - 5.19e-18, - 3.62e-18, - 9.28e-18, - 9.99e-18, - 1.29e-17, - 2.03e-17, - 3.31e-17, - 6.92e-17, - 8.23e-17, - 1.15e-16, - 4.81e-16, - 7.44e-16, - 7.11e-16, - 5.29e-16, - 9.72e-16, - 1.22e-15, - 1.16e-15, - 3.16e-15, - 1.51e-14, - 2.71e-14, - 2.65e-14, - 2.94e-14, - 6.85e-14, - 2.60e-13, - ], - ), - "s": ( - "wvl", - [ - 0.835, - 0.838, - 0.841, - 0.832, - 0.849, - 0.834, - 0.839, - 0.838, - 0.833, - 0.826, - 0.812, - 0.812, - 0.807, - 0.776, - 0.769, - 0.774, - 0.783, - 0.772, - 0.769, - 0.772, - 0.750, - 0.716, - 0.704, - 0.706, - 0.705, - 0.686, - 0.656, - ], - ), - }, - coords={"wvl": np.array([1.3, 1.55] + list(np.arange(2, 14.5, 0.5)))}, + dims=("wvl", "coeff"), + coords={ + "wvl": np.array([1.3, 1.55] + list(np.arange(2, 14.5, 0.5))), + "coeff": ["a", "b", "c", "d", "p", "q", "r", "s"], + }, + name="perturb_coeffs", + ) ) - ref_freq: pd.NonNegativeFloat = pd.Field( - ..., + ref_freq: NonNegativeFloat = Field( title="Reference Frequency", description="Reference frequency to evaluate perturbation at (Hz).", units=HERTZ, ) - electrons_grid: np.ndarray = pd.Field( - np.concatenate(([0], np.logspace(-6, 22, num=200))), + electrons_grid: ArrayFloat = Field( + default_factory=lambda: np.concatenate(([0], np.logspace(-6, 22, num=200))), title="Electron concentration grid.", descriptio="The model will be evaluated at these concentration values. Since " "the data at these locations will later be interpolated to determine perturbations " @@ -1531,8 +1514,8 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): "i.e., `np.concatenate(([0], np.logspace(-6, 22, num=200)))`.", ) - holes_grid: np.ndarray = pd.Field( - np.concatenate(([0], np.logspace(-6, 22, num=200))), + holes_grid: ArrayFloat = Field( + default_factory=lambda: np.concatenate(([0], np.logspace(-6, 22, num=200))), title="Hole concentration grid.", descriptio="The model will be evaluated at these concentration values. Since " "the data at these locations will later be interpolated to determine perturbations " @@ -1541,14 +1524,14 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): "i.e., `np.concatenate(([0], np.logspace(-6, 22, num=200)))`.", ) - @pd.root_validator(skip_on_failure=True) - def _check_freq_in_range(cls, values): + @model_validator(mode="after") + def _check_freq_in_range(self) -> Self: """Check that the given frequency is within validity range. If not, issue a warning. """ - freq = values.get("ref_freq") - wavelengths = list(values.get("perturb_coeffs").coords["wvl"]) + freq = self.ref_freq + wavelengths = list(self.perturb_coeffs.coords["wvl"]) freq_range = (C_0 / np.max(wavelengths), C_0 / np.min(wavelengths)) @@ -1558,7 +1541,7 @@ def _check_freq_in_range(cls, values): f"{freq_range[1]} Hz) of the Nedeljkovic-Soref-Mashanovich model." ) - return values + return self @cached_property def ref_wavelength(self) -> float: @@ -1582,11 +1565,11 @@ def delta_k(self) -> ChargePerturbationType: Ne_mesh, Nh_mesh = np.meshgrid(Ne_range, Nh_range, indexing="ij") # get parameters a, b, c, d - ke_coeff = self._coeffs_at_ref_freq["a"].item() - ke_pow = self._coeffs_at_ref_freq["b"].item() + ke_coeff = self._coeffs_at_ref_freq.sel(coeff="a").item() + ke_pow = self._coeffs_at_ref_freq.sel(coeff="b").item() - kh_coeff = self._coeffs_at_ref_freq["c"].item() - kh_pow = self._coeffs_at_ref_freq["d"].item() + kh_coeff = self._coeffs_at_ref_freq.sel(coeff="c").item() + kh_pow = self._coeffs_at_ref_freq.sel(coeff="d").item() dk_mesh = ke_coeff * Ne_mesh**ke_pow + kh_coeff * Nh_mesh**kh_pow @@ -1616,11 +1599,11 @@ def delta_n(self) -> ChargePerturbationType: Ne_mesh, Nh_mesh = np.meshgrid(Ne_range, Nh_range, indexing="ij") # get parameters p, q, r, s - ne_coeff = self._coeffs_at_ref_freq["p"].item() - ne_pow = self._coeffs_at_ref_freq["q"].item() + ne_coeff = self._coeffs_at_ref_freq.sel(coeff="p").item() + ne_pow = self._coeffs_at_ref_freq.sel(coeff="q").item() - nh_coeff = self._coeffs_at_ref_freq["r"].item() - nh_pow = self._coeffs_at_ref_freq["s"].item() + nh_coeff = self._coeffs_at_ref_freq.sel(coeff="r").item() + nh_pow = self._coeffs_at_ref_freq.sel(coeff="s").item() dn_mesh = -ne_coeff * Ne_mesh**ne_pow - nh_coeff * Nh_mesh**nh_pow @@ -1656,31 +1639,30 @@ class IndexPerturbation(Tidy3dBaseModel): >>> index_pb = IndexPerturbation(delta_n=dn_pb, delta_k=dk_pb, freq=C_0) """ - delta_n: Optional[ParameterPerturbation] = pd.Field( + delta_n: Optional[ParameterPerturbation] = Field( None, title="Refractive Index Perturbation", description="Perturbation of the real part of refractive index.", ) - delta_k: Optional[ParameterPerturbation] = pd.Field( + delta_k: Optional[ParameterPerturbation] = Field( None, title="Exctinction Coefficient Perturbation", description="Perturbation of the imaginary part of refractive index.", ) - freq: pd.NonNegativeFloat = pd.Field( - ..., + freq: NonNegativeFloat = Field( title="Frequency", description="Frequency to evaluate permittivity at (Hz).", units=HERTZ, ) - @pd.root_validator(skip_on_failure=True) - def _check_not_complex(cls, values): + @model_validator(mode="after") + def _check_not_complex(self) -> Self: """Check that perturbation values are not complex.""" - dn = values.get("delta_n") - dk = values.get("delta_k") + dn = self.delta_n + dk = self.delta_k dn_complex = False if dn is None else dn.is_complex dk_complex = False if dk is None else dk.is_complex @@ -1691,14 +1673,14 @@ def _check_not_complex(cls, values): "complex-valued." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - dn = values.get("delta_n") - dk = values.get("delta_k") + dn = self.delta_n + dk = self.delta_k if dn is None and dk is None: raise DataError( @@ -1706,7 +1688,7 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self def _delta_eps_delta_sigma_ranges(self, n: float, k: float): """Perturbation range of permittivity.""" @@ -1787,6 +1769,6 @@ def _sample_delta_eps_delta_sigma( return delta_eps, delta_sigma - def from_perturbation_delta_model(cls, deltas_model: AbstractDeltaModel) -> IndexPerturbation: + def from_perturbation_delta_model(cls, deltas_model: AbstractDeltaModel) -> Self: """Create an IndexPerturbation from a DeltaPerturbationModel.""" return IndexPerturbation(delta_n=deltas_model.delta_n, delta_k=deltas_model.delta_k) diff --git a/tidy3d/components/run_time_spec.py b/tidy3d/components/run_time_spec.py index 5299817609..dd3ef3eae7 100644 --- a/tidy3d/components/run_time_spec.py +++ b/tidy3d/components/run_time_spec.py @@ -1,5 +1,5 @@ # Defines specifications for how long to run a simulation -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from .base import Tidy3dBaseModel @@ -24,15 +24,14 @@ class RunTimeSpec(Tidy3dBaseModel): """ - quality_factor: pd.PositiveFloat = pd.Field( - ..., + quality_factor: PositiveFloat = Field( title="Quality Factor", description="Quality factor expected in the device. This determines how long the " "simulation will run as it assumes a field decay time that scales proportionally to " "this value.", ) - source_factor: pd.PositiveFloat = pd.Field( + source_factor: PositiveFloat = Field( 3, title="Source Factor", description="The contribution to the ``run_time`` from the longest source is computed from " diff --git a/tidy3d/components/scene.py b/tidy3d/components/scene.py index 67c29aa98a..e748394a0c 100644 --- a/tidy3d/components/scene.py +++ b/tidy3d/components/scene.py @@ -1,10 +1,9 @@ """Container holding about the geometry and medium properties common to all types of simulations.""" -from __future__ import annotations - -from typing import Dict, List, Literal, Optional, Set, Tuple, Union +from typing import Any, Literal, Optional, Union import autograd.numpy as np +from pydantic import Field, NonNegativeInt, field_validator try: import matplotlib as mpl @@ -12,17 +11,8 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable except ImportError: pass -import pydantic.v1 as pd - -from tidy3d.components.material.tcad.charge import ( - ChargeConductorMedium, - SemiconductorMedium, -) -from tidy3d.components.material.tcad.heat import SolidMedium, SolidSpec -from tidy3d.components.material.types import MultiPhysicsMediumType3D, StructureMediumType -from tidy3d.components.tcad.doping import ConstantDoping, GaussianDoping -from tidy3d.components.tcad.viz import HEAT_SOURCE_CMAP +from ..compat import Self from ..constants import CONDUCTIVITY, THERMAL_CONDUCTIVITY, inf from ..exceptions import SetupError, Tidy3dError from ..log import log @@ -38,6 +28,12 @@ from .geometry.utils import flatten_groups, merging_geometries_on_plane, traverse_geometries from .grid.grid import Coords, Grid from .material.multi_physics import MultiPhysicsMedium +from .material.tcad.charge import ( + ChargeConductorMedium, + SemiconductorMedium, +) +from .material.tcad.heat import SolidMedium, SolidSpec +from .material.types import MultiPhysicsMediumType3D, StructureMediumType from .medium import ( AbstractCustomMedium, AbstractMedium, @@ -46,6 +42,8 @@ Medium2D, ) from .structure import Structure +from .tcad.doping import ConstantDoping, GaussianDoping +from .tcad.viz import HEAT_SOURCE_CMAP from .types import ( TYPE_TAG_STR, Ax, @@ -95,14 +93,14 @@ class Scene(Tidy3dBaseModel): ... ) """ - medium: MultiPhysicsMediumType3D = pd.Field( - Medium(), + medium: MultiPhysicsMediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of scene, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, ) - structures: Tuple[Structure, ...] = pd.Field( + structures: Optional[tuple[Structure, ...]] = Field( (), title="Structures", description="Tuple of structures present in scene. " @@ -110,7 +108,7 @@ class Scene(Tidy3dBaseModel): "simulation material properties in regions of spatial overlap.", ) - plot_length_units: Optional[LengthUnit] = pd.Field( + plot_length_units: Optional[LengthUnit] = Field( "μm", title="Plot Units", description="When set to a supported ``LengthUnit``, " @@ -120,11 +118,10 @@ class Scene(Tidy3dBaseModel): """ Validating setup """ - # make sure all names are unique _unique_structure_names = assert_unique_names("structures") - @pd.validator("structures", always=True) - def _validate_num_mediums(cls, val): + @field_validator("structures") + def _validate_num_mediums(val): """Error if too many mediums present.""" if val is None: @@ -139,8 +136,8 @@ def _validate_num_mediums(cls, val): return val - @pd.validator("structures", always=True) - def _validate_num_geometries(cls, val): + @field_validator("structures") + def _validate_num_geometries(val): """Error if too many geometries in a single structure.""" if val is None: @@ -173,7 +170,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float, float] + tuple[float, float, float], tuple[float, float, float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ @@ -189,7 +186,7 @@ def size(self) -> Size: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] Scene's size. """ @@ -201,7 +198,7 @@ def center(self) -> Coordinate: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] Scene's center. """ @@ -220,12 +217,12 @@ def box(self) -> Box: return Box(center=self.center, size=self.size) @cached_property - def mediums(self) -> Set[StructureMediumType]: + def mediums(self) -> set[StructureMediumType]: """Returns set of distinct :class:`.AbstractMedium` in scene. Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums in the scene. """ medium_dict = {self.medium: None} @@ -233,13 +230,13 @@ def mediums(self) -> Set[StructureMediumType]: return list(medium_dict.keys()) @cached_property - def medium_map(self) -> Dict[StructureMediumType, pd.NonNegativeInt]: + def medium_map(self) -> dict[StructureMediumType, NonNegativeInt]: """Returns dict mapping medium to index in material. ``medium_map[medium]`` returns unique global index of :class:`.AbstractMedium` in scene. Returns ------- - Dict[:class:`.AbstractMedium`, int] + dict[:class:`.AbstractMedium`, int] Mapping between distinct mediums to index in scene. """ @@ -252,14 +249,14 @@ def background_structure(self) -> Structure: return Structure(geometry=geometry, medium=self.medium) @cached_property - def all_structures(self) -> List[Structure]: + def all_structures(self) -> list[Structure]: """List of all structures in the simulation including the background.""" return [self.background_structure] + list(self.structures) @staticmethod def intersecting_media( - test_object: Box, structures: Tuple[Structure, ...] - ) -> Tuple[StructureMediumType, ...]: + test_object: Box, structures: tuple[Structure, ...] + ) -> tuple[StructureMediumType, ...]: """From a given list of structures, returns a list of :class:`.AbstractMedium` associated with those structures that intersect with the ``test_object``, if it is a surface, or its surfaces, if it is a volume. @@ -268,12 +265,12 @@ def intersecting_media( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums that intersect with the given planar object. """ structures = [s.to_static() for s in structures] @@ -293,8 +290,8 @@ def intersecting_media( @staticmethod def intersecting_structures( - test_object: Box, structures: Tuple[Structure, ...] - ) -> Tuple[Structure, ...]: + test_object: Box, structures: tuple[Structure, ...] + ) -> tuple[Structure, ...]: """From a given list of structures, returns a list of :class:`.Structure` that intersect with the ``test_object``, if it is a surface, or its surfaces, if it is a volume. @@ -302,12 +299,12 @@ def intersecting_structures( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.Structure`] + list[:class:`.Structure`] Set of distinct structures that intersect with the given surface, or with the surfaces of the given volume. """ @@ -340,9 +337,9 @@ def _get_plot_lims( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - ) -> Tuple[Tuple[float, float], Tuple[float, float]]: + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, + ) -> tuple[tuple[float, float], tuple[float, float]]: # if no hlim and/or vlim given, the bounds will then be the usual pml bounds axis, _ = Box.parse_xyz_kwargs(x=x, y=y, z=z) _, (hmin, vmin) = Box.pop_axis(bounds[0], axis=axis) @@ -369,8 +366,8 @@ def plot( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, fill_structures: bool = True, **patch_kwargs, ) -> Ax: @@ -386,9 +383,9 @@ def plot( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill_structures : bool = True Whether to fill structures with color or just draw outlines. @@ -413,8 +410,8 @@ def plot_structures( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, fill: bool = True, ) -> Ax: """Plot each of scene's structures on a plane defined by one nonzero x,y,z coordinate. @@ -429,9 +426,9 @@ def plot_structures( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill : bool = True Whether to fill structures with color or just draw outlines. @@ -554,8 +551,8 @@ def _set_plot_bounds( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Sets the xy limits of the scene at a plane, useful after plotting. @@ -569,9 +566,9 @@ def _set_plot_bounds( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns ------- @@ -586,18 +583,18 @@ def _set_plot_bounds( def _get_structures_2dbox( self, - structures: List[Structure], + structures: list[Structure], x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, - ) -> List[Tuple[Medium, Shapely]]: + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, + ) -> list[tuple[Medium, Shapely]]: """Compute list of shapes to plot on 2d box specified by (x_min, x_max), (y_min, y_max). Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] list of structures to filter on the plane. x : float = None position of plane in x direction, only one of x, y, z must be specified to define plane. @@ -605,14 +602,14 @@ def _get_structures_2dbox( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and mediums on the plane. """ # if no hlim and/or vlim given, the bounds will then be the usual pml bounds @@ -647,21 +644,21 @@ def _get_structures_2dbox( @staticmethod def _filter_structures_plane_medium( - structures: List[Structure], plane: Box - ) -> List[Tuple[Medium, Shapely]]: + structures: list[Structure], plane: Box + ) -> list[tuple[Medium, Shapely]]: """Compute list of shapes to plot on plane. Overlaps are removed or merged depending on medium. Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] List of structures to filter on the plane. plane : Box Plane specification. Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and mediums on the plane after merging. """ @@ -672,16 +669,16 @@ def _filter_structures_plane_medium( @staticmethod def _filter_structures_plane( - structures: List[Structure], + structures: list[Structure], plane: Box, - property_list: List, - ) -> List[Tuple[Medium, Shapely]]: + property_list: list[Any], + ) -> list[tuple[Medium, Shapely]]: """Compute list of shapes to plot on plane. Overlaps are removed or merged depending on provided property_list. Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] List of structures to filter on the plane. plane : Box Plane specification. @@ -690,7 +687,7 @@ def _filter_structures_plane( Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and their property value on the plane after merging. """ return merging_geometries_on_plane( @@ -709,8 +706,8 @@ def plot_eps( freq: float = None, alpha: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of scene's components on a plane defined by one nonzero x,y,z coordinate. The permittivity is plotted in grayscale based on its value at the specified frequency. @@ -731,9 +728,9 @@ def plot_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -761,10 +758,10 @@ def plot_structures_eps( alpha: float = None, cbar: bool = True, reverse: bool = False, - eps_lim: Tuple[Union[float, None], Union[float, None]] = (None, None), + eps_lim: tuple[Union[float, None], Union[float, None]] = (None, None), ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, grid: Grid = None, eps_component: Optional[PermittivityComponent] = None, ) -> Ax: @@ -790,13 +787,13 @@ def plot_structures_eps( alpha : float = None Opacity of the structures being plotted. Defaults to the structure default alpha. - eps_lim : Tuple[float, float] = None + eps_lim : tuple[float, float] = None Custom limits for eps coloring. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -837,10 +834,10 @@ def plot_structures_property( alpha: float = None, cbar: bool = True, reverse: bool = False, - limits: Tuple[Union[float, None], Union[float, None]] = (None, None), + limits: tuple[Union[float, None], Union[float, None]] = (None, None), ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, grid: Grid = None, property: Literal["eps", "doping", "N_a", "N_d"] = "eps", eps_component: Optional[PermittivityComponent] = None, @@ -867,13 +864,13 @@ def plot_structures_property( alpha : float = None Opacity of the structures being plotted. Defaults to the structure default alpha. - limits : Tuple[float, float] = None + limits : tuple[float, float] = None Custom coloring limits for the property to plot. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. property: Literal["eps", "doping", "N_a", "N_d"] = "eps" Indicates the property to plot for the structures. Currently supported properties @@ -1035,7 +1032,7 @@ def _eps_bounds( medium_list: list[Medium], freq: float = None, eps_component: Optional[PermittivityComponent] = None, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: """Compute range of (real) permittivity present in the mediums at frequency "freq".""" medium_list = [medium for medium in medium_list if not medium.is_pec] eps_list = [medium._eps_plot(freq, eps_component) for medium in medium_list] @@ -1050,7 +1047,7 @@ def _eps_bounds( eps_max = max(eps_max, mat_epsmax) return eps_min, eps_max - def eps_bounds(self, freq: float = None, eps_component: str = None) -> Tuple[float, float]: + def eps_bounds(self, freq: float = None, eps_component: str = None) -> tuple[float, float]: """Compute range of (real) permittivity present in the scene at frequency "freq". Parameters @@ -1065,7 +1062,7 @@ def eps_bounds(self, freq: float = None, eps_component: str = None) -> Tuple[flo Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of relative permittivity in scene. """ @@ -1311,8 +1308,8 @@ def plot_heat_charge_property( cbar: bool = True, property: str = "heat_conductivity", ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of scebe's components on a plane defined by one nonzero x,y,z coordinate. The thermal conductivity is plotted in grayscale based on its value. @@ -1335,9 +1332,9 @@ def plot_heat_charge_property( ["heat_conductivity", "electric_conductivity"] ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1365,8 +1362,8 @@ def plot_structures_heat_conductivity( cbar: bool = True, reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of scene's structures on a plane defined by one nonzero x,y,z coordinate. The thermal conductivity is plotted in grayscale based on its value. @@ -1389,9 +1386,9 @@ def plot_structures_heat_conductivity( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1431,8 +1428,8 @@ def plot_structures_heat_charge_property( property: str = "heat_conductivity", reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of scene's structures on a plane defined by one nonzero x,y,z coordinate. The thermal conductivity is plotted in grayscale based on its value. @@ -1455,9 +1452,9 @@ def plot_structures_heat_charge_property( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1524,12 +1521,12 @@ def plot_structures_heat_charge_property( ) return ax - def heat_charge_property_bounds(self, property) -> Tuple[float, float]: + def heat_charge_property_bounds(self, property) -> tuple[float, float]: """Compute range of the heat-charge simulation property present in the scene. Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of thermal conductivity in scene. """ @@ -1553,12 +1550,12 @@ def heat_charge_property_bounds(self, property) -> Tuple[float, float]: cond_max = max(cond_list) return cond_min, cond_max - def heat_conductivity_bounds(self) -> Tuple[float, float]: + def heat_conductivity_bounds(self) -> tuple[float, float]: """Compute range of thermal conductivities present in the scene. Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of thermal conductivity in scene. """ log.warning( @@ -1648,8 +1645,8 @@ def plot_heat_conductivity( alpha: float = None, cbar: bool = True, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ): """Plot each of scebe's components on a plane defined by one nonzero x,y,z coordinate. The thermal conductivity is plotted in grayscale based on its value. @@ -1669,9 +1666,9 @@ def plot_heat_conductivity( Whether to plot a colorbar for the thermal conductivity. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1706,7 +1703,7 @@ def perturbed_mediums_copy( electron_density: CustomSpatialDataType = None, hole_density: CustomSpatialDataType = None, interp_method: InterpMethod = "linear", - ) -> Scene: + ) -> Self: """Return a copy of the scene with heat and/or charge data applied to all mediums that have perturbation models specified. That is, such mediums will be replaced with spatially dependent custom mediums that reflect perturbation effects. Any of temperature, diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 882aeeef4f..2d25f42bb0 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -6,24 +6,30 @@ import pathlib from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple, Union, get_args +from typing import Optional, Union, get_args import autograd.numpy as np +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) try: import matplotlib as mpl except ImportError: pass - -import pydantic.v1 as pydantic import xarray as xr from ..constants import C_0, SECOND, fp_eps, inf from ..exceptions import SetupError, Tidy3dError, Tidy3dImportError, ValidationError from ..log import log from ..updater import Updater -from .base import cached_property, skip_if_fields_missing +from .base import cached_property from .base_sim.simulation import AbstractSimulation from .boundary import ( PML, @@ -111,7 +117,7 @@ Literal, PermittivityComponent, Symmetry, - annotate_type, + discriminated_union, ) from .validators import ( assert_objects_contained_in_sim_bounds, @@ -131,6 +137,11 @@ plot_sim_3d, ) +try: + import matplotlib as mpl +except ImportError: + pass + try: gdstk_available = True import gdstk @@ -182,13 +193,13 @@ def validate_boundaries_for_zero_dims(): """Error if absorbing boundaries, bloch boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" - @pydantic.validator("boundary_spec", allow_reuse=True, always=True) - @skip_if_fields_missing(["size", "symmetry"]) - def boundaries_for_zero_dims(cls, val, values): + @model_validator(mode="after") + def boundaries_for_zero_dims(self): """Error if absorbing boundaries, bloch boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" + val = self.boundary_spec boundaries = val.to_list - size = values.get("size") - symmetry = values.get("symmetry") + size = self.size + symmetry = self.symmetry axis_names = "xyz" for dim, (boundary, symmetry_dim, size_dim) in enumerate(zip(boundaries, symmetry, size)): @@ -227,7 +238,7 @@ def boundaries_for_zero_dims(cls, val, values): "minus must be the same." ) - return val + return self return boundaries_for_zero_dims @@ -237,7 +248,7 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): Abstract class for a simulation involving electromagnetic fields defined on a Yee grid. """ - lumped_elements: Tuple[LumpedElementType, ...] = pydantic.Field( + lumped_elements: tuple[LumpedElementType, ...] = Field( (), title="Lumped Elements", description="Tuple of lumped elements in the simulation. " @@ -247,8 +258,8 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): Tuple of lumped elements in the simulation. """ - grid_spec: GridSpec = pydantic.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) @@ -287,8 +298,8 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - subpixel: Union[bool, SubpixelSpec] = pydantic.Field( - SubpixelSpec(), + subpixel: Union[bool, SubpixelSpec] = Field( + default_factory=SubpixelSpec, title="Subpixel Averaging", description="Apply subpixel averaging methods of the permittivity on structure interfaces " "to result in much higher accuracy for a given grid size. Supply a :class:`SubpixelSpec` " @@ -298,16 +309,14 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): ", or ``False`` to apply staircasing.", ) - simulation_type: Optional[Literal["autograd_fwd", "autograd_bwd", "tidy3d", None]] = ( - pydantic.Field( - "tidy3d", - title="Simulation Type", - description="Tag used internally to distinguish types of simulations for " - "``autograd`` gradient processing.", - ) + simulation_type: Optional[Literal["autograd_fwd", "autograd_bwd", "tidy3d"]] = Field( + "tidy3d", + title="Simulation Type", + description="Tag used internally to distinguish types of simulations for " + "``autograd`` gradient processing.", ) - post_norm: Union[float, FreqDataArray] = pydantic.Field( + post_norm: Union[float, FreqDataArray] = Field( 1.0, title="Post Normalization Values", description="Factor to multiply the fields by after running, " @@ -358,21 +367,18 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): * `Dielectric constant assignment on Yee grids `_ """ - @pydantic.validator("simulation_type", always=True) - def _validate_simulation_type_tidy3d(cls, val): + @field_validator("simulation_type") + def _validate_simulation_type_tidy3d(val, mode="before"): """Enforce the simulation_type is 'tidy3d' if passed as None for bkwrds compatibility.""" - if val is None: - return "tidy3d" - return val + return "tidy3d" if val is None else val - @pydantic.validator("lumped_elements", always=True) - @skip_if_fields_missing(["structures"]) - def _validate_num_lumped_elements(cls, val, values): + @model_validator(mode="after") + def _validate_num_lumped_elements(self): """Error if too many lumped elements present.""" - + val = self.lumped_elements if val is None: - return val - structures = values.get("structures") + return self + structures = self.structures mediums = {structure.medium for structure in structures} total_num_mediums = len(val) + len(mediums) if total_num_mediums > MAX_NUM_MEDIUMS: @@ -381,24 +387,22 @@ def _validate_num_lumped_elements(cls, val, values): f"{total_num_mediums} were supplied." ) - return val + return self - @pydantic.validator("lumped_elements") - @skip_if_fields_missing(["size"]) - def _check_3d_simulation_with_lumped_elements(cls, val, values): + @model_validator(mode="after") + def _check_3d_simulation_with_lumped_elements(self): """Error if Simulation contained lumped elements and is not a 3D simulation""" - size = values.get("size") + val = self.lumped_elements + size = self.size if val and size.count(0.0) > 0: raise ValidationError( - f"'{cls.__name__}' must be a 3D simulation when a 'LumpedElement' is present." + f"'{self.__cls__.__name__}' must be a 3D simulation when a 'LumpedElement' is present." ) - return val + return self - @pydantic.validator("grid_spec", always=True) @abstractmethod - def _validate_auto_grid_wavelength(cls, val, values): + def _validate_auto_grid_wavelength(val): """Check that wavelength can be defined if there is auto grid spec.""" - pass def _monitor_num_cells(self, monitor: Monitor) -> int: """Total number of cells included in monitor based on simulation grid.""" @@ -418,18 +422,19 @@ def num_cells_in_monitor(monitor: Monitor) -> int: return sum(num_cells_in_monitor(mnt) for mnt in monitor.integration_surfaces) return num_cells_in_monitor(monitor) - @pydantic.validator("boundary_spec") - def _validate_boundary_spec_symmetry(cls, val, values): + @model_validator(mode="after") + def _validate_boundary_spec_symmetry(self): """Error if symmetry is imposed along an axis but the boundary conditions are not the same on both sides.""" + val = self.boundary_spec boundaries = [val.x, val.y, val.z] - for ax, symmetry, ax_bounds in zip("xyz", values.get("symmetry"), boundaries): + for ax, symmetry, ax_bounds in zip("xyz", self.symmetry, boundaries): if symmetry != 0 and ax_bounds.plus != ax_bounds.minus: raise ValidationError( f"Symmetry '{symmetry}' along axis {ax} requires the same boundary " f"condition on both sides of the axis." ) - return val + return self @cached_property def _subpixel(self) -> SubpixelSpec: @@ -455,8 +460,8 @@ def plot( source_alpha: float = None, monitor_alpha: float = None, lumped_element_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, fill_structures: bool = True, **patch_kwargs, ) -> Ax: @@ -480,9 +485,9 @@ def plot( Opacity of the lumped elements. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -537,8 +542,8 @@ def plot_eps( source_alpha: float = None, monitor_alpha: float = None, lumped_element_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ax: Ax = None, eps_component: Optional[PermittivityComponent] = None, ) -> Ax: @@ -569,9 +574,9 @@ def plot_eps( Opacity of the lumped elements. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -640,8 +645,8 @@ def plot_structures_eps( cbar: bool = True, reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, eps_component: Optional[PermittivityComponent] = None, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. @@ -670,9 +675,9 @@ def plot_structures_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -722,8 +727,8 @@ def plot_pml( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ax: Ax = None, ) -> Ax: """Plot each of simulation's absorbing boundaries @@ -737,9 +742,9 @@ def plot_pml( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. @@ -764,7 +769,7 @@ def plot_pml( # candidate for removal in 3.0 @cached_property - def bounds_pml(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + def bounds_pml(self) -> tuple[tuple[float, float, float], tuple[float, float, float]]: """Simulation bounds including the PML regions.""" log.warning( "'Simulation.bounds_pml' will be removed in Tidy3D 3.0. " @@ -773,7 +778,7 @@ def bounds_pml(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, fl return self.simulation_bounds @cached_property - def simulation_bounds(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]: + def simulation_bounds(self) -> tuple[tuple[float, float, float], tuple[float, float, float]]: """Simulation bounds including the PML regions.""" pml_thick = self.pml_thicknesses bounds_in = self.bounds @@ -782,7 +787,7 @@ def simulation_bounds(self) -> Tuple[Tuple[float, float, float], Tuple[float, fl return (bounds_min, bounds_max) - def _make_pml_boxes(self, normal_axis: Axis) -> List[Box]: + def _make_pml_boxes(self, normal_axis: Axis) -> list[Box]: """make a list of Box objects representing the pml to plot on plane.""" pml_boxes = [] pml_thicks = self.pml_thicknesses @@ -815,7 +820,7 @@ def _make_pml_box(self, pml_axis: Axis, pml_height: float, sign: int) -> Box: return pml_box # candidate for removal in 3.0 - def eps_bounds(self, freq: float = None) -> Tuple[float, float]: + def eps_bounds(self, freq: float = None) -> tuple[float, float]: """Compute range of (real) permittivity present in the simulation at frequency "freq".""" log.warning( @@ -825,12 +830,12 @@ def eps_bounds(self, freq: float = None) -> Tuple[float, float]: return self.scene.eps_bounds(freq=freq) @cached_property - def pml_thicknesses(self) -> List[Tuple[float, float]]: + def pml_thicknesses(self) -> list[tuple[float, float]]: """Thicknesses (um) of absorbers in all three axes and directions (-, +) Returns ------- - List[Tuple[float, float]] + list[tuple[float, float]] List containing the absorber thickness (micron) in - and + boundaries. """ num_layers = self.num_pml_layers @@ -843,12 +848,12 @@ def pml_thicknesses(self) -> List[Tuple[float, float]]: return pml_thicknesses @cached_property - def internal_override_structures(self) -> List[MeshOverrideStructure]: + def internal_override_structures(self) -> list[MeshOverrideStructure]: """Internal mesh override structures. So far, internal override structures all come from `layer_refinement_specs`. Returns ------- - List[MeshOverrideSructure] + list[MeshOverrideSructure] List of override structures. """ wavelength = self.grid_spec.get_wavelength(self.sources) @@ -860,12 +865,12 @@ def internal_override_structures(self) -> List[MeshOverrideStructure]: ) @cached_property - def internal_snapping_points(self) -> List[CoordinateOptional]: + def internal_snapping_points(self) -> list[CoordinateOptional]: """Internal snapping points. So far, internal snapping points are generated by `layer_refinement_specs`. Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping points coordinates. """ return self.grid_spec.internal_snapping_points( @@ -879,8 +884,8 @@ def plot_lumped_elements( x: float = None, y: float = None, z: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, alpha: float = None, ax: Ax = None, ) -> Ax: @@ -895,9 +900,9 @@ def plot_lumped_elements( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the lumped element, If ``None`` uses Tidy3d default. @@ -925,8 +930,8 @@ def plot_grid( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, override_structures_alpha: float = 1, snapping_points_alpha: float = 1, **kwargs, @@ -941,9 +946,9 @@ def plot_grid( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. override_structures_alpha : float = 1 Opacity of the override structures. @@ -1193,12 +1198,12 @@ def set_plot_params(boundary_edge, lim, side, thickness): # return plot_sim_3d(self, width=width, height=height) @cached_property - def _grid_and_snapping_lines(self) -> Tuple[Grid, List[CoordinateOptional]]: + def _grid_and_snapping_lines(self) -> tuple[Grid, list[CoordinateOptional]]: """FDTD grid spatial locations and information. Returns ------- - Tuple[:class:`.Grid`, List[CoordinateOptional]] + tuple[:class:`.Grid`, list[CoordinateOptional]] :class:`.Grid` storing the spatial locations relevant to the simulation the list of snapping points generated during iterative gap meshing. """ @@ -1250,12 +1255,12 @@ def grid(self) -> Grid: return grid @cached_property - def _gap_meshing_snapping_lines(self) -> List[CoordinateOptional]: + def _gap_meshing_snapping_lines(self) -> list[CoordinateOptional]: """Snapping points resulted from iterative gap meshing. Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping lines resolving thin gaps and strips. """ @@ -1281,7 +1286,7 @@ def num_cells(self) -> int: return np.prod(self.grid.num_cells, dtype=np.int64) @cached_property - def grid_info(self) -> Dict: + def grid_info(self) -> dict: """Dictionary collecting various properties of the grids in the simulation.""" return self.grid.info @@ -1302,7 +1307,7 @@ def _subgrid(self, span_inds: np.ndarray, grid: Grid = None): return Grid(boundaries=Coords(**boundary_dict)) @cached_property - def _periodic(self) -> Tuple[bool, bool, bool]: + def _periodic(self) -> tuple[bool, bool, bool]: """For each dimension, ``True`` if periodic/Bloch boundaries and ``False`` otherwise. We check on both sides but in practice there should be no cases in which a periodic/Bloch BC is on one side only. This is explicitly validated for Bloch, and implicitly done for @@ -1314,12 +1319,12 @@ def _periodic(self) -> Tuple[bool, bool, bool]: return periodic @cached_property - def num_pml_layers(self) -> List[Tuple[float, float]]: + def num_pml_layers(self) -> list[tuple[float, float]]: """Number of absorbing layers in all three axes and directions (-, +). Returns ------- - List[Tuple[float, float]] + list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. """ num_layers = [[0, 0], [0, 0], [0, 0]] @@ -1559,7 +1564,7 @@ def make_eps_data(coords: Coords): coords = grid[coord_key] return make_eps_data(coords) - def _volumetric_structures_grid(self, grid: Grid) -> Tuple[Structure]: + def _volumetric_structures_grid(self, grid: Grid) -> tuple[Structure]: """Generate a tuple of structures wherein any 2D materials are converted to 3D volumetric equivalents, using ``grid`` as the simulation grid.""" @@ -1569,7 +1574,7 @@ def _volumetric_structures_grid(self, grid: Grid) -> Tuple[Structure]: ): return self.structures - def get_dls(geom: Geometry, axis: Axis, num_dls: int) -> List[float]: + def get_dls(geom: Geometry, axis: Axis, num_dls: int) -> list[float]: """Get grid size around the 2D material.""" dls = self._discretize_grid(Box.from_bounds(*geom.bounds), grid=grid).sizes.to_list[ axis @@ -1657,12 +1662,12 @@ def snap_to_grid(geom: Geometry, axis: Axis) -> Geometry: return tuple(new_structures) @cached_property - def volumetric_structures(self) -> Tuple[Structure]: + def volumetric_structures(self) -> tuple[Structure]: """Generate a tuple of structures wherein any 2D materials are converted to 3D volumetric equivalents.""" return self._volumetric_structures_grid(self.grid) - def suggest_mesh_overrides(self, **kwargs) -> List[MeshOverrideStructure]: + def suggest_mesh_overrides(self, **kwargs) -> list[MeshOverrideStructure]: """Generate a :class:`.MeshOverrideStructure` `List` which is automatically generated from structures in the simulation. """ @@ -1679,9 +1684,9 @@ def subsection( region: Box, boundary_spec: BoundarySpec = None, grid_spec: Union[GridSpec, Literal["identical"]] = None, - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = None, - sources: Tuple[SourceType, ...] = None, - monitors: Tuple[MonitorType, ...] = None, + symmetry: tuple[Symmetry, Symmetry, Symmetry] = None, + sources: tuple[SourceType, ...] = None, + monitors: tuple[MonitorType, ...] = None, remove_outside_structures: bool = True, remove_outside_custom_mediums: bool = False, include_pml_cells: bool = False, @@ -1703,14 +1708,14 @@ def subsection( simulation. If ``identical``, then the original grid is transferred directly as a :class:`.CustomGrid`. Note that in the latter case the region of the new simulation is snapped to the original grid lines. - symmetry : Tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None + symmetry : tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None New simulation symmetry. If ``None``, then it is inherited from the original simulation. Note that in this case the size and placement of new simulation domain must be commensurate with the original symmetry. - sources : Tuple[SourceType, ...] = None + sources : tuple[SourceType, ...] = None New list of sources. If ``None``, then the sources intersecting the new simulation domain are inherited from the original simulation. - monitors : Tuple[MonitorType, ...] = None + monitors : tuple[MonitorType, ...] = None New list of monitors. If ``None``, then the monitors intersecting the new simulation domain are inherited from the original simulation. remove_outside_structures : bool = True @@ -1868,10 +1873,10 @@ def subsection( size=new_box.size, grid_spec=grid_spec, boundary_spec=boundary_spec, - monitors=[], - sources=sources, # need wavelength in case of auto grid - symmetry=symmetry, - structures=aux_new_structures, + monitors=(), + sources=tuple(sources), # need wavelength in case of auto grid + symmetry=tuple(symmetry), + structures=tuple(aux_new_structures), deep=deep_copy, ) @@ -1929,11 +1934,11 @@ def subsection( medium=new_sim_medium, grid_spec=grid_spec, boundary_spec=boundary_spec, - monitors=monitors, - sources=sources, - symmetry=symmetry, - structures=aux_new_structures, - lumped_elements=new_lumped_elements, + monitors=tuple(monitors), + sources=tuple(sources), + symmetry=tuple(symmetry), + structures=tuple(aux_new_structures), + lumped_elements=tuple(new_lumped_elements), **kwargs, ) @@ -1942,7 +1947,9 @@ def subsection( # 1) Perform validators not directly related to geometries new_sim = self.updated_copy(**new_sim_dict, deep=deep_copy, validate=True) # 2) Assemble the full simulation without validation - return new_sim.updated_copy(structures=new_structures, deep=deep_copy, validate=False) + return new_sim.updated_copy( + structures=tuple(new_structures), deep=deep_copy, validate=False + ) class Simulation(AbstractYeeGridSimulation): @@ -2048,8 +2055,8 @@ class Simulation(AbstractYeeGridSimulation): * `FDTD Walkthrough `_ """ - boundary_spec: BoundarySpec = pydantic.Field( - BoundarySpec(), + boundary_spec: BoundarySpec = Field( + default_factory=BoundarySpec, title="Boundaries", description="Specification of boundary conditions along each dimension. If ``None``, " "PML boundary conditions are applied on all sides.", @@ -2092,7 +2099,7 @@ class Simulation(AbstractYeeGridSimulation): * `Using FDTD to Compute a Transmission Spectrum `__ """ - courant: float = pydantic.Field( + courant: float = Field( 0.99, title="Normalized Courant Factor", description="Normalized Courant stability factor that is no larger than 1 when CFL " @@ -2174,7 +2181,7 @@ class Simulation(AbstractYeeGridSimulation): * `Numerical dispersion in FDTD `_ """ - lumped_elements: Tuple[LumpedElementType, ...] = pydantic.Field( + lumped_elements: tuple[LumpedElementType, ...] = Field( (), title="Lumped Elements", description="Tuple of lumped elements in the simulation. ", @@ -2212,8 +2219,8 @@ class Simulation(AbstractYeeGridSimulation): * `Using lumped elements in Tidy3D simulations <../../notebooks/LinearLumpedElements.html>`_ """ - grid_spec: GridSpec = pydantic.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) @@ -2359,8 +2366,8 @@ class Simulation(AbstractYeeGridSimulation): * `Numerical dispersion in FDTD `_ """ - medium: MediumType3D = pydantic.Field( - Medium(), + medium: MediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, @@ -2391,7 +2398,7 @@ class Simulation(AbstractYeeGridSimulation): """ - normalize_index: Union[pydantic.NonNegativeInt, None] = pydantic.Field( + normalize_index: Optional[NonNegativeInt] = Field( 0, title="Normalization index", description="Index of the source in the tuple of sources whose spectrum will be used to " @@ -2403,7 +2410,7 @@ class Simulation(AbstractYeeGridSimulation): data. If ``None``, the raw field data is returned. If ``None``, the raw field data is returned unnormalized. """ - monitors: Tuple[annotate_type(MonitorType), ...] = pydantic.Field( + monitors: tuple[discriminated_union(MonitorType), ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " @@ -2419,7 +2426,7 @@ class Simulation(AbstractYeeGridSimulation): All the monitor implementations. """ - sources: Tuple[annotate_type(SourceType), ...] = pydantic.Field( + sources: tuple[discriminated_union(SourceType), ...] = Field( (), title="Sources", description="Tuple of electric current sources injecting fields into the simulation.", @@ -2456,7 +2463,7 @@ class Simulation(AbstractYeeGridSimulation): Frequency and time domain source models. """ - shutoff: pydantic.NonNegativeFloat = pydantic.Field( + shutoff: NonNegativeFloat = Field( 1e-5, title="Shutoff Condition", description="Ratio of the instantaneous integrated E-field intensity to the maximum value " @@ -2471,7 +2478,7 @@ class Simulation(AbstractYeeGridSimulation): Set to ``0`` to disable this feature. """ - structures: Tuple[Structure, ...] = pydantic.Field( + structures: tuple[Structure, ...] = Field( (), title="Structures", description="Tuple of structures present in simulation. " @@ -2536,7 +2543,7 @@ class Simulation(AbstractYeeGridSimulation): * `Structures `_ """ - symmetry: Tuple[Symmetry, Symmetry, Symmetry] = pydantic.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -2566,8 +2573,7 @@ class Simulation(AbstractYeeGridSimulation): """ # TODO: at a later time (once well tested) we could consider making default of RunTimeSpec() - run_time: Union[pydantic.PositiveFloat, RunTimeSpec] = pydantic.Field( - ..., + run_time: Union[PositiveFloat, RunTimeSpec] = Field( title="Run Time", description="Total electromagnetic evolution time in seconds. " "Note: If simulation 'shutoff' is specified, " @@ -2629,25 +2635,25 @@ class Simulation(AbstractYeeGridSimulation): """ Validating setup """ - @pydantic.root_validator(pre=True) - def _update_simulation(cls, values): + @model_validator(mode="before") + def _update_simulation(data): """Update the simulation if it is an earlier version.""" # if no version, assume it's already updated - if "version" not in values: - return values + if "version" not in data: + return data # otherwise, call the updator to update the values dictionary - updater = Updater(sim_dict=values) + updater = Updater(sim_dict=data) return updater.update_to_current() - @pydantic.validator("grid_spec", always=True) - @skip_if_fields_missing(["sources"]) - def _validate_auto_grid_wavelength(cls, val, values): + @model_validator(mode="after") + def _validate_auto_grid_wavelength(self): """Check that wavelength can be defined if there is auto grid spec.""" + val = self.grid_spec if val.wavelength is None and val.auto_grid_used: - _ = val.wavelength_from_sources(sources=values.get("sources")) - return val + _ = val.wavelength_from_sources(sources=self.sources) + return self _sources_in_bounds = assert_objects_in_sim_bounds("sources", strict_inequality=True) _lumped_elements_in_bounds = assert_objects_contained_in_sim_bounds( @@ -2663,34 +2669,33 @@ def _validate_auto_grid_wavelength(cls, val, values): # _resolution_fine_enough = validate_resolution() # _plane_waves_in_homo = validate_plane_wave_intersections() - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["symmetry"]) - def bloch_with_symmetry(cls, val, values): + @model_validator(mode="after") + def bloch_with_symmetry(self): """Error if a Bloch boundary is applied with symmetry""" + val = self.boundary_spec boundaries = val.to_list - symmetry = values.get("symmetry") + symmetry = self.symmetry for dim, boundary in enumerate(boundaries): num_bloch = sum(isinstance(bnd, BlochBoundary) for bnd in boundary) if num_bloch > 0 and symmetry[dim] != 0: raise SetupError( f"Bloch boundaries cannot be used with a symmetry along dimension {dim}." ) - return val + return self - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["medium", "size", "structures", "sources"]) - def plane_wave_boundaries(cls, val, values): + @model_validator(mode="after") + def plane_wave_boundaries(self): """Error if there are plane wave sources incompatible with boundary conditions.""" - boundaries = val.to_list - sources = values.get("sources") - size = values.get("size") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + sim_medium = self.medium + structures = self.structures for source_ind, source in enumerate(sources): if not isinstance(source, PlaneWave): continue - _, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + _, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) medium_set = Scene.intersecting_media(source, structures) medium = medium_set.pop() if medium_set else sim_medium @@ -2721,7 +2726,7 @@ def plane_wave_boundaries(cls, val, values): else: num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch > 0: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=source_ind, bloch_vec=boundary[0].bloch_vec, @@ -2729,28 +2734,27 @@ def plane_wave_boundaries(cls, val, values): medium=medium, domain_size=size[tan_dir], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["boundary_spec", "medium", "size", "structures", "sources"]) - def bloch_boundaries_diff_mnt(cls, val, values): + @model_validator(mode="after") + def bloch_boundaries_diff_mnt(self): """Error if there are diffraction monitors incompatible with boundary conditions.""" - monitors = val + monitors = self.monitors - if not val or not any(isinstance(mnt, DiffractionMonitor) for mnt in monitors): - return val + if not monitors or not any(isinstance(mnt, DiffractionMonitor) for mnt in monitors): + return self - boundaries = values.get("boundary_spec").to_list - sources = values.get("sources") - size = values.get("size") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + sim_medium = self.medium + structures = self.structures for source_ind, source in enumerate(sources): if not isinstance(source, PlaneWave): continue - _, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + _, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) medium_set = Scene.intersecting_media(source, structures) medium = medium_set.pop() if medium_set else sim_medium @@ -2760,7 +2764,7 @@ def bloch_boundaries_diff_mnt(cls, val, values): # check the Bloch boundary + angled plane wave case num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch > 0: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=source_ind, bloch_vec=boundary[0].bloch_vec, @@ -2769,18 +2773,17 @@ def bloch_boundaries_diff_mnt(cls, val, values): domain_size=size[tan_dir], has_diff_mnt=True, ) - return val + return self - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures", "sources"]) - def tfsf_boundaries(cls, val, values): + @model_validator(mode="after") + def tfsf_boundaries(self): """Error if the boundary conditions are incompatible with TFSF sources, if any.""" - boundaries = val.to_list - sources = values.get("sources") - size = values.get("size") - center = values.get("center") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + center = self.center + sim_medium = self.medium + structures = self.structures sim_bounds = [ [c - s / 2.0 for c, s in zip(center, size)], [c + s / 2.0 for c, s in zip(center, size)], @@ -2789,7 +2792,7 @@ def tfsf_boundaries(cls, val, values): if not isinstance(source, TFSF): continue - norm_dir, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + norm_dir, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) src_bounds = source.bounds # make a dummy source that represents the injection surface to get the intersecting @@ -2826,7 +2829,7 @@ def tfsf_boundaries(cls, val, values): # Bloch vector has been correctly set, similar to the check for plane waves num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch == 2: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=src_idx, bloch_vec=boundary[0].bloch_vec, @@ -2843,33 +2846,30 @@ def tfsf_boundaries(cls, val, values): "unless that boundary is 'Periodic' or 'BlochBoundary'." ) - return val + return self - @pydantic.validator("sources", always=True) - @skip_if_fields_missing(["symmetry"]) - def tfsf_with_symmetry(cls, val, values): + @model_validator(mode="after") + def tfsf_with_symmetry(self): """Error if a TFSF source is applied with symmetry""" - symmetry = values.get("symmetry") - for source in val: - if isinstance(source, TFSF) and not all(sym == 0 for sym in symmetry): + for source in self.sources: + if isinstance(source, TFSF) and not all(sym == 0 for sym in self.symmetry): raise SetupError("TFSF sources cannot be used with symmetries.") - return val + return self @staticmethod - def _get_fixed_angle_sources(sources: Tuple[SourceType, ...]) -> Tuple[SourceType, ...]: + def _get_fixed_angle_sources(sources: tuple[SourceType, ...]) -> tuple[SourceType, ...]: """Get list of plane wave sources with ``FixedAngleSpec``.""" return [ source for source in sources if isinstance(source, PlaneWave) and source._is_fixed_angle ] - @pydantic.root_validator() - @skip_if_fields_missing(["sources", "structures", "medium", "monitors"], root=True) - def check_fixed_angle_components(cls, values): + @model_validator(mode="after") + def check_fixed_angle_components(self): """Error if a fixed-angle plane wave is combined with other sources or fully anisotropic mediums or gain mediums.""" - fixed_angle_sources = cls._get_fixed_angle_sources(values["sources"]) + fixed_angle_sources = self._get_fixed_angle_sources(self.sources) if len(fixed_angle_sources) > 0: if len(fixed_angle_sources) > 1: @@ -2877,9 +2877,9 @@ def check_fixed_angle_components(cls, values): "A fixed-angle plane wave source cannot be combined with other sources." ) - structures = values.get("structures") + structures = self.structures structures = structures or [] - medium_bg = values.get("medium") + medium_bg = self.medium mediums = [medium_bg] + [structure.medium for structure in structures] if any(med.is_fully_anisotropic for med in mediums): @@ -2902,13 +2902,60 @@ def check_fixed_angle_components(cls, values): "Fixed-angle plane wave sources cannot be used in the presence of gain materials." ) - if any(isinstance(mnt, TimeMonitor) for mnt in values["monitors"]): + if any(isinstance(mnt, TimeMonitor) for mnt in self.monitors): raise SetupError("Time monitors cannot be used in fixed-angle simulations.") - return values + return self + + @model_validator(mode="after") + def boundaries_for_zero_dims(self): + """Error if absorbing boundaries, bloch boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" + val = self.boundary_spec + boundaries = val.to_list + size = self.size + symmetry = self.symmetry + axis_names = "xyz" + + for dim, (boundary, symmetry_dim, size_dim) in enumerate(zip(boundaries, symmetry, size)): + if size_dim == 0: + axis = axis_names[dim] + num_absorbing_bdries = sum(isinstance(bnd, AbsorberSpec) for bnd in boundary) + num_bloch_bdries = sum(isinstance(bnd, BlochBoundary) for bnd in boundary) + + if num_absorbing_bdries > 0: + raise SetupError( + f"The simulation has zero size along the {axis} axis, so " + "using a PML or absorbing boundary along that axis is incorrect. " + f"Use either 'Periodic' or 'BlochBoundary' along {axis}." + ) + + if num_bloch_bdries > 0: + raise SetupError( + f"The simulation has zero size along the {axis} axis, " + "using a Bloch boundary along such an axis is not supported because of " + "the Bloch vector definition in units of '2 * pi / (size along dimension)'. Use a small " + "but nonzero size along the dimension instead." + ) + + if symmetry_dim != 0: + raise SetupError( + f"The simulation has zero size along the {axis} axis, so " + "using symmetry along that axis is incorrect. Use 'PECBoundary' " + "or 'PMCBoundary' to select source polarization if needed and set " + f"Simulation.symmetry to 0 along {axis}." + ) - @pydantic.validator("sources", always=True) - def _validate_num_sources(cls, val): + if boundary[0] != boundary[1]: + raise SetupError( + f"The simulation has zero size along the {axis} axis. " + f"The boundary condition for {axis} plus and {axis} " + "minus must be the same." + ) + + return self + + @field_validator("sources") + def _validate_num_sources(val): """Error if too many sources present.""" if val is None: @@ -2923,8 +2970,8 @@ def _validate_num_sources(cls, val): return val - @pydantic.validator("structures", always=True) - def _validate_2d_geometry_has_2d_medium(cls, val, values): + @field_validator("structures") + def _validate_2d_geometry_has_2d_medium(val): """Warn if a geometry bounding box has zero size in a certain dimension.""" if val is None: @@ -2947,8 +2994,8 @@ def _validate_2d_geometry_has_2d_medium(cls, val, values): return val - @pydantic.validator("structures", always=True) - def _validate_incompatible_material_intersections(cls, val, values): + @field_validator("structures") + def _validate_incompatible_material_intersections(val): """Check for intersections of incompatible materials.""" structures = val incompatible_indices = [] @@ -2977,20 +3024,20 @@ def _validate_incompatible_material_intersections(cls, val, values): ) return val - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["sources", "center", "size", "structures"]) - def _structures_not_close_pml(cls, val, values): + @model_validator(mode="after") + def _structures_not_close_pml(self): """Warn if any structures lie at the simulation boundaries.""" + val = self.boundary_spec - sim_box = Box(size=values.get("size"), center=values.get("center")) + sim_box = Box(size=self.size, center=self.center) sim_bound_min, sim_bound_max = sim_box.bounds boundaries = val.to_list - structures = values.get("structures") - sources = values.get("sources") + structures = self.structures + sources = self.sources if (not structures) or (not sources): - return val + return self with log as consolidated_logger: @@ -3034,19 +3081,18 @@ def warn(istruct, side): ): warn(istruct, axis + "-max") - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "structures"]) - def _warn_monitor_mediums_frequency_range(cls, val, values): + @model_validator(mode="after") + def _warn_monitor_mediums_frequency_range(self): """Warn user if any DFT monitors have frequencies outside of medium frequency range.""" + val = self.monitors if val is None: - return val + return self - structures = values.get("structures") - structures = structures or [] - medium_bg = values.get("medium") + structures = self.structures or [] + medium_bg = self.medium mediums = [medium_bg] + [structure.medium for structure in structures] with log as consolidated_logger: @@ -3064,7 +3110,7 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): # make sure medium frequency range includes all monitor frequencies fmin_med, fmax_med = medium.frequency_range - sci_fmin_med, sci_fmax_med = cls._scientific_notation(fmin_med, fmax_med) + sci_fmin_med, sci_fmax_med = self._scientific_notation(fmin_med, fmax_med) if fmin_mon < fmin_med or fmax_mon > fmax_med: if medium_index == 0: @@ -3088,27 +3134,26 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): "This can cause inaccuracies in the recorded results.", custom_loc=custom_loc, ) + return self - return val - - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["sources"]) - def _warn_monitor_simulation_frequency_range(cls, val, values): + @model_validator(mode="after") + def _warn_monitor_simulation_frequency_range(self): """Warn if any DFT monitors have frequencies outside of the simulation frequency range.""" + val = self.monitors if val is None: - return val + return self - source_ranges = [source.source_time.frequency_range() for source in values["sources"]] + source_ranges = [source.source_time.frequency_range() for source in self.sources] if not source_ranges: # Commented out to eliminate this message from Mode real time log in GUI # TODO: Bring it back when it doesn't interfere with mode solver # log.info("No sources in simulation.") - return val + return self freq_min = min((freq_range[0] for freq_range in source_ranges), default=0.0) freq_max = max((freq_range[1] for freq_range in source_ranges), default=0.0) - sci_fmin, sci_fmax = cls._scientific_notation(freq_min, freq_max) + sci_fmin, sci_fmax = self._scientific_notation(freq_min, freq_max) with log as consolidated_logger: for monitor_index, monitor in enumerate(val): @@ -3123,15 +3168,14 @@ def _warn_monitor_simulation_frequency_range(cls, val, values): "(Hz) as defined by the sources.", custom_loc=["monitors", monitor_index, "freqs"], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["boundary_spec"]) - def diffraction_monitor_boundaries(cls, val, values): + @model_validator(mode="after") + def diffraction_monitor_boundaries(self): """If any :class:`.DiffractionMonitor` exists, ensure boundary conditions in the transverse directions are periodic or Bloch.""" - monitors = val - boundary_spec = values.get("boundary_spec") + monitors = self.monitors + boundary_spec = self.boundary_spec for monitor in monitors: if isinstance(monitor, DiffractionMonitor): _, (n_x, n_y) = monitor.pop_axis(["x", "y", "z"], axis=monitor.normal_axis) @@ -3148,26 +3192,26 @@ def diffraction_monitor_boundaries(cls, val, values): f"The 'DiffractionMonitor' {monitor.name} requires periodic " f"or Bloch boundaries along dimensions {n_x} and {n_y}." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _projection_monitors_homogeneous(cls, val, values): + @model_validator(mode="after") + def _projection_monitors_homogeneous(self): """Error if any field projection monitor is not in a homogeneous region.""" + val = self.monitors if val is None: - return val + return self # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=values.get("size"), - center=values.get("center"), + size=self.size, + center=self.center, ), - medium=values.get("medium"), + medium=self.medium, ) - structures = values.get("structures") or [] + structures = self.structures or [] total_structures = [structure_bg] + list(structures) with log as consolidated_logger: @@ -3195,15 +3239,14 @@ def _projection_monitors_homogeneous(cls, val, values): custom_loc=["monitors", monitor_ind], ) - return val + return self - @pydantic.validator("monitors", always=True) - def _projection_direction(cls, val, values): + @field_validator("monitors") + def _projection_direction(val): """Warn if field projection observation points are behind surface projection monitors.""" # This validator is in simulation.py rather than monitor.py because volume monitors are # eventually converted to their bounding surface projection monitors, in which case we # do not want this validator to be triggered. - if val is None: return val @@ -3260,15 +3303,16 @@ def _projection_direction(cls, val, values): return val - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def proj_distance_for_approx(cls, val, values): + @model_validator(mode="after") + def proj_distance_for_approx(self): """Warn if projection distance for projection monitors is not large compared to monitor or, simulation size, yet far_field_approx is True.""" + val = self.monitors + if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size with log as consolidated_logger: for monitor_ind, monitor in enumerate(val): @@ -3287,18 +3331,18 @@ def proj_distance_for_approx(cls, val, values): "size of the monitor that records near fields.", custom_loc=["monitors", monitor_ind], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["center", "size"]) - def _integration_surfaces_in_bounds(cls, val, values): + @model_validator(mode="after") + def _integration_surfaces_in_bounds(self): """Error if all of the integration surfaces are outside of the simulation domain.""" + val = self.monitors if val is None: - return val + return self - sim_center = values.get("center") - sim_size = values.get("size") + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) for mnt in (mnt for mnt in val if isinstance(mnt, SurfaceIntegrationMonitor)): @@ -3308,17 +3352,17 @@ def _integration_surfaces_in_bounds(cls, val, values): "simulation bounds." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def _projection_monitors_distance(cls, val, values): + @model_validator(mode="after") + def _projection_monitors_distance(self): """Warn if the projection distance is large for exact projections.""" + val = self.monitors if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size with log as consolidated_logger: for idx, monitor in enumerate(val): @@ -3341,11 +3385,10 @@ def _projection_monitors_distance(cls, val, values): "available.", custom_loc=["monitors", idx, "proj_distance"], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def _projection_mnts_2d(cls, val, values): + @model_validator(mode="after") + def _projection_mnts_2d(self): """ Validate if the field projection monitor is set up for a 2D simulation and ensure the observation parameters are configured correctly. @@ -3357,16 +3400,17 @@ def _projection_mnts_2d(cls, val, values): Note: Exact far field projection is not available yet. Currently, only ``far_field_approx = True`` is supported. """ + val = self.monitors if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size # Validation if is 3D simulation non_zero_dims = sum(1 for size in sim_size if size != 0) if non_zero_dims == 3: - return val + return self if sim_size[0] == 0: plane = "y-z" @@ -3446,15 +3490,14 @@ def _projection_mnts_2d(cls, val, values): f"'{monitor.name}' should be set to '[0]'." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "structures"]) - def diffraction_and_directivity_monitor_medium(cls, val, values): + @model_validator(mode="after") + def diffraction_and_directivity_monitor_medium(self): """If any :class:`.DiffractionMonitor` or :class:`.DirectivityMonitor` exists, ensure it does not lie in a lossy medium.""" - monitors = val - structures = values.get("structures") - medium = values.get("medium") + monitors = self.monitors + structures = self.structures + medium = self.medium for monitor in monitors: if isinstance(monitor, (DiffractionMonitor, DirectivityMonitor)): medium_set = Scene.intersecting_media(monitor, structures) @@ -3465,23 +3508,23 @@ def diffraction_and_directivity_monitor_medium(cls, val, values): _, index_k = medium.nk_model(frequency=freqs) if not np.all(index_k == 0): raise SetupError(f"'{monitor.type}' must not lie in a lossy medium.") - return val + return self - @pydantic.validator("grid_spec", always=True) - @skip_if_fields_missing(["medium", "sources", "structures"]) - def _warn_grid_size_too_small(cls, val, values): + @model_validator(mode="after") + def _warn_grid_size_too_small(self): """Warn user if any grid size is too large compared to minimum wavelength in material.""" + val = self.grid_spec if val is None: - return val + return self - structures = values.get("structures") + structures = self.structures structures = structures or [] - medium_bg = values.get("medium") + medium_bg = self.medium mediums = [medium_bg] + [structure.to_static().medium for structure in structures] with log as consolidated_logger: - for source_index, source in enumerate(values.get("sources")): + for source_index, source in enumerate(self.sources): freq0 = source.source_time.freq0 for medium_index, medium in enumerate(mediums): @@ -3526,28 +3569,28 @@ def _warn_grid_size_too_small(cls, val, values): ) # TODO: warn about custom grid spec - return val + return self - @pydantic.validator("sources", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _source_homogeneous_isotropic(cls, val, values): + @model_validator(mode="after") + def _source_homogeneous_isotropic(self): """Error if a plane wave or gaussian beam source is not in a homogeneous and isotropic region. """ + val = self.sources if val is None: - return val + return self # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=values.get("size"), - center=values.get("center"), + size=self.size, + center=self.center, ), - medium=values.get("medium"), + medium=self.medium, ) - structures = values.get("structures") or [] + structures = self.structures or [] total_structures = [structure_bg] + list(structures) # for each plane wave in the sources list @@ -3594,18 +3637,18 @@ def _source_homogeneous_isotropic(cls, val, values): "A fixed angle plane wave can only be injected into a homogeneous isotropic" "dispersionless medium." ) - return val + return self - @pydantic.validator("normalize_index", always=True) - @skip_if_fields_missing(["sources"]) - def _check_normalize_index(cls, val, values): + @model_validator(mode="after") + def _check_normalize_index(self): """Check validity of normalize index in context of simulation.sources.""" + val = self.normalize_index # not normalizing if val is None: - return val + return self - sources = values.get("sources") + sources = self.sources num_sources = len(sources) if num_sources > 0: # No check if no sources, but it should be irrelevant anyway @@ -3633,21 +3676,24 @@ def _check_normalize_index(cls, val, values): "source is only meaningful if field decay occurs." ) - return val + return self """ Post-init validators """ + @property def _post_init_validators(self) -> None: - """Call validators taking z`self` that get run after init.""" - _ = self.scene - self._validate_no_structures_pml() - self._validate_tfsf_nonuniform_grid() - self._validate_tfsf_aux_sources() - self._validate_nonlinear_specs() - self._validate_custom_source_time() - self._validate_mode_object_bends() - self._warn_mode_object_pml() - self._warn_rf_license() + """Return validators taking z`self` that get run after init.""" + return ( + lambda: self.scene, + self._validate_no_structures_pml, + self._validate_tfsf_nonuniform_grid, + self._validate_tfsf_aux_sources, + self._validate_nonlinear_specs, + self._validate_custom_source_time, + self._validate_mode_object_bends, + self._warn_mode_object_pml, + self._warn_rf_license, + ) def _warn_rf_license(self): """ @@ -3683,6 +3729,8 @@ def _warn_rf_license(self): msg += rf_component_breakdown_msg log.warning(msg, log_once=True) + return self + def _warn_mode_object_pml(self) -> None: """Warn if any mode objects have large pml.""" from .mode.mode_solver import ModeSolver @@ -3748,7 +3796,7 @@ def _validate_no_structures_pml(self) -> None: bound_spec = self.boundary_spec.to_list with log as consolidated_logger: - for i, structure in enumerate(self.structures): + for i, structure in enumerate(self.static_structures): geo_bounds = structure.geometry.bounds warn = False # will only warn once per structure for sim_bound, geo_bound, pml_thick, bound_dim, pm_val in zip( @@ -3890,7 +3938,7 @@ def _validate_nonlinear_specs(self) -> None: ) @cached_property - def aux_fields(self) -> List[str]: + def aux_fields(self) -> list[str]: """All aux fields available in the simulation.""" fields = [] for medium in self.scene.mediums: @@ -3995,7 +4043,7 @@ def _validate_monitor_size(self) -> None: def _validate_modes_size(self) -> None: """Warn if mode sources or monitors have a large number of points.""" - def warn_mode_size(monitor: AbstractModeMonitor, msg_header: str, custom_loc: List): + def warn_mode_size(monitor: AbstractModeMonitor, msg_header: str, custom_loc: list): """Warn if a mode component has a large number of points.""" num_cells = np.prod(self.discretize_monitor(monitor).num_cells) if num_cells > WARN_MODE_NUM_CELLS: @@ -4034,7 +4082,7 @@ def _validate_num_cells_in_mode_objects(self) -> None: of grid cells in their transverse dimensions.""" def check_num_cells( - mode_object: Tuple[ModeSource, ModeMonitor], normal_axis: Axis, msg_header: str + mode_object: tuple[ModeSource, ModeMonitor], normal_axis: Axis, msg_header: str ): disc_grid = self.discretize(mode_object) _, check_axes = Box.pop_axis([0, 1, 2], axis=normal_axis) @@ -4106,7 +4154,7 @@ def _validate_freq_monitors_freq_range(self) -> None: ) @cached_property - def monitors_data_size(self) -> Dict[str, float]: + def monitors_data_size(self) -> dict[str, float]: """Dictionary mapping monitor names to their estimated storage size in bytes.""" data_size = {} for monitor in self.monitors: @@ -4309,12 +4357,12 @@ def _run_time(self) -> float: # candidate for removal in 3.0 @cached_property - def mediums(self) -> Set[MediumType]: + def mediums(self) -> set[MediumType]: """Returns set of distinct :class:`.AbstractMedium` in simulation. Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums in the simulation. """ log.warning( @@ -4325,14 +4373,14 @@ def mediums(self) -> Set[MediumType]: # candidate for removal in 3.0 @cached_property - def medium_map(self) -> Dict[MediumType, pydantic.NonNegativeInt]: + def medium_map(self) -> dict[MediumType, NonNegativeInt]: """Returns dict mapping medium to index in material. ``medium_map[medium]`` returns unique global index of :class:`.AbstractMedium` in simulation. Returns ------- - Dict[:class:`.AbstractMedium`, int] + dict[:class:`.AbstractMedium`, int] Mapping between distinct mediums to index in simulation. """ @@ -4354,7 +4402,7 @@ def background_structure(self) -> Structure: return self.scene.background_structure @cached_property - def _fixed_angle_sources(self) -> Tuple[SourceType, ...]: + def _fixed_angle_sources(self) -> tuple[SourceType, ...]: """List of plane wave sources with ``FixedAngleSpec``.""" return self._get_fixed_angle_sources(self.sources) @@ -4366,8 +4414,8 @@ def _is_fixed_angle(self) -> bool: # candidate for removal in 3.0 @staticmethod def intersecting_media( - test_object: Box, structures: Tuple[Structure, ...] - ) -> Tuple[MediumType, ...]: + test_object: Box, structures: tuple[Structure, ...] + ) -> tuple[MediumType, ...]: """From a given list of structures, returns a list of :class:`.AbstractMedium` associated with those structures that intersect with the ``test_object``, if it is a surface, or its surfaces, if it is a volume. @@ -4376,12 +4424,12 @@ def intersecting_media( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums that intersect with the given planar object. """ @@ -4394,8 +4442,8 @@ def intersecting_media( # candidate for removal in 3.0 @staticmethod def intersecting_structures( - test_object: Box, structures: Tuple[Structure, ...] - ) -> Tuple[Structure, ...]: + test_object: Box, structures: tuple[Structure, ...] + ) -> tuple[Structure, ...]: """From a given list of structures, returns a list of :class:`.Structure` that intersect with the ``test_object``, if it is a surface, or its surfaces, if it is a volume. @@ -4403,12 +4451,12 @@ def intersecting_structures( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.Structure`] + list[:class:`.Structure`] Set of distinct structures that intersect with the given surface, or with the surfaces of the given volume. """ @@ -4492,12 +4540,10 @@ def to_gdstk( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt] - ] = None, - ) -> List: + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, + ) -> list: """Convert a simulation's planar slice to a .gds type polygon list. Parameters @@ -4562,10 +4608,8 @@ def to_gdspy( x: float = None, y: float = None, z: float = None, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt] - ] = None, - ) -> List: + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, + ) -> list: """Convert a simulation's planar slice to a .gds type polygon list. Parameters @@ -4622,11 +4666,9 @@ def to_gds( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt] - ] = None, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, ) -> None: """Append the simulation structures to a .gds cell. @@ -4687,11 +4729,9 @@ def to_gds_file( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt] - ] = None, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, gds_cell_name: str = "MAIN", ) -> None: """Append the simulation structures to a .gds cell. @@ -4764,7 +4804,7 @@ def frequency_range(self) -> FreqBound: Returns ------- - Tuple[float, float] + tuple[float, float] Minimum and maximum frequencies of the power spectrum of the sources. """ source_ranges = [source.source_time.frequency_range() for source in self.sources] @@ -4856,7 +4896,7 @@ def self_structure(self) -> Structure: return self.scene.background_structure @cached_property - def all_structures(self) -> List[Structure]: + def all_structures(self) -> list[Structure]: """List of all structures in the simulation (including the ``Simulation.medium``).""" return self.scene.all_structures @@ -5059,7 +5099,7 @@ def nyquist_step(self) -> int: return nyquist_step @property - def custom_datasets(self) -> List[Dataset]: + def custom_datasets(self) -> list[Dataset]: """List of custom datasets for verification purposes. If the list is not empty, then the simulation needs to be exported to hdf5 to store the data. """ diff --git a/tidy3d/components/source/base.py b/tidy3d/components/source/base.py index 09472b1f6a..3b63ad8520 100644 --- a/tidy3d/components/source/base.py +++ b/tidy3d/components/source/base.py @@ -1,11 +1,8 @@ """Defines an abstract base for electromagnetic sources.""" -from __future__ import annotations - from abc import ABC -from typing import Tuple -import pydantic.v1 as pydantic +from pydantic import Field, field_validator from ..base import cached_property from ..base_sim.source import AbstractSource @@ -25,8 +22,7 @@ class Source(Box, AbstractSource, ABC): """Abstract base class for all sources.""" - source_time: SourceTimeType = pydantic.Field( - ..., + source_time: SourceTimeType = Field( title="Source Time", description="Specification of the source time-dependence.", discriminator=TYPE_TAG_STR, @@ -49,20 +45,20 @@ def _injection_axis(self): return None @cached_property - def _dir_vector(self) -> Tuple[float, float, float]: + def _dir_vector(self) -> tuple[float, float, float]: """Returns a vector indicating the source direction for arrow plotting, if not None.""" return None @cached_property - def _pol_vector(self) -> Tuple[float, float, float]: + def _pol_vector(self) -> tuple[float, float, float]: """Returns a vector indicating the source polarization for arrow plotting, if not None.""" return None _warn_traced_center = _warn_unsupported_traced_argument("center") _warn_traced_size = _warn_unsupported_traced_argument("size") - @pydantic.validator("source_time", always=True) - def _freqs_lower_bound(cls, val): + @field_validator("source_time") + def _freqs_lower_bound(val): """Raise validation error if central frequency is too low.""" _assert_min_freq(val.freq0, msg_start="'source_time.freq0'") return val diff --git a/tidy3d/components/source/current.py b/tidy3d/components/source/current.py index 4035b1fbcc..1d2a0a6b19 100644 --- a/tidy3d/components/source/current.py +++ b/tidy3d/components/source/current.py @@ -3,9 +3,9 @@ from __future__ import annotations from abc import ABC -from typing import Optional, Tuple +from typing import Optional -import pydantic.v1 as pydantic +from pydantic import Field from typing_extensions import Literal from ...constants import MICROMETER @@ -20,26 +20,25 @@ class CurrentSource(Source, ABC): """Source implements a current distribution directly.""" - polarization: Polarization = pydantic.Field( - ..., + polarization: Polarization = Field( title="Polarization", description="Specifies the direction and type of current component.", ) @cached_property - def _pol_vector(self) -> Tuple[float, float, float]: + def _pol_vector(self) -> tuple[float, float, float]: """Returns a vector indicating the source polarization for arrow plotting, if not None.""" component = self.polarization[-1] # 'x' 'y' or 'z' pol_axis = "xyz".index(component) pol_vec = [0, 0, 0] pol_vec[pol_axis] = 1 - return pol_vec + return tuple(pol_vec) class ReverseInterpolatedSource(Source): """Abstract source that allows reverse-interpolation along zero-sized dimensions.""" - interpolate: bool = pydantic.Field( + interpolate: bool = Field( True, title="Enable Interpolation", description="Handles reverse-interpolation of zero-size dimensions of the source. " @@ -48,7 +47,7 @@ class ReverseInterpolatedSource(Source): "placement at the specified location using linear interpolation.", ) - confine_to_bounds: bool = pydantic.Field( + confine_to_bounds: bool = Field( False, title="Confine to Analytical Bounds", description="If ``True``, any source amplitudes which, after discretization, fall beyond " @@ -99,7 +98,7 @@ class PointDipole(CurrentSource, ReverseInterpolatedSource): * `Adjoint optimization of quantum emitter light extraction to an integrated waveguide <../../notebooks/AdjointPlugin12LightExtractor.html>`_ """ - size: Tuple[Literal[0], Literal[0], Literal[0]] = pydantic.Field( + size: tuple[Literal[0], Literal[0], Literal[0]] = Field( (0, 0, 0), title="Size", description="Size in x, y, and z directions, constrained to ``(0, 0, 0)``.", @@ -149,8 +148,7 @@ class CustomCurrentSource(ReverseInterpolatedSource): * `Defining spatially-varying sources <../../notebooks/CustomFieldSource.html>`_ """ - current_dataset: Optional[FieldDataset] = pydantic.Field( - ..., + current_dataset: Optional[FieldDataset] = Field( title="Current Dataset", description=":class:`.FieldDataset` containing the desired frequency-domain " "electric and magnetic current patterns to inject.", diff --git a/tidy3d/components/source/field.py b/tidy3d/components/source/field.py index 60e172d76d..40c490727b 100644 --- a/tidy3d/components/source/field.py +++ b/tidy3d/components/source/field.py @@ -3,25 +3,19 @@ from __future__ import annotations from abc import ABC -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, NonNegativeInt, PositiveFloat, field_validator, model_validator from ...constants import GLANCING_CUTOFF, MICROMETER, RADIAN, inf from ...exceptions import SetupError from ...log import log -from ..base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from ..base import Tidy3dBaseModel, cached_property from ..data.dataset import FieldDataset from ..data.validators import validate_can_interpolate, validate_no_nans from ..mode_spec import ModeSpec -from ..types import ( - TYPE_TAG_STR, - Ax, - Axis, - Coordinate, - Direction, -) +from ..types import TYPE_TAG_STR, Ax, Axis, Coordinate, Direction from ..validators import ( assert_plane, assert_single_freq_in_range, @@ -75,28 +69,27 @@ class VolumeSource(Source, ABC): class DirectionalSource(FieldSource, ABC): """A Field source that propagates in a given direction.""" - direction: Direction = pydantic.Field( - ..., + direction: Direction = Field( title="Direction", description="Specifies propagation in the positive or negative direction of the injection " "axis.", ) @cached_property - def _dir_vector(self) -> Tuple[float, float, float]: + def _dir_vector(self) -> tuple[float, float, float]: """Returns a vector indicating the source direction for arrow plotting, if not None.""" if self._injection_axis is None: return None dir_vec = [0, 0, 0] dir_vec[int(self._injection_axis)] = 1 if self.direction == "+" else -1 - return dir_vec + return tuple(dir_vec) class BroadbandSource(Source, ABC): """A source with frequency dependent field distributions.""" # Default as for analytic beam sources; overwrriten for ModeSource below - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 3, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of the injected " @@ -121,8 +114,8 @@ def _chebyshev_freq_grid(self, freq_min, freq_max): cheb_points = np.cos(np.pi * np.flip(uni_points)) return freq_avg + freq_diff * cheb_points - @pydantic.validator("num_freqs", always=True, allow_reuse=True) - def _warn_if_large_number_of_freqs(cls, val): + @field_validator("num_freqs") + def _warn_if_large_number_of_freqs(val): """Warn if a large number of frequency points is requested.""" if val is None: @@ -235,8 +228,8 @@ class CustomFieldSource(FieldSource, PlanarSource): * `Defining spatially-varying sources <../../notebooks/CustomFieldSource.html>`_ """ - field_dataset: Optional[FieldDataset] = pydantic.Field( - ..., + field_dataset: Optional[FieldDataset] = Field( + None, title="Field Dataset", description=":class:`.FieldDataset` containing the desired frequency-domain " "fields patterns to inject. At least one tangential field component must be specified.", @@ -247,20 +240,19 @@ class CustomFieldSource(FieldSource, PlanarSource): _field_dataset_single_freq = assert_single_freq_in_range("field_dataset") _can_interpolate = validate_can_interpolate("field_dataset") - @pydantic.validator("field_dataset", always=True) - @skip_if_fields_missing(["size"]) - def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> FieldDataset: + @model_validator(mode="after") + def _tangential_component_defined(self) -> FieldDataset: """Assert that at least one tangential field component is provided.""" + val = self.field_dataset if val is None: - return val - size = values.get("size") - normal_axis = size.index(0.0) - _, (cmp1, cmp2) = cls.pop_axis("xyz", axis=normal_axis) + return self + normal_axis = self.size.index(0.0) + _, (cmp1, cmp2) = self.pop_axis("xyz", axis=normal_axis) for field in "EH": for cmp_name in (cmp1, cmp2): tangential_field = field + cmp_name if tangential_field in val.field_components: - return val + return self raise SetupError("No tangential field found in the suppled 'field_dataset'.") @@ -281,14 +273,14 @@ class AngledFieldSource(DirectionalSource, ABC): """ - angle_theta: float = pydantic.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis.", units=RADIAN, ) - angle_phi: float = pydantic.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -296,7 +288,7 @@ class AngledFieldSource(DirectionalSource, ABC): units=RADIAN, ) - pol_angle: float = pydantic.Field( + pol_angle: float = Field( 0, title="Polarization Angle", description="Specifies the angle between the electric field polarization of the " @@ -310,8 +302,8 @@ class AngledFieldSource(DirectionalSource, ABC): units=RADIAN, ) - @pydantic.validator("angle_theta", allow_reuse=True, always=True) - def glancing_incidence(cls, val): + @field_validator("angle_theta") + def glancing_incidence(val): """Warn if close to glancing incidence.""" if np.abs(np.pi / 2 - val) < GLANCING_CUTOFF: log.warning( @@ -322,7 +314,7 @@ def glancing_incidence(cls, val): return val @cached_property - def _dir_vector(self) -> Tuple[float, float, float]: + def _dir_vector(self) -> tuple[float, float, float]: """Source direction normal vector in cartesian coordinates.""" # Propagation vector assuming propagation along z @@ -335,7 +327,7 @@ def _dir_vector(self) -> Tuple[float, float, float]: return self.unpop_axis(dz, (dx, dy), axis=self._injection_axis) @cached_property - def _pol_vector(self) -> Tuple[float, float, float]: + def _pol_vector(self) -> tuple[float, float, float]: """Source polarization normal vector in cartesian coordinates.""" # Polarization vector assuming propagation along z @@ -412,13 +404,13 @@ class ModeSource(DirectionalSource, PlanarSource, BroadbandSource): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - mode_spec: ModeSpec = pydantic.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", ) - mode_index: pydantic.NonNegativeInt = pydantic.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="Index into the collection of modes returned by mode solver. " @@ -427,7 +419,7 @@ class ModeSource(DirectionalSource, PlanarSource, BroadbandSource): "``num_modes`` in the solver will be set to ``mode_index + 1``.", ) - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 1, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of injected " @@ -448,7 +440,7 @@ def angle_phi(self): return self.mode_spec.angle_phi @cached_property - def _dir_vector(self) -> Tuple[float, float, float]: + def _dir_vector(self) -> tuple[float, float, float]: """Source direction normal vector in cartesian coordinates.""" radius = 1.0 if self.direction == "+" else -1.0 dx = radius * np.cos(self.angle_phi) * np.sin(self.angle_theta) @@ -513,8 +505,8 @@ class PlaneWave(AngledFieldSource, PlanarSource, BroadbandSource): * `Using FDTD to Compute a Transmission Spectrum `__ """ - angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = pydantic.Field( - FixedInPlaneKSpec(), + angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = Field( + default_factory=FixedInPlaneKSpec, title="Angular Dependence Specification", description="Specification of plane wave propagation direction dependence on wavelength.", discriminator=TYPE_TAG_STR, @@ -536,7 +528,11 @@ def frequency_grid(self) -> np.ndarray: freq_min = max(freq_min, f_crit * CRITICAL_FREQUENCY_FACTOR) return self._chebyshev_freq_grid(freq_min, freq_max) - def _post_init_validators(self) -> None: + @property + def _post_init_validators(self) -> tuple: + return (self._validate_source_frequency_range,) + + def _validate_source_frequency_range(self): """Error if a broadband plane wave with constant in-plane k is defined such that the source frequency range is entirely below ``f_crit * CRITICAL_FREQUENCY_FACTOR.""" if self._is_fixed_angle or self.num_freqs == 1: @@ -580,14 +576,14 @@ class GaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): * `Inverse taper edge coupler <../../notebooks/EdgeCoupler.html>`_ """ - waist_radius: pydantic.PositiveFloat = pydantic.Field( + waist_radius: PositiveFloat = Field( 1.0, title="Waist Radius", description="Radius of the beam at the waist.", units=MICROMETER, ) - waist_distance: float = pydantic.Field( + waist_distance: float = Field( 0.0, title="Waist Distance", description="Distance from the beam waist along the propagation direction. " @@ -628,14 +624,14 @@ class AstigmaticGaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): ... waist_distances = (3.0, 4.0)) """ - waist_sizes: Tuple[pydantic.PositiveFloat, pydantic.PositiveFloat] = pydantic.Field( + waist_sizes: tuple[PositiveFloat, PositiveFloat] = Field( (1.0, 1.0), title="Waist sizes", description="Size of the beam at the waist in the local x and y directions.", units=MICROMETER, ) - waist_distances: Tuple[float, float] = pydantic.Field( + waist_distances: tuple[float, float] = Field( (0.0, 0.0), title="Waist distances", description="Distance to the beam waist along the propagation direction " @@ -681,8 +677,7 @@ class TFSF(AngledFieldSource, VolumeSource, BroadbandSource): * `Nanoparticle Scattering <../../notebooks/PlasmonicNanoparticle.html>`_: To force a uniform grid in the TFSF region and avoid the warnings, a mesh override structure can be used as illustrated here. """ - injection_axis: Axis = pydantic.Field( - ..., + injection_axis: Axis = Field( title="Injection Axis", description="Specifies the injection axis. The plane of incidence is defined via this " "``injection_axis`` and the ``direction``. The popagation axis is defined with respect " diff --git a/tidy3d/components/source/time.py b/tidy3d/components/source/time.py index d66ac9afec..f5d24d5efc 100644 --- a/tidy3d/components/source/time.py +++ b/tidy3d/components/source/time.py @@ -6,7 +6,7 @@ from typing import Optional, Union import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, PositiveFloat, field_validator from ...constants import HERTZ from ...exceptions import ValidationError @@ -14,13 +14,7 @@ from ..data.dataset import TimeDataset from ..data.validators import validate_no_nans from ..time import AbstractTimeDependence -from ..types import ( - ArrayComplex1D, - ArrayFloat1D, - Ax, - FreqBound, - PlotVal, -) +from ..types import ArrayComplex1D, ArrayFloat1D, Ax, FreqBound, PlotVal from ..validators import warn_if_dataset_none from ..viz import add_ax_if_none @@ -77,17 +71,18 @@ def end_time(self) -> Optional[float]: class Pulse(SourceTime, ABC): """A source time that ramps up with some ``fwidth`` and oscillates at ``freq0``.""" - freq0: pydantic.PositiveFloat = pydantic.Field( - ..., title="Central Frequency", description="Central frequency of the pulse.", units=HERTZ + freq0: PositiveFloat = Field( + title="Central Frequency", + description="Central frequency of the pulse.", + units=HERTZ, ) - fwidth: pydantic.PositiveFloat = pydantic.Field( - ..., + fwidth: PositiveFloat = Field( title="", description="Standard deviation of the frequency content of the pulse.", units=HERTZ, ) - offset: float = pydantic.Field( + offset: float = Field( 5.0, title="Offset", description="Time delay of the maximum value of the " @@ -129,7 +124,7 @@ class GaussianPulse(Pulse): >>> pulse = GaussianPulse(freq0=200e12, fwidth=20e12) """ - remove_dc_component: bool = pydantic.Field( + remove_dc_component: bool = Field( True, title="Remove DC Component", description="Whether to remove the DC component in the Gaussian pulse spectrum. " @@ -193,7 +188,7 @@ def from_amp_complex(cls, amp: complex, **kwargs) -> GaussianPulse: @classmethod def from_frequency_range( - cls, fmin: pydantic.PositiveFloat, fmax: pydantic.PositiveFloat, **kwargs + cls, fmin: PositiveFloat, fmax: PositiveFloat, **kwargs ) -> GaussianPulse: """Create a ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. @@ -301,14 +296,14 @@ class CustomSourceTime(Pulse): """ - offset: float = pydantic.Field( + offset: float = Field( 0.0, title="Offset", description="Time delay of the envelope in units of 1 / (``2pi * fwidth``).", ) - source_time_dataset: Optional[TimeDataset] = pydantic.Field( - ..., + source_time_dataset: Optional[TimeDataset] = Field( + None, title="Source time dataset", description="Dataset for storing the envelope of the custom source time. " "This envelope will be modulated by a complex exponential at frequency ``freq0``.", @@ -317,8 +312,8 @@ class CustomSourceTime(Pulse): _no_nans_dataset = validate_no_nans("source_time_dataset") _source_time_dataset_none_warning = warn_if_dataset_none("source_time_dataset") - @pydantic.validator("source_time_dataset", always=True) - def _more_than_one_time(cls, val): + @field_validator("source_time_dataset") + def _more_than_one_time(val): """Must have more than one time to interpolate.""" if val is None: return val diff --git a/tidy3d/components/source/utils.py b/tidy3d/components/source/utils.py index 4996138c13..cf0fdcecf6 100644 --- a/tidy3d/components/source/utils.py +++ b/tidy3d/components/source/utils.py @@ -1,7 +1,5 @@ """Defines electric current sources for injecting light into simulation.""" -from __future__ import annotations - from typing import Union from .current import CustomCurrentSource, PointDipole, UniformCurrentSource diff --git a/tidy3d/components/spice/analysis/dc.py b/tidy3d/components/spice/analysis/dc.py index 7f6c275602..728621df9c 100644 --- a/tidy3d/components/spice/analysis/dc.py +++ b/tidy3d/components/spice/analysis/dc.py @@ -2,7 +2,7 @@ This class defines standard SPICE electrical_analysis types (electrical simulations configurations). """ -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel @@ -19,19 +19,19 @@ class ChargeToleranceSpec(Tidy3dBaseModel): >>> charge_settings = td.ChargeToleranceSpec(abs_tol=1e8, rel_tol=1e-10, max_iters=30) """ - abs_tol: pd.PositiveFloat = pd.Field( + abs_tol: PositiveFloat = Field( default=1e10, title="Absolute tolerance.", description="Absolute tolerance used as stop criteria when converging towards a solution.", ) - rel_tol: pd.PositiveFloat = pd.Field( + rel_tol: PositiveFloat = Field( default=1e-10, title="Relative tolerance.", description="Relative tolerance used as stop criteria when converging towards a solution.", ) - max_iters: pd.PositiveInt = pd.Field( + max_iters: PositiveInt = Field( default=30, title="Maximum number of iterations.", description="Indicates the maximum number of iterations to be run. " @@ -39,7 +39,7 @@ class ChargeToleranceSpec(Tidy3dBaseModel): "or when the tolerance criteria has been met.", ) - ramp_up_iters: pd.PositiveInt = pd.Field( + ramp_up_iters: PositiveInt = Field( default=1, title="Ramp-up iterations.", description="In order to help in start up, quantities such as doping " @@ -53,7 +53,7 @@ class IsothermalSteadyChargeDCAnalysis(Tidy3dBaseModel): Configures relevant steady-state DC simulation parameters for a charge simulation. """ - temperature: pd.PositiveFloat = pd.Field( + temperature: PositiveFloat = Field( 300, title="Temperature", description="Lattice temperature. Assumed constant throughout the device. " @@ -61,11 +61,11 @@ class IsothermalSteadyChargeDCAnalysis(Tidy3dBaseModel): units=KELVIN, ) - tolerance_settings: ChargeToleranceSpec = pd.Field( + tolerance_settings: ChargeToleranceSpec = Field( default=ChargeToleranceSpec(), title="Tolerance settings" ) - convergence_dv: pd.PositiveFloat = pd.Field( + convergence_dv: PositiveFloat = Field( default=1.0, title="Bias step.", description="By default, a solution is computed at 0 bias. If a bias different than " @@ -74,7 +74,7 @@ class IsothermalSteadyChargeDCAnalysis(Tidy3dBaseModel): "convergence parameter in DC computations.", ) - fermi_dirac: bool = pd.Field( + fermi_dirac: bool = Field( False, title="Fermi-Dirac statistics", description="Determines whether Fermi-Dirac statistics are used. When False, " diff --git a/tidy3d/components/spice/sources/dc.py b/tidy3d/components/spice/sources/dc.py index 7c560ff165..6814ab7d55 100644 --- a/tidy3d/components/spice/sources/dc.py +++ b/tidy3d/components/spice/sources/dc.py @@ -21,7 +21,7 @@ from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field, FiniteFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import AMP, VOLT @@ -44,8 +44,8 @@ class DCVoltageSource(Tidy3dBaseModel): >>> voltage_source = td.DCVoltageSource(voltage=voltages) """ - name: Optional[str] - voltage: Union[pd.FiniteFloat, list[pd.FiniteFloat]] = pd.Field( + name: Optional[str] = None + voltage: Union[FiniteFloat, list[FiniteFloat]] = Field( title="Voltage", description="DC voltage usually used as source in 'VoltageBC' boundary conditions.", ) @@ -62,8 +62,8 @@ class DCCurrentSource(Tidy3dBaseModel): >>> current_source = td.DCCurrentSource(current=0.4) """ - name: Optional[str] - current: pd.FiniteFloat = pd.Field( + name: Optional[str] = None + current: FiniteFloat = Field( title="Current", description="DC current usually used as source in 'CurrentBC' boundary conditions.", ) diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 9cf8724dd8..e98294cde7 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -4,20 +4,27 @@ import pathlib from collections import defaultdict -from typing import Optional, Tuple, Union +from typing import Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic +from autograd.extend import Box as AutogradBox +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) from ..constants import MICROMETER from ..exceptions import SetupError, Tidy3dError, Tidy3dImportError from ..log import log from .autograd.derivative_utils import DerivativeInfo from .autograd.types import AutogradFieldMap -from .autograd.types import Box as AutogradBox from .autograd.utils import get_static -from .base import Tidy3dBaseModel, skip_if_fields_missing +from .base import Tidy3dBaseModel from .data.data_array import ScalarFieldDataArray from .geometry.base import Box, Geometry from .geometry.polyslab import PolySlab @@ -48,16 +55,15 @@ class AbstractStructure(Tidy3dBaseModel): A basic structure object. """ - geometry: GeometryType = pydantic.Field( - ..., + geometry: GeometryType = Field( title="Geometry", description="Defines geometric properties of the structure.", discriminator=TYPE_TAG_STR, ) - name: str = pydantic.Field(None, title="Name", description="Optional name for the structure.") + name: Optional[str] = Field(None, title="Name", description="Optional name for the structure.") - background_permittivity: float = pydantic.Field( + background_permittivity: Optional[float] = Field( None, ge=1.0, title="Background Permittivity", @@ -66,7 +72,7 @@ class AbstractStructure(Tidy3dBaseModel): "when performing shape optimization with autograd.", ) - background_medium: StructureMediumType = pydantic.Field( + background_medium: Optional[StructureMediumType] = Field( None, title="Background Medium", description="Medium used for the background of this structure " @@ -75,12 +81,12 @@ class AbstractStructure(Tidy3dBaseModel): "``Simulation`` by default to compute the shape derivatives.", ) - @pydantic.root_validator(skip_on_failure=True) - def _handle_background_mediums(cls, values): + @model_validator(mode="after") + def _handle_background_mediums(self): """Handle background medium combinations, including deprecation.""" - background_permittivity = values.get("background_permittivity") - background_medium = values.get("background_medium") + background_permittivity = self.background_permittivity + background_medium = self.background_medium # old case, only permittivity supplied, warn and set the Medium automatically if background_medium is None and background_permittivity is not None: @@ -89,7 +95,9 @@ def _handle_background_mediums(cls, values): "set the 'Structure.background_medium' directly using a 'Medium'. " "Handling automatically using the supplied relative permittivity." ) - values["background_medium"] = Medium(permittivity=background_permittivity) + object.__setattr__( + self, "background_medium", Medium(permittivity=background_permittivity) + ) # both present, just make sure they are consistent, error if not if background_medium is not None and background_permittivity is not None: @@ -100,12 +108,12 @@ def _handle_background_mediums(cls, values): "Use 'background_medium' only as 'background_permittivity' is deprecated." ) - return values + return self _name_validator = validate_name_str() - @pydantic.validator("geometry") - def _transformed_slanted_polyslabs_not_allowed(cls, val): + @field_validator("geometry") + def _transformed_slanted_polyslabs_not_allowed(val): """Prevents the creation of slanted polyslabs rotated out of plane.""" validate_no_transformed_polyslabs(val) return val @@ -182,8 +190,7 @@ class Structure(AbstractStructure): * `Structures `_ """ - medium: StructureMediumType = pydantic.Field( - ..., + medium: StructureMediumType = Field( title="Medium", description="Defines the electromagnetic properties of the structure's medium.", discriminator=TYPE_TAG_STR, @@ -193,7 +200,7 @@ class Structure(AbstractStructure): def viz_spec(self): return self.medium.viz_spec - def eps_diagonal(self, frequency: float, coords: Coords) -> Tuple[complex, complex, complex]: + def eps_diagonal(self, frequency: float, coords: Coords) -> tuple[complex, complex, complex]: """Main diagonal of the complex-valued permittivity tensor as a function of frequency. Parameters @@ -210,11 +217,11 @@ def eps_diagonal(self, frequency: float, coords: Coords) -> Tuple[complex, compl return self.medium.eps_diagonal_on_grid(frequency=frequency, coords=coords) return self.medium.eps_diagonal(frequency=frequency) - @pydantic.validator("medium", always=True) - @skip_if_fields_missing(["geometry"]) - def _check_2d_geometry(cls, val, values): + @model_validator(mode="after") + def _check_2d_geometry(self): """Medium2D is only consistent with certain geometry types""" - geom = values.get("geometry") + val = self.medium + geom = self.geometry if isinstance(val, Medium2D): # the geometry needs to be supported by 2d materials @@ -228,7 +235,7 @@ def _check_2d_geometry(cls, val, values): # if the geometry is not supported / not 2d _ = geom._normal_2dmaterial - return val + return self def _compatible_with(self, other: Structure) -> bool: """Whether these two structures are compatible.""" @@ -358,10 +365,10 @@ def to_gdstk( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, ) -> None: """Convert a structure's planar slice to a .gds type polygon. @@ -426,8 +433,8 @@ def to_gdspy( x: float = None, y: float = None, z: float = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, ) -> None: """Convert a structure's planar slice to a .gds type polygon. @@ -464,10 +471,10 @@ def to_gds( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, ) -> None: """Append a structure's planar slice to a .gds cell. @@ -528,10 +535,10 @@ def to_gds_file( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, gds_cell_name: str = "MAIN", ) -> None: """Export a structure's planar slice to a .gds file. @@ -659,18 +666,17 @@ class MeshOverrideStructure(AbstractStructure): >>> struct_override = MeshOverrideStructure(geometry=box, dl=(0.1,0.2,0.3), name='override_box') """ - dl: Tuple[ - Optional[pydantic.PositiveFloat], - Optional[pydantic.PositiveFloat], - Optional[pydantic.PositiveFloat], - ] = pydantic.Field( - ..., + dl: tuple[ + Optional[PositiveFloat], + Optional[PositiveFloat], + Optional[PositiveFloat], + ] = Field( title="Grid Size", description="Grid size along x, y, z directions.", units=MICROMETER, ) - enforce: bool = pydantic.Field( + enforce: bool = Field( False, title="Enforce Grid Size", description="If ``True``, enforce the grid size setup inside the structure " @@ -679,7 +685,7 @@ class MeshOverrideStructure(AbstractStructure): "the last added structure of ``enforce=True``.", ) - shadow: bool = pydantic.Field( + shadow: bool = Field( True, title="Grid Size Choice In Structure Overlapping Region", description="In structure intersection region, grid size is decided by the latter added " @@ -688,7 +694,7 @@ class MeshOverrideStructure(AbstractStructure): "the bounding box of the structure is disabled.", ) - drop_outside_sim: bool = pydantic.Field( + drop_outside_sim: bool = Field( True, title="Drop Structure Outside Simulation Domain", description="If ``True``, structure outside the simulation domain is dropped; if ``False``, " @@ -696,8 +702,8 @@ class MeshOverrideStructure(AbstractStructure): "and that of the simulation domain overlap.", ) - @pydantic.validator("geometry") - def _box_only(cls, val): + @field_validator("geometry") + def _box_only(val): """Ensure this is a box.""" if isinstance(val, Geometry): if not isinstance(val, Box): @@ -708,12 +714,12 @@ def _box_only(cls, val): return val.bounding_box return val - @pydantic.validator("shadow") - def _unshadowed_cannot_be_enforced(cls, val, values): + @model_validator(mode="after") + def _unshadowed_cannot_be_enforced(self): """Unshadowed structure cannot be enforced.""" - if not val and values["enforce"]: + if not self.shadow and self.enforce: raise SetupError("A structure cannot be simultaneously enforced and unshadowed.") - return val + return self StructureType = Union[Structure, MeshOverrideStructure] diff --git a/tidy3d/components/subpixel_spec.py b/tidy3d/components/subpixel_spec.py index 1781b04876..7b6aadae44 100644 --- a/tidy3d/components/subpixel_spec.py +++ b/tidy3d/components/subpixel_spec.py @@ -1,12 +1,12 @@ # Defines specifications for subpixel averaging -from __future__ import annotations from typing import Union -import pydantic.v1 as pd +from pydantic import Field +from ..compat import Self from .base import Tidy3dBaseModel, cached_property -from .types import TYPE_TAG_STR +from .types import discriminated_union # Default Courant number reduction rate in PEC conformal's scheme DEFAULT_COURANT_REDUCTION_PEC_CONFORMAL = 0.3 @@ -65,7 +65,9 @@ class ContourPathAveraging(AbstractSubpixelAveragingMethod): """ -DielectricSubpixelType = Union[Staircasing, PolarizedAveraging, ContourPathAveraging] +DielectricSubpixelType = discriminated_union( + Union[Staircasing, PolarizedAveraging, ContourPathAveraging] +) class VolumetricAveraging(AbstractSubpixelAveragingMethod): @@ -73,7 +75,7 @@ class VolumetricAveraging(AbstractSubpixelAveragingMethod): The material property is averaged in the volume surrounding the Yee grid. """ - staircase_normal_component: bool = pd.Field( + staircase_normal_component: bool = Field( True, title="Staircasing For Field Components Substantially Normal To Interface", description="Volumetric averaging works accurately if the electric field component " @@ -83,7 +85,7 @@ class VolumetricAveraging(AbstractSubpixelAveragingMethod): ) -MetalSubpixelType = Union[Staircasing, VolumetricAveraging] +MetalSubpixelType = discriminated_union(Union[Staircasing, VolumetricAveraging]) class HeuristicPECStaircasing(AbstractSubpixelAveragingMethod): @@ -110,7 +112,7 @@ class PECConformal(AbstractSubpixelAveragingMethod): IEEE Transactions on Antennas and Propagation, 54(6), 1843 (2006). """ - timestep_reduction: float = pd.Field( + timestep_reduction: float = Field( DEFAULT_COURANT_REDUCTION_PEC_CONFORMAL, title="Time Step Size Reduction Rate", description="Reduction factor between 0 and 1 such that the simulation's time step size " @@ -128,7 +130,7 @@ def courant_ratio(self) -> float: return 1 - self.timestep_reduction -PECSubpixelType = Union[Staircasing, HeuristicPECStaircasing, PECConformal] +PECSubpixelType = discriminated_union(Union[Staircasing, HeuristicPECStaircasing, PECConformal]) class SurfaceImpedance(PECConformal): @@ -136,7 +138,7 @@ class SurfaceImpedance(PECConformal): structure made of :class:`.LossyMetalMedium`. """ - timestep_reduction: float = pd.Field( + timestep_reduction: float = Field( DEFAULT_COURANT_REDUCTION_SIBC_CONFORMAL, title="Time Step Size Reduction Rate", description="Reduction factor between 0 and 1 such that the simulation's time step size " @@ -147,44 +149,42 @@ class SurfaceImpedance(PECConformal): ) -LossyMetalSubpixelType = Union[Staircasing, VolumetricAveraging, SurfaceImpedance] +LossyMetalSubpixelType = discriminated_union( + Union[Staircasing, VolumetricAveraging, SurfaceImpedance] +) class SubpixelSpec(Tidy3dBaseModel): """Defines specification for subpixel averaging schemes when added to ``Simulation.subpixel``.""" - dielectric: DielectricSubpixelType = pd.Field( - PolarizedAveraging(), + dielectric: DielectricSubpixelType = Field( + default_factory=PolarizedAveraging, title="Subpixel Averaging Method For Dielectric Interfaces", description="Subpixel averaging method applied to dielectric material interfaces.", - discriminator=TYPE_TAG_STR, ) - metal: MetalSubpixelType = pd.Field( - Staircasing(), + metal: MetalSubpixelType = Field( + default_factory=Staircasing, title="Subpixel Averaging Method For Metallic Interfaces", description="Subpixel averaging method applied to metallic structure interfaces. " "A material is considered as metallic if its real part of relative permittivity " "is less than 1 at the central frequency.", - discriminator=TYPE_TAG_STR, ) - pec: PECSubpixelType = pd.Field( - PECConformal(), + pec: PECSubpixelType = Field( + default_factory=PECConformal, title="Subpixel Averaging Method For PEC Interfaces", description="Subpixel averaging method applied to PEC structure interfaces.", - discriminator=TYPE_TAG_STR, ) - lossy_metal: LossyMetalSubpixelType = pd.Field( - SurfaceImpedance(), + lossy_metal: LossyMetalSubpixelType = Field( + default_factory=SurfaceImpedance, title="Subpixel Averaging Method for Lossy Metal Interfaces", description="Subpixel averaging method applied to ``td.LossyMetalMedium`` material interfaces.", - discriminator=TYPE_TAG_STR, ) @classmethod - def staircasing(cls) -> SubpixelSpec: + def staircasing(cls) -> Self: """Apply staircasing on all material boundaries.""" return cls( dielectric=Staircasing(), diff --git a/tidy3d/components/tcad/bandgap.py b/tidy3d/components/tcad/bandgap.py index 61b469252c..721e607211 100644 --- a/tidy3d/components/tcad/bandgap.py +++ b/tidy3d/components/tcad/bandgap.py @@ -1,4 +1,4 @@ -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import PERCMCUBE, VOLT @@ -35,27 +35,24 @@ class SlotboomBandGapNarrowing(Tidy3dBaseModel): .. [1] 'UNIFIED APPARENT BANDGAP NARROWING IN n- AND p-TYPE SILICON' Solid-State Electronics Vol. 35, No. 2, pp. 125-129, 1992""" - v1: pd.PositiveFloat = pd.Field( - ..., + v1: PositiveFloat = Field( title=r"$V_{1,bgn}$ parameter", description=r"$V_{1,bgn}$ parameter", units=VOLT, ) - n2: pd.PositiveFloat = pd.Field( - ..., + n2: PositiveFloat = Field( title=r"$N_{2,bgn}$ parameter", description=r"$N_{2,bgn}$ parameter", units=PERCMCUBE, ) - c2: float = pd.Field( + c2: float = Field( title=r"$C_{2,bgn}$ parameter", description=r"$C_{2,bgn}$ parameter", ) - min_N: pd.NonNegativeFloat = pd.Field( - ..., + min_N: NonNegativeFloat = Field( title="Minimum total doping", description="Bandgap narrowing is applied at location where total doping " "is higher than 'min_N'.", diff --git a/tidy3d/components/tcad/boundary/abstract.py b/tidy3d/components/tcad/boundary/abstract.py index 7e46014e91..36ecf368e3 100644 --- a/tidy3d/components/tcad/boundary/abstract.py +++ b/tidy3d/components/tcad/boundary/abstract.py @@ -1,7 +1,5 @@ """Defines heat material specifications""" -from __future__ import annotations - from abc import ABC from tidy3d.components.base import Tidy3dBaseModel diff --git a/tidy3d/components/tcad/boundary/charge.py b/tidy3d/components/tcad/boundary/charge.py index b31dac4917..b6061ea5ea 100644 --- a/tidy3d/components/tcad/boundary/charge.py +++ b/tidy3d/components/tcad/boundary/charge.py @@ -1,8 +1,6 @@ """Defines heat material specifications""" -from __future__ import annotations - -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.spice.sources.types import CurrentSourceType, VoltageSourceType from tidy3d.components.tcad.boundary.abstract import HeatChargeBC @@ -28,8 +26,7 @@ class VoltageBC(HeatChargeBC): >>> voltage_bc = td.VoltageBC(source=voltage_source) """ - source: VoltageSourceType = pd.Field( - ..., + source: VoltageSourceType = Field( title="Voltage", description="Electric potential to be applied at the specified boundary.", units=VOLT, @@ -47,8 +44,7 @@ class CurrentBC(HeatChargeBC): >>> current_bc = CurrentBC(source=current_source) """ - source: CurrentSourceType = pd.Field( - ..., + source: CurrentSourceType = Field( title="Current Source", description="A current source", units=CURRENT_DENSITY, diff --git a/tidy3d/components/tcad/boundary/heat.py b/tidy3d/components/tcad/boundary/heat.py index 9671dee6a9..77087fcbfb 100644 --- a/tidy3d/components/tcad/boundary/heat.py +++ b/tidy3d/components/tcad/boundary/heat.py @@ -1,8 +1,6 @@ """Defines heat material specifications""" -from __future__ import annotations - -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.tcad.boundary.abstract import HeatChargeBC from tidy3d.constants import HEAT_FLUX, HEAT_TRANSFER_COEFF, KELVIN @@ -17,7 +15,7 @@ class TemperatureBC(HeatChargeBC): >>> bc = td.TemperatureBC(temperature=300) """ - temperature: pd.PositiveFloat = pd.Field( + temperature: PositiveFloat = Field( title="Temperature", description=f"Temperature value in units of {KELVIN}.", units=KELVIN, @@ -33,7 +31,7 @@ class HeatFluxBC(HeatChargeBC): >>> bc = td.HeatFluxBC(flux=1) """ - flux: float = pd.Field( + flux: float = Field( title="Heat Flux", description=f"Heat flux value in units of {HEAT_FLUX}.", units=HEAT_FLUX, @@ -49,13 +47,13 @@ class ConvectionBC(HeatChargeBC): >>> bc = td.ConvectionBC(ambient_temperature=300, transfer_coeff=1) """ - ambient_temperature: pd.PositiveFloat = pd.Field( + ambient_temperature: PositiveFloat = Field( title="Ambient Temperature", description=f"Ambient temperature value in units of {KELVIN}.", units=KELVIN, ) - transfer_coeff: pd.NonNegativeFloat = pd.Field( + transfer_coeff: NonNegativeFloat = Field( title="Heat Transfer Coefficient", description=f"Heat flux value in units of {HEAT_TRANSFER_COEFF}.", units=HEAT_TRANSFER_COEFF, diff --git a/tidy3d/components/tcad/boundary/specification.py b/tidy3d/components/tcad/boundary/specification.py index 08289fa19f..837ec66c51 100644 --- a/tidy3d/components/tcad/boundary/specification.py +++ b/tidy3d/components/tcad/boundary/specification.py @@ -1,8 +1,6 @@ """Defines heat material specifications""" -from __future__ import annotations - -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.bc_placement import BCPlacementType @@ -22,13 +20,13 @@ class HeatChargeBoundarySpec(Tidy3dBaseModel): ... ) """ - placement: BCPlacementType = pd.Field( + placement: BCPlacementType = Field( title="Boundary Conditions Placement", description="Location to apply boundary conditions.", discriminator=TYPE_TAG_STR, ) - condition: HeatChargeBCType = pd.Field( + condition: HeatChargeBCType = Field( title="Boundary Conditions", description="Boundary conditions to apply at the selected location.", discriminator=TYPE_TAG_STR, diff --git a/tidy3d/components/tcad/data/monitor_data/abstract.py b/tidy3d/components/tcad/data/monitor_data/abstract.py index 6deab4014f..308bb85006 100644 --- a/tidy3d/components/tcad/data/monitor_data/abstract.py +++ b/tidy3d/components/tcad/data/monitor_data/abstract.py @@ -4,23 +4,19 @@ import copy from abc import ABC, abstractmethod -from typing import Tuple, Union +from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData -from tidy3d.components.data.data_array import ( - SpatialDataArray, -) +from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset -from tidy3d.components.tcad.types import ( - HeatChargeMonitorType, -) -from tidy3d.components.types import Coordinate, ScalarSymmetry, annotate_type +from tidy3d.components.tcad.types import HeatChargeMonitorType +from tidy3d.components.types import Coordinate, ScalarSymmetry, discriminated_union FieldDataset = Union[ - SpatialDataArray, annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]) + SpatialDataArray, discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) ] UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] @@ -28,19 +24,18 @@ class HeatChargeMonitorData(AbstractMonitorData, ABC): """Abstract base class of objects that store data pertaining to a single :class:`HeatChargeMonitor`.""" - monitor: HeatChargeMonitorType = pd.Field( - ..., + monitor: HeatChargeMonitorType = Field( title="Monitor", description="Monitor associated with the data.", ) - symmetry: Tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = pd.Field( + symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = Field( (0, 0, 0), title="Symmetry", description="Symmetry of the original simulation in x, y, and z.", ) - symmetry_center: Coordinate = pd.Field( + symmetry_center: Coordinate = Field( (0, 0, 0), title="Symmetry Center", description="Symmetry center of the original simulation in x, y, and z.", diff --git a/tidy3d/components/tcad/data/monitor_data/charge.py b/tidy3d/components/tcad/data/monitor_data/charge.py index d3a3f9738c..95afd6f5c4 100644 --- a/tidy3d/components/tcad/data/monitor_data/charge.py +++ b/tidy3d/components/tcad/data/monitor_data/charge.py @@ -2,12 +2,11 @@ from __future__ import annotations -from typing import Dict, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator -from tidy3d.components.base import skip_if_fields_missing from tidy3d.components.data.data_array import ( DataArray, IndexedVoltageDataArray, @@ -22,52 +21,47 @@ SteadyFreeCarrierMonitor, SteadyPotentialMonitor, ) -from tidy3d.components.types import TYPE_TAG_STR, Ax, annotate_type +from tidy3d.components.types import Ax, discriminated_union from tidy3d.components.viz import add_ax_if_none from tidy3d.exceptions import DataError from tidy3d.log import log FieldDataset = Union[ - SpatialDataArray, annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]) + SpatialDataArray, discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) ] -UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] +UnstructuredFieldType = discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) class SteadyPotentialData(HeatChargeMonitorData): """Stores electric potential :math:`\\psi` from a charge simulation.""" - monitor: SteadyPotentialMonitor = pd.Field( - ..., + monitor: SteadyPotentialMonitor = Field( title="Electric potential monitor", description="Electric potential monitor associated with a `charge` simulation.", ) - potential: FieldDataset = pd.Field( + potential: Optional[FieldDataset] = Field( None, title="Electric potential series", description="Contains the electric potential series.", ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" return dict(potential=self.potential) - @pd.validator("potential", always=True) - @skip_if_fields_missing(["monitor"]) - def warn_no_data(cls, val, values): + @model_validator(mode="after") + def warn_no_data(self): """Warn if no data provided.""" - - mnt = values.get("monitor") - - if val is None: + if self.potential is None: log.warning( - f"No data is available for monitor '{mnt.name}'. This is typically caused by " - "monitor not intersecting any solid medium." + f"No data is available for monitor '{self.monitor.name}'. This is " + "typically caused by monitor not intersecting any solid medium." ) - return val + return self @property def symmetry_expanded_copy(self) -> SteadyPotentialData: @@ -95,65 +89,53 @@ class SteadyFreeCarrierData(HeatChargeMonitorData): ``monitor``. """ - monitor: SteadyFreeCarrierMonitor = pd.Field( - ..., + monitor: SteadyFreeCarrierMonitor = Field( title="Free carrier monitor", description="Free carrier data associated with a Charge simulation.", ) - electrons: UnstructuredFieldType = pd.Field( + electrons: Optional[UnstructuredFieldType] = Field( None, title="Electrons series", description=r"Contains the computed electrons concentration $n$.", - discriminator=TYPE_TAG_STR, ) # n = electrons - holes: UnstructuredFieldType = pd.Field( + holes: Optional[UnstructuredFieldType] = Field( None, title="Holes series", description=r"Contains the computed holes concentration $p$.", - discriminator=TYPE_TAG_STR, ) # p = holes @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" return dict(electrons=self.electrons, holes=self.holes) - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self): """Issue error if incorrect data type is used""" - - mnt = values.get("monitor") - field_data = {field: values.get(field) for field in ["electrons", "holes"]} - + field_data = {field: getattr(self, field) for field in ["electrons", "holes"]} for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): if not isinstance(data.values, IndexedVoltageDataArray): raise ValueError( - f"In the data associated with monitor {mnt}, the field {field} does not contain " - "data associated to any voltage value." + f"In the data associated with monitor {self.monitor}, the " + f"field {field} does not contain data associated to any voltage value." ) + return self - return values - - @pd.root_validator(skip_on_failure=True) - def warn_no_data(cls, values): + @model_validator(mode="after") + def warn_no_data(self): """Warn if no data provided.""" - mnt = values.get("monitor") - electrons = values.get("electrons") - holes = values.get("holes") - - if electrons is None or holes is None: + if self.electrons is None or self.holes is None: log.warning( - f"No data is available for monitor '{mnt.name}'. This is typically caused by " - "monitor not intersecting any solid medium." + f"No data is available for monitor '{self.monitor.name}'. This is " + "typically caused by monitor not intersecting any solid medium." ) - - return values + return self @property def symmetry_expanded_copy(self) -> SteadyFreeCarrierData: @@ -192,85 +174,77 @@ class SteadyEnergyBandData(HeatChargeMonitorData): as defined in the ``monitor``. """ - monitor: SteadyEnergyBandMonitor = pd.Field( - ..., + monitor: SteadyEnergyBandMonitor = Field( title="Energy band monitor", description="Energy bands data associated with a Charge simulation.", ) - Ec: UnstructuredFieldType = pd.Field( + Ec: Optional[UnstructuredFieldType] = Field( None, title="Conduction band series", description=r"Contains the computed energy of the bottom of the conduction band $Ec$.", - discriminator=TYPE_TAG_STR, ) - Ev: UnstructuredFieldType = pd.Field( + Ev: Optional[UnstructuredFieldType] = Field( None, title="Valence band series", description=r"Contains the computed energy of the top of the valence band $Ec$.", - discriminator=TYPE_TAG_STR, ) - Ei: UnstructuredFieldType = pd.Field( + Ei: Optional[UnstructuredFieldType] = Field( None, title="Intrinsic Fermi level series", description=r"Contains the computed intrinsic Fermi level for the material $Ei$.", - discriminator=TYPE_TAG_STR, ) - Efn: UnstructuredFieldType = pd.Field( + Efn: Optional[UnstructuredFieldType] = Field( None, title="Electron's quasi-Fermi level series", description=r"Contains the computed quasi-Fermi level for electrons $Efn$.", - discriminator=TYPE_TAG_STR, ) - Efp: UnstructuredFieldType = pd.Field( + Efp: Optional[UnstructuredFieldType] = Field( None, title="Hole's quasi-Fermi level series", description=r"Contains the computed quasi-Fermi level for holes $Efp$.", - discriminator=TYPE_TAG_STR, ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" return dict(Ec=self.Ec, Ev=self.Ev, Ei=self.Ei, Efn=self.Efn, Efp=self.Efp) - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self): """Issue error if incorrect data type is used""" - mnt = values.get("monitor") - field_data = {field: values.get(field) for field in ["Ec", "Ev", "Ei", "Efn", "Efp"]} + field_data = {field: getattr(self, field) for field in ["Ec", "Ev", "Ei", "Efn", "Efp"]} for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): if not isinstance(data.values, IndexedVoltageDataArray): raise ValueError( - f"In the data associated with monitor {mnt}, the field {field} does not contain " - "data associated to any voltage value." + f"In the data associated with monitor {self.monitor}, the " + f"field {field} does not contain data associated to any voltage value." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def warn_no_data(cls, values): + @model_validator(mode="after") + def warn_no_data(self): """Warn if no data provided.""" - mnt = values.get("monitor") fields = ["Ec", "Ev", "Ei", "Efn", "Efp"] for field_name in fields: - field_data = values.get(field_name) + field_data = getattr(self, field_name) if field_data is None: log.warning( - f"No data is available for monitor '{mnt.name}'. This is typically caused by " - "monitor not intersecting any solid medium." + f"No data is available for monitor '{self.monitor.name}'. This " + "is typically caused by monitor not intersecting any solid medium." ) - return values + return self @property def symmetry_expanded_copy(self) -> SteadyEnergyBandData: @@ -394,40 +368,36 @@ class SteadyCapacitanceData(HeatChargeMonitorData): This is only computed when a voltage source with more than two sources is included within the simulation and determines the :math:`\\Delta V`. """ - monitor: SteadyCapacitanceMonitor = pd.Field( - ..., + monitor: SteadyCapacitanceMonitor = Field( title="Capacitance monitor", description="Capacitance data associated with a Charge simulation.", ) - hole_capacitance: SteadyVoltageDataArray = pd.Field( + hole_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Hole capacitance", description=r"Small signal capacitance ($\frac{dQ_p}{dV}$) associated to the monitor.", ) # C_p = hole_capacitance - electron_capacitance: SteadyVoltageDataArray = pd.Field( + electron_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Electron capacitance", description=r"Small signal capacitance ($\frac{dQn}{dV}$) associated to the monitor.", ) # C_n = electron_capacitance - @pd.validator("hole_capacitance", always=True) - @skip_if_fields_missing(["monitor"]) - def warn_no_data(cls, val, values): + @model_validator(mode="after") + def warn_no_data(self): """Warn if no data provided.""" - mnt = values.get("monitor") - - if val is None: + if self.hole_capacitance is None: log.warning( - f"No data is available for monitor '{mnt.name}'. This is typically caused by " - "monitor not intersecting any solid medium." + f"No data is available for monitor '{self.monitor.name}'. This is " + "typically caused by monitor not intersecting any solid medium." ) - return val + return self def field_name(self, val: str) -> str: """Gets the name of the fields to be plotted.""" diff --git a/tidy3d/components/tcad/data/monitor_data/heat.py b/tidy3d/components/tcad/data/monitor_data/heat.py index eac51a52b0..16313f45ca 100644 --- a/tidy3d/components/tcad/data/monitor_data/heat.py +++ b/tidy3d/components/tcad/data/monitor_data/heat.py @@ -2,27 +2,20 @@ from __future__ import annotations -from typing import Dict, Optional, Union +from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field, model_validator -from tidy3d.components.base import skip_if_fields_missing -from tidy3d.components.data.data_array import ( - DataArray, - IndexedDataArray, - SpatialDataArray, -) +from tidy3d.components.data.data_array import DataArray, IndexedDataArray, SpatialDataArray from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset from tidy3d.components.tcad.data.monitor_data.abstract import HeatChargeMonitorData -from tidy3d.components.tcad.monitors.heat import ( - TemperatureMonitor, -) -from tidy3d.components.types import annotate_type +from tidy3d.components.tcad.monitors.heat import TemperatureMonitor +from tidy3d.components.types import discriminated_union from tidy3d.constants import KELVIN from tidy3d.log import log FieldDataset = Union[ - SpatialDataArray, annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]) + SpatialDataArray, discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) ] UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] @@ -44,52 +37,45 @@ class TemperatureData(HeatChargeMonitorData): >>> temp_mnt_data_expanded = temp_mnt_data.symmetry_expanded_copy """ - monitor: TemperatureMonitor = pd.Field( - ..., title="Monitor", description="Temperature monitor associated with the data." + monitor: TemperatureMonitor = Field( + title="Monitor", + description="Temperature monitor associated with the data.", ) - temperature: Optional[FieldDataset] = pd.Field( - ..., + temperature: Optional[FieldDataset] = Field( + None, title="Temperature", description="Spatial temperature field.", units=KELVIN, ) @property - def field_components(self) -> Dict[str, DataArray]: + def field_components(self) -> dict[str, DataArray]: """Maps the field components to their associated data.""" return dict(temperature=self.temperature) - @pd.validator("temperature", always=True) - @skip_if_fields_missing(["monitor"]) - def warn_no_data(cls, val, values): + @model_validator(mode="after") + def warn_no_data(self): """Warn if no data provided.""" - - mnt = values.get("monitor") - - if val is None: + if self.temperature is None: log.warning( - f"No data is available for monitor '{mnt.name}'. This is typically caused by " + f"No data is available for monitor '{self.monitor.name}'. This is typically caused by " "monitor not intersecting any solid medium." ) + return self - return val - - @pd.validator("temperature", always=True) - @skip_if_fields_missing(["monitor"]) - def check_correct_data_type(cls, val, values): + @model_validator(mode="after") + def check_correct_data_type(self): """Issue error if incorrect data type is used""" - - mnt = values.get("monitor") - - if isinstance(val, TetrahedralGridDataset) or isinstance(val, TriangularGridDataset): - if not isinstance(val.values, IndexedDataArray): + if isinstance(self.temperature, TetrahedralGridDataset) or isinstance( + self.temperature, TriangularGridDataset + ): + if not isinstance(self.temperature.values, IndexedDataArray): raise ValueError( - f"Monitor {mnt} of type 'TemperatureMonitor' cannot be associated with data arrays " - "of type 'IndexVoltageDataArray'." + f"Monitor {self.monitor} of type 'TemperatureMonitor' cannot be " + "associated with data arrays of type 'IndexVoltageDataArray'." ) - - return val + return self def field_name(self, val: str = "") -> str: """Gets the name of the fields to be plot.""" diff --git a/tidy3d/components/tcad/data/sim_data.py b/tidy3d/components/tcad/data/sim_data.py index 0a9a8e4898..021b0af288 100644 --- a/tidy3d/components/tcad/data/sim_data.py +++ b/tidy3d/components/tcad/data/sim_data.py @@ -1,17 +1,12 @@ """Defines heat simulation data class""" -from __future__ import annotations - -from typing import Optional, Tuple +from typing import Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData -from tidy3d.components.data.data_array import ( - SpatialDataArray, - SteadyVoltageDataArray, -) +from tidy3d.components.data.data_array import SpatialDataArray, SteadyVoltageDataArray from tidy3d.components.data.utils import ( TetrahedralGridDataset, TriangularGridDataset, @@ -24,7 +19,7 @@ ) from tidy3d.components.tcad.simulation.heat import HeatSimulation from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation -from tidy3d.components.types import Ax, Literal, RealFieldVal, annotate_type +from tidy3d.components.types import Ax, Literal, RealFieldVal, discriminated_union from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.exceptions import DataError from tidy3d.log import log @@ -53,21 +48,21 @@ class DeviceCharacteristics(Tidy3dBaseModel): """ - steady_dc_hole_capacitance: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_hole_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC hole capacitance", description="Device steady DC capacitance data based on holes. If the simulation " "has converged, these result should be close to that of electrons.", ) - steady_dc_electron_capacitance: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_electron_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC electron capacitance", description="Device steady DC capacitance data based on electrons. If the simulation " "has converged, these result should be close to that of holes.", ) - steady_dc_current_voltage: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_current_voltage: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC current-voltage", description="Device steady DC current-voltage relation for the device.", @@ -118,19 +113,18 @@ class HeatChargeSimulationData(AbstractSimulationData): ... ) """ - simulation: HeatChargeSimulation = pd.Field( + simulation: HeatChargeSimulation = Field( title="Heat-Charge Simulation", description="Original :class:`.HeatChargeSimulation` associated with the data.", ) - data: Tuple[annotate_type(TCADMonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(TCADMonitorDataType), ...] = Field( title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.Simulation`.", ) - device_characteristics: Optional[DeviceCharacteristics] = pd.Field( + device_characteristics: Optional[DeviceCharacteristics] = Field( None, title="Device characteristics", description="Data characterizing the device. Current characteristics include: " @@ -356,16 +350,16 @@ class HeatSimulationData(HeatChargeSimulationData): Consider using :class:`HeatChargeSimulationData` instead. """ - simulation: HeatSimulation = pd.Field( + simulation: HeatSimulation = Field( title="Heat Simulation", description="Original :class:`HeatSimulation` associated with the data.", ) - @pd.root_validator(skip_on_failure=True) - def issue_warning_deprecated(cls, values): + @model_validator(mode="before") + def issue_warning_deprecated(data): """Issue warning for 'HeatSimulations'.""" log.warning( "'HeatSimulationData' is deprecated and will be discontinued. You can use " "'HeatChargeSimulationData' instead" ) - return values + return data diff --git a/tidy3d/components/tcad/data/types.py b/tidy3d/components/tcad/data/types.py index c1865d309e..3a3a3469ea 100644 --- a/tidy3d/components/tcad/data/types.py +++ b/tidy3d/components/tcad/data/types.py @@ -1,7 +1,5 @@ """Monitor level data, store the DataArrays associated with a single heat-charge monitor.""" -from __future__ import annotations - from typing import Union from tidy3d.components.tcad.data.monitor_data.charge import ( diff --git a/tidy3d/components/tcad/doping.py b/tidy3d/components/tcad/doping.py index 9570af2a9a..66e153327e 100644 --- a/tidy3d/components/tcad/doping.py +++ b/tidy3d/components/tcad/doping.py @@ -1,7 +1,7 @@ """File containing classes required for the setup of a DEVSIM case.""" import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, field_validator from tidy3d.components.base import cached_property from tidy3d.components.geometry.base import Box @@ -71,23 +71,20 @@ def _get_indices_in_box(self, coords: dict, meshgrid: bool = True): return indices_in_box, X, Y, Z, normal_axis - @pd.root_validator(skip_on_failure=True) - def check_dimensions(cls, values): + @field_validator("size") + def check_dimensions(val): """Make sure dimensionality is specified correctly. I.e., a 2D box must be defined with an inf size in the normal direction.""" - - size = values["size"] for dim in range(3): - if size[dim] == 0: + if val[dim] == 0: zero_dim_name = "xyz"[dim] - raise SetupError( f"The doping box has been set up with 0 size in the {zero_dim_name} direction. " "If this was intended to be translationally invariant, the box must have a large " "or infinite ('td.inf') size in the perpendicular direction." ) - return values + return val class ConstantDoping(AbstractDopingBox): @@ -108,7 +105,7 @@ class ConstantDoping(AbstractDopingBox): >>> constant_box2 = td.ConstantDoping.from_bounds(rmin=box_coords[0], rmax=box_coords[1], concentration=1e18) """ - concentration: pd.NonNegativeFloat = pd.Field( + concentration: NonNegativeFloat = Field( default=0, title="Doping concentration density.", description="Doping concentration density in #/cm^3.", @@ -184,25 +181,25 @@ class GaussianDoping(AbstractDopingBox): ... ) """ - ref_con: pd.PositiveFloat = pd.Field( + ref_con: PositiveFloat = Field( title="Reference concentration.", description="Reference concentration. This is the minimum concentration in the box " "and it is attained at the edges/faces of the box.", ) - concentration: pd.PositiveFloat = pd.Field( + concentration: PositiveFloat = Field( title="Concentration", description="The concentration at the center of the box.", ) - width: pd.PositiveFloat = pd.Field( + width: PositiveFloat = Field( title="Width of the gaussian.", description="Width of the gaussian. The concentration will transition from " "'concentration' at the center of the box to 'ref_con' at the edge/face " "of the box in a distance equal to 'width'. ", ) - source: str = pd.Field( + source: str = Field( "xmin", title="Source face", description="Specifies the side of the box acting as the source, i.e., " diff --git a/tidy3d/components/tcad/generation_recombination.py b/tidy3d/components/tcad/generation_recombination.py index b33f8412e9..fbe8411e9a 100644 --- a/tidy3d/components/tcad/generation_recombination.py +++ b/tidy3d/components/tcad/generation_recombination.py @@ -1,4 +1,4 @@ -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import Union @@ -38,25 +38,42 @@ class FossumCarrierLifetime(Tidy3dBaseModel): """ - tau_300: pd.PositiveFloat = pd.Field( - ..., title="Tau at 300K", description="Carrier lifetime at 300K", units=SECOND + tau_300: PositiveFloat = Field( + title="Tau at 300K", + description="Carrier lifetime at 300K", + units=SECOND, ) - alpha_T: float = pd.Field( - ..., title="Exponent for thermal dependence", description="Exponent for thermal dependence" + alpha_T: float = Field( + title="Exponent for thermal dependence", + description="Exponent for thermal dependence", ) - N0: pd.PositiveFloat = pd.Field( - ..., title="Reference concentration", description="Reference concentration", units=PERCMCUBE + N0: PositiveFloat = Field( + title="Reference concentration", + description="Reference concentration", + units=PERCMCUBE, ) - A: float = pd.Field(..., title="Constant A", description="Constant A") + A: float = Field( + title="Constant A", + description="Constant A", + ) - B: float = pd.Field(..., title="Constant B", description="Constant B") + B: float = Field( + title="Constant B", + description="Constant B", + ) - C: float = pd.Field(..., title="Constant C", description="Constant C") + C: float = Field( + title="Constant C", + description="Constant C", + ) - alpha: float = pd.Field(..., title="Exponent constant", description="Exponent constant") + alpha: float = Field( + title="Exponent constant", + description="Exponent constant", + ) CarrierLifetimeType = Union[FossumCarrierLifetime] @@ -85,12 +102,14 @@ class AugerRecombination(Tidy3dBaseModel): ... ) """ - c_n: pd.PositiveFloat = pd.Field( - ..., title="Constant for electrons", description="Constant for electrons in cm^6/s" + c_n: PositiveFloat = Field( + title="Constant for electrons", + description="Constant for electrons in cm^6/s", ) - c_p: pd.PositiveFloat = pd.Field( - ..., title="Constant for holes", description="Constant for holes in cm^6/s" + c_p: PositiveFloat = Field( + title="Constant for holes", + description="Constant for holes in cm^6/s", ) @@ -115,8 +134,7 @@ class RadiativeRecombination(Tidy3dBaseModel): ... ) """ - r_const: float = pd.Field( - ..., + r_const: float = Field( title="Radiation constant in cm^3/s", description="Radiation constant in cm^3/s", ) @@ -157,10 +175,14 @@ class ShockleyReedHallRecombination(Tidy3dBaseModel): - This model represents mid-gap traps Shockley-Reed-Hall recombination. """ - tau_n: Union[pd.PositiveFloat, CarrierLifetimeType] = pd.Field( - ..., title="Electron lifetime", description="Electron lifetime", union=SECOND + tau_n: Union[PositiveFloat, CarrierLifetimeType] = Field( + title="Electron lifetime", + description="Electron lifetime", + union=SECOND, ) - tau_p: Union[pd.PositiveFloat, CarrierLifetimeType] = pd.Field( - ..., title="Hole lifetime", description="Hole lifetime", units=SECOND + tau_p: Union[PositiveFloat, CarrierLifetimeType] = Field( + title="Hole lifetime", + description="Hole lifetime", + units=SECOND, ) diff --git a/tidy3d/components/tcad/grid.py b/tidy3d/components/tcad/grid.py index 5720f3d667..89478da12a 100644 --- a/tidy3d/components/tcad/grid.py +++ b/tidy3d/components/tcad/grid.py @@ -3,23 +3,23 @@ from __future__ import annotations from abc import ABC -from typing import Tuple, Union +from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import MICROMETER from tidy3d.exceptions import ValidationError from ..geometry.base import Box -from ..types import Coordinate, annotate_type +from ..types import Coordinate, discriminated_union class UnstructuredGrid(Tidy3dBaseModel, ABC): """Abstract unstructured grid.""" - relative_min_dl: pd.NonNegativeFloat = pd.Field( + relative_min_dl: NonNegativeFloat = Field( 1e-3, title="Relative Mesh Size Limit", description="The minimal allowed mesh size relative to the largest dimension of the simulation domain." @@ -35,14 +35,13 @@ class UniformUnstructuredGrid(UnstructuredGrid): >>> heat_grid = UniformUnstructuredGrid(dl=0.1) """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for uniform grid generation.", units=MICROMETER, ) - min_edges_per_circumference: pd.PositiveFloat = pd.Field( + min_edges_per_circumference: PositiveFloat = Field( 15, title="Minimum Edges per Circumference", description="Enforced minimum number of mesh segments per circumference of an object. " @@ -50,13 +49,13 @@ class UniformUnstructuredGrid(UnstructuredGrid): "is taken as 2 * pi * radius.", ) - min_edges_per_side: pd.PositiveFloat = pd.Field( + min_edges_per_side: PositiveFloat = Field( 2, title="Minimum Edges per Side", description="Enforced minimum number of mesh segments per any side of an object.", ) - non_refined_structures: Tuple[str, ...] = pd.Field( + non_refined_structures: tuple[str, ...] = Field( (), title="Structures Without Refinement", description="List of structures for which ``min_edges_per_circumference`` and " @@ -68,15 +67,13 @@ class GridRefinementRegion(Box): """Refinement region for the unstructured mesh. The cell size is enforced to be constant inside the region. The cell size outside of the region depends on the distance from the region.""" - dl_internal: pd.PositiveFloat = pd.Field( - ..., + dl_internal: PositiveFloat = Field( title="Internal mesh cell size", description="Mesh cell size inside the refinement region", units=MICROMETER, ) - transition_thickness: pd.NonNegativeFloat = pd.Field( - ..., + transition_thickness: NonNegativeFloat = Field( title="Interface Distance", description="Thickness of a transition layer outside the box where the mesh cell size changes from the" "internal size to the external one.", @@ -87,51 +84,39 @@ class GridRefinementRegion(Box): class GridRefinementLine(Tidy3dBaseModel, ABC): """Refinement line for the unstructured mesh. The cell size depends on the distance from the line.""" - r1: Coordinate = pd.Field( - ..., + r1: Coordinate = Field( title="Start point of the line", description="Start point of the line in x, y, and z.", units=MICROMETER, ) - r2: Coordinate = pd.Field( - ..., + r2: Coordinate = Field( title="End point of the line", description="End point of the line in x, y, and z.", units=MICROMETER, ) - @pd.validator("r1", always=True) - def _r1_not_inf(cls, val): + @field_validator("r1", "r2") + def _not_inf(val, info): """Make sure the point is not infinitiy.""" if any(np.isinf(v) for v in val): - raise ValidationError("Point can not contain td.inf terms.") + raise ValidationError("Point can not contain 'td.inf' terms.") return val - @pd.validator("r2", always=True) - def _r2_not_inf(cls, val): - """Make sure the point is not infinitiy.""" - if any(np.isinf(v) for v in val): - raise ValidationError("Point can not contain td.inf terms.") - return val - - dl_near: pd.PositiveFloat = pd.Field( - ..., + dl_near: PositiveFloat = Field( title="Mesh cell size near the line", description="Mesh cell size near the line", units=MICROMETER, ) - distance_near: pd.NonNegativeFloat = pd.Field( - ..., + distance_near: NonNegativeFloat = Field( title="Near distance", description="Distance from the line within which ``dl_near`` is enforced." "Typically the same as ``dl_near`` or its multiple.", units=MICROMETER, ) - distance_bulk: pd.NonNegativeFloat = pd.Field( - ..., + distance_bulk: NonNegativeFloat = Field( title="Bulk distance", description="Distance from the line outside of which ``dl_bulk`` is enforced." "Typically twice of ``dl_bulk`` or its multiple. Use larger values for a smoother " @@ -139,15 +124,13 @@ def _r2_not_inf(cls, val): units=MICROMETER, ) - @pd.validator("distance_bulk", always=True) - @skip_if_fields_missing(["distance_near"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self): """Error if distance_bulk is less than distance_near""" - distance_near = values.get("distance_near") - if distance_near > val: + if self.distance_near > self.distance_bulk: raise ValidationError("'distance_bulk' cannot be smaller than 'distance_near'.") - return val + return self class DistanceUnstructuredGrid(UnstructuredGrid): @@ -164,30 +147,26 @@ class DistanceUnstructuredGrid(UnstructuredGrid): ... ) """ - dl_interface: pd.PositiveFloat = pd.Field( - ..., + dl_interface: PositiveFloat = Field( title="Interface Grid Size", description="Grid size near material interfaces.", units=MICROMETER, ) - dl_bulk: pd.PositiveFloat = pd.Field( - ..., + dl_bulk: PositiveFloat = Field( title="Bulk Grid Size", description="Grid size away from material interfaces.", units=MICROMETER, ) - distance_interface: pd.NonNegativeFloat = pd.Field( - ..., + distance_interface: NonNegativeFloat = Field( title="Interface Distance", description="Distance from interface within which ``dl_interface`` is enforced." "Typically the same as ``dl_interface`` or its multiple.", units=MICROMETER, ) - distance_bulk: pd.NonNegativeFloat = pd.Field( - ..., + distance_bulk: NonNegativeFloat = Field( title="Bulk Distance", description="Distance from interface outside of which ``dl_bulk`` is enforced." "Typically twice of ``dl_bulk`` or its multiple. Use larger values for a smoother " @@ -195,44 +174,42 @@ class DistanceUnstructuredGrid(UnstructuredGrid): units=MICROMETER, ) - sampling: pd.PositiveFloat = pd.Field( + sampling: PositiveFloat = Field( 100, title="Surface Sampling", description="An internal advanced parameter that defines number of sampling points per " "surface when computing distance values.", ) - uniform_grid_mediums: Tuple[str, ...] = pd.Field( + uniform_grid_mediums: tuple[str, ...] = Field( (), title="Mediums With Uniform Refinement", description="List of mediums for which ``dl_interface`` will be enforced everywhere " "in the volume.", ) - non_refined_structures: Tuple[str, ...] = pd.Field( + non_refined_structures: tuple[str, ...] = Field( (), title="Structures Without Refinement", description="List of structures for which ``dl_interface`` will not be enforced. " "``dl_bulk`` is used instead.", ) - mesh_refinements: Tuple[annotate_type(Union[GridRefinementRegion, GridRefinementLine]), ...] = ( - pd.Field( - (), - title="Mesh refinement structures", - description="List of regions/lines for which the mesh refinement will be applied", - ) + mesh_refinements: tuple[ + discriminated_union(Union[GridRefinementRegion, GridRefinementLine]), ... + ] = Field( + (), + title="Mesh refinement structures", + description="List of regions/lines for which the mesh refinement will be applied", ) - @pd.validator("distance_bulk", always=True) - @skip_if_fields_missing(["distance_interface"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self): """Error if distance_bulk is less than distance_interface""" - distance_interface = values.get("distance_interface") - if distance_interface > val: + if self.distance_interface > self.distance_bulk: raise ValidationError("'distance_bulk' cannot be smaller than 'distance_interface'.") - return val + return self UnstructuredGridType = Union[UniformUnstructuredGrid, DistanceUnstructuredGrid] diff --git a/tidy3d/components/tcad/mobility.py b/tidy3d/components/tcad/mobility.py index 127d3cca47..1c8c73372f 100644 --- a/tidy3d/components/tcad/mobility.py +++ b/tidy3d/components/tcad/mobility.py @@ -1,4 +1,4 @@ -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel @@ -12,8 +12,10 @@ class ConstantMobilityModel(Tidy3dBaseModel): >>> mobility_model = td.ConstantMobilityModel(mu=1500) """ - mu: pd.NonNegativeFloat = pd.Field( - ..., title="Mobility", description="Mobility", units="cm²/V-s" + mu: NonNegativeFloat = Field( + title="Mobility", + description="Mobility", + units="cm²/V-s", ) @@ -112,52 +114,45 @@ class CaugheyThomasMobility(Tidy3dBaseModel): """ # mobilities - mu_min: pd.PositiveFloat = pd.Field( - ..., + mu_min: PositiveFloat = Field( title=r"$\mu_{min}$ Minimum electron mobility", description="Minimum electron mobility at reference temperature (300K) in cm^2/V-s. ", ) - mu: pd.PositiveFloat = pd.Field( - ..., + mu: PositiveFloat = Field( title="Reference mobility", description="Reference mobility at reference temperature (300K) in cm^2/V-s", ) # thermal exponent for reference mobility - exp_2: float = pd.Field( - ..., title="Exponent for temperature dependent behavior of reference mobility" + exp_2: float = Field( + title="Exponent for temperature dependent behavior of reference mobility", ) # doping exponent - exp_N: pd.PositiveFloat = pd.Field( - ..., + exp_N: PositiveFloat = Field( title="Exponent for doping dependence of mobility.", description="Exponent for doping dependence of mobility at reference temperature (300K).", ) # reference doping - ref_N: pd.PositiveFloat = pd.Field( - ..., + ref_N: PositiveFloat = Field( title="Reference doping", description="Reference doping at reference temperature (300K) in #/cm^3.", ) # temperature exponent - exp_1: float = pd.Field( - ..., + exp_1: float = Field( title="Exponent of thermal dependence of minimum mobility.", description="Exponent of thermal dependence of minimum mobility.", ) - exp_3: float = pd.Field( - ..., + exp_3: float = Field( title="Exponent of thermal dependence of reference doping.", description="Exponent of thermal dependence of reference doping.", ) - exp_4: float = pd.Field( - ..., + exp_4: float = Field( title="Exponent of thermal dependence of the doping exponent effect.", description="Exponent of thermal dependence of the doping exponent effect.", ) diff --git a/tidy3d/components/tcad/monitors/abstract.py b/tidy3d/components/tcad/monitors/abstract.py index 6628079429..12201e9067 100644 --- a/tidy3d/components/tcad/monitors/abstract.py +++ b/tidy3d/components/tcad/monitors/abstract.py @@ -2,7 +2,7 @@ from abc import ABC -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.monitor import AbstractMonitor from tidy3d.components.types import ArrayFloat1D @@ -13,13 +13,13 @@ class HeatChargeMonitor(AbstractMonitor, ABC): """Abstract base class for heat-charge monitors.""" - unstructured: bool = pd.Field( + unstructured: bool = Field( False, title="Unstructured Grid", description="Return data on the original unstructured grid.", ) - conformal: bool = pd.Field( + conformal: bool = Field( False, title="Conformal Monitor Meshing", description="If ``True`` the simulation mesh will conform to the monitor's geometry. " diff --git a/tidy3d/components/tcad/monitors/charge.py b/tidy3d/components/tcad/monitors/charge.py index e575e30a86..7783d24613 100644 --- a/tidy3d/components/tcad/monitors/charge.py +++ b/tidy3d/components/tcad/monitors/charge.py @@ -2,7 +2,7 @@ from typing import Literal -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.tcad.monitors.abstract import HeatChargeMonitor @@ -33,7 +33,7 @@ class SteadyFreeCarrierMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -53,7 +53,7 @@ class SteadyEnergyBandMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -73,7 +73,7 @@ class SteadyCapacitanceMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", diff --git a/tidy3d/components/tcad/simulation/heat.py b/tidy3d/components/tcad/simulation/heat.py index 6304ed8630..315fb14571 100644 --- a/tidy3d/components/tcad/simulation/heat.py +++ b/tidy3d/components/tcad/simulation/heat.py @@ -1,11 +1,7 @@ """Defines heat simulation class NOTE: Keeping this class for backward compatibility only""" -from __future__ import annotations - -from typing import Tuple - -import pydantic.v1 as pd +from pydantic import model_validator from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation from tidy3d.components.types import Ax @@ -47,14 +43,14 @@ class HeatSimulation(HeatChargeSimulation): ... ) """ - @pd.root_validator(skip_on_failure=True) - def issue_warning_deprecated(cls, values): + @model_validator(mode="before") + def issue_warning_deprecated(data): """Issue warning for 'HeatSimulations'.""" log.warning( "Setting up deprecated 'HeatSimulation'. " "Consider defining 'HeatChargeSimulation' instead." ) - return values + return data @equal_aspect @add_ax_if_none @@ -68,8 +64,8 @@ def plot_heat_conductivity( source_alpha: float = None, monitor_alpha: float = None, colorbar: str = "conductivity", - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's components on a plane defined by one nonzero x,y,z coordinate. diff --git a/tidy3d/components/tcad/simulation/heat_charge.py b/tidy3d/components/tcad/simulation/heat_charge.py index 60f57ffa1b..2193998fbd 100644 --- a/tidy3d/components/tcad/simulation/heat_charge.py +++ b/tidy3d/components/tcad/simulation/heat_charge.py @@ -1,20 +1,18 @@ -# ruff: noqa: W293, W291 """Defines heat simulation class""" from __future__ import annotations from enum import Enum -from typing import Dict, List, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator, model_validator try: from matplotlib import colormaps except ImportError: pass -from tidy3d.components.base import skip_if_fields_missing from tidy3d.components.base_sim.simulation import AbstractSimulation from tidy3d.components.bc_placement import ( MediumMediumInterface, @@ -79,7 +77,14 @@ plot_params_heat_bc, plot_params_heat_source, ) -from tidy3d.components.types import TYPE_TAG_STR, Ax, Bound, ScalarSymmetry, Shapely, annotate_type +from tidy3d.components.types import ( + TYPE_TAG_STR, + Ax, + Bound, + ScalarSymmetry, + Shapely, + discriminated_union, +) from tidy3d.components.viz import PlotParams, add_ax_if_none, equal_aspect from tidy3d.constants import VOLUMETRIC_HEAT_RATE, inf from tidy3d.exceptions import SetupError @@ -249,8 +254,8 @@ class HeatChargeSimulation(AbstractSimulation): top of the coupling heat source. """ - medium: StructureMediumType = pd.Field( - Medium(), + medium: StructureMediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to a standard dispersion-less :class:`Medium` if not " "specified.", @@ -260,34 +265,34 @@ class HeatChargeSimulation(AbstractSimulation): Background medium of simulation, defaults to a standard dispersion-less :class:`Medium` if not specified. """ - sources: Tuple[annotate_type(HeatChargeSourceType), ...] = pd.Field( + sources: tuple[discriminated_union(HeatChargeSourceType), ...] = Field( (), title="Heat and Charge sources", description="List of heat and/or charge sources.", ) - monitors: Tuple[annotate_type(HeatChargeMonitorType), ...] = pd.Field( + monitors: tuple[discriminated_union(HeatChargeMonitorType), ...] = Field( (), title="Monitors", description="Monitors in the simulation.", ) - boundary_spec: Tuple[annotate_type(Union[HeatChargeBoundarySpec, HeatBoundarySpec]), ...] = ( - pd.Field( - (), - title="Boundary Condition Specifications", - description="List of boundary condition specifications.", - ) + boundary_spec: tuple[ + discriminated_union(Union[HeatChargeBoundarySpec, HeatBoundarySpec]), ... + ] = Field( + (), + title="Boundary Condition Specifications", + description="List of boundary condition specifications.", ) # NOTE: creating a union with HeatBoundarySpec for backwards compatibility - grid_spec: UnstructuredGridType = pd.Field( + grid_spec: UnstructuredGridType = Field( title="Grid Specification", description="Grid specification for heat-charge simulation.", discriminator=TYPE_TAG_STR, ) - symmetry: Tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = pd.Field( + symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -296,15 +301,15 @@ class HeatChargeSimulation(AbstractSimulation): "Each element can be ``0`` (symmetry off) or ``1`` (symmetry on).", ) - analysis_spec: AnalysisSpecType = pd.Field( + analysis_spec: Optional[AnalysisSpecType] = Field( None, title="Analysis specification.", description="The `analysis_spec` is used to validate that the simulation parameters and tolerance settings " "are correctly configured as desired by the user.", ) - @pd.validator("structures", always=True) - def check_unsupported_geometries(cls, val): + @field_validator("structures") + def check_unsupported_geometries(val): """Error if structures contain unsupported yet geometries.""" for ind, structure in enumerate(val): bbox = structure.geometry.bounding_box @@ -314,8 +319,7 @@ def check_unsupported_geometries(cls, val): ) return val - @staticmethod - def _check_cross_solids(objs: Tuple[Box, ...], values: Dict) -> Tuple[int, ...]: + def _check_cross_solids(self, objs: tuple[Box, ...]) -> tuple[int, ...]: """Given model dictionary ``values``, check whether objects in list ``objs`` cross a ``SolidSpec`` medium. """ @@ -324,29 +328,16 @@ def _check_cross_solids(objs: Tuple[Box, ...], values: Dict) -> Tuple[int, ...]: # will be accepted valid_electric_medium = (SemiconductorMedium, ChargeConductorMedium) - try: - size = values["size"] - center = values["center"] - medium = values["medium"] - structures = values["structures"] - except KeyError: - raise SetupError( - "Function '_check_cross_solids' assumes dictionary 'values' contains well-defined " - "'size', 'center', 'medium', and 'structures'. Thus, it should only be used in " - "validators with @skip_if_fields_missing(['medium', 'center', 'size', 'structures']) " - "or root validators with option 'skip_on_failure=True'." - ) - # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=size, - center=center, + size=self.size, + center=self.center, ), - medium=medium, + medium=self.medium, ) - total_structures = [structure_bg] + list(structures) + total_structures = [structure_bg] + list(self.structures) obj_do_not_cross_solid_idx = [] obj_do_not_cross_cond_idx = [] @@ -385,15 +376,14 @@ def _check_cross_solids(objs: Tuple[Box, ...], values: Dict) -> Tuple[int, ...]: return obj_do_not_cross_solid_idx, obj_do_not_cross_cond_idx - @pd.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _monitors_cross_solids(cls, val, values): + @model_validator(mode="after") + def _monitors_cross_solids(self): """Error if monitors does not cross any solid medium.""" - + val = self.monitors if val is None: - return val + return self - failed_solid_idx, failed_elect_idx = cls._check_cross_solids(val, values) + failed_solid_idx, failed_elect_idx = self._check_cross_solids(val) temp_monitors = [idx for idx, mnt in enumerate(val) if isinstance(mnt, TemperatureMonitor)] volt_monitors = [ @@ -419,19 +409,19 @@ def _monitors_cross_solids(cls, val, values): "materials. Thus, no information will be recorded in these monitors." ) - return val + return self - @pd.root_validator(skip_on_failure=True) - def check_voltage_array_if_capacitance(cls, values): + @model_validator(mode="after") + def check_voltage_array_if_capacitance(self): """Make sure an array of voltages has been defined if a SteadyCapacitanceMonitor' has been defined""" - bounday_spec = values["boundary_spec"] - monitors = values["monitors"] + boundary_spec = self.boundary_spec + monitors = self.monitors is_capacitance_mnt = any(isinstance(mnt, SteadyCapacitanceMonitor) for mnt in monitors) voltage_array_present = False if is_capacitance_mnt: - for bc in bounday_spec: + for bc in boundary_spec: if isinstance(bc.condition, VoltageBC): if isinstance(bc.condition.source, DCVoltageSource): if isinstance(bc.condition.source.voltage, list) or isinstance( @@ -446,10 +436,10 @@ def check_voltage_array_if_capacitance(cls, values): "Voltage arrays can be included in a source in this manner: " "'VoltageBC(source=DCVoltageSource(voltage=yourArray))'" ) - return values + return self - @pd.validator("size", always=True) - def check_zero_dim_domain(cls, val, values): + @field_validator("size") + def check_zero_dim_domain(val): """Error if heat domain have zero dimensions.""" dim_names = ["x", "y", "z"] @@ -469,17 +459,15 @@ def check_zero_dim_domain(cls, val, values): return val - @pd.validator("boundary_spec", always=True) - @skip_if_fields_missing(["structures", "medium"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self): """Error if boundary conditions point to non-existing structures/media.""" - - structures = values.get("structures") + structures = self.structures structures_names = {s.name for s in structures} mediums_names = {s.medium.name for s in structures} - mediums_names.add(values.get("medium").name) + mediums_names.add(self.medium.name) - for bc_ind, bc_spec in enumerate(val): + for bc_ind, bc_spec in enumerate(self.boundary_spec): bc_place = bc_spec.placement if isinstance(bc_place, (StructureBoundary, StructureSimulationBoundary)): if bc_place.structure not in structures_names: @@ -504,14 +492,13 @@ def names_exist_bcs(cls, val, values): f"'boundary_spec[{bc_ind}].placement' (type '{bc_place.type}') " "is not found among simulation mediums." ) - return val + return self - @pd.validator("boundary_spec", always=True) - def check_only_one_voltage_array_provided(cls, val, values): + @field_validator("boundary_spec") + def check_only_one_voltage_array_provided(val): """Issue error if more than one voltage array is provided. Currently we only allow to sweep over one voltage array. """ - array_already_provided = False for bc in val: @@ -532,8 +519,8 @@ def check_only_one_voltage_array_provided(cls, val, values): ) return val - @pd.root_validator(skip_on_failure=True) - def check_charge_simulation(cls, values): + @model_validator(mode="after") + def check_charge_simulation(self): """Makes sure that Charge simulations are set correctly.""" ChargeMonitorType = ( @@ -542,13 +529,12 @@ def check_charge_simulation(cls, values): SteadyCapacitanceMonitor, ) - simulation_types = cls._check_simulation_types(values=values) + simulation_types = self._check_simulation_types() if TCADAnalysisTypes.CHARGE in simulation_types: # check that we have at least 2 'VoltageBC's - boundary_spec = values["boundary_spec"] voltage_bcs = 0 - for bc in boundary_spec: + for bc in self.boundary_spec: if isinstance(bc.condition, VoltageBC): voltage_bcs = voltage_bcs + 1 if voltage_bcs < 2: @@ -558,8 +544,7 @@ def check_charge_simulation(cls, values): ) # check that we have at least one charge monitor - monitors = values["monitors"] - if not any(isinstance(mnt, ChargeMonitorType) for mnt in monitors): + if not any(isinstance(mnt, ChargeMonitorType) for mnt in self.monitors): raise SetupError( "Charge simulations require the definition of, at least, one of these monitors: " "'[SteadyPotentialMonitor, SteadyFreeCarrierMonitor, SteadyCapacitanceMonitor]' " @@ -568,7 +553,7 @@ def check_charge_simulation(cls, values): # NOTE: SteadyPotentialMonitor and SteadyFreeCarrierMonitor are only supported # for unstructured = True - for mnt in monitors: + for mnt in self.monitors: if isinstance(mnt, SteadyPotentialMonitor) or isinstance( mnt, SteadyFreeCarrierMonitor ): @@ -578,23 +563,22 @@ def check_charge_simulation(cls, values): f"monitor '{mnt.name}' to 'unstructured = True'." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def not_all_neumann(cls, values): + @model_validator(mode="after") + def not_all_neumann(self): """Make sure not all BCs are of Neumann type""" NeumannBCsHeat = (HeatFluxBC,) NeumannBCsCharge = (CurrentBC, InsulatingBC) - simulation_types = cls._check_simulation_types(values=values) - bounday_conditions = values["boundary_spec"] + simulation_types = self._check_simulation_types() raise_error = False for sim_type in simulation_types: if sim_type == TCADAnalysisTypes.HEAT: type_bcs = [ - bc for bc in bounday_conditions if isinstance(bc.condition, HeatBCTypes) + bc for bc in self.boundary_spec if isinstance(bc.condition, HeatBCTypes) ] if len(type_bcs) == 0 or all( isinstance(bc.condition, NeumannBCsHeat) for bc in type_bcs @@ -602,7 +586,7 @@ def not_all_neumann(cls, values): raise_error = True elif sim_type == TCADAnalysisTypes.CONDUCTION: type_bcs = [ - bc for bc in bounday_conditions if isinstance(bc.condition, ElectricBCTypes) + bc for bc in self.boundary_spec if isinstance(bc.condition, ElectricBCTypes) ] if len(type_bcs) == 0 or all( isinstance(bc.condition, NeumannBCsCharge) for bc in type_bcs @@ -618,30 +602,25 @@ def not_all_neumann(cls, values): f"Current Neumann BCs are {names_neumann_Bcs}" ) - return values + return self - @pd.validator("grid_spec", always=True) - @skip_if_fields_missing(["structures"]) - def names_exist_grid_spec(cls, val, values): + @model_validator(mode="after") + def names_exist_grid_spec(self): """Warn if 'UniformUnstructuredGrid' points at a non-existing structure.""" - - structures = values.get("structures") - structures_names = {s.name for s in structures} - - for structure_name in val.non_refined_structures: + structures_names = {s.name for s in self.structures} + for structure_name in self.grid_spec.non_refined_structures: if structure_name not in structures_names: log.warning( f"Structure '{structure_name}' listed as a non-refined structure in " "'HeatChargeSimulation.grid_spec' is not present in 'HeatChargeSimulation.structures'" ) + return self - return val - - @pd.validator("grid_spec", always=True) - def warn_if_minimal_mesh_size_override(cls, val, values): + @model_validator(mode="after") + def warn_if_minimal_mesh_size_override(self): """Warn if minimal mesh size limit overrides desired mesh size.""" - - max_size = np.max(values.get("size")) + val = self.grid_spec + max_size = np.max(self.size) min_dl = val.relative_min_dl * max_size if isinstance(val, UniformUnstructuredGrid): @@ -655,16 +634,14 @@ def warn_if_minimal_mesh_size_override(cls, val, values): "Consider lowering parameter 'relative_min_dl' if a finer grid is required." ) - return val + return self - @pd.validator("sources", always=True) - @skip_if_fields_missing(["structures"]) - def names_exist_sources(cls, val, values): + @model_validator(mode="after") + def names_exist_sources(self): """Error if a heat-charge source point to non-existing structures.""" - structures = values.get("structures") - structures_names = {s.name for s in structures} + structures_names = {s.name for s in self.structures} - sources = [s for s in val if not isinstance(s, HeatFromElectricSource)] + sources = [s for s in self.sources if not isinstance(s, HeatFromElectricSource)] for source in sources: for name in source.structures: @@ -673,22 +650,17 @@ def names_exist_sources(cls, val, values): f"Structure '{name}' provided in a '{source.type}' " "is not found among simulation structures." ) - return val + return self - @pd.root_validator(skip_on_failure=True) - def check_medium_specs(cls, values): + @model_validator(mode="after") + def check_medium_specs(self): """Error if no appropriate specs.""" - sim_box = ( - Box( - size=values.get("size"), - center=values.get("center"), - ), - ) + sim_box = (Box(size=self.size, center=self.center),) - failed_solid_idx, failed_elect_idx = cls._check_cross_solids(sim_box, values) + failed_solid_idx, failed_elect_idx = self._check_cross_solids(sim_box) - simulation_types = cls._check_simulation_types(values=values) + simulation_types = self._check_simulation_types() for sim_type in simulation_types: if sim_type == TCADAnalysisTypes.HEAT: @@ -702,7 +674,7 @@ def check_medium_specs(cls, values): "No conducting materials ('ChargeConductorMedium') are detected in conduction simulation. Solution domain is empty." ) - return values + return self @staticmethod def _check_if_semiconductor_present(structures) -> bool: @@ -720,9 +692,8 @@ def _check_if_semiconductor_present(structures) -> bool: charge_sim = True return charge_sim - @staticmethod def _check_simulation_types( - values: Dict, + self, HeatBCTypes=HeatBCTypes, ElectricBCTypes=ElectricBCTypes, HeatSourceTypes=HeatSourceTypes, @@ -732,17 +703,13 @@ def _check_simulation_types( """ simulation_types = [] - boundaries = list(values["boundary_spec"]) - sources = list(values["sources"]) - - structures = list(values["structures"]) semiconductor_present = HeatChargeSimulation._check_if_semiconductor_present( - structures=structures + structures=self.structures ) if semiconductor_present: simulation_types.append(TCADAnalysisTypes.CHARGE) - for boundary in boundaries: + for boundary in self.boundary_spec: if isinstance(boundary.condition, HeatBCTypes): simulation_types.append(TCADAnalysisTypes.HEAT) if isinstance(boundary.condition, ElectricBCTypes): @@ -753,26 +720,22 @@ def _check_simulation_types( else: simulation_types.append(TCADAnalysisTypes.CONDUCTION) - for source in sources: + for source in self.sources: if isinstance(source, HeatSourceTypes): simulation_types.append(TCADAnalysisTypes.HEAT) return set(simulation_types) - @pd.root_validator(skip_on_failure=True) - def check_coupling_source_can_be_applied(cls, values): + @model_validator(mode="after") + def check_coupling_source_can_be_applied(self): """Error if material doesn't have the right specifications""" HeatSourceTypes_noCoupling = (UniformHeatSource, HeatSource) - simulation_types = cls._check_simulation_types( - values, HeatSourceTypes=HeatSourceTypes_noCoupling - ) + simulation_types = self._check_simulation_types(HeatSourceTypes=HeatSourceTypes_noCoupling) simulation_types = list(simulation_types) - sources = list(values["sources"]) - - for source in sources: + for source in self.sources: if isinstance(source, HeatFromElectricSource) and len(simulation_types) < 2: raise SetupError( f"Using 'HeatFromElectricSource' requires the definition of both " @@ -780,24 +743,22 @@ def check_coupling_source_can_be_applied(cls, values): f"The current simulation setup contains only conditions of type {simulation_types[0].name}" ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def estimate_charge_mesh_size(cls, values): + @model_validator(mode="after") + def estimate_charge_mesh_size(self): """Make an estimate of the mesh size and raise a warning if too big. NOTE: this is a very rough estimate. The back-end will actually stop execution based on actual node-count.""" - if TCADAnalysisTypes.CHARGE not in cls._check_simulation_types(values=values): - return values + if TCADAnalysisTypes.CHARGE not in self._check_simulation_types(): + return self # let's raise a warning if the estimate is larger than 2M nodes max_nodes = 2e6 nodes_estimate = 0 - structures = values["structures"] - grid_spec = values["grid_spec"] - + grid_spec = self.grid_spec non_refined_structures = grid_spec.non_refined_structures if isinstance(grid_spec, UniformUnstructuredGrid): @@ -807,7 +768,7 @@ def estimate_charge_mesh_size(cls, values): dl_min = grid_spec.dl_interface dl_max = grid_spec.dl_bulk - for struct in structures: + for struct in self.structures: name = struct.name bounds = struct.geometry.bounds dl = dl_min @@ -832,7 +793,7 @@ def estimate_charge_mesh_size(cls, values): "the pipeline will be stopped. If this happens the grid specification " "may need to be modified." ) - return values + return self @equal_aspect @add_ax_if_none @@ -846,8 +807,8 @@ def plot_property( source_alpha: float = None, monitor_alpha: float = None, property: str = "heat_conductivity", - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's components on a plane defined by one nonzero x,y,z coordinate. @@ -871,9 +832,9 @@ def plot_property( property : str = "heat_conductivity" Specified the type of simulation for which the plot will be tailored. Options are ["heat_conductivity", "electric_conductivity", "source"] - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -945,8 +906,8 @@ def plot_heat_conductivity( source_alpha: float = None, monitor_alpha: float = None, colorbar: str = "conductivity", - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **kwargs, ) -> Ax: """ @@ -973,9 +934,9 @@ def plot_heat_conductivity( colorbar: str = "conductivity" Display colorbar for thermal conductivity ("conductivity") or heat source rate ("source"). - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -984,9 +945,9 @@ def plot_heat_conductivity( The supplied or created matplotlib axes. """ log.warning( - """This function `plot_heat_conductivity` is - deprecated and will be discontinued. In its place you can use - `plot_property(property="heat_conductivity")`""" + "The function 'plot_heat_conductivity' is " + "deprecated and will be discontinued. In its place you can use " + r"'plot_property(property=\"heat_conductivity\")'" ) plot_type = "heat_conductivity" @@ -1107,9 +1068,9 @@ def _plot_boundary_condition( @staticmethod def _structure_to_bc_spec_map( plane: Box, - structures: Tuple[Structure, ...], - boundary_spec: Tuple[HeatChargeBoundarySpec, ...], - ) -> Dict[str, HeatChargeBoundarySpec]: + structures: tuple[Structure, ...], + boundary_spec: tuple[HeatChargeBoundarySpec, ...], + ) -> dict[str, HeatChargeBoundarySpec]: """Construct structure name to bc spec inverse mapping. One structure may correspond to multiple boundary conditions.""" @@ -1143,9 +1104,9 @@ def _structure_to_bc_spec_map( @staticmethod def _medium_to_bc_spec_map( plane: Box, - structures: Tuple[Structure, ...], - boundary_spec: Tuple[HeatChargeBoundarySpec, ...], - ) -> Dict[str, HeatChargeBoundarySpec]: + structures: tuple[Structure, ...], + boundary_spec: tuple[HeatChargeBoundarySpec, ...], + ) -> dict[str, HeatChargeBoundarySpec]: """Construct medium name to bc spec inverse mapping. One medium may correspond to multiple boundary conditions.""" @@ -1168,11 +1129,11 @@ def _medium_to_bc_spec_map( @staticmethod def _construct_forward_boundaries( - shapes: Tuple[Tuple[str, str, Shapely, Tuple[float, float, float, float]], ...], - struct_to_bc_spec: Dict[str, HeatChargeBoundarySpec], - med_to_bc_spec: Dict[str, HeatChargeBoundarySpec], + shapes: tuple[tuple[str, str, Shapely, tuple[float, float, float, float]], ...], + struct_to_bc_spec: dict[str, HeatChargeBoundarySpec], + med_to_bc_spec: dict[str, HeatChargeBoundarySpec], background_structure_shape: Shapely, - ) -> Tuple[Tuple[HeatChargeBoundarySpec, Shapely], ...]: + ) -> tuple[tuple[HeatChargeBoundarySpec, Shapely], ...]: """Construct Simulation, StructureSimulation, Structure, and MediumMedium boundaries.""" # forward foop to take care of Simulation, StructureSimulation, Structure, @@ -1260,10 +1221,10 @@ def _construct_forward_boundaries( @staticmethod def _construct_reverse_boundaries( - shapes: Tuple[Tuple[str, str, Shapely, Bound], ...], - struct_to_bc_spec: Dict[str, HeatChargeBoundarySpec], + shapes: tuple[tuple[str, str, Shapely, Bound], ...], + struct_to_bc_spec: dict[str, HeatChargeBoundarySpec], background_structure_shape: Shapely, - ) -> Tuple[Tuple[HeatChargeBoundarySpec, Shapely], ...]: + ) -> tuple[tuple[HeatChargeBoundarySpec, Shapely], ...]: """Construct StructureStructure boundaries.""" # backward foop to take care of StructureStructure @@ -1328,24 +1289,24 @@ def _construct_reverse_boundaries( @staticmethod def _construct_heat_charge_boundaries( - structures: List[Structure], + structures: list[Structure], plane: Box, - boundary_spec: List[HeatChargeBoundarySpec], - ) -> List[Tuple[HeatChargeBoundarySpec, Shapely]]: + boundary_spec: list[HeatChargeBoundarySpec], + ) -> list[tuple[HeatChargeBoundarySpec, Shapely]]: """Compute list of boundary lines to plot on plane. Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] list of structures to filter on the plane. plane : :class:`.Box` target plane. - boundary_spec : List[HeatBoundarySpec] + boundary_spec : list[HeatBoundarySpec] list of boundary conditions associated with structures. Returns ------- - List[Tuple[:class:`.HeatBoundarySpec`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.HeatBoundarySpec`, shapely.geometry.base.BaseGeometry]] List of boundary lines and boundary conditions on the plane after merging. """ @@ -1399,8 +1360,8 @@ def plot_sources( y: float = None, z: float = None, property: str = "heat_conductivity", - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, alpha: float = None, ax: Ax = None, ) -> Ax: @@ -1417,9 +1378,9 @@ def plot_sources( property : str = None Specified the type of simulation for which the plot will be tailored. Options are ["heat_conductivity", "electric_conductivity"] - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -1506,7 +1467,7 @@ def _safe_float_conversion(self, string) -> float: except ValueError: return None - def source_bounds(self, property: str = "heat_conductivity") -> Tuple[float, float]: + def source_bounds(self, property: str = "heat_conductivity") -> tuple[float, float]: """Compute range of heat sources present in the simulation.""" if property == "heat_conductivity" or property == "source": @@ -1664,5 +1625,4 @@ def _get_simulation_types(self) -> list[TCADAnalysisTypes]: def _useHeatSourceFromConductionSim(self): """Returns True if 'HeatFromElectricSource' has been defined.""" - return any(isinstance(source, HeatFromElectricSource) for source in self.sources) diff --git a/tidy3d/components/tcad/source/abstract.py b/tidy3d/components/tcad/source/abstract.py index 9e1b7ffa71..0469e89d69 100644 --- a/tidy3d/components/tcad/source/abstract.py +++ b/tidy3d/components/tcad/source/abstract.py @@ -1,11 +1,8 @@ """Defines heat-charge material specifications for 'HeatChargeSimulation'""" -from __future__ import annotations - from abc import ABC -from typing import Tuple -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.base import cached_property from tidy3d.components.base_sim.source import AbstractSource @@ -28,13 +25,13 @@ class StructureBasedHeatChargeSource(AbstractHeatChargeSource): """Abstract class associated with structures. Sources associated to structures must derive from this class""" - structures: Tuple[str, ...] = pd.Field( + structures: tuple[str, ...] = Field( title="Target Structures", description="Names of structures where to apply heat source.", ) - @pd.validator("structures", always=True) - def check_non_empty_structures(cls, val): + @field_validator("structures") + def check_non_empty_structures(val): """Error if source doesn't point at any structures.""" if len(val) == 0: raise SetupError("List of structures for heat source is empty.") diff --git a/tidy3d/components/tcad/source/coupled.py b/tidy3d/components/tcad/source/coupled.py index b63b943048..f67906cb0b 100644 --- a/tidy3d/components/tcad/source/coupled.py +++ b/tidy3d/components/tcad/source/coupled.py @@ -1,7 +1,5 @@ """Defines heat-charge material specifications for 'HeatChargeSimulation'""" -from __future__ import annotations - from tidy3d.components.tcad.source.abstract import GlobalHeatChargeSource diff --git a/tidy3d/components/tcad/source/heat.py b/tidy3d/components/tcad/source/heat.py index ba269d31ad..1af52fb13f 100644 --- a/tidy3d/components/tcad/source/heat.py +++ b/tidy3d/components/tcad/source/heat.py @@ -1,10 +1,8 @@ """Defines heat-charge material specifications for 'HeatChargeSimulation'""" -from __future__ import annotations - from typing import Union -import pydantic.v1 as pd +import pydantic as pd from tidy3d.components.tcad.source.abstract import StructureBasedHeatChargeSource from tidy3d.constants import VOLUMETRIC_HEAT_RATE diff --git a/tidy3d/components/time.py b/tidy3d/components/time.py index c26ebf88ba..65e0170d3b 100644 --- a/tidy3d/components/time.py +++ b/tidy3d/components/time.py @@ -1,11 +1,9 @@ """Defines time dependence""" -from __future__ import annotations - from abc import ABC, abstractmethod import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, NonNegativeFloat from ..constants import RADIAN from ..exceptions import SetupError @@ -20,11 +18,11 @@ class AbstractTimeDependence(ABC, Tidy3dBaseModel): """Base class describing time dependence.""" - amplitude: pydantic.NonNegativeFloat = pydantic.Field( + amplitude: NonNegativeFloat = Field( 1.0, title="Amplitude", description="Real-valued maximum amplitude of the time dependence." ) - phase: float = pydantic.Field( + phase: float = Field( 0.0, title="Phase", description="Phase shift of the time dependence.", units=RADIAN ) diff --git a/tidy3d/components/time_modulation.py b/tidy3d/components/time_modulation.py index 93b55cee6b..aded1f2b2d 100644 --- a/tidy3d/components/time_modulation.py +++ b/tidy3d/components/time_modulation.py @@ -1,17 +1,16 @@ """Defines time modulation to the medium""" -from __future__ import annotations - from abc import ABC, abstractmethod from math import isclose -from typing import Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, field_validator, model_validator +from ..compat import Self from ..constants import HERTZ, RADIAN from ..exceptions import ValidationError -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .data.data_array import SpatialDataArray from .data.validators import validate_no_nans from .time import AbstractTimeDependence @@ -62,8 +61,8 @@ class ContinuousWaveTimeModulation(AbstractTimeDependence): >>> cw = ContinuousWaveTimeModulation(freq0=200e12, amplitude=1, phase=0) """ - freq0: pd.PositiveFloat = pd.Field( - ..., title="Modulation Frequency", description="Modulation frequency.", units=HERTZ + freq0: PositiveFloat = Field( + title="Modulation Frequency", description="Modulation frequency.", units=HERTZ ) def amp_time(self, time: float) -> complex: @@ -127,41 +126,33 @@ class SpaceModulation(AbstractSpaceModulation): >>> space = SpaceModulation(amplitude=amp, phase=phase) """ - amplitude: Union[float, SpatialDataArray] = pd.Field( + amplitude: Union[float, SpatialDataArray] = Field( 1, title="Amplitude of modulation in space", description="Amplitude of modulation that can vary spatially. " "It takes the unit of whatever is being modulated.", ) - phase: Union[float, SpatialDataArray] = pd.Field( + phase: Union[float, SpatialDataArray] = Field( 0, title="Phase of modulation in space", description="Phase of modulation that can vary spatially.", units=RADIAN, ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "nearest", title="Interpolation method", description="Method of interpolation to use to obtain values at spatial locations on the Yee grids.", ) - _no_nans_amplitude = validate_no_nans("amplitude") - _no_nans_phase = validate_no_nans("phase") + _no_nans = validate_no_nans("amplitude", "phase") - @pd.validator("amplitude", always=True) - def _real_amplitude(cls, val): + @field_validator("amplitude", "phase") + def _validate_fields_real(val, info): """Assert that the amplitude is real.""" if np.iscomplexobj(val): - raise ValidationError("'amplitude' must be real.") - return val - - @pd.validator("phase", always=True) - def _real_phase(cls, val): - """Assert that the phase is real.""" - if np.iscomplexobj(val): - raise ValidationError("'phase' must be real.") + raise ValidationError(f"'{info.field_name}' must be real.") return val @cached_property @@ -169,7 +160,7 @@ def max_modulation(self) -> float: """Estimated maximum modulation amplitude.""" return np.max(abs(np.array(self.amplitude))) - def sel_inside(self, bounds: Bound) -> SpaceModulation: + def sel_inside(self, bounds: Bound) -> Self: """Return a new space modulation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. @@ -216,18 +207,15 @@ class SpaceTimeModulation(Tidy3dBaseModel): \\delta \\epsilon(r, t) = \\Re[amp\\_time(t) \\cdot amp\\_space(r)] """ - space_modulation: SpaceModulationType = pd.Field( - SpaceModulation(), + space_modulation: SpaceModulationType = Field( + default_factory=SpaceModulation, title="Space modulation", description="Space modulation part from the separable SpaceTimeModulation.", - # discriminator=TYPE_TAG_STR, ) - time_modulation: TimeModulationType = pd.Field( - ..., + time_modulation: TimeModulationType = Field( title="Time modulation", description="Time modulation part from the separable SpaceTimeModulation.", - # discriminator=TYPE_TAG_STR, ) @cached_property @@ -243,7 +231,7 @@ def negligible_modulation(self) -> bool: return True return False - def sel_inside(self, bounds: Bound) -> SpaceTimeModulation: + def sel_inside(self, bounds: Bound) -> Self: """Return a new space-time modulation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. @@ -267,38 +255,36 @@ class ModulationSpec(Tidy3dBaseModel): including relative permittivity at infinite frequency and electric conductivity. """ - permittivity: SpaceTimeModulation = pd.Field( + permittivity: Optional[SpaceTimeModulation] = Field( None, title="Space-time modulation of relative permittivity", description="Space-time modulation of relative permittivity at infinite frequency " "applied on top of the base permittivity at infinite frequency.", ) - conductivity: SpaceTimeModulation = pd.Field( + conductivity: Optional[SpaceTimeModulation] = Field( None, title="Space-time modulation of conductivity", description="Space-time modulation of electric conductivity " "applied on top of the base conductivity.", ) - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def _same_modulation_frequency(cls, val, values): + @model_validator(mode="after") + def _check_same_modulation_frequency(self) -> Self: """Assert same time-modulation applied to permittivity and conductivity.""" - permittivity = values.get("permittivity") - if val is not None and permittivity is not None: - if val.time_modulation != permittivity.time_modulation: + if self.conductivity is not None and self.permittivity is not None: + if self.conductivity.time_modulation != self.permittivity.time_modulation: raise ValidationError( "'permittivity' and 'conductivity' should have the same time modulation." ) - return val + return self @cached_property def applied_modulation(self) -> bool: """Check if any modulation has been applied to ``permittivity`` or ``conductivity``.""" return self.permittivity is not None or self.conductivity is not None - def sel_inside(self, bounds: Bound) -> ModulationSpec: + def sel_inside(self, bounds: Bound) -> Self: """Return a new modulation specficiation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. diff --git a/tidy3d/components/transformation.py b/tidy3d/components/transformation.py index ea7c1ee494..2fb89e3914 100644 --- a/tidy3d/components/transformation.py +++ b/tidy3d/components/transformation.py @@ -1,12 +1,10 @@ """Defines geometric transformation classes""" -from __future__ import annotations - from abc import ABC, abstractmethod from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator from ..constants import RADIAN from ..exceptions import ValidationError @@ -73,30 +71,30 @@ def rotate_tensor(self, tensor: TensorReal) -> TensorReal: class RotationAroundAxis(AbstractRotation): """Rotation of vectors and tensors around a given vector.""" - axis: Union[Axis, Coordinate] = pd.Field( + axis: Union[Axis, Coordinate] = Field( 0, title="Axis of Rotation", description="A vector that specifies the axis of rotation, or a single int: 0, 1, or 2, " "indicating x, y, or z.", ) - angle: TracedFloat = pd.Field( + angle: TracedFloat = Field( 0.0, title="Angle of Rotation", description="Angle of rotation in radians.", units=RADIAN, ) - @pd.validator("axis", always=True) - def _convert_axis_index_to_vector(cls, val): + @field_validator("axis") + def _validate_axis_vector(val): if not isinstance(val, tuple): axis = [0.0, 0.0, 0.0] axis[val] = 1.0 val = tuple(axis) return val - @pd.validator("axis") - def _guarantee_nonzero_axis(cls, val): + @field_validator("axis") + def _validate_axis_nonzero_norm(val): norm = np.linalg.norm(val) if np.isclose(norm, 0): raise ValidationError( @@ -174,14 +172,14 @@ def reflect_tensor(self, tensor: TensorReal) -> TensorReal: class ReflectionFromPlane(AbstractReflection): """Reflection of vectors and tensors around a given vector.""" - normal: Coordinate = pd.Field( + normal: Coordinate = Field( (1, 0, 0), title="Normal of the reflecting plane", description="A vector that specifies the normal of the plane of reflection", ) - @pd.validator("normal") - def _guarantee_nonzero_normal(cls, val): + @field_validator("normal") + def _validate_normal_nonzero_norm(val): norm = np.linalg.norm(val) if np.isclose(norm, 0): raise ValidationError( diff --git a/tidy3d/components/type_util.py b/tidy3d/components/type_util.py index 5bbb64f975..1637dce78a 100644 --- a/tidy3d/components/type_util.py +++ b/tidy3d/components/type_util.py @@ -1,12 +1,18 @@ """Utilities for type & schema creation.""" +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None: """Adds a schema to the ``arbitrary_type`` class without subclassing.""" @classmethod - def mod_schema_fn(cls, field_schema: dict) -> None: - """Function that gets set to ``arbitrary_type.__modify_schema__``.""" - field_schema.update(dict(title=title, type=field_type_str)) + def __get_pydantic_core_schema__( + cls, _source_type: type, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + # `any_schema()` is enough because we only want to accept the value + # the metadata will later show up in JSON-schema generation. + return core_schema.any_schema(metadata={"title": title, "type": field_type_str}) - arbitrary_type.__modify_schema__ = mod_schema_fn + arbitrary_type.__get_pydantic_core_schema__ = __get_pydantic_core_schema__ diff --git a/tidy3d/components/types.py b/tidy3d/components/types.py index 61697f87d4..8bfbacc558 100644 --- a/tidy3d/components/types.py +++ b/tidy3d/components/types.py @@ -1,183 +1,171 @@ """Defines 'types' that various fields can be""" -from typing import ( - Literal, # We support py3.9+, so direct typing import is fine. - Optional, - Tuple, - Union, +import numbers +from typing import Annotated, Any, Literal, Optional, Union + +import numpy as np +from pydantic import ( + BeforeValidator, + Field, + NonNegativeFloat, + PositiveFloat, ) - -import autograd.numpy as np -import pydantic.v1 as pydantic +from pydantic.functional_serializers import PlainSerializer try: from matplotlib.axes import Axes except ImportError: Axes = None -from shapely.geometry.base import BaseGeometry -from typing_extensions import Annotated -from ..exceptions import ValidationError +from shapely.geometry.base import BaseGeometry # type tag default name TYPE_TAG_STR = "type" -def annotate_type(UnionType): - """Annotated union type using TYPE_TAG_STR as discriminator.""" - return Annotated[UnionType, pydantic.Field(discriminator=TYPE_TAG_STR)] +def discriminated_union(union, discriminator=TYPE_TAG_STR): + return Annotated[union, Field(discriminator=discriminator)] """ Numpy Arrays """ -def _totuple(arr: np.ndarray) -> tuple: - """Convert a numpy array to a nested tuple.""" - if arr.ndim > 1: - return tuple(_totuple(val) for val in arr) - return tuple(arr) - - -# generic numpy array -Numpy = np.ndarray - - -class ArrayLike: - """Type that stores a numpy array.""" - - ndim = None - dtype = None - shape = None - - @classmethod - def __get_validators__(cls): - yield cls.load_complex - yield cls.convert_to_numpy - yield cls.check_dims - yield cls.check_shape - yield cls.assert_non_null - - @classmethod - def load_complex(cls, val): - """Special handling to load a complex-valued np.ndarray saved to file.""" - if not isinstance(val, dict): - return val - if "real" not in val or "imag" not in val: - raise ValueError("ArrayLike real and imaginary parts not stored properly.") - arr_real = np.array(val["real"]) - arr_imag = np.array(val["imag"]) - return arr_real + 1j * arr_imag - - @classmethod - def convert_to_numpy(cls, val): - """Convert the value to np.ndarray and provide some casting.""" - arr_numpy = np.array(val, ndmin=1, dtype=cls.dtype, copy=True) - return arr_numpy - - @classmethod - def check_dims(cls, val): - """Make sure the number of dimensions is correct.""" - if cls.ndim and val.ndim != cls.ndim: - raise ValidationError(f"Expected {cls.ndim} dimensions for ArrayLike, got {val.ndim}.") - return val - - @classmethod - def check_shape(cls, val): - """Make sure the shape is correct.""" - if cls.shape and val.shape != cls.shape: - raise ValidationError(f"Expected shape {cls.shape} for ArrayLike, got {val.shape}.") - return val - - @classmethod - def assert_non_null(cls, val): - """Make sure array is not None.""" - if np.any(np.isnan(val)): - raise ValidationError("'ArrayLike' field contained None or nan values.") - return val - - @classmethod - def __modify_schema__(cls, field_schema): - """Sets the schema of DataArray object.""" - - schema = dict( - type="ArrayLike", - ) - field_schema.update(schema) - - -def constrained_array( - dtype: type = None, ndim: int = None, shape: Tuple[pydantic.NonNegativeInt, ...] = None -) -> type: - """Generate an ArrayLike sub-type with constraints built in.""" - - # note, a unique name is required for each subclass of ArrayLike with constraints - type_name = "ArrayLike" - - meta_args = [] - if dtype is not None: - meta_args.append(f"dtype={dtype.__name__}") - if ndim is not None: - meta_args.append(f"ndim={ndim}") - if shape is not None: - meta_args.append(f"shape={shape}") - type_name += "[" + ", ".join(meta_args) + "]" - - return type(type_name, (ArrayLike,), dict(dtype=dtype, ndim=ndim, shape=shape)) - - -# pre-define a set of commonly used array like instances for import and use in type hints -ArrayInt1D = constrained_array(dtype=int, ndim=1) -ArrayFloat1D = constrained_array(dtype=float, ndim=1) -ArrayFloat2D = constrained_array(dtype=float, ndim=2) -ArrayFloat3D = constrained_array(dtype=float, ndim=3) -ArrayFloat4D = constrained_array(dtype=float, ndim=4) -ArrayComplex1D = constrained_array(dtype=complex, ndim=1) -ArrayComplex2D = constrained_array(dtype=complex, ndim=2) -ArrayComplex3D = constrained_array(dtype=complex, ndim=3) -ArrayComplex4D = constrained_array(dtype=complex, ndim=4) - -TensorReal = constrained_array(dtype=float, ndim=2, shape=(3, 3)) -MatrixReal4x4 = constrained_array(dtype=float, ndim=2, shape=(4, 4)) +def _from_complex_dict(v): + if isinstance(v, dict) and "real" in v and "imag" in v: + return np.asarray(v["real"]) + 1j * np.asarray(v["imag"]) + return v + + +def _coerce(v, *, dtype, ndim, shape, forbid_nan, scalar_to_1d): + """Convert input to a NumPy array with constraints. + + Raises + ------ + ValueError + - If conversion to an array fails. + - If the array ends up with dtype=object (unsupported element type). + - If the number of dimensions or shape does not match the expectations. + - If ``forbid_nan`` is ``True`` and the array contains NaN values. + """ + try: + arr = np.asarray(v) if dtype is None else np.asarray(v, dtype=dtype) + except Exception as e: + raise ValueError(f"cannot convert {type(v).__name__!r} to a NumPy array") from e + if arr.dtype == np.dtype("object"): + raise ValueError(f"unsupported element type {type(v).__name__!r} for array coercion") + + if arr.ndim == 0 and scalar_to_1d and ndim == 1: + arr = arr.reshape(1) + if ndim is not None and arr.ndim != ndim: + raise ValueError(f"expected {ndim}-D, got {arr.ndim}-D") + if shape is not None and tuple(arr.shape) != shape: + raise ValueError(f"expected shape {shape}, got {tuple(arr.shape)}") + if forbid_nan and np.any(np.isnan(arr)): + raise ValueError("array contains NaN") + return arr + + +def _auto_serializer(a, _): + """Serializes numpy arrays and scalars for JSON.""" + if isinstance(a, complex) or ( + hasattr(np, "complexfloating") and isinstance(a, np.complexfloating) + ): + return {"real": float(a.real), "imag": float(a.imag)} + if isinstance(a, np.ndarray): + if np.iscomplexobj(a): + return {"real": a.real.tolist(), "imag": a.imag.tolist()} + else: + return a.tolist() + if isinstance(a, float) or (hasattr(np, "floating") and isinstance(a, np.floating)): + return float(a) # Ensure basic Python float + if isinstance(a, int) or (hasattr(np, "integer") and isinstance(a, np.integer)): + return int(a) # Ensure basic Python int + if hasattr(np, "number") and isinstance(a, np.number): + return a.item() + return a + + +def array_alias( + *, + dtype: Optional[type] = None, + ndim: Optional[int] = None, + shape: Optional[tuple[int, ...]] = None, + forbid_nan: bool = True, + scalar_to_1d: bool = False, +): + """Return an `Annotated[np.ndarray, ...]` with checks.""" + validators = [ + BeforeValidator(_from_complex_dict), + BeforeValidator( + lambda v: _coerce( + v, + dtype=np.dtype(dtype) if dtype is not None else None, + ndim=ndim, + shape=shape, + forbid_nan=forbid_nan, + scalar_to_1d=scalar_to_1d, + ) + ), + ] + + serializer = PlainSerializer(_auto_serializer, when_used="json") + + return Annotated[np.ndarray, *validators, serializer] + + +ArrayLike = array_alias() + +ArrayInt1D = array_alias(dtype=int, ndim=1, scalar_to_1d=True) + +ArrayFloat = array_alias(dtype=float) +ArrayFloat1D = array_alias(dtype=float, ndim=1, scalar_to_1d=True) +ArrayFloat2D = array_alias(dtype=float, ndim=2) +ArrayFloat3D = array_alias(dtype=float, ndim=3) +ArrayFloat4D = array_alias(dtype=float, ndim=4) + +ArrayComplex = array_alias(dtype=complex) +ArrayComplex1D = array_alias(dtype=complex, ndim=1, scalar_to_1d=True) +ArrayComplex2D = array_alias(dtype=complex, ndim=2) +ArrayComplex3D = array_alias(dtype=complex, ndim=3) +ArrayComplex4D = array_alias(dtype=complex, ndim=4) + +TensorReal = array_alias(dtype=float, ndim=2, shape=(3, 3)) +MatrixReal4x4 = array_alias(dtype=float, ndim=2, shape=(4, 4)) """ Complex Values """ -class ComplexNumber(pydantic.BaseModel): - """Complex number with a well defined schema.""" +def _parse_complex(v: Any) -> complex: + if isinstance(v, complex): + return v - real: float - imag: float + if isinstance(v, dict) and "real" in v and "imag" in v: + return complex(v["real"], v["imag"]) - @property - def as_complex(self): - """return complex representation of ComplexNumber.""" - return self.real + 1j * self.imag + if isinstance(v, numbers.Number): + return complex(v) + if hasattr(v, "__complex__"): + try: + return complex(v.__complex__()) + except Exception: + pass -class tidycomplex(complex): - """complex type that we can use in our models.""" + if isinstance(v, (list, tuple)) and len(v) == 2: + return complex(v[0], v[1]) - @classmethod - def __get_validators__(cls): - """Defines which validator function to use for ComplexNumber.""" - yield cls.validate + return v - @classmethod - def validate(cls, value): - """What gets called when you construct a tidycomplex.""" - - if isinstance(value, ComplexNumber): - return value.as_complex - if isinstance(value, dict): - c = ComplexNumber(**value) - return c.as_complex - return cls(value) - - @classmethod - def __modify_schema__(cls, field_schema): - """Sets the schema of ComplexNumber.""" - field_schema.update(ComplexNumber.schema()) +Complex = Annotated[ + complex, + BeforeValidator(_parse_complex), + PlainSerializer( + lambda z, _: {"real": z.real, "imag": z.imag}, + when_used="json", + return_type=dict, + ), +] """ symmetry """ @@ -186,13 +174,13 @@ def __modify_schema__(cls, field_schema): """ geometric """ -Size1D = pydantic.NonNegativeFloat -Size = Tuple[Size1D, Size1D, Size1D] -Coordinate = Tuple[float, float, float] -CoordinateOptional = Tuple[Optional[float], Optional[float], Optional[float]] -Coordinate2D = Tuple[float, float] -Bound = Tuple[Coordinate, Coordinate] -GridSize = Union[pydantic.PositiveFloat, Tuple[pydantic.PositiveFloat, ...]] +Size1D = NonNegativeFloat +Size = tuple[Size1D, Size1D, Size1D] +Coordinate = tuple[float, float, float] +CoordinateOptional = tuple[Optional[float], Optional[float], Optional[float]] +Coordinate2D = tuple[float, float] +Bound = tuple[Coordinate, Coordinate] +GridSize = Union[PositiveFloat, tuple[PositiveFloat, ...]] Axis = Literal[0, 1, 2] Axis2D = Literal[0, 1] Shapely = BaseGeometry @@ -206,14 +194,11 @@ def __modify_schema__(cls, field_schema): # custom medium InterpMethod = Literal["nearest", "linear"] -# Complex = Union[complex, ComplexNumber] -Complex = Union[tidycomplex, ComplexNumber] -PoleAndResidue = Tuple[Complex, Complex] - -# PoleAndResidue = Tuple[Tuple[float, float], Tuple[float, float]] +PoleAndResidue = tuple[Complex, Complex] +PolesAndResidues = tuple[PoleAndResidue, ...] FreqBoundMax = float FreqBoundMin = float -FreqBound = Tuple[FreqBoundMin, FreqBoundMax] +FreqBound = tuple[FreqBoundMin, FreqBoundMax] PermittivityComponent = Literal["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"] @@ -226,8 +211,8 @@ def __modify_schema__(cls, field_schema): EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -FreqArray = Union[Tuple[float, ...], ArrayFloat1D] -ObsGridArray = Union[Tuple[float, ...], ArrayFloat1D] +FreqArray = ArrayFloat1D +ObsGridArray = ArrayFloat1D PolarizationBasis = Literal["linear", "circular"] AuxField = Literal["Nfx", "Nfy", "Nfz"] diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index 55bca5575e..6572b5d54f 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -1,17 +1,18 @@ """Defines various validation functions that get used to ensure inputs are legit""" +from typing import Optional + import numpy as np -import pydantic.v1 as pydantic from autograd.tracer import isbox +from pydantic import field_validator, model_validator +from ..compat import Self from ..exceptions import SetupError, ValidationError from ..log import log from .autograd.utils import get_static -from .base import DATA_ARRAY_MAP, skip_if_fields_missing -from .data.dataset import Dataset, FieldDataset +from .base import DATA_ARRAY_MAP from .geometry.base import Box from .mode_spec import ModeSpec -from .types import Tuple """ Explanation of pydantic validators: @@ -52,7 +53,8 @@ def assert_line(): """makes sure a field's ``size`` attribute has exactly 2 zeros""" - @pydantic.validator("size", allow_reuse=True, always=True) + @field_validator("size") + @classmethod def is_line(cls, val): """Raise validation error if not 1 dimensional.""" if val.count(0.0) != 2: @@ -65,7 +67,8 @@ def is_line(cls, val): def assert_plane(): """makes sure a field's ``size`` attribute has exactly 1 zero""" - @pydantic.validator("size", allow_reuse=True, always=True) + @field_validator("size") + @classmethod def is_plane(cls, val): """Raise validation error if not planar.""" if val.count(0.0) != 1: @@ -78,7 +81,8 @@ def is_plane(cls, val): def assert_line_or_plane(): """makes sure a field's ``size`` attribute has either 1 or 2 zeros""" - @pydantic.validator("size", allow_reuse=True, always=True) + @field_validator("size") + @classmethod def is_line_or_plane(cls, val): """Raise validation error if not a line or plane.""" if val.count(0.0) == 0 or val.count(0.0) == 3: @@ -93,7 +97,7 @@ def is_line_or_plane(cls, val): def assert_volumetric(): """makes sure a field's ``size`` attribute has no zero entry""" - @pydantic.validator("size", allow_reuse=True, always=True) + @field_validator("size") def is_volumetric(cls, val): """Raise validation error if volume is 0.""" if val.count(0.0) > 0: @@ -107,11 +111,12 @@ def is_volumetric(cls, val): return is_volumetric +# FIXME: this validator doesn't do anything def validate_name_str(): """make sure the name does not include [, ] (used for default names)""" - @pydantic.validator("name", allow_reuse=True, always=True, pre=True) - def field_has_unique_names(cls, val): + @field_validator("name") + def field_has_unique_names(val): """raise exception if '[' or ']' in name""" # if val and ('[' in val or ']' in val): # raise SetupError(f"'[' or ']' not allowed in name: {val} (used for defaults)") @@ -120,14 +125,14 @@ def field_has_unique_names(cls, val): return field_has_unique_names -def validate_unique(field_name: str): +def validate_unique(*field_names: str): """Make sure the given field has unique entries.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - def field_has_unique_entries(cls, val): + @field_validator(*field_names) + def field_has_unique_entries(val, info): """Check if the field has unique entries.""" if len(set(val)) != len(val): - raise SetupError(f"Entries of '{field_name}' must be unique.") + raise SetupError(f"Entries of '{info.field_name}' must be unique.") return val return field_has_unique_entries @@ -139,15 +144,15 @@ def validate_mode_objects_symmetry(field_name: str): obj_type = "ModeSource" if field_name == "sources" else "ModeMonitor" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "symmetry"]) - def check_symmetry(cls, val, values): + @model_validator(mode="after") + def check_symmetry(self): """check for intersection of each structure with simulation bounds.""" - sim_center = values.get("center") + val = getattr(self, field_name) + sim_center = self.center for position_index, geometric_object in enumerate(val): if geometric_object.type == obj_type: bounds_min, _ = geometric_object.bounds - for dim, sym in enumerate(values.get("symmetry")): + for dim, sym in enumerate(self.symmetry): if ( sym != 0 and bounds_min[dim] < sim_center[dim] @@ -159,21 +164,21 @@ def check_symmetry(cls, val, values): "or centered on the symmetry axis." ) - return val + return self return check_symmetry -def assert_unique_names(field_name: str): +def assert_unique_names(*field_names: str): """makes sure all elements of a field have unique .name values""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - def field_has_unique_names(cls, val, values): + @field_validator(*field_names) + def field_has_unique_names(val, info): """make sure each element of val has a unique name (if specified).""" field_names = [field.name for field in val if field.name] unique_names = set(field_names) if len(unique_names) != len(field_names): - raise SetupError(f"'{field_name}' names are not unique, given {field_names}.") + raise SetupError(f"'{info.field_name}' names are not unique, given {field_names}.") return val return field_has_unique_names @@ -184,12 +189,12 @@ def assert_objects_in_sim_bounds( ): """Makes sure all objects in field are at least partially inside of simulation bounds.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "size"]) - def objects_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def objects_in_sim_bounds(self): """check for intersection of each structure with simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") + val = getattr(self, field_name) + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) # Do a strict check, unless simulation is 0D along a dimension @@ -208,7 +213,7 @@ def objects_in_sim_bounds(cls, val, values): raise SetupError(message) consolidated_logger.warning(message, custom_loc=custom_loc) - return val + return self return objects_in_sim_bounds @@ -218,12 +223,12 @@ def assert_objects_contained_in_sim_bounds( ): """Makes sure all objects in field are completely inside the simulation bounds.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "size"]) - def objects_contained_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def objects_contained_in_sim_bounds(self): """check for containment of each structure with simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") + val = getattr(self, field_name) + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) # Do a strict check, unless simulation is 0D along a dimension @@ -242,7 +247,7 @@ def objects_contained_in_sim_bounds(cls, val, values): raise SetupError(message) consolidated_logger.warning(message, custom_loc=custom_loc) - return val + return self return objects_contained_in_sim_bounds @@ -250,13 +255,13 @@ def objects_contained_in_sim_bounds(cls, val, values): def enforce_monitor_fields_present(): """Make sure all of the fields in the monitor are present in the corresponding data.""" - @pydantic.root_validator(skip_on_failure=True, allow_reuse=True) - def _contains_fields(cls, values): + @model_validator(mode="after") + def _contains_fields(self): """Make sure the initially specified fields are here.""" - for field_name in values.get("monitor").fields: - if values.get(field_name) is None: + for field_name in self.monitor.fields: + if getattr(self, field_name) is None: raise SetupError(f"missing field {field_name}") - return values + return self return _contains_fields @@ -264,14 +269,14 @@ def _contains_fields(cls, values): def required_if_symmetry_present(field_name: str): """Make a field required (not None) if any non-zero symmetry eigenvalue is present.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["symmetry"]) - def _make_required(cls, val, values): + @model_validator(mode="after") + def _make_required(self): """Ensure val is not None if the symmetry is non-zero along any dimension.""" - symmetry = values.get("symmetry") + val = getattr(self, field_name) + symmetry = self.symmetry if any(sym_val != 0 for sym_val in symmetry) and val is None: raise SetupError(f"'{field_name}' must be provided if symmetry present.") - return val + return self return _make_required @@ -279,8 +284,8 @@ def _make_required(cls, val, values): def warn_if_dataset_none(field_name: str): """Warn if a Dataset field has None in its dictionary.""" - @pydantic.validator(field_name, pre=True, always=True, allow_reuse=True) - def _warn_if_none(cls, val: Dataset) -> Dataset: + @field_validator(field_name, mode="before") + def _warn_if_none(val: dict) -> Optional[dict]: """Warn if the DataArrays fail to load.""" if isinstance(val, dict): if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): @@ -294,13 +299,13 @@ def _warn_if_none(cls, val: Dataset) -> Dataset: def assert_single_freq_in_range(field_name: str): """Assert only one frequency supplied in source and it's in source time range.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - @skip_if_fields_missing(["source_time"]) - def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDataset: + @model_validator(mode="after") + def _single_frequency_in_range(self) -> Self: """Assert only one frequency supplied and it's in source time range.""" + val = getattr(self, field_name, None) if val is None: - return val - source_time = values.get("source_time") + return self + source_time = self.source_time fmin, fmax = source_time.frequency_range() for name, scalar_field in val.field_components.items(): freqs = scalar_field.f @@ -315,7 +320,7 @@ def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDat f"'{field_name}.{name}' contains frequency: {freq:.2e} Hz, which is outside " f"of the 'source_time' frequency range [{fmin:.2e}-{fmax:.2e}] Hz." ) - return val + return self return _single_frequency_in_range @@ -323,9 +328,9 @@ def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDat def _warn_potential_error( field_name: str, base_value: float, - val_change_range: Tuple[float, float], - allowed_real_range: Tuple[float, float], - allowed_imag_range: Tuple[float, float], + val_change_range: tuple[float, float], + allowed_real_range: tuple[float, float], + allowed_imag_range: tuple[float, float], ): """Basic validation that perturbations do not drive a parameter out of physical bounds.""" @@ -358,19 +363,19 @@ def _warn_potential_error( def validate_parameter_perturbation( field_name: str, base_field_name: str, - allowed_real_range: Tuple[Tuple[float, float], ...], - allowed_imag_range: Tuple[Tuple[float, float], ...] = None, + allowed_real_range: tuple[tuple[float, float], ...], + allowed_imag_range: tuple[tuple[float, float], ...] = None, allowed_complex: bool = True, ): """Assert perturbations do not drive a parameter out of physical bounds.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - def _warn_perturbed_val_range(cls, val, values): + @field_validator(field_name) + def _warn_perturbed_val_range(val, info): """Assert perturbations do not drive a parameter out of physical bounds.""" if val is not None: # get base values - base_values = values[base_field_name] + base_values = info.data[base_field_name] # check that shapes of base parameter and perturbations coincide if np.shape(base_values) != np.shape(val): @@ -428,7 +433,8 @@ def _assert_min_freq(freqs, msg_start: str): def validate_freqs_min(): """Validate lower bound for monitor, and mode solver frequencies.""" - @pydantic.validator("freqs", always=True, allow_reuse=True) + @field_validator("freqs") + @classmethod def freqs_lower_bound(cls, val): """Raise validation error if any of ``freqs`` is lower than ``MIN_FREQUENCY``.""" _assert_min_freq(val, msg_start=f"All of '{cls.__name__}.freqs'") @@ -440,7 +446,8 @@ def freqs_lower_bound(cls, val): def validate_freqs_not_empty(): """Validate that the array of frequencies is not empty.""" - @pydantic.validator("freqs", always=True, allow_reuse=True) + @field_validator("freqs") + @classmethod def freqs_not_empty(cls, val): """Raise validation error if ``freqs`` is an empty Tuple.""" if len(val) == 0: @@ -468,12 +475,13 @@ def validate_mode_plane_radius(mode_spec: ModeSpec, plane: Box, msg_prefix: str ) -def _warn_unsupported_traced_argument(name: str): - @pydantic.validator(name, always=True, allow_reuse=True) - def _warn_traced_arg(cls, val, values): +def _warn_unsupported_traced_argument(*names: str): + @field_validator(*names) + @classmethod + def _warn_traced_arg(cls, val, info): if isbox(val): log.warning( - f"Field '{name}' of '{cls.__name__}' received an autograd tracer " + f"Field '{info.field_name}' of '{cls.__name__}' received an autograd tracer " f"(i.e., a value being tracked for automatic differentiation). " f"Automatic differentiation through this field is unsupported, " f"so the tracer has been converted to its static value. " diff --git a/tidy3d/components/viz.py b/tidy3d/components/viz.py index 55a1ce0243..84d9b053c8 100644 --- a/tidy3d/components/viz.py +++ b/tidy3d/components/viz.py @@ -1,12 +1,11 @@ """utilities for plotting""" -from __future__ import annotations - from functools import wraps from html import escape -from typing import Any, Dict, Optional +from typing import Any, Optional -import pydantic.v1 as pd +from numpy import array, concatenate, inf, ones +from pydantic import Field, NonNegativeFloat, field_validator try: import matplotlib.pyplot as plt @@ -20,10 +19,9 @@ except ImportError: arrow_style = None -from numpy import array, concatenate, inf, ones - +from ..compat import Self from ..constants import UnitScaling -from ..exceptions import SetupError, Tidy3dKeyError +from ..exceptions import SetupError, Tidy3dKeyError, ValidationError from .base import Tidy3dBaseModel from .types import Ax, Axis, LengthUnit @@ -97,19 +95,19 @@ class AbstractPlotParams(Tidy3dBaseModel): Corresponds with select properties of ``matplotlib.artist.Artist``. """ - alpha: Any = pd.Field(1.0, title="Opacity") - zorder: float = pd.Field(None, title="Display Order") + alpha: float = Field(1.0, title="Opacity", ge=0, le=1) + zorder: Optional[float] = Field(None, title="Display Order") - def include_kwargs(self, **kwargs) -> AbstractPlotParams: + def include_kwargs(self, **kwargs) -> Self: """Update the plot params with supplied kwargs.""" update_dict = { key: value for key, value in kwargs.items() - if key not in ("type",) and value is not None and key in self.__fields__ + if key not in ("type",) and value is not None and key in self.model_fields } - return self.copy(update=update_dict) + return self.model_copy(update=update_dict) - def override_with_viz_spec(self, viz_spec) -> AbstractPlotParams: + def override_with_viz_spec(self, viz_spec) -> Self: """Override plot params with supplied VisualizationSpec.""" return self.include_kwargs(**dict(viz_spec)) @@ -126,13 +124,13 @@ class PathPlotParams(AbstractPlotParams): Corresponds with select properties of ``matplotlib.lines.Line2D``. """ - color: Any = pd.Field(None, title="Color", alias="c") - linewidth: pd.NonNegativeFloat = pd.Field(2, title="Line Width", alias="lw") - linestyle: str = pd.Field("--", title="Line Style", alias="ls") - marker: Any = pd.Field("o", title="Marker Style") - markeredgecolor: Any = pd.Field(None, title="Marker Edge Color", alias="mec") - markerfacecolor: Any = pd.Field(None, title="Marker Face Color", alias="mfc") - markersize: pd.NonNegativeFloat = pd.Field(10, title="Marker Size", alias="ms") + color: Optional[Any] = Field(None, title="Color", alias="c") + linewidth: NonNegativeFloat = Field(2, title="Line Width", alias="lw") + linestyle: str = Field("--", title="Line Style", alias="ls") + marker: Any = Field("o", title="Marker Style") + markeredgecolor: Optional[Any] = Field(None, title="Marker Edge Color", alias="mec") + markerfacecolor: Optional[Any] = Field(None, title="Marker Face Color", alias="mfc") + markersize: NonNegativeFloat = Field(10, title="Marker Size", alias="ms") class PlotParams(AbstractPlotParams): @@ -140,11 +138,11 @@ class PlotParams(AbstractPlotParams): Corresponds with select properties of ``matplotlib.patches.Patch``. """ - edgecolor: Any = pd.Field(None, title="Edge Color", alias="ec") - facecolor: Any = pd.Field(None, title="Face Color", alias="fc") - fill: bool = pd.Field(True, title="Is Filled") - hatch: str = pd.Field(None, title="Hatch Style") - linewidth: pd.NonNegativeFloat = pd.Field(1, title="Line Width", alias="lw") + edgecolor: Optional[Any] = Field(None, title="Edge Color", alias="ec") + facecolor: Optional[Any] = Field(None, title="Face Color", alias="fc") + fill: bool = Field(True, title="Is Filled") + hatch: Optional[str] = Field(None, title="Hatch Style") + linewidth: NonNegativeFloat = Field(1, title="Line Width", alias="lw") # defaults for different tidy3d objects @@ -186,7 +184,7 @@ class PlotParams(AbstractPlotParams): def is_valid_color(value: str) -> str: if not is_color_like(value): - raise pd.ValidationError(f"{value} is not a valid plotting color") + raise ValidationError(f"'{value}' is not a valid plotting color") return value @@ -194,32 +192,34 @@ def is_valid_color(value: str) -> str: class VisualizationSpec(Tidy3dBaseModel): """Defines specification for visualization when used with plotting functions.""" - facecolor: str = pd.Field( + facecolor: str = Field( "", title="Face color", description="Color applied to the faces in visualization.", ) - edgecolor: Optional[str] = pd.Field( + edgecolor: Optional[str] = Field( "", title="Edge color", description="Color applied to the edges in visualization.", ) - alpha: Optional[pd.confloat(ge=0.0, le=1.0)] = pd.Field( + alpha: Optional[float] = Field( 1.0, title="Opacity", description="Opacity/alpha value in plotting between 0 and 1.", + ge=0, + le=1, ) - @pd.validator("facecolor", always=True) - def validate_color(value: str) -> str: + @field_validator("facecolor") + def _validate_facecolor(value): return is_valid_color(value) - @pd.validator("edgecolor", always=True) - def validate_and_copy_color(value: str, values: Dict[str, Any]) -> str: - if (value == "") and "facecolor" in values: - return is_valid_color(values["facecolor"]) + @field_validator("edgecolor") + def _validate_edgecolor(value, info): + if value == "" and "facecolor" in info.data: + value = info.data["facecolor"] return is_valid_color(value) diff --git a/tidy3d/config.py b/tidy3d/config.py index 3598bc27a7..81893a8bb6 100644 --- a/tidy3d/config.py +++ b/tidy3d/config.py @@ -1,24 +1,23 @@ """Sets the configuration of the script, can be changed with `td.config.config_name = new_val`.""" -import pydantic.v1 as pd +from pydantic import BaseModel, ConfigDict, Field, field_validator from .log import DEFAULT_LEVEL, LogLevel, set_log_suppression, set_logging_level -class Tidy3dConfig(pd.BaseModel): +class Tidy3dConfig(BaseModel): """configuration of tidy3d""" - class Config: - """Config of the config.""" - - arbitrary_types_allowed = False - validate_all = True - extra = "forbid" - validate_assignment = True - allow_population_by_field_name = True - frozen = False + model_config = ConfigDict( + arbitrary_types_allowed=False, + validate_default=True, + extra="forbid", + validate_assignment=True, + populate_by_name=True, + frozen=False, + ) - logging_level: LogLevel = pd.Field( + logging_level: LogLevel = Field( DEFAULT_LEVEL, title="Logging Level", description="The lowest level of logging output that will be displayed. " @@ -26,21 +25,21 @@ class Config: 'Note: "SUPPORT" and "USER" levels are only used in backend solver logging.', ) - log_suppression: bool = pd.Field( + log_suppression: bool = Field( True, title="Log suppression", description="Enable or disable suppression of certain log messages when they are repeated " "for several elements.", ) - @pd.validator("logging_level", pre=True, always=True) - def _set_logging_level(cls, val): + @field_validator("logging_level") + def _set_logging_level(val): """Set the logging level if logging_level is changed.""" set_logging_level(val) return val - @pd.validator("log_suppression", pre=True, always=True) - def _set_log_suppression(cls, val): + @field_validator("log_suppression") + def _set_log_suppression(val): """Control log suppression when log_suppression is changed.""" set_log_suppression(val) return val diff --git a/tidy3d/exceptions.py b/tidy3d/exceptions.py index 32ccc88e49..31ace387e9 100644 --- a/tidy3d/exceptions.py +++ b/tidy3d/exceptions.py @@ -1,8 +1,15 @@ """Custom Tidy3D exceptions""" +from pydantic_core import PydanticCustomError + from .log import log +class PostInitValidationError(PydanticCustomError): + code = "post_init_validator" + msg_template = 'post-init validator "{validator}" failed: {msg}' + + class Tidy3dError(ValueError): """Any error in tidy3d""" diff --git a/tidy3d/log.py b/tidy3d/log.py index 9bbb0e04be..8e99423151 100644 --- a/tidy3d/log.py +++ b/tidy3d/log.py @@ -2,11 +2,10 @@ import inspect from datetime import datetime -from typing import Callable, List, Tuple, Union +from typing import Callable, Literal, Union from rich.console import Console from rich.text import Text -from typing_extensions import Literal # Note: "SUPPORT" and "USER" levels are meant for backend runs only. # Logging in frontend code should just use the standard debug/info/warning/error/critical. @@ -42,7 +41,7 @@ CONSOLE_WIDTH = 80 -def _default_log_level_format(level: str, message: str) -> Tuple[str, str]: +def _default_log_level_format(level: str, message: str) -> tuple[str, str]: """By default just return unformatted prefix and message.""" return level, message @@ -242,7 +241,7 @@ def _log( message: str, *args, log_once: bool = False, - custom_loc: List = None, + custom_loc: list = None, capture: bool = True, ) -> None: """Distribute log messages to all handlers""" @@ -311,7 +310,7 @@ def warning( message: str, *args, log_once: bool = False, - custom_loc: List = None, + custom_loc: list = None, capture: bool = True, ) -> None: """Log (message) % (args) at warning level""" diff --git a/tidy3d/material_library/material_library.py b/tidy3d/material_library/material_library.py index 18550e04cf..cd9951a86c 100644 --- a/tidy3d/material_library/material_library.py +++ b/tidy3d/material_library/material_library.py @@ -1,22 +1,21 @@ """Holds dispersive models for several commonly used optical materials.""" import json -from typing import Dict, List, Union +from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field, model_validator -from tidy3d.components.material.multi_physics import MultiPhysicsMedium -from tidy3d.components.material.tcad.charge import SemiconductorMedium -from tidy3d.components.tcad.types import ( +from ..components.base import Tidy3dBaseModel +from ..components.material.multi_physics import MultiPhysicsMedium +from ..components.material.tcad.charge import SemiconductorMedium +from ..components.medium import AnisotropicMedium, Medium2D, PoleResidue, Sellmeier +from ..components.tcad.types import ( AugerRecombination, CaugheyThomasMobility, RadiativeRecombination, ShockleyReedHallRecombination, SlotboomBandGapNarrowing, ) - -from ..components.base import Tidy3dBaseModel -from ..components.medium import AnisotropicMedium, Medium2D, PoleResidue, Sellmeier from ..components.types import Axis from ..exceptions import SetupError from ..log import log @@ -66,13 +65,13 @@ def export_matlib_to_file(fname: str = "matlib.json") -> None: class AbstractVariantItem(Tidy3dBaseModel): """Reference, and data_source for a variant of a material.""" - reference: List[ReferenceData] = pd.Field( + reference: Optional[list[ReferenceData]] = Field( None, title="Reference information", description="A list of references related to this variant model.", ) - data_url: str = pd.Field( + data_url: Optional[str] = Field( None, title="Dispersion data URL", description="The URL to access the dispersion data upon which the material " @@ -80,7 +79,7 @@ class AbstractVariantItem(Tidy3dBaseModel): ) @property - def summarize_mediums(self) -> Dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: + def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: return {} def __str__(self): @@ -96,40 +95,36 @@ def _repr_pretty_(self, p, cycle): class VariantItem(AbstractVariantItem): """Reference, data_source, and material model for a variant of a material.""" - medium: Union[PoleResidue, MultiPhysicsMedium] = pd.Field( - ..., + medium: Union[PoleResidue, MultiPhysicsMedium] = Field( title="Material dispersion model", description="A dispersive medium described by the pole-residue pair model.", ) @property - def summarize_mediums(self) -> Dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: + def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: return {"medium": self.medium} class MaterialItem(Tidy3dBaseModel): """A material that includes several variants.""" - name: str = pd.Field(..., title="Name", description="Unique name for the medium.") - variants: Dict[str, VariantItem] = pd.Field( - ..., + name: str = Field(title="Name", description="Unique name for the medium.") + variants: dict[str, VariantItem] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", ) - default: str = pd.Field( - ..., title="default variant", description="The default type of variant." - ) + default: str = Field(title="default variant", description="The default type of variant.") - @pd.validator("default", always=True) - def _default_in_variants(cls, val, values): + @model_validator(mode="after") + def _default_in_variants(self): """Make sure the default variant is already included in the ``variants``.""" - if val not in values["variants"]: + if self.default not in self.variants: raise SetupError( - f"The data of the default variant '{val}' is not supplied; " + f"The data of the default variant '{self.default}' is not supplied; " "please include it in the 'variants'." ) - return val + return self def __getitem__(self, variant_name): """Helper function to easily access the medium of a variant""" @@ -158,8 +153,7 @@ def _repr_pretty_(self, p, cycle): class VariantItem2D(AbstractVariantItem): """Reference, data_source, and material model for a variant of a 2D material.""" - medium: Medium2D = pd.Field( - ..., + medium: Medium2D = Field( title="Material dispersion model", description="A dispersive 2D medium described by a surface conductivity model, " "which is handled as an anisotropic medium with pole-residue pair models " @@ -167,15 +161,14 @@ class VariantItem2D(AbstractVariantItem): ) @property - def summarize_mediums(self) -> Dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: + def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: return {"medium": self.medium} class MaterialItem2D(MaterialItem): """A 2D material that includes several variants.""" - variants: Dict[str, VariantItem2D] = pd.Field( - ..., + variants: dict[str, VariantItem2D] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", @@ -185,14 +178,12 @@ class MaterialItem2D(MaterialItem): class VariantItemUniaxial(AbstractVariantItem): """Reference, data_source, and material model for a variant of an uniaxial material.""" - ordinary: PoleResidue = pd.Field( - ..., + ordinary: PoleResidue = Field( title="Ordinary Component", description="Medium describing the ordinary component.", ) - extraordinary: PoleResidue = pd.Field( - ..., + extraordinary: PoleResidue = Field( title="Extraordinary Component", description="Medium describing the extraordinary component.", ) @@ -218,15 +209,14 @@ def medium(self, optical_axis: Axis) -> AnisotropicMedium: return AnisotropicMedium.parse_obj(mat_dict) @property - def summarize_mediums(self) -> Dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: + def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: return {"ordinary": self.ordinary, "extraordinary": self.extraordinary} class MaterialItemUniaxial(MaterialItem): """A material that includes several variants.""" - variants: Dict[str, VariantItemUniaxial] = pd.Field( - ..., + variants: dict[str, VariantItemUniaxial] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", diff --git a/tidy3d/material_library/material_reference.py b/tidy3d/material_library/material_reference.py index ac27867dfe..a4930812c4 100644 --- a/tidy3d/material_library/material_reference.py +++ b/tidy3d/material_library/material_reference.py @@ -1,6 +1,8 @@ """Holds the reference materials for Tidy3D material library.""" -import pydantic.v1 as pd +from typing import Optional + +from pydantic import Field from ..components.base import Tidy3dBaseModel @@ -8,23 +10,23 @@ class ReferenceData(Tidy3dBaseModel): """Reference data.""" - doi: str = pd.Field(None, title="DOI", description="DOI of the reference.") - journal: str = pd.Field( + doi: Optional[str] = Field(None, title="DOI", description="DOI of the reference.") + journal: Optional[str] = Field( None, title="Journal publication info", description="Publication info in the order of author, title, journal volume, and year.", ) - url: str = pd.Field( + url: Optional[str] = Field( None, title="URL link", description="Some reference can be accessed through a url link to its pdf etc.", ) - manufacturer: str = pd.Field( + manufacturer: Optional[str] = Field( None, title="Manufacturer", description="Name of the manufacturer, e.g., Rogers, Arlon.", ) - datasheet_title: str = pd.Field( + datasheet_title: Optional[str] = Field( None, title="Datasheet Title", description="Title of the datasheet.", diff --git a/tidy3d/material_library/parametric_materials.py b/tidy3d/material_library/parametric_materials.py index 02608aca9d..b62b4ea1eb 100644 --- a/tidy3d/material_library/parametric_materials.py +++ b/tidy3d/material_library/parametric_materials.py @@ -2,23 +2,16 @@ import warnings from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt from ..components.base import Tidy3dBaseModel from ..components.medium import Drude, Medium2D, PoleResidue from ..constants import ELECTRON_VOLT, EPSILON_0, HBAR, K_B, KELVIN, Q_e from ..log import log -try: - from scipy import integrate - - INTEGRATE_AVAILABLE = True -except ImportError: - INTEGRATE_AVAILABLE = False - # default values of the physical parameters for graphene # scattering rate in eV GRAPHENE_DEF_GAMMA = 0.00041 @@ -74,35 +67,35 @@ class Graphene(ParametricVariantItem2D): """ - mu_c: float = pd.Field( + mu_c: float = Field( GRAPHENE_DEF_MU_C, title="Chemical potential in eV", description="Chemical potential in eV.", units=ELECTRON_VOLT, ) - temp: float = pd.Field( + temp: float = Field( GRAPHENE_DEF_TEMP, title="Temperature in K", description="Temperature in K.", units=KELVIN ) - gamma: float = pd.Field( + gamma: float = Field( GRAPHENE_DEF_GAMMA, title="Scattering rate in eV", description="Scattering rate in eV. Must be small compared to the optical frequency.", units=ELECTRON_VOLT, ) - scaling: float = pd.Field( + scaling: float = Field( 1, title="Scaling factor", description="Scaling factor used to model multiple layers of graphene.", ) - include_interband: bool = pd.Field( + include_interband: bool = Field( True, title="Include interband terms", description="Include interband terms, relevant at high frequency (IR). " "Otherwise, the intraband terms only give a simpler Drude-type model relevant " "only at low frequency (THz).", ) - interband_fit_freq_nodes: List[Tuple[float, float]] = pd.Field( + interband_fit_freq_nodes: Optional[list[tuple[float, float]]] = Field( None, title="Interband fitting frequency nodes", description="Frequency nodes for fitting interband term. " @@ -114,7 +107,7 @@ class Graphene(ParametricVariantItem2D): "of frequencies; consider changing the nodes to obtain a better fit for a " "narrow-band simulation.", ) - interband_fit_num_iters: pd.NonNegativeInt = pd.Field( + interband_fit_num_iters: NonNegativeInt = Field( GRAPHENE_FIT_NUM_ITERS, title="Interband fitting number of iterations", description="Number of iterations for optimizing each Pade approximant when " @@ -199,7 +192,7 @@ def interband_pole_residue(self) -> PoleResidue: ) return pole_residue_filtered - def numerical_conductivity(self, freqs: List[float]) -> List[complex]: + def numerical_conductivity(self, freqs: list[float]) -> list[complex]: """Numerically calculate the conductivity. If this differs from the conductivity of the :class:`.Medium2D`, it is due to error while fitting the interband term, and you may try values of ``interband_fit_freq_nodes`` @@ -207,31 +200,39 @@ def numerical_conductivity(self, freqs: List[float]) -> List[complex]: Parameters ---------- - freqs : List[float] + freqs : list[float] The list of frequencies. Returns ------- - List[complex] + list[complex] The list of corresponding conductivities, in S. """ intra_sigma = self.intraband_drude.sigma_model(freqs) inter_sigma = self.interband_conductivity(freqs) return intra_sigma + inter_sigma - def interband_conductivity(self, freqs: List[float]) -> List[complex]: + def interband_conductivity(self, freqs: list[float]) -> list[complex]: """Numerically integrate interband term. Parameters ---------- - freqs : List[float] + freqs : list[float] The list of frequencies. Returns ------- - List[complex] + list[complex] The list of corresponding interband conductivities, in S. """ + try: + from scipy import integrate + except ImportError: + raise ImportError( + "The package 'scipy' was not found. Please install the 'core' " + "dependencies to calculate the interband term of graphene. For example: " + "pip install tidy3d" + ) def fermi(E: float) -> float: """Fermi distribution.""" @@ -248,13 +249,6 @@ def integrand(E: float, omega: float) -> float: """Integrand for interband term.""" return (fermi_g(E * HBAR) - fermi_g(HBAR * omega / 2)) / (omega**2 - 4 * E**2) - if not INTEGRATE_AVAILABLE: - raise ImportError( - "The package 'scipy' was not found. Please install the 'core' " - "dependencies to calculate the interband term of graphene. For example: " - "pip install tidy3d" - ) - omegas = 2 * np.pi * np.array(freqs) sigma = np.zeros(len(omegas), dtype=complex) integration_min = GRAPHENE_INT_MIN @@ -271,9 +265,9 @@ def integrand(E: float, omega: float) -> float: def _fit_interband_conductivity( self, - freqs: List[float], - sigma: List[complex], - indslist: List[Tuple[int, int]], + freqs: list[float], + sigma: list[complex], + indslist: list[tuple[int, int]], ): """Fit the interband conductivity with a Pade approximation, as described in @@ -283,11 +277,11 @@ def _fit_interband_conductivity( Parameters ---------- - freqs : List[float] + freqs : list[float] The input frequencies. - sigma : List[complex] + sigma : list[complex] The interband conductivity to fit. - indslist : List[Tuple[int, int]] + indslist : list[tuple[int, int]] The indices at which to sample the data for fitting. The length of this list determines the number of Pade terms used. Returns @@ -296,7 +290,7 @@ def _fit_interband_conductivity( A pole-residue model approximating the interband conductivity. """ - def evaluate_coeffslist(omega: List[float], coeffslist: List[List[float]]) -> List[float]: + def evaluate_coeffslist(omega: list[float], coeffslist: list[list[float]]) -> list[float]: """Evaluate the Pade approximants given by ``coeffslist` to ``omega``. Each item in ``coeffslist`` is a list of four coefficients corresponding to a single Pade term.""" @@ -308,8 +302,8 @@ def evaluate_coeffslist(omega: List[float], coeffslist: List[List[float]]) -> Li return res def fit_single( - omega: List[float], sigma: List[complex], inds: Tuple[int, int] - ) -> List[float]: + omega: list[float], sigma: list[complex], inds: tuple[int, int] + ) -> list[float]: """Fit a single Pade approximant of degree (1, 2) to ``sigma`` as a real function of i ``omega``. The method is described in @@ -331,11 +325,11 @@ def fit_single( return np.linalg.pinv(matrix) @ np.array([gamma[0], eta[0], gamma[1], eta[1]]) def optimize( - omega: List[float], - sigma: List[complex], - indslist: List[Tuple[int, int]], - coeffslist: List[List[float]], - ) -> List[float]: + omega: list[float], + sigma: list[complex], + indslist: list[tuple[int, int]], + coeffslist: list[list[float]], + ) -> list[float]: """Optimize the coefficients in ``coeffslist`` by sampling ``omega`` and ``sigma`` at the indices in ``indslist``.""" for _ in range(self.interband_fit_num_iters): @@ -346,7 +340,7 @@ def optimize( coeffslist[j] = fit_single(omega, curr_res, indslist[j]) return coeffslist - def get_pole_residue(coeffslist: List[List[float]]) -> PoleResidue: + def get_pole_residue(coeffslist: list[list[float]]) -> PoleResidue: """Convert a list of Pade coefficients into a :class:`.PoleResidue` model.""" poles = [] for alpha0, alpha1, beta1, beta2 in coeffslist: diff --git a/tidy3d/plugins/adjoint/components/__init__.py b/tidy3d/plugins/adjoint/components/__init__.py index c9d9cdb5cc..392cd053c9 100644 --- a/tidy3d/plugins/adjoint/components/__init__.py +++ b/tidy3d/plugins/adjoint/components/__init__.py @@ -4,8 +4,7 @@ from .data.data_array import JaxDataArray from .data.dataset import JaxPermittivityDataset from .data.monitor_data import JaxModeData -from .data.sim_data import JaxSimulationData -from .geometry import JaxBox, JaxComplexPolySlab, JaxPolySlab +from .geometry import JaxBox, JaxComplexPolySlab, JaxGeometryGroup, JaxPolySlab from .medium import JaxAnisotropicMedium, JaxCustomMedium, JaxMedium from .simulation import JaxSimulation from .structure import JaxStructure, JaxStructureStaticGeometry, JaxStructureStaticMedium @@ -22,7 +21,6 @@ "JaxStructureStaticMedium", "JaxStructureStaticGeometry", "JaxSimulation", - "JaxSimulationData", "JaxModeData", "JaxPermittivityDataset", "JaxDataArray", diff --git a/tidy3d/plugins/adjoint/components/base.py b/tidy3d/plugins/adjoint/components/base.py index 1fcd6baaa7..a997b6cd86 100644 --- a/tidy3d/plugins/adjoint/components/base.py +++ b/tidy3d/plugins/adjoint/components/base.py @@ -3,15 +3,16 @@ from __future__ import annotations import json -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Iterable import jax import numpy as np -import pydantic.v1 as pd from jax.tree_util import tree_flatten as jax_tree_flatten from jax.tree_util import tree_unflatten as jax_tree_unflatten +from pydantic import model_validator + +from tidy3d.components.base import Tidy3dBaseModel -from ....components.base import Tidy3dBaseModel from .data.data_array import JAX_DATA_ARRAY_TAG, JaxDataArray # end of the error message when a ``_validate_web_adjoint`` exception is raised @@ -36,27 +37,27 @@ class JaxObject(Tidy3dBaseModel): """Shortcut to get names of fields with certain properties.""" @classmethod - def _get_field_names(cls, field_key: str) -> List[str]: + def _get_field_names(cls, field_key: str) -> list[str]: """Get all fields where ``field_key`` defined in the ``pydantic.Field``.""" fields = [] - for field_name, model_field in cls.__fields__.items(): - field_value = model_field.field_info.extra.get(field_key) - if field_value: - fields.append(field_name) + for name, field in cls.model_fields.items(): + extra = field.json_schema_extra or {} + if extra.get(field_key): + fields.append(name) return fields @classmethod - def get_jax_field_names(cls) -> List[str]: + def get_jax_field_names(cls) -> list[str]: """Returns list of field names where ``jax_field=True``.""" return cls._get_field_names("jax_field") @classmethod - def get_jax_leaf_names(cls) -> List[str]: + def get_jax_leaf_names(cls) -> list[str]: """Returns list of field names where ``stores_jax_for`` defined.""" return cls._get_field_names("stores_jax_for") @classmethod - def get_jax_field_names_all(cls) -> List[str]: + def get_jax_field_names_all(cls) -> list[str]: """Returns list of field names where ``jax_field=True`` or ``stores_jax_for`` defined.""" jax_field_names = cls.get_jax_field_names() jax_leaf_names = cls.get_jax_leaf_names() @@ -76,7 +77,7 @@ def _validate_web_adjoint(self) -> None: """Methods needed for jax to register arbitrary classes.""" - def tree_flatten(self) -> Tuple[list, dict]: + def tree_flatten(self) -> tuple[list, dict]: """How to flatten a :class:`.JaxObject` instance into a ``pytree``.""" children = [] aux_data = self.dict() @@ -135,20 +136,22 @@ def to_tidy3d(self: JaxObject) -> Tidy3dBaseModel: @classmethod def from_tidy3d(cls, tidy3d_obj: Tidy3dBaseModel) -> JaxObject: """Convert :class:`.Tidy3dBaseModel` instance to :class:`.JaxObject`.""" - obj_dict = tidy3d_obj.dict(exclude={"type"}) + obj_dict = tidy3d_obj.model_dump(exclude={"type"}) for key in cls.get_jax_field_names(): - sub_field_type = cls.__fields__[key].type_ + sub_field_type = cls.model_fields[key].annotation tidy3d_sub_field = getattr(tidy3d_obj, key) # TODO: simplify this logic - if isinstance(tidy3d_sub_field, (tuple, list)): - obj_dict[key] = [sub_field_type.from_tidy3d(x) for x in tidy3d_sub_field] + if isinstance(tidy3d_sub_field, Iterable) and not isinstance( + tidy3d_sub_field, (str, bytes) + ): + obj_dict[key] = [sub_field_type.from_tidy3d(v) for v in tidy3d_sub_field] else: obj_dict[key] = sub_field_type.from_tidy3d(tidy3d_sub_field) # end TODO - return cls.parse_obj(obj_dict) + return cls.model_validate(obj_dict) @property def exclude_fields_leafs_only(self) -> set: @@ -157,38 +160,42 @@ def exclude_fields_leafs_only(self) -> set: """Accounting with jax and regular fields.""" - @pd.root_validator(pre=True) - def handle_jax_kwargs(cls, values: dict) -> dict: + @model_validator(mode="before") + @classmethod + def handle_jax_kwargs(cls, data: dict[str, Any]) -> dict[str, Any]: """Pass jax inputs to the jax fields and pass untraced values to the regular fields.""" - # for all jax-traced fields for jax_name in cls.get_jax_leaf_names(): - # if a value was passed to the object for the regular field - orig_name = cls.__fields__[jax_name].field_info.extra.get("stores_jax_for") - val = values.get(orig_name) - if val is not None: - # try adding the sanitized (no trace) version to the regular field - try: - values[orig_name] = jax.lax.stop_gradient(val) - - # if it doesnt work, just pass the raw value (necessary to handle inf strings) - except TypeError: - values[orig_name] = val - - # if the jax name was not specified directly, use the original traced value - if jax_name not in values: - values[jax_name] = val - - return values - - @pd.root_validator(pre=True) - def handle_array_jax_leafs(cls, values: dict) -> dict: + # where the un-traced value should go + meta = cls.model_fields[jax_name].json_schema_extra or {} + orig_name = meta.get("stores_jax_for") + + if orig_name is None: + continue + + val = data.get(orig_name) + if val is None: + continue # nothing supplied for the plain field + + # put a non-traced version on the regular field + try: + data[orig_name] = jax.lax.stop_gradient(val) + except TypeError: + data[orig_name] = val + + data.setdefault(jax_name, val) + + return data + + @model_validator(mode="before") + @classmethod + def handle_array_jax_leafs(cls, data) -> dict: """Convert jax_leafs that are passed as numpy arrays.""" for jax_name in cls.get_jax_leaf_names(): - val = values.get(jax_name) + val = data.get(jax_name) if isinstance(val, np.ndarray): - values[jax_name] = val.tolist() - return values + data[jax_name] = val.tolist() + return data """ IO """ @@ -218,14 +225,14 @@ def strip_data_array(val: Any) -> Any: # TODO: replace with implementing these in DataArray - def to_hdf5(self, fname: str, custom_encoders: List[Callable] = None) -> None: + def to_hdf5(self, fname: str, custom_encoders: list[Callable] = None) -> None: """Exports :class:`JaxObject` instance to .hdf5 file. Parameters ---------- fname : str Full path to the .hdf5 file to save the :class:`JaxObject` to. - custom_encoders : List[Callable] + custom_encoders : list[Callable] List of functions accepting (fname: str, group_path: str, value: Any) that take the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. @@ -249,7 +256,7 @@ def data_array_encoder(fname: str, group_path: str, value: Any) -> None: @classmethod def dict_from_hdf5( - cls, fname: str, group_path: str = "", custom_decoders: List[Callable] = None + cls, fname: str, group_path: str = "", custom_decoders: list[Callable] = None ) -> dict: """Loads a dictionary containing the model contents from a .hdf5 file. @@ -259,7 +266,7 @@ def dict_from_hdf5( Full path to the .hdf5 file to load the :class:`JaxObject` from. group_path : str, optional Path to a group inside the file to selectively load a sub-element of the model only. - custom_decoders : List[Callable] + custom_decoders : list[Callable] List of functions accepting (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the value in the model dict after a custom decoding. diff --git a/tidy3d/plugins/adjoint/components/data/data_array.py b/tidy3d/plugins/adjoint/components/data/data_array.py index 083dd9d2e9..4488bff15e 100644 --- a/tidy3d/plugins/adjoint/components/data/data_array.py +++ b/tidy3d/plugins/adjoint/components/data/data_array.py @@ -2,18 +2,18 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Sequence, Tuple, Union +from typing import Any, Literal, Sequence, Union import h5py import jax import jax.numpy as jnp import numpy as np -import pydantic.v1 as pd import xarray as xr from jax.tree_util import register_pytree_node_class +from pydantic import Field, field_validator, model_validator -from .....components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from .....exceptions import AdjointError, DataError, Tidy3dKeyError +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.exceptions import AdjointError, DataError, Tidy3dKeyError # condition setting when to set value in DataArray to zero: # if abs(val) <= VALUE_FILTER_THRESHOLD * max(abs(val)) @@ -27,15 +27,13 @@ class JaxDataArray(Tidy3dBaseModel): """A :class:`.DataArray`-like class that only wraps xarray for jax compatibility.""" - values: Any = pd.Field( - ..., + values: Any = Field( title="Values", description="Nested list containing the raw values, which can be tracked by jax.", jax_field=True, ) - coords: Dict[str, list] = pd.Field( - ..., + coords: dict[str, list] = Field( title="Coords", description="Dictionary storing the coordinates, namely ``(direction, f, mode_index)``.", ) @@ -51,19 +49,18 @@ def from_tidy3d(cls, tidy3d_obj: xr.DataArray) -> JaxDataArray: coords = {k: np.array(v).tolist() for k, v in tidy3d_obj.coords.items()} return cls(values=tidy3d_obj.data, coords=coords) - @pd.validator("values", always=True) - def _convert_values_to_np(cls, val): + @field_validator("values") + def _convert_values_to_np(val): """Convert supplied values to numpy if they are list (from file).""" if isinstance(val, list): return np.array(val) return val - @pd.validator("coords", always=True) - @skip_if_fields_missing(["values"]) - def _coords_match_values(cls, val, values): + @model_validator(mode="after") + def _coords_match_values(self): """Make sure the coordinate dimensions and shapes match the values data.""" - _values = values.get("values") + _values = self.values # get the shape, handling both regular and jax objects try: @@ -71,7 +68,7 @@ def _coords_match_values(cls, val, values): except TypeError: values_shape = jnp.array(_values).shape - for (key, coord_val), size_dim in zip(val.items(), values_shape): + for (key, coord_val), size_dim in zip(self.coord.items(), values_shape): if len(coord_val) != size_dim: raise ValueError( f"JaxDataArray coord {key} has {len(coord_val)} elements, " @@ -79,11 +76,11 @@ def _coords_match_values(cls, val, values): f"with size {size_dim} along that dimension." ) - return val + return self - @pd.validator("coords", always=True) - def _convert_coords_to_list(cls, val): - """Convert supplied coordinates to Dict[str, list].""" + @field_validator("coords") + def _convert_coords_to_list(val): + """Convert supplied coordinates to dict[str, list].""" return {coord_name: list(coord_list) for coord_name, coord_list in val.items()} def __eq__(self, other) -> bool: @@ -407,7 +404,7 @@ def assign_coords(self, coords: dict = None, **coords_kwargs) -> JaxDataArray: update_kwargs = {key: np.array(value).tolist() for key, value in update_kwargs.items()} return self.updated_copy(coords=update_kwargs) - def multiply_at(self, value: complex, coord_name: str, indices: List[int]) -> JaxDataArray: + def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> JaxDataArray: """Multiply self by value at indices into .""" axis = list(self.coords.keys()).index(coord_name) scalar_data_arr = self.as_jnp_array @@ -498,7 +495,7 @@ def interp(self, kwargs=None, assume_sorted=None, **interp_kwargs) -> JaxDataArr return ret_value @cached_property - def nonzero_val_coords(self) -> Tuple[List[complex], Dict[str, Any]]: + def nonzero_val_coords(self) -> tuple[list[complex], dict[str, Any]]: """The value and coordinate associated with the only non-zero element of ``self.values``.""" values = np.nan_to_num(self.as_ndarray) @@ -519,7 +516,7 @@ def nonzero_val_coords(self) -> Tuple[List[complex], Dict[str, Any]]: return nonzero_values, nonzero_coords - def tree_flatten(self) -> Tuple[list, dict]: + def tree_flatten(self) -> tuple[list, dict]: """Jax works on the values, stash the coords for reconstruction.""" return self.values, self.coords diff --git a/tidy3d/plugins/adjoint/components/data/dataset.py b/tidy3d/plugins/adjoint/components/data/dataset.py index 6e4006816b..949c5ff9c3 100644 --- a/tidy3d/plugins/adjoint/components/data/dataset.py +++ b/tidy3d/plugins/adjoint/components/data/dataset.py @@ -1,9 +1,10 @@ """Defines jax-compatible datasets.""" -import pydantic.v1 as pd from jax.tree_util import register_pytree_node_class +from pydantic import Field + +from tidy3d.components.data.dataset import PermittivityDataset -from .....components.data.dataset import PermittivityDataset from ..base import JaxObject from .data_array import JaxDataArray @@ -14,20 +15,17 @@ class JaxPermittivityDataset(PermittivityDataset, JaxObject): _tidy3d_class = PermittivityDataset - eps_xx: JaxDataArray = pd.Field( - ..., + eps_xx: JaxDataArray = Field( title="Epsilon xx", description="Spatial distribution of the xx-component of the relative permittivity.", jax_field=True, ) - eps_yy: JaxDataArray = pd.Field( - ..., + eps_yy: JaxDataArray = Field( title="Epsilon yy", description="Spatial distribution of the yy-component of the relative permittivity.", jax_field=True, ) - eps_zz: JaxDataArray = pd.Field( - ..., + eps_zz: JaxDataArray = Field( title="Epsilon zz", description="Spatial distribution of the zz-component of the relative permittivity.", jax_field=True, diff --git a/tidy3d/plugins/adjoint/components/data/monitor_data.py b/tidy3d/plugins/adjoint/components/data/monitor_data.py index 23ed786958..7c4deba6c4 100644 --- a/tidy3d/plugins/adjoint/components/data/monitor_data.py +++ b/tidy3d/plugins/adjoint/components/data/monitor_data.py @@ -3,35 +3,36 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union +from typing import Any, Optional, Union import jax.numpy as jnp import numpy as np -import pydantic.v1 as pd from jax.tree_util import register_pytree_node_class +from pydantic import Field -from .....components.base import cached_property -from .....components.data.data_array import ( +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import ( FreqModeDataArray, MixedModeDataArray, ModeAmpsDataArray, ScalarFieldDataArray, ) -from .....components.data.dataset import FieldDataset -from .....components.data.monitor_data import ( +from tidy3d.components.data.dataset import FieldDataset +from tidy3d.components.data.monitor_data import ( DiffractionData, FieldData, ModeData, ModeSolverData, MonitorData, ) -from .....components.geometry.base import Box -from .....components.source.base import Source -from .....components.source.current import CustomCurrentSource, PointDipole -from .....components.source.field import CustomFieldSource, ModeSource, PlaneWave -from .....components.source.time import GaussianPulse -from .....constants import C_0, ETA_0, MU_0 -from .....exceptions import AdjointError +from tidy3d.components.geometry.base import Box +from tidy3d.components.source.base import Source +from tidy3d.components.source.current import CustomCurrentSource, PointDipole +from tidy3d.components.source.field import CustomFieldSource, ModeSource, PlaneWave +from tidy3d.components.source.time import GaussianPulse +from tidy3d.constants import C_0, ETA_0, MU_0 +from tidy3d.exceptions import AdjointError + from ..base import JaxObject from .data_array import JaxDataArray @@ -54,7 +55,7 @@ def from_monitor_data(cls, mnt_data: MonitorData) -> JaxMonitorData: return cls.parse_obj(self_dict) @abstractmethod - def to_adjoint_sources(self, fwidth: float) -> List[Source]: + def to_adjoint_sources(self, fwidth: float) -> list[Source]: """Construct a list of adjoint sources from this :class:`.JaxMonitorData`.""" @staticmethod @@ -80,14 +81,13 @@ def flip_direction(direction: str) -> str: class JaxModeData(JaxMonitorData, ModeData): """A :class:`.ModeData` registered with jax.""" - amps: JaxDataArray = pd.Field( - ..., + amps: JaxDataArray = Field( title="Amplitudes", description="Jax-compatible modal amplitude data associated with an output monitor.", jax_field=True, ) - def to_adjoint_sources(self, fwidth: float) -> List[ModeSource]: + def to_adjoint_sources(self, fwidth: float) -> list[ModeSource]: """Converts a :class:`.ModeData` to a list of adjoint :class:`.ModeSource`.""" amps, sel_coords = self.amps.nonzero_val_coords @@ -121,37 +121,37 @@ def to_adjoint_sources(self, fwidth: float) -> List[ModeSource]: class JaxFieldData(JaxMonitorData, FieldData): """A :class:`.FieldData` registered with jax.""" - Ex: JaxDataArray = pd.Field( + Ex: Optional[JaxDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", jax_field=True, ) - Ey: JaxDataArray = pd.Field( + Ey: Optional[JaxDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", jax_field=True, ) - Ez: JaxDataArray = pd.Field( + Ez: Optional[JaxDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", jax_field=True, ) - Hx: JaxDataArray = pd.Field( + Hx: Optional[JaxDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", jax_field=True, ) - Hy: JaxDataArray = pd.Field( + Hy: Optional[JaxDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", jax_field=True, ) - Hz: JaxDataArray = pd.Field( + Hz: Optional[JaxDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -167,7 +167,7 @@ def __contains__(self, item: str) -> bool: def __getitem__(self, item: str) -> bool: return self.field_components[item] - def package_colocate_results(self, centered_fields: Dict[str, ScalarFieldDataArray]) -> Any: + def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: """How to package the dictionary of fields computed via self.colocate().""" return self.updated_copy(**centered_fields) @@ -243,7 +243,7 @@ def time_reversed_copy(self) -> FieldData: "'time_reversed_copy' is not yet supported in the adjoint plugin." ) - def to_adjoint_sources(self, fwidth: float) -> List[CustomFieldSource]: + def to_adjoint_sources(self, fwidth: float) -> list[CustomFieldSource]: """Converts a :class:`.JaxFieldData` to a list of adjoint :class:`.CustomFieldSource.""" interpolate_source = True @@ -335,38 +335,32 @@ def shift_value(coords) -> float: class JaxDiffractionData(JaxMonitorData, DiffractionData): """A :class:`.DiffractionData` registered with jax.""" - Er: JaxDataArray = pd.Field( - ..., + Er: JaxDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", jax_field=True, ) - Etheta: JaxDataArray = pd.Field( - ..., + Etheta: JaxDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", jax_field=True, ) - Ephi: JaxDataArray = pd.Field( - ..., + Ephi: JaxDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", jax_field=True, ) - Hr: JaxDataArray = pd.Field( - ..., + Hr: JaxDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", jax_field=True, ) - Htheta: JaxDataArray = pd.Field( - ..., + Htheta: JaxDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", jax_field=True, ) - Hphi: JaxDataArray = pd.Field( - ..., + Hphi: JaxDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", jax_field=True, @@ -408,7 +402,7 @@ def power(self) -> JaxDataArray: return JaxDataArray(values=power_values, coords=power_coords) - def to_adjoint_sources(self, fwidth: float) -> List[PlaneWave]: + def to_adjoint_sources(self, fwidth: float) -> list[PlaneWave]: """Converts a :class:`.DiffractionData` to a list of adjoint :class:`.PlaneWave`.""" # extract the values coordinates of the non-zero amplitudes diff --git a/tidy3d/plugins/adjoint/components/data/sim_data.py b/tidy3d/plugins/adjoint/components/data/sim_data.py index 7568c9135e..c8f213c252 100644 --- a/tidy3d/plugins/adjoint/components/data/sim_data.py +++ b/tidy3d/plugins/adjoint/components/data/sim_data.py @@ -2,18 +2,19 @@ from __future__ import annotations -from typing import Dict, List, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd import xarray as xr from jax.tree_util import register_pytree_node_class +from pydantic import Field + +from tidy3d.components.data.monitor_data import FieldData, MonitorDataType, PermittivityData +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.source.current import PointDipole +from tidy3d.components.source.time import GaussianPulse +from tidy3d.log import log -from .....components.data.monitor_data import FieldData, MonitorDataType, PermittivityData -from .....components.data.sim_data import SimulationData -from .....components.source.current import PointDipole -from .....components.source.time import GaussianPulse -from .....log import log from ..base import JaxObject from ..simulation import JaxInfo, JaxSimulation from .monitor_data import JAX_MONITOR_DATA_MAP, JaxMonitorDataType @@ -23,32 +24,31 @@ class JaxSimulationData(SimulationData, JaxObject): """A :class:`.SimulationData` registered with jax.""" - output_data: Tuple[JaxMonitorDataType, ...] = pd.Field( + output_data: tuple[JaxMonitorDataType, ...] = Field( (), title="Jax Data", description="Tuple of Jax-compatible data associated with output monitors.", jax_field=True, ) - grad_data: Tuple[FieldData, ...] = pd.Field( + grad_data: tuple[FieldData, ...] = Field( (), title="Gradient Field Data", description="Tuple of monitor data storing fields associated with the input structures.", ) - grad_eps_data: Tuple[PermittivityData, ...] = pd.Field( + grad_eps_data: tuple[PermittivityData, ...] = Field( (), title="Gradient Permittivity Data", description="Tuple of monitor data storing epsilon associated with the input structures.", ) - simulation: JaxSimulation = pd.Field( - ..., + simulation: JaxSimulation = Field( title="Simulation", description="The jax-compatible simulation corresponding to the data.", ) - task_id: str = pd.Field( + task_id: Optional[str] = Field( None, title="Task ID", description="Optional field storing the task_id for the original JaxSimulation.", @@ -83,22 +83,22 @@ def get_poynting_vector(self, field_monitor_name: str) -> xr.Dataset: return super().get_poynting_vector(field_monitor_name) @property - def grad_data_symmetry(self) -> Tuple[FieldData, ...]: + def grad_data_symmetry(self) -> tuple[FieldData, ...]: """``self.grad_data`` but with ``symmetry_expanded_copy`` applied.""" return tuple(data.symmetry_expanded_copy for data in self.grad_data) @property - def grad_eps_data_symmetry(self) -> Tuple[FieldData, ...]: + def grad_eps_data_symmetry(self) -> tuple[FieldData, ...]: """``self.grad_eps_data`` but with ``symmetry_expanded_copy`` applied.""" return tuple(data.symmetry_expanded_copy for data in self.grad_eps_data) @property - def output_monitor_data(self) -> Dict[str, JaxMonitorDataType]: + def output_monitor_data(self) -> dict[str, JaxMonitorDataType]: """Dictionary of ``.output_data`` monitor ``.name`` to the corresponding data.""" return {monitor_data.monitor.name: monitor_data for monitor_data in self.output_data} @property - def monitor_data(self) -> Dict[str, Union[JaxMonitorDataType, MonitorDataType]]: + def monitor_data(self) -> dict[str, Union[JaxMonitorDataType, MonitorDataType]]: """Dictionary of ``.output_data`` monitor ``.name`` to the corresponding data.""" reg_mnt_data = {monitor_data.monitor.name: monitor_data for monitor_data in self.data} reg_mnt_data.update(self.output_monitor_data) @@ -106,8 +106,8 @@ def monitor_data(self) -> Dict[str, Union[JaxMonitorDataType, MonitorDataType]]: @staticmethod def split_data( - mnt_data: List[MonitorDataType], jax_info: JaxInfo - ) -> Dict[str, List[MonitorDataType]]: + mnt_data: list[MonitorDataType], jax_info: JaxInfo + ) -> dict[str, list[MonitorDataType]]: """Split list of monitor data into data, output_data, grad_data, and grad_eps_data.""" # Get information needed to split the full data list len_output_data = jax_info.num_output_monitors @@ -166,7 +166,7 @@ def from_sim_data( @classmethod def split_fwd_sim_data( cls, sim_data: SimulationData, jax_info: JaxInfo - ) -> Tuple[SimulationData, SimulationData]: + ) -> tuple[SimulationData, SimulationData]: """Split a :class:`.SimulationData` into two parts, containing user and gradient data.""" sim = sim_data.simulation diff --git a/tidy3d/plugins/adjoint/components/geometry.py b/tidy3d/plugins/adjoint/components/geometry.py index 578d8f53aa..afa835585a 100644 --- a/tidy3d/plugins/adjoint/components/geometry.py +++ b/tidy3d/plugins/adjoint/components/geometry.py @@ -3,32 +3,33 @@ from __future__ import annotations from abc import ABC -from typing import Dict, List, Tuple, Union +from typing import Union import jax import jax.numpy as jnp import numpy as np -import pydantic.v1 as pd import shapely import xarray as xr from jax.tree_util import register_pytree_node_class from joblib import Parallel, delayed +from pydantic import Field, field_validator -from ....components.base import cached_property -from ....components.data.data_array import ScalarFieldDataArray -from ....components.data.monitor_data import FieldData, PermittivityData -from ....components.geometry.base import Box, Geometry, GeometryGroup -from ....components.geometry.polyslab import ( +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import ScalarFieldDataArray +from tidy3d.components.data.monitor_data import FieldData, PermittivityData +from tidy3d.components.geometry.base import Box, Geometry, GeometryGroup +from tidy3d.components.geometry.polyslab import ( _COMPLEX_POLYSLAB_DIVISIONS_WARN, _IS_CLOSE_RTOL, PolySlab, ) -from ....components.monitor import FieldMonitor, PermittivityMonitor -from ....components.types import ArrayFloat2D, Bound, Coordinate2D # , annotate_type -from ....constants import MICROMETER, fp_eps -from ....exceptions import AdjointError -from ....log import log -from ...polyslab import ComplexPolySlab +from tidy3d.components.monitor import FieldMonitor, PermittivityMonitor +from tidy3d.components.types import ArrayFloat2D, Bound, Coordinate2D +from tidy3d.constants import MICROMETER, fp_eps +from tidy3d.exceptions import AdjointError +from tidy3d.log import log +from tidy3d.plugins.polyslab import ComplexPolySlab + from .base import WEB_ADJOINT_MESSAGE, JaxObject from .types import JaxFloat @@ -46,13 +47,13 @@ class JaxGeometry(Geometry, ABC): """Abstract :class:`.Geometry` with methods useful for all Jax subclasses.""" @property - def bound_size(self) -> Tuple[float, float, float]: + def bound_size(self) -> tuple[float, float, float]: """Size of the bounding box of this geometry.""" rmin, rmax = self.bounds return tuple(abs(pt_max - pt_min) for (pt_min, pt_max) in zip(rmin, rmax)) @property - def bound_center(self) -> Tuple[float, float, float]: + def bound_center(self) -> tuple[float, float, float]: """Size of the bounding box of this geometry.""" rmin, rmax = self.bounds @@ -76,8 +77,8 @@ def bounding_box(self): return JaxBox.from_bounds(*self.bounds) def make_grad_monitors( - self, freqs: List[float], name: str - ) -> Tuple[FieldMonitor, PermittivityMonitor]: + self, freqs: list[float], name: str + ) -> tuple[FieldMonitor, PermittivityMonitor]: """Return gradient monitor associated with this object.""" size_enlarged = tuple(s + 2 * GRAD_MONITOR_EXPANSION for s in self.bound_size) field_mnt = FieldMonitor( @@ -100,7 +101,7 @@ def make_grad_monitors( @staticmethod def compute_dotted_e_d_fields( grad_data_fwd: FieldData, grad_data_adj: FieldData, grad_data_eps: PermittivityData - ) -> Tuple[Dict[str, ScalarFieldDataArray], Dict[str, ScalarFieldDataArray]]: + ) -> tuple[dict[str, ScalarFieldDataArray], dict[str, ScalarFieldDataArray]]: """Get the (x,y,z) components of E_fwd * E_adj and D_fwd * D_adj fields in the domain.""" e_mult_xyz = {} @@ -133,7 +134,7 @@ class JaxBox(JaxGeometry, Box, JaxObject): _tidy3d_class = Box - center_jax: Tuple[JaxFloat, JaxFloat, JaxFloat] = pd.Field( + center_jax: tuple[JaxFloat, JaxFloat, JaxFloat] = Field( (0.0, 0.0, 0.0), title="Center (Jax)", description="Jax traced value for the center of the box in (x, y, z).", @@ -141,8 +142,7 @@ class JaxBox(JaxGeometry, Box, JaxObject): stores_jax_for="center", ) - size_jax: Tuple[JaxFloat, JaxFloat, JaxFloat] = pd.Field( - ..., + size_jax: tuple[JaxFloat, JaxFloat, JaxFloat] = Field( title="Size (Jax)", description="Jax-traced value for the size of the box in (x, y, z).", units=MICROMETER, @@ -275,8 +275,7 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject): _tidy3d_class = PolySlab - vertices_jax: Tuple[Tuple[JaxFloat, JaxFloat], ...] = pd.Field( - ..., + vertices_jax: tuple[tuple[JaxFloat, JaxFloat], ...] = Field( title="Vertices (Jax)", description="Jax-traced list of (d1, d2) defining the 2 dimensional positions of the " "polygon face vertices at the ``reference_plane``. " @@ -286,8 +285,7 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject): stores_jax_for="vertices", ) - slab_bounds_jax: Tuple[JaxFloat, JaxFloat] = pd.Field( - ..., + slab_bounds_jax: tuple[JaxFloat, JaxFloat] = Field( title="Slab bounds (Jax)", description="Jax-traced list of (h1, h2) defining the minimum and maximum positions " "of the slab along the ``axis`` dimension. ", @@ -295,7 +293,7 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject): stores_jax_for="slab_bounds", ) - sidewall_angle_jax: JaxFloat = pd.Field( + sidewall_angle_jax: JaxFloat = Field( default=0.0, title="Sidewall angle (Jax)", description="Jax-traced float defining the sidewall angle of the slab " @@ -304,7 +302,7 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject): stores_jax_for="sidewall_angle", ) - dilation_jax: JaxFloat = pd.Field( + dilation_jax: JaxFloat = Field( default=0.0, title="Dilation (Jax)", description="Jax-traced float defining the dilation.", @@ -312,8 +310,8 @@ class JaxPolySlab(JaxGeometry, PolySlab, JaxObject): stores_jax_for="dilation", ) - @pd.validator("sidewall_angle", always=True) - def no_sidewall(cls, val): + @field_validator("sidewall_angle") + def no_sidewall(val): """Warn if sidewall angle present.""" if not np.isclose(val, 0.0): log.warning( @@ -383,7 +381,7 @@ def _area(vertices: jnp.ndarray) -> float: @staticmethod def _shift_vertices( vertices: jnp.ndarray, dist - ) -> Tuple[jnp.ndarray, jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + ) -> tuple[jnp.ndarray, jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray]]: """Shifts the vertices of a polygon outward uniformly by distances `dists`. @@ -396,7 +394,7 @@ def _shift_vertices( Returns ------- - Tuple[jnp.ndarray, jnp.narray, Tuple[jnp.ndarray, jnp.ndarray]] + tuple[jnp.ndarray, jnp.narray, tuple[jnp.ndarray, jnp.ndarray]] New polygon vertices; and the shift of vertices in direction parallel to the edges. Shift along x and y direction. @@ -468,7 +466,7 @@ def _neighbor_vertices_crossing_detection( return None @staticmethod - def _edge_length_and_reduction_rate(vertices: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def _edge_length_and_reduction_rate(vertices: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Edge length of reduction rate of each edge with unit offset length. Parameters @@ -478,7 +476,7 @@ def _edge_length_and_reduction_rate(vertices: jnp.ndarray) -> Tuple[jnp.ndarray, Returns ------- - Tuple[jnp.ndarray, jnp.narray] + tuple[jnp.ndarray, jnp.narray] edge length, and reduction rate """ @@ -565,8 +563,8 @@ def edge_contrib( vertex_grad: Coordinate2D, vertex_stat: Coordinate2D, is_next: bool, - e_mult_xyz: Tuple[Dict[str, ScalarFieldDataArray]], - d_mult_xyz: Tuple[Dict[str, ScalarFieldDataArray]], + e_mult_xyz: tuple[dict[str, ScalarFieldDataArray]], + d_mult_xyz: tuple[dict[str, ScalarFieldDataArray]], sim_bounds: Bound, wvl_mat: float, eps_out: complex, @@ -606,8 +604,8 @@ def edge_position(s: np.array) -> np.array: return (1 - s) * vertex_stat[:, None] + s * vertex_grad[:, None] def edge_basis( - xyz_components: Tuple[FieldData, FieldData, FieldData], - ) -> Tuple[FieldData, FieldData, FieldData]: + xyz_components: tuple[FieldData, FieldData, FieldData], + ) -> tuple[FieldData, FieldData, FieldData]: """Puts a field component from the (x, y, z) basis to the (t, n, z) basis.""" cmp_z, (cmp_x_edge, cmp_y_edge) = self.pop_axis(xyz_components, axis=self.axis) @@ -696,8 +694,8 @@ def evaluate(scalar_field: ScalarFieldDataArray) -> float: def vertex_vjp( self, i_vertex, - e_mult_xyz: Tuple[Dict[str, ScalarFieldDataArray]], - d_mult_xyz: Tuple[Dict[str, ScalarFieldDataArray]], + e_mult_xyz: tuple[dict[str, ScalarFieldDataArray]], + d_mult_xyz: tuple[dict[str, ScalarFieldDataArray]], sim_bounds: Bound, wvl_mat: float, eps_out: complex, @@ -840,8 +838,8 @@ class JaxComplexPolySlab(JaxPolySlab, ComplexPolySlab): _tidy3d_class = ComplexPolySlab - @pd.validator("vertices", always=True) - def no_self_intersecting_polygon_during_extrusion(cls, val, values): + @field_validator("vertices") + def no_self_intersecting_polygon_during_extrusion(val): """Turn off the validation for this class.""" return val @@ -870,14 +868,14 @@ def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: return z_coord @property - def sub_polyslabs(self) -> List[JaxPolySlab]: + def sub_polyslabs(self) -> list[JaxPolySlab]: """Divide a complex polyslab into a list of simple polyslabs. Only neighboring vertex-vertex crossing events are treated in this version. Returns ------- - List[JaxPolySlab] + list[JaxPolySlab] A list of simple jax polyslabs. """ sub_polyslab_list = [] @@ -992,7 +990,7 @@ class JaxGeometryGroup(JaxGeometry, GeometryGroup, JaxObject): _tidy3d_class = GeometryGroup - geometries: Tuple[JaxPolySlab, ...] = pd.Field( + geometries: tuple[JaxPolySlab, ...] = Field( ..., title="Geometries", description="Tuple of jax geometries in a single grouping. " diff --git a/tidy3d/plugins/adjoint/components/medium.py b/tidy3d/plugins/adjoint/components/medium.py index f87a3ec9a2..bd953508fd 100644 --- a/tidy3d/plugins/adjoint/components/medium.py +++ b/tidy3d/plugins/adjoint/components/medium.py @@ -3,19 +3,20 @@ from __future__ import annotations from abc import ABC -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Union import numpy as np -import pydantic.v1 as pd import xarray as xr from jax.tree_util import register_pytree_node_class +from pydantic import Field, field_validator, model_validator + +from tidy3d.components.data.monitor_data import FieldData +from tidy3d.components.geometry.base import Geometry +from tidy3d.components.medium import AnisotropicMedium, CustomMedium, Medium +from tidy3d.components.types import Bound, Literal +from tidy3d.constants import CONDUCTIVITY +from tidy3d.exceptions import SetupError -from ....components.data.monitor_data import FieldData -from ....components.geometry.base import Geometry -from ....components.medium import AnisotropicMedium, CustomMedium, Medium -from ....components.types import Bound, Literal -from ....constants import CONDUCTIVITY -from ....exceptions import SetupError from .base import WEB_ADJOINT_MESSAGE, JaxObject from .data.data_array import JaxDataArray from .data.dataset import JaxPermittivityDataset @@ -33,7 +34,7 @@ class AbstractJaxMedium(ABC, JaxObject): def _get_volume_disc( self, grad_data: FieldData, sim_bounds: Bound, wvl_mat: float - ) -> Tuple[Dict[str, np.ndarray], float]: + ) -> tuple[dict[str, np.ndarray], float]: """Get the coordinates and volume element for the inside of the corresponding structure.""" # find intersecting volume between structure and simulation @@ -63,7 +64,7 @@ def _get_volume_disc( return vol_coords, d_vol @staticmethod - def make_inside_mask(vol_coords: Dict[str, np.ndarray], inside_fn: Callable) -> xr.DataArray: + def make_inside_mask(vol_coords: dict[str, np.ndarray], inside_fn: Callable) -> xr.DataArray: """Make a 3D mask of where the volume coordinates are inside a supplied function.""" meshgrid_args = [vol_coords[dim] for dim in "xyz" if dim in vol_coords] @@ -77,7 +78,7 @@ def e_mult_volume( field: Literal["Ex", "Ey", "Ez"], grad_data_fwd: FieldData, grad_data_adj: FieldData, - vol_coords: Dict[str, np.ndarray], + vol_coords: dict[str, np.ndarray], d_vol: float, inside_fn: Callable, ) -> xr.DataArray: @@ -141,14 +142,14 @@ class JaxMedium(Medium, AbstractJaxMedium): _tidy3d_class = Medium - permittivity_jax: JaxFloat = pd.Field( + permittivity_jax: JaxFloat = Field( 1.0, title="Permittivity", description="Relative permittivity of the medium. May be a ``jax`` ``Array``.", stores_jax_for="permittivity", ) - conductivity_jax: JaxFloat = pd.Field( + conductivity_jax: JaxFloat = Field( 0.0, title="Conductivity", description="Electric conductivity. Defined such that the imaginary part of the complex " @@ -201,22 +202,19 @@ class JaxAnisotropicMedium(AnisotropicMedium, AbstractJaxMedium): _tidy3d_class = AnisotropicMedium - xx: JaxMedium = pd.Field( - ..., + xx: JaxMedium = Field( title="XX Component", description="Medium describing the xx-component of the diagonal permittivity tensor.", jax_field=True, ) - yy: JaxMedium = pd.Field( - ..., + yy: JaxMedium = Field( title="YY Component", description="Medium describing the yy-component of the diagonal permittivity tensor.", jax_field=True, ) - zz: JaxMedium = pd.Field( - ..., + zz: JaxMedium = Field( title="ZZ Component", description="Medium describing the zz-component of the diagonal permittivity tensor.", jax_field=True, @@ -283,7 +281,7 @@ class JaxCustomMedium(CustomMedium, AbstractJaxMedium): _tidy3d_class = CustomMedium - eps_dataset: Optional[JaxPermittivityDataset] = pd.Field( + eps_dataset: Optional[JaxPermittivityDataset] = Field( None, title="Permittivity Dataset", description="User-supplied dataset containing complex-valued permittivity " @@ -292,16 +290,16 @@ class JaxCustomMedium(CustomMedium, AbstractJaxMedium): jax_field=True, ) - @pd.root_validator(pre=True) - def _pre_deprecation_dataset(cls, values): + @model_validator(mode="before") + def _pre_deprecation_dataset(data): """Don't allow permittivity as a field until we support it.""" - if values.get("permittivity") or values.get("conductivity"): + if data.get("permittivity") or data.get("conductivity"): raise SetupError( "'permittivity' and 'conductivity' are not yet supported in adjoint plugin. " "Please continue to use the 'eps_dataset' field to define the component " "of the permittivity tensor." ) - return values + return data def _validate_web_adjoint(self) -> None: """Run validators for this component, only if using ``tda.web.run()``.""" @@ -325,12 +323,12 @@ def _is_not_too_large(self): + WEB_ADJOINT_MESSAGE ) - @pd.validator("eps_dataset", always=True) - def _eps_dataset_single_frequency(cls, val): + @field_validator("eps_dataset") + def _eps_dataset_single_frequency(val): """Override of inherited validator. (still needed)""" return val - @pd.validator("eps_dataset", always=True) + @field_validator("eps_dataset") def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, values): """Override of inherited validator.""" return val diff --git a/tidy3d/plugins/adjoint/components/simulation.py b/tidy3d/plugins/adjoint/components/simulation.py index 93088536d4..dd00e61195 100644 --- a/tidy3d/plugins/adjoint/components/simulation.py +++ b/tidy3d/plugins/adjoint/components/simulation.py @@ -2,32 +2,40 @@ from __future__ import annotations -from typing import Dict, List, Literal, Tuple, Union +from typing import Literal, Optional, Union import numpy as np -import pydantic.v1 as pd import xarray as xr from jax.tree_util import register_pytree_node_class from joblib import Parallel, delayed +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) -from ....components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from ....components.data.monitor_data import FieldData, PermittivityData -from ....components.geometry.base import Box -from ....components.medium import AbstractMedium -from ....components.monitor import ( +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.data.monitor_data import FieldData, PermittivityData +from tidy3d.components.geometry.base import Box +from tidy3d.components.medium import AbstractMedium +from tidy3d.components.monitor import ( DiffractionMonitor, FieldMonitor, ModeMonitor, Monitor, PermittivityMonitor, ) -from ....components.simulation import Simulation -from ....components.structure import Structure -from ....components.subpixel_spec import Staircasing, SubpixelSpec -from ....components.types import Ax, annotate_type -from ....constants import HERTZ, SECOND -from ....exceptions import AdjointError -from ....log import log +from tidy3d.components.simulation import Simulation +from tidy3d.components.structure import Structure +from tidy3d.components.subpixel_spec import Staircasing, SubpixelSpec +from tidy3d.components.types import Ax, discriminated_union +from tidy3d.constants import HERTZ, SECOND +from tidy3d.exceptions import AdjointError +from tidy3d.log import log + from .base import WEB_ADJOINT_MESSAGE, JaxObject from .geometry import JaxGeometryGroup, JaxPolySlab from .structure import ( @@ -61,53 +69,49 @@ ) OutputMonitorTypes = (DiffractionMonitor, FieldMonitor, ModeMonitor) -OutputMonitorType = Tuple[annotate_type(Union[OutputMonitorTypes]), ...] +OutputMonitorType = tuple[discriminated_union(Union[OutputMonitorTypes]), ...] class JaxInfo(Tidy3dBaseModel): """Class to store information when converting between jax and tidy3d.""" - num_input_structures: pd.NonNegativeInt = pd.Field( - ..., + num_input_structures: NonNegativeInt = Field( title="Number of Input Structures", description="Number of input structures in the original JaxSimulation.", ) - num_output_monitors: pd.NonNegativeInt = pd.Field( - ..., + num_output_monitors: NonNegativeInt = Field( title="Number of Output Monitors", description="Number of output monitors in the original JaxSimulation.", ) - num_grad_monitors: pd.NonNegativeInt = pd.Field( - ..., + num_grad_monitors: NonNegativeInt = Field( title="Number of Gradient Monitors", description="Number of gradient monitors in the original JaxSimulation.", ) - num_grad_eps_monitors: pd.NonNegativeInt = pd.Field( - ..., + num_grad_eps_monitors: NonNegativeInt = Field( title="Number of Permittivity Monitors", description="Number of permittivity monitors in the original JaxSimulation.", ) - fwidth_adjoint: float = pd.Field( + fwidth_adjoint: Optional[float] = Field( None, title="Adjoint Frequency Width", description="Custom frequency width of the original JaxSimulation.", units=HERTZ, ) - run_time_adjoint: float = pd.Field( + run_time_adjoint: float = Field( None, title="Adjoint Run Time", description="Custom run time of the original JaxSimulation.", units=SECOND, ) - input_structure_types: Tuple[ + input_structure_types: tuple[ Literal["JaxStructure", "JaxStructureStaticMedium", "JaxStructureStaticGeometry"], ... - ] = pd.Field( + ] = Field( (), title="Input Structure Types", description="Type of the original input_structures (as strings).", @@ -118,7 +122,7 @@ class JaxInfo(Tidy3dBaseModel): class JaxSimulation(Simulation, JaxObject): """A :class:`.Simulation` registered with jax.""" - input_structures: Tuple[annotate_type(JaxStructureType), ...] = pd.Field( + input_structures: tuple[discriminated_union(JaxStructureType), ...] = Field( (), title="Input Structures", description="Tuple of jax-compatible structures" @@ -126,25 +130,25 @@ class JaxSimulation(Simulation, JaxObject): jax_field=True, ) - output_monitors: OutputMonitorType = pd.Field( + output_monitors: OutputMonitorType = Field( (), title="Output Monitors", description="Tuple of monitors whose data the differentiable output depends on.", ) - grad_monitors: Tuple[FieldMonitor, ...] = pd.Field( + grad_monitors: tuple[FieldMonitor, ...] = Field( (), title="Gradient Field Monitors", description="Tuple of monitors used for storing fields, used internally for gradients.", ) - grad_eps_monitors: Tuple[PermittivityMonitor, ...] = pd.Field( + grad_eps_monitors: tuple[PermittivityMonitor, ...] = Field( (), title="Gradient Permittivity Monitors", description="Tuple of monitors used for storing epsilon, used internally for gradients.", ) - fwidth_adjoint: pd.PositiveFloat = pd.Field( + fwidth_adjoint: Optional[PositiveFloat] = Field( None, title="Adjoint Frequency Width", description="Custom frequency width to use for ``source_time`` of adjoint sources. " @@ -152,7 +156,7 @@ class JaxSimulation(Simulation, JaxObject): units=HERTZ, ) - run_time_adjoint: pd.PositiveFloat = pd.Field( + run_time_adjoint: Optional[PositiveFloat] = Field( None, title="Adjoint Run Time", description="Custom ``run_time`` to use for adjoint simulation. " @@ -160,8 +164,8 @@ class JaxSimulation(Simulation, JaxObject): units=SECOND, ) - @pd.validator("output_monitors", always=True) - def _output_monitors_colocate_false(cls, val): + @field_validator("output_monitors") + def _output_monitors_colocate_false(val): """Make sure server-side colocation is off.""" new_vals = [] for mnt in val: @@ -177,8 +181,8 @@ def _output_monitors_colocate_false(cls, val): new_vals.append(mnt) return new_vals - @pd.validator("subpixel", always=True) - def _subpixel_is_on(cls, val): + @field_validator("subpixel") + def _subpixel_is_on(val): """Assert dielectric subpixel is on.""" if (isinstance(val, SubpixelSpec) and isinstance(val.dielectric, Staircasing)) or not val: raise AdjointError( @@ -187,14 +191,13 @@ def _subpixel_is_on(cls, val): ) return val - @pd.validator("input_structures", always=True) - @skip_if_fields_missing(["structures"]) - def _warn_overlap(cls, val, values): + @model_validator(mode="after") + def _warn_overlap(self): """Print appropriate warning if structures intersect in ways that cause gradient error.""" - + val = self.input_structures input_structures = [s for s in val if "geometry" in s._differentiable_fields] - structures = list(values.get("structures")) + structures = list(self.structures) # if the center and size of all structure geometries do not contain all numbers, skip check for struct in input_structures: @@ -202,7 +205,7 @@ def _warn_overlap(cls, val, values): size_all_floats = all(isinstance(s, (float, int)) for s in geometry.bound_size) cent_all_floats = all(isinstance(c, (float, int)) for c in geometry.bound_center) if not (size_all_floats and cent_all_floats): - return val + return self with log as consolidated_logger: # check intersections with other input_structures @@ -231,10 +234,10 @@ def _warn_overlap(cls, val, values): "when 'JaxPolySlab' intersects with background structures." ) - return val + return self - @pd.validator("output_monitors", always=True) - def _warn_if_colocate(cls, val): + @field_validator("output_monitors") + def _warn_if_colocate(val): """warn if any colocate=True in output FieldMonitors.""" for index, mnt in enumerate(val): if isinstance(mnt, FieldMonitor): @@ -247,8 +250,8 @@ def _warn_if_colocate(cls, val): return val return val - @pd.validator("medium", always=True) - def _warn_nonlinear_medium(cls, val): + @field_validator("medium") + def _warn_nonlinear_medium(val): """warn if the jax simulation medium is nonlinear.""" # hasattr is just an additional check to avoid unnecessary bugs # if a medium is encountered that does not support nonlinear spec, or things change. @@ -258,8 +261,8 @@ def _warn_nonlinear_medium(cls, val): ) return val - @pd.validator("structures", always=True) - def _warn_nonlinear_structure(cls, val): + @field_validator("structures") + def _warn_nonlinear_structure(val): """warn if a jax simulation structure.medium is nonlinear.""" for i, struct in enumerate(val): medium = struct.medium @@ -269,8 +272,8 @@ def _warn_nonlinear_structure(cls, val): log.warning(f"Nonlinear medium detected in structures[{i}]. " + NL_WARNING) return val - @pd.validator("input_structures", always=True) - def _warn_nonlinear_input_structure(cls, val): + @field_validator("input_structures") + def _warn_nonlinear_input_structure(val): """warn if a jax simulation input_structure.medium is nonlinear.""" for i, struct in enumerate(val): medium = struct.medium @@ -296,7 +299,7 @@ def _validate_web_adjoint(self) -> None: structure._validate_web_adjoint() @staticmethod - def get_freqs_adjoint(output_monitors: List[Monitor]) -> List[float]: + def get_freqs_adjoint(output_monitors: list[Monitor]) -> list[float]: """Return sorted list of unique frequencies stripped from a collection of monitors.""" if len(output_monitors) == 0: @@ -310,7 +313,7 @@ def get_freqs_adjoint(output_monitors: List[Monitor]) -> List[float]: return np.unique(output_freqs).tolist() @cached_property - def freqs_adjoint(self) -> List[float]: + def freqs_adjoint(self) -> list[float]: """Return sorted list of frequencies stripped from the output monitors.""" return self.get_freqs_adjoint(output_monitors=self.output_monitors) @@ -400,7 +403,7 @@ def num_time_steps_adjoint(self) -> int: """Number of time steps in the adjoint simulation.""" return len(self.tmesh_adjoint) - def to_simulation(self) -> Tuple[Simulation, JaxInfo]: + def to_simulation(self) -> tuple[Simulation, JaxInfo]: """Convert :class:`.JaxSimulation` instance to :class:`.Simulation` with an info dict.""" sim_dict = self.dict( @@ -448,11 +451,9 @@ def to_gds( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pd.NonNegativeFloat = 1, - frequency: pd.PositiveFloat = 0, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pd.NonNegativeInt, pd.NonNegativeInt] - ] = None, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, ) -> None: """Append the simulation structures to a .gds cell. Parameters @@ -489,12 +490,10 @@ def to_gdstk( x: float = None, y: float = None, z: float = None, - permittivity_threshold: pd.NonNegativeFloat = 1, - frequency: pd.PositiveFloat = 0, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pd.NonNegativeInt, pd.NonNegativeInt] - ] = None, - ) -> List: + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, + ) -> list: """Convert a simulation's planar slice to a .gds type polygon list. Parameters ---------- @@ -531,10 +530,8 @@ def to_gdspy( x: float = None, y: float = None, z: float = None, - gds_layer_dtype_map: Dict[ - AbstractMedium, Tuple[pd.NonNegativeInt, pd.NonNegativeInt] - ] = None, - ) -> List: + gds_layer_dtype_map: dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] = None, + ) -> list: """Convert a simulation's planar slice to a .gds type polygon list. Parameters ---------- @@ -562,8 +559,8 @@ def plot( ax: Ax = None, source_alpha: float = None, monitor_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, **patch_kwargs, ) -> Ax: """Wrapper around regular :class:`.Simulation` structure plotting.""" @@ -589,8 +586,8 @@ def plot_eps( alpha: float = None, source_alpha: float = None, monitor_alpha: float = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ax: Ax = None, ) -> Ax: """Wrapper around regular :class:`.Simulation` permittivity plotting.""" @@ -612,8 +609,8 @@ def plot_structures( y: float = None, z: float = None, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. @@ -627,9 +624,9 @@ def plot_structures( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -657,8 +654,8 @@ def plot_structures_eps( cbar: bool = True, reverse: bool = False, ax: Ax = None, - hlim: Tuple[float, float] = None, - vlim: Tuple[float, float] = None, + hlim: tuple[float, float] = None, + vlim: tuple[float, float] = None, ) -> Ax: """Plot each of simulation's structures on a plane defined by one nonzero x,y,z coordinate. The permittivity is plotted in grayscale based on its value at the specified frequency. @@ -684,9 +681,9 @@ def plot_structures_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -753,7 +750,7 @@ def __eq__(self, other: JaxSimulation) -> bool: return self.to_simulation()[0] == other.to_simulation()[0] @classmethod - def split_monitors(cls, monitors: List[Monitor], jax_info: JaxInfo) -> Dict[str, Monitor]: + def split_monitors(cls, monitors: list[Monitor], jax_info: JaxInfo) -> dict[str, Monitor]: """Split monitors into user and adjoint required based on jax info.""" all_monitors = list(monitors) @@ -785,8 +782,8 @@ def split_monitors(cls, monitors: List[Monitor], jax_info: JaxInfo) -> Dict[str, @classmethod def split_structures( - cls, structures: List[Structure], jax_info: JaxInfo - ) -> Dict[str, Structure]: + cls, structures: list[Structure], jax_info: JaxInfo + ) -> dict[str, Structure]: """Split structures into regular and input based on jax info.""" all_structures = list(structures) @@ -838,7 +835,7 @@ def from_simulation(cls, simulation: Simulation, jax_info: JaxInfo) -> JaxSimula return cls.parse_obj(sim_dict) @classmethod - def make_sim_fwd(cls, simulation: Simulation, jax_info: JaxInfo) -> Tuple[Simulation, JaxInfo]: + def make_sim_fwd(cls, simulation: Simulation, jax_info: JaxInfo) -> tuple[Simulation, JaxInfo]: """Make the forward :class:`.JaxSimulation` from the supplied :class:`.Simulation`.""" mnt_dict = JaxSimulation.split_monitors(monitors=simulation.monitors, jax_info=jax_info) @@ -871,7 +868,7 @@ def make_sim_fwd(cls, simulation: Simulation, jax_info: JaxInfo) -> Tuple[Simula return sim_fwd, jax_info - def to_simulation_fwd(self) -> Tuple[Simulation, JaxInfo, JaxInfo]: + def to_simulation_fwd(self) -> tuple[Simulation, JaxInfo, JaxInfo]: """Like ``to_simulation()`` but the gradient monitors are included.""" simulation, jax_info = self.to_simulation() sim_fwd, jax_info_fwd = self.make_sim_fwd(simulation=simulation, jax_info=jax_info) @@ -879,7 +876,7 @@ def to_simulation_fwd(self) -> Tuple[Simulation, JaxInfo, JaxInfo]: @staticmethod def get_grad_monitors( - input_structures: List[Structure], freqs_adjoint: List[float], include_eps_mnts: bool = True + input_structures: list[Structure], freqs_adjoint: list[float], include_eps_mnts: bool = True ) -> dict: """Return dictionary of gradient monitors for simulation.""" grad_mnts = [] @@ -916,9 +913,9 @@ def _store_vjp_structure( def store_vjp( self, - grad_data_fwd: Tuple[FieldData], - grad_data_adj: Tuple[FieldData], - grad_eps_data: Tuple[PermittivityData], + grad_data_fwd: tuple[FieldData], + grad_data_adj: tuple[FieldData], + grad_eps_data: tuple[PermittivityData], num_proc: int = NUM_PROC_LOCAL, ) -> JaxSimulation: """Store the vjp w.r.t. each input_structure as a sim using fwd and adj grad_data.""" @@ -939,9 +936,9 @@ def store_vjp( def store_vjp_sequential( self, - grad_data_fwd: Tuple[FieldData], - grad_data_adj: Tuple[FieldData], - grad_eps_data: Tuple[PermittivityData], + grad_data_fwd: tuple[FieldData], + grad_data_adj: tuple[FieldData], + grad_eps_data: tuple[PermittivityData], ) -> JaxSimulation: """Store the vjp w.r.t. each input_structure without multiprocessing.""" map_args = [self.input_structures, grad_data_fwd, grad_data_adj, grad_eps_data] @@ -955,9 +952,9 @@ def store_vjp_sequential( def store_vjp_parallel( self, - grad_data_fwd: Tuple[FieldData], - grad_data_adj: Tuple[FieldData], - grad_eps_data: Tuple[PermittivityData], + grad_data_fwd: tuple[FieldData], + grad_data_adj: tuple[FieldData], + grad_eps_data: tuple[PermittivityData], num_proc: int, ) -> JaxSimulation: """Store the vjp w.r.t. each input_structure as a sim using fwd and adj grad_data, and diff --git a/tidy3d/plugins/adjoint/components/structure.py b/tidy3d/plugins/adjoint/components/structure.py index d562816722..51b708f9cd 100644 --- a/tidy3d/plugins/adjoint/components/structure.py +++ b/tidy3d/plugins/adjoint/components/structure.py @@ -2,19 +2,20 @@ from __future__ import annotations -from typing import Dict, List, Union +from typing import Union import numpy as np -import pydantic.v1 as pd from jax.tree_util import register_pytree_node_class +from pydantic import Field, field_validator + +from tidy3d.components.data.monitor_data import FieldData, PermittivityData +from tidy3d.components.geometry.utils import GeometryType +from tidy3d.components.medium import MediumType +from tidy3d.components.monitor import FieldMonitor +from tidy3d.components.structure import Structure +from tidy3d.components.types import TYPE_TAG_STR, Bound +from tidy3d.constants import C_0 -from ....components.data.monitor_data import FieldData, PermittivityData -from ....components.geometry.utils import GeometryType -from ....components.medium import MediumType -from ....components.monitor import FieldMonitor -from ....components.structure import Structure -from ....components.types import TYPE_TAG_STR, Bound -from ....constants import C_0 from .base import JaxObject from .geometry import JAX_GEOMETRY_MAP, JaxBox, JaxGeometryType from .medium import JAX_MEDIUM_MAP, JaxMediumType @@ -33,8 +34,8 @@ class AbstractJaxStructure(Structure, JaxObject): geometry: Union[JaxGeometryType, GeometryType] medium: Union[JaxMediumType, MediumType] - @pd.validator("medium", always=True) - def _check_2d_geometry(cls, val, values): + @field_validator("medium") + def _check_2d_geometry(val): """Override validator checking 2D geometry, which triggers unnecessarily for gradients.""" return val @@ -83,7 +84,7 @@ def from_structure(cls, structure: Structure) -> JaxStructure: return cls.parse_obj(struct_dict) - def make_grad_monitors(self, freqs: List[float], name: str) -> FieldMonitor: + def make_grad_monitors(self, freqs: list[float], name: str) -> FieldMonitor: """Return gradient monitor associated with this object.""" if "geometry" not in self._differentiable_fields: # make a fake JaxBox to be able to call .make_grad_monitors @@ -96,7 +97,7 @@ def make_grad_monitors(self, freqs: List[float], name: str) -> FieldMonitor: def _get_medium_params( self, grad_data_eps: PermittivityData, - ) -> Dict[str, float]: + ) -> dict[str, float]: """Compute params in the material of this structure.""" freq_max = float(max(grad_data_eps.eps_xx.f)) eps_in = self.medium.eps_model(frequency=freq_max) @@ -151,7 +152,7 @@ def medium_vjp( def store_vjp( self, - # field_keys: List[Literal["medium", "geometry"]], + # field_keys: list[Literal["medium", "geometry"]], grad_data_fwd: FieldData, grad_data_adj: FieldData, grad_data_eps: PermittivityData, @@ -193,16 +194,14 @@ def store_vjp( class JaxStructure(AbstractJaxStructure, JaxObject): """A :class:`.Structure` registered with jax.""" - geometry: JaxGeometryType = pd.Field( - ..., + geometry: JaxGeometryType = Field( title="Geometry", description="Geometry of the structure, which is jax-compatible.", jax_field=True, discriminator=TYPE_TAG_STR, ) - medium: JaxMediumType = pd.Field( - ..., + medium: JaxMediumType = Field( title="Medium", description="Medium of the structure, which is jax-compatible.", jax_field=True, @@ -216,15 +215,14 @@ class JaxStructure(AbstractJaxStructure, JaxObject): class JaxStructureStaticMedium(AbstractJaxStructure, JaxObject): """A :class:`.Structure` registered with jax.""" - geometry: JaxGeometryType = pd.Field( - ..., + geometry: JaxGeometryType = Field( title="Geometry", description="Geometry of the structure, which is jax-compatible.", jax_field=True, discriminator=TYPE_TAG_STR, ) - medium: MediumType = pd.Field( + medium: MediumType = Field( ..., title="Medium", description="Regular ``tidy3d`` medium of the structure, non differentiable. " @@ -240,8 +238,7 @@ class JaxStructureStaticMedium(AbstractJaxStructure, JaxObject): class JaxStructureStaticGeometry(AbstractJaxStructure, JaxObject): """A :class:`.Structure` registered with jax.""" - geometry: GeometryType = pd.Field( - ..., + geometry: GeometryType = Field( title="Geometry", description="Regular ``tidy3d`` geometry of the structure, non differentiable. " "Supports angled sidewalls and other complex geometries.", @@ -249,8 +246,7 @@ class JaxStructureStaticGeometry(AbstractJaxStructure, JaxObject): discriminator=TYPE_TAG_STR, ) - medium: JaxMediumType = pd.Field( - ..., + medium: JaxMediumType = Field( title="Medium", description="Medium of the structure, which is jax-compatible.", jax_field=True, diff --git a/tidy3d/plugins/adjoint/components/types.py b/tidy3d/plugins/adjoint/components/types.py index fb1f01b580..a980b5ee02 100644 --- a/tidy3d/plugins/adjoint/components/types.py +++ b/tidy3d/plugins/adjoint/components/types.py @@ -3,6 +3,7 @@ from typing import Any, Union import numpy as np +from pydantic_core import core_schema from tidy3d.components.type_util import _add_schema @@ -29,14 +30,16 @@ class NumpyArrayType(np.ndarray): """Subclass of ``np.ndarray`` with a schema defined for pydantic.""" @classmethod - def __modify_schema__(cls, field_schema): - """Sets the schema of np.ndarray object.""" - - schema = dict( - title="npdarray", - type="numpy.ndarray", - ) - field_schema.update(schema) + def __get_pydantic_core_schema__(cls, source, handler): + return core_schema.no_info_plain_validator_function(lambda v, _: np.asarray(v)) + + @classmethod + def __get_pydantic_json_schema__(cls, core_schema, handler): + return { + "title": "npdarray", + "type": "numpy.ndarray", + "items": {}, + } _add_schema(JaxArrayType, title="JaxArray", field_type_str="jax.numpy.ndarray") diff --git a/tidy3d/plugins/adjoint/utils/filter.py b/tidy3d/plugins/adjoint/utils/filter.py index 139aa5678b..922a9e7944 100644 --- a/tidy3d/plugins/adjoint/utils/filter.py +++ b/tidy3d/plugins/adjoint/utils/filter.py @@ -5,11 +5,12 @@ import jax.numpy as jnp import jax.scipy as jsp import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator -from ....components.base import Tidy3dBaseModel -from ....constants import MICROMETER -from ....log import log +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.constants import MICROMETER +from tidy3d.exceptions import ValidationError +from tidy3d.log import log class Filter(Tidy3dBaseModel, ABC): @@ -23,7 +24,7 @@ def evaluate(self, spatial_data: jnp.array) -> jnp.array: class AbstractCircularFilter(Filter, ABC): """Abstract circular filter class. Initializes with parameters and .evaluate() on a design.""" - radius: float = pd.Field( + radius: float = Field( ..., title="Filter Radius", description="Radius of the filter to convolve with supplied spatial data. " @@ -33,7 +34,7 @@ class AbstractCircularFilter(Filter, ABC): units=MICROMETER, ) - design_region_dl: float = pd.Field( + design_region_dl: float = Field( ..., title="Grid Size in Design Region", description="Grid size in the design region. " @@ -46,16 +47,16 @@ def filter_radius_pixels(self) -> int: """Filter radius in pixels.""" return np.ceil(self.radius / self.design_region_dl) - @pd.root_validator(pre=True) - def _deprecate_feature_size(cls, values): + @model_validator(mode="before") + def _deprecate_feature_size(data): """Extra warning for user using ``feature_size`` field.""" - if "feature_size" in values: - raise pd.ValidationError( + if "feature_size" in data: + raise ValidationError( "The 'feature_size' field of circular filters available in 2.4 pre-releases was " "renamed to 'radius' for the official 2.4.0 release. " "If you're seeing this message, please change your script to use that field name." ) - return values + return data @abstractmethod def make_kernel(self, coords_rad: jnp.array) -> jnp.array: @@ -176,11 +177,17 @@ class BinaryProjector(Filter): """ - vmin: float = pd.Field(..., title="Min Value", description="Minimum value to project to.") + vmin: float = Field( + title="Min Value", + description="Minimum value to project to.", + ) - vmax: float = pd.Field(..., title="Max Value", description="Maximum value to project to.") + vmax: float = Field( + title="Max Value", + description="Maximum value to project to.", + ) - beta: float = pd.Field( + beta: float = Field( 1.0, title="Beta", description="Steepness of the binarization, " @@ -189,9 +196,9 @@ class BinaryProjector(Filter): "Can be useful to ramp up in a scheduled way during optimization.", ) - eta: float = pd.Field(0.5, title="Eta", description="Halfway point in projection function.") + eta: float = Field(0.5, title="Eta", description="Halfway point in projection function.") - strict_binarize: bool = pd.Field( + strict_binarize: bool = Field( False, title="Binarize strictly", description="If ``False``, the binarization is still continuous between min and max. " diff --git a/tidy3d/plugins/adjoint/utils/penalty.py b/tidy3d/plugins/adjoint/utils/penalty.py index ae4892a3bf..014a3866dc 100644 --- a/tidy3d/plugins/adjoint/utils/penalty.py +++ b/tidy3d/plugins/adjoint/utils/penalty.py @@ -3,12 +3,13 @@ from abc import ABC, abstractmethod import jax.numpy as jnp -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.types import ArrayFloat2D +from tidy3d.constants import MICROMETER +from tidy3d.log import log -from ....components.base import Tidy3dBaseModel -from ....components.types import ArrayFloat2D -from ....constants import MICROMETER -from ....log import log from .filter import BinaryProjector, ConicFilter # Radius of Curvature Calculation @@ -48,27 +49,27 @@ class RadiusPenalty(Penalty): """ - min_radius: float = pd.Field( + min_radius: float = Field( 0.150, title="Minimum Radius", description="Radius of curvature value below which the penalty ramps to its maximum value.", units=MICROMETER, ) - alpha: float = pd.Field( + alpha: float = Field( 1.0, title="Alpha", description="Parameter controlling the strength of the penalty.", ) - kappa: float = pd.Field( + kappa: float = Field( 10.0, title="Kappa", description="Parameter controlling the steepness of the penalty evaluation.", units="1/" + MICROMETER, ) - wrap: bool = pd.Field( + wrap: bool = Field( False, title="Wrap", description="Whether to consider the first set of points as connected to the last.", @@ -180,8 +181,7 @@ class ErosionDilationPenalty(Penalty): """ - length_scale: pd.NonNegativeFloat = pd.Field( - ..., + length_scale: NonNegativeFloat = Field( title="Length Scale", description="Length scale of erosion and dilation. " "Corresponds to ``radius`` in the :class:`ConicFilter` used for filtering. " @@ -190,15 +190,14 @@ class ErosionDilationPenalty(Penalty): units=MICROMETER, ) - pixel_size: pd.PositiveFloat = pd.Field( - ..., + pixel_size: PositiveFloat = Field( title="Pixel Size", description="Size of each pixel in the array (must be the same along all dimensions). " "Corresponds to ``design_region_dl`` in the :class:`ConicFilter` used for filtering.", units=MICROMETER, ) - beta: pd.PositiveFloat = pd.Field( + beta: PositiveFloat = Field( 100.0, title="Projection Beta", description="Strength of the ``tanh`` projection. " @@ -206,7 +205,7 @@ class ErosionDilationPenalty(Penalty): "Higher values correspond to stronger discretization.", ) - eta0: pd.PositiveFloat = pd.Field( + eta0: PositiveFloat = Field( 0.5, title="Projection Midpoint", description="Value between 0 and 1 that sets the projection midpoint. In other words, " @@ -214,7 +213,7 @@ class ErosionDilationPenalty(Penalty): "Corresponds to ``eta`` in the :class:`BinaryProjector`.", ) - delta_eta: pd.PositiveFloat = pd.Field( + delta_eta: PositiveFloat = Field( 0.01, title="Delta Eta Cutoff", description="The binarization threshold for erosion and dilation operations " diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index f1031ed211..c58f7bd010 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -3,22 +3,21 @@ import os import tempfile from functools import partial -from typing import Dict, List, Tuple +from typing import Literal, Optional -import pydantic.v1 as pd from jax import custom_vjp from jax.tree_util import register_pytree_node_class +from pydantic import Field import tidy3d as td +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.simulation import Simulation from tidy3d.web.api.asynchronous import run_async as web_run_async +from tidy3d.web.api.container import DEFAULT_DATA_DIR, Batch, BatchData, Job from tidy3d.web.api.webapi import run as web_run from tidy3d.web.api.webapi import wait_for_connection from tidy3d.web.core.s3utils import download_file, upload_file -from ...components.data.sim_data import SimulationData -from ...components.simulation import Simulation -from ...components.types import Literal -from ...web.api.container import DEFAULT_DATA_DIR, Batch, BatchData, Job from .components.base import JaxObject from .components.data.sim_data import JaxSimulationData from .components.simulation import NUM_PROC_LOCAL, JaxInfo, JaxSimulation @@ -32,8 +31,9 @@ class RunResidual(JaxObject): """Class to store extra data needed to pass between the forward and backward adjoint run.""" - fwd_task_id: str = pd.Field( - ..., title="Forward task_id", description="task_id of the forward simulation." + fwd_task_id: str = Field( + title="Forward task_id", + description="task_id of the forward simulation.", ) @@ -41,8 +41,9 @@ class RunResidual(JaxObject): class RunResidualBatch(JaxObject): """Class to store extra data needed to pass between the forward and backward adjoint run.""" - fwd_task_ids: Tuple[str, ...] = pd.Field( - ..., title="Forward task_ids", description="task_ids of the forward simulations." + fwd_task_ids: tuple[str, ...] = Field( + title="Forward task_ids", + description="task_ids of the forward simulations.", ) @@ -50,8 +51,9 @@ class RunResidualBatch(JaxObject): class RunResidualAsync(JaxObject): """Class to store extra data needed to pass between the forward and backward adjoint run.""" - fwd_task_ids: Dict[str, str] = pd.Field( - ..., title="Forward task_ids", description="task_ids of the forward simulation for async." + fwd_task_ids: dict[str, str] = Field( + title="Forward task_ids", + description="task_ids of the forward simulation for async.", ) @@ -70,7 +72,7 @@ def tidy3d_run_fn(simulation: Simulation, task_name: str, **kwargs) -> Simulatio return web_run(simulation=simulation, task_name=task_name, **kwargs) -def tidy3d_run_async_fn(simulations: Dict[str, Simulation], **kwargs) -> BatchData: +def tidy3d_run_async_fn(simulations: dict[str, Simulation], **kwargs) -> BatchData: """Run a set of regular :class:`.Simulation` objects after conversion from jax type.""" return web_run_async(simulations=simulations, **kwargs) @@ -158,7 +160,7 @@ def run_fwd( path: str, callback_url: str, verbose: bool, -) -> Tuple[JaxSimulationData, Tuple[RunResidual]]: +) -> tuple[JaxSimulationData, tuple[RunResidual]]: """Run forward pass and stash extra objects for the backwards pass.""" simulation._validate_web_adjoint() @@ -191,7 +193,7 @@ def run_bwd( verbose: bool, res: tuple, sim_data_vjp: JaxSimulationData, -) -> Tuple[JaxSimulation]: +) -> tuple[JaxSimulation]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" fwd_task_id = res[0].fwd_task_id @@ -259,13 +261,13 @@ def download_sim_vjp(task_id: str, verbose: bool) -> JaxSimulation: class AdjointJob(Job): """Job that uploads a jax_info object and also includes new fields for adjoint tasks.""" - simulation_type: AdjointSimulationType = pd.Field( + simulation_type: AdjointSimulationType = Field( "tidy3d", title="Simulation Type", description="Type of simulation, used internally only.", ) - jax_info: JaxInfo = pd.Field( + jax_info: Optional[JaxInfo] = Field( None, title="Jax Info", description="Container of information needed to reconstruct jax simulation.", @@ -286,19 +288,18 @@ def start(self) -> None: class AdjointBatch(Batch): """Batch that uploads a jax_info object and also includes new fields for adjoint tasks.""" - simulation_type: AdjointSimulationType = pd.Field( + simulation_type: AdjointSimulationType = Field( "tidy3d", title="Simulation Type", description="Type of simulation, used internally only.", ) - jax_infos: Dict[str, JaxInfo] = pd.Field( - ..., + jax_infos: dict[str, JaxInfo] = Field( title="Jax Info Dict", description="Containers of information needed to reconstruct JaxSimulation for each item.", ) - jobs_cached: Dict[str, AdjointJob] = pd.Field( + jobs_cached: Optional[dict[str, AdjointJob]] = Field( None, title="Jobs (Cached)", description="Optional field to specify ``jobs``. Only used as a workaround internally " @@ -330,7 +331,7 @@ def webapi_run_adjoint_fwd( path: str, callback_url: str, verbose: bool, -) -> Dict[str, float]: +) -> dict[str, float]: """Runs the forward simulation on our servers, stores the gradient data for later.""" job = AdjointJob( @@ -390,20 +391,20 @@ def _task_name_orig(index: int): @partial(custom_vjp, nondiff_argnums=tuple(range(1, 6))) def run_async( - simulations: Tuple[JaxSimulation, ...], + simulations: tuple[JaxSimulation, ...], folder_name: str = "default", path_dir: str = DEFAULT_DATA_DIR, callback_url: str = None, verbose: bool = True, num_workers: int = None, -) -> Tuple[JaxSimulationData, ...]: +) -> tuple[JaxSimulationData, ...]: """Submits a set of :class:`.JaxSimulation` objects to server, starts running, monitors progress, downloads, and loads results as a tuple of :class:`.JaxSimulationData` objects. Parameters ---------- - simulations : Tuple[:class:`.JaxSimulation`, ...] + simulations : tuple[:class:`.JaxSimulation`, ...] Collection of :class:`.JaxSimulations` to run asynchronously. folder_name : str = "default" Name of folder to store each task on web UI. @@ -424,7 +425,7 @@ def run_async( Returns ------ - Tuple[:class:`.JaxSimulationData`, ...] + tuple[:class:`.JaxSimulationData`, ...] Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`. """ @@ -464,13 +465,13 @@ def run_async( def run_async_fwd( - simulations: Tuple[JaxSimulation, ...], + simulations: tuple[JaxSimulation, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, -) -> Tuple[Tuple[JaxSimulationData, ...], RunResidualBatch]: +) -> tuple[tuple[JaxSimulationData, ...], RunResidualBatch]: """Run forward pass and stash extra objects for the backwards pass.""" for simulation in simulations: @@ -514,8 +515,8 @@ def run_async_bwd( verbose: bool, num_workers: int, res: tuple, - batch_data_vjp: Tuple[JaxSimulationData, ...], -) -> Tuple[Dict[str, JaxSimulation]]: + batch_data_vjp: tuple[JaxSimulationData, ...], +) -> tuple[dict[str, JaxSimulation]]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" fwd_task_ids = res[0].fwd_task_ids @@ -553,13 +554,13 @@ def run_async_bwd( def webapi_run_async_adjoint_fwd( - simulations: Tuple[Simulation, ...], - jax_infos: Tuple[JaxInfo, ...], + simulations: tuple[Simulation, ...], + jax_infos: tuple[JaxInfo, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, -) -> Tuple[BatchData, Dict[str, str]]: +) -> tuple[BatchData, dict[str, str]]: """Runs the forward simulations on our servers, stores the gradient data for later.""" task_names = [str(_task_name_orig(i)) for i in range(len(simulations))] @@ -581,14 +582,14 @@ def webapi_run_async_adjoint_fwd( def webapi_run_async_adjoint_bwd( - simulations: Tuple[Simulation, ...], - jax_infos: Tuple[JaxInfo, ...], + simulations: tuple[Simulation, ...], + jax_infos: tuple[JaxInfo, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, - parent_tasks: List[List[str]], -) -> List[JaxSimulation]: + parent_tasks: list[list[str]], +) -> list[JaxSimulation]: """Runs the forward simulations on our servers, stores the gradient data for later.""" task_names = [str(i) for i in range(len(simulations))] @@ -690,7 +691,7 @@ def run_local_fwd( callback_url: str, verbose: bool, num_proc: int, -) -> Tuple[JaxSimulationData, tuple]: +) -> tuple[JaxSimulationData, tuple]: """Run forward pass and stash extra objects for the backwards pass.""" # add the gradient monitors and run the forward simulation @@ -721,7 +722,7 @@ def run_local_bwd( num_proc: int, res: tuple, sim_data_vjp: JaxSimulationData, -) -> Tuple[JaxSimulation]: +) -> tuple[JaxSimulation]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" # grab the forward simulation and its gradient monitor data @@ -773,14 +774,14 @@ def _task_name_orig_local(index: int, task_name_suffix: str = None): @partial(custom_vjp, nondiff_argnums=tuple(range(1, 7))) def run_async_local( - simulations: Tuple[JaxSimulation, ...], + simulations: tuple[JaxSimulation, ...], folder_name: str = "default", path_dir: str = DEFAULT_DATA_DIR, callback_url: str = None, verbose: bool = True, num_workers: int = None, task_name_suffix: str = None, -) -> Tuple[JaxSimulationData, ...]: +) -> tuple[JaxSimulationData, ...]: """Submits a set of :class:`.JaxSimulation` objects to server, starts running, monitors progress, downloads, and loads results as a tuple of :class:`.JaxSimulationData` objects. @@ -788,7 +789,7 @@ def run_async_local( Parameters ---------- - simulations : Tuple[:class:`.JaxSimulation`, ...] + simulations : tuple[:class:`.JaxSimulation`, ...] Collection of :class:`.JaxSimulations` to run asynchronously. folder_name : str = "default" Name of folder to store each task on web UI. @@ -808,7 +809,7 @@ def run_async_local( Returns ------ - Tuple[:class:`.JaxSimulationData`, ...] + tuple[:class:`.JaxSimulationData`, ...] Contains the :class:`.JaxSimulationData` of each :class:`.JaxSimulation`. """ @@ -847,14 +848,14 @@ def run_async_local( def run_async_local_fwd( - simulations: Tuple[JaxSimulation, ...], + simulations: tuple[JaxSimulation, ...], folder_name: str, path_dir: str, callback_url: str, verbose: bool, num_workers: int, task_name_suffix: str, -) -> Tuple[Dict[str, JaxSimulationData], tuple]: +) -> tuple[dict[str, JaxSimulationData], tuple]: """Run forward pass and stash extra objects for the backwards pass.""" task_name_suffix_fwd = _task_name_fwd("") @@ -895,8 +896,8 @@ def run_async_local_bwd( num_workers: int, task_name_suffix: str, res: tuple, - batch_data_vjp: Tuple[JaxSimulationData, ...], -) -> Tuple[Dict[str, JaxSimulation]]: + batch_data_vjp: tuple[JaxSimulationData, ...], +) -> tuple[dict[str, JaxSimulation]]: """Run backward pass and return simulation storing vjp of the objective w.r.t. the sim.""" # grab the forward simulation and its gradient monitor data diff --git a/tidy3d/plugins/autograd/functions.py b/tidy3d/plugins/autograd/functions.py index 0cdb3c952b..6516dda668 100644 --- a/tidy3d/plugins/autograd/functions.py +++ b/tidy3d/plugins/autograd/functions.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List, Literal, Tuple, Union +from typing import Callable, Iterable, Literal, Union import autograd.numpy as np from autograd import jacobian @@ -30,14 +30,14 @@ ] -def _pad_indices(n: int, pad_width: Tuple[int, int], *, mode: PaddingType) -> NDArray: +def _pad_indices(n: int, pad_width: tuple[int, int], *, mode: PaddingType) -> NDArray: """Compute the indices to pad an array along a single axis based on the padding mode. Parameters ---------- n : int The size of the axis to pad. - pad_width : Tuple[int, int] + pad_width : tuple[int, int] The number of values padded to the edges of the axis. mode : PaddingType The padding mode to use. @@ -78,7 +78,7 @@ def _pad_indices(n: int, pad_width: Tuple[int, int], *, mode: PaddingType) -> ND def _pad_axis( array: NDArray, - pad_width: Tuple[int, int], + pad_width: tuple[int, int], axis: int, *, mode: PaddingType = "constant", @@ -90,7 +90,7 @@ def _pad_axis( ---------- array : np.ndarray The input array to pad. - pad_width : Tuple[int, int] + pad_width : tuple[int, int] The number of values padded to the edges of the axis. axis : int The axis along which to pad. @@ -117,7 +117,7 @@ def _pad_axis( def pad( array: NDArray, - pad_width: Union[int, Tuple[int, int]], + pad_width: Union[int, tuple[int, int]], *, mode: PaddingType = "constant", axis: Union[int, Iterable[int], None] = None, @@ -129,7 +129,7 @@ def pad( ---------- array : np.ndarray The input array to pad. - pad_width : Union[int, Tuple[int, int]] + pad_width : Union[int, tuple[int, int]] The number of values padded to the edges of each axis. If an integer is provided, it is used for both the left and right sides. If a tuple is provided, it specifies the padding for the left and right sides respectively. @@ -188,7 +188,7 @@ def convolve( kernel: NDArray, *, padding: PaddingType = "constant", - axes: Union[Tuple[List[int], List[int]], None] = None, + axes: Union[tuple[list[int], list[int]], None] = None, mode: Literal["full", "valid", "same"] = "same", ) -> NDArray: """Convolve an array with a given kernel. @@ -201,7 +201,7 @@ def convolve( The kernel to convolve with the input array. All dimensions of the kernel must be odd. padding : PaddingType = "constant" The padding mode to use. - axes : Union[Tuple[List[int], List[int]], None] = None + axes : Union[tuple[list[int], list[int]], None] = None The axes along which to perform the convolution. mode : Literal["full", "valid", "same"] = "same" The convolution mode. @@ -237,7 +237,7 @@ def convolve( def grey_dilation( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -249,7 +249,7 @@ def grey_dilation( ---------- array : np.ndarray The input array to perform grey dilation on. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -291,7 +291,7 @@ def grey_dilation( def grey_erosion( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -303,7 +303,7 @@ def grey_erosion( ---------- array : np.ndarray The input array to perform grey dilation on. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -345,7 +345,7 @@ def grey_erosion( def grey_opening( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -357,7 +357,7 @@ def grey_opening( ---------- array : np.ndarray The input array to perform grey opening on. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -378,7 +378,7 @@ def grey_opening( def grey_closing( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -390,7 +390,7 @@ def grey_closing( ---------- array : np.ndarray The input array to perform grey closing on. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -411,7 +411,7 @@ def grey_closing( def morphological_gradient( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -423,7 +423,7 @@ def morphological_gradient( ---------- array : np.ndarray The input array to compute the morphological gradient of. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -444,7 +444,7 @@ def morphological_gradient( def morphological_gradient_internal( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -456,7 +456,7 @@ def morphological_gradient_internal( ---------- array : np.ndarray The input array to compute the internal morphological gradient of. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -475,7 +475,7 @@ def morphological_gradient_internal( def morphological_gradient_external( array: NDArray, - size: Union[Union[int, Tuple[int, int]], None] = None, + size: Union[Union[int, tuple[int, int]], None] = None, structure: Union[NDArray, None] = None, *, mode: PaddingType = "reflect", @@ -487,7 +487,7 @@ def morphological_gradient_external( ---------- array : np.ndarray The input array to compute the external morphological gradient of. - size : Union[Union[int, Tuple[int, int]], None] = None + size : Union[Union[int, tuple[int, int]], None] = None The size of the structuring element. If None, `structure` must be provided. structure : Union[np.ndarray, None] = None The structuring element. If None, `size` must be provided. @@ -581,7 +581,7 @@ def threshold( def smooth_max( - x: NDArray, tau: float = 1.0, axis: Union[int, Tuple[int, ...], None] = None + x: NDArray, tau: float = 1.0, axis: Union[int, tuple[int, ...], None] = None ) -> float: """Compute the smooth maximum of an array using temperature parameter tau. @@ -591,7 +591,7 @@ def smooth_max( Input array. tau : float = 1.0 Temperature parameter controlling smoothness. Larger values make the maximum smoother. - axis : Union[int, Tuple[int, ...], None] = None + axis : Union[int, tuple[int, ...], None] = None Axis or axes over which the smooth maximum is computed. By default, the smooth maximum is computed over the entire array. Returns @@ -603,7 +603,7 @@ def smooth_max( def smooth_min( - x: NDArray, tau: float = 1.0, axis: Union[int, Tuple[int, ...], None] = None + x: NDArray, tau: float = 1.0, axis: Union[int, tuple[int, ...], None] = None ) -> float: """Compute the smooth minimum of an array using temperature parameter tau. @@ -613,7 +613,7 @@ def smooth_min( Input array. tau : float = 1.0 Temperature parameter controlling smoothness. Larger values make the minimum smoother. - axis : Union[int, Tuple[int, ...], None] = None + axis : Union[int, tuple[int, ...], None] = None Axis or axes over which the smooth minimum is computed. By default, the smooth minimum is computed over the entire array. Returns @@ -628,7 +628,7 @@ def least_squares( func: Callable[[NDArray, float], NDArray], x: NDArray, y: NDArray, - initial_guess: Tuple[float, ...], + initial_guess: tuple[float, ...], max_iterations: int = 100, tol: float = 1e-6, ) -> NDArray: @@ -643,7 +643,7 @@ def least_squares( Independent variable data. y : np.ndarray Dependent variable data. - initial_guess : Tuple[float, ...] + initial_guess : tuple[float, ...] Initial guess for the parameters to be optimized. max_iterations : int = 100 Maximum number of iterations for the optimization process. diff --git a/tidy3d/plugins/autograd/invdes/filters.py b/tidy3d/plugins/autograd/invdes/filters.py index 2bdf45e2bf..310a7b28f7 100644 --- a/tidy3d/plugins/autograd/invdes/filters.py +++ b/tidy3d/plugins/autograd/invdes/filters.py @@ -2,11 +2,11 @@ import abc from functools import lru_cache, partial -from typing import Annotated, Callable, Iterable, Tuple, Union +from typing import Annotated, Callable, Iterable, Union import numpy as np -import pydantic.v1 as pd from numpy.typing import NDArray +from pydantic import Field, PositiveInt import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel @@ -20,27 +20,32 @@ class AbstractFilter(Tidy3dBaseModel, abc.ABC): """An abstract class for creating and applying convolution filters.""" - kernel_size: Union[pd.PositiveInt, Tuple[pd.PositiveInt, ...]] = pd.Field( - ..., title="Kernel Size", description="Size of the kernel in pixels for each dimension." + kernel_size: Union[PositiveInt, tuple[PositiveInt, ...]] = Field( + title="Kernel Size", + description="Size of the kernel in pixels for each dimension.", ) - normalize: bool = pd.Field( - True, title="Normalize", description="Whether to normalize the kernel so that it sums to 1." + normalize: bool = Field( + True, + title="Normalize", + description="Whether to normalize the kernel so that it sums to 1.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) @classmethod def from_radius_dl( - cls, radius: Union[float, Tuple[float, ...]], dl: Union[float, Tuple[float, ...]], **kwargs + cls, radius: Union[float, tuple[float, ...]], dl: Union[float, tuple[float, ...]], **kwargs ) -> AbstractFilter: """Create a filter from radius and grid spacing. Parameters ---------- - radius : Union[float, Tuple[float, ...]] + radius : Union[float, tuple[float, ...]] The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] + dl : Union[float, tuple[float, ...]] The grid spacing. Can be a scalar or a tuple. **kwargs Additional keyword arguments to pass to the filter constructor. @@ -125,24 +130,24 @@ def get_kernel(size_px: Iterable[int], normalize: bool) -> NDArray: def _get_kernel_size( - radius: Union[float, Tuple[float, ...]], - dl: Union[float, Tuple[float, ...]], - size_px: Union[int, Tuple[int, ...]], -) -> Tuple[int, ...]: + radius: Union[float, tuple[float, ...]], + dl: Union[float, tuple[float, ...]], + size_px: Union[int, tuple[int, ...]], +) -> tuple[int, ...]: """Determine the kernel size based on the provided radius, grid spacing, or size in pixels. Parameters ---------- - radius : Union[float, Tuple[float, ...]] + radius : Union[float, tuple[float, ...]] The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] + dl : Union[float, tuple[float, ...]] The grid spacing. Can be a scalar or a tuple. - size_px : Union[int, Tuple[int, ...]] + size_px : Union[int, tuple[int, ...]] The size of the kernel in pixels for each dimension. Can be a scalar or a tuple. Returns ------- - Tuple[int, ...] + tuple[int, ...] The size of the kernel in pixels for each dimension. Raises @@ -164,10 +169,10 @@ def _get_kernel_size( def make_filter( - radius: Union[float, Tuple[float, ...]] = None, - dl: Union[float, Tuple[float, ...]] = None, + radius: Union[float, tuple[float, ...]] = None, + dl: Union[float, tuple[float, ...]] = None, *, - size_px: Union[int, Tuple[int, ...]] = None, + size_px: Union[int, tuple[int, ...]] = None, normalize: bool = True, padding: PaddingType = "reflect", filter_type: KernelType, @@ -176,11 +181,11 @@ def make_filter( Parameters ---------- - radius : Union[float, Tuple[float, ...]] = None + radius : Union[float, tuple[float, ...]] = None The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] = None + dl : Union[float, tuple[float, ...]] = None The grid spacing. Can be a scalar or a tuple. - size_px : Union[int, Tuple[int, ...]] = None + size_px : Union[int, tuple[int, ...]] = None The size of the kernel in pixels for each dimension. Can be a scalar or a tuple. normalize : bool = True Whether to normalize the kernel so that it sums to 1. @@ -226,4 +231,4 @@ def make_filter( :func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size. """ -FilterType = Annotated[Union[ConicFilter, CircularFilter], pd.Field(discriminator=TYPE_TAG_STR)] +FilterType = Annotated[Union[ConicFilter, CircularFilter], Field(discriminator=TYPE_TAG_STR)] diff --git a/tidy3d/plugins/autograd/invdes/parametrizations.py b/tidy3d/plugins/autograd/invdes/parametrizations.py index 22ef6d6f61..773189030b 100644 --- a/tidy3d/plugins/autograd/invdes/parametrizations.py +++ b/tidy3d/plugins/autograd/invdes/parametrizations.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Callable, Tuple, Union +from typing import Callable, Optional, Union -import pydantic.v1 as pd from numpy.typing import NDArray +from pydantic import Field, NonNegativeFloat from tidy3d.components.base import Tidy3dBaseModel @@ -16,26 +16,38 @@ class FilterAndProject(Tidy3dBaseModel): """A class that combines filtering and projection operations.""" - radius: Union[float, Tuple[float, ...]] = pd.Field( - ..., title="Radius", description="The radius of the kernel." + radius: Union[float, tuple[float, ...]] = Field( + title="Radius", + description="The radius of the kernel.", ) - dl: Union[float, Tuple[float, ...]] = pd.Field( - ..., title="Grid Spacing", description="The grid spacing." + dl: Union[float, tuple[float, ...]] = Field( + title="Grid Spacing", + description="The grid spacing.", ) - size_px: Union[int, Tuple[int, ...]] = pd.Field( - None, title="Size in Pixels", description="The size of the kernel in pixels." + size_px: Optional[Union[int, tuple[int, ...]]] = Field( + None, + title="Size in Pixels", + description="The size of the kernel in pixels.", ) - beta: pd.NonNegativeFloat = pd.Field( - BETA_DEFAULT, title="Beta", description="The beta parameter for the tanh projection." + beta: NonNegativeFloat = Field( + BETA_DEFAULT, + title="Beta", + description="The beta parameter for the tanh projection.", ) - eta: pd.NonNegativeFloat = pd.Field( - ETA_DEFAULT, title="Eta", description="The eta parameter for the tanh projection." + eta: NonNegativeFloat = Field( + ETA_DEFAULT, + title="Eta", + description="The eta parameter for the tanh projection.", ) - filter_type: KernelType = pd.Field( - "conic", title="Filter Type", description="The type of filter to create." + filter_type: KernelType = Field( + "conic", + title="Filter Type", + description="The type of filter to create.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) def __call__(self, array: NDArray, beta: float = None, eta: float = None) -> NDArray: @@ -70,10 +82,10 @@ def __call__(self, array: NDArray, beta: float = None, eta: float = None) -> NDA def make_filter_and_project( - radius: Union[float, Tuple[float, ...]] = None, - dl: Union[float, Tuple[float, ...]] = None, + radius: Union[float, tuple[float, ...]] = None, + dl: Union[float, tuple[float, ...]] = None, *, - size_px: Union[int, Tuple[int, ...]] = None, + size_px: Union[int, tuple[int, ...]] = None, beta: float = BETA_DEFAULT, eta: float = ETA_DEFAULT, filter_type: KernelType = "conic", diff --git a/tidy3d/plugins/autograd/invdes/penalties.py b/tidy3d/plugins/autograd/invdes/penalties.py index 9fc27d40ae..fe68a1748f 100644 --- a/tidy3d/plugins/autograd/invdes/penalties.py +++ b/tidy3d/plugins/autograd/invdes/penalties.py @@ -1,8 +1,8 @@ -from typing import Callable, Tuple, Union +from typing import Callable, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd from numpy.typing import NDArray +from pydantic import Field, NonNegativeFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import ArrayFloat2D @@ -14,28 +14,40 @@ class ErosionDilationPenalty(Tidy3dBaseModel): """A class that computes a penalty for erosion/dilation of a parameter map not being unity.""" - radius: Union[float, Tuple[float, ...]] = pd.Field( - ..., title="Radius", description="The radius of the kernel." + radius: Union[float, tuple[float, ...]] = Field( + title="Radius", + description="The radius of the kernel.", ) - dl: Union[float, Tuple[float, ...]] = pd.Field( - ..., title="Grid Spacing", description="The grid spacing." + dl: Union[float, tuple[float, ...]] = Field( + title="Grid Spacing", + description="The grid spacing.", ) - size_px: Union[int, Tuple[int, ...]] = pd.Field( - None, title="Size in Pixels", description="The size of the kernel in pixels." + size_px: Optional[Union[int, tuple[int, ...]]] = Field( + None, + title="Size in Pixels", + description="The size of the kernel in pixels.", ) - beta: pd.NonNegativeFloat = pd.Field( - 20.0, title="Beta", description="The beta parameter for the tanh projection." + beta: NonNegativeFloat = Field( + 20.0, + title="Beta", + description="The beta parameter for the tanh projection.", ) - eta: pd.NonNegativeFloat = pd.Field( - 0.5, title="Eta", description="The eta parameter for the tanh projection." + eta: NonNegativeFloat = Field( + 0.5, + title="Eta", + description="The eta parameter for the tanh projection.", ) - filter_type: str = pd.Field( - "conic", title="Filter Type", description="The type of filter to create." + filter_type: str = Field( + "conic", + title="Filter Type", + description="The type of filter to create.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) - delta_eta: float = pd.Field( + delta_eta: float = Field( 0.01, title="Delta Eta", description="The binarization threshold for erosion and dilation operations.", @@ -88,10 +100,10 @@ def _close(arr: NDArray): def make_erosion_dilation_penalty( - radius: Union[float, Tuple[float, ...]], - dl: Union[float, Tuple[float, ...]], + radius: Union[float, tuple[float, ...]], + dl: Union[float, tuple[float, ...]], *, - size_px: Union[int, Tuple[int, ...]] = None, + size_px: Union[int, tuple[int, ...]] = None, beta: float = 20.0, eta: float = 0.5, delta_eta: float = 0.01, diff --git a/tidy3d/plugins/autograd/utilities.py b/tidy3d/plugins/autograd/utilities.py index 1303f9e1b6..ae824f8a07 100644 --- a/tidy3d/plugins/autograd/utilities.py +++ b/tidy3d/plugins/autograd/utilities.py @@ -1,5 +1,5 @@ from functools import reduce, wraps -from typing import Any, Callable, Iterable, List, Union +from typing import Any, Callable, Iterable, Union import autograd.numpy as anp import numpy as np @@ -84,7 +84,7 @@ def make_kernel(kernel_type: KernelType, size: Iterable[int], normalize: bool = def get_kernel_size_px( radius: Union[float, Iterable[float]] = None, dl: Union[float, Iterable[float]] = None -) -> Union[int, List[int]]: +) -> Union[int, list[int]]: """Calculate the kernel size in pixels based on the provided radius and grid spacing. Parameters @@ -96,7 +96,7 @@ def get_kernel_size_px( Returns ------- - Union[int, List[int]] + Union[int, list[int]] The size of the kernel in pixels for each dimension. Returns an integer if the radius is scalar, otherwise a list of integers. Raises diff --git a/tidy3d/plugins/design/design.py b/tidy3d/plugins/design/design.py index cf3333fcb4..ea754b9a8f 100644 --- a/tidy3d/plugins/design/design.py +++ b/tidy3d/plugins/design/design.py @@ -3,15 +3,16 @@ from __future__ import annotations import inspect -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Optional, Union -import pydantic.v1 as pd +from pydantic import Field + +from tidy3d.components.base import TYPE_TAG_STR, Tidy3dBaseModel, cached_property +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.simulation import Simulation +from tidy3d.log import Console, get_logging_console, log +from tidy3d.web.api.container import Batch, BatchData, Job -from ...components.base import TYPE_TAG_STR, Tidy3dBaseModel, cached_property -from ...components.data.sim_data import SimulationData -from ...components.simulation import Simulation -from ...log import Console, get_logging_console, log -from ...web.api.container import Batch, BatchData, Job from .method import ( MethodBayOpt, MethodGenAlg, @@ -64,20 +65,19 @@ class DesignSpace(Tidy3dBaseModel): """ - parameters: Tuple[ParameterType, ...] = pd.Field( + parameters: tuple[ParameterType, ...] = Field( (), title="Parameters", description="Set of parameters defining the dimensions and allowed values for the design space.", ) - method: MethodType = pd.Field( - ..., + method: MethodType = Field( title="Search Type", description="Specifications for the procedure used to explore the parameter space.", discriminator=TYPE_TAG_STR, # Stops pydantic trying to validate every method whilst checking MethodType ) - task_name: str = pd.Field( + task_name: str = Field( "", title="Task Name", description="Task name assigned to tasks along with a simulation counter in the form of {task_name}_{sim_index}_{counter} where ``sim_index`` is " @@ -86,33 +86,37 @@ class DesignSpace(Tidy3dBaseModel): "Only used when pre-post functions are supplied.", ) - name: str = pd.Field(None, title="Name", description="Optional name for the design space.") + name: Optional[str] = Field( + None, + title="Name", + description="Optional name for the design space.", + ) - path_dir: str = pd.Field( + path_dir: str = Field( ".", title="Path Directory", description="Directory where simulation data files will be locally saved to. Only used when pre and post functions are supplied.", ) - folder_name: str = pd.Field( + folder_name: str = Field( "default", title="Folder Name", description="Folder path where the simulation will be uploaded in the Tidy3D Workspace. Will use 'default' if no path is set.", ) @cached_property - def dims(self) -> Tuple[str]: + def dims(self) -> tuple[str]: """dimensions defined by the design parameter names.""" return tuple(param.name for param in self.parameters) def _package_run_results( self, fn_args: list[dict[str, Any]], - fn_values: List[Any], + fn_values: list[Any], fn_source: str, - task_names: Tuple[str] = None, + task_names: tuple[str] = None, task_paths: list = None, - aux_values: List[Any] = None, + aux_values: list[Any] = None, opt_output: Any = None, ) -> Result: """How to package results from ``method.run`` and ``method.run_batch``""" @@ -245,14 +249,14 @@ def run(self, fn: Callable, fn_post: Callable = None, verbose: bool = True) -> R opt_output=opt_output, ) - def run_single(self, fn: Callable, console: Console) -> Tuple(list[dict], list, list[Any]): + def run_single(self, fn: Callable, console: Console) -> tuple[list[dict], list, list[Any]]: """Run a single function of parameter inputs.""" evaluate_fn = self._get_evaluate_fn_single(fn=fn) return self.method._run(run_fn=evaluate_fn, parameters=self.parameters, console=console) - def run_pre_post(self, fn_pre: Callable, fn_post: Callable, console: Console) -> Tuple( - list[dict], list[dict], list[Any] - ): + def run_pre_post( + self, fn_pre: Callable, fn_post: Callable, console: Console + ) -> tuple[list[dict], list[dict], list[Any]]: """Run a function with Tidy3D implicitly called in between.""" handler = self._get_evaluate_fn_pre_post( fn_pre=fn_pre, fn_post=fn_post, fn_mid=self._fn_mid, console=console @@ -454,9 +458,9 @@ def _remove_or_replace(search_dict: dict, attr_name: str) -> dict: def run_batch( self, - fn_pre: Callable[Any, Union[Simulation, List[Simulation], Dict[str, Simulation]]], + fn_pre: Callable[Any, Union[Simulation, list[Simulation], dict[str, Simulation]]], fn_post: Callable[ - Union[SimulationData, List[SimulationData], Dict[str, SimulationData]], Any + Union[SimulationData, list[SimulationData], dict[str, SimulationData]], Any ], path_dir: str = ".", **batch_kwargs, @@ -592,8 +596,8 @@ def summarize(self, fn_pre: Callable = None, verbose: bool = True) -> dict[str, # If check stops it printing standard attributes arg_values = [ f"{field}: {getattr(self.method, field)}\n" - for field in self.method.__fields__ - if field not in MethodOptimize.__fields__ + for field in self.method.model_fields + if field not in MethodOptimize.model_fields ] param_values = [] diff --git a/tidy3d/plugins/design/method.py b/tidy3d/plugins/design/method.py index 23fcafd217..b3fb233445 100644 --- a/tidy3d/plugins/design/method.py +++ b/tidy3d/plugins/design/method.py @@ -1,14 +1,15 @@ """Defines the methods used for parameter sweep.""" from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Literal, Tuple, Union +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd import scipy.stats.qmc as qmc +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.constants import inf -from ...components.base import Tidy3dBaseModel -from ...constants import inf from .parameter import ParameterAny, ParameterFloat, ParameterInt, ParameterType DEFAULT_MONTE_CARLO_SAMPLER_TYPE = qmc.LatinHypercube @@ -17,10 +18,12 @@ class Method(Tidy3dBaseModel, ABC): """Spec for a sweep algorithm, with a method to run it.""" - name: str = pd.Field(None, title="Name", description="Optional name for the sweep method.") + name: Optional[str] = Field( + None, title="Name", description="Optional name for the sweep method." + ) @abstractmethod - def _run(self, parameters: Tuple[ParameterType, ...], run_fn: Callable) -> Tuple[Any]: + def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable) -> tuple[Any]: """Defines the search algorithm.""" @abstractmethod @@ -36,7 +39,7 @@ def _force_int(self, next_point: dict, parameters: list) -> None: next_point[param.name] = int(round(next_point[param.name], 0)) @staticmethod - def _extract_output(output: list, sampler: bool = False) -> Tuple: + def _extract_output(output: list, sampler: bool = False) -> tuple: """Format the user function output for further optimization and result storage.""" # Light check if all the outputs are the same type @@ -57,7 +60,7 @@ def _extract_output(output: list, sampler: bool = False) -> Tuple: none_aux = [None for _ in range(len(output))] return (output, none_aux) - if all(isinstance(val, (list, Tuple)) for val in output): + if all(isinstance(val, (list, tuple)) for val in output): if all(isinstance(val[0], (float, int)) for val in output): float_out = [] aux_out = [] @@ -95,13 +98,13 @@ class MethodSample(Method, ABC): """A sweep method where all points are independently computed in one iteration.""" @abstractmethod - def sample(self, parameters: Tuple[ParameterType, ...], **kwargs) -> Dict[str, Any]: + def sample(self, parameters: tuple[ParameterType, ...], **kwargs) -> dict[str, Any]: """Defines how the design parameters are sampled.""" def _assemble_args( self, - parameters: Tuple[ParameterType, ...], - ) -> Tuple[dict, int]: + parameters: tuple[ParameterType, ...], + ) -> tuple[dict, int]: """Sample design parameters, check the args are hashable and compute number of points.""" fn_args = self.sample(parameters) @@ -109,7 +112,7 @@ def _assemble_args( self._force_int(arg_dict, parameters) return fn_args - def _run(self, parameters: Tuple[ParameterType, ...], run_fn: Callable, console) -> Tuple[Any]: + def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: """Defines the search algorithm.""" # get all function inputs @@ -139,7 +142,7 @@ def _get_run_count(self, parameters: list) -> int: return len(self.sample(parameters)) @staticmethod - def sample(parameters: Tuple[ParameterType, ...]) -> Dict[str, Any]: + def sample(parameters: tuple[ParameterType, ...]) -> dict[str, Any]: """Defines how the design parameters are sampled on the grid.""" # sample each dimension individually @@ -161,8 +164,8 @@ class MethodOptimize(Method, ABC): """A method for handling design searches that optimize the design.""" # NOTE: We could move this to the Method base class but it's not relevant to MethodGrid - seed: pd.PositiveInt = pd.Field( - default=None, + seed: Optional[PositiveInt] = Field( + None, title="Seed for random number generation", description="Set the seed used by the optimizers to ensure consistant random number generation.", ) @@ -200,31 +203,29 @@ class MethodBayOpt(MethodOptimize, ABC): >>> method = tdd.MethodBayOpt(initial_iter=4, n_iter=10) """ - initial_iter: pd.PositiveInt = pd.Field( - ..., + initial_iter: PositiveInt = Field( title="Number of Initial Random Search Iterations", description="The number of search runs to be done initialially with parameter values picked randomly. This provides a starting point for the Gaussian processor to optimize from. These solutions can be computed as a single ``Batch`` if the pre function generates ``Simulation`` objects.", ) - n_iter: pd.PositiveInt = pd.Field( - ..., + n_iter: PositiveInt = Field( title="Number of Bayesian Optimization Iterations", description="Following the initial search, this is number of iterations the Gaussian processor should be sequentially called to suggest parameter values and register the results.", ) - acq_func: Literal["ucb", "ei", "poi"] = pd.Field( + acq_func: Literal["ucb", "ei", "poi"] = Field( default="ucb", title="Type of Acquisition Function", description="The type of acquisition function that should be used to suggest parameter values. More detail available in the `package docs `_.", ) - kappa: pd.PositiveFloat = pd.Field( + kappa: PositiveFloat = Field( default=2.5, title="Kappa", description="The kappa coefficient used by the ``ucb`` acquisition function. More detail available in the `package docs `_.", ) - xi: pd.NonNegativeFloat = pd.Field( + xi: NonNegativeFloat = Field( default=0.0, title="Xi", description="The Xi coefficient used by the ``ei`` and ``poi`` acquisition functions. More detail available in the `package docs `_.", @@ -234,7 +235,7 @@ def _get_run_count(self, parameters: list = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.initial_iter + self.n_iter - def _run(self, parameters: Tuple[ParameterType, ...], run_fn: Callable, console) -> Tuple[Any]: + def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: """Defines the Bayesian optimization search algorithm for the method. Uses the ``bayes_opt`` package to carry out a Bayesian optimization. Utilizes the ``.suggest`` and ``.register`` methods instead of @@ -344,98 +345,92 @@ class MethodGenAlg(MethodOptimize, ABC): >>> method = tdd.MethodGenAlg(solutions_per_pop=2, n_generations=1, n_parents_mating=2) """ - # Args for the user - solutions_per_pop: pd.PositiveInt = pd.Field( - ..., + solutions_per_pop: PositiveInt = Field( title="Solutions per Population", description="The number of solutions to be generated for each population.", ) - n_generations: pd.PositiveInt = pd.Field( - ..., + n_generations: PositiveInt = Field( title="Number of Generations", description="The maximum number of generations to run the genetic algorithm.", ) - n_parents_mating: pd.PositiveInt = pd.Field( - ..., + n_parents_mating: PositiveInt = Field( title="Number of Parents Mating", description="The number of solutions to be selected as parents for the next generation. Crossovers of these parents will produce the next population.", ) - stop_criteria_type: Literal["reach", "saturate"] = pd.Field( + stop_criteria_type: Optional[Literal["reach", "saturate"]] = Field( default=None, title="Early Stopping Criteria Type", description="Define the early stopping criteria. Supported words are 'reach' or 'saturate'. 'reach' stops at a desired fitness, 'saturate' stops when the fitness stops improving. Must set ``stop_criteria_number``. See the `PyGAD docs `_ for more details.", ) - stop_criteria_number: pd.PositiveFloat = pd.Field( + stop_criteria_number: Optional[PositiveFloat] = Field( default=None, title="Early Stopping Criteria Number", description="Must set ``stop_criteria_type``. If type is 'reach' the number is acceptable fitness value to stop the optimization. If type is 'saturate' the number is the number generations where the fitness doesn't improve before optimization is stopped. See the `PyGAD docs `_ for more details.", ) - parent_selection_type: Literal["sss", "rws", "sus", "rank", "random", "tournament"] = pd.Field( + parent_selection_type: Literal["sss", "rws", "sus", "rank", "random", "tournament"] = Field( default="sss", title="Parent Selection Type", description="The style of parent selector. See the `PyGAD docs `_ for more details.", ) - keep_parents: Union[pd.PositiveInt, Literal[-1, 0]] = pd.Field( + keep_parents: Union[PositiveInt, Literal[-1, 0]] = Field( default=-1, title="Keep Parents", description="The number of parents to keep unaltered in the population of the next generation. Default value of -1 keeps all current parents for the next generation. This value is overwritten if ``keep_parents`` is > 0. See the `PyGAD docs `_ for more details.", ) - keep_elitism: Union[pd.PositiveInt, Literal[0]] = pd.Field( + keep_elitism: Union[PositiveInt, Literal[0]] = Field( default=1, title="Keep Elitism", description="The number of top solutions to be included in the population of the next generation. Overwrites ``keep_parents`` if value is > 0. See the `PyGAD docs `_ for more details.", ) - crossover_type: Union[None, Literal["single_point", "two_points", "uniform", "scattered"]] = ( - pd.Field( - default="single_point", - title="Crossover Type", - description="The style of crossover operation. See the `PyGAD docs `_ for more details.", - ) + crossover_type: Optional[Literal["single_point", "two_points", "uniform", "scattered"]] = Field( + default="single_point", + title="Crossover Type", + description="The style of crossover operation. See the `PyGAD docs `_ for more details.", ) - crossover_prob: pd.confloat(ge=0, le=1) = pd.Field( + crossover_prob: float = Field( default=0.8, title="Crossover Probability", description="The probability of performing a crossover between two parents.", + ge=0, + le=1, ) - mutation_type: Union[None, Literal["random", "swap", "inversion", "scramble", "adaptive"]] = ( - pd.Field( - default="random", - title="Mutation Type", - description="The style of gene mutation. See the `PyGAD docs `_ for more details.", - ) + mutation_type: Optional[Literal["random", "swap", "inversion", "scramble", "adaptive"]] = Field( + default="random", + title="Mutation Type", + description="The style of gene mutation. See the `PyGAD docs `_ for more details.", ) - mutation_prob: Union[pd.confloat(ge=0, le=1), Literal[None]] = pd.Field( + mutation_prob: Optional[float] = Field( default=0.2, title="Mutation Probability", description="The probability of mutating a gene.", + ge=0, + le=1, ) - save_solution: pd.StrictBool = pd.Field( + save_solution: bool = Field( default=False, title="Save Solutions", description="Save all solutions from all generations within a numpy array. Can be accessed from the optimizer object stored in the Result. May cause memory issues with large populations or many generations. See the `PyGAD docs _` for more details.", ) - # TODO: See if anyone is interested in having the full suite of PyGAD options - there's a lot! - def _get_run_count(self, parameters: list = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" # +1 to generations as pygad creates an initial population which is effectively "Generation 0" run_count = self.solutions_per_pop * (self.n_generations + 1) return run_count - def _run(self, parameters: Tuple[ParameterType, ...], run_fn: Callable, console) -> Tuple[Any]: + def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: """Defines the genetic algorithm for the method. Uses the ``pygad`` package to carry out a particle search optimization. Additional development has ensured that @@ -632,49 +627,47 @@ class MethodParticleSwarm(MethodOptimize, ABC): >>> method = tdd.MethodParticleSwarm(n_particles=5, n_iter=3) """ - n_particles: pd.PositiveInt = pd.Field( - ..., + n_particles: PositiveInt = Field( title="Number of Particles", description="The number of particles to be used in the swarm for the optimization.", ) - n_iter: pd.PositiveInt = pd.Field( - ..., + n_iter: PositiveInt = Field( title="Number of Iterations", description="The maxmium number of iterations to run the optimization.", ) - cognitive_coeff: pd.PositiveFloat = pd.Field( + cognitive_coeff: PositiveFloat = Field( default=1.5, title="Cognitive Coefficient", description="The cognitive parameter decides how attracted the particle is to its previous best position.", ) - social_coeff: pd.PositiveFloat = pd.Field( + social_coeff: PositiveFloat = Field( default=1.5, title="Social Coefficient", description="The social parameter decides how attracted the particle is to the global best position found by the swarm.", ) - weight: pd.PositiveFloat = pd.Field( + weight: PositiveFloat = Field( default=0.9, title="Weight", description="The weight or inertia of particles in the optimization.", ) - ftol: Union[pd.confloat(ge=0, le=1), Literal[-inf]] = pd.Field( + ftol: Union[Annotated[float, Field(ge=0, le=1)], Literal[-inf]] = Field( default=-inf, title="Relative Error for Convergence", description="Relative error in ``objective_func(best_solution)`` acceptable for convergence. See the `PySwarms docs `_ for details. Off by default.", ) - ftol_iter: pd.PositiveInt = pd.Field( + ftol_iter: PositiveInt = Field( default=1, title="Number of Iterations Before Convergence", description="Number of iterations over which the relative error in the objective_func is acceptable for convergence.", ) - init_pos: np.ndarray = pd.Field( + init_pos: Optional[np.ndarray] = Field( default=None, title="Initial Swarm Positions", description="Set the initial positions of the swarm using a numpy array of appropriate size.", @@ -684,7 +677,7 @@ def _get_run_count(self, parameters: list = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.n_particles * self.n_iter - def _run(self, parameters: Tuple[ParameterType, ...], run_fn: Callable, console) -> Tuple[Any]: + def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: """Defines the particle search optimization algorithm for the method. Uses the ``pyswarms`` package to carry out a particle search optimization. @@ -772,27 +765,26 @@ def fitness_function(solution: np.array) -> np.array: class AbstractMethodRandom(MethodSample, ABC): """Select parameters with an object with a ``random`` method.""" - num_points: pd.PositiveInt = pd.Field( - ..., + num_points: PositiveInt = Field( title="Number of Sampling Points", description="The number of points to be generated for sampling.", ) - seed: pd.PositiveInt = pd.Field( + seed: Optional[PositiveInt] = Field( default=None, title="Seed", description="Sets the seed used by the optimizers to set constant random number generation.", ) @abstractmethod - def _get_sampler(self, parameters: Tuple[ParameterType, ...]) -> qmc.QMCEngine: + def _get_sampler(self, parameters: tuple[ParameterType, ...]) -> qmc.QMCEngine: """Sampler for this ``Method`` class. If ``None``, sets a default.""" def _get_run_count(self, parameters: list = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.num_points - def sample(self, parameters: Tuple[ParameterType, ...], **kwargs) -> Dict[str, Any]: + def sample(self, parameters: tuple[ParameterType, ...], **kwargs) -> dict[str, Any]: """Defines how the design parameters are sampled on grid.""" sampler = self._get_sampler(parameters) @@ -822,7 +814,7 @@ class MethodMonteCarlo(AbstractMethodRandom): >>> method = tdd.MethodMonteCarlo(num_points=20) """ - def _get_sampler(self, parameters: Tuple[ParameterType, ...]) -> qmc.QMCEngine: + def _get_sampler(self, parameters: tuple[ParameterType, ...]) -> qmc.QMCEngine: """Sampler for this ``Method`` class.""" d = len(parameters) diff --git a/tidy3d/plugins/design/parameter.py b/tidy3d/plugins/design/parameter.py index 77e5cf3c70..36cff824db 100644 --- a/tidy3d/plugins/design/parameter.py +++ b/tidy3d/plugins/design/parameter.py @@ -3,52 +3,51 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Union +from typing import Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, field_validator -from ...components.base import Tidy3dBaseModel +from tidy3d.components.base import Tidy3dBaseModel class Parameter(Tidy3dBaseModel, ABC): """Specification for a single variable / dimension in a design problem.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the variable. Used as a key into the parameter sweep results.", ) - values: Tuple[Any, ...] = pd.Field( + values: Optional[tuple[Any, ...]] = Field( None, title="Custom Values", description="If specified, the parameter scan uses these values for grid search methods.", ) - @pd.validator("values", always=True) - def _values_unique(cls, val): + @field_validator("values") + def _values_unique(val): """Supplied unique values.""" if (val is not None) and (len(set(val)) != len(val)): raise ValueError("Supplied 'values' were not unique.") return val - def sample_grid(self) -> List[Any]: + def sample_grid(self) -> list[Any]: """Sample design variable on grid, checking for custom values.""" if self.values is not None: return self.values return self._sample_grid() @abstractmethod - def sample_random(self, num_samples: int) -> List[Any]: + def sample_random(self, num_samples: int) -> list[Any]: """Sample this design variable randomly 'num_samples' times.""" @abstractmethod - def _sample_grid(self) -> List[Any]: + def _sample_grid(self) -> list[Any]: """Sample this design variable on a grid.""" @abstractmethod - def select_from_01(self, pts_01: np.ndarray) -> List[Any]: + def select_from_01(self, pts_01: np.ndarray) -> list[Any]: """Select values given a set of points between 0, 1.""" @abstractmethod @@ -59,14 +58,13 @@ def sample_first(self) -> Any: class ParameterNumeric(Parameter, ABC): """A variable with numeric values.""" - span: Tuple[Union[float, int], Union[float, int]] = pd.Field( - ..., + span: tuple[Union[float, int], Union[float, int]] = Field( title="Span", description="(min, max) range within which are allowed values for the variable. Is inclusive of max value.", ) - @pd.validator("span", always=True) - def _span_valid(cls, val): + @field_validator("span") + def _span_valid(val): """Span min <= span max.""" span_min, span_max = val if span_min > span_max: @@ -96,25 +94,25 @@ class ParameterFloat(ParameterNumeric): >>> var = tdd.ParameterFloat(name="x", num_points=10, span=(1, 2.5)) """ - num_points: pd.PositiveInt = pd.Field( + num_points: Optional[PositiveInt] = Field( None, title="Number of Points", description="Number of uniform sampling points for this variable. " "Only used for 'MethodGrid'. ", ) - @pd.validator("span", always=True) - def _span_is_float(cls, val): + @field_validator("span") + def _span_is_float(val): """Make sure the span contains floats.""" low, high = val return float(low), float(high) - def sample_random(self, num_samples: int) -> List[float]: + def sample_random(self, num_samples: int) -> list[float]: """Sample this design variable randomly 'num_samples' times.""" low, high = self.span return np.random.uniform(low=low, high=high, size=num_samples).tolist() - def _sample_grid(self) -> List[float]: + def _sample_grid(self) -> list[float]: """Sample this design variable on a grid.""" if self.num_points is None: raise ValueError( @@ -123,7 +121,7 @@ def _sample_grid(self) -> List[float]: low, high = self.span return np.linspace(low, high, self.num_points).tolist() - def select_from_01(self, pts_01: np.ndarray) -> List[Any]: + def select_from_01(self, pts_01: np.ndarray) -> list[Any]: """Select values given a set of points between 0, 1.""" return (min(self.span) + pts_01 * self.span_size).tolist() @@ -138,31 +136,30 @@ class ParameterInt(ParameterNumeric): >>> var = tdd.ParameterInt(name="x", span=(1, 4)) """ - span: Tuple[int, int] = pd.Field( - ..., + span: tuple[int, int] = Field( title="Span", description="``(min, max)`` range within which are allowed values for the variable. " "The ``min`` value is inclusive and the ``max`` value is exclusive. In other words, " "a grid search over this variable will iterate over ``np.arange(min, max)``.", ) - @pd.validator("span", always=True) - def _span_is_int(cls, val): + @field_validator("span") + def _span_is_int(val): """Make sure the span contains ints.""" low, high = val return int(low), int(high) - def sample_random(self, num_samples: int) -> List[int]: + def sample_random(self, num_samples: int) -> list[int]: """Sample this design variable randomly 'num_samples' times.""" low, high = self.span return np.random.randint(low=low, high=high, size=num_samples).tolist() - def _sample_grid(self) -> List[float]: + def _sample_grid(self) -> list[float]: """Sample this design variable on a grid.""" low, high = self.span return np.arange(low, high).tolist() - def select_from_01(self, pts_01: np.ndarray) -> List[Any]: + def select_from_01(self, pts_01: np.ndarray) -> list[Any]: """Select values given a set of points between 0, 1.""" pts_continuous = min(self.span) + pts_01 * self.span_size return np.floor(pts_continuous).astype(int).tolist() @@ -177,35 +174,34 @@ class ParameterAny(Parameter): >>> var = tdd.ParameterAny(name="x", allowed_values=("a", "b", "c")) """ - allowed_values: Tuple[Any, ...] = pd.Field( - ..., + allowed_values: tuple[Any, ...] = Field( title="Allowed Values", description="The discrete set of values that this variable can take on.", ) - @pd.validator("allowed_values", always=True) - def _given_any_allowed_values(cls, val): + @field_validator("allowed_values") + def _given_any_allowed_values(val): """Need at least one allowed value.""" if not len(val): raise ValueError("Given empty tuple of allowed values. Must have at least one.") return val - @pd.validator("allowed_values", always=True) + @field_validator("allowed_values") def _no_duplicate_allowed_values(cls, val): """No duplicates in allowed_values.""" if len(val) != len(set(val)): raise ValueError("'allowed_values' has duplicate entries, must be unique.") return val - def sample_random(self, num_samples: int) -> List[Any]: + def sample_random(self, num_samples: int) -> list[Any]: """Sample this design variable randomly 'num_samples' times.""" return np.random.choice(self.allowed_values, size=int(num_samples)).tolist() - def _sample_grid(self) -> List[Any]: + def _sample_grid(self) -> list[Any]: """Sample this design variable uniformly, ie just take all allowed values.""" return list(self.allowed_values) - def select_from_01(self, pts_01: np.ndarray) -> List[Any]: + def select_from_01(self, pts_01: np.ndarray) -> list[Any]: """Select values given a set of points between 0, 1.""" pts_continuous = pts_01 * len(self.allowed_values) indices = np.floor(pts_continuous).astype(int) diff --git a/tidy3d/plugins/design/result.py b/tidy3d/plugins/design/result.py index 5f703cecfc..8ac7d2f84b 100644 --- a/tidy3d/plugins/design/result.py +++ b/tidy3d/plugins/design/result.py @@ -2,13 +2,13 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Optional import numpy as np import pandas -import pydantic.v1 as pd +from pydantic import Field, model_validator -from ...components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.base import Tidy3dBaseModel, cached_property # NOTE: Coords are args_dict from method and design. This may be changed in future to unify naming @@ -30,39 +30,39 @@ class Result(Tidy3dBaseModel): >>> # df.head() # print out first 5 elements of data """ - dims: Tuple[str, ...] = pd.Field( + dims: tuple[str, ...] = Field( (), title="Dimensions", description="The dimensions of the design variables (indexed by 'name').", ) - values: Tuple[Any, ...] = pd.Field( + values: tuple[Any, ...] = Field( (), title="Values", description="The return values from the design problem function.", ) - coords: Tuple[Tuple[Any, ...], ...] = pd.Field( + coords: tuple[tuple[Any, ...], ...] = Field( (), title="Coordinates", description="The values of the coordinates corresponding to each of the dims." "Note: shaped (D, N) where D is the ``len(dims)`` and N is the ``len(values)``", ) - output_names: Tuple[str, ...] = pd.Field( + output_names: Optional[tuple[str, ...]] = Field( None, title="Output Names", description="Names for each of the outputs stored in ``values``. If not specified, default " "values are assigned.", ) - fn_source: str = pd.Field( + fn_source: Optional[str] = Field( None, title="Function Source Code", description="Source code for the function evaluated in the parameter sweep.", ) - task_names: list = pd.Field( + task_names: Optional[list] = Field( None, title="Task Names", description="Task name of every simulation run during ``DesignSpace.run``. Only available if " @@ -70,7 +70,7 @@ class Result(Tidy3dBaseModel): "Stored in the same format as the output of fn_pre i.e. if pre outputs a dict, this output is a dict with the keys preserved.", ) - task_paths: list = pd.Field( + task_paths: Optional[list] = Field( None, title="Task Paths", description="Task paths of every simulation run during ``DesignSpace.run``. Useful for loading download ``SimulationData`` hdf5 files." @@ -78,50 +78,48 @@ class Result(Tidy3dBaseModel): "Stored in the same format as the output of fn_pre i.e. if pre outputs a dict, this output is a dict with the keys preserved.", ) - aux_values: Tuple[Any, ...] = pd.Field( + aux_values: Optional[tuple[Any, ...]] = Field( None, title="Auxiliary values output from the user function", description="The auxiliary return values from the design problem function. This is the collection of objects returned " "alongside the float value used for the optimization. These weren't used to inform the optimizer, if one was used.", ) - optimizer: Any = pd.Field( + optimizer: Any = Field( None, title="Optimizer object", description="The optimizer returned at the end of an optimizer run. Can be used to analyze and plot how the optimization progressed. " "Attributes depend on the optimizer used; a full explaination of the optimizer can be found on associated library doc pages. Will be ``None`` for sampling based methods.", ) - @pd.validator("coords", always=True) - def _coords_and_dims_shape(cls, val, values): + @model_validator(mode="after") + def _coords_and_dims_shape(self): """Make sure coords and dims have same size.""" - dims = values.get("dims") - - if val is None or dims is None: + if self.coords is None or self.dims is None: return - num_dims = len(dims) - for i, _val in enumerate(val): + num_dims = len(self.dims) + for i, _val in enumerate(self.coords): if len(_val) != num_dims: raise ValueError( f"Number of 'coords' at index '{i}' ({len(_val)}) " f"doesn't match the number of 'dims' ({num_dims})." ) - return val + return self - @pd.validator("coords", always=True) - def _coords_and_values_shape(cls, val, values): + @model_validator(mode="after") + def _coords_and_values_shape(self): """Make sure coords and values have same length.""" - _values = values.get("values") + _values = self.values - if val is None or _values is None: + if self.coords is None or _values is None: return num_values = len(_values) - num_coords = len(val) + num_coords = len(self.coords) if num_values != num_coords: raise ValueError( @@ -129,9 +127,9 @@ def _coords_and_values_shape(cls, val, values): f"Have {num_coords} and {num_values} elements, respectively." ) - return val + return self - def value_as_dict(self, value) -> Dict[str, Any]: + def value_as_dict(self, value) -> dict[str, Any]: """How to convert an output function value as a dictionary.""" if isinstance(value, dict): return value @@ -141,7 +139,7 @@ def value_as_dict(self, value) -> Dict[str, Any]: return dict(zip(keys, value)) @staticmethod - def default_value_keys(value) -> Tuple[str, ...]: + def default_value_keys(value) -> tuple[str, ...]: """The default keys for a given value.""" # if a dict already, just use the existing keys as labels @@ -155,7 +153,7 @@ def default_value_keys(value) -> Tuple[str, ...]: # if simply single value (float, int, bool, etc) just label "output" return ("output",) - def items(self) -> Tuple[dict, Any]: + def items(self) -> tuple[dict, Any]: """Iterate through coordinates (args) and values (outputs) one by one.""" for coord_tuple, val in zip(self.coords, self.values): @@ -163,7 +161,7 @@ def items(self) -> Tuple[dict, Any]: yield coord_dict, val @cached_property - def data(self) -> Dict[tuple, Any]: + def data(self) -> dict[tuple, Any]: """Dict mapping tuple of fn args to their value.""" result = {} @@ -246,14 +244,14 @@ def to_dataframe(self, include_aux: bool = False) -> pandas.DataFrame: return df @classmethod - def from_dataframe(cls, df: pandas.DataFrame, dims: List[str] = None) -> Result: + def from_dataframe(cls, df: pandas.DataFrame, dims: list[str] = None) -> Result: """Load a result directly from a `pandas.DataFrame` object. Parameters ---------- df : ``pandas.DataFrame`` ```DataFrame`` object to load into a :class:`.Result`. - dims : List[str] = None + dims : list[str] = None Set of dimensions corresponding to the function arguments. Not required if this dataframe was generated directly from a :class:`.Result` without modification. In that case, it contains the dims in its ``.attrs`` metadata. @@ -346,19 +344,19 @@ def __add__(self, other): """Special syntax for design_result1 + design_result2.""" return self.combine(other) - def get_index(self, fn_args: Dict[str, float]) -> int: + def get_index(self, fn_args: dict[str, float]) -> int: """Get index into the data for a specific set of arguments.""" key_list = list(self.coords) arg_key = tuple(fn_args[dim] for dim in self.dims) return key_list.index(arg_key) - def delete(self, fn_args: Dict[str, float]) -> Result: + def delete(self, fn_args: dict[str, float]) -> Result: """Delete a specific set of arguments from the result. Parameters ---------- - fn_args : Dict[str, float] + fn_args : dict[str, float] ``dict`` containing the function arguments one wishes to delete. Returns @@ -392,12 +390,12 @@ def delete(self, fn_args: Dict[str, float]) -> Result: return self.updated_copy(values=new_values, coords=new_coords) - def add(self, fn_args: Dict[str, float], value: Any) -> Result: + def add(self, fn_args: dict[str, float], value: Any) -> Result: """Add a specific argument and value the result. Parameters ---------- - fn_args : Dict[str, float] + fn_args : dict[str, float] ``dict`` containing the function arguments one wishes to add. value : Any Data point value corresponding to these arguments. diff --git a/tidy3d/plugins/dispersion/fit.py b/tidy3d/plugins/dispersion/fit.py index 24d318fa12..bbc07094b6 100644 --- a/tidy3d/plugins/dispersion/fit.py +++ b/tidy3d/plugins/dispersion/fit.py @@ -4,49 +4,46 @@ import codecs import csv -from typing import List, Optional, Tuple +from typing import Optional import numpy as np import requests import scipy.optimize as opt -from pydantic.v1 import Field, validator +from pydantic import Field, field_validator, model_validator from rich.progress import Progress +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.medium import AbstractMedium, PoleResidue +from tidy3d.components.types import ArrayFloat1D, Ax +from tidy3d.components.viz import add_ax_if_none +from tidy3d.constants import C_0, HBAR, MICROMETER +from tidy3d.exceptions import SetupError, ValidationError, WebError +from tidy3d.log import get_logging_console, log from tidy3d.web.core.environment import Env -from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from ...components.medium import AbstractMedium, PoleResidue -from ...components.types import ArrayFloat1D, Ax -from ...components.viz import add_ax_if_none -from ...constants import C_0, HBAR, MICROMETER -from ...exceptions import SetupError, ValidationError, WebError -from ...log import get_logging_console, log - class DispersionFitter(Tidy3dBaseModel): """Tool for fitting refractive index data to get a dispersive medium described by :class:`.PoleResidue` model.""" wvl_um: ArrayFloat1D = Field( - ..., title="Wavelength data", description="Wavelength data in micrometers.", units=MICROMETER, ) n_data: ArrayFloat1D = Field( - ..., title="Index of refraction data", description="Real part of the complex index of refraction.", ) - k_data: ArrayFloat1D = Field( + k_data: Optional[ArrayFloat1D] = Field( None, title="Extinction coefficient data", description="Imaginary part of the complex index of refraction.", ) - wvl_range: Tuple[Optional[float], Optional[float]] = Field( + wvl_range: tuple[Optional[float], Optional[float]] = Field( (None, None), title="Wavelength range [wvl_min,wvl_max] for fitting", description="Truncate the wavelength, n and k data to the wavelength range '[wvl_min, " @@ -54,40 +51,36 @@ class DispersionFitter(Tidy3dBaseModel): units=MICROMETER, ) - @validator("wvl_um", always=True) - def _setup_wvl(cls, val): + @field_validator("wvl_um") + def _setup_wvl(val): """Convert wvl_um to a numpy array.""" if val.size == 0: raise ValidationError("Wavelength data cannot be empty.") return val - @validator("n_data", always=True) - @skip_if_fields_missing(["wvl_um"]) - def _ndata_length_match_wvl(cls, val, values): + @model_validator(mode="after") + def _ndata_length_match_wvl(self): """Validate n_data""" - - if val.shape != values["wvl_um"].shape: + if self.n_data.shape != self.wvl_um.shape: raise ValidationError("The length of 'n_data' doesn't match 'wvl_um'.") - return val + return self - @validator("k_data", always=True) - @skip_if_fields_missing(["wvl_um"]) - def _kdata_setup_and_length_match(cls, val, values): + @model_validator(mode="after") + def _kdata_setup_and_length_match(self): """Validate the length of k_data, or setup k if it's None.""" - - if val is None: - return np.zeros_like(values["wvl_um"]) - if val.shape != values["wvl_um"].shape: + if self.k_data is None: + self.k_data = np.zeros_like(self.wvl_um) + if self.k_data.shape != self.wvl_um.shape: raise ValidationError("The length of 'k_data' doesn't match 'wvl_um'.") - return val + return self @cached_property - def data_in_range(self) -> Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D]: + def data_in_range(self) -> tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D]: """Filter the wavelength-nk data to wavelength range for fitting. Returns ------- - Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] Filtered wvl_um, n_data, k_data """ @@ -129,12 +122,12 @@ def eps_data(self) -> complex: return AbstractMedium.nk_to_eps_complex(n=n_data, k=k_data) @property - def freqs(self) -> Tuple[float, ...]: + def freqs(self) -> tuple[float, ...]: """Convert filtered input wavelength data to frequency. Returns ------- - Tuple[float, ...] + tuple[float, ...] Frequency array converted from filtered input wavelength data """ @@ -142,12 +135,12 @@ def freqs(self) -> Tuple[float, ...]: return C_0 / wvl_um @property - def frequency_range(self) -> Tuple[float, float]: + def frequency_range(self) -> tuple[float, float]: """Frequency range of filtered input data Returns ------- - Tuple[float, float] + tuple[float, float] The minimal frequency and the maximal frequency """ @@ -164,7 +157,7 @@ def _unpack_coeffs(coeffs): Returns ------- - Tuple[np.ndarray[complex], np.ndarray[complex]] + tuple[np.ndarray[complex], np.ndarray[complex]] "a" and "c" poles for the PoleResidue model. """ if len(coeffs) % 4 != 0: @@ -209,7 +202,7 @@ def _coeffs_to_poles(coeffs): Returns ------- - List[Tuple[complex, complex]] + list[tuple[complex, complex]] List of complex poles (a, c) """ coeffs_scaled = coeffs / HBAR @@ -222,7 +215,7 @@ def _poles_to_coeffs(poles): Parameters ---------- - poles : List[Tuple[complex, complex]] + poles : list[tuple[complex, complex]] List of complex poles (a, c) Returns @@ -262,7 +255,7 @@ def fit( num_tries: int = 50, tolerance_rms: float = 1e-2, guess: PoleResidue = None, - ) -> Tuple[PoleResidue, float]: + ) -> tuple[PoleResidue, float]: """Fit data a number of times and returns best results. Parameters @@ -279,7 +272,7 @@ def fit( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ @@ -347,7 +340,7 @@ def _fit_single( self, num_poles: int = 3, guess: PoleResidue = None, - ) -> Tuple[PoleResidue, float]: + ) -> tuple[PoleResidue, float]: """Perform a single fit to the data and return optimization result. Parameters @@ -359,7 +352,7 @@ def _fit_single( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Results of single fit: (dispersive medium, RMS error). """ @@ -557,7 +550,7 @@ def plot( return ax @staticmethod - def _validate_url_load(data_load: List): + def _validate_url_load(data_load: list): """Validate if the loaded data from URL is valid The data list should be in this format: [["wl", "n"], @@ -753,7 +746,7 @@ def from_complex_permittivity( wvl_um: ArrayFloat1D, eps_real: ArrayFloat1D, eps_imag: ArrayFloat1D = None, - wvl_range: Tuple[Optional[float], Optional[float]] = (None, None), + wvl_range: tuple[Optional[float], Optional[float]] = (None, None), ) -> DispersionFitter: """Loads :class:`DispersionFitter` from wavelength and complex relative permittivity data @@ -765,7 +758,7 @@ def from_complex_permittivity( Real parts of relative permittivity data eps_imag : Optional[ArrayFloat1D] Imaginary parts of relative permittivity data; `None` for lossless medium. - wvg_range : Tuple[Optional[float], Optional[float]] + wvg_range : tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns @@ -785,7 +778,7 @@ def from_loss_tangent( wvl_um: ArrayFloat1D, eps_real: ArrayFloat1D, loss_tangent: ArrayFloat1D, - wvl_range: Tuple[Optional[float], Optional[float]] = (None, None), + wvl_range: tuple[Optional[float], Optional[float]] = (None, None), ) -> DispersionFitter: """Loads :class:`DispersionFitter` from wavelength and loss tangent data. @@ -797,7 +790,7 @@ def from_loss_tangent( Real parts of relative permittivity data loss_tangent : Optional[ArrayFloat1D] Loss tangent data, defined as the ratio of imaginary and real parts of permittivity. - wvl_range : Tuple[Optional[float], Optional[float]] + wvl_range : tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns diff --git a/tidy3d/plugins/dispersion/fit_fast.py b/tidy3d/plugins/dispersion/fit_fast.py index 10564093ea..b983454e62 100644 --- a/tidy3d/plugins/dispersion/fit_fast.py +++ b/tidy3d/plugins/dispersion/fit_fast.py @@ -2,14 +2,13 @@ from __future__ import annotations -from typing import Tuple - import numpy as np -from pydantic.v1 import NonNegativeFloat, PositiveInt +from pydantic import NonNegativeFloat, PositiveInt + +from tidy3d.components.dispersion_fitter import AdvancedFastFitterParam, fit +from tidy3d.components.medium import PoleResidue +from tidy3d.constants import C_0, HBAR -from ...components.dispersion_fitter import AdvancedFastFitterParam, fit -from ...components.medium import PoleResidue -from ...constants import C_0, HBAR from .fit import DispersionFitter # numerical tolerance for pole relocation for fast fitter @@ -44,7 +43,7 @@ def fit( eps_inf: float = None, tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS, advanced_param: AdvancedFastFitterParam = None, - ) -> Tuple[PoleResidue, float]: + ) -> tuple[PoleResidue, float]: """Fit data using a fast fitting algorithm. Note @@ -87,7 +86,7 @@ def fit( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best fitting result: (dispersive medium, weighted RMS error). """ @@ -117,7 +116,7 @@ def constant_loss_tangent_model( cls, eps_real: float, loss_tangent: float, - frequency_range: Tuple[float, float], + frequency_range: tuple[float, float], max_num_poles: PositiveInt = DEFAULT_MAX_POLES, number_sampling_frequency: PositiveInt = 10, tolerance_rms: NonNegativeFloat = DEFAULT_TOLERANCE_RMS, @@ -130,7 +129,7 @@ def constant_loss_tangent_model( Real part of permittivity loss_tangent : float Loss tangent. - frequency_range : Tuple[float, float] + frequency_range : tuple[float, float] Freqquency range for the material to exhibit constant loss tangent response. max_num_poles : PositiveInt, optional Maximum number of poles in the model. diff --git a/tidy3d/plugins/dispersion/fit_web.py b/tidy3d/plugins/dispersion/fit_web.py index 4eac610734..63bfa9c864 100644 --- a/tidy3d/plugins/dispersion/fit_web.py +++ b/tidy3d/plugins/dispersion/fit_web.py @@ -1,6 +1,6 @@ """Deprecated module""" -from ...log import log +from tidy3d.log import log log.warning( "The module 'plugins.dispersion.fit_web' has been deprecated in favor of " diff --git a/tidy3d/plugins/dispersion/web.py b/tidy3d/plugins/dispersion/web.py index 275ff076e7..a5637a64ca 100644 --- a/tidy3d/plugins/dispersion/web.py +++ b/tidy3d/plugins/dispersion/web.py @@ -4,21 +4,20 @@ import ssl from enum import Enum -from typing import Optional, Tuple +from typing import Optional -import pydantic.v1 as pydantic import requests -from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator - +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, model_validator + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.medium import PoleResidue +from tidy3d.components.types import Literal +from tidy3d.constants import HERTZ, MICROMETER +from tidy3d.exceptions import SetupError, Tidy3dError, WebError +from tidy3d.log import log from tidy3d.web.core.environment import Env from tidy3d.web.core.http_util import get_headers -from ...components.base import Tidy3dBaseModel, skip_if_fields_missing -from ...components.medium import PoleResidue -from ...components.types import Literal -from ...constants import HERTZ, MICROMETER -from ...exceptions import SetupError, Tidy3dError, WebError -from ...log import log from .fit import DispersionFitter BOUND_MAX_FACTOR = 10 @@ -40,7 +39,7 @@ class ExceptionCodes(Enum): class AdvancedFitterParam(Tidy3dBaseModel): """Advanced fitter parameters""" - bound_amp: NonNegativeFloat = Field( + bound_amp: Optional[NonNegativeFloat] = Field( None, title="Upper bound of oscillator strength", description="Upper bound of real and imagniary part of oscillator " @@ -48,7 +47,7 @@ class AdvancedFitterParam(Tidy3dBaseModel): "automatic setup based on the frequency range of interest).", units=HERTZ, ) - bound_f: NonNegativeFloat = Field( + bound_f: Optional[NonNegativeFloat] = Field( None, title="Upper bound of pole frequency", description="Upper bound of real and imaginary part of ``a`` that corresponds to pole " @@ -96,39 +95,38 @@ class AdvancedFitterParam(Tidy3dBaseModel): lt=2**32, ) - @validator("bound_f_lower", always=True) - @skip_if_fields_missing(["bound_f"]) - def _validate_lower_frequency_bound(cls, val, values): + @model_validator(mode="after") + def _validate_lower_frequency_bound(self): """bound_f_lower cannot be larger than bound_f.""" - if values["bound_f"] is not None and val > values["bound_f"]: + if self.bound_f is not None and self.bound_f_lowerval > self.bound_f: raise SetupError( "The upper bound 'bound_f' cannot be smaller " "than the lower bound 'bound_f_lower'." ) - return val + return self class FitterData(AdvancedFitterParam): """Data class for request body of Fitter where dipsersion data is input through tuple.""" - wvl_um: Tuple[float, ...] = Field( - ..., + wvl_um: tuple[float, ...] = Field( title="Wavelengths", description="A set of wavelengths for dispersion data.", units=MICROMETER, ) - n_data: Tuple[float, ...] = Field( - ..., + n_data: tuple[float, ...] = Field( title="Index of refraction", description="Real part of the complex index of refraction at each wavelength.", ) - k_data: Tuple[float, ...] = Field( + k_data: Optional[tuple[float, ...]] = Field( None, title="Extinction coefficient", description="Imaginary part of the complex index of refraction at each wavelength.", ) num_poles: PositiveInt = Field( - 1, title="Number of poles", description="Number of poles in model." + 1, + title="Number of poles", + description="Number of poles in model.", ) num_tries: PositiveInt = Field( 50, @@ -255,12 +253,12 @@ def _setup_server(url_server: str): return get_headers() - def run(self) -> Tuple[PoleResidue, float]: + def run(self) -> tuple[PoleResidue, float]: """Execute the data fit using the stable fitter in the server. Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ @@ -317,7 +315,7 @@ def run( num_tries: PositiveInt = 50, tolerance_rms: NonNegativeFloat = 1e-2, advanced_param: AdvancedFitterParam = AdvancedFitterParam(), -) -> Tuple[PoleResidue, float]: +) -> tuple[PoleResidue, float]: """Execute the data fit using the stable fitter in the server. Parameters @@ -335,7 +333,7 @@ def run( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ task = FitterData.create(fitter, num_poles, num_tries, tolerance_rms, advanced_param) @@ -345,13 +343,13 @@ def run( class StableDispersionFitter(DispersionFitter): """Deprecated.""" - @pydantic.root_validator() - def _deprecate_stable_fitter(cls, values): + @model_validator(mode="before") + def _deprecate_stable_fitter(data): log.warning( "'StableDispersionFitter' has been deprecated. Use 'DispersionFitter' with " "'tidy3d.plugins.dispersion.web.run' to access the stable fitter from the web server." ) - return values + return data def fit( self, @@ -360,6 +358,6 @@ def fit( tolerance_rms: NonNegativeFloat = 1e-2, guess: PoleResidue = None, advanced_param: AdvancedFitterParam = AdvancedFitterParam(), - ) -> Tuple[PoleResidue, float]: + ) -> tuple[PoleResidue, float]: """Deprecated.""" return run(self, num_poles, num_tries, tolerance_rms, advanced_param) diff --git a/tidy3d/plugins/expressions/__init__.py b/tidy3d/plugins/expressions/__init__.py index 1f8f733f3a..3805669331 100644 --- a/tidy3d/plugins/expressions/__init__.py +++ b/tidy3d/plugins/expressions/__init__.py @@ -1,22 +1,44 @@ from .base import Expression from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan from .metrics import ModeAmp, ModePower, generate_validation_data +from .operators import ( + Abs, + Add, + Divide, + FloorDivide, + MatMul, + Modulus, + Multiply, + Negate, + Power, + Subtract, +) from .variables import Constant, Variable __all__ = [ - "Expression", + "Abs", + "Add", "Constant", - "Variable", - "ModeAmp", - "ModePower", - "generate_validation_data", - "Sin", "Cos", - "Tan", + "Divide", "Exp", + "Expression", + "FloorDivide", "Log", "Log10", + "MatMul", + "ModeAmp", + "ModePower", + "Modulus", + "Multiply", + "Negate", + "Power", + "Sin", "Sqrt", + "Subtract", + "Tan", + "Variable", + "generate_validation_data", ] # The following code dynamically collects all classes that are subclasses of Expression @@ -41,4 +63,4 @@ _local_vars[name] = obj for cls in _model_classes: - cls.update_forward_refs(**_local_vars) + cls.model_rebuild(force=True) diff --git a/tidy3d/plugins/expressions/base.py b/tidy3d/plugins/expressions/base.py index f24cd57cc3..5c8425981e 100644 --- a/tidy3d/plugins/expressions/base.py +++ b/tidy3d/plugins/expressions/base.py @@ -33,9 +33,6 @@ class Expression(Tidy3dBaseModel, ABC): It provides common functionality and operator overloading for derived classes. """ - class Config: - smart_union = True - @abstractmethod def evaluate(self, *args: Any, **kwargs: Any) -> NumberType: pass @@ -45,7 +42,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> NumberType: def __init_subclass__(cls, **kwargs: dict[str, Any]) -> None: super().__init_subclass__(**kwargs) - type_value = cls.__fields__.get(TYPE_TAG_STR) + type_value = cls.model_fields.get(TYPE_TAG_STR) if type_value and type_value.default: TYPE_TO_CLASS_MAP[type_value.default] = cls @@ -88,8 +85,8 @@ def _find_instances(expr: Expression): yield value else: yield expr - for field in expr.__fields__.values(): - value = getattr(expr, field.name) + for name in expr.model_fields: + value = getattr(expr, name) if isinstance(value, Expression): yield from _find_instances(value) elif isinstance(value, list): diff --git a/tidy3d/plugins/expressions/functions.py b/tidy3d/plugins/expressions/functions.py index 5bcfcf5c1a..dcbcb467b5 100644 --- a/tidy3d/plugins/expressions/functions.py +++ b/tidy3d/plugins/expressions/functions.py @@ -1,7 +1,7 @@ from typing import Any import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, field_validator from .base import Expression from .types import NumberOrExpression, NumberType @@ -12,15 +12,15 @@ class Function(Expression): Base class for mathematical functions in expressions. """ - operand: NumberOrExpression = pd.Field( - ..., + operand: NumberOrExpression = Field( title="Operand", description="The operand for the function.", ) _format: str = "{func}({operand})" - @pd.validator("operand", pre=True, always=True) + @field_validator("operand") + @classmethod def validate_operand(cls, v): """ Validate and convert operand to an expression. diff --git a/tidy3d/plugins/expressions/metrics.py b/tidy3d/plugins/expressions/metrics.py index 04aa79c405..7086421c74 100644 --- a/tidy3d/plugins/expressions/metrics.py +++ b/tidy3d/plugins/expressions/metrics.py @@ -2,8 +2,8 @@ from typing import Any, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, NonNegativeInt from tidy3d.components.monitor import ModeMonitor from tidy3d.components.types import Direction, FreqArray @@ -62,23 +62,22 @@ class ModeAmp(Metric): (abs(ModeAmp("monitor1")) ** 2) """ - monitor_name: str = pd.Field( - ..., + monitor_name: str = Field( title="Monitor Name", description="The name of the mode monitor. This needs to match the name of the monitor in the simulation.", ) - f: Optional[Union[float, FreqArray]] = pd.Field( # type: ignore + f: Optional[Union[float, FreqArray]] = Field( None, title="Frequency Array", description="The frequency array. If None, all frequencies in the monitor will be used.", alias="freqs", ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Direction", description="The direction of propagation of the mode.", ) - mode_index: pd.NonNegativeInt = pd.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="The index of the mode.", diff --git a/tidy3d/plugins/expressions/operators.py b/tidy3d/plugins/expressions/operators.py index e25004180e..62b74225f8 100644 --- a/tidy3d/plugins/expressions/operators.py +++ b/tidy3d/plugins/expressions/operators.py @@ -2,7 +2,7 @@ from typing import Any -import pydantic.v1 as pd +from pydantic import Field, field_validator from .base import Expression from .types import NumberOrExpression, NumberType @@ -16,8 +16,7 @@ class UnaryOperator(Expression): Subclasses should implement the evaluate method to define the specific operation. """ - operand: NumberOrExpression = pd.Field( - ..., + operand: NumberOrExpression = Field( title="Operand", description="The operand for the unary operator.", ) @@ -25,7 +24,8 @@ class UnaryOperator(Expression): _symbol: str _format: str = "({symbol}{operand})" - @pd.validator("operand", pre=True, always=True) + @field_validator("operand") + @classmethod def validate_operand(cls, v): return cls._to_expression(v) @@ -41,13 +41,11 @@ class BinaryOperator(Expression): Subclasses should implement the evaluate method to define the specific operation. """ - left: NumberOrExpression = pd.Field( - ..., + left: NumberOrExpression = Field( title="Left", description="The left operand for the binary operator.", ) - right: NumberOrExpression = pd.Field( - ..., + right: NumberOrExpression = Field( title="Right", description="The right operand for the binary operator.", ) @@ -55,7 +53,8 @@ class BinaryOperator(Expression): _symbol: str _format: str = "({left} {symbol} {right})" - @pd.validator("left", "right", pre=True, always=True) + @field_validator("left", "right") + @classmethod def validate_operands(cls, v): return cls._to_expression(v) diff --git a/tidy3d/plugins/expressions/types.py b/tidy3d/plugins/expressions/types.py index 861e86353b..5aaa068604 100644 --- a/tidy3d/plugins/expressions/types.py +++ b/tidy3d/plugins/expressions/types.py @@ -1,8 +1,6 @@ -from typing import TYPE_CHECKING, Annotated, Union +from typing import TYPE_CHECKING, Union -from pydantic.v1 import Field - -from tidy3d.components.types import TYPE_TAG_STR, ArrayLike, Complex +from tidy3d.components.types import ArrayLike, Complex, discriminated_union if TYPE_CHECKING: from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan @@ -23,7 +21,7 @@ NumberType = Union[int, float, Complex, ArrayLike] -OperatorType = Annotated[ +OperatorType = discriminated_union( Union[ "Add", "Subtract", @@ -35,11 +33,10 @@ "MatMul", "Negate", "Abs", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) -FunctionType = Annotated[ +FunctionType = discriminated_union( Union[ "Sin", "Cos", @@ -48,19 +45,17 @@ "Log", "Log10", "Sqrt", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) -MetricType = Annotated[ +MetricType = discriminated_union( Union[ "Constant", "Variable", "ModeAmp", "ModePower", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) ExpressionType = Union[ OperatorType, @@ -68,4 +63,4 @@ MetricType, ] -NumberOrExpression = Union[NumberType, ExpressionType] +NumberOrExpression = Union[ExpressionType, NumberType] diff --git a/tidy3d/plugins/expressions/variables.py b/tidy3d/plugins/expressions/variables.py index ae32591c1e..2d99318bfb 100644 --- a/tidy3d/plugins/expressions/variables.py +++ b/tidy3d/plugins/expressions/variables.py @@ -1,6 +1,6 @@ from typing import Any, Optional -import pydantic.v1 as pd +from pydantic import Field from .base import Expression from .types import NumberType @@ -37,7 +37,7 @@ class Variable(Expression): 10 """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Name", description="The name of the variable used for lookup during evaluation.", @@ -80,8 +80,7 @@ class Constant(Variable): 5 """ - value: NumberType = pd.Field( - ..., + value: NumberType = Field( title="Value", description="The fixed value of the constant.", ) diff --git a/tidy3d/plugins/invdes/base.py b/tidy3d/plugins/invdes/base.py index bde96efcf0..41638225f0 100644 --- a/tidy3d/plugins/invdes/base.py +++ b/tidy3d/plugins/invdes/base.py @@ -1,10 +1,8 @@ # base class for all of the invdes fields -from __future__ import annotations +from abc import ABC -import abc +from tidy3d.components.base import Tidy3dBaseModel -import tidy3d as td - -class InvdesBaseModel(td.components.base.Tidy3dBaseModel, abc.ABC): +class InvdesBaseModel(Tidy3dBaseModel, ABC): """Base class for ``invdes`` components, in case we need it.""" diff --git a/tidy3d/plugins/invdes/design.py b/tidy3d/plugins/invdes/design.py index f8a4487b07..12fd632352 100644 --- a/tidy3d/plugins/invdes/design.py +++ b/tidy3d/plugins/invdes/design.py @@ -3,11 +3,11 @@ from __future__ import annotations import abc -import typing +from typing import Callable, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator import tidy3d as td from tidy3d.components.autograd import get_static @@ -19,39 +19,37 @@ from .region import DesignRegionType from .validators import check_pixel_size -PostProcessFnType = typing.Callable[[td.SimulationData], float] +PostProcessFnType = Callable[[td.SimulationData], float] class AbstractInverseDesign(InvdesBaseModel, abc.ABC): """Container for an inverse design problem.""" - design_region: DesignRegionType = pd.Field( - ..., + design_region: DesignRegionType = Field( title="Design Region", description="Region within which we will optimize the simulation.", ) - task_name: str = pd.Field( - ..., + task_name: str = Field( title="Task Name", description="Task name to use in the objective function when running the ``JaxSimulation``.", ) - verbose: bool = pd.Field( + verbose: bool = Field( False, title="Task Verbosity", description="If ``True``, will print the regular output from ``web`` functions.", ) - metric: typing.Optional[ExpressionType] = pd.Field( + metric: Optional[ExpressionType] = Field( None, title="Objective Metric", description="Serializable expression defining the objective function.", ) def make_objective_fn( - self, post_process_fn: typing.Optional[typing.Callable] = None, maximize: bool = True - ) -> typing.Callable[[anp.ndarray], tuple[float, dict]]: + self, post_process_fn: Optional[Callable] = None, maximize: bool = True + ) -> Callable[[anp.ndarray], tuple[float, dict]]: """Construct the objective function for this InverseDesign object.""" if (post_process_fn is None) and (self.metric is None): @@ -118,13 +116,12 @@ def run_async(self, simulations, **kwargs) -> web.BatchData: # noqa: F821 class InverseDesign(AbstractInverseDesign): """Container for an inverse design problem.""" - simulation: td.Simulation = pd.Field( - ..., + simulation: td.Simulation = Field( title="Base Simulation", description="Simulation without the design regions or monitors used in the objective fn.", ) - output_monitor_names: typing.Tuple[str, ...] = pd.Field( + output_monitor_names: Optional[tuple[str, ...]] = Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." @@ -136,23 +133,16 @@ class InverseDesign(AbstractInverseDesign): _check_sim_pixel_size = check_pixel_size("simulation") - @pd.root_validator(pre=False) - def _validate_model(cls, values: dict) -> dict: - cls._validate_metric(values) - return values - - @staticmethod - def _validate_metric(values: dict) -> dict: - metric_expr = values.get("metric") - if not metric_expr: - return values - simulation = values.get("simulation") - for metric in metric_expr.filter(Metric): - InverseDesign._validate_metric_monitor_name(metric, simulation) - InverseDesign._validate_metric_mode_index(metric, simulation) - InverseDesign._validate_metric_f(metric, simulation) - InverseDesign._validate_metric_data(metric_expr, simulation) - return values + @model_validator(mode="after") + def _validate_model(self): + if not self.metric: + return self + for metric in self.metric.filter(Metric): + InverseDesign._validate_metric_monitor_name(metric, self.simulation) + InverseDesign._validate_metric_mode_index(metric, self.simulation) + InverseDesign._validate_metric_f(metric, self.simulation) + InverseDesign._validate_metric_data(metric, self.simulation) + return self @staticmethod def _validate_metric_monitor_name(metric: Metric, simulation: td.Simulation) -> None: @@ -221,7 +211,7 @@ def is_output_monitor(self, monitor: td.Monitor) -> bool: return monitor.name in self.output_monitor_names - def separate_output_monitors(self, monitors: typing.Tuple[td.Monitor]) -> dict: + def separate_output_monitors(self, monitors: tuple[td.Monitor]) -> dict: """Separate monitors into output_monitors and regular monitors.""" monitor_fields = dict(monitors=[], output_monitors=[]) @@ -260,13 +250,12 @@ def to_simulation_data(self, params: anp.ndarray, **kwargs) -> td.SimulationData class InverseDesignMulti(AbstractInverseDesign): """``InverseDesign`` with multiple simulations and corresponding postprocess functions.""" - simulations: typing.Tuple[td.Simulation, ...] = pd.Field( - ..., + simulations: tuple[td.Simulation, ...] = Field( title="Base Simulations", description="Set of simulation without the design regions or monitors used in the objective fn.", ) - output_monitor_names: typing.Tuple[typing.Union[typing.Tuple[str, ...], None], ...] = pd.Field( + output_monitor_names: Optional[tuple[Union[tuple[str, ...], None], ...]] = Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." @@ -278,12 +267,12 @@ class InverseDesignMulti(AbstractInverseDesign): _check_sim_pixel_size = check_pixel_size("simulations") - @pd.root_validator() - def _check_lengths(cls, values): + @model_validator(mode="after") + def _check_lengths(self): """Check the lengths of all of the multi fields.""" keys = ("simulations", "post_process_fns", "output_monitor_names", "override_structure_dl") - multi_dict = {key: values.get(key) for key in keys} + multi_dict = {key: getattr(self, key) for key in keys} sizes = {key: len(val) for key, val in multi_dict.items() if val is not None} if len(set(sizes.values())) != 1: @@ -293,7 +282,7 @@ def _check_lengths(cls, values): "corresponding sizes of '{sizes}'." ) - return values + return self @property def task_names(self) -> list[str]: @@ -301,7 +290,7 @@ def task_names(self) -> list[str]: return [f"{self.task_name}_{i}" for i in range(len(self.simulations))] @property - def designs(self) -> typing.List[InverseDesign]: + def designs(self) -> list[InverseDesign]: """List of individual ``InverseDesign`` objects corresponding to this instance.""" designs_list = [] @@ -330,4 +319,4 @@ def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData: # return self.run_async(simulations, **kwargs) -InverseDesignType = typing.Union[InverseDesign, InverseDesignMulti] +InverseDesignType = Union[InverseDesign, InverseDesignMulti] diff --git a/tidy3d/plugins/invdes/initialization.py b/tidy3d/plugins/invdes/initialization.py index eb75a940c4..9b3873dfd0 100644 --- a/tidy3d/plugins/invdes/initialization.py +++ b/tidy3d/plugins/invdes/initialization.py @@ -6,8 +6,8 @@ from typing import Optional, Union import numpy as np -import pydantic.v1 as pd from numpy.typing import NDArray +from pydantic import Field, NonNegativeInt, field_validator, model_validator import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel @@ -30,34 +30,33 @@ class RandomInitializationSpec(AbstractInitializationSpec): When a seed is provided, a call to `create_parameters` will always return the same array. """ - min_value: float = pd.Field( + min_value: float = Field( 0.0, ge=0.0, le=1.0, title="Minimum Value", description="Minimum value for the random parameters (inclusive).", ) - max_value: float = pd.Field( + max_value: float = Field( 1.0, ge=0.0, le=1.0, title="Maximum Value", description="Maximum value for the random parameters (exclusive).", ) - seed: Optional[pd.NonNegativeInt] = pd.Field( - None, description="Seed for the random number generator." + seed: Optional[NonNegativeInt] = Field( + None, + description="Seed for the random number generator.", ) - @pd.root_validator(pre=False) - def _validate_max_ge_min(cls, values): + @model_validator(mode="after") + def _validate_max_ge_min(self): """Ensure that max_value is greater than or equal to min_value.""" - minval = values.get("min_value") - maxval = values.get("max_value") - if minval > maxval: + if self.minval > self.maxval: raise ValidationError( - f"'max_value' ({maxval}) must be greater or equal than 'min_value' ({minval})" + f"'max_value' ({self.maxval}) must be greater or equal than 'min_value' ({self.minval})" ) - return values + return self def create_parameters(self, shape: tuple[int, ...]) -> NDArray: """Generate the parameter array based on the specification.""" @@ -68,7 +67,7 @@ def create_parameters(self, shape: tuple[int, ...]) -> NDArray: class UniformInitializationSpec(AbstractInitializationSpec): """Specification for uniform initial parameters.""" - value: float = pd.Field( + value: float = Field( 0.5, ge=0.0, le=1.0, @@ -84,38 +83,38 @@ def create_parameters(self, shape: tuple[int, ...]) -> NDArray: class CustomInitializationSpec(AbstractInitializationSpec): """Specification for custom initial parameters provided by the user.""" - params: ArrayLike = pd.Field( + params: ArrayLike = Field( ..., title="Parameters", description="Custom parameters provided by the user.", ) - @pd.validator("params") - def _validate_params_range(cls, value, values): + @field_validator("params") + def _validate_params_range(val): """Ensure that all parameter values are between 0 and 1.""" - if np.any((value < 0) | (value > 1)): + if np.any((val < 0) | (val > 1)): raise ValidationError("'params' need to be between 0 and 1.") - return value + return val - @pd.validator("params") - def _validate_params_dtype(cls, value, values): + @field_validator("params") + def _validate_params_dtype(val): """Ensure that params is real-valued.""" - if np.issubdtype(value.dtype, np.bool_): + if np.issubdtype(val.dtype, np.bool_): td.log.warning( "Got a boolean array for 'params'. " "This will be treated as a floating point array." ) - value = value.astype(float) - elif not np.issubdtype(value.dtype, np.floating): - raise ValidationError(f"'params' need to be real-valued, but got '{value.dtype}'.") - return value + val = val.astype(float) + elif not np.issubdtype(val.dtype, np.floating): + raise ValidationError(f"'params' need to be real-valued, but got '{val.dtype}'.") + return val - @pd.validator("params") - def _validate_params_3d(cls, value, values): + @field_validator("params") + def _validate_params_3d(val): """Ensure that params is a 3D array.""" - if value.ndim != 3: - raise ValidationError(f"'params' must be 3D, but got {value.ndim}D.") - return value + if val.ndim != 3: + raise ValidationError(f"'params' must be 3D, but got {val.ndim}D.") + return val def create_parameters(self, shape: tuple[int, ...]) -> NDArray: """Return the custom parameters provided by the user.""" diff --git a/tidy3d/plugins/invdes/optimizer.py b/tidy3d/plugins/invdes/optimizer.py index 08758f12b6..9aa5616ea4 100644 --- a/tidy3d/plugins/invdes/optimizer.py +++ b/tidy3d/plugins/invdes/optimizer.py @@ -1,13 +1,13 @@ # specification for running the optimizer import abc -import typing from copy import deepcopy +from typing import Callable, Optional import autograd as ag import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt import tidy3d as td from tidy3d.components.types import TYPE_TAG_STR @@ -20,32 +20,29 @@ class AbstractOptimizer(InvdesBaseModel, abc.ABC): """Specification for an optimization.""" - design: InverseDesignType = pd.Field( - ..., + design: InverseDesignType = Field( title="Inverse Design Specification", description="Specification describing the inverse design problem we wish to optimize.", discriminator=TYPE_TAG_STR, ) - learning_rate: pd.PositiveFloat = pd.Field( - ..., + learning_rate: PositiveFloat = Field( title="Learning Rate", description="Step size for the gradient descent optimizer.", ) - maximize: bool = pd.Field( + maximize: bool = Field( True, title="Direction of Optimization", description="If ``True``, the optimizer will maximize the objective function. If ``False``, the optimizer will minimize the objective function.", ) - num_steps: pd.PositiveInt = pd.Field( - ..., + num_steps: PositiveInt = Field( title="Number of Steps", description="Number of steps in the gradient descent optimizer.", ) - results_cache_fname: str = pd.Field( + results_cache_fname: Optional[str] = Field( None, title="History Storage File", description="If specified, will save the optimization state to a local ``.pkl`` file " @@ -57,7 +54,7 @@ class AbstractOptimizer(InvdesBaseModel, abc.ABC): "``optimizer.continue_run(result)``. ", ) - store_full_results: bool = pd.Field( + store_full_results: bool = Field( True, title="Store Full Results", description="If ``True``, stores the full history for the vector fields, specifically " @@ -83,9 +80,7 @@ def display_fn(self, result: InverseDesignResult, step_index: int) -> None: print(f"\tpost_process_val = {result.post_process_val[-1]:.3e}") print(f"\tpenalty = {result.penalty[-1]:.3e}") - def initialize_result( - self, params0: typing.Optional[anp.ndarray] = None - ) -> InverseDesignResult: + def initialize_result(self, params0: Optional[anp.ndarray] = None) -> InverseDesignResult: """ Create an initially empty `InverseDesignResult` from the starting parameters. @@ -110,8 +105,8 @@ def initialize_result( def run( self, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, params0: anp.ndarray = None, ) -> InverseDesignResult: """Run this inverse design problem from an optional initial set of parameters. @@ -140,8 +135,8 @@ def continue_run( self, result: InverseDesignResult, num_steps: int = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Run optimizer for a series of steps with an initialized state. @@ -230,8 +225,8 @@ def continue_run_from_file( self, fname: str, num_steps: int = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Continue the optimization run from a ``.pkl`` file with an ``InverseDesignResult``.""" result = InverseDesignResult.from_file(fname) @@ -245,8 +240,8 @@ def continue_run_from_file( def continue_run_from_history( self, num_steps: int = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Continue the optimization run from a ``.pkl`` file with an ``InverseDesignResult``.""" return self.continue_run_from_file( @@ -260,7 +255,7 @@ def continue_run_from_history( class AdamOptimizer(AbstractOptimizer): """Specification for an optimization.""" - beta1: float = pd.Field( + beta1: float = Field( 0.9, ge=0.0, le=1.0, @@ -268,7 +263,7 @@ class AdamOptimizer(AbstractOptimizer): description="Beta 1 parameter in the Adam optimization method.", ) - beta2: float = pd.Field( + beta2: float = Field( 0.999, ge=0.0, le=1.0, @@ -276,7 +271,7 @@ class AdamOptimizer(AbstractOptimizer): description="Beta 2 parameter in the Adam optimization method.", ) - eps: pd.PositiveFloat = pd.Field( + eps: PositiveFloat = Field( 1e-8, title="Epsilon", description="Epsilon parameter in the Adam optimization method.", diff --git a/tidy3d/plugins/invdes/penalty.py b/tidy3d/plugins/invdes/penalty.py index 621cbfb9e4..1a9682ab83 100644 --- a/tidy3d/plugins/invdes/penalty.py +++ b/tidy3d/plugins/invdes/penalty.py @@ -1,10 +1,10 @@ # define penalties applied to parameters from design region import abc -import typing +from typing import Union import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.constants import MICROMETER from tidy3d.plugins.autograd.invdes import make_erosion_dilation_penalty @@ -15,7 +15,7 @@ class AbstractPenalty(InvdesBaseModel, abc.ABC): """Base class for penalties added to ``invdes.DesignRegion`` objects.""" - weight: pd.NonNegativeFloat = pd.Field( + weight: NonNegativeFloat = Field( 1.0, title="Weight", description="When this penalty is evaluated, it will be weighted by this " @@ -48,8 +48,7 @@ class ErosionDilationPenalty(AbstractPenalty): """ - length_scale: pd.PositiveFloat = pd.Field( - ..., + length_scale: PositiveFloat = Field( title="Length Scale", description="Length scale of erosion and dilation. " "Corresponds to ``radius`` in the :class:`ConicFilter` used for filtering. " @@ -58,7 +57,7 @@ class ErosionDilationPenalty(AbstractPenalty): units=MICROMETER, ) - beta: float = pd.Field( + beta: float = Field( 100.0, ge=1.0, title="Projection Beta", @@ -67,7 +66,7 @@ class ErosionDilationPenalty(AbstractPenalty): "Higher values correspond to stronger discretization.", ) - eta0: float = pd.Field( + eta0: float = Field( 0.5, ge=0.0, le=1.0, @@ -77,7 +76,7 @@ class ErosionDilationPenalty(AbstractPenalty): "Corresponds to ``eta`` in the :class:`BinaryProjector`.", ) - delta_eta: float = pd.Field( + delta_eta: float = Field( 0.01, ge=0.0, le=1.0, @@ -97,4 +96,4 @@ def evaluate(self, x: anp.ndarray, pixel_size: float) -> float: return self.weight * penalty_unweighted -PenaltyType = typing.Union[ErosionDilationPenalty] +PenaltyType = Union[ErosionDilationPenalty] diff --git a/tidy3d/plugins/invdes/region.py b/tidy3d/plugins/invdes/region.py index b94eaf8682..4ad715a4ab 100644 --- a/tidy3d/plugins/invdes/region.py +++ b/tidy3d/plugins/invdes/region.py @@ -1,13 +1,13 @@ # container for specification fully defining the inverse design problem import abc -import typing import warnings +from typing import Literal, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd from autograd import elementwise_grad, grad +from pydantic import Field, PositiveFloat import tidy3d as td from tidy3d.components.types import TYPE_TAG_STR, Coordinate, Size @@ -24,28 +24,25 @@ class DesignRegion(InvdesBaseModel, abc.ABC): """Base class for design regions in the ``invdes`` plugin.""" - size: Size = pd.Field( - ..., + size: Size = Field( title="Size", description="Size in x, y, and z directions.", units=td.constants.MICROMETER, ) - center: Coordinate = pd.Field( - ..., + center: Coordinate = Field( title="Center", description="Center of object in x, y, and z.", units=td.constants.MICROMETER, ) - eps_bounds: typing.Tuple[float, float] = pd.Field( - ..., + eps_bounds: tuple[float, float] = Field( ge=1.0, title="Relative Permittivity Bounds", description="Minimum and maximum relative permittivity expressed to the design region.", ) - transformations: typing.Tuple[TransformationType, ...] = pd.Field( + transformations: tuple[TransformationType, ...] = Field( (), title="Transformations", description="Transformations that get applied from first to last on the parameter array." @@ -55,7 +52,7 @@ class DesignRegion(InvdesBaseModel, abc.ABC): "Specific permittivity values given the density array are determined by ``eps_bounds``.", ) - penalties: typing.Tuple[PenaltyType, ...] = pd.Field( + penalties: tuple[PenaltyType, ...] = Field( (), title="Penalties", description="Set of penalties that get evaluated on the material density. Note that the " @@ -63,18 +60,21 @@ class DesignRegion(InvdesBaseModel, abc.ABC): "inside of the penalties directly through the ``.weight`` field.", ) - initialization_spec: InitializationSpecType = pd.Field( - UniformInitializationSpec(value=0.5), + initialization_spec: InitializationSpecType = Field( + default_factory=lambda: UniformInitializationSpec(value=0.5), title="Initialization Specification", description="Specification of how to initialize the parameters in the design region.", discriminator=TYPE_TAG_STR, ) + @property def _post_init_validators(self): - """Automatically call any `_validate_XXX` method.""" + """Return any `_validate_XXX` method.""" + validators = [] for attr_name in dir(self): if attr_name.startswith("_validate") and callable(getattr(self, attr_name)): - getattr(self, attr_name)() + validators.append(getattr(self, attr_name)) + return tuple(validators) def _validate_eps_bounds(self): if self.eps_bounds[1] < self.eps_bounds[0]: @@ -131,8 +131,7 @@ def initial_parameters(self) -> np.ndarray: class TopologyDesignRegion(DesignRegion): """Design region as a pixellated permittivity grid.""" - pixel_size: pd.PositiveFloat = pd.Field( - ..., + pixel_size: PositiveFloat = Field( title="Pixel Size", description="Pixel size of the design region in x, y, z. For now, we only support the same " "pixel size in all 3 dimensions. If ``TopologyDesignRegion.override_structure_dl`` is left " @@ -142,14 +141,14 @@ class TopologyDesignRegion(DesignRegion): "a value on the same order as the grid size.", ) - uniform: tuple[bool, bool, bool] = pd.Field( + uniform: tuple[bool, bool, bool] = Field( (False, False, True), title="Uniform", description="Axes along which the design should be uniform. By default, the structure " "is assumed to be uniform, i.e. invariant, in the z direction.", ) - transformations: typing.Tuple[TransformationType, ...] = pd.Field( + transformations: tuple[TransformationType, ...] = Field( (), title="Transformations", description="Transformations that get applied from first to last on the parameter array." @@ -158,7 +157,7 @@ class TopologyDesignRegion(DesignRegion): "permittivity and 1 corresponds to the maximum relative permittivity. " "Specific permittivity values given the density array are determined by ``eps_bounds``.", ) - penalties: typing.Tuple[PenaltyType, ...] = pd.Field( + penalties: tuple[PenaltyType, ...] = Field( (), title="Penalties", description="Set of penalties that get evaluated on the material density. Note that the " @@ -166,7 +165,7 @@ class TopologyDesignRegion(DesignRegion): "inside of the penalties directly through the ``.weight`` field.", ) - override_structure_dl: typing.Union[pd.PositiveFloat, typing.Literal[False]] = pd.Field( + override_structure_dl: Optional[Union[PositiveFloat, Literal[False]]] = Field( None, title="Design Region Override Structure", description="Defines grid size when adding an ``override_structure`` to the " @@ -244,7 +243,7 @@ def _check_params(params: anp.ndarray = None): ) @property - def params_shape(self) -> typing.Tuple[int, int, int]: + def params_shape(self) -> tuple[int, int, int]: """Shape of the parameters array in (x, y, z), given the ``pixel_size`` and bounds.""" side_lengths = np.array(self.size) num_pixels = np.ceil(side_lengths / self.pixel_size) @@ -289,7 +288,7 @@ def params_ones(self): return self.params_uniform(1.0) @property - def coords(self) -> typing.Dict[str, typing.List[float]]: + def coords(self) -> dict[str, list[float]]: """Coordinates for the custom medium corresponding to this design region.""" lengths = np.array(self.size) @@ -365,4 +364,4 @@ def evaluate_penalty(self, penalty: PenaltyType, material_density: anp.ndarray) return penalty.evaluate(x=material_density, pixel_size=self.pixel_size) -DesignRegionType = typing.Union[TopologyDesignRegion] +DesignRegionType = Union[TopologyDesignRegion] diff --git a/tidy3d/plugins/invdes/result.py b/tidy3d/plugins/invdes/result.py index 86e5fcd015..01a74f34e8 100644 --- a/tidy3d/plugins/invdes/result.py +++ b/tidy3d/plugins/invdes/result.py @@ -1,10 +1,10 @@ # convenient container for the output of the inverse design (specifically the history) -import typing +from typing import Any, Union import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd +from pydantic import Field import tidy3d as td from tidy3d.components.types import ArrayLike @@ -18,56 +18,55 @@ class InverseDesignResult(InvdesBaseModel): """Container for the result of an ``InverseDesign.run()`` call.""" - design: InverseDesignType = pd.Field( - ..., + design: InverseDesignType = Field( title="Inverse Design Specification", description="Specification describing the inverse design problem we wish to optimize.", ) - params: typing.Tuple[ArrayLike, ...] = pd.Field( + params: tuple[ArrayLike, ...] = Field( (), title="Parameter History", description="History of parameter arrays throughout the optimization.", ) - objective_fn_val: typing.Tuple[float, ...] = pd.Field( + objective_fn_val: tuple[float, ...] = Field( (), title="Objective Function History", description="History of objective function values throughout the optimization.", ) - grad: typing.Tuple[ArrayLike, ...] = pd.Field( + grad: tuple[ArrayLike, ...] = Field( (), title="Gradient History", description="History of objective function gradient arrays throughout the optimization.", ) - penalty: typing.Tuple[float, ...] = pd.Field( + penalty: tuple[float, ...] = Field( (), title="Penalty History", description="History of weighted sum of penalties throughout the optimization.", ) - post_process_val: typing.Tuple[float, ...] = pd.Field( + post_process_val: tuple[float, ...] = Field( (), title="Post-Process Function History", description="History of return values from ``post_process_fn`` throughout the optimization.", ) - simulation: typing.Tuple[td.Simulation, ...] = pd.Field( + simulation: tuple[td.Simulation, ...] = Field( (), title="Simulation History", description="History of ``td.Simulation`` instances throughout the optimization.", ) - opt_state: typing.Tuple[dict, ...] = pd.Field( + opt_state: tuple[dict, ...] = Field( (), title="Optimizer State History", description="History of optimizer states throughout the optimization.", ) @property - def history(self) -> typing.Dict[str, list]: + def history(self) -> dict[str, list]: """The history-containing fields as a dictionary of lists.""" return dict( params=list(self.params), @@ -79,16 +78,16 @@ def history(self) -> typing.Dict[str, list]: ) @property - def keys(self) -> typing.List[str]: + def keys(self) -> list[str]: """Keys stored in the history.""" return list(self.history.keys()) @property - def last(self) -> typing.Dict[str, typing.Any]: + def last(self) -> dict[str, Any]: """Dictionary of last values in ``self.history``.""" return {key: value[-1] for key, value in self.history.items()} - def get(self, key: str, index: int = -1) -> typing.Any: + def get(self, key: str, index: int = -1) -> Any: """Get the value from the history at a certain index (-1 means last).""" if key not in self.keys: raise KeyError(f"'{key}' not present in 'Result.history' dict with: {self.keys}.") @@ -97,24 +96,24 @@ def get(self, key: str, index: int = -1) -> typing.Any: raise ValueError(f"Can't get the last value of '{key}' as there is no history present.") return values[index] - def get_last(self, key: str) -> typing.Any: + def get_last(self, key: str) -> Any: """Get the last value from the history.""" return self.get(key=key, index=-1) - def get_sim(self, index: int = -1) -> typing.Union[td.Simulation, typing.List[td.Simulation]]: + def get_sim(self, index: int = -1) -> Union[td.Simulation, list[td.Simulation]]: """Get the simulation at a specific index in the history (list of sims if multi).""" params = np.array(self.get(key="params", index=index)) return self.design.to_simulation(params=params) def get_sim_data( self, index: int = -1, **kwargs - ) -> typing.Union[td.SimulationData, typing.List[td.SimulationData]]: + ) -> Union[td.SimulationData, list[td.SimulationData]]: """Get the simulation data at a specific index in the history (list of simdata if multi).""" params = np.array(self.get(key="params", index=index)) return self.design.to_simulation_data(params=params, **kwargs) @property - def sim_last(self) -> typing.Union[td.Simulation, typing.List[td.Simulation]]: + def sim_last(self) -> Union[td.Simulation, list[td.Simulation]]: """The last simulation.""" return self.get_sim(index=-1) diff --git a/tidy3d/plugins/invdes/transformation.py b/tidy3d/plugins/invdes/transformation.py index cb9e343ceb..882f09a4dd 100644 --- a/tidy3d/plugins/invdes/transformation.py +++ b/tidy3d/plugins/invdes/transformation.py @@ -1,10 +1,10 @@ # transformations applied to design region import abc -import typing +from typing import Union import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat import tidy3d as td from tidy3d.plugins.autograd.functions import threshold @@ -34,8 +34,7 @@ class FilterProject(InvdesBaseModel): """ - radius: pd.PositiveFloat = pd.Field( - ..., + radius: PositiveFloat = Field( title="Filter Radius", description="Radius of the filter to convolve with supplied spatial data. " "Note: the corresponding feature size expressed in the device is typically " @@ -48,7 +47,7 @@ class FilterProject(InvdesBaseModel): units=td.constants.MICROMETER, ) - beta: float = pd.Field( + beta: float = Field( 1.0, ge=1.0, title="Beta", @@ -57,11 +56,15 @@ class FilterProject(InvdesBaseModel): "at the expense of gradient accuracy and ease of optimization. ", ) - eta: float = pd.Field( - 0.5, ge=0.0, le=1.0, title="Eta", description="Halfway point in projection function." + eta: float = Field( + 0.5, + ge=0.0, + le=1.0, + title="Eta", + description="Halfway point in projection function.", ) - strict_binarize: bool = pd.Field( + strict_binarize: bool = Field( False, title="Binarize strictly", description="If ``False``, the binarization is still continuous between min and max. " @@ -81,4 +84,4 @@ def evaluate(self, spatial_data: anp.ndarray, design_region_dl: float) -> anp.nd return data_projected -TransformationType = typing.Union[FilterProject] +TransformationType = Union[FilterProject] diff --git a/tidy3d/plugins/invdes/utils.py b/tidy3d/plugins/invdes/utils.py index 3e3e1fd0af..3fc00e00b6 100644 --- a/tidy3d/plugins/invdes/utils.py +++ b/tidy3d/plugins/invdes/utils.py @@ -1,8 +1,6 @@ """Functional utilities that help define postprocessing functions more simply in ``invdes``.""" -# TODO: improve these? - -import typing +from typing import Any import autograd.numpy as anp import xarray as xr @@ -10,7 +8,7 @@ import tidy3d as td -def make_array(arr: typing.Any) -> anp.ndarray: +def make_array(arr: Any) -> anp.ndarray: """Turn something into a ``anp.ndarray``.""" if isinstance(arr, xr.DataArray): return anp.array(arr.values) diff --git a/tidy3d/plugins/invdes/validators.py b/tidy3d/plugins/invdes/validators.py index cfdaa6ad18..3aa772012a 100644 --- a/tidy3d/plugins/invdes/validators.py +++ b/tidy3d/plugins/invdes/validators.py @@ -1,21 +1,20 @@ # validator utilities for invdes plugin -import typing +from typing import Callable -import pydantic.v1 as pd +from pydantic import field_validator, model_validator import tidy3d as td -from tidy3d.components.base import skip_if_fields_missing # warn if pixel size is > PIXEL_SIZE_WARNING_THRESHOLD * (minimum wavelength in material) PIXEL_SIZE_WARNING_THRESHOLD = 0.1 -def ignore_inherited_field(field_name: str) -> typing.Callable: +def ignore_inherited_field(field_name: str) -> Callable: """Create validator that ignores a field inherited but not set by user.""" - @pd.validator(field_name, always=True) - def _ignore_field(cls, val): + @field_validator(field_name) + def _ignore_field(val): """Ignore supplied field value and warn.""" if val is not None: td.log.warning( @@ -52,16 +51,15 @@ def check_pixel_size_sim(sim: td.Simulation, pixel_size: float, index: int = Non "array resolution, one can set 'DesignRegion.override_structure_dl'." ) - @pd.root_validator(allow_reuse=True) - @skip_if_fields_missing(["design_region"], root=True) - def _check_pixel_size(cls, values): + @model_validator(mode="after") + def _check_pixel_size(self): """Make sure region pixel_size isn't too large compared to sim's wavelength in material.""" - sim = values.get(sim_field_name) - region = values.get("design_region") + sim = getattr(self, sim_field_name) + region = self.design_region pixel_size = region.pixel_size if not sim and region: - return values + return self if isinstance(sim, (list, tuple)): for i, s in enumerate(sim): @@ -69,6 +67,6 @@ def _check_pixel_size(cls, values): else: check_pixel_size_sim(sim=sim, pixel_size=pixel_size) - return values + return self return _check_pixel_size diff --git a/tidy3d/plugins/microwave/array_factor.py b/tidy3d/plugins/microwave/array_factor.py index 5becb44414..81ca4b3f33 100644 --- a/tidy3d/plugins/microwave/array_factor.py +++ b/tidy3d/plugins/microwave/array_factor.py @@ -1,28 +1,26 @@ """Convenience functions for estimating antenna radiation by applying array factor.""" from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np -import pydantic.v1 as pd -from pydantic.v1 import NonNegativeFloat, PositiveInt - +from pydantic import Field, NonNegativeFloat, PositiveInt, model_validator + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.data.monitor_data import AbstractFieldProjectionData, DirectivityData +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.geometry.base import Box, Geometry +from tidy3d.components.grid.grid_spec import GridSpec, LayerRefinementSpec +from tidy3d.components.lumped_element import LumpedElement +from tidy3d.components.medium import Medium, MediumType3D +from tidy3d.components.monitor import AbstractFieldProjectionMonitor, MonitorType +from tidy3d.components.simulation import Simulation +from tidy3d.components.source.utils import SourceType +from tidy3d.components.structure import MeshOverrideStructure, Structure +from tidy3d.components.types import ArrayLike, Axis, Bound +from tidy3d.constants import C_0, inf from tidy3d.log import log -from ...components.base import Tidy3dBaseModel, skip_if_fields_missing -from ...components.data.monitor_data import AbstractFieldProjectionData, DirectivityData -from ...components.data.sim_data import SimulationData -from ...components.geometry.base import Box, Geometry -from ...components.grid.grid_spec import GridSpec, LayerRefinementSpec -from ...components.lumped_element import LumpedElement -from ...components.medium import Medium, MediumType3D -from ...components.monitor import AbstractFieldProjectionMonitor, MonitorType -from ...components.simulation import Simulation -from ...components.source.utils import SourceType -from ...components.structure import MeshOverrideStructure, Structure -from ...components.types import ArrayLike, Axis, Bound -from ...constants import C_0, inf - class AbstractAntennaArrayCalculator(Tidy3dBaseModel, ABC): """Abstract base for phased array calculators.""" @@ -159,7 +157,7 @@ def _try_to_expand_geometry( def _duplicate_or_expand_list_of_objects( self, - objects: Tuple[ + objects: tuple[ Union[Structure, MeshOverrideStructure, LayerRefinementSpec, LumpedElement], ... ], old_sim_bounds: Bound, @@ -228,7 +226,7 @@ def _duplicate_or_expand_list_of_objects( def _expand_monitors( self, - monitors: Tuple[MonitorType, ...], + monitors: tuple[MonitorType, ...], antenna_bounds: Bound, new_sim_bounds: Bound, old_sim_bounds: Bound, @@ -300,7 +298,7 @@ def _expand_monitors( return array_monitors def _duplicate_structures( - self, structures: Tuple[Structure, ...], new_sim_bounds: Bound, old_sim_bounds: Bound + self, structures: tuple[Structure, ...], new_sim_bounds: Bound, old_sim_bounds: Bound ): """Duplicate structures.""" @@ -310,8 +308,8 @@ def _duplicate_structures( def _duplicate_sources( self, - sources: Tuple[SourceType, ...], - lumped_elements: Tuple[LumpedElement, ...], + sources: tuple[SourceType, ...], + lumped_elements: tuple[LumpedElement, ...], old_sim_bounds: Bound, new_sim_bounds: Bound, ): @@ -584,13 +582,13 @@ def simulation_data_from_array_factor( simulation=sim_array.updated_copy(monitors=good_monitors), data=data_array ) - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class RectangularAntennaArrayCalculator(AbstractAntennaArrayCalculator): @@ -619,35 +617,33 @@ class RectangularAntennaArrayCalculator(AbstractAntennaArrayCalculator): ... ) # doctest: +SKIP """ - array_size: Tuple[PositiveInt, PositiveInt, PositiveInt] = pd.Field( + array_size: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( title="Array Size", description="Number of antennas along x, y, and z directions.", ) - spacings: Tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat] = pd.Field( + spacings: tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat] = Field( title="Antenna Spacings", description="Center-to-center spacings between antennas along x, y, and z directions.", ) - phase_shifts: Tuple[float, float, float] = pd.Field( + phase_shifts: tuple[float, float, float] = Field( (0, 0, 0), title="Phase Shifts", description="Phase-shifts between antennas along x, y, and z directions.", ) - amp_multipliers: Tuple[Optional[ArrayLike], Optional[ArrayLike], Optional[ArrayLike]] = ( - pd.Field( - (None, None, None), - title="Amplitude Multipliers", - description="Amplitude multipliers spatially distributed along x, y, and z directions.", - ) + amp_multipliers: tuple[Optional[ArrayLike], Optional[ArrayLike], Optional[ArrayLike]] = Field( + (None, None, None), + title="Amplitude Multipliers", + description="Amplitude multipliers spatially distributed along x, y, and z directions.", ) - @pd.validator("amp_multipliers", pre=True, always=True) - @skip_if_fields_missing(["array_size"]) - def _check_amp_multipliers(cls, val, values): + @model_validator(mode="after") + def _check_amp_multipliers(self): """Check that the length of the amplitude multipliers is equal to the array size along each dimension.""" - array_size = values.get("array_size") + val = self.amp_multipliers + array_size = self.array_size if len(val) != 3: raise ValueError("'amp_multipliers' must have 3 elements.") if val[0] is not None and len(val[0]) != array_size[0]: @@ -662,7 +658,7 @@ def _check_amp_multipliers(cls, val, values): raise ValueError( f"'amp_multipliers' has length of {len(val[2])} along the z direction, but the array size is {array_size[2]}." ) - return val + return self @property def _antenna_locations(self) -> ArrayLike: @@ -702,7 +698,7 @@ def _antenna_phases(self) -> ArrayLike: return np.ravel(sum(p for p in phase_shifts_grid)) @property - def _extend_dims(self) -> Tuple[Axis, ...]: + def _extend_dims(self) -> tuple[Axis, ...]: """Dimensions along which antennas will be duplicated.""" return [ind for ind, size in enumerate(self.array_size) if size > 1] diff --git a/tidy3d/plugins/microwave/auto_path_integrals.py b/tidy3d/plugins/microwave/auto_path_integrals.py index d9140b7a57..4dd5199ddc 100644 --- a/tidy3d/plugins/microwave/auto_path_integrals.py +++ b/tidy3d/plugins/microwave/auto_path_integrals.py @@ -1,10 +1,16 @@ """Helpers for automatic setup of path integrals.""" -from ...components.geometry.base import Box -from ...components.geometry.utils import SnapBehavior, SnapLocation, SnappingSpec, snap_box_to_grid -from ...components.grid.grid import Grid -from ...components.lumped_element import LinearLumpedElement -from ...components.types import Direction +from tidy3d.components.geometry.base import Box +from tidy3d.components.geometry.utils import ( + SnapBehavior, + SnapLocation, + SnappingSpec, + snap_box_to_grid, +) +from tidy3d.components.grid.grid import Grid +from tidy3d.components.lumped_element import LinearLumpedElement +from tidy3d.components.types import Direction + from .path_integrals import ( CurrentIntegralAxisAligned, VoltageIntegralAxisAligned, diff --git a/tidy3d/plugins/microwave/custom_path_integrals.py b/tidy3d/plugins/microwave/custom_path_integrals.py index 6b6c48469a..9614607d40 100644 --- a/tidy3d/plugins/microwave/custom_path_integrals.py +++ b/tidy3d/plugins/microwave/custom_path_integrals.py @@ -5,16 +5,17 @@ from typing import Literal import numpy as np -import pydantic.v1 as pd import shapely import xarray as xr +from pydantic import Field, field_validator + +from tidy3d.components.base import cached_property +from tidy3d.components.geometry.base import Geometry +from tidy3d.components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction +from tidy3d.components.viz import add_ax_if_none +from tidy3d.constants import MICROMETER, fp_eps +from tidy3d.exceptions import SetupError -from ...components.base import cached_property -from ...components.geometry.base import Geometry -from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction -from ...components.viz import add_ax_if_none -from ...constants import MICROMETER, fp_eps -from ...exceptions import SetupError from .path_integrals import ( AbstractAxesRH, AxisAlignedPathIntegral, @@ -48,18 +49,16 @@ class CustomPathIntegral2D(AbstractAxesRH): If the path is not closed, forward and backward differences are used at the endpoints. """ - axis: Axis = pd.Field( + axis: Axis = Field( 2, title="Axis", description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z)." ) - position: float = pd.Field( - ..., + position: float = Field( title="Position", description="Position of the plane along the ``axis``.", ) - vertices: ArrayFloat2D = pd.Field( - ..., + vertices: ArrayFloat2D = Field( title="Vertices", description="List of (d1, d2) defining the 2 dimensional positions of the path. " "The index of dimension should be in the ascending order, which means " @@ -209,8 +208,8 @@ def main_axis(self) -> Axis: """Axis for performing integration.""" return self.axis - @pd.validator("vertices", always=True) - def _correct_shape(cls, val): + @field_validator("vertices") + def _correct_shape(val): """Makes sure vertices size is correct.""" # overall shape of vertices if val.shape[1] != 2: diff --git a/tidy3d/plugins/microwave/impedance_calculator.py b/tidy3d/plugins/microwave/impedance_calculator.py index 488fbcd9b1..41f04c7acc 100644 --- a/tidy3d/plugins/microwave/impedance_calculator.py +++ b/tidy3d/plugins/microwave/impedance_calculator.py @@ -1,17 +1,16 @@ """Class for computing characteristic impedance of transmission lines.""" -from __future__ import annotations - from typing import Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator + +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.data.monitor_data import FieldTimeData +from tidy3d.constants import OHM +from tidy3d.exceptions import ValidationError +from tidy3d.log import log -from ...components.base import Tidy3dBaseModel -from ...components.data.monitor_data import FieldTimeData -from ...constants import OHM -from ...exceptions import ValidationError -from ...log import log from .custom_path_integrals import CustomCurrentIntegral2D, CustomVoltageIntegral2D from .path_integrals import ( AxisAlignedPathIntegral, @@ -28,13 +27,13 @@ class ImpedanceCalculator(Tidy3dBaseModel): """Tool for computing the characteristic impedance of a transmission line.""" - voltage_integral: Optional[VoltageIntegralTypes] = pd.Field( + voltage_integral: Optional[VoltageIntegralTypes] = Field( None, title="Voltage Integral", description="Definition of path integral for computing voltage.", ) - current_integral: Optional[CurrentIntegralTypes] = pd.Field( + current_integral: Optional[CurrentIntegralTypes] = Field( None, title="Current Integral", description="Definition of contour integral for computing current.", @@ -87,15 +86,15 @@ def compute_impedance(self, em_field: MonitorDataTypes) -> IntegralResultTypes: impedance = ImpedanceCalculator._set_data_array_attributes(impedance) return impedance - @pd.validator("current_integral", always=True) - def check_voltage_or_current(cls, val, values): + @model_validator(mode="after") + def check_voltage_or_current(self): """Raise validation error if both ``voltage_integral`` and ``current_integral`` are not provided.""" - if not values.get("voltage_integral") and not val: + if not self.voltage_integral and not self.current_intergral: raise ValidationError( "At least one of 'voltage_integral' or 'current_integral' must be provided." ) - return val + return self @staticmethod def _set_data_array_attributes(data_array: IntegralResultTypes) -> IntegralResultTypes: @@ -103,10 +102,10 @@ def _set_data_array_attributes(data_array: IntegralResultTypes) -> IntegralResul data_array.name = "Z0" return data_array.assign_attrs(units=OHM, long_name="characteristic impedance") - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data diff --git a/tidy3d/plugins/microwave/lobe_measurer.py b/tidy3d/plugins/microwave/lobe_measurer.py index 9bf0f61ddc..1feb6ed0c8 100644 --- a/tidy3d/plugins/microwave/lobe_measurer.py +++ b/tidy3d/plugins/microwave/lobe_measurer.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -import pydantic.v1 as pd from pandas import DataFrame +from pydantic import Field, field_validator, model_validator from scipy.signal import find_peaks, peak_widths -from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from ...components.types import ArrayFloat1D, ArrayLike, Ax -from ...constants import fp_eps -from ...exceptions import ValidationError -from ...log import log +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.types import ArrayFloat1D, ArrayLike, Ax +from tidy3d.constants import fp_eps +from tidy3d.exceptions import ValidationError +from tidy3d.log import log + from .viz import plot_params_lobe_FNBW, plot_params_lobe_peak, plot_params_lobe_width # The minimum plateau size for peak finding, which is set to 0 to ensure that all peaks are found. @@ -37,21 +38,19 @@ class LobeMeasurer(Tidy3dBaseModel): >>> lobe_measures = lobe_measurer.lobe_measures # doctest: +SKIP """ - angle: ArrayFloat1D = pd.Field( - ..., + angle: ArrayFloat1D = Field( title="Angle", description="A 1-dimensional array of angles in radians. The angles should be " "in the range [0, 2π] and should be sorted in ascending order.", ) - radiation_pattern: ArrayFloat1D = pd.Field( - ..., + radiation_pattern: ArrayFloat1D = Field( title="Radiation Pattern", description="A 1-dimensional array of real values representing the radiation pattern " "of the antenna measured on a linear scale.", ) - apply_cyclic_extension: bool = pd.Field( + apply_cyclic_extension: bool = Field( True, title="Apply Cyclic Extension", description="To enable accurate peak finding near boundaries of the ``angle`` array, " @@ -59,7 +58,7 @@ class LobeMeasurer(Tidy3dBaseModel): "of interest, this can be set to ``False``.", ) - width_measure: float = pd.Field( + width_measure: float = Field( 0.5, gt=0.0, le=1.0, @@ -68,7 +67,7 @@ class LobeMeasurer(Tidy3dBaseModel): "Default value of ``0.5`` corresponds with the half-power beamwidth.", ) - min_lobe_height: float = pd.Field( + min_lobe_height: float = Field( DEFAULT_MIN_LOBE_REL_HEIGHT, gt=0.0, le=1.0, @@ -77,7 +76,7 @@ class LobeMeasurer(Tidy3dBaseModel): "Lobe heights are measured relative to the maximum value in ``radiation_pattern``.", ) - null_threshold: float = pd.Field( + null_threshold: float = Field( DEFAULT_NULL_THRESHOLD, gt=0.0, le=1.0, @@ -86,30 +85,29 @@ class LobeMeasurer(Tidy3dBaseModel): "which is relative to the maximum value in the ``radiation_pattern``.", ) - @pd.validator("angle", always=True) - def _sorted_angle(cls, val): + @field_validator("angle") + def _sorted_angle(val): """Ensure the angle array is sorted.""" if not np.all(np.diff(val) >= 0): raise ValidationError("The angle array must be sorted in ascending order.") return val - @pd.validator("radiation_pattern", always=True) - def _nonnegative_radiation_pattern(cls, val): + @field_validator("radiation_pattern") + def _nonnegative_radiation_pattern(val): """Ensure the radiation pattern is nonnegative.""" if not np.all(val >= 0): raise ValidationError("Radiation pattern must be nonnegative.") return val - @pd.validator("apply_cyclic_extension", always=True) - @skip_if_fields_missing(["angle"]) - def _cyclic_extension_valid(cls, val, values): - if val: - angle = values.get("angle") + @model_validator(mode="after") + def _cyclic_extension_valid(self): + if self.apply_cyclic_extension: + angle = self.angle if np.any(angle < 0) or np.any(angle > 2 * np.pi): raise ValidationError( "When using cyclic extension, the angle array must be in the range [0, 2π]." ) - return val + return self @cached_property def lobe_measures(self) -> DataFrame: @@ -344,10 +342,10 @@ def plot( return ax - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data diff --git a/tidy3d/plugins/microwave/path_integrals.py b/tidy3d/plugins/microwave/path_integrals.py index 8b06cb2334..c7f4259aaf 100644 --- a/tidy3d/plugins/microwave/path_integrals.py +++ b/tidy3d/plugins/microwave/path_integrals.py @@ -6,12 +6,12 @@ from typing import Union import numpy as np -import pydantic.v1 as pd import shapely as shapely import xarray as xr +from pydantic import Field, model_validator -from ...components.base import Tidy3dBaseModel, cached_property -from ...components.data.data_array import ( +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.data.data_array import ( FreqDataArray, FreqModeDataArray, ScalarFieldDataArray, @@ -19,14 +19,15 @@ ScalarModeFieldDataArray, TimeDataArray, ) -from ...components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData -from ...components.geometry.base import Box, Geometry -from ...components.types import Ax, Axis, Coordinate2D, Direction -from ...components.validators import assert_line, assert_plane -from ...components.viz import add_ax_if_none -from ...constants import AMP, VOLT, fp_eps -from ...exceptions import DataError, Tidy3dError -from ...log import log +from tidy3d.components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData +from tidy3d.components.geometry.base import Box, Geometry +from tidy3d.components.types import Ax, Axis, Coordinate2D, Direction +from tidy3d.components.validators import assert_line, assert_plane +from tidy3d.components.viz import add_ax_if_none +from tidy3d.constants import AMP, VOLT, fp_eps +from tidy3d.exceptions import DataError, Tidy3dError +from tidy3d.log import log + from .viz import ( ARROW_CURRENT, plot_params_current_path, @@ -73,7 +74,7 @@ def local_dims(self) -> tuple[str, str, str]: dim3 = "xyz"[self.main_axis] return self.remaining_dims + tuple(dim3) - @pd.root_validator(pre=False) + @model_validator(mode="before") def _warn_rf_license(cls, values): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", @@ -87,7 +88,7 @@ class AxisAlignedPathIntegral(AbstractAxesRH, Box): _line_validator = assert_line() - extrapolate_to_endpoints: bool = pd.Field( + extrapolate_to_endpoints: bool = Field( False, title="Extrapolate to Endpoints", description="If the endpoints of the path integral terminate at or near a material interface, " @@ -95,7 +96,7 @@ class AxisAlignedPathIntegral(AbstractAxesRH, Box): "of the integral are ignored. Should be enabled when computing voltage between two conductors.", ) - snap_path_to_grid: bool = pd.Field( + snap_path_to_grid: bool = Field( False, title="Snap Path to Grid", description="It might be desireable to integrate exactly along the Yee grid associated with " @@ -220,8 +221,7 @@ def _make_result_data_array(result: xr.DataArray) -> IntegralResultTypes: class VoltageIntegralAxisAligned(AxisAlignedPathIntegral): """Class for computing the voltage between two points defined by an axis-aligned line.""" - sign: Direction = pd.Field( - ..., + sign: Direction = Field( title="Direction of Path Integral", description="Positive indicates V=Vb-Va where position b has a larger coordinate along the axis of integration.", ) @@ -365,19 +365,18 @@ class CurrentIntegralAxisAligned(AbstractAxesRH, Box): _plane_validator = assert_plane() - sign: Direction = pd.Field( - ..., + sign: Direction = Field( title="Direction of Contour Integral", description="Positive indicates current flowing in the positive normal axis direction.", ) - extrapolate_to_endpoints: bool = pd.Field( + extrapolate_to_endpoints: bool = Field( False, title="Extrapolate to Endpoints", description="This parameter is passed to :class:`AxisAlignedPathIntegral` objects when computing the contour integral.", ) - snap_contour_to_grid: bool = pd.Field( + snap_contour_to_grid: bool = Field( False, title="Snap Contour to Grid", description="This parameter is passed to :class:`AxisAlignedPathIntegral` objects when computing the contour integral.", diff --git a/tidy3d/plugins/microwave/rf_material_library.py b/tidy3d/plugins/microwave/rf_material_library.py index 690333bce1..7757855681 100644 --- a/tidy3d/plugins/microwave/rf_material_library.py +++ b/tidy3d/plugins/microwave/rf_material_library.py @@ -1,8 +1,8 @@ """Holds dispersive models for several commonly used RF materials.""" -# from ...components.base import Tidy3dBaseModel -from ...components.medium import PoleResidue -from ...material_library.material_library import MaterialItem, VariantItem +from tidy3d.components.medium import PoleResidue +from tidy3d.material_library.material_library import MaterialItem, VariantItem + from .rf_material_reference import rf_material_refs Rogers3003_design = VariantItem( diff --git a/tidy3d/plugins/microwave/rf_material_reference.py b/tidy3d/plugins/microwave/rf_material_reference.py index 0767687dda..97894ce5e3 100644 --- a/tidy3d/plugins/microwave/rf_material_reference.py +++ b/tidy3d/plugins/microwave/rf_material_reference.py @@ -1,6 +1,6 @@ """Holds the reference materials for Tidy3D material library.""" -from ...material_library.material_reference import ReferenceData +from tidy3d.material_library.material_reference import ReferenceData rf_material_refs = dict( Rogers3003=ReferenceData( diff --git a/tidy3d/plugins/microwave/viz.py b/tidy3d/plugins/microwave/viz.py index 5a8b3dd582..376a8b10e1 100644 --- a/tidy3d/plugins/microwave/viz.py +++ b/tidy3d/plugins/microwave/viz.py @@ -2,7 +2,7 @@ from numpy import inf -from ...components.viz import PathPlotParams +from tidy3d.components.viz import PathPlotParams """ Constants """ VOLTAGE_COLOR = "red" diff --git a/tidy3d/plugins/mode/mode_solver.py b/tidy3d/plugins/mode/mode_solver.py index 12716b0daf..bdfab633cb 100644 --- a/tidy3d/plugins/mode/mode_solver.py +++ b/tidy3d/plugins/mode/mode_solver.py @@ -2,10 +2,8 @@ invariance along a given propagation axis. """ -from __future__ import annotations - -from ...components.data.monitor_data import ModeSolverData -from ...components.mode.mode_solver import MODE_MONITOR_NAME, MODE_PLANE_TYPE, ModeSolver +from tidy3d.components.data.monitor_data import ModeSolverData +from tidy3d.components.mode.mode_solver import MODE_MONITOR_NAME, MODE_PLANE_TYPE, ModeSolver _ = ModeSolver _ = ModeSolverData diff --git a/tidy3d/plugins/mode/web.py b/tidy3d/plugins/mode/web.py index edbbbeccac..ece675c49e 100644 --- a/tidy3d/plugins/mode/web.py +++ b/tidy3d/plugins/mode/web.py @@ -1,5 +1,5 @@ """Web API for mode solver""" -from ...web.api.mode import run, run_batch +from tidy3d.web.api.mode import run, run_batch __all__ = ["run", "run_batch"] diff --git a/tidy3d/plugins/polyslab/polyslab.py b/tidy3d/plugins/polyslab/polyslab.py index 1fb496bd09..2d8068585b 100644 --- a/tidy3d/plugins/polyslab/polyslab.py +++ b/tidy3d/plugins/polyslab/polyslab.py @@ -1,8 +1,8 @@ """Divide a complex polyslab where self-intersecting polygon can occur during extrusion.""" -from ...components.geometry.polyslab import ComplexPolySlabBase -from ...components.medium import MediumType -from ...components.structure import Structure +from tidy3d.components.geometry.polyslab import ComplexPolySlabBase +from tidy3d.components.medium import MediumType +from tidy3d.components.structure import Structure class ComplexPolySlab(ComplexPolySlabBase): diff --git a/tidy3d/plugins/resonance/resonance.py b/tidy3d/plugins/resonance/resonance.py index 924a24e409..07da7cc47c 100644 --- a/tidy3d/plugins/resonance/resonance.py +++ b/tidy3d/plugins/resonance/resonance.py @@ -1,20 +1,20 @@ """Find resonances in time series data""" from functools import partial -from typing import List, Tuple, Union +from typing import Optional, Union import numpy as np import scipy.linalg import xarray as xr -from pydantic.v1 import Field, NonNegativeFloat, PositiveInt, validator +from pydantic import Field, NonNegativeFloat, PositiveInt, field_validator -from ...components.base import Tidy3dBaseModel -from ...components.data.data_array import ScalarFieldTimeDataArray -from ...components.data.monitor_data import FieldTimeData -from ...components.types import ArrayComplex1D, ArrayComplex2D, ArrayComplex3D, ArrayFloat1D -from ...constants import HERTZ -from ...exceptions import SetupError, ValidationError -from ...log import log +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.data.data_array import ScalarFieldTimeDataArray +from tidy3d.components.data.monitor_data import FieldTimeData +from tidy3d.components.types import ArrayComplex1D, ArrayComplex2D, ArrayComplex3D, ArrayFloat1D +from tidy3d.constants import HERTZ +from tidy3d.exceptions import SetupError, ValidationError +from tidy3d.log import log INIT_NUM_FREQS = 200 @@ -27,12 +27,19 @@ class ResonanceData(Tidy3dBaseModel): """Data class for storing objects computed while running the resonance finder.""" - eigvals: ArrayComplex1D = Field(..., title="Eigenvalues", description="Resonance eigenvalues.") - complex_amplitudes: ArrayComplex1D = Field( - None, title="Complex amplitudes", description="Complex resonance amplitudes" + eigvals: ArrayComplex1D = Field( + title="Eigenvalues", + description="Resonance eigenvalues.", ) - errors: ArrayFloat1D = Field( - None, title="Errors", description="Rough eigenvalue error estimate." + complex_amplitudes: Optional[ArrayComplex1D] = Field( + None, + title="Complex amplitudes", + description="Complex resonance amplitudes", + ) + errors: Optional[ArrayFloat1D] = Field( + None, + title="Errors", + description="Rough eigenvalue error estimate.", ) @@ -69,8 +76,7 @@ class ResonanceFinder(Tidy3dBaseModel): ... # A given dataframe """ - freq_window: Tuple[float, float] = Field( - ..., + freq_window: tuple[float, float] = Field( title="Window ``[fmin, fmax]``", description="Window ``[fmin, fmax]`` for the initial frequencies. " "The resonance finder is initialized with an even grid of frequencies between " @@ -101,8 +107,8 @@ class ResonanceFinder(Tidy3dBaseModel): "Making this closer to zero will typically return more resonances.", ) - @validator("freq_window", always=True) - def _check_freq_window(cls, val): + @field_validator("freq_window") + def _check_freq_window(val): """Validate ``freq_window``""" if val[1] < val[0]: raise ValidationError( @@ -110,7 +116,7 @@ def _check_freq_window(cls, val): ) return val - def run(self, signals: Union[FieldTimeData, Tuple[FieldTimeData, ...]]) -> xr.Dataset: + def run(self, signals: Union[FieldTimeData, tuple[FieldTimeData, ...]]) -> xr.Dataset: """Finds resonances in a :class:`.FieldTimeData` or a Tuple of such. The time coordinates must be uniformly spaced, and the spacing must be the same across all supplied data. The resonance finder runs on the sum of the @@ -158,13 +164,13 @@ def run_scalar_field_time(self, signal: ScalarFieldTimeDataArray) -> xr.Dataset: signal, dt = self._validate_scalar_field_time(signal) return self.run_raw_signal(signal, dt) - def run_raw_signal(self, signal: List[complex], time_step: float) -> xr.Dataset: + def run_raw_signal(self, signal: list[complex], time_step: float) -> xr.Dataset: """Finds resonances in a time series. Note that the signal should start after the sources have turned off. Parameters ---------- - signal : List[complex] + signal : list[complex] One-dimensional array holding the complex-valued time series data to search for resonances. time_step : float @@ -207,7 +213,7 @@ def run_raw_signal(self, signal: List[complex], time_step: float) -> xr.Dataset: def _validate_scalar_field_time( self, signal: ScalarFieldTimeDataArray - ) -> Tuple[ArrayComplex1D, float]: + ) -> tuple[ArrayComplex1D, float]: """Validates a :class:`.ScalarFieldTimeDataArray` and returns the time step as well as underlying data array.""" dts = np.diff(signal.t) @@ -227,7 +233,7 @@ def _validate_scalar_field_time( return np.squeeze(signal.data), dt def _aggregate_field_time_comps( - self, signals: Tuple[FieldTimeData, ...], comps + self, signals: tuple[FieldTimeData, ...], comps ) -> ScalarFieldTimeDataArray: """Aggregates the given components from several :class:`.FieldTimeData`.""" total_signal = None @@ -261,7 +267,7 @@ def _aggregate_field_time_comps( ) def _aggregate_field_time( - self, signals: Union[FieldTimeData, Tuple[FieldTimeData, ...]] + self, signals: Union[FieldTimeData, tuple[FieldTimeData, ...]] ) -> ScalarFieldTimeDataArray: """Aggregates several :class:`.FieldTimeData` into a single :class:`.ScalarFieldTimeDataArray`.""" @@ -345,7 +351,7 @@ def _gram_schmidt(self, a_matrix: ArrayComplex2D) -> ArrayComplex2D: def _solve_gen_eig_prob( self, a_matrix: ArrayComplex2D, b_matrix: ArrayComplex2D, rcond: float - ) -> Tuple[ArrayComplex1D, ArrayComplex2D]: + ) -> tuple[ArrayComplex1D, ArrayComplex2D]: """Solve a generalized eigenvalue problem of the form .. math:: diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index e8467f2799..8792fb0c82 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -4,21 +4,22 @@ import os from abc import ABC, abstractmethod -from typing import Dict, Tuple, Union, get_args +from typing import Optional, Union, get_args import numpy as np -import pydantic.v1 as pd - -from ....components.base import Tidy3dBaseModel, cached_property -from ....components.data.data_array import DataArray -from ....components.data.sim_data import SimulationData -from ....components.simulation import Simulation -from ....components.types import FreqArray -from ....config import config -from ....constants import HERTZ -from ....exceptions import SetupError, Tidy3dKeyError -from ....log import log -from ....web.api.container import Batch, BatchData +from pydantic import Field, field_validator + +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.data.data_array import DataArray +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.simulation import Simulation +from tidy3d.components.types import FreqArray +from tidy3d.config import config +from tidy3d.constants import HERTZ +from tidy3d.exceptions import SetupError, Tidy3dKeyError +from tidy3d.log import log +from tidy3d.web.api.container import Batch, BatchData + from ..ports.coaxial_lumped import CoaxialLumpedPort from ..ports.modal import Port from ..ports.rectangular_lumped import LumpedPort @@ -35,27 +36,25 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): """Tool for modeling devices and computing port parameters.""" - simulation: Simulation = pd.Field( - ..., + simulation: Simulation = Field( title="Simulation", description="Simulation describing the device without any sources present.", ) - ports: Tuple[Union[Port, TerminalPortType], ...] = pd.Field( + ports: tuple[Union[Port, TerminalPortType], ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Array or list of frequencies at which to compute port parameters.", units=HERTZ, ) - remove_dc_component: bool = pd.Field( + remove_dc_component: bool = Field( True, title="Remove DC Component", description="Whether to remove the DC component in the Gaussian pulse spectrum. " @@ -66,19 +65,19 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "pulse spectrum which can have a nonzero DC component.", ) - folder_name: str = pd.Field( + folder_name: str = Field( "default", title="Folder Name", description="Name of the folder for the tasks on web.", ) - verbose: bool = pd.Field( + verbose: bool = Field( False, title="Verbosity", description="Whether the :class:`.AbstractComponentModeler` should print status and progressbars.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -86,20 +85,20 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - path_dir: str = pd.Field( + path_dir: str = Field( DEFAULT_DATA_DIR, title="Directory Path", description="Base directory where data and batch will be downloaded.", ) - solver_version: str = pd.Field( + solver_version: str = Field( None, title="Solver Version", description_str="Custom solver version to use. " "If not supplied, uses default for the current front end version.", ) - batch_cached: Batch = pd.Field( + batch_cached: Optional[Batch] = Field( None, title="Batch (Cached)", description="Optional field to specify ``batch``. Only used as a workaround internally " @@ -108,15 +107,15 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "fields that were not used to create the task will cause errors.", ) - @pd.validator("simulation", always=True) - def _sim_has_no_sources(cls, val): + @field_validator("simulation") + def _sim_has_no_sources(val): """Make sure simulation has no sources as they interfere with tool.""" if len(val.sources) > 0: raise SetupError("'AbstractComponentModeler.simulation' must not have any sources.") return val - @pd.validator("ports", always=True) - def _warn_rf_license(cls, val): + @field_validator("ports") + def _warn_rf_license(val): """Warn about new licensing requirements for RF ports.""" rf_port = False TerminalPortTypeTuple = get_args(TerminalPortType) @@ -139,7 +138,7 @@ def _task_name(port: Port, mode_index: int = None) -> str: return f"smatrix_{port.name}" @cached_property - def sim_dict(self) -> Dict[str, Simulation]: + def sim_dict(self) -> dict[str, Simulation]: """Generate all the :class:`.Simulation` objects for the S matrix calculation.""" def to_file(self, fname: str) -> None: diff --git a/tidy3d/plugins/smatrix/component_modelers/modal.py b/tidy3d/plugins/smatrix/component_modelers/modal.py index c0acdf7fae..26742cca57 100644 --- a/tidy3d/plugins/smatrix/component_modelers/modal.py +++ b/tidy3d/plugins/smatrix/component_modelers/modal.py @@ -4,26 +4,27 @@ # "ModalPort" to explicitly differentiate these from "TerminalComponentModeler" and "LumpedPort". from __future__ import annotations -from typing import Dict, List, Optional, Tuple +from typing import Optional import numpy as np -import pydantic.v1 as pd - -from ....components.base import cached_property -from ....components.data.sim_data import SimulationData -from ....components.monitor import ModeMonitor -from ....components.simulation import Simulation -from ....components.source.field import ModeSource -from ....components.source.time import GaussianPulse -from ....components.types import Ax, Complex -from ....components.viz import add_ax_if_none, equal_aspect -from ....exceptions import SetupError -from ....web.api.container import BatchData +from pydantic import Field, NonNegativeInt, field_validator + +from tidy3d.components.base import cached_property +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.monitor import ModeMonitor +from tidy3d.components.simulation import Simulation +from tidy3d.components.source.field import ModeSource +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import Ax, Complex +from tidy3d.components.viz import add_ax_if_none, equal_aspect +from tidy3d.exceptions import SetupError +from tidy3d.web.api.container import BatchData + from ..ports.modal import ModalPortDataArray, Port from .base import FWIDTH_FRAC, AbstractComponentModeler -MatrixIndex = Tuple[str, pd.NonNegativeInt] # the 'i' in S_ij -Element = Tuple[MatrixIndex, MatrixIndex] # the 'ij' in S_ij +MatrixIndex = tuple[str, NonNegativeInt] # the 'i' in S_ij +Element = tuple[MatrixIndex, MatrixIndex] # the 'ij' in S_ij class ComponentModeler(AbstractComponentModeler): @@ -39,14 +40,14 @@ class ComponentModeler(AbstractComponentModeler): * `Computing the scattering matrix of a device <../../notebooks/SMatrix.html>`_ """ - ports: Tuple[Port, ...] = pd.Field( + ports: tuple[Port, ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) - element_mappings: Tuple[Tuple[Element, Element, Complex], ...] = pd.Field( + element_mappings: tuple[tuple[Element, Element, Complex], ...] = Field( (), title="Element Mappings", description="Mapping between elements of the scattering matrix, " @@ -59,7 +60,7 @@ class ComponentModeler(AbstractComponentModeler): "is skipped automatically.", ) - run_only: Optional[Tuple[MatrixIndex, ...]] = pd.Field( + run_only: Optional[tuple[MatrixIndex, ...]] = Field( None, title="Run Only", description="If given, a tuple of matrix indices, specified by (:class:`.Port`, ``int``)," @@ -71,13 +72,13 @@ class ComponentModeler(AbstractComponentModeler): :class:`ComponentModeler`. ``run_only`` contains the scattering matrix indices that the user wants to run as a source. If any indices are excluded, they will not be run.""" - verbose: bool = pd.Field( + verbose: bool = Field( False, title="Verbosity", description="Whether the :class:`.ComponentModeler` should print status and progressbars.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -85,15 +86,15 @@ class ComponentModeler(AbstractComponentModeler): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - @pd.validator("simulation", always=True) - def _sim_has_no_sources(cls, val): + @field_validator("simulation") + def _sim_has_no_sources(val): """Make sure simulation has no sources as they interfere with tool.""" if len(val.sources) > 0: raise SetupError("'ComponentModeler.simulation' must not have any sources.") return val @cached_property - def sim_dict(self) -> Dict[str, Simulation]: + def sim_dict(self) -> dict[str, Simulation]: """Generate all the :class:`.Simulation` objects for the S matrix calculation.""" sim_dict = {} @@ -112,7 +113,7 @@ def sim_dict(self) -> Dict[str, Simulation]: return sim_dict @cached_property - def matrix_indices_monitor(self) -> Tuple[MatrixIndex, ...]: + def matrix_indices_monitor(self) -> tuple[MatrixIndex, ...]: """Tuple of all the possible matrix indices (port, mode_index) in the Component Modeler.""" matrix_indices = [] for port in self.ports: @@ -121,14 +122,14 @@ def matrix_indices_monitor(self) -> Tuple[MatrixIndex, ...]: return tuple(matrix_indices) @cached_property - def matrix_indices_source(self) -> Tuple[MatrixIndex, ...]: + def matrix_indices_source(self) -> tuple[MatrixIndex, ...]: """Tuple of all the source matrix indices (port, mode_index) in the Component Modeler.""" if self.run_only is not None: return self.run_only return self.matrix_indices_monitor @cached_property - def matrix_indices_run_sim(self) -> Tuple[MatrixIndex, ...]: + def matrix_indices_run_sim(self) -> tuple[MatrixIndex, ...]: """Tuple of all the source matrix indices (port, mode_index) in the Component Modeler.""" if self.element_mappings is None or self.element_mappings == {}: @@ -154,10 +155,10 @@ def matrix_indices_run_sim(self) -> Tuple[MatrixIndex, ...]: return source_indices_needed @cached_property - def port_names(self) -> Tuple[List[str], List[str]]: + def port_names(self) -> tuple[list[str], list[str]]: """List of port names for inputs and outputs, respectively.""" - def get_port_names(matrix_elements: Tuple[str, int]) -> List[str]: + def get_port_names(matrix_elements: tuple[str, int]) -> list[str]: """Get the port names from a list of (port name, mode index).""" port_names = [] for port_name, _ in matrix_elements: @@ -182,7 +183,7 @@ def to_monitor(self, port: Port) -> ModeMonitor: def to_source( self, port: Port, mode_index: int, num_freqs: int = 1, **kwargs - ) -> List[ModeSource]: + ) -> list[ModeSource]: """Creates a list of mode sources from a given port.""" freq0 = np.mean(self.freqs) fdiff = max(self.freqs) - min(self.freqs) @@ -249,10 +250,10 @@ def _normalization_factor(self, port_source: Port, sim_data: SimulationData) -> return normalize_amps.values @cached_property - def max_mode_index(self) -> Tuple[int, int]: + def max_mode_index(self) -> tuple[int, int]: """maximum mode indices for the smatrix dataset for the in and out ports, respectively.""" - def get_max_mode_indices(matrix_elements: Tuple[str, int]) -> int: + def get_max_mode_indices(matrix_elements: tuple[str, int]) -> int: """Get the maximum mode index for a list of (port name, mode index).""" return max(mode_index for _, mode_index in matrix_elements) diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index ebd17b5046..bbfec31d82 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -2,31 +2,27 @@ from __future__ import annotations -from typing import Dict, Tuple, Union +from typing import Union import numpy as np -import pydantic.v1 as pd - -from ....components.base import cached_property -from ....components.data.data_array import ( - DataArray, - FreqDataArray, -) -from ....components.data.monitor_data import ( - MonitorData, -) -from ....components.data.sim_data import SimulationData -from ....components.geometry.utils_2d import snap_coordinate_to_grid -from ....components.microwave.data.monitor_data import AntennaMetricsData -from ....components.monitor import DirectivityMonitor -from ....components.simulation import Simulation -from ....components.source.time import GaussianPulse -from ....components.types import Ax -from ....components.viz import add_ax_if_none, equal_aspect -from ....constants import C_0, OHM -from ....exceptions import Tidy3dError, Tidy3dKeyError, ValidationError -from ....log import log -from ....web.api.container import BatchData +from pydantic import Field, field_validator, model_validator + +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import DataArray, FreqDataArray +from tidy3d.components.data.monitor_data import MonitorData +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.geometry.utils_2d import snap_coordinate_to_grid +from tidy3d.components.microwave.data.monitor_data import AntennaMetricsData +from tidy3d.components.monitor import DirectivityMonitor +from tidy3d.components.simulation import Simulation +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import Ax +from tidy3d.components.viz import add_ax_if_none, equal_aspect +from tidy3d.constants import C_0, OHM +from tidy3d.exceptions import Tidy3dError, Tidy3dKeyError, ValidationError +from tidy3d.log import log +from tidy3d.web.api.container import BatchData + from ..data.terminal import PortDataArray, TerminalPortDataArray from ..ports.base_lumped import AbstractLumpedPort from ..ports.coaxial_lumped import CoaxialLumpedPort @@ -39,27 +35,27 @@ class TerminalComponentModeler(AbstractComponentModeler): """Tool for modeling two-terminal multiport devices and computing port parameters with lumped and wave ports.""" - ports: Tuple[TerminalPortType, ...] = pd.Field( + ports: tuple[TerminalPortType, ...] = Field( (), title="Terminal Ports", description="Collection of lumped and wave ports associated with the network. " "For each port, one simulation will be run with a source that is associated with the port.", ) - radiation_monitors: tuple[DirectivityMonitor, ...] = pd.Field( + radiation_monitors: tuple[DirectivityMonitor, ...] = Field( (), title="Radiation Monitors", description="Facilitates the calculation of figures-of-merit for antennas. " "These monitor will be included in every simulation and record the radiated fields. ", ) - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data @equal_aspect @add_ax_if_none @@ -90,7 +86,7 @@ def plot_sim_eps( return sim_plot.plot_eps(x=x, y=y, z=z, ax=ax, **kwargs) @cached_property - def sim_dict(self) -> Dict[str, Simulation]: + def sim_dict(self) -> dict[str, Simulation]: """Generate all the :class:`.Simulation` objects for the port parameter calculation.""" sim_dict = {} @@ -212,7 +208,8 @@ def _internal_construct_smatrix(self, batch_data: BatchData) -> TerminalPortData s_matrix = self.ab_to_s(a_matrix, b_matrix) return s_matrix - @pd.validator("simulation") + @field_validator("simulation") + @classmethod def _validate_3d_simulation(cls, val): """Error if :class:`.Simulation` is not a 3D simulation""" @@ -222,18 +219,18 @@ def _validate_3d_simulation(cls, val): ) return val - @pd.validator("radiation_monitors") - def _validate_radiation_monitors(cls, val, values): - freqs = set(values.get("freqs")) - for rad_mon in val: + @model_validator(mode="after") + def _validate_radiation_monitors(self): + freqs = set(self.freqs) + for rad_mon in self.radiation_monitors: mon_freqs = rad_mon.freqs is_subset = freqs.issuperset(mon_freqs) if not is_subset: raise ValidationError( f"The frequencies in the radiation monitor '{rad_mon.name}' " - f"must be equal to or a subset of the frequencies in the '{cls.__name__}'." + f"must be equal to or a subset of the frequencies in the '{self.__name__}'." ) - return val + return self @staticmethod def _check_grid_size_at_ports(simulation: Simulation, ports: list[Union[AbstractLumpedPort]]): diff --git a/tidy3d/plugins/smatrix/data/terminal.py b/tidy3d/plugins/smatrix/data/terminal.py index 6a240a1dd0..7d3d0f3ae0 100644 --- a/tidy3d/plugins/smatrix/data/terminal.py +++ b/tidy3d/plugins/smatrix/data/terminal.py @@ -1,15 +1,10 @@ """Storing data associated with results from the TerminalComponentModeler""" -from __future__ import annotations - -import pydantic.v1 as pd +from pydantic import model_validator +from tidy3d.components.data.data_array import DataArray from tidy3d.log import log -from ....components.data.data_array import ( - DataArray, -) - class PortDataArray(DataArray): """Array of values over dimensions of frequency and port name. @@ -27,13 +22,13 @@ class PortDataArray(DataArray): __slots__ = () _dims = ("f", "port") - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data class TerminalPortDataArray(DataArray): @@ -54,10 +49,10 @@ class TerminalPortDataArray(DataArray): _dims = ("f", "port_out", "port_in") _data_attrs = {"long_name": "terminal-based port matrix element"} - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data diff --git a/tidy3d/plugins/smatrix/ports/base_lumped.py b/tidy3d/plugins/smatrix/ports/base_lumped.py index e0bc3b0147..d6b206fcff 100644 --- a/tidy3d/plugins/smatrix/ports/base_lumped.py +++ b/tidy3d/plugins/smatrix/ports/base_lumped.py @@ -3,15 +3,16 @@ from abc import abstractmethod from typing import Optional -import pydantic.v1 as pd - -from ....components.base import cached_property -from ....components.geometry.utils_2d import snap_coordinate_to_grid -from ....components.grid.grid import Grid, YeeGrid -from ....components.lumped_element import LumpedElementType -from ....components.monitor import FieldMonitor -from ....components.types import Complex, Coordinate, FreqArray -from ....constants import OHM +from pydantic import Field, PositiveInt + +from tidy3d.components.base import cached_property +from tidy3d.components.geometry.utils_2d import snap_coordinate_to_grid +from tidy3d.components.grid.grid import Grid, YeeGrid +from tidy3d.components.lumped_element import LumpedElementType +from tidy3d.components.monitor import FieldMonitor +from tidy3d.components.types import Complex, Coordinate, FreqArray +from tidy3d.constants import OHM + from .base_terminal import AbstractTerminalPort DEFAULT_PORT_NUM_CELLS = 3 @@ -21,14 +22,14 @@ class AbstractLumpedPort(AbstractTerminalPort): """Class representing a single lumped port.""" - impedance: Complex = pd.Field( + impedance: Complex = Field( DEFAULT_REFERENCE_IMPEDANCE, title="Reference impedance", description="Reference port impedance for scattering parameter computation.", units=OHM, ) - num_grid_cells: Optional[pd.PositiveInt] = pd.Field( + num_grid_cells: Optional[PositiveInt] = Field( DEFAULT_PORT_NUM_CELLS, title="Port grid cells", description="Number of mesh grid cells associated with the port along each direction, " @@ -36,7 +37,7 @@ class AbstractLumpedPort(AbstractTerminalPort): "A value of ``None`` will turn off automatic mesh refinement.", ) - enable_snapping_points: bool = pd.Field( + enable_snapping_points: bool = Field( True, title="Snap Grid To Lumped Port", description="When enabled, snapping points are automatically generated to snap grids to key " diff --git a/tidy3d/plugins/smatrix/ports/base_terminal.py b/tidy3d/plugins/smatrix/ports/base_terminal.py index 88504856ce..10ea88f5c6 100644 --- a/tidy3d/plugins/smatrix/ports/base_terminal.py +++ b/tidy3d/plugins/smatrix/ports/base_terminal.py @@ -3,19 +3,18 @@ from abc import ABC, abstractmethod from typing import Union -import pydantic.v1 as pd +from pydantic import Field, model_validator +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.data.data_array import FreqDataArray +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.grid.grid import Grid +from tidy3d.components.monitor import FieldMonitor, ModeMonitor +from tidy3d.components.source.base import Source +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import FreqArray from tidy3d.log import log -from ....components.base import Tidy3dBaseModel, cached_property -from ....components.data.data_array import FreqDataArray -from ....components.data.sim_data import SimulationData -from ....components.grid.grid import Grid -from ....components.monitor import FieldMonitor, ModeMonitor -from ....components.source.base import Source -from ....components.source.time import GaussianPulse -from ....components.types import FreqArray - class AbstractTerminalPort(Tidy3dBaseModel, ABC): """Class representing a single terminal-based port. All terminal ports must provide methods @@ -23,8 +22,7 @@ class AbstractTerminalPort(Tidy3dBaseModel, ABC): terminals, and the current flowing from one terminal into the other. """ - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the port.", min_length=1, @@ -65,10 +63,10 @@ def compute_voltage(self, sim_data: SimulationData) -> FreqDataArray: def compute_current(self, sim_data: SimulationData) -> FreqDataArray: """Helper to compute current flowing into the port.""" - @pd.root_validator(pre=False) - def _warn_rf_license(cls, values): + @model_validator(mode="before") + def _warn_rf_license(data): log.warning( "ℹ️ ⚠️ RF simulations are subject to new license requirements in the future. You have instantiated at least one RF-specific component.", log_once=True, ) - return values + return data diff --git a/tidy3d/plugins/smatrix/ports/coaxial_lumped.py b/tidy3d/plugins/smatrix/ports/coaxial_lumped.py index 51a0faef31..73014c3129 100644 --- a/tidy3d/plugins/smatrix/ports/coaxial_lumped.py +++ b/tidy3d/plugins/smatrix/ports/coaxial_lumped.py @@ -1,25 +1,25 @@ """Lumped port specialization with an annular geometry for exciting coaxial ports.""" import numpy as np -import pydantic.v1 as pd - -from ....components.base import cached_property -from ....components.data.data_array import FreqDataArray, ScalarFieldDataArray -from ....components.data.dataset import FieldDataset -from ....components.data.sim_data import SimulationData -from ....components.geometry.base import Box, Geometry -from ....components.geometry.utils_2d import increment_float -from ....components.grid.grid import Grid, YeeGrid -from ....components.lumped_element import CoaxialLumpedResistor -from ....components.monitor import FieldMonitor -from ....components.source.current import CustomCurrentSource -from ....components.source.time import GaussianPulse -from ....components.types import Axis, Coordinate, Direction, FreqArray, Size -from ....components.validators import skip_if_fields_missing -from ....constants import MICROMETER -from ....exceptions import SetupError, ValidationError -from ...microwave import CustomCurrentIntegral2D, VoltageIntegralAxisAligned -from ...microwave.path_integrals import AbstractAxesRH +from pydantic import Field, PositiveFloat, field_validator, model_validator + +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray +from tidy3d.components.data.dataset import FieldDataset +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.geometry.base import Box, Geometry +from tidy3d.components.geometry.utils_2d import increment_float +from tidy3d.components.grid.grid import Grid, YeeGrid +from tidy3d.components.lumped_element import CoaxialLumpedResistor +from tidy3d.components.monitor import FieldMonitor +from tidy3d.components.source.current import CustomCurrentSource +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import Axis, Coordinate, Direction, FreqArray, Size +from tidy3d.constants import MICROMETER +from tidy3d.exceptions import SetupError, ValidationError +from tidy3d.plugins.microwave import CustomCurrentIntegral2D, VoltageIntegralAxisAligned +from tidy3d.plugins.microwave.path_integrals import AbstractAxesRH + from .base_lumped import AbstractLumpedPort DEFAULT_COAX_SOURCE_NUM_POINTS = 11 @@ -40,35 +40,31 @@ class CoaxialLumpedPort(AbstractLumpedPort, AbstractAxesRH): ... ) # doctest: +SKIP """ - center: Coordinate = pd.Field( + center: Coordinate = Field( (0.0, 0.0, 0.0), title="Center", description="Center of object in x, y, and z.", units=MICROMETER, ) - outer_diameter: pd.PositiveFloat = pd.Field( - ..., + outer_diameter: PositiveFloat = Field( title="Outer Diameter", description="Diameter of the outer coaxial circle.", units=MICROMETER, ) - inner_diameter: pd.PositiveFloat = pd.Field( - ..., + inner_diameter: PositiveFloat = Field( title="Inner Diameter", description="Diameter of the inner coaxial circle.", units=MICROMETER, ) - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Normal Axis", description="Specifies the axis which is normal to the concentric circles.", ) - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="The direction of the signal travelling in the transmission line. " "This is needed in order to position the path integral, which is used for computing " @@ -85,25 +81,23 @@ def injection_axis(self): """Required for inheriting from AbstractTerminalPort.""" return self.normal_axis - @pd.validator("center", always=True) - def _center_not_inf(cls, val): + @field_validator("center") + def _center_not_inf(val): """Make sure center is not infinity.""" if any(np.isinf(v) for v in val): raise ValidationError("'center' can not contain 'td.inf' terms.") return val - @pd.validator("inner_diameter", always=True) - @skip_if_fields_missing(["outer_diameter"]) - def _ensure_inner_diameter_is_smaller(cls, val, values): + @model_validator(mode="after") + def _ensure_inner_diameter_is_smaller(self): """Ensures that the inner diameter is smaller than the outer diameter, so that the final shape is an annulus.""" - outer_diameter = values.get("outer_diameter") - if val >= outer_diameter: + if self.inner_diameter >= self.outer_diameter: raise ValidationError( - f"The 'inner_diameter' {val} of a coaxial lumped element must be less than its " - f"'outer_diameter' {outer_diameter}." + f"The 'inner_diameter' {self.inner_diameterval} of a coaxial lumped element " + f"must be less than its 'outer_diameter' {self.outer_diameter}." ) - return val + return self def to_source( self, source_time: GaussianPulse, snap_center: float = None, grid: Grid = None diff --git a/tidy3d/plugins/smatrix/ports/modal.py b/tidy3d/plugins/smatrix/ports/modal.py index eb747d482c..d11b26d35e 100644 --- a/tidy3d/plugins/smatrix/ports/modal.py +++ b/tidy3d/plugins/smatrix/ports/modal.py @@ -1,11 +1,11 @@ """Class and custom data array for representing a scattering matrix port based on waveguide modes.""" -import pydantic.v1 as pd +from pydantic import Field -from ....components.data.data_array import DataArray -from ....components.geometry.base import Box -from ....components.mode_spec import ModeSpec -from ....components.types import Direction +from tidy3d.components.data.data_array import DataArray +from tidy3d.components.geometry.base import Box +from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.types import Direction class ModalPortDataArray(DataArray): @@ -37,18 +37,16 @@ class ModalPortDataArray(DataArray): class Port(Box): """Specifies a port in the scattering matrix.""" - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="'+' or '-', defining which direction is considered 'input'.", ) - mode_spec: ModeSpec = pd.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Specifies how the mode solver will solve for the modes of the port.", ) - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the port.", min_length=1, diff --git a/tidy3d/plugins/smatrix/ports/rectangular_lumped.py b/tidy3d/plugins/smatrix/ports/rectangular_lumped.py index 507a80f1c4..edb9866b68 100644 --- a/tidy3d/plugins/smatrix/ports/rectangular_lumped.py +++ b/tidy3d/plugins/smatrix/ports/rectangular_lumped.py @@ -1,35 +1,29 @@ """Lumped port specialization with a rectangular geometry.""" import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator -from ....components.base import cached_property -from ....components.data.data_array import FreqDataArray -from ....components.data.sim_data import SimulationData -from ....components.geometry.base import Box -from ....components.geometry.utils import ( +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import FreqDataArray +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.geometry.base import Box +from tidy3d.components.geometry.utils import ( SnapBehavior, SnapLocation, SnappingSpec, snap_box_to_grid, ) -from ....components.geometry.utils_2d import increment_float -from ....components.grid.grid import Grid, YeeGrid -from ....components.lumped_element import ( - LinearLumpedElement, - LumpedResistor, - RLCNetwork, -) -from ....components.monitor import FieldMonitor -from ....components.source.current import UniformCurrentSource -from ....components.source.time import GaussianPulse -from ....components.types import Axis, FreqArray, LumpDistType -from ....components.validators import assert_line_or_plane -from ....exceptions import SetupError, ValidationError -from ...microwave import ( - CurrentIntegralAxisAligned, - VoltageIntegralAxisAligned, -) +from tidy3d.components.geometry.utils_2d import increment_float +from tidy3d.components.grid.grid import Grid, YeeGrid +from tidy3d.components.lumped_element import LinearLumpedElement, LumpedResistor, RLCNetwork +from tidy3d.components.monitor import FieldMonitor +from tidy3d.components.source.current import UniformCurrentSource +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import Axis, FreqArray, LumpDistType +from tidy3d.components.validators import assert_line_or_plane +from tidy3d.exceptions import SetupError, ValidationError +from tidy3d.plugins.microwave import CurrentIntegralAxisAligned, VoltageIntegralAxisAligned + from .base_lumped import AbstractLumpedPort @@ -51,14 +45,13 @@ class LumpedPort(AbstractLumpedPort, Box): The lumped element representing the load of the port. """ - voltage_axis: Axis = pd.Field( - ..., + voltage_axis: Axis = Field( title="Voltage Integration Axis", description="Specifies the axis along which the E-field line integral is performed when " "computing the port voltage. The integration axis must lie in the plane of the port.", ) - snap_perimeter_to_grid: bool = pd.Field( + snap_perimeter_to_grid: bool = Field( True, title="Snap Perimeter to Grid", description="When enabled, the perimeter of the port is snapped to the simulation grid, " @@ -66,7 +59,7 @@ class LumpedPort(AbstractLumpedPort, Box): "is always snapped to the grid along its injection axis.", ) - dist_type: LumpDistType = pd.Field( + dist_type: LumpDistType = Field( "on", title="Distribute Type", description="Optional field that is passed directly to the :class:`.LinearLumpedElement` used to model the port's load. " @@ -85,13 +78,12 @@ def injection_axis(self): """Injection axis of the port.""" return self.size.index(0.0) - @pd.validator("voltage_axis", always=True) - def _voltage_axis_in_plane(cls, val, values): + @model_validator(mode="after") + def _voltage_axis_in_plane(self): """Ensure voltage integration axis is in the port's plane.""" - size = values.get("size") - if val == size.index(0.0): + if self.voltage_axis == self.size.index(0.0): raise ValidationError("'voltage_axis' must lie in the port's plane.") - return val + return self @cached_property def current_axis(self) -> Axis: diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index e5532ef6a0..17fc76094c 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -3,45 +3,41 @@ from typing import Optional, Union import numpy as np -import pydantic.v1 as pd - -from ....components.base import cached_property, skip_if_fields_missing -from ....components.data.data_array import FreqDataArray, FreqModeDataArray -from ....components.data.monitor_data import ModeData -from ....components.data.sim_data import SimulationData -from ....components.geometry.base import Box -from ....components.grid.grid import Grid -from ....components.monitor import ModeMonitor -from ....components.simulation import Simulation -from ....components.source.field import ModeSource, ModeSpec -from ....components.source.time import GaussianPulse -from ....components.types import Bound, Direction, FreqArray -from ....exceptions import ValidationError -from ...microwave import ( - CurrentIntegralTypes, - ImpedanceCalculator, - VoltageIntegralTypes, -) -from ...mode import ModeSolver +from pydantic import Field, NonNegativeInt, model_validator + +from tidy3d.components.base import cached_property +from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray +from tidy3d.components.data.monitor_data import ModeData +from tidy3d.components.data.sim_data import SimulationData +from tidy3d.components.geometry.base import Box +from tidy3d.components.grid.grid import Grid +from tidy3d.components.monitor import ModeMonitor +from tidy3d.components.simulation import Simulation +from tidy3d.components.source.field import ModeSource, ModeSpec +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.types import Bound, Direction, FreqArray +from tidy3d.exceptions import ValidationError +from tidy3d.plugins.microwave import CurrentIntegralTypes, ImpedanceCalculator, VoltageIntegralTypes +from tidy3d.plugins.mode import ModeSolver + from .base_terminal import AbstractTerminalPort class WavePort(AbstractTerminalPort, Box): """Class representing a single wave port""" - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="'+' or '-', defining which direction is considered 'input'.", ) - mode_spec: ModeSpec = pd.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", ) - mode_index: pd.NonNegativeInt = pd.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="Index into the collection of modes returned by mode solver. " @@ -50,13 +46,13 @@ class WavePort(AbstractTerminalPort, Box): "``num_modes`` in the solver will be set to ``mode_index + 1``.", ) - voltage_integral: Optional[VoltageIntegralTypes] = pd.Field( + voltage_integral: Optional[VoltageIntegralTypes] = Field( None, title="Voltage Integral", description="Definition of voltage integral used to compute voltage and the characteristic impedance.", ) - current_integral: Optional[CurrentIntegralTypes] = pd.Field( + current_integral: Optional[CurrentIntegralTypes] = Field( None, title="Current Integral", description="Definition of current integral used to compute current and the characteristic impedance.", @@ -192,42 +188,35 @@ def _within_port_bounds(path_bounds: Bound, port_bounds: Bound) -> bool: bound_max = np.array(port_bounds[1]) return (bound_min <= path_min).all() and (bound_max >= path_max).all() - @pd.validator("voltage_integral", "current_integral") - def _validate_path_integrals_within_port(cls, val, values): - """Raise ``ValidationError`` when the supplied path integrals are not within the port bounds.""" - center = values["center"] - size = values["size"] - box = Box(center=center, size=size) - if val and not WavePort._within_port_bounds(val.bounds, box.bounds): - raise ValidationError( - f"'{cls.__name__}' must be setup with all path integrals defined within the bounds " - f"of the port. Path bounds are '{val.bounds}', but port bounds are '{box.bounds}'." - ) - return val + @model_validator(mode="after") + def _validate_integrals_within_port(self): + box = Box(center=self.center, size=self.size) + + for name in ("voltage_integral", "current_integral"): + val = getattr(self, name) + if val and not self._within_port_bounds(val.bounds, box.bounds): + raise ValueError( + f"{name} bounds {val.bounds!r} must be inside port bounds {box.bounds!r}" + ) + return self - @pd.validator("current_integral", always=True) - @skip_if_fields_missing(["voltage_integral"]) - def _check_voltage_or_current(cls, val, values): + @model_validator(mode="after") + def _check_voltage_or_current(self): """Raise validation error if both ``voltage_integral`` and ``current_integral`` were not provided.""" - if values.get("voltage_integral") is None and val is None: + if self.voltage_integral is None and self.current_integral is None: raise ValidationError( "At least one of 'voltage_integral' or 'current_integral' must be provided." ) - return val + return self - @pd.validator("current_integral", always=True) - def validate_current_integral_sign(cls, val, values): - """ - Validate that the sign of ``current_integral`` matches the port direction. - """ - if val is None: - return val - - direction = values.get("direction") - name = values.get("name") - if val.sign != direction: + @model_validator(mode="after") + def validate_current_integral_sign(self): + """Validate that the sign of ``current_integral`` matches the port direction.""" + if self.current_integral is None: + return self + if self.current_integral.sign != self.direction: raise ValidationError( - f"'current_integral' sign must match the '{name}' direction '{direction}'." + f"'current_integral' sign must match the '{self.name}' direction '{self.direction}'." ) - return val + return self diff --git a/tidy3d/plugins/waveguide/rectangular_dielectric.py b/tidy3d/plugins/waveguide/rectangular_dielectric.py index 78891e17bf..9a1ef129c2 100644 --- a/tidy3d/plugins/waveguide/rectangular_dielectric.py +++ b/tidy3d/plugins/waveguide/rectangular_dielectric.py @@ -1,32 +1,39 @@ """Rectangular dielectric waveguide utilities.""" -from typing import Any, List, Tuple, Union +from typing import Annotated, Any, Optional, Union import numpy -import pydantic.v1 as pydantic from matplotlib import pyplot -from typing_extensions import Annotated - -from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from ...components.boundary import BoundarySpec, Periodic -from ...components.data.data_array import FreqModeDataArray, ModeIndexDataArray -from ...components.geometry.base import Box -from ...components.geometry.polyslab import PolySlab -from ...components.grid.grid_spec import GridSpec -from ...components.medium import Medium, MediumType -from ...components.mode_spec import ModeSpec -from ...components.simulation import Simulation -from ...components.source.field import ModeSource -from ...components.source.time import GaussianPulse -from ...components.structure import Structure -from ...components.types import TYPE_TAG_STR, ArrayFloat1D, Ax, Axis, Coordinate, Literal, Size1D -from ...components.viz import add_ax_if_none -from ...constants import C_0, MICROMETER, RADIAN, inf -from ...exceptions import Tidy3dError, ValidationError -from ...log import log -from ..mode.mode_solver import ModeSolver - -AnnotatedMedium = Annotated[MediumType, pydantic.Field(discriminator=TYPE_TAG_STR)] +from pydantic import Field, field_validator, model_validator + +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.boundary import BoundarySpec, Periodic +from tidy3d.components.data.data_array import FreqModeDataArray, ModeIndexDataArray +from tidy3d.components.geometry.base import Box +from tidy3d.components.geometry.polyslab import PolySlab +from tidy3d.components.grid.grid_spec import GridSpec +from tidy3d.components.medium import Medium, MediumType +from tidy3d.components.mode_spec import ModeSpec +from tidy3d.components.simulation import Simulation +from tidy3d.components.source.field import ModeSource +from tidy3d.components.source.time import GaussianPulse +from tidy3d.components.structure import Structure +from tidy3d.components.types import ( + TYPE_TAG_STR, + ArrayFloat1D, + Ax, + Axis, + Coordinate, + Literal, + Size1D, +) +from tidy3d.components.viz import add_ax_if_none +from tidy3d.constants import C_0, MICROMETER, RADIAN, inf +from tidy3d.exceptions import Tidy3dError, ValidationError +from tidy3d.log import log +from tidy3d.plugins.mode.mode_solver import ModeSolver + +AnnotatedMedium = Annotated[MediumType, Field(discriminator=TYPE_TAG_STR)] EVANESCENT_TAIL = 1.5 @@ -43,57 +50,52 @@ class RectangularDielectric(Tidy3dBaseModel): - Coupled waveguides """ - wavelength: Union[float, ArrayFloat1D] = pydantic.Field( - ..., + wavelength: Union[float, ArrayFloat1D] = Field( title="Wavelength", description="Wavelength(s) at which to calculate modes (in μm).", units=MICROMETER, ) - core_width: Union[Size1D, ArrayFloat1D] = pydantic.Field( - ..., + core_width: Union[Size1D, ArrayFloat1D] = Field( title="Core width", description="Core width at the top of the waveguide. If set to an array, defines " "the widths of adjacent waveguides.", units=MICROMETER, ) - core_thickness: Size1D = pydantic.Field( - ..., + core_thickness: Size1D = Field( title="Core Thickness", description="Thickness of the core layer.", units=MICROMETER, ) - core_medium: MediumType = pydantic.Field( - ..., + core_medium: MediumType = Field( title="Core Medium", description="Medium associated with the core layer.", discriminator=TYPE_TAG_STR, ) - clad_medium: Union[AnnotatedMedium, Tuple[AnnotatedMedium, ...]] = pydantic.Field( - ..., + clad_medium: Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]] = Field( title="Clad Medium", description="Medium associated with the upper cladding layer. A sequence of mediums can " "be used to create a layered clad.", ) - box_medium: Union[AnnotatedMedium, Tuple[AnnotatedMedium, ...]] = pydantic.Field( + box_medium: Optional[Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]]] = Field( None, title="Box Medium", description="Medium associated with the lower cladding layer. A sequence of mediums can " "be used to create a layered substrate. If not set, the first clad medium is used.", ) - slab_thickness: Size1D = pydantic.Field( + slab_thickness: Size1D = Field( 0.0, title="Slab Thickness", description="Thickness of the slab for rib geometry.", units=MICROMETER, ) - clad_thickness: Union[Size1D, ArrayFloat1D] = pydantic.Field( + clad_thickness: Optional[Union[Size1D, ArrayFloat1D]] = Field( None, title="Clad Thickness", description="Domain size above the core layer. An array can be used to define a layered " @@ -101,7 +103,7 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - box_thickness: Union[Size1D, ArrayFloat1D] = pydantic.Field( + box_thickness: Optional[Union[Size1D, ArrayFloat1D]] = Field( None, title="Box Thickness", description="Domain size below the core layer. An array can be used to define a layered " @@ -109,14 +111,14 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - side_margin: Size1D = pydantic.Field( + side_margin: Optional[Size1D] = Field( None, title="Side Margin", description="Domain size to the sides of the waveguide core.", units=MICROMETER, ) - sidewall_angle: float = pydantic.Field( + sidewall_angle: float = Field( 0.0, title="Sidewall Angle", description="Angle of the core sidewalls measured from the vertical direction (in " @@ -125,7 +127,7 @@ class RectangularDielectric(Tidy3dBaseModel): units=RADIAN, ) - gap: Union[float, ArrayFloat1D] = pydantic.Field( + gap: Union[float, ArrayFloat1D] = Field( 0.0, title="Gap", description="Distance between adjacent waveguides, measured at the top core edges. " @@ -133,21 +135,21 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - sidewall_thickness: Size1D = pydantic.Field( + sidewall_thickness: Size1D = Field( 0.0, title="Sidewall Thickness", description="Sidewall layer thickness (within core).", units=MICROMETER, ) - sidewall_medium: MediumType = pydantic.Field( + sidewall_medium: Optional[MediumType] = Field( None, title="Sidewall medium", description="Medium associated with the sidewall layer to model sidewall losses.", discriminator=TYPE_TAG_STR, ) - surface_thickness: Size1D = pydantic.Field( + surface_thickness: Size1D = Field( 0.0, title="Surface Thickness", description="Thickness of the surface layers defined on the top of the waveguide and " @@ -155,14 +157,14 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - surface_medium: MediumType = pydantic.Field( + surface_medium: Optional[MediumType] = Field( None, title="Surface Medium", description="Medium associated with the surface layer to model surface losses.", discriminator=TYPE_TAG_STR, ) - origin: Coordinate = pydantic.Field( + origin: Coordinate = Field( (0, 0, 0), title="Origin", description="Center of the waveguide geometry. This coordinate represents the base " @@ -171,57 +173,57 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - length: Size1D = pydantic.Field( + length: Size1D = Field( 1e30, title="Length", description="Length of the waveguides in the propagation direction", units=MICROMETER, ) - propagation_axis: Axis = pydantic.Field( + propagation_axis: Axis = Field( 0, title="Propagation Axis", description="Axis of propagation of the waveguide", ) - normal_axis: Axis = pydantic.Field( + normal_axis: Axis = Field( 2, title="Normal Axis", description="Axis normal to the substrate surface", ) - mode_spec: ModeSpec = pydantic.Field( - ModeSpec(num_modes=2), + mode_spec: ModeSpec = Field( + default_factory=lambda: ModeSpec(num_modes=2), title="Mode Specification", description=":class:`ModeSpec` defining waveguide mode properties.", ) - grid_resolution: int = pydantic.Field( + grid_resolution: int = Field( 15, title="Grid Resolution", description="Solver grid resolution per wavelength.", ) - max_grid_scaling: float = pydantic.Field( + max_grid_scaling: float = Field( 1.2, title="Maximal Grid Scaling", description="Maximal size increase between adjacent grid boundaries.", ) - @pydantic.validator("wavelength", "core_width", "gap", always=True) - def _set_non_negative_array(cls, val): + @field_validator("wavelength", "core_width", "gap") + def _set_non_negative_array(val): """Ensure values are not negative and convert to numpy arrays.""" val = numpy.array(val, ndmin=1) if any(val < 0): raise ValidationError("Values may not be negative.") return val - @pydantic.validator("core_medium", "clad_medium", "box_medium") - def _check_non_metallic(cls, val, values): + @field_validator("core_medium", "clad_medium", "box_medium") + def _check_non_metallic(val, info): if val is None: return val media = val if isinstance(val, tuple) else (val,) - freqs = C_0 / values["wavelength"] + freqs = C_0 / info.data["wavelength"] if any(medium.eps_model(f).real < 1 for medium in media for f in freqs): raise ValidationError( "'RectangularDielectric' can only be used with dielectric media. " @@ -229,60 +231,57 @@ def _check_non_metallic(cls, val, values): ) return val - @pydantic.validator("gap", always=True) - @skip_if_fields_missing(["core_width"]) - def _validate_gaps(cls, val, values): + @model_validator(mode="after") + def _validate_gaps(self): """Ensure the number of gaps is compatible with the number of cores supplied.""" - if val.size == 1 and values["core_width"].size != 2: + if self.gap.size == 1 and self.core_width.size != 2: # If a single value is defined, use it for all gaps - return numpy.array([val[0]] * (values["core_width"].size - 1)) - if val.size != values["core_width"].size - 1: + object.__setattr__(self, "gap", numpy.array([self.gap[0]] * (self.core_width.size - 1))) + return self + if self.gap.size != self.core_width.size - 1: raise ValidationError("Number of gaps must be 1 less than number of core widths.") - return val + return self - @pydantic.root_validator - def _set_box_medium(cls, values): + @model_validator(mode="after") + def _set_box_medium(self): """Set BOX medium same as cladding as default value.""" - box_medium = values.get("box_medium") - if box_medium is None: - clad_medium = values.get("clad_medium") - if clad_medium is None: - return values - if isinstance(clad_medium, tuple): - clad_medium = clad_medium[0] - values["box_medium"] = clad_medium - return values - - @pydantic.root_validator - def _set_clad_thickness(cls, values): + if self.box_medium is None: + if self.clad_medium is None: + return self + if isinstance(self.clad_medium, tuple): + clad_medium = self.clad_medium[0] + object.__setattr__(self, "box_medium", clad_medium) + return self + + @model_validator(mode="after") + def _set_clad_thickness(self): """Set default clad/BOX thickness based on the max wavelength in the medium.""" for side in ("clad", "box"): - val = values.get(side + "_thickness") + val = getattr(self, side + "_thickness") if val is None: - wavelength = values.get("wavelength") - medium = values.get(side + "_medium") - if wavelength is None or medium is None: - return values + medium = getattr(self, side + "_medium") + if self.wavelength is None or medium is None: + return self if isinstance(medium, tuple): medium = medium[0] - n = numpy.array([medium.nk_model(f)[0] for f in C_0 / wavelength]) - lda = wavelength / n - values[side + "_thickness"] = EVANESCENT_TAIL * lda.max() + n = numpy.array([medium.nk_model(f)[0] for f in C_0 / self.wavelength]) + lda = self.wavelength / n + object.__setattr__(self, side + "_thickness", EVANESCENT_TAIL * lda.max()) elif isinstance(val, float): if val < 0: raise ValidationError("Thickness may not be negative.") else: - values[side + "_thickness"] = cls._set_non_negative_array(val) - return values + object.__setattr__(self, side + "_thickness", self._set_non_negative_array(val)) + return self - @pydantic.root_validator - def _validate_layers(cls, values): + @model_validator(mode="after") + def _validate_layers(self): """Ensure the number of clad media is compatible with the number of layers supplied.""" for side in ("clad", "box"): - thickness = values.get(side + "_thickness") - medium = values.get(side + "_medium") + thickness = getattr(self, side + "_thickness") + medium = getattr(self, side + "_medium") if thickness is None or medium is None: - return values + return self num_layers = 1 if isinstance(thickness, float) else thickness.size num_media = 1 if not isinstance(medium, tuple) else len(medium) if num_layers != num_media: @@ -290,63 +289,56 @@ def _validate_layers(cls, values): f"Number of '{side}_thickness' values ({num_layers}) must be equal to that of " f"'{side}_medium' ({num_media})." ) - return values + return self - @pydantic.root_validator - def _set_side_margin(cls, values): + @model_validator(mode="after") + def _set_side_margin(self): """Set default side margin based on BOX and cladding thicknesses.""" - clad_thickness = values.get("clad_thickness") - box_thickness = values.get("box_thickness") + clad_thickness = self.clad_thickness + box_thickness = self.box_thickness if clad_thickness is None or box_thickness is None: - return values - if values["side_margin"] is None: + return self + if self.side_margin is None: if not isinstance(clad_thickness, float): clad_thickness = clad_thickness.sum() if not isinstance(box_thickness, float): box_thickness = box_thickness.sum() - values["side_margin"] = max(clad_thickness, box_thickness) - return values + self.side_margin = max(clad_thickness, box_thickness) + return self - @pydantic.root_validator - def _ensure_consistency(cls, values): + @model_validator(mode="after") + def _ensure_consistency(self): """Ensure consistency in setting surface/sidewall models and propagation/normal axes.""" - sidewall_thickness = values["sidewall_thickness"] - sidewall_medium = values["sidewall_medium"] - surface_thickness = values["surface_thickness"] - surface_medium = values["surface_medium"] - propagation_axis = values["propagation_axis"] - normal_axis = values["normal_axis"] - - if sidewall_thickness > 0 and sidewall_medium is None: + if self.sidewall_thickness > 0 and self.sidewall_medium is None: raise ValidationError( "Sidewall medium must be provided when sidewall thickness is greater than 0." ) - if sidewall_thickness == 0 and sidewall_medium is not None: + if self.sidewall_thickness == 0 and self.sidewall_medium is not None: log.warning("Sidewall medium not used because sidewall thickness is zero.") - if surface_thickness > 0 and surface_medium is None: + if self.surface_thickness > 0 and self.surface_medium is None: raise ValidationError( "Surface medium must be provided when surface thickness is greater than 0." ) - if surface_thickness == 0 and surface_medium is not None: + if self.surface_thickness == 0 and self.surface_medium is not None: log.warning("Surface medium not used because surface thickness is zero.") - if propagation_axis == normal_axis: + if self.propagation_axis == self.normal_axis: raise ValidationError("Propagation and normal axes must be different.") - return values + return self @property - def _clad_medium(self) -> Tuple[MediumType, ...]: + def _clad_medium(self) -> tuple[MediumType, ...]: """Normalize data type to tuple.""" if not isinstance(self.clad_medium, tuple): return (self.clad_medium,) return self.clad_medium @property - def _box_medium(self) -> Tuple[MediumType, ...]: + def _box_medium(self) -> tuple[MediumType, ...]: """Normalize data type to tuple.""" if not isinstance(self.box_medium, tuple): return (self.box_medium,) @@ -373,7 +365,7 @@ def lateral_axis(self) -> Axis: def _swap_axis( self, lateral_coord: Any, normal_coord: Any, propagation_coord: Any - ) -> List[Any]: + ) -> list[Any]: """Swap the model coordinates to desired axes.""" result = [None, None, None] result[self.lateral_axis] = lateral_coord @@ -383,13 +375,13 @@ def _swap_axis( def _translate( self, lateral_coord: float, normal_coord: float, propagation_coord: float - ) -> List[float]: + ) -> list[float]: """Swap the model coordinates to desired axes and translate to origin.""" coordinates = self._swap_axis(lateral_coord, normal_coord, propagation_coord) result = [a + b for a, b in zip(self.origin, coordinates)] return result - def _transform_in_plane(self, lateral_coord: float, propagation_coord: float) -> List[float]: + def _transform_in_plane(self, lateral_coord: float, propagation_coord: float) -> list[float]: """Swap the model coordinates to desired axes in the substrate plane.""" result = self._translate(lateral_coord, 0, propagation_coord) _, result = Box.pop_axis(result, self.normal_axis) @@ -409,14 +401,14 @@ def width(self) -> Size1D: return w @property - def _core_starts(self) -> List[float]: + def _core_starts(self) -> list[float]: """Starting positions of each waveguide (x is the position in the lateral direction).""" core_x = [-0.5 * (self.core_width.sum() + self.gap.sum())] core_x.extend(core_x[0] + numpy.cumsum(self.core_width[:-1]) + numpy.cumsum(self.gap)) return core_x @property - def _override_structures(self) -> List[Structure]: + def _override_structures(self) -> list[Structure]: """Build override structures to define the simulation grid.""" # Grid resolution factor applied to the materials (increase for waveguide corners @@ -547,7 +539,7 @@ def grid_spec(self) -> GridSpec: return grid_spec @cached_property - def structures(self) -> List[Structure]: + def structures(self) -> list[Structure]: """Waveguide structures for simulation, including the core(s), slabs (if any), and bottom cladding, if different from the top. For bend modes, the structure is a 270 degree bend regardless of :attr:`length`.""" diff --git a/tidy3d/updater.py b/tidy3d/updater.py index 6e6228e3a9..c310fd7e01 100644 --- a/tidy3d/updater.py +++ b/tidy3d/updater.py @@ -4,10 +4,10 @@ import functools import json -from typing import Callable, Dict +from typing import Callable -import pydantic.v1 as pd import yaml +from pydantic import BaseModel from .components.base import Tidy3dBaseModel from .exceptions import FileError, SetupError @@ -17,7 +17,7 @@ """Storing version numbers.""" -class Version(pd.BaseModel): +class Version(BaseModel): """Stores a version number (excluding patch).""" major: int @@ -83,7 +83,7 @@ def __ge__(self, other): """Class for updating simulation objects.""" -class Updater(pd.BaseModel): +class Updater(BaseModel): """Converts a tidy3d simulation.json file to an up-to-date Simulation instance.""" sim_dict: dict @@ -186,7 +186,7 @@ def new_update_function(sim_dict: dict) -> dict: return decorator -def iterate_update_dict(update_dict: Dict, update_types: Dict[str, Callable]): +def iterate_update_dict(update_dict: dict, update_types: dict[str, Callable]): """Recursively iterate nested ``update_dict``. For any nested ``nested_dict`` found, apply an update function if its ``nested_dict["type"]`` is in the keys of the ``update_types`` dictionary. Also iterates lists and tuples. diff --git a/tidy3d/web/api/asynchronous.py b/tidy3d/web/api/asynchronous.py index fc50b1d349..1f47d44227 100644 --- a/tidy3d/web/api/asynchronous.py +++ b/tidy3d/web/api/asynchronous.py @@ -1,6 +1,6 @@ """Interface to run several jobs in batch using simplified syntax.""" -from typing import Dict, List, Literal, Union +from typing import Literal, Union from ...log import log from ..core.types import PayType @@ -9,14 +9,14 @@ def run_async( - simulations: Dict[str, SimulationType], + simulations: dict[str, SimulationType], folder_name: str = "default", path_dir: str = DEFAULT_DATA_DIR, callback_url: str = None, num_workers: int = None, verbose: bool = True, simulation_type: str = "tidy3d", - parent_tasks: Dict[str, List[str]] = None, + parent_tasks: dict[str, list[str]] = None, reduce_simulation: Literal["auto", True, False] = "auto", pay_type: Union[PayType, str] = PayType.AUTO, ) -> BatchData: @@ -27,7 +27,7 @@ def run_async( Parameters ---------- - simulations : Dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]] + simulations : dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]] Mapping of task name to simulation. folder_name : str = "default" Name of folder to store each task on web UI. diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index c5e1f506b2..21e0b320a7 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -2,10 +2,10 @@ import os import tempfile -import typing from collections import defaultdict from os.path import basename, dirname, join from pathlib import Path +from typing import Any, Callable, Literal, Union import numpy as np from autograd.builtins import dict as dict_ag @@ -14,7 +14,6 @@ import tidy3d as td from tidy3d.components.autograd import AutogradFieldMap, get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.types import Literal from ....exceptions import AdjointError from ...core.s3utils import download_file, upload_file @@ -98,8 +97,8 @@ def run( path: str = "simulation_data.hdf5", callback_url: str = None, verbose: bool = True, - progress_callback_upload: typing.Callable[[float], None] = None, - progress_callback_download: typing.Callable[[float], None] = None, + progress_callback_upload: Callable[[float], None] = None, + progress_callback_download: Callable[[float], None] = None, solver_version: str = None, worker_group: str = None, simulation_type: str = "tidy3d", @@ -107,7 +106,7 @@ def run( local_gradient: bool = LOCAL_GRADIENT, max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD, reduce_simulation: Literal["auto", True, False] = "auto", - pay_type: typing.Union[PayType, str] = PayType.AUTO, + pay_type: Union[PayType, str] = PayType.AUTO, ) -> SimulationDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -145,7 +144,7 @@ def run( Maximum number of adjoint simulations allowed to run automatically. reduce_simulation: Literal["auto", True, False] = "auto" Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver. - pay_type: typing.Union[PayType, str] = PayType.AUTO + pay_type: Union[PayType, str] = PayType.AUTO Which method to pay for the simulation. Returns ------- @@ -240,7 +239,7 @@ def run_async( local_gradient: bool = LOCAL_GRADIENT, max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD, reduce_simulation: Literal["auto", True, False] = "auto", - pay_type: typing.Union[PayType, str] = PayType.AUTO, + pay_type: Union[PayType, str] = PayType.AUTO, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -269,7 +268,7 @@ def run_async( Maximum number of adjoint simulations allowed to run automatically. reduce_simulation: Literal["auto", True, False] = "auto" Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver. - pay_type: typing.Union[PayType, str] = PayType.AUTO + pay_type: Union[PayType, str] = PayType.AUTO Specify the payment method. Returns @@ -478,7 +477,7 @@ def _run_primitive( def _run_async_primitive( sim_fields_dict: dict[str, AutogradFieldMap], sims_original: dict[str, td.Simulation], - aux_data_dict: dict[dict[str, typing.Any]], + aux_data_dict: dict[dict[str, Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, **run_async_kwargs, @@ -622,7 +621,7 @@ def _run_bwd( local_gradient: bool, max_num_adjoint_per_fwd: int, **run_kwargs, -) -> typing.Callable[[AutogradFieldMap], AutogradFieldMap]: +) -> Callable[[AutogradFieldMap], AutogradFieldMap]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulations, computes grad.""" # indicate this is an adjoint run @@ -742,11 +741,11 @@ def _run_async_bwd( data_fields_original_dict: dict[str, AutogradFieldMap], sim_fields_original_dict: dict[str, AutogradFieldMap], sims_original: dict[str, td.Simulation], - aux_data_dict: dict[str, dict[str, typing.Any]], + aux_data_dict: dict[str, dict[str, Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, **run_async_kwargs, -) -> typing.Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: +) -> Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulation, computes grad.""" # indicate this is an adjoint run diff --git a/tidy3d/web/api/autograd/utils.py b/tidy3d/web/api/autograd/utils.py index 1d5cd136c3..d42ba40aa1 100644 --- a/tidy3d/web/api/autograd/utils.py +++ b/tidy3d/web/api/autograd/utils.py @@ -1,14 +1,14 @@ # utility functions for autograd web API from __future__ import annotations -import typing +from typing import Any, Union -import pydantic as pd +from pydantic import Field import tidy3d as td from tidy3d.components.autograd.types import AutogradFieldMap, dict_ag from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import ArrayLike, tidycomplex +from tidy3d.components.types import ArrayLike, Complex """ E and D field gradient map calculation helpers. """ @@ -48,11 +48,11 @@ def E_to_D(fld_data: td.FieldData, eps_data: td.PermittivityData) -> td.FieldDat def multiply_field_data( - fld_1: td.FieldData, fld_2: typing.Union[td.FieldData, td.PermittivityData] + fld_1: td.FieldData, fld_2: Union[td.FieldData, td.PermittivityData] ) -> td.FieldData: """Elementwise multiply two field data objects, writes data into ``fld_1`` copy.""" - def get_field_key(dim: str, fld_data: typing.Union[td.FieldData, td.PermittivityData]) -> str: + def get_field_key(dim: str, fld_data: Union[td.FieldData, td.PermittivityData]) -> str: """Get the key corresponding to the scalar field along this dimension.""" return f"E{dim}" if isinstance(fld_data, td.FieldData) else f"eps_{dim}{dim}" @@ -70,21 +70,17 @@ def get_field_key(dim: str, fld_data: typing.Union[td.FieldData, td.Permittivity class Tracer(Tidy3dBaseModel): """Class to store a single traced field.""" - path: tuple[typing.Any, ...] = pd.Field( - ..., + path: tuple[Any, ...] = Field( title="Path to the traced object in the model dictionary.", ) - data: typing.Union[float, tidycomplex, ArrayLike] = pd.Field(..., title="Tracing data") + data: Union[float, Complex, ArrayLike] = Field(title="Tracing data") class FieldMap(Tidy3dBaseModel): """Class to store a collection of traced fields.""" - tracers: tuple[Tracer, ...] = pd.Field( - ..., - title="Collection of Tracers.", - ) + tracers: tuple[Tracer, ...] = Field(title="Collection of Tracers.") @property def to_autograd_field_map(self) -> AutogradFieldMap: @@ -104,7 +100,4 @@ def from_autograd_field_map(cls, autograd_field_map) -> FieldMap: class TracerKeys(Tidy3dBaseModel): """Class to store a collection of tracer keys.""" - keys: tuple[tuple[typing.Any, ...], ...] = pd.Field( - ..., - title="Collection of tracer keys.", - ) + keys: tuple[tuple[Any, ...], ...] = Field(title="Collection of tracer keys.") diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 851cd35acd..8e0f36b562 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -8,14 +8,14 @@ from abc import ABC from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Optional, Tuple +from typing import Literal, Optional -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, PrivateAttr from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn from ...components.base import Tidy3dBaseModel, cached_property from ...components.mode.mode_solver import ModeSolver -from ...components.types import Literal, annotate_type +from ...components.types import discriminated_union from ...exceptions import DataError from ...log import get_logging_console, log from ..api import webapi as web @@ -126,20 +126,24 @@ class Job(WebContainer): * `Inverse taper edge coupler <../../notebooks/EdgeCoupler.html>`_ """ - simulation: SimulationType = pd.Field( - ..., + simulation: SimulationType = Field( title="simulation", description="Simulation to run as a 'task'.", discriminator="type", ) - task_name: TaskName = pd.Field(..., title="Task Name", description="Unique name of the task.") + task_name: TaskName = Field( + title="Task Name", + description="Unique name of the task.", + ) - folder_name: str = pd.Field( - "default", title="Folder Name", description="Name of folder to store task on web UI." + folder_name: str = Field( + "default", + title="Folder Name", + description="Name of folder to store task on web UI.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -147,28 +151,32 @@ class Job(WebContainer): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - solver_version: str = pd.Field( + solver_version: Optional[str] = Field( None, title="Solver Version", description="Custom solver version to use, " "otherwise uses default for the current front end version.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) - simulation_type: str = pd.Field( + simulation_type: str = Field( "tidy3d", title="Simulation Type", description="Type of simulation, used internally only.", ) - parent_tasks: Tuple[TaskId, ...] = pd.Field( - None, title="Parent Tasks", description="Tuple of parent task ids, used internally only." + parent_tasks: Optional[tuple[TaskId, ...]] = Field( + None, + title="Parent Tasks", + description="Tuple of parent task ids, used internally only.", ) - task_id_cached: TaskId = pd.Field( + task_id_cached: Optional[TaskId] = Field( None, title="Task ID (Cached)", description="Optional field to specify ``task_id``. Only used as a workaround internally " @@ -177,28 +185,30 @@ class Job(WebContainer): "fields that were not used to create the task will cause errors.", ) - reduce_simulation: Literal["auto", True, False] = pd.Field( + reduce_simulation: Literal["auto", True, False] = Field( "auto", title="Reduce Simulation", description="Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.", ) - pay_type: PayType = pd.Field( + pay_type: PayType = Field( PayType.AUTO, title="Payment Type", description="Specify the payment method.", ) - _upload_fields = ( - "simulation", - "task_name", - "folder_name", - "callback_url", - "verbose", - "simulation_type", - "parent_tasks", - "solver_version", - "reduce_simulation", + _upload_fields: tuple[str, ...] = PrivateAttr( + ( + "simulation", + "task_name", + "folder_name", + "callback_url", + "verbose", + "simulation_type", + "parent_tasks", + "solver_version", + "reduce_simulation", + ) ) def to_file(self, fname: str) -> None: @@ -411,18 +421,20 @@ class BatchData(Tidy3dBaseModel, Mapping): * `Performing parallel / batch processing of simulations <../../notebooks/ParameterScan.html>`_ """ - task_paths: Dict[TaskName, str] = pd.Field( - ..., + task_paths: dict[TaskName, str] = Field( title="Data Paths", description="Mapping of task_name to path to corresponding data for each task in batch.", ) - task_ids: Dict[TaskName, str] = pd.Field( - ..., title="Task IDs", description="Mapping of task_name to task_id for each task in batch." + task_ids: dict[TaskName, str] = Field( + title="Task IDs", + description="Mapping of task_name to task_id for each task in batch.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) def load_sim_data(self, task_name: str) -> SimulationDataType: @@ -499,30 +511,31 @@ class Batch(WebContainer): * `Inverse taper edge coupler <../../notebooks/EdgeCoupler.html>`_ """ - simulations: Dict[TaskName, annotate_type(SimulationType)] = pd.Field( - ..., + simulations: dict[TaskName, discriminated_union(SimulationType)] = Field( title="Simulations", description="Mapping of task names to Simulations to run as a batch.", ) - folder_name: str = pd.Field( + folder_name: str = Field( "default", title="Folder Name", description="Name of folder to store member of each batch on web UI.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) - solver_version: str = pd.Field( + solver_version: Optional[str] = Field( None, title="Solver Version", description="Custom solver version to use, " "otherwise uses default for the current front end version.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -530,19 +543,19 @@ class Batch(WebContainer): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - simulation_type: str = pd.Field( + simulation_type: str = Field( "tidy3d", title="Simulation Type", description="Type of each simulation in the batch, used internally only.", ) - parent_tasks: Dict[str, Tuple[TaskId, ...]] = pd.Field( + parent_tasks: Optional[dict[str, tuple[TaskId, ...]]] = Field( None, title="Parent Tasks", description="Collection of parent task ids for each job in batch, used internally only.", ) - num_workers: Optional[pd.PositiveInt] = pd.Field( + num_workers: Optional[PositiveInt] = Field( DEFAULT_NUM_WORKERS, title="Number of Workers", description="Number of workers for multi-threading upload and download of batch. " @@ -551,19 +564,19 @@ class Batch(WebContainer): "number of threads available on the system.", ) - reduce_simulation: Literal["auto", True, False] = pd.Field( + reduce_simulation: Literal["auto", True, False] = Field( "auto", title="Reduce Simulation", description="Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.", ) - pay_type: PayType = pd.Field( + pay_type: PayType = Field( PayType.AUTO, title="Payment Type", description="Specify the payment method.", ) - jobs_cached: Dict[TaskName, Job] = pd.Field( + jobs_cached: Optional[dict[TaskName, Job]] = Field( None, title="Jobs (Cached)", description="Optional field to specify ``jobs``. Only used as a workaround internally " @@ -572,7 +585,7 @@ class Batch(WebContainer): "fields that were not used to create the task will cause errors.", ) - _job_type = Job + _job_type: type = PrivateAttr(Job) def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: """Upload and run each simulation in :class:`Batch`. @@ -610,7 +623,7 @@ def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData: return self.load(path_dir=path_dir) @cached_property - def jobs(self) -> Dict[TaskName, Job]: + def jobs(self) -> dict[TaskName, Job]: """Create a series of tasks in the :class:`.Batch` and upload them to server. Note @@ -629,7 +642,7 @@ def jobs(self) -> Dict[TaskName, Job]: for task_name, simulation in self.simulations.items(): job_kwargs = {} - for key in JobType._upload_fields: + for key in JobType._upload_fields.default: if key in self_dict: job_kwargs[key] = self_dict.get(key) @@ -694,12 +707,12 @@ def upload(self) -> None: completed += 1 progress.update(pbar, completed=completed) - def get_info(self) -> Dict[TaskName, TaskInfo]: + def get_info(self) -> dict[TaskName, TaskInfo]: """Get information about each task in the :class:`Batch`. Returns ------- - Dict[str, :class:`TaskInfo`] + dict[str, :class:`TaskInfo`] Mapping of task name to data about task associated with each task. """ info_dict = {} @@ -723,12 +736,12 @@ def start(self) -> None: for _, job in self.jobs.items(): executor.submit(job.start) - def get_run_info(self) -> Dict[TaskName, RunInfo]: + def get_run_info(self) -> dict[TaskName, RunInfo]: """get information about a each of the tasks in the :class:`Batch`. Returns ------- - Dict[str: :class:`RunInfo`] + dict[str: :class:`RunInfo`] Maps task names to run info for each task in the :class:`Batch`. """ run_info_dict = {} diff --git a/tidy3d/web/api/material_fitter.py b/tidy3d/web/api/material_fitter.py index 0d91f825e5..00c5f34704 100644 --- a/tidy3d/web/api/material_fitter.py +++ b/tidy3d/web/api/material_fitter.py @@ -10,7 +10,7 @@ import numpy as np import requests -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from ...plugins.dispersion import DispersionFitter from ..core.http_util import http @@ -50,16 +50,27 @@ class _FitterRequest(BaseModel): class MaterialFitterTask(Submittable): """Material Fitter Task.""" - id: str = Field(title="Task ID", description="Task ID") + id: str = Field( + title="Task ID", + description="Task ID", + ) dispersion_fitter: DispersionFitter = Field( - title="Dispersion Fitter", description="Dispersion Fitter data" + title="Dispersion Fitter", + description="Dispersion Fitter data", + ) + status: str = Field( + title="Task Status", + description="Task Status", ) - status: str = Field(title="Task Status", description="Task Status") file_name: str = Field( - ..., title="file name", description="fitter data file name", alias="fileName" + title="file name", + description="fitter data file name", + alias="fileName", ) resource_path: str = Field( - ..., title="resource path", description="resource path", alias="resourcePath" + title="resource path", + description="resource path", + alias="resourcePath", ) @classmethod diff --git a/tidy3d/web/api/material_library.py b/tidy3d/web/api/material_library.py new file mode 100644 index 0000000000..6060a3d84e --- /dev/null +++ b/tidy3d/web/api/material_library.py @@ -0,0 +1,61 @@ +"""Material Library API.""" + +from __future__ import annotations + +import json +from typing import Optional + +from pydantic import Field, field_validator, parse_obj_as + +from ...components.medium import MediumType +from ..core.http_util import http +from ..core.types import Queryable + + +class MaterialLibrary(Queryable): + """Material Library Resource interface.""" + + id: str = Field( + title="Material Library ID", + description="Material Library ID", + ) + name: str = Field( + title="Material Library Name", + description="Material Library Name", + ) + medium: Optional[MediumType] = Field( + None, + title="medium", + description="medium", + alias="calcResult", + ) + medium_type: Optional[str] = Field( + None, + title="medium type", + description="medium type", + alias="mediumType", + ) + json_input: Optional[dict] = Field( + None, + title="json input", + description="original input", + alias="jsonInput", + ) + + @field_validator("medium", "json_input", mode="before") + @classmethod + def parse_result(cls, values): + """Automatically parsing medium and json_input from string to object.""" + return json.loads(values) + + @classmethod + def list(cls) -> list[MaterialLibrary]: + """List all material libraries. + + Returns + ------- + tasks : list[:class:`.MaterialLibrary`] + List of material libraries/ + """ + resp = http.get("tidy3d/libraries") + return parse_obj_as(list[MaterialLibrary], resp) if resp else None diff --git a/tidy3d/web/api/material_libray.py b/tidy3d/web/api/material_libray.py deleted file mode 100644 index 3b9e397f2c..0000000000 --- a/tidy3d/web/api/material_libray.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Material Library API.""" - -from __future__ import annotations - -import json -from typing import List, Optional - -from pydantic.v1 import Field, parse_obj_as, validator - -from ...components.medium import MediumType -from ..core.http_util import http -from ..core.types import Queryable - - -class MaterialLibray(Queryable, smart_union=True): - """Material Library Resource interface.""" - - id: str = Field(title="Material Library ID", description="Material Library ID") - name: str = Field(title="Material Library Name", description="Material Library Name") - medium: Optional[MediumType] = Field(title="medium", description="medium", alias="calcResult") - medium_type: Optional[str] = Field( - title="medium type", description="medium type", alias="mediumType" - ) - json_input: Optional[dict] = Field( - title="json input", description="original input", alias="jsonInput" - ) - - @validator("medium", "json_input", pre=True) - def parse_result(cls, values): - """Automatically parsing medium and json_input from string to object.""" - return json.loads(values) - - @classmethod - def list(cls) -> List[MaterialLibray]: - """List all material libraries. - - Returns - ------- - tasks : List[:class:`.MaterialLibray`] - List of material libraries/ - """ - resp = http.get("tidy3d/libraries") - return parse_obj_as(List[MaterialLibray], resp) if resp else None diff --git a/tidy3d/web/api/mode.py b/tidy3d/web/api/mode.py index 09024aeff8..78017c566c 100644 --- a/tidy3d/web/api/mode.py +++ b/tidy3d/web/api/mode.py @@ -7,18 +7,17 @@ import tempfile import time from datetime import datetime -from typing import Callable, List, Optional, Union +from typing import Callable, Literal, Optional, Union -import pydantic.v1 as pydantic from botocore.exceptions import ClientError from joblib import Parallel, delayed +from pydantic import Extra, Field from rich.progress import Progress from ...components.data.monitor_data import ModeSolverData from ...components.eme.simulation import EMESimulation from ...components.medium import AbstractCustomMedium from ...components.simulation import Simulation -from ...components.types import Literal from ...exceptions import SetupError, WebError from ...log import get_logging_console, log from ...plugins.mode.mode_solver import MODE_MONITOR_NAME, ModeSolver @@ -147,29 +146,29 @@ def run( def run_batch( - mode_solvers: List[ModeSolver], + mode_solvers: list[ModeSolver], task_name: str = "BatchModeSolver", folder_name: str = "BatchModeSolvers", - results_files: List[str] = None, + results_files: list[str] = None, verbose: bool = True, max_workers: int = DEFAULT_NUM_WORKERS, max_retries: int = DEFAULT_MAX_RETRIES, retry_delay: float = DEFAULT_RETRY_DELAY, progress_callback_upload: Callable[[float], None] = None, progress_callback_download: Callable[[float], None] = None, -) -> List[ModeSolverData]: +) -> list[ModeSolverData]: """ Submits a batch of ModeSolver to the server concurrently, manages progress, and retrieves results. Parameters ---------- - mode_solvers : List[ModeSolver] + mode_solvers : list[ModeSolver] List of mode solvers to be submitted to the server. task_name : str Base name for tasks. Each task in the batch will have a unique index appended to this base name. folder_name : str Name of the folder where tasks are stored on the server's web UI. - results_files : List[str], optional + results_files : list[str], optional List of file paths where the results for each ModeSolver should be downloaded. If None, a default path based on the folder name and index is used. verbose : bool If True, displays a progress bar. If False, runs silently. @@ -187,7 +186,7 @@ def run_batch( Returns ------- - List[ModeSolverData] + list[ModeSolverData] A list of ModeSolverData objects containing the results from each simulation in the batch. ``None`` is placed in the list for simulations that fail after all retries. """ console = get_logging_console() @@ -253,45 +252,45 @@ def handle_mode_solver(index, progress, pbar): return results -class ModeSolverTask(ResourceLifecycle, Submittable, extra=pydantic.Extra.allow): +class ModeSolverTask(ResourceLifecycle, Submittable, extra=Extra.allow): """Interface for managing the running of a :class:`.ModeSolver` task on server.""" - task_id: str = pydantic.Field( + task_id: str = Field( None, title="task_id", description="Task ID number, set when the task is created, leave as None.", alias="refId", ) - solver_id: str = pydantic.Field( + solver_id: str = Field( None, title="solver", description="Solver ID number, set when the task is created, leave as None.", alias="id", ) - real_flex_unit: float = pydantic.Field( + real_flex_unit: float = Field( None, title="real FlexCredits", description="Billed FlexCredits.", alias="charge" ) - created_at: Optional[datetime] = pydantic.Field( + created_at: Optional[datetime] = Field( title="created_at", description="Time at which this task was created.", alias="createdAt" ) - status: str = pydantic.Field( + status: str = Field( None, title="status", description="Mode solver task status.", ) - file_type: str = pydantic.Field( + file_type: str = Field( None, title="file_type", description="File type used to upload the mode solver.", alias="fileType", ) - mode_solver: ModeSolver = pydantic.Field( + mode_solver: ModeSolver = Field( None, title="mode_solver", description="Mode solver being run by this task.", diff --git a/tidy3d/web/api/tidy3d_stub.py b/tidy3d/web/api/tidy3d_stub.py index 5a5af337d1..45692e8fae 100644 --- a/tidy3d/web/api/tidy3d_stub.py +++ b/tidy3d/web/api/tidy3d_stub.py @@ -3,10 +3,9 @@ from __future__ import annotations import json -from typing import Callable, List, Union +from typing import Callable, Union -import pydantic.v1 as pd -from pydantic.v1 import BaseModel +from pydantic import BaseModel, Field from tidy3d.components.tcad.data.sim_data import HeatChargeSimulationData, HeatSimulationData from tidy3d.components.tcad.simulation.heat import HeatSimulation @@ -44,7 +43,7 @@ class Tidy3dStub(BaseModel, TaskStub): - simulation: SimulationType = pd.Field(discriminator="type") + simulation: SimulationType = Field(discriminator="type") @classmethod def from_file(cls, file_path: str) -> SimulationType: @@ -109,7 +108,7 @@ def to_file( """ self.simulation.to_file(file_path) - def to_hdf5_gz(self, fname: str, custom_encoders: List[Callable] = None) -> None: + def to_hdf5_gz(self, fname: str, custom_encoders: list[Callable] = None) -> None: """Exports Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] instance to .hdf5.gz file. Parameters @@ -117,7 +116,7 @@ def to_hdf5_gz(self, fname: str, custom_encoders: List[Callable] = None) -> None fname : str Full path to the .hdf5.gz file to save the Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] to. - custom_encoders : List[Callable] + custom_encoders : list[Callable] List of functions accepting (fname: str, group_path: str, value: Any) that take the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index d0d42c7663..50e52ffea8 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -5,7 +5,7 @@ import tempfile import time from datetime import datetime, timedelta -from typing import Callable, Dict, List, Union +from typing import Callable, Union import pytz from requests import HTTPError @@ -30,12 +30,7 @@ from ..core.task_core import Folder, SimulationTask from ..core.task_info import ChargeType, TaskInfo from ..core.types import PayType -from .connect_util import ( - REFRESH_TIME, - get_grid_points_str, - get_time_steps_str, - wait_for_connection, -) +from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection from .tidy3d_stub import SimulationDataType, SimulationType, Tidy3dStub, Tidy3dStubData # time between checking run status @@ -202,7 +197,7 @@ def upload( verbose: bool = True, progress_callback: Callable[[float], None] = None, simulation_type: str = "tidy3d", - parent_tasks: List[str] = None, + parent_tasks: list[str] = None, source_required: bool = True, solver_version: str = None, reduce_simulation: Literal["auto", True, False] = "auto", @@ -227,7 +222,7 @@ def upload( Optional callback function called when uploading file with ``bytes_in_chunk`` as argument. simulation_type : str = "tidy3d" Type of simulation being uploaded. - parent_tasks : List[str] + parent_tasks : list[str] List of related task ids. source_required: bool = True If ``True``, simulations without sources will raise an error before being uploaded. @@ -909,7 +904,7 @@ def abort(task_id: TaskId): @wait_for_connection def get_tasks( num_tasks: int = None, order: Literal["new", "old"] = "new", folder: str = "default" -) -> List[Dict]: +) -> list[dict]: """Get a list with the metadata of the last ``num_tasks`` tasks. Parameters @@ -923,7 +918,7 @@ def get_tasks( Returns ------- - List[Dict] + list[dict] List of dictionaries storing the information for each of the tasks last ``num_tasks`` tasks. """ folder = Folder.get(folder, create=True) diff --git a/tidy3d/web/core/account.py b/tidy3d/web/core/account.py index c505b66487..3a89720eda 100644 --- a/tidy3d/web/core/account.py +++ b/tidy3d/web/core/account.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Optional -from pydantic.v1 import Extra, Field +from pydantic import Extra, Field from .http_util import http from .types import Tidy3DResource diff --git a/tidy3d/web/core/environment.py b/tidy3d/web/core/environment.py index d7aa016f53..6e62a216d3 100644 --- a/tidy3d/web/core/environment.py +++ b/tidy3d/web/core/environment.py @@ -2,8 +2,10 @@ import os import ssl +from typing import Optional -from pydantic.v1 import BaseSettings, Field +from pydantic import Field +from pydantic_settings import BaseSettings from .core_config import get_logger @@ -19,8 +21,8 @@ def __hash__(self): website_endpoint: str s3_region: str ssl_verify: bool = Field(True, env="TIDY3D_SSL_VERIFY") - enable_caching: bool = None - ssl_version: ssl.TLSVersion = None + enable_caching: Optional[bool] = None + ssl_version: Optional[ssl.TLSVersion] = None def active(self) -> None: """Activate the environment instance.""" diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index 1535959465..25e795d519 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -4,7 +4,6 @@ from enum import Enum from functools import wraps from os.path import expanduser -from typing import Dict import requests import toml @@ -104,12 +103,12 @@ def api_key_auth(request: requests.request) -> requests.request: return request -def get_headers() -> Dict[str, str]: +def get_headers() -> dict[str, str]: """get headers for http request. Returns ------- - Dict[str, str] + dict[str, str] dictionary with "Authorization" and "Application" keys. """ return { diff --git a/tidy3d/web/core/s3utils.py b/tidy3d/web/core/s3utils.py index b687cc11d0..faa6063726 100644 --- a/tidy3d/web/core/s3utils.py +++ b/tidy3d/web/core/s3utils.py @@ -10,7 +10,7 @@ import boto3 from boto3.s3.transfer import TransferConfig -from pydantic.v1 import BaseModel, Field +from pydantic import BaseModel, Field from rich.progress import ( BarColumn, DownloadColumn, diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index 3447369c19..90f859dec3 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -6,11 +6,10 @@ import pathlib import tempfile from datetime import datetime -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Union -import pydantic.v1 as pd from botocore.exceptions import ClientError -from pydantic.v1 import Extra, Field, parse_obj_as +from pydantic import Extra, Field, model_validator, parse_obj_as import tidy3d as td @@ -30,9 +29,15 @@ class Folder(Tidy3DResource, Queryable, extra=Extra.allow): """Tidy3D Folder.""" - folder_id: str = Field(..., title="Folder id", description="folder id", alias="projectId") + folder_id: str = Field( + title="Folder id", + description="folder id", + alias="projectId", + ) folder_name: str = Field( - ..., title="Folder name", description="folder name", alias="projectName" + title="Folder name", + description="folder name", + alias="projectName", ) @classmethod @@ -47,7 +52,7 @@ def list(cls) -> []: resp = http.get("tidy3d/projects") return ( parse_obj_as( - List[Folder], + list[Folder], resp, ) if resp @@ -101,18 +106,18 @@ def delete(self): http.delete(f"tidy3d/projects/{self.folder_id}") - def list_tasks(self) -> List[Tidy3DResource]: + def list_tasks(self) -> list[Tidy3DResource]: """List all tasks in this folder. Returns ------- - tasks : List[:class:`.SimulationTask`] + tasks : list[:class:`.SimulationTask`] List of tasks in this folder """ resp = http.get(f"tidy3d/projects/{self.folder_id}/tasks") return ( parse_obj_as( - List[SimulationTask], + list[SimulationTask], resp, ) if resp @@ -124,7 +129,7 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): """Interface for managing the running of a :class:`.Simulation` task on server.""" task_id: Optional[str] = Field( - ..., + None, title="task_id", description="Task ID number, set when the task is uploaded, leave as None.", alias="taskId", @@ -135,18 +140,31 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): description="Folder ID number, set when the task is uploaded, leave as None.", alias="folderId", ) - status: Optional[str] = Field(title="status", description="Simulation task status.") + status: Optional[str] = Field( + None, + title="status", + description="Simulation task status.", + ) - real_flex_unit: float = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + real_flex_unit: Optional[float] = Field( + None, + title="real FlexCredits", + description="Billed FlexCredits.", + alias="realCost", ) created_at: Optional[datetime] = Field( - title="created_at", description="Time at which this task was created.", alias="createdAt" + None, + title="created_at", + description="Time at which this task was created.", + alias="createdAt", ) task_type: Optional[str] = Field( - title="task_type", description="The type of task.", alias="taskType" + None, + title="task_type", + description="The type of task.", + alias="taskType", ) folder_name: Optional[str] = Field( @@ -156,7 +174,7 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): alias="folderName", ) - callback_url: str = Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -164,31 +182,30 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - # simulation_type: str = pd.Field( + # simulation_type: str = Field( # None, # title="Simulation Type", # description="Type of simulation, used internally only.", # ) - # parent_tasks: Tuple[TaskId, ...] = pd.Field( + # parent_tasks: tuple[TaskId, ...] = Field( # None, # title="Parent Tasks", # description="List of parent task ids for the simulation, used internally only." # ) - @pd.root_validator(pre=True) - def _error_if_jax_sim(cls, values): + @model_validator(mode="before") + def _error_if_jax_sim(data: dict) -> dict: """Raise error if user tries to submit simulation that's a JaxSimulation.""" - sim = values.get("simulation") - if sim is None: - return values - if "JaxSimulation" in str(type(sim)): + if data.get("sim") is None: + return data + if "JaxSimulation" in str(type(data.get("sim"))): raise ValueError( "'JaxSimulation' not compatible with regular webapi functions. " "Either convert it to Simulation with 'jax_sim.to_simulation()[0]' or use " "the 'adjoint.run' function to run JaxSimulations." ) - return values + return data @classmethod def create( @@ -198,7 +215,7 @@ def create( folder_name: str = "default", callback_url: str = None, simulation_type: str = "tidy3d", - parent_tasks: List[str] = None, + parent_tasks: list[str] = None, file_type: str = "Gz", ) -> SimulationTask: """Create a new task on the server. @@ -216,7 +233,7 @@ def create( fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. simulation_type : str Type of simulation being uploaded. - parent_tasks : List[str] + parent_tasks : list[str] List of related task ids. file_type: str the simulation file type Json, Hdf5, Gz @@ -273,19 +290,19 @@ def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: return task @classmethod - def get_running_tasks(cls) -> List[SimulationTask]: + def get_running_tasks(cls) -> list[SimulationTask]: """Get a list of running tasks from the server" Returns ------- - List[:class:`.SimulationTask`] + list[:class:`.SimulationTask`] :class:`.SimulationTask` object containing info about status, size, credits of task and others. """ resp = http.get("tidy3d/py/tasks") if not resp: return [] - return parse_obj_as(List[SimulationTask], resp) + return parse_obj_as(list[SimulationTask], resp) def delete(self, versions: bool = False): """Delete current task from server. @@ -577,7 +594,7 @@ def get_simulation_hdf5( progress_callback=progress_callback, ) - def get_running_info(self) -> Tuple[float, float]: + def get_running_info(self) -> tuple[float, float]: """Gets the % done and field_decay for a running task. Returns diff --git a/tidy3d/web/core/task_info.py b/tidy3d/web/core/task_info.py index f33093b392..fcc5a08ffa 100644 --- a/tidy3d/web/core/task_info.py +++ b/tidy3d/web/core/task_info.py @@ -3,9 +3,9 @@ from abc import ABC from datetime import datetime from enum import Enum -from typing import Optional +from typing import Annotated, Optional -import pydantic.v1 as pydantic +from pydantic import BaseModel, ConfigDict, Field class TaskStatus(Enum): @@ -33,14 +33,10 @@ class TaskStatus(Enum): """The task has completed with an error.""" -class TaskBase(pydantic.BaseModel, ABC): +class TaskBase(BaseModel, ABC): """Base configuration for all task objects.""" - class Config: - """Configuration for TaskBase""" - - arbitrary_types_allowed = True - """Allow arbitrary types to be used within the model.""" + model_config = ConfigDict(arbitrary_types_allowed=True) class ChargeType(str, Enum): @@ -60,16 +56,16 @@ class TaskBlockInfo(TaskBase): such as user limits and insufficient balance. """ - chargeType: ChargeType = None + chargeType: Optional[ChargeType] = None """The type of charge applicable to the task (free or paid).""" - maxFreeCount: int = None + maxFreeCount: Optional[int] = None """The maximum number of free tasks allowed.""" - maxGridPoints: int = None + maxGridPoints: Optional[int] = None """The maximum number of grid points permitted.""" - maxTimeSteps: int = None + maxTimeSteps: Optional[int] = None """The maximum number of time steps allowed.""" @@ -79,55 +75,55 @@ class TaskInfo(TaskBase): taskId: str """Unique identifier for the task.""" - taskName: str = None + taskName: Optional[str] = None """Name of the task.""" - nodeSize: int = None + nodeSize: Optional[int] = None """Size of the node allocated for the task.""" completedAt: Optional[datetime] = None """Timestamp when the task was completed.""" - status: str = None + status: Optional[str] = None """Current status of the task.""" - realCost: float = None + realCost: Optional[float] = None """Actual cost incurred by the task.""" - timeSteps: int = None + timeSteps: Optional[int] = None """Number of time steps involved in the task.""" - solverVersion: str = None + solverVersion: Optional[str] = None """Version of the solver used for the task.""" createAt: Optional[datetime] = None """Timestamp when the task was created.""" - estCostMin: float = None + estCostMin: Optional[float] = None """Estimated minimum cost for the task.""" - estCostMax: float = None + estCostMax: Optional[float] = None """Estimated maximum cost for the task.""" - realFlexUnit: float = None + realFlexUnit: Optional[float] = None """Actual flexible units used by the task.""" - oriRealFlexUnit: float = None + oriRealFlexUnit: Optional[float] = None """Original real flexible units.""" - estFlexUnit: float = None + estFlexUnit: Optional[float] = None """Estimated flexible units for the task.""" - estFlexCreditTimeStepping: float = None + estFlexCreditTimeStepping: Optional[float] = None """Estimated flexible credits for time stepping.""" - estFlexCreditPostProcess: float = None + estFlexCreditPostProcess: Optional[float] = None """Estimated flexible credits for post-processing.""" - estFlexCreditMode: float = None + estFlexCreditMode: Optional[float] = None """Estimated flexible credits based on the mode.""" - s3Storage: float = None + s3Storage: Optional[float] = None """Amount of S3 storage used by the task.""" startSolverTime: Optional[datetime] = None @@ -136,29 +132,29 @@ class TaskInfo(TaskBase): finishSolverTime: Optional[datetime] = None """Timestamp when the solver finished.""" - totalSolverTime: int = None + totalSolverTime: Optional[int] = None """Total time taken by the solver.""" - callbackUrl: str = None + callbackUrl: Optional[str] = None """Callback URL for task notifications.""" - taskType: str = None + taskType: Optional[str] = None """Type of the task.""" - metadataStatus: str = None + metadataStatus: Optional[str] = None """Status of the metadata for the task.""" - taskBlockInfo: TaskBlockInfo = None + taskBlockInfo: Optional[TaskBlockInfo] = None """Blocking information for the task.""" class RunInfo(TaskBase): """Information about the run of a task.""" - perc_done: pydantic.confloat(ge=0.0, le=100.0) + perc_done: Annotated[float, Field(ge=0.0, le=100.0)] """Percentage of the task that is completed (0 to 100).""" - field_decay: pydantic.confloat(ge=0.0, le=1.0) + field_decay: Annotated[float, Field(ge=0.0, le=1.0)] """Field decay from the maximum value (0 to 1).""" def display(self): diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index 748cfc7a8a..78ec60a493 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum -from pydantic.v1 import BaseModel +from pydantic import BaseModel class Tidy3DResource(BaseModel, ABC): @@ -43,7 +43,7 @@ class Queryable(BaseModel, ABC): @classmethod @abstractmethod - def list(cls, *args, **kwargs) -> [Queryable]: + def list(cls, *args, **kwargs) -> list[Queryable]: """List all resources of this type."""