Skip to content

Commit

Permalink
Nice union{} error (EnzymeAD#1479)
Browse files Browse the repository at this point in the history
* Nice union{} error

* fixup
  • Loading branch information
wsmoses authored May 28, 2024
1 parent e362c36 commit 5609c7e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
end

if A <: Active
if !allocatedinline(rt) || rt isa Union
if (!allocatedinline(rt) || rt isa Union) && rt != Union{}
forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI)
res = forward(f, args...)
tape = res[1]
Expand All @@ -244,7 +244,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
throw(ErrorException("Duplicated Returns not yet handled"))
end

if A <: Active && rt <: Complex
if (A <: Active && rt <: Complex) && rt != Union{}
if Holomorphic
seen = IdDict()
seen2 = IdDict()
Expand Down
79 changes: 55 additions & 24 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT
adjoint::PT
end

struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width}
adjoint::PT
end

@inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT
@inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT

Expand Down Expand Up @@ -5277,7 +5281,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
cf = LLVM.called_operand(tmp)
if isa(cf, LLVM.Function)
nm = LLVM.name(cf)
if nm == "gpu_signal_exception" || nm == "gpu_report_exception"
if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw"
shouldemit = false
break
end
Expand Down Expand Up @@ -5433,6 +5437,9 @@ struct CompileResult{AT, PT}
TapeType::Type
end

@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} =
enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...)

@inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} =
enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...)

Expand Down Expand Up @@ -5536,7 +5543,9 @@ end
end

@inline function default_adjoint(T)
if T <: AbstractFloat
if T == Union{}
return nothing
elseif T <: AbstractFloat
return one(T)
elseif T <: Complex
error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff")
Expand All @@ -5559,7 +5568,7 @@ end

JuliaContext() do ctx
F = eltype(FA)
is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk
is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk
is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk
is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk
needs_tape = CC <: AdjointThunk
Expand All @@ -5569,32 +5578,44 @@ end
argtypes = DataType[argtt.parameters...]
argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N]

if !RawCall
if false && CC <: PrimalErrorThunk
primargs = [quote
convert($(eltype(T)), $(argexprs[i]).val)
end for (i, T) in enumerate(argtypes)]
return quote
fn.val($(primargs...))
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end

if !RawCall && !(CC <: PrimalErrorThunk)
if rettype <: Active
if length(argtypes) + is_adjoint + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
elseif rettype <: Const
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
else
if length(argtypes) + needs_tape != length(argexprs)
return quote
throw(MethodError($CC($fptr), $args))
throw(MethodError($CC(fptr), $args))
end
end
end
end

types = DataType[]

if eltype(rettype) === Union{}
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
if eltype(rettype) === Union{} && false
return quote
error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end
end
if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType)
rrt = eltype(rettype)
Expand Down Expand Up @@ -5665,7 +5686,9 @@ end
end
continue
end

if CC <: PrimalErrorThunk
continue
end
if T <: Active
if is_adjoint
if width == 1
Expand Down Expand Up @@ -5752,8 +5775,10 @@ end
end
push!(sret_types, NT)
end

@assert i == length(argexprs)+1

if !(CC <: PrimalErrorThunk)
@assert i == length(argexprs)+1
end

# Tape
if CC <: AugmentedForwardThunk
Expand Down Expand Up @@ -5785,7 +5810,7 @@ end

T_void = convert(LLVMType, Nothing)

combinedReturn = Tuple{sret_types...}
combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...}
if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types)
combinedReturn = AnonymousStruct(combinedReturn)
end
Expand Down Expand Up @@ -6003,29 +6028,30 @@ end
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}

interp = GPUCompiler.get_interpreter(tmp_job)

# TODO check compile return here, early
# rrt = Core.Compiler.return_type(f, primal.tt) # nothing
rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any)
rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype

run_enzyme = true

if rrt == Union{}
estr = "Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up"
return quote
error($estr)
end
run_enzyme = false
A = Const
end

if !(A <: Const) && guaranteed_const_nongen(rrt, World)
if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World)
estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant"
return quote
error($estr)
end
end

rt2 = if A isa UnionAll
rt2 = if !run_enzyme
Const{rrt}
elseif A isa UnionAll
A{rrt}
else
@assert A isa DataType
Expand All @@ -6034,7 +6060,7 @@ end
A
end

params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

# We need to use primal as the key, to lookup the right method
Expand All @@ -6045,7 +6071,13 @@ end


compile_result = cached_compilation(job)
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
if !run_enzyme
ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World}
return quote
Base.@_inline_meta
$ErrT($(compile_result.adjoint))
end
elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
TapeType = compile_result.TapeType
AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType}
AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType}
Expand Down Expand Up @@ -6086,7 +6118,6 @@ import GPUCompiler: deferred_codegen_jobs
params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI)
tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

sig = Tuple{eltype(FA), map(eltype, TT.parameters)...}
interp = GPUCompiler.get_interpreter(tmp_job)

rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any)
Expand Down
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,15 @@ end
@test 2.0 Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1]
end


function assured_err(x)
throw(AssertionError("foo"))
end

@testset "UnionAll" begin
@test_throws AssertionError Enzyme.autodiff(Reverse, assured_err, Active, Active(2.0))
end

struct MyFlux
end

Expand Down

0 comments on commit 5609c7e

Please sign in to comment.