Skip to content

Commit

Permalink
refactor: Remove comments
Browse files Browse the repository at this point in the history
  • Loading branch information
CinquilCinquil committed Jan 31, 2025
1 parent 0675ed6 commit e274fdc
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 44 deletions.
3 changes: 1 addition & 2 deletions urnai/sc2/actions/collectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def __init__(self):
self.pending_actions = []

def is_action_done(self):
# return len(self.pending_actions) == 0
return True
return len(self.pending_actions) == 0

def reset(self):
self.move_number = 0
Expand Down
3 changes: 0 additions & 3 deletions urnai/sc2/environments/sc2environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import sys

from absl import flags
from pysc2.env import sc2_env
from pysc2.env.environment import TimeStep
from pysc2.lib import actions, features
Expand Down
75 changes: 38 additions & 37 deletions urnai/sc2/states/collectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,41 @@ class CollectablesMethod(Enum):
class CollectablesState(StateBase):

def __init__(self, trim_map : bool = False,
method : CollectablesMethod = CollectablesMethod.STATE_MAP):
method : CollectablesMethod = CollectablesMethod.STATE_MAP,
map_size = (64, 64)):

"""
map_reduction_factor's default value is 1.
Example: if the value is 2, the map's size gets reduced by half.
Non-spatial is state is composed of:
. x distance to next mineral shard
. y distance to next mineral shard
"""

self.previous_state = None
self.method = method
# number of quadrants is the amount of parts the map should be reduced
# this helps the agent to deal with the big size of state space
# if 1 (default value), the map wont be reduced
self.map_size = map_size
self.map_reduction_factor = STATE_MAP_DEFAULT_REDUCTIONFACTOR

self.non_spatial_maximums = [
STATE_MAX_COLL_DIST,
STATE_MAX_COLL_DIST,
# RTSGeneralization.STATE_MAXIMUM_NUMBER_OF_MINERAL_SHARDS,
]

self.non_spatial_minimums = [
0,
0,
# 0,
]
# non-spatial is composed of
# X distance to next mineral shard
# Y distance to next mineral shard
# number of mineral shards left

self.non_spatial_state = [
0,
0,
# 0,
]

self.trim_map = trim_map
self.trim_factor = (22/64, 0.25)
self.reset()

def update(self, obs):
Expand All @@ -68,7 +76,6 @@ def update(self, obs):
def build_map(self, obs):
map_ = self.build_basic_map(obs)
map_ = self.reduce_map(map_)
map_ = self.normalize_map(map_)

return map_

Expand All @@ -89,8 +96,7 @@ def build_basic_map(self, obs):
return map_

def normalize_map(self, map_):
# map = (map_ - map_.min()) / (map_.max() - map_.min())
return map_
return (map_ - map_.min()) / (map_.max() - map_.min())

def normalize_non_spatial_list(self):
for i in range(len(self.non_spatial_state)):
Expand All @@ -103,16 +109,17 @@ def normalize_non_spatial_list(self):
def normalize_value(self, value, max_, min_=0):
return (value - min_) / (max_ - min_)

# TODO: Remove magic numbers
@property
def dimension(self):
if self.method == CollectablesMethod.STATE_MAP:
if self.trim_map:
a = int(22 / self.map_reduction_factor)
b = int(16 / self.map_reduction_factor)
a = int(self.trim_factor[0] *
self.map_size[0] / self.map_reduction_factor)
b = int(self.trim_factor[1] *
self.map_size[1] / self.map_reduction_factor)
else:
a = int(64 / self.map_reduction_factor)
b = int(64 / self.map_reduction_factor)
a = int(self.map_size[0] / self.map_reduction_factor)
b = int(self.map_size[1] / self.map_reduction_factor)
return int(a * b)
elif self.method == CollectablesMethod.STATE_NON_SPATIAL:
return len(self.non_spatial_state)
Expand Down Expand Up @@ -151,12 +158,8 @@ def get_closest_mineral_shard_x_y(self, obs):

def build_non_spatial_state(self, obs):
x, y = self.get_closest_mineral_shard_x_y(obs)
# position 0: distance x to closest shard
self.non_spatial_state[0] = int(x)
# position 1: distance y to closest shard
self.non_spatial_state[1] = int(y)
# position 2: number of remaining shards
# self.non_spatial_state[2]=np.count_nonzero(obs.feature_minimap[4]==16)
self.normalize_non_spatial_list()
self.non_spatial_state = np.array(self.non_spatial_state)
return self.non_spatial_state
Expand All @@ -178,13 +181,13 @@ def reset(self):

def trim_matrix(self, matrix, x1, y1, x2, y2):
"""
If you have a 2D numpy array
and you want a submatrix of that array,
you can use this function to extract it.
You just need to tell this function
what are the top-left and bottom-right
corners of this submatrix, by setting
x1, y1 and x2, y2.
This function extracts a submatrix of a
2D numpy array.
The arguments x1, y1 and x2, y2 are the
top-left and bottom-right corners of
this submatrix, respectively.
For example: some maps of StarCraft II
have parts that are not walkable, this
happens specially in PySC2 mini-games
Expand All @@ -199,25 +202,23 @@ def trim_matrix(self, matrix, x1, y1, x2, y2):
matrix = np.delete(matrix, np.s_[y2 - y1 + 1::1], 0)
return matrix

def lower_featuremap_resolution(self, map, rf): # rf = reduction_factor
def lower_featuremap_resolution(self, map, reduction_factor):
"""
Reduces a matrix "resolution" by a reduction factor. If we have a 64x64 matrix
and rf=4 the map will be reduced to 16x16 in which every new element of the
matrix is an average from 4x4=16 elements from the original matrix
"""
if rf == 1:
if reduction_factor == 1:
return map

N, M = map.shape
N = N // rf
M = M // rf
N = N // reduction_factor
M = M // reduction_factor

reduced_map = np.empty((N, M))
for i in range(N):
for j in range(M):
# reduction_array = map[rf*i:rf*i+rf, rf*j:rf*j+rf].flatten()
# reduced_map[i,j] = Counter(reduction_array).most_common(1)[0][0]

rf = reduction_factor
reduced_map[i, j] = ((map[rf * i:rf * i + rf, rf * j:rf * j + rf].sum())
/ (rf * rf))

Expand Down
5 changes: 3 additions & 2 deletions urnai/trainers/stablebaselines3_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

import wandb

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.type_aliases import MaybeCallback

import wandb
from urnai.environments.stablebaselines3.custom_env import CustomEnv


Expand Down Expand Up @@ -37,7 +38,7 @@ def load_most_recent_model(self, model_path):
raise Exception(f"No models found in {model_path}")
else:
def only_digits(filename):
return ''.join(c for c in filename if c.isdigit())
return int(''.join(c for c in filename if c.isdigit()))
save_files.sort(reverse=True, key=only_digits)
self.load_model(f"{model_path}/{save_files[0]}")

Expand Down

0 comments on commit e274fdc

Please sign in to comment.