|
16 | 16 | },
|
17 | 17 | {
|
18 | 18 | "cell_type": "code",
|
19 |
| - "execution_count": null, |
| 19 | + "execution_count": 1, |
20 | 20 | "id": "dfdb47e2-0ea0-4bf8-8279-8500ff3cf21f",
|
21 | 21 | "metadata": {},
|
22 | 22 | "outputs": [],
|
|
36 | 36 | },
|
37 | 37 | {
|
38 | 38 | "cell_type": "code",
|
39 |
| - "execution_count": null, |
| 39 | + "execution_count": 2, |
40 | 40 | "id": "291d3409-1c8c-44cb-8380-44f08019b57d",
|
41 | 41 | "metadata": {},
|
42 | 42 | "outputs": [],
|
|
45 | 45 | "import autokoopman.benchmark.fhn as fhn\n",
|
46 | 46 | "fhn = fhn.FitzHughNagumo()\n",
|
47 | 47 | "training_data = fhn.solve_ivps(\n",
|
48 |
| - " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(1, 2)),\n", |
| 48 | + " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n", |
49 | 49 | " tspan=[0.0, 6.0],\n",
|
50 | 50 | " sampling_period=0.1\n",
|
51 | 51 | ")\n",
|
|
68 | 68 | },
|
69 | 69 | {
|
70 | 70 | "cell_type": "code",
|
71 |
| - "execution_count": null, |
| 71 | + "execution_count": 3, |
72 | 72 | "id": "e2d42e41-46c2-467c-9ce7-9bd6a7c509a1",
|
73 | 73 | "metadata": {},
|
74 | 74 | "outputs": [],
|
|
92 | 92 | " w = np.ones(traj.states.shape)\n",
|
93 | 93 | " w[:, -3:] = 0.0\n",
|
94 | 94 | " w[:, :2] = 1.0\n",
|
| 95 | + " w[:, 0] = 1.0\n", |
95 | 96 | " weights.append(w)\n",
|
96 | 97 | "\n",
|
97 | 98 | " # weight garbage trajectory to zero\n",
|
|
103 | 104 | "#weights = {idx: w for idx, w in enumerate(weights)}"
|
104 | 105 | ]
|
105 | 106 | },
|
| 107 | + { |
| 108 | + "cell_type": "code", |
| 109 | + "execution_count": 4, |
| 110 | + "id": "ddd76415-6d19-4a38-a2b0-84eb48d0fdfc", |
| 111 | + "metadata": {}, |
| 112 | + "outputs": [], |
| 113 | + "source": [ |
| 114 | + "from autokoopman.observable import ReweightedRFFObservable\n", |
| 115 | + "import autokoopman.observable as akobs\n", |
| 116 | + "\n", |
| 117 | + "X, WX = list(zip(*list((trajectories[i], w) for i, w in enumerate(weights))))\n", |
| 118 | + "X, WX = np.vstack(X), np.vstack(WX)\n", |
| 119 | + "X, WX = np.tile(X, (3, 1)), np.tile(WX, (3, 1))\n", |
| 120 | + "idxs = np.random.permutation(np.arange(len(X)))\n", |
| 121 | + "Y, WY = X[idxs], WX[idxs]\n", |
| 122 | + "\n", |
| 123 | + "reweight_obs = akobs.IdentityObservable() | akobs.ReweightedRFFObservable(5, 40, 1.0, X, Y, WX, WY, 'rff')" |
| 124 | + ] |
| 125 | + }, |
106 | 126 | {
|
107 | 127 | "cell_type": "markdown",
|
108 | 128 | "id": "a706f212-36cd-4203-b209-cb7c5ce4ad94",
|
|
122 | 142 | "execution_count": null,
|
123 | 143 | "id": "98510aa7-3416-4181-a493-00500be53f61",
|
124 | 144 | "metadata": {},
|
125 |
| - "outputs": [], |
| 145 | + "outputs": [ |
| 146 | + { |
| 147 | + "name": "stderr", |
| 148 | + "output_type": "stream", |
| 149 | + "text": [ |
| 150 | + "Tuning GridSearchTuner: 0%| | 0/40 [00:00<?, ?it/s]/home/elew/AutoKoopman/notebooks/../autokoopman/estimator/koopman.py:113: UserWarning: SW-eDMD (cvxpy) Optimization failed to converge. Switching to unweighted DMDc.\n", |
| 151 | + " warnings.warn(\"SW-eDMD (cvxpy) Optimization failed to converge. Switching to unweighted DMDc.\")\n", |
| 152 | + "Tuning GridSearchTuner: 12%|▊ | 5/40 [00:23<03:07, 5.37s/it]/home/elew/anaconda3/envs/autokoopman/lib/python3.12/site-packages/numpy/linalg/linalg.py:2582: RuntimeWarning: overflow encountered in multiply\n", |
| 153 | + " s = (x.conj() * x).real\n", |
| 154 | + "Tuning GridSearchTuner: 30%|█▌ | 12/40 [00:52<02:04, 4.44s/it]" |
| 155 | + ] |
| 156 | + } |
| 157 | + ], |
126 | 158 | "source": [
|
127 | 159 | "# learn model from weighted data\n",
|
128 | 160 | "experiment_results = auto_koopman(\n",
|
|
136 | 168 | " n_obs=40, # maximum number of observables to try\n",
|
137 | 169 | " max_opt_iter=200, # maximum number of optimization iterations\n",
|
138 | 170 | " grid_param_slices=5, # for grid search, number of slices for each parameter\n",
|
139 |
| - " n_splits=None, # k-folds validation for tuning, helps stabilize the scoring\n", |
140 |
| - " rank=(40, 41, 1) # NOTE: don't rank tune (for now)--SW-eDMD already optimizes for this\n", |
| 171 | + " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n", |
| 172 | + " rank=(1, 41, 5) # rank (SW-eDMD now uses rank adaptation)\n", |
141 | 173 | ")"
|
142 | 174 | ]
|
143 | 175 | },
|
|
290 | 322 | "name": "python",
|
291 | 323 | "nbconvert_exporter": "python",
|
292 | 324 | "pygments_lexer": "ipython3",
|
293 |
| - "version": "3.9.0" |
| 325 | + "version": "3.12.3" |
294 | 326 | }
|
295 | 327 | },
|
296 | 328 | "nbformat": 4,
|
|
0 commit comments