Skip to content

Commit

Permalink
Adaptible PIPNs
Browse files Browse the repository at this point in the history
  • Loading branch information
ka-bear authored and ChrisRackauckas committed Sep 3, 2024
1 parent ce182c9 commit 06f41b5
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,41 +650,39 @@ function PIPN(chain, strategy = GridTraining(0.1);
logger = nothing,
log_options = LogOptions(),
iteration = nothing,
shared_mlp1_sizes = [64, 64],
shared_mlp2_sizes = [128, 1024],
after_pool_mlp_sizes = [512, 256, 128],
kwargs...)

input_dim = chain[1].in_dims[1]
output_dim = chain[end].out_dims[1]
input_dim = chain[1].in_dims[1]
output_dim = chain[end].out_dims[1]

println("hi");
# Create shared_mlp1
shared_mlp1_layers = [Lux.Dense(i == 1 ? input_dim : shared_mlp1_sizes[i-1] => shared_mlp1_sizes[i], tanh) for i in 1:length(shared_mlp1_sizes)]
shared_mlp1 = Lux.Chain(shared_mlp1_layers...)

shared_mlp1 = Lux.Chain(
Lux.Dense(input_dim => 64, tanh),
Lux.Dense(64 => 64, tanh)
)
# Create shared_mlp2
shared_mlp2_layers = [Lux.Dense(i == 1 ? shared_mlp1_sizes[end] : shared_mlp2_sizes[i-1] => shared_mlp2_sizes[i], tanh) for i in 1:length(shared_mlp2_sizes)]
shared_mlp2 = Lux.Chain(shared_mlp2_layers...)

shared_mlp2 = Lux.Chain(
Lux.Dense(64 => 128, tanh),
Lux.Dense(128 => 1024, tanh)
)
# Create after_pool_mlp
after_pool_input_size = 2 * shared_mlp2_sizes[end] # Doubled due to concatenation
after_pool_mlp_layers = [Lux.Dense(i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i-1] => after_pool_mlp_sizes[i], tanh) for i in 1:length(after_pool_mlp_sizes)]
after_pool_mlp = Lux.Chain(after_pool_mlp_layers...)

after_pool_mlp = Lux.Chain(
Lux.Dense(2048 => 512, tanh), # Changed from 1024 to 2048
Lux.Dense(512 => 256, tanh),
Lux.Dense(256 => 128, tanh)
)
final_layer = Lux.Dense(after_pool_mlp_sizes[end] => output_dim)

final_layer = Lux.Dense(128 => output_dim)

if iteration isa Vector{Int64}
self_increment = false
else
iteration = [1]
self_increment = true
end
if iteration isa Vector{Int64}
self_increment = false
else
iteration = [1]
self_increment = true
end

PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
strategy, init_params, param_estim, additional_loss, adaptive_loss,
logger, log_options, iteration, self_increment, kwargs)
PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
strategy, init_params, param_estim, additional_loss, adaptive_loss,
logger, log_options, iteration, self_increment, kwargs)
end

function (model::PIPN)(x, ps, st::NamedTuple)
Expand Down

0 comments on commit 06f41b5

Please sign in to comment.