@@ -22,15 +22,17 @@ def __init__(
2222
2323 space = spaces .Discrete (dim )
2424 self .n_invalid_actions = n_invalid_actions
25- self .possible_actions = np .arange (space .n )
25+ self .possible_actions = np .arange (space .n , dtype = int )
2626 self .invalid_actions : list [int ] = []
2727 super ().__init__ (space = space , ep_length = ep_length )
2828
2929 def _choose_next_state (self ) -> None :
3030 self .state = self .action_space .sample ()
3131 # Randomly choose invalid actions that are not the current state
3232 potential_invalid_actions = [i for i in self .possible_actions if i != self .state ]
33- self .invalid_actions = np .random .choice (potential_invalid_actions , self .n_invalid_actions , replace = False ).tolist ()
33+ self .invalid_actions = np .random .choice ( # type: ignore[assignment]
34+ potential_invalid_actions , self .n_invalid_actions , replace = False
35+ ).tolist ()
3436
3537 def action_masks (self ) -> list [bool ]:
3638 return [action not in self .invalid_actions for action in self .possible_actions ]
@@ -72,7 +74,9 @@ def _choose_next_state(self) -> None:
7274
7375 # Randomly choose invalid actions that are not the current state
7476 potential_invalid_actions = [i for i in self .possible_actions if i not in converted_state ]
75- self .invalid_actions = np .random .choice (potential_invalid_actions , self .n_invalid_actions , replace = False ).tolist ()
77+ self .invalid_actions = np .random .choice ( # type: ignore[assignment]
78+ potential_invalid_actions , self .n_invalid_actions , replace = False
79+ ).tolist ()
7680
7781 def action_masks (self ) -> list [bool ]:
7882 return [action not in self .invalid_actions for action in self .possible_actions ]
@@ -113,7 +117,9 @@ def _choose_next_state(self) -> None:
113117
114118 # Randomly choose invalid actions that are not the current state
115119 potential_invalid_actions = [i for i in self .possible_actions if i not in converted_state ]
116- self .invalid_actions = np .random .choice (potential_invalid_actions , self .n_invalid_actions , replace = False ).tolist ()
120+ self .invalid_actions = np .random .choice ( # type: ignore[assignment]
121+ potential_invalid_actions , self .n_invalid_actions , replace = False
122+ ).tolist ()
117123
118124 def action_masks (self ) -> list [bool ]:
119125 return [action not in self .invalid_actions for action in self .possible_actions ]
0 commit comments