|
| 1 | +# SPDX-FileCopyrightText: 2024 Alexandru Fikl <[email protected]> |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from collections.abc import Iterator |
| 7 | +from dataclasses import dataclass |
| 8 | +from functools import cached_property |
| 9 | +from typing import Any, NamedTuple |
| 10 | + |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +from pycaputo.derivatives import CaputoDerivative, Side |
| 14 | +from pycaputo.events import Event, StepCompleted |
| 15 | +from pycaputo.history import History, ProductIntegrationHistory |
| 16 | +from pycaputo.logging import get_logger |
| 17 | +from pycaputo.stepping import ( |
| 18 | + FractionalDifferentialEquationMethod, |
| 19 | + evolve, |
| 20 | + make_initial_condition, |
| 21 | +) |
| 22 | +from pycaputo.typing import Array, StateFunctionT |
| 23 | + |
| 24 | +logger = get_logger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class AdvanceResult(NamedTuple): |
| 28 | + """Result of :func:`~pycaputo.stepping.advance` for |
| 29 | + :class:`ProductIntegrationMethod` subclasses.""" |
| 30 | + |
| 31 | + y: Array |
| 32 | + """Estimated solution at the next time step.""" |
| 33 | + trunc: Array |
| 34 | + """Estimated truncation error at the next time step.""" |
| 35 | + storage: Array |
| 36 | + """Array to add to the history storage.""" |
| 37 | + |
| 38 | + |
| 39 | +@dataclass(frozen=True) |
| 40 | +class SplineCollocationMethod(FractionalDifferentialEquationMethod[StateFunctionT]): |
| 41 | + """A spline collocation method for""" |
| 42 | + |
| 43 | + @cached_property |
| 44 | + def d(self) -> tuple[CaputoDerivative, ...]: |
| 45 | + return tuple([ |
| 46 | + CaputoDerivative(alpha=alpha, side=Side.Left) |
| 47 | + for alpha in self.derivative_order |
| 48 | + ]) |
| 49 | + |
| 50 | + |
| 51 | +@make_initial_condition.register(SplineCollocationMethod) |
| 52 | +def _make_initial_condition_caputo( |
| 53 | + m: SplineCollocationMethod[StateFunctionT], |
| 54 | +) -> Array: |
| 55 | + return m.y0[0] |
| 56 | + |
| 57 | + |
| 58 | +@evolve.register(SplineCollocationMethod) |
| 59 | +def _evolve_pi( |
| 60 | + m: SplineCollocationMethod[StateFunctionT], |
| 61 | + *, |
| 62 | + history: History[Any] | None = None, |
| 63 | + dtinit: float | None = None, |
| 64 | +) -> Iterator[Event]: |
| 65 | + from pycaputo.controller import estimate_initial_time_step |
| 66 | + |
| 67 | + if history is None: |
| 68 | + history = ProductIntegrationHistory.empty_like(m.y0[0]) |
| 69 | + |
| 70 | + # initialize |
| 71 | + c = m.control |
| 72 | + n = 0 |
| 73 | + t = c.tstart |
| 74 | + |
| 75 | + # determine the initial condition |
| 76 | + yprev = make_initial_condition(m) |
| 77 | + history.append(t, m.source(t, yprev)) |
| 78 | + |
| 79 | + # determine initial time step |
| 80 | + if dtinit is None: |
| 81 | + dt = estimate_initial_time_step( |
| 82 | + t, yprev, m.source, m.smallest_derivative_order, trunc=m.order + 1 |
| 83 | + ) |
| 84 | + else: |
| 85 | + dt = dtinit |
| 86 | + |
| 87 | + yield StepCompleted( |
| 88 | + t=t, |
| 89 | + iteration=n, |
| 90 | + dt=dt, |
| 91 | + y=yprev, |
| 92 | + eest=0.0, |
| 93 | + q=1.0, |
| 94 | + trunc=np.zeros_like(yprev), |
| 95 | + ) |
0 commit comments