Skip to content

Commit

Permalink
CUDA ext
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed May 17, 2024
1 parent cd36e26 commit 286742f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TapirAdaptExt = "Adapt"
TapirCUDAExt = "CUDA"
TapirLogDensityProblemsADExt = "LogDensityProblemsAD"
TapirSpecialFunctionsExt = "SpecialFunctions"

Expand Down
16 changes: 16 additions & 0 deletions ext/TapirCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TapirSpecialFunctionsExt

using CUDA, Tapir

import Tapir: MinimalCtx, rrule!!

@is_primitive MinimalCtx Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}}
function rrule!!(
::CoDual{Type{C}},
::CoDual{UndefInitializer},
dims::CoDual{Int}...
) where {P<:CuArray{<:IEEEFloat}}
y = CoDual(P(undef, dims), P(undef, dims))
return y, NoPullback()
end
end
8 changes: 8 additions & 0 deletions test/integration_testing/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using CUDA

@testset "cuda" begin
TestUtils.test_derived_rule(
sr(123456), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Undef, 256;
Tapir.PInterp(), perf_flag=:stability, interface_only=true, is_primitive=true,
)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ include("front_matter.jl")
elseif test_group == "interface"
include("interface.jl")
elseif test_group == "gpu"
println("Placeholder for actual GPU code.")
include(joinpath("integration_testing", "cuda.jl"))
else
throw(error("test_group=$(test_group) is not recognised"))
end
Expand Down

0 comments on commit 286742f

Please sign in to comment.