Skip to content

Commit

Permalink
refactor: clean up computation of tstops_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Aug 28, 2023
1 parent c57c468 commit d5836b4
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
samples = strategy.samples
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end
Expand Down Expand Up @@ -450,19 +450,17 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
if !(tstops isa Nothing)
num_tstops_points = length(tstops)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch)
total_tstops_loss = tstops_loss_func(θ, phi) * num_tstops_points
tstops_loss = tstops_loss_func(θ, phi)
if strategy isa GridTraining
num_original_points = length(tspan[1]:(strategy.dx):tspan[2])
elseif strategy isa WeightedIntervalTraining
num_original_points = strategy.samples
elseif strategy isa StochasticTraining
elseif strategy isa Union{WeightedIntervalTraining, StochasticTraining}
num_original_points = strategy.points
else
L2_loss = L2_loss + tstops_loss_func(θ, phi)
return L2_loss
return L2_loss + tstops_loss
end

total_original_loss = L2_loss * num_original_points
total_tstops_loss = tstops_loss * num_original_points
total_points = num_original_points + num_tstops_points
L2_loss = (total_original_loss + total_tstops_loss) / total_points

Expand Down

0 comments on commit d5836b4

Please sign in to comment.