Skip to content

Commit 365a072

Browse files
committed
Update model.py:
+ Increase default layer size of AttentionMLP + Improved Model Summary + Add error when hyperparams file not found
1 parent cc7df54 commit 365a072

File tree

1 file changed

+80
-15
lines changed

1 file changed

+80
-15
lines changed

scripts/models.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def __init__(self, observation_space, features_dim=64):
2727

2828
# MLP layers
2929
self.mlp = nn.Sequential(
30-
nn.Linear(self.num_features, 64),
30+
nn.Linear(self.num_features, 128),
3131
nn.ReLU(),
32-
nn.Linear(64, features_dim),
32+
nn.Linear(128, features_dim),
3333
nn.ReLU()
3434
)
3535

@@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs):
5959
*args,
6060
**kwargs,
6161
features_extractor_class=AttentionMLP,
62-
features_extractor_kwargs={"features_dim": 64},
62+
features_extractor_kwargs={"features_dim": 128},
6363
)
6464
# ==========================================================================================
6565
class DartFeatureExtractor(BaseFeaturesExtractor):
@@ -412,15 +412,6 @@ def __init__(self, *args, dropout_prob=0.5, **kwargs):
412412
self.mlp_extractor = CustomMLPExtractor(self.features_dim, net_arch, dropout_prob)
413413

414414
# ==========================================================================================
415-
# Warning: input size is hardcoded for now
416-
def print_model_info(model):
417-
418-
if args.nn == 'CnnPolicy':
419-
summary(model.policy, (4, 84, 84))
420-
elif args.nn == 'MlpPolicy':
421-
summary(model.policy, (1, 6))
422-
423-
return total_params
424415

425416
def get_num_parameters(model):
426417
total_params = sum(p.numel() for p in model.policy.parameters() if p.requires_grad)
@@ -443,14 +434,85 @@ def load_hyperparameters(json_file):
443434

444435
return hyperparams
445436

437+
def print_model_summary(env, player_model, model):
438+
# Print model summary
439+
if player_model == '': # Only for newly created models
440+
print("\nModel Architecture Summary:")
441+
442+
# Get the policy network
443+
policy = model.policy
444+
445+
# Print basic info
446+
print(f"\nPolicy Network Type: {policy.__class__.__name__}")
447+
print(f"Observation Space: {env.observation_space}")
448+
print(f"Action Space: {env.action_space}")
449+
450+
# Print feature extractor info if available
451+
if hasattr(policy, 'features_extractor'):
452+
print("\nFeature Extractor Architecture:")
453+
fe = policy.features_extractor
454+
print(f"Type: {fe.__class__.__name__}")
455+
456+
# Print layers for common extractor types
457+
if isinstance(fe, (CustomCNN, CustomImpalaFeatureExtractor, ViTFeatureExtractor)):
458+
for name, layer in fe.named_children():
459+
print(f" {name}: {layer.__class__.__name__}")
460+
if hasattr(layer, 'weight'):
461+
print(f" Weight shape: {layer.weight.shape}")
462+
463+
# Special handling for CNN-based extractors
464+
if hasattr(fe, 'cnn'):
465+
print("\nCNN Layers:")
466+
for name, layer in fe.cnn.named_children():
467+
print(f" {name}: {layer.__class__.__name__}")
468+
if isinstance(layer, nn.Conv2d):
469+
print(f" in_channels: {layer.in_channels}")
470+
print(f" out_channels: {layer.out_channels}")
471+
print(f" kernel_size: {layer.kernel_size}")
472+
473+
# Print MLP extractor info if available
474+
if hasattr(policy, 'mlp_extractor'):
475+
print("\nMLP Extractor Architecture:")
476+
mlp = policy.mlp_extractor
477+
print(f"Type: {mlp.__class__.__name__}")
478+
479+
if hasattr(mlp, 'shared_net'):
480+
print("\nShared Layers:")
481+
for i, layer in enumerate(mlp.shared_net):
482+
print(f" Layer {i}: {layer.__class__.__name__}")
483+
if hasattr(layer, 'weight'):
484+
print(f" Weight shape: {layer.weight.shape}")
485+
486+
if hasattr(mlp, 'policy_net'):
487+
print("\nPolicy Head:")
488+
for i, layer in enumerate(mlp.policy_net):
489+
print(f" Layer {i}: {layer.__class__.__name__}")
490+
491+
if hasattr(mlp, 'value_net'):
492+
print("\nValue Head:")
493+
for i, layer in enumerate(mlp.value_net):
494+
print(f" Layer {i}: {layer.__class__.__name__}")
495+
496+
# Print action distribution info
497+
if hasattr(policy, 'action_dist'):
498+
print("\nAction Distribution:")
499+
print(f"Type: {policy.action_dist.__class__.__name__}")
500+
501+
# Print total parameters
502+
total_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
503+
print(f"\nTotal Trainable Parameters: {total_params:,}")
504+
446505
def init_model(output_path, player_model, player_alg, args, env, logger):
447506
policy_kwargs = None
448507
nn_type = args.nn
449508
size = args.nnsize
450509

451-
# Load hyperparameters from JSON if file provided
452-
if args.hyperparams and os.path.isfile(args.hyperparams):
453-
hyperparams = load_hyperparameters(args.hyperparams)
510+
# Load hyperparameters from JSON
511+
if args.hyperparams:
512+
if os.path.isfile(args.hyperparams):
513+
hyperparams = load_hyperparameters(args.hyperparams)
514+
else:
515+
raise FileNotFoundError(f"Hyperparameters file not found: {args.hyperparams}")
454516
else:
455517
hyperparams = {}
456518

@@ -519,5 +581,8 @@ def init_model(output_path, player_model, player_alg, args, env, logger):
519581
else:
520582
model = A2C.load(os.path.expanduser(player_model), env=env, verbose=1, tensorboard_log=output_path)
521583

584+
585+
print_model_summary(env, player_model, model)
586+
522587
model.set_logger(logger)
523588
return model

0 commit comments

Comments
 (0)