|
| 1 | +import numpy as np |
| 2 | +import eudist |
| 3 | + |
| 4 | +from warnings import warn |
| 5 | + |
| 6 | + |
| 7 | +def rz_to_ab(rz, mesh): |
| 8 | + ij = mesh.find_cell(rz) |
| 9 | + assert ij >= 0, "We left the domain" |
| 10 | + _, nz = mesh.shape |
| 11 | + i, j = ij // nz, ij % nz |
| 12 | + ABCD = mesh.grid[i : i + 2, j : j + 2] |
| 13 | + A = ABCD[0, 0] |
| 14 | + a = ABCD[0, 1] - A |
| 15 | + b = ABCD[1, 0] - A |
| 16 | + c = ABCD[1, 1] - A - a - b |
| 17 | + rz0 = rz - A |
| 18 | + |
| 19 | + def fun(albe): |
| 20 | + al, be = albe |
| 21 | + return rz0 - a * al - b * be - c * al * be |
| 22 | + |
| 23 | + def J(albe): |
| 24 | + al, be = albe |
| 25 | + return np.array([-a - c * be, -b - c * al]) |
| 26 | + |
| 27 | + tol = 1e-13 |
| 28 | + albe = np.ones(2) / 2 |
| 29 | + while True: |
| 30 | + res = np.sum(fun(albe) ** 2) |
| 31 | + albe = albe - np.linalg.inv(J(albe).T) @ fun(albe) |
| 32 | + if res < tol: |
| 33 | + return albe, ij |
| 34 | + |
| 35 | + |
| 36 | +def ab_to_rz(ab, ij, mesh): |
| 37 | + _, nz = mesh.shape |
| 38 | + i, j = ij // nz, ij % nz |
| 39 | + A = mesh.grid[i, j] |
| 40 | + a = mesh.grid[i, j + 1] - A |
| 41 | + b = mesh.grid[i + 1, j] - A |
| 42 | + c = mesh.grid[i + 1, j + 1] - A - a - b |
| 43 | + al, be = ab |
| 44 | + return A + al * a + be * b + al * be * c |
| 45 | + |
| 46 | + |
| 47 | +def setup_mesh(x, y): |
| 48 | + def per(d): |
| 49 | + return np.concatenate((d, d[:, :1]), axis=1) |
| 50 | + |
| 51 | + assert x.dims == y.dims |
| 52 | + assert x.dims == ("x", "z") |
| 53 | + x = per(x.data) |
| 54 | + y = per(y.data) |
| 55 | + return mymesh(x, y) |
| 56 | + |
| 57 | + |
| 58 | +class mymesh(eudist.PolyMesh): |
| 59 | + def __init__(self, x, y): |
| 60 | + super().__init__() |
| 61 | + self.r = x |
| 62 | + self.z = y |
| 63 | + self.grid = np.array([x, y]).transpose(1, 2, 0) |
| 64 | + self.shape = tuple([x - 1 for x in x.shape]) |
| 65 | + |
| 66 | + |
| 67 | +class Tracer: |
| 68 | + def __init__(self, ds, direction="forward"): |
| 69 | + meshes = [] |
| 70 | + for yi in range(len(ds.y)): |
| 71 | + dsi = ds.isel(y=yi) |
| 72 | + meshes.append( |
| 73 | + [ |
| 74 | + setup_mesh(dsi[f"{pre}R"], dsi[f"{pre}Z"]) |
| 75 | + for pre in ["", f"{direction}_"] |
| 76 | + ] |
| 77 | + + [yi] |
| 78 | + ) |
| 79 | + self.meshes = meshes |
| 80 | + |
| 81 | + def poincare(self, rz, yind=0, num=100, early_exit="warn"): |
| 82 | + rz = np.array(rz) |
| 83 | + assert rz.shape == (2,) |
| 84 | + thismeshes = self.meshes[yind:] + self.meshes[:yind] |
| 85 | + out = np.empty((num, 2)) |
| 86 | + out[0] = rz |
| 87 | + last = None |
| 88 | + for i in range(1, num): |
| 89 | + for d, meshes in enumerate(thismeshes): |
| 90 | + try: |
| 91 | + abij = rz_to_ab(rz, meshes[0]) |
| 92 | + except AssertionError as e: |
| 93 | + if early_exit == "warn": |
| 94 | + warn(f"early exit in iteration {i} because `{e}`") |
| 95 | + elif early_exit == "plot": |
| 96 | + m = meshes[0] |
| 97 | + import matplotlib.pyplot as plt |
| 98 | + |
| 99 | + plt.plot(m.r, m.z) |
| 100 | + if last: |
| 101 | + plt.plot(last[1].r.T, last[1].z.T) |
| 102 | + |
| 103 | + plt.plot(*rz, "o") |
| 104 | + plt.show() |
| 105 | + elif early_exit == "raise": |
| 106 | + raise |
| 107 | + else: |
| 108 | + assert ( |
| 109 | + early_exit == "ignore" |
| 110 | + ), f'early_exit needs to be one of ["warn", "plot", "raise", ignore"] but got `{early_exit}`' |
| 111 | + return out[:i] |
| 112 | + rz = ab_to_rz(*abij, meshes[1]) |
| 113 | + last = meshes |
| 114 | + out[i] = rz |
| 115 | + return out |
0 commit comments