diff --git a/Project.toml b/Project.toml index 9460321e..8c9b1976 100644 --- a/Project.toml +++ b/Project.toml @@ -7,9 +7,19 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" LLVMExtra_jll = "dad2f222-ce93-54a1-a47d-0025e8a3acab" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +[weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" + +[extensions] +BFloat16sExt = "BFloat16s" + [compat] CEnum = "0.2, 0.3, 0.4" LLVMExtra_jll = "=0.0.26" julia = "1.8" + +[extras] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" diff --git a/ext/BFloat16sExt.jl b/ext/BFloat16sExt.jl new file mode 100644 index 00000000..d37e670e --- /dev/null +++ b/ext/BFloat16sExt.jl @@ -0,0 +1,18 @@ +module BFloat16sExt + +using LLVM +using LLVM: API + +isdefined(Base, :get_extension) ? (using BFloat16s) : (using ..BFloat16s) + +## constant values + +LLVM.ConstantFP(val::BFloat16) = ConstantFP(BFloatType(), val) + +Base.convert(::Type{BFloat16}, val::ConstantFP) = + convert(BFloat16, API.LLVMConstRealGetDouble(val, Ref{API.LLVMBool}())) + +ConstantDataArray(data::AbstractVector{BFloat16}) = + ConstantDataArray(BFloatType(), data) + +end diff --git a/src/LLVM.jl b/src/LLVM.jl index b433d732..eb8dad42 100644 --- a/src/LLVM.jl +++ b/src/LLVM.jl @@ -4,6 +4,10 @@ using Unicode using Printf using Libdl +if !isdefined(Base, :get_extension) + using Requires: @require +end + ## source code includes @@ -106,6 +110,12 @@ function __init__() Please re-compile Julia and LLVM.jl (but note that USE_SYSTEM_LLVM is not a supported configuration).""" end + @static if !isdefined(Base, :get_extension) + @require BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" begin + include("../ext/BFloat16sExt.jl") + end + end + _install_handlers() _install_handlers(GlobalContext()) end diff --git a/src/core/type.jl b/src/core/type.jl index 4cae7e9c..54636973 100644 --- a/src/core/type.jl +++ b/src/core/type.jl @@ -85,7 +85,7 @@ width(inttyp::IntegerType) = API.LLVMGetIntTypeWidth(inttyp) # we add it for convenience of typechecking generic values (see execution.jl) abstract type FloatingPointType <: LLVMType end -for T in [:Half, :Float, :Double, :FP128, :X86_FP80, :PPC_FP128] +for T in [:Half, :Float, :Double, :BFloat, :FP128, :X86_FP80, :PPC_FP128] CleanT = Symbol(replace(String(T), "_"=>"")) # only the type kind retains the underscore jl_fname = Symbol(CleanT, :Type) api_typename = Symbol(:LLVM, CleanT) diff --git a/src/core/value/constant.jl b/src/core/value/constant.jl index c057f371..b7212655 100644 --- a/src/core/value/constant.jl +++ b/src/core/value/constant.jl @@ -101,12 +101,9 @@ register(ConstantFP, API.LLVMConstantFPValueKind) ConstantFP(typ::FloatingPointType, val::Real) = ConstantFP(API.LLVMConstReal(typ, Cdouble(val))) -ConstantFP(val::Float16) = - ConstantFP(HalfType(), val) -ConstantFP(val::Float32) = - ConstantFP(FloatType(), val) -ConstantFP(val::Float64) = - ConstantFP(DoubleType(), val) +ConstantFP(val::Float64) = ConstantFP(DoubleType(), val) +ConstantFP(val::Float32) = ConstantFP(FloatType(), val) +ConstantFP(val::Float16) = ConstantFP(HalfType(), val) Base.convert(::Type{T}, val::ConstantFP) where {T<:AbstractFloat} = convert(T, API.LLVMConstRealGetDouble(val, Ref{API.LLVMBool}())) @@ -166,12 +163,12 @@ ConstantDataArray(data::AbstractVector{T}) where {T<:Integer} = ConstantDataArray(IntType(sizeof(T)*8), data) ConstantDataArray(data::AbstractVector{Core.Bool}) = ConstantDataArray(Int1Type(), data) -ConstantDataArray(data::AbstractVector{Float16}) = - ConstantDataArray(HalfType(), data) -ConstantDataArray(data::AbstractVector{Float32}) = - ConstantDataArray(FloatType(), data) ConstantDataArray(data::AbstractVector{Float64}) = ConstantDataArray(DoubleType(), data) +ConstantDataArray(data::AbstractVector{Float32}) = + ConstantDataArray(FloatType(), data) +ConstantDataArray(data::AbstractVector{Float16}) = + ConstantDataArray(HalfType(), data) @checked struct ConstantDataVector <: ConstantDataSequential ref::API.LLVMValueRef diff --git a/test/Project.toml b/test/Project.toml index 17352403..943a00d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVMExtra_jll = "dad2f222-ce93-54a1-a47d-0025e8a3acab" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" diff --git a/test/core_tests.jl b/test/core_tests.jl index 2203e37a..8d997edd 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -1,5 +1,7 @@ @testitem "core" setup=[TestHelpers] begin +using BFloat16s + struct TestStruct x::Bool y::Int64 @@ -410,6 +412,11 @@ end c = ConstantFP(typ, 1.1) @test convert(Float64, c) == 1.1 end + let + typ = LLVM.BFloatType() + c = ConstantFP(typ, BFloat16(1.1)) + @test convert(BFloat16, c) == BFloat16(1.1) + end let typ = LLVM.X86FP80Type() # TODO: how to construct full-width constants?