From d0af062f69fbcd74d47452d85ad2328f2fefd45d Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Thu, 18 Jul 2024 01:55:09 +0900 Subject: [PATCH] [BridgeBidding] Improve efficiency in using dds results (#1191) --- pgx/bridge_bidding.py | 82 +++++++----------------------------- tests/test_bridge_bidding.py | 32 +++----------- 2 files changed, 20 insertions(+), 94 deletions(-) diff --git a/pgx/bridge_bidding.py b/pgx/bridge_bidding.py index c9639671b..418992c44 100644 --- a/pgx/bridge_bidding.py +++ b/pgx/bridge_bidding.py @@ -18,7 +18,6 @@ import jax import jax.numpy as jnp -import numpy as np import pgx.core as core from pgx._src.struct import dataclass @@ -28,6 +27,8 @@ TRUE = jnp.bool_(True) FALSE = jnp.bool_(False) +HandArray = Array # (4,) int array + # The card and number correspondence # 0~12 spade, 13~25 heart, 26~38 diamond, 39~51 club # For each suit, the numbers are arranged in the following order @@ -45,42 +46,19 @@ def download_dds_results(download_dir="dds_results"): """Download and split the results into 100K chunks.""" - def split_data(data, prefix, base_i=0): - with open(data, "rb") as f: - keys, values = jnp.load(f) - n = 100_000 - m = keys.shape[0] // n - for i in range(m): - fname = os.path.join(download_dir, f"{prefix}_{base_i + i:03d}.npy") - with open(fname, "wb") as f: - print( - f"saving {fname} ... [{i * n}, {(i + 1) * n})", - file=sys.stderr, - ) - jnp.save( - f, - ( - keys[i * n : (i + 1) * n], - values[i * n : (i + 1) * n], - ), - ) - os.makedirs(download_dir, exist_ok=True) train_small_fname = os.path.join(download_dir, "dds_results_2.5M.npy") if not os.path.exists(train_small_fname): _download(DDS_RESULTS_TRAIN_SMALL_URL, train_small_fname) - split_data(train_small_fname, "train") train_large_fname = os.path.join(download_dir, "dds_results_10M.npy") if not os.path.exists(train_large_fname): _download(DDS_RESULTS_TRAIN_LARGE_URL, train_large_fname) - split_data(train_large_fname, "train", base_i=25) test_fname = os.path.join(download_dir, "dds_results_500K.npy") if not os.path.exists(test_fname): _download(DDS_RESULTS_TEST_URL, test_fname) - split_data(test_fname, "test") @dataclass @@ -136,6 +114,7 @@ class State(core.State): _first_denomination_EW: Array = jnp.full(5, -1, dtype=jnp.int32) # Number of pass _pass_num: Array = jnp.array(0, dtype=jnp.int32) + _dds_val: Array = jnp.zeros(4, dtype=jnp.int32) @property def env_id(self) -> core.EnvId: @@ -143,7 +122,7 @@ def env_id(self) -> core.EnvId: class BridgeBidding(core.Env): - def __init__(self, dds_results_table_path: str = "dds_results/train_000.npy"): + def __init__(self, dds_results_table_path: str = "dds_results/dds_results_10M.npy"): super().__init__() print( f"Loading dds results from {dds_results_table_path} ...", @@ -180,12 +159,14 @@ def __init__(self, dds_results_table_path: str = "dds_results/train_000.npy"): def _init(self, key: PRNGKey) -> State: key1, key2 = jax.random.split(key, num=2) - return _init_by_key(jax.random.choice(key1, self._lut_keys), key2) + ix = jax.random.choice(key1, jnp.arange(self._lut_keys.shape[0])) + key, val = self._lut_keys[ix], self._lut_values[ix] + return _init_by_key(key, val, key2) def _step(self, state: core.State, action: int, key) -> State: del key assert isinstance(state, State) - return _step(state, action, self._lut_keys, self._lut_values) + return _step(state, action) def _observe(self, state: core.State, player_id: Array) -> Array: assert isinstance(state, State) @@ -208,7 +189,7 @@ def _illegal_action_penalty(self) -> float: return -7600.0 -def _init_by_key(key: PRNGKey, rng: PRNGKey) -> State: +def _init_by_key(key: HandArray, val: Array, rng: PRNGKey) -> State: """Make init state from key""" rng1, rng2, rng3, rng4 = jax.random.split(rng, num=4) hand = _key_to_hand(key) @@ -230,6 +211,7 @@ def _init_by_key(key: PRNGKey, rng: PRNGKey) -> State: _vul_NS=vul_NS, _vul_EW=vul_EW, legal_action_mask=legal_actions, + _dds_val=val, ) return state @@ -285,8 +267,6 @@ def _player_position(player: Array, state: State) -> Array: def _step( state: State, action: int, - lut_keys: Array, - lut_values: Array, ) -> State: # fmt: off state = state.replace(_bidding_history=state._bidding_history.at[state._turn].set(action)) # type: ignore @@ -298,7 +278,7 @@ def _step( [ lambda: jax.lax.cond( _is_terminated(_state_pass(state)), - lambda: _terminated_step(_state_pass(state), lut_keys, lut_values), + lambda: _terminated_step(_state_pass(state)), lambda: _continue_step(_state_pass(state)), ), lambda: _continue_step(_state_X(state)), @@ -392,12 +372,10 @@ def _convert_card_pgx_to_openspiel(card: Array) -> Array: def _terminated_step( state: State, - lut_keys: Array, - lut_values: Array, ) -> State: """Return state if the game is successfully completed""" terminated = jnp.bool_(True) - reward = _reward(state, lut_keys, lut_values) + reward = _reward(state) # fmt: off return state.replace(terminated=terminated, rewards=reward) # type: ignore # fmt: on @@ -433,8 +411,6 @@ def _is_terminated(state: State) -> bool: def _reward( state: State, - lut_keys: Array, - lut_values: Array, ) -> Array: """Return reward If pass out, 0 reward for everyone; if bid, calculate and return reward @@ -442,14 +418,12 @@ def _reward( return jax.lax.cond( (state._last_bid == -1) & (state._pass_num == 4), lambda: jnp.zeros(4, dtype=jnp.float32), # pass out - lambda: _make_reward(state, lut_keys, lut_values), # caluculate reward + lambda: _make_reward(state), # caluculate reward ) def _make_reward( state: State, - lut_keys: Array, - lut_values: Array, ) -> Array: """Calculate rewards for each player by dds results @@ -459,7 +433,7 @@ def _make_reward( # Extract contract from state declare_position, denomination, level, vul = _contract(state) # Calculate trick table from hash table - dds_tricks = _calculate_dds_tricks(state, lut_keys, lut_values) + dds_tricks = _value_to_dds_tricks(state._dds_val) # Calculate the tricks you could have accomplished with this contraption dds_trick = dds_tricks[declare_position * 5 + denomination] # Clculate score @@ -814,7 +788,7 @@ def _card_str_to_int(card: str) -> int: return int(card) - 1 -def _key_to_hand(key: PRNGKey) -> Array: +def _key_to_hand(key: HandArray) -> Array: """Convert key to hand""" def _convert_quat(j): @@ -859,32 +833,6 @@ def _convert_hex(j): return jnp.array(hex_digits, dtype=jnp.int32) -def _calculate_dds_tricks( - state: State, - lut_keys: Array, - lut_values: Array, -) -> Array: - key = _state_to_key(state) - return _value_to_dds_tricks(_find_value_from_key(key, lut_keys, lut_values)) - - -def _find_value_from_key(key: PRNGKey, lut_keys: Array, lut_values: Array): - """Find a value matching key without batch processing - >>> VALUES = jnp.arange(20).reshape(5, 4) - >>> KEYS = jnp.arange(20).reshape(5, 4) - >>> key = jnp.arange(4, 8) - >>> _find_value_from_key(key, KEYS, VALUES) - Array([4, 5, 6, 7], dtype=int32) - """ - mask = jnp.where( - jnp.all((lut_keys == key), axis=1), - jnp.ones(1, dtype=np.bool_), - jnp.zeros(1, dtype=np.bool_), - ) - ix = jnp.argmax(mask) - return lut_values[ix] - - def _load_sample_hash() -> Tuple[Array, Array]: # fmt: off return jnp.array([[19556549, 61212362, 52381660, 50424958], [53254536, 21854346, 37287883, 14009558], [44178585, 6709002, 23279217, 16304124], [36635659, 48114215, 13583653, 26208086], [44309474, 39388022, 28376136, 59735189], [61391908, 52173479, 29276467, 31670621], [34786519, 13802254, 57433417, 43152306], [48319039, 55845612, 44614774, 58169152], [47062227, 32289487, 12941848, 21338650], [36579116, 15643926, 64729756, 18678099], [62136384, 37064817, 59701038, 39188202], [13417016, 56577539, 25995845, 27248037], [61125047, 43238281, 23465183, 20030494], [7139188, 31324229, 58855042, 14296487], [2653767, 47502150, 35507905, 43823846], [31453323, 11605145, 6716808, 41061859], [21294711, 49709, 26110952, 50058629], [48130172, 3340423, 60445890, 7686579], [16041939, 27817393, 37167847, 9605779], [61154057, 17937858, 12254613, 12568801], [13796245, 46546127, 49123920, 51772041], [7195005, 45581051, 41076865, 17429796], [20635965, 14642724, 7001617, 45370595], [35616421, 19938131, 45131030, 16524847], [14559399, 15413729, 39188470, 535365], [48743216, 39672069, 60203571, 60210880], [63862780, 2462075, 23267370, 36595020], [11229980, 11616119, 20292263, 3695004], [24135854, 37532826, 54421444, 14130249], [42798085, 33026223, 2460251, 18566823], [49558558, 65537599, 14768519, 31103243], [44321156, 20075251, 42663767, 11615602], [20186726, 42678073, 11763300, 56739471], [57534601, 16703645, 6039937, 17088125], [50795278, 17350238, 11955835, 21538127], [45919621, 5520088, 27736513, 52674927], [13928720, 57324148, 28222453, 15480785], [910719, 47238830, 26345802, 56166394], [58841430, 1098476, 61890558, 26907706], [10379825, 8624220, 39701822, 29045990], [54444873, 50000486, 48563308, 55867521], [47291672, 22084522, 45484828, 32878832], [55350706, 23903891, 46142039, 11499952], [4708326, 27588734, 31010458, 11730972], [27078872, 59038086, 62842566, 51147874], [28922172, 32377861, 9109075, 10154350], [26104086, 62786977, 224865, 14335943], [20448626, 33187645, 34338784, 26382893], [29194006, 19635744, 24917755, 8532577], [64047742, 34885257, 5027048, 58399668], [27603972, 26820121, 44837703, 63748595], [60038456, 19611050, 7928914, 38555047], [13583610, 19626473, 22239272, 19888268], [28521006, 1743692, 31319264, 15168920], [64585849, 63931241, 57019799, 14189800], [2632453, 7269809, 60404342, 57986125], [1996183, 49918209, 49490468, 47760867], [6233580, 15318425, 51356120, 55074857], [15769884, 61654638, 8374039, 43685186], [44162419, 47272176, 62693156, 35359329], [36345796, 15667465, 53341561, 2978505], [1664472, 12761950, 34145519, 55197543], [37567005, 3228834, 6198166, 15646487], [63233399, 42640049, 12969011, 41620641], [22090925, 3386355, 56655568, 31631004], [16442787, 9420273, 48595545, 29770176], [49404288, 37823218, 58551818, 6772527], [36575583, 53847347, 32379432, 1630009], [9004247, 12999580, 48379959, 14252211], [25850203, 26136823, 64934025, 29362603], [10214276, 43557352, 33387586, 55512773], [45810841, 49561478, 41130845, 27034816], [34460081, 16560450, 57722793, 41007718], [53414778, 6845803, 15340368, 16647575], [30535873, 5193469, 43608154, 11391114], [20622004, 34424126, 31475211, 29619615], [10428836, 49656416, 7912677, 33427787], [57600861, 18251799, 46147432, 58946294], [6760779, 14675737, 42952146, 5480498], [46037552, 39969058, 30103468, 55330772], [64466305, 29376674, 49914839, 55269895], [36494113, 27010567, 65752150, 12395385], [49385632, 19550767, 39809394, 58806235], [20987521, 37444597, 49290126, 42326125], [37150229, 37487849, 28254397, 32949826], [9724895, 53813417, 19431235, 27438556], [42132748, 47073733, 19396568, 10026137], [3961481, 27204521, 62087205, 37602005], [22178323, 17505521, 42006207, 44143605], [12753258, 63063515, 61993175, 8920985], [10998000, 64833190, 6446892, 63676805], [66983817, 63684932, 18378359, 39946382], [63476803, 60000436, 19442420, 66417845], [38004446, 64752157, 42570179, 52844512], [1270809, 23735482, 17543294, 18795903], [4862706, 16352249, 57100612, 6219870], [63203206, 25630930, 35608240, 51357885], [59819625, 64662579, 50925335, 55670434], [29216830, 26446697, 52243336, 58475666], [43138915, 30592834, 43931516, 50628002]], dtype=jnp.int32), jnp.array([[71233, 771721, 71505, 706185], [289177, 484147, 358809, 484147], [359355, 549137, 359096, 549137], [350631, 558133, 350630, 554037], [370087, 538677, 370087, 538677], [4432, 899725, 4432, 904077], [678487, 229987, 678487, 229987], [423799, 480614, 423799, 480870], [549958, 284804, 549958, 280708], [423848, 480565, 423848, 480549], [489129, 283940, 554921, 283940], [86641, 822120, 86641, 822120], [206370, 702394, 206370, 567209], [500533, 407959, 500533, 407959], [759723, 79137, 759723, 79137], [563305, 345460, 559209, 345460], [231733, 611478, 231733, 611478], [502682, 406082, 498585, 406082], [554567, 288662, 554567, 288662], [476823, 427846, 476823, 427846], [488823, 415846, 488823, 415846], [431687, 477078, 431687, 477078], [419159, 424070, 415062, 424070], [493399, 345734, 493143, 345718], [678295, 230451, 678295, 230451], [496520, 342596, 496520, 346709], [567109, 276116, 567109, 276116], [624005, 284758, 624005, 284758], [420249, 484420, 420248, 484420], [217715, 621418, 217715, 621418], [344884, 493977, 344884, 493977], [550841, 292132, 550841, 292132], [284262, 558967, 284006, 558967], [152146, 756616, 152146, 756616], [144466, 698763, 144466, 694667], [284261, 624504, 284261, 624504], [288406, 620102, 288405, 620358], [301366, 607383, 301366, 607382], [468771, 435882, 468771, 435882], [555688, 283444, 555688, 283444], [485497, 414820, 485497, 414820], [633754, 275010, 633754, 275010], [419141, 489608, 419157, 489608], [694121, 214387, 694121, 214387], [480869, 427639, 481125, 427639], [489317, 419447, 489301, 419447], [152900, 747672, 152900, 747672], [348516, 494457, 348516, 494457], [534562, 370088, 534562, 370088], [371272, 537475, 371274, 537475], [144194, 760473, 144194, 760473], [567962, 275011, 567962, 275011], [493161, 350052, 493161, 350052], [490138, 348979, 490138, 348979], [328450, 506552, 328450, 506552], [148882, 759593, 148626, 755497], [642171, 266593, 642171, 266593], [685894, 218774, 685894, 218774], [674182, 234548, 674214, 234548], [756347, 152146, 690811, 86353], [612758, 291894, 612758, 291894], [296550, 612214, 296550, 612214], [363130, 475730, 363130, 475730], [691559, 16496, 691559, 16496], [340755, 502202, 336659, 502218], [632473, 210499, 628377, 210483], [564410, 266513, 564410, 266513], [427366, 481399, 427366, 481399], [493159, 349797, 493159, 415605], [331793, 576972, 331793, 576972], [416681, 492084, 416681, 492084], [813496, 95265, 813496, 91153], [695194, 213571, 695194, 213571], [436105, 407124, 436105, 407124], [836970, 6243, 902506, 6243], [160882, 747882, 160882, 747882], [493977, 414788, 489624, 414788], [29184, 551096, 29184, 616888], [903629, 4880, 899517, 4880], [351419, 553250, 351419, 553250], [75554, 767671, 75554, 767671], [279909, 563304, 279909, 563304], [215174, 628054, 215174, 628054], [361365, 481864, 361365, 481864], [424022, 484743, 358486, 484725], [271650, 633018, 271650, 633018], [681896, 226867, 616088, 226867], [222580, 686184, 222564, 686184], [144451, 698778, 209987, 698778], [532883, 310086, 532883, 310086], [628872, 279893, 628872, 279893], [533797, 374951, 533797, 374951], [91713, 817036, 91713, 817036], [427605, 477046, 431718, 477046], [145490, 689529, 145490, 689529], [551098, 291875, 551098, 291875], [349781, 558984, 349781, 558983], [205378, 703115, 205378, 703115], [362053, 546456, 362053, 546456], [612248, 226371, 678040, 226371]], dtype=jnp.int32) diff --git a/tests/test_bridge_bidding.py b/tests/test_bridge_bidding.py index 28a4d7a18..64eebf5f0 100644 --- a/tests/test_bridge_bidding.py +++ b/tests/test_bridge_bidding.py @@ -10,7 +10,6 @@ BridgeBidding, State, _calc_score, - _calculate_dds_tricks, _contract, _init_by_key, _key_to_hand, @@ -64,7 +63,6 @@ def init(rng: jax.Array) -> State: step = jax.jit(env.step) observe = jax.jit(env.observe) _calc_score = jax.jit(_calc_score) -_calculate_dds_tricks = jax.jit(_calculate_dds_tricks) _contract = jax.jit(_contract) _init_by_key = jax.jit(_init_by_key) _key_to_hand = jax.jit(_key_to_hand) @@ -131,7 +129,7 @@ def test_step(): # fmt: on key = jax.random.PRNGKey(0) HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) + state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0], key) # state = init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) state = state.replace( _dealer=jnp.int32(1), @@ -957,7 +955,7 @@ def max_action_length_agent(state: State) -> int: def test_max_action(): key = jax.random.PRNGKey(0) HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) + state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key) for i in range(319): if i < 318: @@ -984,7 +982,7 @@ def max_action_length_agent(state: State) -> int: def test_max_action(): key = jax.random.PRNGKey(0) HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) + state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key) for i in range(319): if i < 318: @@ -1005,7 +1003,7 @@ def test_pass_out(): actions = iter([0, 0, 0, 0]) key = jax.random.PRNGKey(0) HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) + state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key) # state = init_by_key(HASH_TABLE_SAMPLE_KEYS[1], key) state = state.replace( _dealer=jnp.int32(1), @@ -1146,7 +1144,7 @@ def test_observe(): ) key = jax.random.PRNGKey(0) HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], key) + state = _init_by_key(HASH_TABLE_SAMPLE_KEYS[0], HASH_TABLE_SAMPLE_VALUES[0],key) state = state.replace( _dealer=jnp.int32(1), current_player=jnp.int32(3), @@ -2193,26 +2191,6 @@ def test_state_to_key_cycle(): assert jnp.all(sorted_hand == reconst_hand) -def test_calcurate_dds_tricks(): - HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES = _load_sample_hash() - samples = [] - with open("tests/assets/contractbridge-ddstable-sample100.csv", "r") as f: - reader = csv.reader(f, delimiter=",") - for i in reader: - samples.append([i[0], np.array(i[1:]).astype(np.int32)]) - key = jax.random.PRNGKey(0) - for i in range(len(HASH_TABLE_SAMPLE_KEYS)): - key, subkey = jax.random.split(key) - state = init(subkey) - state = state.replace(_hand=_key_to_hand(HASH_TABLE_SAMPLE_KEYS[i])) - dds_tricks = _calculate_dds_tricks( - state, HASH_TABLE_SAMPLE_KEYS, HASH_TABLE_SAMPLE_VALUES - ) - # calculate dds results from hash table made by dample data - # check whether the results are conssitent with sample data - assert jnp.all(dds_tricks == samples[i][1]) - - def test_value_to_dds_tricks(): value = jnp.array([4160, 904605, 4160, 904605]) # fmt: off