Skip to content

Commit 98d3350

Browse files
committed
i have solved the interrupt problem
1 parent 9ae2277 commit 98d3350

File tree

3 files changed

+62
-19
lines changed

3 files changed

+62
-19
lines changed

src/lib.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ pub use gameplay::*;
1818
pub use mccfr::*;
1919
pub use transport::*;
2020
pub use wasm::*;
21+
22+
#[cfg(feature = "native")]
23+
static INTERRUPTED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
24+
2125
/// dimensional analysis types
2226
type Chips = i16;
2327
type Equity = f32;
@@ -53,7 +57,7 @@ const CFR_TREE_COUNT_RPS: usize = 8192;
5357

5458
// nlhe mccfr parameters
5559
const CFR_BATCH_SIZE_NLHE: usize = 64;
56-
const CFR_TREE_COUNT_NLHE: usize = 0x10000;
60+
const CFR_TREE_COUNT_NLHE: usize = 0x1000000;
5761

5862
/// profile average sampling parameters
5963
const SAMPLING_THRESHOLD: Entropy = 1.0;
@@ -74,7 +78,7 @@ pub struct Args {
7478
pub cluster: bool,
7579
/// Run the MCCFR training
7680
#[arg(long)]
77-
pub solve: bool,
81+
pub trainer: bool,
7882
/// Publish results to the database
7983
#[arg(long)]
8084
pub publish: bool,
@@ -100,15 +104,9 @@ pub fn progress(n: usize) -> indicatif::ProgressBar {
100104
progress
101105
}
102106

103-
/// initialize logging and exit on ctrl-c
107+
/// initialize logging and setup graceful interrupt listener
104108
#[cfg(feature = "native")]
105-
pub fn init() {
106-
tokio::spawn(async move {
107-
tokio::signal::ctrl_c().await.unwrap();
108-
println!();
109-
log::warn!("forcing exit");
110-
std::process::exit(0);
111-
});
109+
pub fn logs() {
112110
std::fs::create_dir_all("logs").expect("create logs directory");
113111
let config = simplelog::ConfigBuilder::new()
114112
.set_location_level(log::LevelFilter::Off)
@@ -145,3 +143,31 @@ pub async fn db() -> std::sync::Arc<tokio_postgres::Client> {
145143
tokio::spawn(connection);
146144
std::sync::Arc::new(client)
147145
}
146+
147+
#[cfg(feature = "native")]
148+
/// keyboard interruption for training
149+
/// spawn a thread to listen for 'q' input to gracefully interrupt training
150+
pub fn interrupts() {
151+
// handle ctrl+c for immediate exit
152+
tokio::spawn(async move {
153+
tokio::signal::ctrl_c().await.unwrap();
154+
println!();
155+
log::warn!("Ctrl+C received, exiting immediately");
156+
std::process::exit(0);
157+
});
158+
// handle 'q' input for graceful interrupt
159+
std::thread::spawn(|| {
160+
log::info!("training started. type 'Q + Enter' to gracefully interrupt.");
161+
let ref mut buffer = String::new();
162+
loop {
163+
buffer.clear();
164+
if let Ok(_) = std::io::stdin().read_line(buffer) {
165+
if buffer.trim().to_uppercase() == "Q" {
166+
log::warn!("graceful interrupt requested, finishing current batch...");
167+
INTERRUPTED.store(true, std::sync::atomic::Ordering::Relaxed);
168+
break;
169+
}
170+
}
171+
}
172+
});
173+
}

src/main.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ use robopoker::*;
33

44
#[tokio::main]
55
async fn main() {
6-
crate::init();
7-
if crate::Args::parse().cluster {
6+
crate::logs();
7+
crate::interrupts();
8+
let ref arguments = crate::Args::parse();
9+
if arguments.cluster {
810
crate::clustering::Layer::learn();
911
}
10-
if crate::Args::parse().solve {
12+
if arguments.trainer {
1113
crate::mccfr::NLHE::train();
1214
}
13-
if crate::Args::parse().publish {
15+
if arguments.publish {
1416
crate::save::Writer::publish().await.unwrap();
1517
}
16-
if crate::Args::parse().analyze {
18+
if arguments.analyze {
1719
crate::analysis::Server::run().await.unwrap();
1820
}
1921
}

src/mccfr/traits/blueprint.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ use crate::mccfr::structs::infoset::InfoSet;
88
use crate::mccfr::structs::tree::Tree;
99
use crate::mccfr::types::counterfactual::Counterfactual;
1010

11+
#[cfg(feature = "native")]
12+
use crate::INTERRUPTED;
13+
1114
/// given access to a Profile and Encoder,
1215
/// we enapsulate the process of
1316
/// 1) sampling Trees
@@ -58,18 +61,30 @@ pub trait Blueprint: Send + Sync {
5861
where
5962
Self: Sized,
6063
{
61-
let t = Self::iterations();
62-
log::info!("beginning training loop ({})", t);
63-
for _ in 0..t {
64+
'training: for _ in 0..Self::iterations() {
6465
for ref update in self.batch() {
6566
self.update_regret(update);
6667
self.update_weight(update);
6768
}
68-
self.advance();
69+
if self.interrupted() {
70+
break 'training;
71+
}
6972
}
7073
self
7174
}
7275

76+
/// handles interrupt for training process
77+
fn interrupted(&mut self) -> bool {
78+
if INTERRUPTED.load(std::sync::atomic::Ordering::Relaxed) {
79+
let t = self.profile().epochs();
80+
log::warn!("training interrupted @ {}", t);
81+
true
82+
} else {
83+
self.advance();
84+
false
85+
}
86+
}
87+
7388
/// Updates accumulated regret values for each edge in the counterfactual.
7489
fn update_regret(&mut self, cfr: &Counterfactual<Self::E, Self::I>) {
7590
let ref info = cfr.0.clone();

0 commit comments

Comments
 (0)