Skip to content

Commit

Permalink
Handle xlogy limit (EnzymeAD#1615)
Browse files Browse the repository at this point in the history
* Handle xlogy limit

* with test

* fixup
  • Loading branch information
wsmoses authored Jul 11, 2024
1 parent c83fcf8 commit ff9d320
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 111 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
EnzymeChainRulesCoreExt = "ChainRulesCore"
EnzymeLogExpFunctionsExt = "LogExpFunctions"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
EnzymeStaticArraysExt = "StaticArrays"

Expand All @@ -33,6 +35,7 @@ EnzymeCore = "0.7.5"
Enzyme_jll = "0.0.133"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7, 8"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Preferences = "1.4"
SpecialFunctions = "1, 2"
Expand All @@ -41,5 +44,6 @@ julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
10 changes: 10 additions & 0 deletions ext/EnzymeLogExpFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module EnzymeLogExpFunctionsExt

using LogExpFunctions
using Enzyme

function __init__()
Enzyme.Compiler.known_ops[typeof(LogExpFunctions.xlogy)] = (:xlogy_jl, 2, nothing)
end

end
107 changes: 53 additions & 54 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,58 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}(
@static if VERSION >= v"1.8.0"
known_ops[typeof(Base.fma_emulated)] = (:fma, 3, nothing)
end
@inline function find_math_method(@nospecialize(func), sparam_vals)
if func keys(known_ops)
name, arity, toinject = known_ops[func]
Tys = (Float32, Float64)

if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

if legal
if name == :ldexp
if !(sparam_vals[2] <: Integer)
legal = false
end
elseif name == :pow
if sparam_vals[2] <: Integer
name = :powi
elseif sparam_vals[2] != T
legal = false
end
elseif name == :jl_rem2pi
else
if !all(==(T), sparam_vals)
legal = false
end
end
end
if legal
return name, toinject, T
end
end
end

if func keys(cmplx_known_ops)
name, arity, toinject = cmplx_known_ops[func]
Tys = (Complex{Float32}, Complex{Float64})
if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

if legal
if !all(==(T), sparam_vals)
legal = false
end
end
if legal
return name, toinject, T
end
end
end
return nothing, nothing, nothing
end

const nofreefns = Set{String}((
"ijl_f_isdefined", "jl_f_isdefined",
Expand Down Expand Up @@ -5621,61 +5673,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end
continue
end

@inline function find_math_method()
if func keys(known_ops)
name, arity, toinject = known_ops[func]
Tys = (Float32, Float64)

if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

if legal
if name == :ldexp
if !(sparam_vals[2] <: Integer)
legal = false
end
elseif name == :pow
if sparam_vals[2] <: Integer
name = :powi
elseif sparam_vals[2] != T
legal = false
end
elseif name == :jl_rem2pi
else
if !all(==(T), sparam_vals)
legal = false
end
end
end
if legal
return name, toinject, T
end
end
end

if func keys(cmplx_known_ops)
name, arity, toinject = cmplx_known_ops[func]
Tys = (Complex{Float32}, Complex{Float64})
if length(sparam_vals) == arity
T = first(sparam_vals)
legal = T Tys

if legal
if !all(==(T), sparam_vals)
legal = false
end
end
if legal
return name, toinject, T
end
end
end
return nothing, nothing, nothing
end

name, toinject, T = find_math_method()
name, toinject, T = find_math_method(func, sparam_vals)
if name === nothing
continue
end
Expand Down
60 changes: 3 additions & 57 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,65 +108,11 @@ function is_primitive_func(@nospecialize(TT))
if ft == typeof(Enzyme.pmap)
return true
end
if ft === typeof(Base.rem2pi)
if TT <: Tuple{ft, Float32, <:Any} || TT <: Tuple{ft, Float64, <:Any} || TT <: Tuple{ft, Float16, <:Any}
return true
end
end

if ft == typeof(Base.inv) || ft == typeof(Base.sqrt)
if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}}
return true
end
end

@static if VERSION >= v"1.9-"
if ft === typeof(Base.rem)
if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64}
return true
end
end
match = Enzyme.Compiler.find_math_method(ft, TT.parameters[2:end])[1]
if match !== nothing
return true
end

if ft === typeof(Base.cbrt) || ft === typeof(Base.sin) || ft === typeof(Base.cos) ||
ft === typeof(Base.sinc) ||
ft === typeof(Base.tan) || ft === typeof(Base.exp) || ft === typeof(Base.FastMath.exp_fast) ||
ft === typeof(Base.exp10) ||
ft === typeof(Base.exp2) ||
ft === typeof(Base.expm1) ||
ft === typeof(Base.log) || ft === typeof(Base.FastMath.log) ||
ft === typeof(Base.log1p) ||
ft === typeof(Base.log2) ||
ft === typeof(Base.log10) ||
ft === typeof(Base.asin) ||
ft === typeof(Base.acos) ||
ft === typeof(Base.atan) ||
ft === typeof(Base.sinpi) ||
ft === typeof(Base.cospi) ||
ft === typeof(Base.sinh) || ft === typeof(Base.FastMath.sinh_fast) ||
ft === typeof(Base.cosh) || ft === typeof(Base.FastMath.cosh_fast) ||
ft === typeof(Base.tanh) || ft === typeof(Base.FastMath.tanh_fast) ||
ft === typeof(Base.sqrt) || ft === typeof(Base.sincos) || ft === typeof(Base.sincospi)
if TT <: Tuple{ft, Float32} || TT <: Tuple{ft, Float64} || TT <: Tuple{ft, Float16}
return true
end
end
@static if VERSION < v"1.8.0"
else
if ft === typeof(Base.fma_emulated)
if TT <: Tuple{ft, Float32, Float32, Float32} || TT <: Tuple{ft, Float64, Float64, Float64}
return true
end
end
end
if ft === typeof(Base.:^) || ft === typeof(Base.atan)
if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64}
return true
end
if TT <: Tuple{ft, Float32, <:Integer} || TT <: Tuple{ft, Float64, <:Integer}
return true
end
end
# FIXME(@wsmoses): For which types should we not inline?
if ft === typeof(Base.wait) || ft === typeof(Base._wait) || ft === typeof(Base.enq_work) ||
ft === typeof(Base.Threads.threadid) || ft == typeof(Base.Threads.nthreads) ||
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
14 changes: 14 additions & 0 deletions test/ext/logexpfunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using LogExpFunctions


xlogydiff(x) = xlogy(x[1], 23.0)
@testset "LogExpFunctions" begin

x = [0.0]

grad_forward = Enzyme.gradient(Enzyme.Forward, xlogydiff, x)
grad_reverse = Enzyme.gradient(Enzyme.Reverse, xlogydiff, x)

@test grad_forward[1] log(23.0)
@test grad_reverse[1] log(23.0)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3511,6 +3511,7 @@ end
@testset "ChainRulesCore ext" begin
include("ext/chainrulescore.jl")
end
include("ext/logexpfunctions.jl")
end


Expand Down

0 comments on commit ff9d320

Please sign in to comment.