Skip to content

Update the ModeSolver by adding perturbation-based group index and group velocity dispersion calculatoion method #2243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
24 changes: 24 additions & 0 deletions tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,26 @@ class ModeSolverDataset(ElectromagneticFieldDataset):
description="Index associated with group velocity of the mode.",
)

n_group_analytic: GroupIndexDataArray = pd.Field(
None,
title="Group Index",
description="Index associated with group velocity of the mode.",
)

dispersion_raw: ModeDispersionDataArray = pd.Field(
None,
title="Dispersion",
description="Dispersion parameter for the mode.",
units=PICOSECOND_PER_NANOMETER_PER_KILOMETER,
)

dispersion_analytic: GroupIndexDataArray = pd.Field(
None,
title="Dispersion",
description="Dispersion parameter for the mode.",
units=PICOSECOND_PER_NANOMETER_PER_KILOMETER,
)

@property
def field_components(self) -> Dict[str, DataArray]:
"""Maps the field components to their associated data."""
Expand Down Expand Up @@ -431,6 +444,17 @@ def n_group(self) -> GroupIndexDataArray:
)
return self.n_group_raw

@property
def n_group_new(self) -> GroupIndexDataArray:
"""Group index."""
if self.n_group_analytic is None:
log.warning(
"The group index was not computed. To calculate group index, pass "
"'calculate_group_index = True' in the 'ModeSpec'.",
log_once=True,
)
return self.n_group_analytic

@property
def dispersion(self) -> ModeDispersionDataArray:
r"""Dispersion parameter.
Expand Down
48 changes: 48 additions & 0 deletions tidy3d/components/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class ModeSpec(Tidy3dBaseModel):
f"default of {GROUP_INDEX_STEP} is used.",
)

calculate_group_index: bool = pd.Field(
False,
title="Perturbation-based calculation of the group index and the GVD",
description="Control the computation of the group index and the group velocity dispersion"
"alongside the effective index. If set to 'True', the perturbation theory based algorithm"
" is used for calculation.",
)

@pd.validator("bend_axis", always=True)
@skip_if_fields_missing(["bend_radius"])
def bend_axis_given(cls, val, values):
Expand Down Expand Up @@ -207,3 +215,43 @@ def check_precision(cls, values):
)

return values

class SingleFreqModesData(Tidy3dBaseModel):
"""A data class to store the modes data for a single frequency"""

E_vectors: np.ndarray = pd.Field(
None, title="E field", description="Electric field of the eigenmodes, shape (3, N, num_modes)."
)

H_vectors: np.ndarray = pd.Field(
None, title="H field", description="Magnetic field of the eigenmodes, shape (3, N, num_modes)."
)

E_fields: np.ndarray = pd.Field(
None, title="E field", description="Electric field of the eigenmodes, shape (3, Nx, Ny, 1, num_modes)."
)

H_fields: np.ndarray = pd.Field(
None, title="H field", description="Magnetic field of the eigenmodes, shape (3, Nx, Ny, 1, num_modes)."
)

n_eff: np.ndarray = pd.Field(
None, title="Mode refractive index", description="Real part of the effective index, shape (num_modes, )."
)

k_eff: np.ndarray = pd.Field(
None, title="Mode absorption index", description="Imaginary part of the effective index, shape (num_modes, )."
)

n_group: np.ndarray = pd.Field(
None, title="Mode group index", description="Real part of the effective group index, shape (num_modes, )."
)

GVD: np.ndarray = pd.Field(
None, title="Group velocity dispersion", description="Group velocity dispersion data, shape (num_modes, )."
)

eps_spec : Literal["diagonal", "tensorial_real", "tensorial_complex"] = pd.Field(
None,
title="Permittivity characterization on the mode solver's plane",
)
51 changes: 44 additions & 7 deletions tidy3d/plugins/mode/mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,63 @@
invariance along a given propagation axis.
"""

from __future__ import annotations

from functools import wraps
from math import isclose
from typing import Dict, List, Tuple, Union

import numpy as np
import pydantic.v1 as pydantic
import xarray as xr
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
from ...components.boundary import PML, Absorber, Boundary, BoundarySpec, PECBoundary, StablePML
from ...components.data.data_array import (
FreqModeDataArray,
ModeIndexDataArray,
GroupIndexDataArray,
ModeDispersionDataArray,
ScalarModeFieldDataArray,
)
from ...components.data.monitor_data import ModeSolverData
from ...components.data.sim_data import SimulationData
from ...components.eme.data.sim_data import EMESimulationData
from ...components.eme.simulation import EMESimulation
from ...components.geometry.base import Box
from ...components.grid.grid import Grid
from ...components.medium import FullyAnisotropicMedium
from ...components.mode import ModeSpec
from ...components.monitor import ModeMonitor, ModeSolverMonitor
from ...components.simulation import Simulation
from ...components.source import ModeSource, SourceTime
from ...components.types import (
TYPE_TAG_STR,
ArrayComplex3D,
ArrayComplex4D,
ArrayFloat1D,
Ax,
Axis,
Direction,
EpsSpecType,
FreqArray,
Literal,
PlotScale,
Symmetry,
)
from ...components.validators import (
validate_freqs_min,
validate_freqs_not_empty,
validate_mode_plane_radius,
)
from ...components.viz import make_ax, plot_params_pml
from ...constants import C_0
from ...exceptions import SetupError, ValidationError
from ...log import log

