diff --git a/urnai/sc2/actions/collectables.py b/urnai/sc2/actions/collectables.py index 493da9d..b7a4711 100644 --- a/urnai/sc2/actions/collectables.py +++ b/urnai/sc2/actions/collectables.py @@ -31,8 +31,7 @@ def __init__(self): self.pending_actions = [] def is_action_done(self): - # return len(self.pending_actions) == 0 - return True + return len(self.pending_actions) == 0 def reset(self): self.move_number = 0 diff --git a/urnai/sc2/environments/sc2environment.py b/urnai/sc2/environments/sc2environment.py index c4d9e58..1d61b7b 100644 --- a/urnai/sc2/environments/sc2environment.py +++ b/urnai/sc2/environments/sc2environment.py @@ -1,6 +1,3 @@ -import sys - -from absl import flags from pysc2.env import sc2_env from pysc2.env.environment import TimeStep from pysc2.lib import actions, features diff --git a/urnai/sc2/states/collectables.py b/urnai/sc2/states/collectables.py index b541b51..7208708 100644 --- a/urnai/sc2/states/collectables.py +++ b/urnai/sc2/states/collectables.py @@ -21,33 +21,41 @@ class CollectablesMethod(Enum): class CollectablesState(StateBase): def __init__(self, trim_map : bool = False, - method : CollectablesMethod = CollectablesMethod.STATE_MAP): + method : CollectablesMethod = CollectablesMethod.STATE_MAP, + map_size = (64, 64)): + + """ + map_reduction_factor's default value is 1. + Example: if the value is 2, the map's size gets reduced by half. + + Non-spatial is state is composed of: + . x distance to next mineral shard + . y distance to next mineral shard + + """ + self.previous_state = None self.method = method - # number of quadrants is the amount of parts the map should be reduced - # this helps the agent to deal with the big size of state space - # if 1 (default value), the map wont be reduced + self.map_size = map_size self.map_reduction_factor = STATE_MAP_DEFAULT_REDUCTIONFACTOR + self.non_spatial_maximums = [ STATE_MAX_COLL_DIST, STATE_MAX_COLL_DIST, - # RTSGeneralization.STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS, ] + self.non_spatial_minimums = [ 0, 0, - # 0, ] - # non-spatial is composed of - # X distance to next mineral shard - # Y distance to next mineral shard - # number of mineral shards left + self.non_spatial_state = [ 0, 0, - # 0, ] + self.trim_map = trim_map + self.trim_factor = (22/64, 0.25) self.reset() def update(self, obs): @@ -68,7 +76,6 @@ def update(self, obs): def build_map(self, obs): map_ = self.build_basic_map(obs) map_ = self.reduce_map(map_) - map_ = self.normalize_map(map_) return map_ @@ -89,8 +96,7 @@ def build_basic_map(self, obs): return map_ def normalize_map(self, map_): - # map = (map_ - map_.min()) / (map_.max() - map_.min()) - return map_ + return (map_ - map_.min()) / (map_.max() - map_.min()) def normalize_non_spatial_list(self): for i in range(len(self.non_spatial_state)): @@ -103,16 +109,17 @@ def normalize_non_spatial_list(self): def normalize_value(self, value, max_, min_=0): return (value - min_) / (max_ - min_) - # TODO: Remove magic numbers @property def dimension(self): if self.method == CollectablesMethod.STATE_MAP: if self.trim_map: - a = int(22 / self.map_reduction_factor) - b = int(16 / self.map_reduction_factor) + a = int(self.trim_factor[0] * + self.map_size[0] / self.map_reduction_factor) + b = int(self.trim_factor[1] * + self.map_size[1] / self.map_reduction_factor) else: - a = int(64 / self.map_reduction_factor) - b = int(64 / self.map_reduction_factor) + a = int(self.map_size[0] / self.map_reduction_factor) + b = int(self.map_size[1] / self.map_reduction_factor) return int(a * b) elif self.method == CollectablesMethod.STATE_NON_SPATIAL: return len(self.non_spatial_state) @@ -151,12 +158,8 @@ def get_closest_mineral_shard_x_y(self, obs): def build_non_spatial_state(self, obs): x, y = self.get_closest_mineral_shard_x_y(obs) - # position 0: distance x to closest shard self.non_spatial_state[0] = int(x) - # position 1: distance y to closest shard self.non_spatial_state[1] = int(y) - # position 2: number of remaining shards - # self.non_spatial_state[2]=np.count_nonzero(obs.feature_minimap[4]==16) self.normalize_non_spatial_list() self.non_spatial_state = np.array(self.non_spatial_state) return self.non_spatial_state @@ -178,13 +181,13 @@ def reset(self): def trim_matrix(self, matrix, x1, y1, x2, y2): """ - If you have a 2D numpy array - and you want a submatrix of that array, - you can use this function to extract it. - You just need to tell this function - what are the top-left and bottom-right - corners of this submatrix, by setting - x1, y1 and x2, y2. + This function extracts a submatrix of a + 2D numpy array. + + The arguments x1, y1 and x2, y2 are the + top-left and bottom-right corners of + this submatrix, respectively. + For example: some maps of StarCraft II have parts that are not walkable, this happens specially in PySC2 mini-games @@ -199,25 +202,23 @@ def trim_matrix(self, matrix, x1, y1, x2, y2): matrix = np.delete(matrix, np.s_[y2 - y1 + 1::1], 0) return matrix - def lower_featuremap_resolution(self, map, rf): # rf = reduction_factor + def lower_featuremap_resolution(self, map, reduction_factor): """ Reduces a matrix "resolution" by a reduction factor. If we have a 64x64 matrix and rf=4 the map will be reduced to 16x16 in which every new element of the matrix is an average from 4x4=16 elements from the original matrix """ - if rf == 1: + if reduction_factor == 1: return map N, M = map.shape - N = N // rf - M = M // rf + N = N // reduction_factor + M = M // reduction_factor reduced_map = np.empty((N, M)) for i in range(N): for j in range(M): - # reduction_array = map[rf*i:rf*i+rf, rf*j:rf*j+rf].flatten() - # reduced_map[i,j] = Counter(reduction_array).most_common(1)[0][0] - + rf = reduction_factor reduced_map[i, j] = ((map[rf * i:rf * i + rf, rf * j:rf * j + rf].sum()) / (rf * rf)) diff --git a/urnai/trainers/stablebaselines3_trainer.py b/urnai/trainers/stablebaselines3_trainer.py index bcf3c3c..ab4b94e 100644 --- a/urnai/trainers/stablebaselines3_trainer.py +++ b/urnai/trainers/stablebaselines3_trainer.py @@ -1,10 +1,11 @@ import os +import wandb + from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.type_aliases import MaybeCallback -import wandb from urnai.environments.stablebaselines3.custom_env import CustomEnv @@ -37,7 +38,7 @@ def load_most_recent_model(self, model_path): raise Exception(f"No models found in {model_path}") else: def only_digits(filename): - return ''.join(c for c in filename if c.isdigit()) + return int(''.join(c for c in filename if c.isdigit())) save_files.sort(reverse=True, key=only_digits) self.load_model(f"{model_path}/{save_files[0]}")