Skip to content

Commit

Permalink
Update _cmawm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ha-mano committed Nov 15, 2024
1 parent 291376e commit f3dbc1c
Showing 1 changed file with 21 additions and 46 deletions.
67 changes: 21 additions & 46 deletions cmaes/_cmawm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,15 @@ def __init__(
for i in range(self._n_zdim):
self.z_space[i][np.isnan(self.z_space[i])] = np.nanmax(self.z_space[i])
self.z_lim[i][np.isnan(self.z_lim[i])] = np.nanmax(self.z_lim[i])
self.z_lim_low = np.concatenate(
[self.z_lim.min(axis=1).reshape([self._n_zdim, 1]), self.z_lim], 1
)
self.z_lim_up = np.concatenate(
[self.z_lim, self.z_lim.max(axis=1).reshape([self._n_zdim, 1])], 1
)
m_z = self._cma._mean[self._discrete_idx].reshape(([self._n_zdim, 1]))
m_z = self._cma._mean[self._discrete_idx]
# m_z_lim_low ->| mean vector |<- m_z_lim_up
self.m_z_lim_low = (
self.z_lim_low
* np.where(np.sort(np.concatenate([self.z_lim, m_z], 1)) == m_z, 1, 0)
).sum(axis=1)
self.m_z_lim_up = (
self.z_lim_up
* np.where(np.sort(np.concatenate([self.z_lim, m_z], 1)) == m_z, 1, 0)
).sum(axis=1)
m_pos = np.array(
[np.searchsorted(self.z_lim[i], m_z[i]) for i in range(len(m_z))]
)
z_lim_low_index = np.clip(m_pos - 1, 0, self.z_lim.shape[1] - 1)
z_lim_up_index = np.clip(m_pos, 0, self.z_lim.shape[1] - 1)
self.z_lim_low = self.z_lim[np.arange(len(self.z_lim)), z_lim_low_index]
self.z_lim_up = self.z_lim[np.arange(len(self.z_lim)), z_lim_up_index]

self._A = np.full(n_dim, 1.0)

Expand Down Expand Up @@ -250,12 +243,9 @@ def _encode_discrete_params(self, discrete_param: np.ndarray) -> np.ndarray:
x = (discrete_param - mean[self._discrete_idx]) * self._A[
self._discrete_idx
] + mean[self._discrete_idx]
x = x.reshape([self._n_zdim, 1])
x_enc = (
self.z_space
* np.where(np.sort(np.concatenate((self.z_lim, x), axis=1)) == x, 1, 0)
).sum(axis=1)
return x_enc.reshape(self._n_zdim)
x_pos = np.array([np.searchsorted(self.z_lim[i], x[i]) for i in range(len(x))])
x_enc = self.z_space[np.arange(len(self.z_space)), x_pos]
return x_enc

def tell(self, solutions: list[tuple[np.ndarray, float]]) -> None:
"""Tell evaluation values"""
Expand All @@ -267,31 +257,17 @@ def tell(self, solutions: list[tuple[np.ndarray, float]]) -> None:
if self._n_zdim == 0:
return
# margin correction
updated_m_integer = mean[self._discrete_idx, np.newaxis]
self.z_lim_low = np.concatenate(
[self.z_lim.min(axis=1).reshape([self._n_zdim, 1]), self.z_lim], 1
updated_m_integer = mean[self._discrete_idx]
m_pos = np.array(
[
np.searchsorted(self.z_lim[i], updated_m_integer[i])
for i in range(len(updated_m_integer))
]
)
self.z_lim_up = np.concatenate(
[self.z_lim, self.z_lim.max(axis=1).reshape([self._n_zdim, 1])], 1
)
self.m_z_lim_low = (
self.z_lim_low
* np.where(
np.sort(np.concatenate([self.z_lim, updated_m_integer], 1))
== updated_m_integer,
1,
0,
)
).sum(axis=1)
self.m_z_lim_up = (
self.z_lim_up
* np.where(
np.sort(np.concatenate([self.z_lim, updated_m_integer], 1))
== updated_m_integer,
1,
0,
)
).sum(axis=1)
z_lim_low_index = np.clip(m_pos - 1, 0, self.z_lim.shape[1] - 1)
z_lim_up_index = np.clip(m_pos, 0, self.z_lim.shape[1] - 1)
self.m_z_lim_low = self.z_lim[np.arange(len(self.z_lim)), z_lim_low_index]
self.m_z_lim_up = self.z_lim[np.arange(len(self.z_lim)), z_lim_up_index]

# calculate probability low_cdf := Pr(X <= m_z_lim_low) and up_cdf := Pr(m_z_lim_up < X)
# sig_z_sq_Cdiag = self.model.sigma * self.model.A * np.sqrt(np.diag(self.model.C))
Expand All @@ -300,7 +276,6 @@ def tell(self, solutions: list[tuple[np.ndarray, float]]) -> None:
* self._A[self._discrete_idx]
* np.sqrt(np.diag(C)[self._discrete_idx])
)
updated_m_integer = updated_m_integer.flatten()
low_cdf = norm_cdf(self.m_z_lim_low, loc=updated_m_integer, scale=z_scale)
up_cdf = 1.0 - norm_cdf(self.m_z_lim_up, loc=updated_m_integer, scale=z_scale)
mid_cdf = 1.0 - (low_cdf + up_cdf)
Expand Down

0 comments on commit f3dbc1c

Please sign in to comment.