Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control Input for NeuralODE #844

Closed
MortezaBabazadehShareh opened this issue Jul 21, 2023 · 3 comments
Closed

Control Input for NeuralODE #844

MortezaBabazadehShareh opened this issue Jul 21, 2023 · 3 comments

Comments

@MortezaBabazadehShareh
Copy link

MortezaBabazadehShareh commented Jul 21, 2023

How can we define a control input for a Neural ODE? The control input is an array with a specific value for each time step.

@MortezaBabazadehShareh MortezaBabazadehShareh changed the title Control Function in NeuralODE Control Input for NeuralODE Jul 21, 2023
@ChrisRackauckas
Copy link
Member

Just use DataInterpolations to make the interpolation.

@MortezaBabazadehShareh
Copy link
Author

After this Interpolation:

Control=LinearInterpolation(u,tsteps);

How should I use the Control as an external control signal in the following code?

data_dim=size(ode_data[:, 1])[1]
dudt2 = Lux.Chain(Lux.Dense(data_dim, 64, relu),
    Lux.Dense(64, 32, relu),
    Lux.Dense(32, data_dim),
    )

rng = Random.default_rng()
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

The Control is the daily vaccination rate, and the primary variable in the ode_data is the Infection number.

@Spinachboul
Copy link

@ChrisRackauckas and @MortezaBabazadehShareh

![Just a concept code, could be taken into consideration]

In the above code snippet, we can modify it slightly to externally include the control input as an added input while building the neural networks layers using Flux

Here is the code snippet:

# Include the required packages

# Assuming ```LinearInterpolation(x, tsteps)``` is defined to get Control(t)

# Define the Neural ODE model
data_dim = size(ode_data[:, 1])[1]
dudt2 = Flux.Chain(
    Flux.Dense(data_dim + 1, 64, relu),  # +1 for the control input
    Flux.Dense(64, 32, relu),
    Flux.Dense(32, data_dim)
)

# Initialize the parameters rng, p and st

# Function to calculate ODE with control input
function dudt_with_control(u, p, t)
    control_input = Control(t)  # Assuming Control is a function of time
    u_with_control = vcat(u, control_input)
    Flux.mlp(p, u_with_control)
end

# Then the further steps include training and loss calculations

Here, dudt_with_control is a modified function that takes the state u, parameters p, and time t, and appends the control input at each time step. The control input is assumed to be obtained from the function Control(t). The rest of the code is adjusted accordingly to incorporate this modified model into the Neural ODE and for training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants