@@ -52,6 +52,7 @@ class NewtonOptConfig(OptimizerConfig):
52
52
last_best : float = 0
53
53
use_temporal_smooth : bool = False
54
54
cost_relative_threshold : float = 0.999
55
+ fix_terminal_action : bool = False
55
56
56
57
# use_update_best_kernel: bool
57
58
# c_1: float
@@ -416,16 +417,21 @@ def _armijo_line_search(self, x, step_direction):
416
417
def _approx_line_search (self , x , step_direction ):
417
418
if self .step_scale != 0.0 and self .step_scale != 1.0 :
418
419
step_direction = self .scale_step_direction (step_direction )
420
+ if self .fix_terminal_action and self .action_horizon > 1 :
421
+ step_direction [..., (self .action_horizon - 1 ) * self .d_action :] = 0.0
419
422
if self .line_search_type == LineSearchType .GREEDY :
420
- return self ._greedy_line_search (x , step_direction )
423
+ best_x , best_c , best_grad = self ._greedy_line_search (x , step_direction )
421
424
elif self .line_search_type == LineSearchType .ARMIJO :
422
- return self ._armijo_line_search (x , step_direction )
425
+ best_x , best_c , best_grad = self ._armijo_line_search (x , step_direction )
423
426
elif self .line_search_type in [
424
427
LineSearchType .WOLFE ,
425
428
LineSearchType .STRONG_WOLFE ,
426
429
LineSearchType .APPROX_WOLFE ,
427
430
]:
428
- return self ._wolfe_line_search (x , step_direction )
431
+ best_x , best_c , best_grad = self ._wolfe_line_search (x , step_direction )
432
+ if self .fix_terminal_action and self .action_horizon > 1 :
433
+ best_grad [..., (self .action_horizon - 1 ) * self .d_action :] = 0.0
434
+ return best_x , best_c , best_grad
429
435
430
436
def check_convergence (self , cost ):
431
437
above_threshold = cost > self .cost_convergence
0 commit comments