Skip to content

Commit 81dda39

Browse files
committed
start the state weighted eDMD
1 parent 0ee18bf commit 81dda39

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

autokoopman/core/trajectory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def n_step_matrices(
626626
if weights is None:
627627
W = None
628628
else:
629-
W = np.hstack([weights[idx].flatten()[1:] for idx, _ in items])
629+
W = np.vstack([weights[idx][:-nstep:nstep, :] for idx, _ in items])
630630

631631
return X, Xp, U, W
632632

autokoopman/estimator/koopman.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ def wdmdc(X, Xp, U, r, W):
5252
return Atilde[:, :state_size], Atilde[:, state_size:]
5353

5454

55+
def swdmdc(X, Xp, U, r, Js, W):
56+
"""State Weighted Dynamic Mode Decomposition with Control (wDMDC)"""
57+
assert len(W.shape) == 2, "weights must be 2D for snapshot x state"
58+
59+
if U is not None:
60+
Y = np.hstack((X, U)).T
61+
else:
62+
Y = X.T
63+
Yp = Xp.T
64+
65+
# compute observables weights from state weights
66+
Wy = np.vstack([(np.abs(J) @ np.atleast_2d(w).T).T for J, w in zip(Js, W)])
67+
68+
# apply weights element-wise
69+
Y, Yp = Wy.T * Y, Wy.T * Yp
70+
state_size = Yp.shape[0]
71+
72+
# compute Atilde
73+
U, Sigma, V = np.linalg.svd(Y, False)
74+
U, Sigma, V = U[:, :r], np.diag(Sigma)[:r, :r], V.conj().T[:, :r]
75+
76+
# get the transformation
77+
Atilde = Yp @ V @ np.linalg.inv(Sigma) @ U.conj().T
78+
return Atilde[:, :state_size], Atilde[:, state_size:]
79+
80+
5581
class KoopmanDiscEstimator(kest.NextStepEstimator):
5682
"""Koopman Discrete Estimator
5783
@@ -96,9 +122,20 @@ def fit_next_step(
96122
if weights is None:
97123
self._A, self._B = dmdc(G, Gp, U.T if U is not None else U, self.rank)
98124
else:
99-
self._A, self._B = wdmdc(
100-
G, Gp, U.T if U is not None else U, self.rank, weights
101-
)
125+
# TODO: change this condition to be more accurate
126+
if False: # len(weights[0].shape) == 1:
127+
self._A, self._B = wdmdc(
128+
G, Gp, U.T if U is not None else U, self.rank, weights
129+
)
130+
else:
131+
self._A, self._B = swdmdc(
132+
G,
133+
Gp,
134+
U.T if U is not None else U,
135+
self.rank,
136+
[self.obs.obs_grad(xi) for xi in X.T],
137+
weights,
138+
)
102139
self._has_input = U is not None
103140

104141
@property

0 commit comments

Comments
 (0)