Skip to content

Commit

Permalink
Some work
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed May 16, 2024
1 parent 6c7b345 commit ccdf6ec
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool)
end

function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N}
sig = signature_from_values(tuple_map(primal, args))
sig = Tuple{map(_typeof, args)...}
is_primitive(context_type(dynamic_rule.interp), sig) && return rrule!!(args...)
rule = get(dynamic_rule.cache, sig, nothing)
if rule === nothing
Expand Down
9 changes: 1 addition & 8 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ end

# A function with the same semantics as `Core._apply_iterate`, but which is differentiable.
function _apply_iterate_equivalent(itr, f::F, args::Vararg{Any, N}) where {F, N}
vec_args = reduce(vcat, tuple_map(collect, args))
vec_args = reduce(vcat, map(collect, args))
tuple_args = __vec_to_tuple(vec_args)
return tuple_splat(f, tuple_args)
end
Expand Down Expand Up @@ -418,13 +418,6 @@ function build_rrule(
return ApplyIterateRule(build_rrule(interp, new_sig; kwargs...))
end

function rule_type(
interp::TapirInterpreter{C}, sig::Type{<:Tuple{typeof(Core._apply_iterate), Vararg}}
) where {C}
new_sig = Tuple{typeof(_apply_iterate_equivalent), sig.parameters[2:end]...}
return ApplyIterateRule{rule_type(interp, new_sig)}
end

# Core._apply_pure
# Core._call_in_world
# Core._call_in_world_total
Expand Down
7 changes: 0 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@ _typeof(x) = Base._stable_typeof(x)
_typeof(x::Tuple) = Tuple{map(_typeof, x)...}
_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x))}

"""
signature_from_values(x::Tuple)
"""
signature_from_values(x::Tuple) = Tuple{map(Base._stable_typeof, x)...}

"""
tuple_map(f::F, x::Tuple) where {F}
Expand Down
97 changes: 53 additions & 44 deletions test/integration_testing/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ end
#w ~ arraydist([dist[doc[i]] for i in 1:length(doc)])
end

function make_large_model(num_tildes::Int)
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!
mainbody = last(expr.args)
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes])
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
return invokelatest(f)
end

function build_turing_problem(rng, model, example=nothing)
ctx = Turing.DefaultContext()
vi = example === nothing ? Turing.SimpleVarInfo(model) : Turing.SimpleVarInfo(example)
Expand All @@ -89,14 +97,15 @@ end
interp = Tapir.PInterp()
@testset "$(typeof(model))" for (interface_only, name, model, ex) in vcat(
Any[
(false, "simple_model", simple_model(), nothing),
(false, "demo", demo(), nothing),
(
false,
"broadcast_demo",
broadcast_demo(rand(LogNormal(1.5, 0.5), 1_000)),
nothing,
),
# (false, "simple_model", simple_model(), nothing),
# (false, "demo", demo(), nothing),
# (
# false,
# "broadcast_demo",
# broadcast_demo(rand(LogNormal(1.5, 0.5), 1_000)),
# nothing,
# ),
(false, "large model", make_large_model(33), nothing),
# (
# false,
# "CollapsedLDA",
Expand All @@ -105,50 +114,50 @@ end
# ),
# ), doesn't currently work with SimpleVarInfo
],
Any[
(false, "demo_$n", m, Turing.DynamicPPL.TestUtils.rand_prior_true(m)) for
(n, m) in enumerate(Turing.DynamicPPL.TestUtils.DEMO_MODELS)
],
# Any[
# (false, "demo_$n", m, Turing.DynamicPPL.TestUtils.rand_prior_true(m)) for
# (n, m) in enumerate(Turing.DynamicPPL.TestUtils.DEMO_MODELS)
# ],
)
@info name
rng = sr(123)
f, x = build_turing_problem(rng, model, ex)
TestUtils.test_derived_rule(
sr(123456), f, x;
perf_flag=:none, interface_only=true, is_primitive=false, interp
perf_flag=:none, interface_only=true, is_primitive=false, interp, safety_on=true
)

# rule = build_rrule(interp, _typeof((f, x)))
# codualed_args = map(zero_codual, (f, x))
# TestUtils.to_benchmark(rule, codualed_args...)

# primal = @benchmark $f($x)
# gradient = @benchmark(TestUtils.to_benchmark($rule, $codualed_args...))

# println("primal")
# display(primal)
# println()

# println("gradient")
# display(gradient)
# println()

# try
# tape = ReverseDiff.GradientTape(f, x);
# ReverseDiff.gradient!(tape, x);
# result = zeros(size(x));
# ReverseDiff.gradient!(result, tape, x)

# revdiff = @benchmark ReverseDiff.gradient!($result, $tape, $x)
# println("ReverseDiff")
# display(revdiff)
# println()
# @show time(revdiff) / time(primal)
# catch
# display("revdiff failed")
# end

# @show time(gradient) / time(primal)
rule = build_rrule(interp, _typeof((f, x)))
codualed_args = map(zero_codual, (f, x))
TestUtils.to_benchmark(rule, codualed_args...)

primal = @benchmark $f($x)
gradient = @benchmark(TestUtils.to_benchmark($rule, $codualed_args...))

println("primal")
display(primal)
println()

println("gradient")
display(gradient)
println()

try
tape = ReverseDiff.GradientTape(f, x);
ReverseDiff.gradient!(tape, x);
result = zeros(size(x));
ReverseDiff.gradient!(result, tape, x)

revdiff = @benchmark ReverseDiff.gradient!($result, $tape, $x)
println("ReverseDiff")
display(revdiff)
println()
@show time(revdiff) / time(primal)
catch
display("revdiff failed")
end

@show time(gradient) / time(primal)

# @profview run_many_times(10_000, TestUtils.to_benchmark, rule, codualed_args...)

Expand Down

0 comments on commit ccdf6ec

Please sign in to comment.