Skip to content

Commit 78a707f

Browse files
committed
fix: fix get_random_weights
1 parent 14dc492 commit 78a707f

File tree

1 file changed

+43
-41
lines changed

1 file changed

+43
-41
lines changed

okama/common/helpers/helpers.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -77,58 +77,60 @@ def get_random_weights(n: int, w_shape: int, bounds: Optional[Tuple[Tuple[float,
7777
Constraints for each asset's weight, e.g., ((0, 1), (0, 0.5), (0.5, 1), ...).
7878
If None, default constraints are applied.
7979
"""
80+
# Case 1: default bounds
8081
if bounds is None:
81-
# Case 1: default bounds
8282
random_numbers = np.random.rand(n, w_shape)
8383
# keepdims instead of transpose
84-
weights = random_numbers / random_numbers.sum(axis=1, keepdims=True)
85-
84+
weights = random_numbers / random_numbers.sum(axis=1, keepdims=True)
85+
86+
# Case 2: custom bounds
8687
else:
87-
# Case 2: custom bounds
88+
bounds_arr = np.array(bounds)
89+
mins = bounds_arr[:, 0]
90+
maxs = bounds_arr[:, 1]
91+
8892
weights = []
89-
attempts_per_weight = 1000
93+
batch_size = min(1000, n)
9094

9195
while len(weights) < n:
92-
remaining = 1.0
93-
indices = list(range(w_shape))
94-
np.random.shuffle(indices)
95-
w = np.zeros(w_shape)
96-
valid = True
97-
98-
for i, j in enumerate(indices[:-1]):
99-
low, high = bounds[j]
100-
101-
min_remaining = sum(bounds[k][0] for k in indices[i+1:])
102-
max_remaining = sum(bounds[k][1] for k in indices[i+1:])
103-
104-
adjusted_low = max(low, remaining - max_remaining)
105-
adjusted_high = min(high, remaining - min_remaining)
106-
107-
if adjusted_low > adjusted_high:
108-
valid = False
109-
break
110-
111-
w[j] = np.random.uniform(adjusted_low, adjusted_high)
112-
remaining -= w[j]
96+
97+
remaining = np.ones(batch_size)
98+
indices = np.arange(w_shape)
99+
batch_w = np.zeros((batch_size, w_shape))
100+
valid_mask = np.ones(batch_size, dtype=bool)
101+
102+
shuffled_indices = np.tile(indices, (batch_size, 1))
103+
for i in range(batch_size):
104+
np.random.shuffle(shuffled_indices[i])
113105

114-
if not valid:
115-
continue
106+
for i in range(w_shape - 1):
107+
108+
idx = shuffled_indices[:, i]
109+
low = mins[idx]
110+
high = maxs[idx]
116111

117-
last_idx = indices[-1]
118-
w[last_idx] = remaining
119-
120-
if not (bounds[last_idx][0] <= w[last_idx] <= bounds[last_idx][1]):
121-
valid = False
122-
123-
if np.any(w < 0):
124-
valid = False
112+
future_mins = np.sum(mins[shuffled_indices[:, i+1:]], axis=1)
113+
future_maxs = np.sum(maxs[shuffled_indices[:, i+1:]], axis=1)
125114

126-
if valid:
127-
weights.append(w)
115+
adjusted_low = np.maximum(low, remaining - future_maxs)
116+
adjusted_high = np.minimum(high, remaining - future_mins)
117+
118+
rand_vals = np.random.uniform(adjusted_low, adjusted_high)
128119

129-
elif len(weights) + (n - len(weights)) * attempts_per_weight < attempts_per_weight * n:
130-
continue
131-
else:
120+
batch_w[np.arange(batch_size), idx] = rand_vals
121+
remaining -= rand_vals
122+
123+
valid_mask &= (adjusted_low <= adjusted_high)
124+
125+
last_idx = shuffled_indices[:, -1]
126+
batch_w[np.arange(batch_size), last_idx] = remaining
127+
valid_mask &= (mins[last_idx] <= remaining) & (remaining <= maxs[last_idx])
128+
valid_mask &= np.all(batch_w >= 0, axis=1)
129+
130+
valid_weights = batch_w[valid_mask]
131+
weights.extend(valid_weights.tolist())
132+
133+
if len(weights) >= n:
132134
break
133135

134136
return pd.Series([np.array(w) for w in weights[:n]])

0 commit comments

Comments
 (0)