@@ -27,9 +27,9 @@ def __init__(self, observation_space, features_dim=64):
2727
2828 # MLP layers
2929 self .mlp = nn .Sequential (
30- nn .Linear (self .num_features , 64 ),
30+ nn .Linear (self .num_features , 128 ),
3131 nn .ReLU (),
32- nn .Linear (64 , features_dim ),
32+ nn .Linear (128 , features_dim ),
3333 nn .ReLU ()
3434 )
3535
@@ -59,7 +59,7 @@ def __init__(self, *args, **kwargs):
5959 * args ,
6060 ** kwargs ,
6161 features_extractor_class = AttentionMLP ,
62- features_extractor_kwargs = {"features_dim" : 64 },
62+ features_extractor_kwargs = {"features_dim" : 128 },
6363 )
6464# ==========================================================================================
6565class DartFeatureExtractor (BaseFeaturesExtractor ):
@@ -412,15 +412,6 @@ def __init__(self, *args, dropout_prob=0.5, **kwargs):
412412 self .mlp_extractor = CustomMLPExtractor (self .features_dim , net_arch , dropout_prob )
413413
414414# ==========================================================================================
415- # Warning: input size is hardcoded for now
416- def print_model_info (model ):
417-
418- if args .nn == 'CnnPolicy' :
419- summary (model .policy , (4 , 84 , 84 ))
420- elif args .nn == 'MlpPolicy' :
421- summary (model .policy , (1 , 6 ))
422-
423- return total_params
424415
425416def get_num_parameters (model ):
426417 total_params = sum (p .numel () for p in model .policy .parameters () if p .requires_grad )
@@ -443,14 +434,85 @@ def load_hyperparameters(json_file):
443434
444435 return hyperparams
445436
437+ def print_model_summary (env , player_model , model ):
438+ # Print model summary
439+ if player_model == '' : # Only for newly created models
440+ print ("\n Model Architecture Summary:" )
441+
442+ # Get the policy network
443+ policy = model .policy
444+
445+ # Print basic info
446+ print (f"\n Policy Network Type: { policy .__class__ .__name__ } " )
447+ print (f"Observation Space: { env .observation_space } " )
448+ print (f"Action Space: { env .action_space } " )
449+
450+ # Print feature extractor info if available
451+ if hasattr (policy , 'features_extractor' ):
452+ print ("\n Feature Extractor Architecture:" )
453+ fe = policy .features_extractor
454+ print (f"Type: { fe .__class__ .__name__ } " )
455+
456+ # Print layers for common extractor types
457+ if isinstance (fe , (CustomCNN , CustomImpalaFeatureExtractor , ViTFeatureExtractor )):
458+ for name , layer in fe .named_children ():
459+ print (f" { name } : { layer .__class__ .__name__ } " )
460+ if hasattr (layer , 'weight' ):
461+ print (f" Weight shape: { layer .weight .shape } " )
462+
463+ # Special handling for CNN-based extractors
464+ if hasattr (fe , 'cnn' ):
465+ print ("\n CNN Layers:" )
466+ for name , layer in fe .cnn .named_children ():
467+ print (f" { name } : { layer .__class__ .__name__ } " )
468+ if isinstance (layer , nn .Conv2d ):
469+ print (f" in_channels: { layer .in_channels } " )
470+ print (f" out_channels: { layer .out_channels } " )
471+ print (f" kernel_size: { layer .kernel_size } " )
472+
473+ # Print MLP extractor info if available
474+ if hasattr (policy , 'mlp_extractor' ):
475+ print ("\n MLP Extractor Architecture:" )
476+ mlp = policy .mlp_extractor
477+ print (f"Type: { mlp .__class__ .__name__ } " )
478+
479+ if hasattr (mlp , 'shared_net' ):
480+ print ("\n Shared Layers:" )
481+ for i , layer in enumerate (mlp .shared_net ):
482+ print (f" Layer { i } : { layer .__class__ .__name__ } " )
483+ if hasattr (layer , 'weight' ):
484+ print (f" Weight shape: { layer .weight .shape } " )
485+
486+ if hasattr (mlp , 'policy_net' ):
487+ print ("\n Policy Head:" )
488+ for i , layer in enumerate (mlp .policy_net ):
489+ print (f" Layer { i } : { layer .__class__ .__name__ } " )
490+
491+ if hasattr (mlp , 'value_net' ):
492+ print ("\n Value Head:" )
493+ for i , layer in enumerate (mlp .value_net ):
494+ print (f" Layer { i } : { layer .__class__ .__name__ } " )
495+
496+ # Print action distribution info
497+ if hasattr (policy , 'action_dist' ):
498+ print ("\n Action Distribution:" )
499+ print (f"Type: { policy .action_dist .__class__ .__name__ } " )
500+
501+ # Print total parameters
502+ total_params = sum (p .numel () for p in policy .parameters () if p .requires_grad )
503+ print (f"\n Total Trainable Parameters: { total_params :,} " )
504+
446505def init_model (output_path , player_model , player_alg , args , env , logger ):
447506 policy_kwargs = None
448507 nn_type = args .nn
449508 size = args .nnsize
450509
451- # Load hyperparameters from JSON if file provided
452- if args .hyperparams and os .path .isfile (args .hyperparams ):
453- hyperparams = load_hyperparameters (args .hyperparams )
510+ # Load hyperparameters from JSON
511+ if args .hyperparams :
512+ if os .path .isfile (args .hyperparams ):
513+ hyperparams = load_hyperparameters (args .hyperparams )
514+ else :
515+ raise FileNotFoundError (f"Hyperparameters file not found: { args .hyperparams } " )
454516 else :
455517 hyperparams = {}
456518
@@ -519,5 +581,8 @@ def init_model(output_path, player_model, player_alg, args, env, logger):
519581 else :
520582 model = A2C .load (os .path .expanduser (player_model ), env = env , verbose = 1 , tensorboard_log = output_path )
521583
584+
585+ print_model_summary (env , player_model , model )
586+
522587 model .set_logger (logger )
523588 return model
0 commit comments