Skip to content

Commit

Permalink
Fix getfield of const (EnzymeAD#1572)
Browse files Browse the repository at this point in the history
* Fix getfield of const

* fix

* add test

* fixup
  • Loading branch information
wsmoses authored Jun 27, 2024
1 parent c4068fc commit cdb4df3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}(
end

const nofreefns = Set{String}((
"ijl_field_index", "jl_field_index",
"ijl_specializations_get_linfo", "jl_specializations_get_linfo",
"ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds",
"ijl_gc_get_total_bytes", "jl_gc_get_total_bytes",
Expand Down Expand Up @@ -183,6 +184,7 @@ const nofreefns = Set{String}((
))

const inactivefns = Set{String}((
"ijl_field_index", "jl_field_index",
"ijl_specializations_get_linfo", "jl_specializations_get_linfo",
"ijl_gf_invoke_lookup_worlds", "jl_gf_invoke_lookup_worlds",
"ijl_gc_get_total_bytes", "jl_gc_get_total_bytes",
Expand Down Expand Up @@ -3258,7 +3260,7 @@ function annotate!(mod, mode)
end
end

for fname in ("jl_excstack_state","ijl_excstack_state")
for fname in ("jl_excstack_state","ijl_excstack_state", "ijl_field_index", "jl_field_index")
if haskey(fns, fname)
fn = fns[fname]
if LLVM.version().major <= 15
Expand Down
34 changes: 27 additions & 7 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco
RT = Core.Typeof(res)

actreg = active_reg_nothrow(RT, Val(nothing))
if actreg == ActiveState
if actreg == ActiveState || (isconst && actreg == MixedState)
if length(dptrs) == 0
return Ref{RT}(make_zero(res))
else
Expand All @@ -626,6 +626,16 @@ function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isco
end)...))
return fval
end
elseif isconst
if length(dptrs) == 0
return make_zero(res)
else
fval = NT((res, (ntuple(Val(length(dptrs))) do i
Base.@_inline_meta
make_zero(res)
end)...))
return fval
end
else
if length(dptrs) == 0
return res
Expand All @@ -648,7 +658,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc
end
RT = Core.Typeof(res)
actreg = active_reg_nothrow(RT, Val(nothing))
if actreg == ActiveState
if actreg == ActiveState || (isconst && actreg == MixedState)
if length(dptrs) == 0
return Ref{RT}(make_zero(res))::Any
else
Expand All @@ -659,7 +669,7 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc
end
elseif actreg == MixedState
if length(dptrs) == 0
return Ref{RT}(res)::Any
return Ref{RT}(res)
else
fval = NT((Ref{RT}(res), (ntuple(Val(length(dptrs))) do i
Base.@_inline_meta
Expand All @@ -668,6 +678,16 @@ function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isc
end)...))
return fval
end
elseif isconst
if length(dptrs) == 0
return make_zero(res)::Any
else
fval = NT((res, (ntuple(Val(length(dptrs))) do i
Base.@_inline_meta
make_zero(res)
end)...))
return fval
end
else
if length(dptrs) == 0
return res::Any
Expand Down Expand Up @@ -858,7 +878,7 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta
sym = emit_apply_type!(B, Base.Val, [sym])
push!(vals, sym)

push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))
push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2]))))

for v in inps[2:end]
push!(vals, v)
Expand Down Expand Up @@ -944,7 +964,7 @@ function common_jl_getfield_rev(offset, B, orig, gutils, tape)
sym = emit_apply_type!(B, Base.Val, [sym])
push!(vals, sym)

push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))
push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[2]))))

for v in inps[2:end]
push!(vals, v)
Expand Down Expand Up @@ -1037,7 +1057,7 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
sym = emit_apply_type!(B, Base.Val, [sym])
push!(vals, sym)

push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))
push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1]))))

for v in inps[2:end]
push!(vals, v)
Expand Down Expand Up @@ -1125,7 +1145,7 @@ function jl_nthfield_rev(B, orig, gutils, tape)
sym = emit_apply_type!(B, Base.Val, [sym])
push!(vals, sym)

push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, orig))))
push!(vals, unsafe_to_llvm(Val(is_constant_value(gutils, ops[1]))))

for v in inps[2:end]
push!(vals, v)
Expand Down
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2327,7 +2327,35 @@ end

adres = Enzyme.autodiff(Reverse, sf_for3, Duplicated(mt3, dmt3), Const(:x), Const(:x), Active(3.1))
@test adres[1][4] 5050.0

mutable struct MyTypeM
x::Float64
y
end

@noinline function unstable_mul(x, y)
return (x*y)::Float64
end

function gf3(y, v::MyTypeM, fld::Symbol)
x = getfield(v, fld)
unstable_mul(x, y)
end

function gf3(y, v::MyTypeM, fld::Integer)
x = getfield_idx(v, fld)
unstable_mul(x, y)
end

mx = MyTypeM(3.0, 1)
res = Enzyme.autodiff(Reverse, gf3, Active, Active(2.7), Const(mx), Const(:x))
@test mx.x 3.0
@test res[1][1] 3.0

mx = MyTypeM(3.0, 1)
res = Enzyme.autodiff(Reverse, gf3, Active, Active(2.7), Const(mx), Const(0))
@test mx.x 3.0
@test res[1][1] 3.0
end


Expand Down

0 comments on commit cdb4df3

Please sign in to comment.