# Importing the local solver may not work if e.g. scipy is not installed

Check failure on line 61 in tidy3d/plugins/mode/mode_solver.py

View workflow job for this annotation

GitHub Actions / Run linting

Ruff (I001)

tidy3d/plugins/mode/mode_solver.py:5:1: I001 Import block is un-sorted or un-formatted
IMPORT_ERROR_MSG = """Could not import local solver, 'ModeSolver' objects can still be constructed
but will have to be run through the server.
"""
Expand Down Expand Up @@ -370,7 +372,7 @@
)

# Compute and store the modes at all frequencies
n_complex, fields, eps_spec = solver._solve_all_freqs(
n_complex, fields, n_group, n_GVD, eps_spec = solver._solve_all_freqs(
coords=_solver_coords, symmetry=solver.solver_symmetry
)

Expand All @@ -384,6 +386,29 @@
)
data_dict = {"n_complex": index_data}

if n_group is not None:
if n_group[0] is not None:
n_group_data = GroupIndexDataArray(
np.stack(n_group, axis=0),
coords=dict(
f=list(solver.freqs),
mode_index=np.arange(solver.mode_spec.num_modes),
),
)

data_dict["n_group_analytic"] = n_group_data

if n_GVD is not None:
if n_GVD[0] is not None:
n_GVD_data = ModeDispersionDataArray(
np.stack(n_GVD, axis=0),
coords=dict(
f=list(solver.freqs),
mode_index=np.arange(solver.mode_spec.num_modes),
),
)
data_dict["dispersion_analytic"] = n_GVD_data

# Construct the field data on Yee grid
for field_name in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
xyz_coords = solver.grid_snapped[field_name].to_list
Expand Down Expand Up @@ -670,15 +695,19 @@

fields = []
n_complex = []
n_group = []
n_GVD = []
eps_spec = []
for freq in self.freqs:
n_freq, fields_freq, eps_spec_freq = self._solve_single_freq(
n_freq, fields_freq, n_group_freq, GVD_freq, eps_spec_freq = self._solve_single_freq(
freq=freq, coords=coords, symmetry=symmetry
)
fields.append(fields_freq)
n_complex.append(n_freq)
n_group.append(n_group_freq)
n_GVD.append(GVD_freq)
eps_spec.append(eps_spec_freq)
return n_complex, fields, eps_spec
return n_complex, fields, n_group, n_GVD, eps_spec

def _solve_all_freqs_relative(
self,
Expand Down Expand Up @@ -731,7 +760,8 @@
if not LOCAL_SOLVER_IMPORTED:
raise ImportError(IMPORT_ERROR_MSG)

solver_fields, n_complex, eps_spec = compute_modes(
# solver_fields, n_complex, eps_spec = compute_modes(
modes_data = compute_modes(
eps_cross=self._solver_eps(freq),
coords=coords,
freq=freq,
Expand All @@ -740,8 +770,11 @@
direction=self.direction,
)

solver_fields = fields = np.stack((modes_data.E_fields, modes_data.H_fields), axis=0)
fields = self._postprocess_solver_fields(solver_fields)
return n_complex, fields, eps_spec

n_complex = modes_data.n_eff + 1j * modes_data.k_eff
return n_complex, fields, modes_data.n_group, modes_data.GVD, modes_data.eps_spec

def _rotate_field_coords_inverse(self, field: FIELD) -> FIELD:
"""Move the propagation axis to the z axis in the array."""
Expand Down Expand Up @@ -782,7 +815,7 @@

solver_basis_fields = self._postprocess_solver_fields_inverse(basis_fields)

solver_fields, n_complex, eps_spec = compute_modes(
modes_data = compute_modes(
eps_cross=self._solver_eps(freq),
coords=coords,
freq=freq,
Expand All @@ -792,8 +825,12 @@
solver_basis_fields=solver_basis_fields,
)

solver_fields = fields = np.stack((modes_data.E_fields, modes_data.H_fields), axis=0)
fields = self._postprocess_solver_fields(solver_fields)
return n_complex, fields, eps_spec

n_complex = modes_data.n_eff + 1j * modes_data.k_eff

return n_complex, fields, modes_data.eps_spec

def _rotate_field_coords(self, field: FIELD) -> FIELD:
"""Move the propagation axis=z to the proper order in the array."""
Expand Down
Loading
Loading