@@ -650,41 +650,39 @@ function PIPN(chain, strategy = GridTraining(0.1);
650650 logger = nothing ,
651651 log_options = LogOptions(),
652652 iteration = nothing ,
653+ shared_mlp1_sizes = [64 , 64 ],
654+ shared_mlp2_sizes = [128 , 1024 ],
655+ after_pool_mlp_sizes = [512 , 256 , 128 ],
653656 kwargs... )
654657
655- input_dim = chain[1 ]. in_dims[1 ]
656- output_dim = chain[end ]. out_dims[1 ]
658+ input_dim = chain[1 ]. in_dims[1 ]
659+ output_dim = chain[end ]. out_dims[1 ]
657660
658- println(" hi" );
661+ # Create shared_mlp1
662+ shared_mlp1_layers = [Lux. Dense(i == 1 ? input_dim : shared_mlp1_sizes[i- 1 ] => shared_mlp1_sizes[i], tanh) for i in 1 : length(shared_mlp1_sizes)]
663+ shared_mlp1 = Lux. Chain(shared_mlp1_layers... )
659664
660- shared_mlp1 = Lux. Chain(
661- Lux. Dense(input_dim => 64 , tanh),
662- Lux. Dense(64 => 64 , tanh)
663- )
665+ # Create shared_mlp2
666+ shared_mlp2_layers = [Lux. Dense(i == 1 ? shared_mlp1_sizes[end ] : shared_mlp2_sizes[i- 1 ] => shared_mlp2_sizes[i], tanh) for i in 1 : length(shared_mlp2_sizes)]
667+ shared_mlp2 = Lux. Chain(shared_mlp2_layers... )
664668
665- shared_mlp2 = Lux . Chain(
666- Lux . Dense( 64 => 128 , tanh),
667- Lux. Dense(128 => 1024 , tanh)
668- )
669+ # Create after_pool_mlp
670+ after_pool_input_size = 2 * shared_mlp2_sizes[ end ] # Doubled due to concatenation
671+ after_pool_mlp_layers = [ Lux. Dense(i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i - 1 ] => after_pool_mlp_sizes[i] , tanh) for i in 1 : length(after_pool_mlp_sizes)]
672+ after_pool_mlp = Lux . Chain(after_pool_mlp_layers ... )
669673
670- after_pool_mlp = Lux. Chain(
671- Lux. Dense(2048 => 512 , tanh), # Changed from 1024 to 2048
672- Lux. Dense(512 => 256 , tanh),
673- Lux. Dense(256 => 128 , tanh)
674- )
674+ final_layer = Lux. Dense(after_pool_mlp_sizes[end ] => output_dim)
675675
676- final_layer = Lux. Dense(128 => output_dim)
677-
678- if iteration isa Vector{Int64}
679- self_increment = false
680- else
681- iteration = [1 ]
682- self_increment = true
683- end
676+ if iteration isa Vector{Int64}
677+ self_increment = false
678+ else
679+ iteration = [1 ]
680+ self_increment = true
681+ end
684682
685- PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
686- strategy, init_params, param_estim, additional_loss, adaptive_loss,
687- logger, log_options, iteration, self_increment, kwargs)
683+ PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
684+ strategy, init_params, param_estim, additional_loss, adaptive_loss,
685+ logger, log_options, iteration, self_increment, kwargs)
688686end
689687
690688function (model:: PIPN )(x, ps, st:: NamedTuple )
0 commit comments