Skip to content

Commit 6a4b37b

Browse files
committed
remove clipping function
change to rank adaptation problem for sw-edmd modify weighted cost func for new rank values
1 parent 96a40cf commit 6a4b37b

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

autokoopman/estimator/koopman.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,21 @@ def swdmdc(X, Xp, U, r, Js, W):
8181
# so the objective isn't numerically unstable
8282
sf = (1.0 / n_snap)
8383

84+
# check that rank is less than the number of observations
85+
if r > n_obs:
86+
warnings.warn("Rank must be less than the number of observations. Reducing rank to n_obs.")
87+
r = n_obs
88+
8489
# koopman operator
85-
K = cp.Variable((n_obs, n_obs + n_inps))
90+
# Variables for the low-rank approximation
91+
K_U = cp.Variable((n_obs, r))
92+
K_V = cp.Variable((r, n_obs + n_inps))
8693

8794
# SW-eDMD objective
88-
weights_obj = np.vstack([(np.clip(np.abs(J), 0.0, 1.0) @ w) for J, w in zip(Js, W)]).T
89-
P = sf * cp.multiply(weights_obj, Yp.T - K @ Y.T)
95+
weights_obj = np.vstack([(np.abs(J) @ w) for J, w in zip(Js, W)]).T
96+
P = sf * cp.multiply(weights_obj, Yp.T - (K_U @ K_V) @ Y.T)
9097
# add regularization
91-
objective = cp.Minimize(cp.sum_squares(P) + 1E-4 * 1.0 / (n_obs**2) * cp.norm(K, "fro"))
98+
objective = cp.Minimize(cp.sum_squares(P) + 1E-4 * 1.0 / (n_obs**2) * cp.norm(K_U @ K_V, "fro"))
9299

93100
# unconstrained problem
94101
constraints = None
@@ -100,14 +107,14 @@ def swdmdc(X, Xp, U, r, Js, W):
100107
try:
101108
_ = prob.solve(solver=cp.CLARABEL)
102109
#_ = prob.solve(solver=cp.ECOS)
103-
if K.value is None:
110+
if K_U.value is None or K_V.value is None:
104111
raise Exception("SW-eDMD (cvxpy) Optimization failed to converge.")
105112
except:
106113
warnings.warn("SW-eDMD (cvxpy) Optimization failed to converge. Switching to unweighted DMDc.")
107114
return dmdc(X, Xp, U, r)
108115

109116
# get the transformation
110-
Atilde = K.value
117+
Atilde = K_U.value @ K_V.value
111118
return Atilde[:, :state_size], Atilde[:, state_size:]
112119

113120

notebooks/weighted-cost-func.ipynb

+40-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": null,
19+
"execution_count": 1,
2020
"id": "dfdb47e2-0ea0-4bf8-8279-8500ff3cf21f",
2121
"metadata": {},
2222
"outputs": [],
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": null,
39+
"execution_count": 2,
4040
"id": "291d3409-1c8c-44cb-8380-44f08019b57d",
4141
"metadata": {},
4242
"outputs": [],
@@ -45,7 +45,7 @@
4545
"import autokoopman.benchmark.fhn as fhn\n",
4646
"fhn = fhn.FitzHughNagumo()\n",
4747
"training_data = fhn.solve_ivps(\n",
48-
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(1, 2)),\n",
48+
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
4949
" tspan=[0.0, 6.0],\n",
5050
" sampling_period=0.1\n",
5151
")\n",
@@ -68,7 +68,7 @@
6868
},
6969
{
7070
"cell_type": "code",
71-
"execution_count": null,
71+
"execution_count": 3,
7272
"id": "e2d42e41-46c2-467c-9ce7-9bd6a7c509a1",
7373
"metadata": {},
7474
"outputs": [],
@@ -92,6 +92,7 @@
9292
" w = np.ones(traj.states.shape)\n",
9393
" w[:, -3:] = 0.0\n",
9494
" w[:, :2] = 1.0\n",
95+
" w[:, 0] = 1.0\n",
9596
" weights.append(w)\n",
9697
"\n",
9798
" # weight garbage trajectory to zero\n",
@@ -103,6 +104,25 @@
103104
"#weights = {idx: w for idx, w in enumerate(weights)}"
104105
]
105106
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 4,
110+
"id": "ddd76415-6d19-4a38-a2b0-84eb48d0fdfc",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"from autokoopman.observable import ReweightedRFFObservable\n",
115+
"import autokoopman.observable as akobs\n",
116+
"\n",
117+
"X, WX = list(zip(*list((trajectories[i], w) for i, w in enumerate(weights))))\n",
118+
"X, WX = np.vstack(X), np.vstack(WX)\n",
119+
"X, WX = np.tile(X, (3, 1)), np.tile(WX, (3, 1))\n",
120+
"idxs = np.random.permutation(np.arange(len(X)))\n",
121+
"Y, WY = X[idxs], WX[idxs]\n",
122+
"\n",
123+
"reweight_obs = akobs.IdentityObservable() | akobs.ReweightedRFFObservable(5, 40, 1.0, X, Y, WX, WY, 'rff')"
124+
]
125+
},
106126
{
107127
"cell_type": "markdown",
108128
"id": "a706f212-36cd-4203-b209-cb7c5ce4ad94",
@@ -122,7 +142,19 @@
122142
"execution_count": null,
123143
"id": "98510aa7-3416-4181-a493-00500be53f61",
124144
"metadata": {},
125-
"outputs": [],
145+
"outputs": [
146+
{
147+
"name": "stderr",
148+
"output_type": "stream",
149+
"text": [
150+
"Tuning GridSearchTuner: 0%| | 0/40 [00:00<?, ?it/s]/home/elew/AutoKoopman/notebooks/../autokoopman/estimator/koopman.py:113: UserWarning: SW-eDMD (cvxpy) Optimization failed to converge. Switching to unweighted DMDc.\n",
151+
" warnings.warn(\"SW-eDMD (cvxpy) Optimization failed to converge. Switching to unweighted DMDc.\")\n",
152+
"Tuning GridSearchTuner: 12%|▊ | 5/40 [00:23<03:07, 5.37s/it]/home/elew/anaconda3/envs/autokoopman/lib/python3.12/site-packages/numpy/linalg/linalg.py:2582: RuntimeWarning: overflow encountered in multiply\n",
153+
" s = (x.conj() * x).real\n",
154+
"Tuning GridSearchTuner: 30%|█▌ | 12/40 [00:52<02:04, 4.44s/it]"
155+
]
156+
}
157+
],
126158
"source": [
127159
"# learn model from weighted data\n",
128160
"experiment_results = auto_koopman(\n",
@@ -136,8 +168,8 @@
136168
" n_obs=40, # maximum number of observables to try\n",
137169
" max_opt_iter=200, # maximum number of optimization iterations\n",
138170
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
139-
" n_splits=None, # k-folds validation for tuning, helps stabilize the scoring\n",
140-
" rank=(40, 41, 1) # NOTE: don't rank tune (for now)--SW-eDMD already optimizes for this\n",
171+
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
172+
" rank=(1, 41, 5) # rank (SW-eDMD now uses rank adaptation)\n",
141173
")"
142174
]
143175
},
@@ -290,7 +322,7 @@
290322
"name": "python",
291323
"nbconvert_exporter": "python",
292324
"pygments_lexer": "ipython3",
293-
"version": "3.9.0"
325+
"version": "3.12.3"
294326
}
295327
},
296328
"nbformat": 4,

0 commit comments

Comments
 (0)