@@ -18,6 +18,10 @@ pub use gameplay::*;
1818pub use mccfr:: * ;
1919pub use transport:: * ;
2020pub use wasm:: * ;
21+
22+ #[ cfg( feature = "native" ) ]
23+ static INTERRUPTED : std:: sync:: atomic:: AtomicBool = std:: sync:: atomic:: AtomicBool :: new ( false ) ;
24+
2125/// dimensional analysis types
2226type Chips = i16 ;
2327type Equity = f32 ;
@@ -53,7 +57,7 @@ const CFR_TREE_COUNT_RPS: usize = 8192;
5357
5458// nlhe mccfr parameters
5559const CFR_BATCH_SIZE_NLHE : usize = 64 ;
56- const CFR_TREE_COUNT_NLHE : usize = 0x10000 ;
60+ const CFR_TREE_COUNT_NLHE : usize = 0x1000000 ;
5761
5862/// profile average sampling parameters
5963const SAMPLING_THRESHOLD : Entropy = 1.0 ;
@@ -74,7 +78,7 @@ pub struct Args {
7478 pub cluster : bool ,
7579 /// Run the MCCFR training
7680 #[ arg( long) ]
77- pub solve : bool ,
81+ pub trainer : bool ,
7882 /// Publish results to the database
7983 #[ arg( long) ]
8084 pub publish : bool ,
@@ -100,15 +104,9 @@ pub fn progress(n: usize) -> indicatif::ProgressBar {
100104 progress
101105}
102106
103- /// initialize logging and exit on ctrl-c
107+ /// initialize logging and setup graceful interrupt listener
104108#[ cfg( feature = "native" ) ]
105- pub fn init ( ) {
106- tokio:: spawn ( async move {
107- tokio:: signal:: ctrl_c ( ) . await . unwrap ( ) ;
108- println ! ( ) ;
109- log:: warn!( "forcing exit" ) ;
110- std:: process:: exit ( 0 ) ;
111- } ) ;
109+ pub fn logs ( ) {
112110 std:: fs:: create_dir_all ( "logs" ) . expect ( "create logs directory" ) ;
113111 let config = simplelog:: ConfigBuilder :: new ( )
114112 . set_location_level ( log:: LevelFilter :: Off )
@@ -145,3 +143,31 @@ pub async fn db() -> std::sync::Arc<tokio_postgres::Client> {
145143 tokio:: spawn ( connection) ;
146144 std:: sync:: Arc :: new ( client)
147145}
146+
147+ #[ cfg( feature = "native" ) ]
148+ /// keyboard interruption for training
149+ /// spawn a thread to listen for 'q' input to gracefully interrupt training
150+ pub fn interrupts ( ) {
151+ // handle ctrl+c for immediate exit
152+ tokio:: spawn ( async move {
153+ tokio:: signal:: ctrl_c ( ) . await . unwrap ( ) ;
154+ println ! ( ) ;
155+ log:: warn!( "Ctrl+C received, exiting immediately" ) ;
156+ std:: process:: exit ( 0 ) ;
157+ } ) ;
158+ // handle 'q' input for graceful interrupt
159+ std:: thread:: spawn ( || {
160+ log:: info!( "training started. type 'Q + Enter' to gracefully interrupt." ) ;
161+ let ref mut buffer = String :: new ( ) ;
162+ loop {
163+ buffer. clear ( ) ;
164+ if let Ok ( _) = std:: io:: stdin ( ) . read_line ( buffer) {
165+ if buffer. trim ( ) . to_uppercase ( ) == "Q" {
166+ log:: warn!( "graceful interrupt requested, finishing current batch..." ) ;
167+ INTERRUPTED . store ( true , std:: sync:: atomic:: Ordering :: Relaxed ) ;
168+ break ;
169+ }
170+ }
171+ }
172+ } ) ;
173+ }
0 commit comments