Skip to content

Commit

Permalink
Merge pull request #100 from JuliaGPU/jps/ka-support2
Browse files Browse the repository at this point in the history
KernelAbstractions support
  • Loading branch information
jpsamaroo authored Mar 25, 2021
2 parents 3d2959f + 1b8ed2c commit 9f387fa
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AMDGPU"
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
version = "0.2.3"
version = "0.2.4"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down Expand Up @@ -32,6 +32,8 @@ LLVM = "3"
MacroTools = "0.5"
Requires = "1"
Setfield = "0.5, 0.6, 0.7"
hsa_rocr_jll = "3.7"
hsakmt_roct_jll = "3.7"
julia = "1.6"

[extras]
Expand Down
30 changes: 21 additions & 9 deletions deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ function version_hsa(libpath)
major_ref = Ref{Cushort}(typemax(Cushort))
minor_ref = Ref{Cushort}(typemax(Cushort))
status = ccall(sym, Cint, (Cint, Ptr{Cushort}), 0, major_ref)
@assert status == 0 "HSA error: $status"
if status != 0
@warn "HSA error: $status"
return v"0"
end
status = ccall(sym, Cint, (Cint, Ptr{Cushort}), 1, minor_ref)
@assert status == 0 "HSA error: $status"
if status != 0
@warn "HSA error: $status"
return v"0"
end
return VersionNumber(major_ref[], minor_ref[])
end

Expand Down Expand Up @@ -138,35 +144,40 @@ function main()

# check that we're running Linux
if !Sys.islinux()
build_error("Not running Linux, which is the only platform currently supported by the ROCm Runtime.")
build_warning("Not running Linux, which is the only platform currently supported by the ROCm Runtime.")
return
end

roc_dirs = find_roc_paths()

config[:libhsaruntime_path] = find_hsa_library("libhsa-runtime64.so.1", roc_dirs)
if config[:libhsaruntime_path] == nothing
build_error("Could not find HSA runtime library v1.")
build_warning("Could not find HSA runtime library v1.")
return
end

# initializing the library isn't necessary, but flushes out errors that otherwise would
# happen during `version` or, worse, at package load time.
status = init_hsa(config[:libhsaruntime_path])
if status != 0
build_error("Initializing HSA runtime failed with code $status.")
build_warning("Initializing HSA runtime failed with code $status.")
return
end

config[:libhsaruntime_version] = version_hsa(config[:libhsaruntime_path])

# also shutdown just in case
status = shutdown_hsa(config[:libhsaruntime_path])
if status != 0
build_error("Shutdown of HSA runtime failed with code $status.")
build_warning("Shutdown of HSA runtime failed with code $status.")
return
end

# find the ld.lld program for linking kernels
ld_path = find_ld_lld()
if ld_path == ""
build_error("Couldn't find ld.lld, please install it with your package manager")
build_warning("Couldn't find ld.lld, please install it with your package manager")
return
end
config[:ld_lld_path] = ld_path

Expand Down Expand Up @@ -209,7 +220,7 @@ function main()

if status != 0
# we got here, so the status is non-fatal
build_error("""
build_warning("""
AMDGPU.jl has been built successfully, but there were warnings.
Some functionality may be unavailable.""")
Expand Down Expand Up @@ -246,7 +257,8 @@ if dl_info === nothing && unsatisfied
# If we don't have a compatible .tar.gz to download, complain.
# Alternatively, you could attempt to install from a separate provider,
# build from source or something even more ambitious here.
error("Your platform (\"$(Sys.MACHINE)\", parsed as \"$(triplet(platform_key_abi()))\") is not supported by this package!")
@warn "Your platform (\"$(Sys.MACHINE)\", parsed as \"$(triplet(platform_key_abi()))\") is not supported by this package!"
return
end

# If we have a download, and we are unsatisfied (or the version we're
Expand Down
2 changes: 2 additions & 0 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ end

barrier_and!(signals::Vector) = barrier_and!(default_queue().queue, signals)
barrier_or!(signals::Vector) = barrier_or!(default_queue().queue, signals)
barrier_and!(queue::HSAQueue, signals::Vector{HSA.Signal}) = barrier_and!(queue, map(HSASignal, signals))
barrier_or!(queue::HSAQueue, signals::Vector{HSA.Signal}) = barrier_or!(queue, map(HSASignal, signals))
barrier_and!(queue::HSAQueue, signals::Vector{HSASignal}) =
barrier!(HSA.BarrierAndPacket, queue, signals)
barrier_or!(queue::HSAQueue, signals::Vector{HSASignal}) =
Expand Down
1 change: 1 addition & 0 deletions src/queue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ function HSAQueue(agent::HSAAgent)
end
return queue
end
HSAQueue() = HSAQueue(get_default_agent())

get_default_queue() = get_default_queue(get_default_agent())

Expand Down
1 change: 1 addition & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct HSAStatusSignal
exe::HSAExecutable
end
create_event(::typeof(HSA_rt), exe) = HSAStatusSignal(HSASignal(), exe.exe)
Base.wait(event::HSAStatusSignal; kwargs...) = wait(RuntimeEvent(event); kwargs...)
function Base.wait(event::RuntimeEvent{HSAStatusSignal}; check_exceptions=true, cleanup=true, kwargs...)
wait(event.event.signal; kwargs...) # wait for completion signal
unpreserve!(event) # allow kernel-associated objects to be freed
Expand Down
2 changes: 2 additions & 0 deletions src/signal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function HSASignal(init::Integer=1)
end
return signal
end
HSASignal(signal::HSA.Signal) = HSASignal(Ref(signal))

Adapt.adapt_structure(::Adaptor, sig::HSASignal) = sig.signal[]

Expand Down Expand Up @@ -51,3 +52,4 @@ function Base.wait(signal::HSASignal; soft=true, minlat=0.01, timeout=nothing)
end
end

Base.wait(signal::HSA.Signal; kwargs...) = wait(HSASignal(signal); kwargs...)
10 changes: 10 additions & 0 deletions test/device/launch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ end
y = Int64(x)
wait(@roc kernel(x, y))
end

@testset "Signal waiting" begin
kernel() = nothing

sig = @roc kernel()
wait(sig)
wait(sig.event)
wait(sig.event.signal)
wait(sig.event.signal.signal[])
end

2 comments on commit 9f387fa

@jpsamaroo
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32864

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.4 -m "<description of version>" 9f387fa009a9816773d56af1c30943e535c798e1
git push origin v0.2.4

Please sign in to comment.