88from __future__ import annotations
99
1010import warnings
11+ from copy import deepcopy
1112from typing import Any , Dict , List , Mapping , Optional , Tuple , Union
1213
1314import numpy as np
@@ -56,6 +57,7 @@ def __init__(
5657 name : str = "" ,
5758 run_indefinitely : bool = False ,
5859 transforms : ChainedInputTransform = ChainedInputTransform (** {}),
60+ copy_model : bool = False ,
5961 ) -> None :
6062 """Initialize the strategy object.
6163
@@ -90,6 +92,9 @@ def __init__(
9092 should be defined in raw parameter space for initialization. However,
9193 if the lb/ub attribute are access from an initialized Strategy object,
9294 it will be returned in transformed space.
95+ copy_model (bool): Whether to do any model-related methods on a
96+ copy or the original. Used for multi-client strategies. Defaults
97+ to False.
9398 """
9499 self .is_finished = False
95100
@@ -160,6 +165,7 @@ def __init__(
160165 self .min_total_outcome_occurrences = min_total_outcome_occurrences
161166 self .max_asks = max_asks or generator .max_asks
162167 self .keep_most_recent = keep_most_recent
168+ self .copy_model = copy_model
163169
164170 self .transforms = transforms
165171 if self .transforms is not None :
@@ -267,7 +273,8 @@ def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor:
267273 self .model .to (self .generator_device ) # type: ignore
268274
269275 self ._count = self ._count + num_points
270- points = self .generator .gen (num_points , self .model , ** kwargs )
276+ model = deepcopy (self .model ) if self .copy_model else self .model
277+ points = self .generator .gen (num_points , model , ** kwargs )
271278
272279 if original_device is not None :
273280 self .model .to (original_device ) # type: ignore
@@ -295,9 +302,9 @@ def get_max(
295302 self .model is not None
296303 ), "model is None! Cannot get the max without a model!"
297304 self .model .to (self .model_device )
298-
305+ model = deepcopy ( self . model ) if self . copy_model else self . model
299306 val , arg = get_max (
300- self . model ,
307+ model ,
301308 self .bounds ,
302309 locked_dims = constraints ,
303310 probability_space = probability_space ,
@@ -324,9 +331,9 @@ def get_min(
324331 self .model is not None
325332 ), "model is None! Cannot get the min without a model!"
326333 self .model .to (self .model_device )
327-
334+ model = deepcopy ( self . model ) if self . copy_model else self . model
328335 val , arg = get_min (
329- self . model ,
336+ model ,
330337 self .bounds ,
331338 locked_dims = constraints ,
332339 probability_space = probability_space ,
@@ -358,9 +365,9 @@ def inv_query(
358365 self .model is not None
359366 ), "model is None! Cannot get the inv_query without a model!"
360367 self .model .to (self .model_device )
361-
368+ model = deepcopy ( self . model ) if self . copy_model else self . model
362369 val , arg = inv_query (
363- model = self . model ,
370+ model = model ,
364371 y = y ,
365372 bounds = self .bounds ,
366373 locked_dims = constraints ,
@@ -385,7 +392,8 @@ def predict(
385392 """
386393 assert self .model is not None , "model is None! Cannot predict without a model!"
387394 self .model .to (self .model_device )
388- return self .model .predict (x = x , probability_space = probability_space )
395+ model = deepcopy (self .model ) if self .copy_model else self .model
396+ return model .predict (x = x , probability_space = probability_space )
389397
390398 @ensure_model_is_fresh
391399 def sample (self , x : torch .Tensor , num_samples : int = 1000 ) -> torch .Tensor :
@@ -400,7 +408,8 @@ def sample(self, x: torch.Tensor, num_samples: int = 1000) -> torch.Tensor:
400408 """
401409 assert self .model is not None , "model is None! Cannot sample without a model!"
402410 self .model .to (self .model_device )
403- return self .model .sample (x , num_samples = num_samples )
411+ model = deepcopy (self .model ) if self .copy_model else self .model
412+ return model .sample (x , num_samples = num_samples )
404413
405414 def finish (self ) -> None :
406415 """Finish the strategy."""
@@ -442,7 +451,8 @@ def finished(self) -> bool:
442451 assert (
443452 self .model is not None
444453 ), "model is None! Cannot predict without a model!"
445- fmean , _ = self .model .predict (self .eval_grid , probability_space = True )
454+ model = deepcopy (self .model ) if self .copy_model else self .model
455+ fmean , _ = model .predict (self .eval_grid , probability_space = True )
446456 meets_post_range = bool (
447457 ((fmean .max () - fmean .min ()) >= self .min_post_range ).item ()
448458 )
@@ -504,9 +514,10 @@ def fit(self) -> None:
504514 """Fit the model."""
505515 if self .can_fit :
506516 self .model .to (self .model_device ) # type: ignore
517+ model = deepcopy (self .model ) if self .copy_model else self .model
507518 if self .keep_most_recent is not None :
508519 try :
509- self . model .fit ( # type: ignore
520+ model .fit ( # type: ignore
510521 self .x [- self .keep_most_recent :], # type: ignore
511522 self .y [- self .keep_most_recent :], # type: ignore
512523 )
@@ -516,21 +527,23 @@ def fit(self) -> None:
516527 )
517528 else :
518529 try :
519- self . model .fit (self .x , self .y ) # type: ignore
530+ model .fit (self .x , self .y ) # type: ignore
520531 except ModelFittingError :
521532 logger .warning (
522533 "Failed to fit model! Predictions may not be accurate!"
523534 )
535+ self .model = model
524536 else :
525537 warnings .warn ("Cannot fit: no model has been initialized!" , RuntimeWarning )
526538
527539 def update (self ) -> None :
528540 """Update the model."""
529541 if self .can_fit :
530542 self .model .to (self .model_device ) # type: ignore
543+ model = deepcopy (self .model ) if self .copy_model else self .model
531544 if self .keep_most_recent is not None :
532545 try :
533- self . model .update ( # type: ignore
546+ model .update ( # type: ignore
534547 self .x [- self .keep_most_recent :], # type: ignore
535548 self .y [- self .keep_most_recent :], # type: ignore
536549 )
@@ -540,11 +553,13 @@ def update(self) -> None:
540553 )
541554 else :
542555 try :
543- self . model .update (self .x , self .y ) # type: ignore
556+ model .update (self .x , self .y ) # type: ignore
544557 except ModelFittingError :
545558 logger .warning (
546559 "Failed to fit model! Predictions may not be accurate!"
547560 )
561+
562+ self .model = model
548563 else :
549564 warnings .warn ("Cannot fit: no model has been initialized!" , RuntimeWarning )
550565
0 commit comments