|
39 | 39 | "fhn = fhn.FitzHughNagumo()\n", |
40 | 40 | "training_data = fhn.solve_ivps(\n", |
41 | 41 | " initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n", |
42 | | - " tspan=[0.0, 10.0],\n", |
| 42 | + " tspan=[0.0, 6.0],\n", |
43 | 43 | " sampling_period=0.1\n", |
44 | 44 | ")" |
45 | 45 | ] |
|
66 | 66 | " \n", |
67 | 67 | "\n", |
68 | 68 | " # weight good trajectory by its 1 norm\n", |
69 | | - " w = np.sum(traj.abs().states, axis=1)\n", |
| 69 | + " #w = np.sum(traj.abs().states, axis=1)\n", |
| 70 | + " w = np.ones(traj.states.shape)\n", |
70 | 71 | " weights.append(w)\n", |
71 | 72 | "\n", |
72 | 73 | " # weight garbage trajectory to zero\n", |
73 | | - " w = np.zeros(len(traj.states))\n", |
| 74 | + " #w = np.zeros(len(traj.states))\n", |
| 75 | + " w = np.zeros(traj.states.shape)\n", |
74 | 76 | " weights.append(w)\n", |
75 | 77 | "\n", |
76 | 78 | "# you can also use a dict to name the trajectories if using TrajectoriesData (numpy arrays are named by their index number)\n", |
77 | | - "weights = {idx: w for idx, w in enumerate(weights)}" |
| 79 | + "#weights = {idx: w for idx, w in enumerate(weights)}" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": null, |
| 85 | + "id": "280b4bb3-4f7d-4a94-a983-663c6255bc83", |
| 86 | + "metadata": {}, |
| 87 | + "outputs": [], |
| 88 | + "source": [ |
| 89 | + "weights[1].shape" |
78 | 90 | ] |
79 | 91 | }, |
80 | 92 | { |
|
93 | 105 | " learning_weights=weights, # weight the eDMD algorithm objectives\n", |
94 | 106 | " scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n", |
95 | 107 | " opt=\"grid\", # grid search to find best hyperparameters\n", |
96 | | - " n_obs=200, # maximum number of observables to try\n", |
| 108 | + " n_obs=40, # maximum number of observables to try\n", |
97 | 109 | " max_opt_iter=200, # maximum number of optimization iterations\n", |
98 | 110 | " grid_param_slices=5, # for grid search, number of slices for each parameter\n", |
99 | 111 | " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n", |
|
117 | 129 | " learning_weights=None, # don't use eDMD weighting\n", |
118 | 130 | " scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n", |
119 | 131 | " opt=\"grid\", # grid search to find best hyperparameters\n", |
120 | | - " n_obs=200, # maximum number of observables to try\n", |
| 132 | + " n_obs=40, # maximum number of observables to try\n", |
121 | 133 | " max_opt_iter=200, # maximum number of optimization iterations\n", |
122 | 134 | " grid_param_slices=5, # for grid search, number of slices for each parameter\n", |
123 | 135 | " n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n", |
|
178 | 190 | "plt.figure(figsize=(10, 6))\n", |
179 | 191 | "\n", |
180 | 192 | "# plot the results\n", |
| 193 | + "plt.plot(*true_trajectory.states.T, linewidth=2, label='Ground Truth')\n", |
181 | 194 | "plt.plot(*trajectory.states.T, label='Weighted Trajectory Prediction')\n", |
182 | 195 | "plt.plot(*trajectory_uw.states.T, label='Trajectory Prediction')\n", |
183 | | - "plt.plot(*true_trajectory.states.T, label='Ground Truth')\n", |
| 196 | + "\n", |
184 | 197 | "\n", |
185 | 198 | "plt.xlabel(\"$x_1$\")\n", |
186 | 199 | "plt.ylabel(\"$x_2$\")\n", |
|
0 commit comments