Skip to content

Commit

Permalink
Update rocFFT (#640)
Browse files Browse the repository at this point in the history
- Update rocFFT wrapper and add error checks.
- Update HIPStream for plan.
- Re-enable rocFFT tests.
  • Loading branch information
pxl-th authored May 25, 2024
1 parent 2a00c14 commit f8ca0d6
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 239 deletions.
39 changes: 27 additions & 12 deletions src/fft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ const ROCFFT_FORWARD = true
const ROCFFT_INVERSE = false

# TODO: Real to Complex full not possible atm
# For R2C -> cast array to Complex first

# K is flag for forward/inverse
mutable struct cROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
mutable struct cROCFFTPlan{T, K, inplace, N} <: ROCFFTPlan{T, K, inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N,Int} # Julia size of input array
osz::NTuple{N,Int} # Julia size of output array
sz::NTuple{N, Int} # Julia size of input array
osz::NTuple{N, Int} # Julia size of output array
xtype::rocfft_transform_type
region::Any
pinv::ScaledPlan # required by AbstractFFTs API
Expand All @@ -35,18 +38,20 @@ mutable struct cROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
info = info_ref[]

# assign to the current stream
rocfft_execution_info_set_stream(info, AMDGPU.stream())
stream = AMDGPU.stream()
rocfft_execution_info_set_stream(info, stream)
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, workarea, info, size(X), sizey, xtype, region)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
finalizer(unsafe_free!, p)
p
end
end

mutable struct rROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
handle::rocfft_plan
stream::HIPStream
workarea::ROCVector{Int8}
execution_info::rocfft_execution_info
sz::NTuple{N,Int} # Julia size of input array
Expand All @@ -63,16 +68,26 @@ mutable struct rROCFFTPlan{T,K,inplace,N} <: ROCFFTPlan{T,K,inplace}
rocfft_execution_info_create(info_ref)
info = info_ref[]

rocfft_execution_info_set_stream(info, AMDGPU.stream())
stream = AMDGPU.stream()
rocfft_execution_info_set_stream(info, stream)
if length(workarea) > 0
rocfft_execution_info_set_work_buffer(info, workarea, length(workarea))
end
p = new(handle, workarea, info, size(X), sizey, xtype, region)
p = new(handle, stream, workarea, info, size(X), sizey, xtype, region)
finalizer(unsafe_free!, p)
p
end
end

function update_stream!(plan::ROCFFTPlan)
new_stream = AMDGPU.stream()
if plan.stream != new_stream
plan.stream = new_stream
rocfft_execution_info_set_stream(info, new_stream)
end
return
end

const xtypenames = (
"complex forward", "complex inverse", "real forward", "real inverse")

Expand Down Expand Up @@ -140,8 +155,7 @@ function plan_inv(p::cROCFFTPlan{T,ROCFFT_FORWARD,inplace,N}) where {T<:rocfftCo
xtype = rocfft_transform_type_complex_inverse
pp = get_plan(xtype, p.sz, T, inplace, p.region)
ScaledPlan(
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(
pp..., X, p.sz, xtype, p.region),
cROCFFTPlan{T,ROCFFT_INVERSE,inplace,N}(pp..., X, p.sz, xtype, p.region),
normalization(X, p.region))
end

Expand Down Expand Up @@ -198,9 +212,8 @@ function assert_applicable(p::ROCFFTPlan{T,K}, X::ROCArray{T}, Y::ROCArray{Ty})
end
end

# TODO update stream

function unsafe_execute!(plan::cROCFFTPlan{T,K,true,N}, X::ROCArray{T,N}) where {T,K,N}
update_stream!(plan)
rocfft_execute(plan, [pointer(X),], C_NULL, plan.execution_info)
end

Expand All @@ -209,6 +222,7 @@ function unsafe_execute!(
) where {T,N,K}
X = copy(X) # since input array can also be modified
# TODO on 1.11 we need to manually cast `pointer(X)` to `Ptr{Cvoid}`.
update_stream!(plan)
rocfft_execute(plan, [pointer(X),], [pointer(Y),], plan.execution_info)
end

Expand All @@ -218,6 +232,7 @@ function unsafe_execute!(
) where {T<:rocfftReals,N}
@assert plan.xtype == rocfft_transform_type_real_forward
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
end

Expand All @@ -227,10 +242,10 @@ function unsafe_execute!(
) where {T<:rocfftComplexes,N}
@assert plan.xtype == rocfft_transform_type_real_inverse
Xcopy = copy(X)
update_stream!(plan)
rocfft_execute(plan, [pointer(Xcopy),], [pointer(Y),], plan.execution_info)
end


function LinearAlgebra.mul!(y::ROCArray{Ty}, p::ROCFFTPlan{T,K,false}, x::ROCArray{T}) where {T,Ty,K}
assert_applicable(p, x, y)
unsafe_execute!(p, x, y)
Expand Down
86 changes: 63 additions & 23 deletions src/fft/librocfft.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using CEnum

mutable struct rocfft_plan_t end

const rocfft_plan = Ptr{rocfft_plan_t}
Expand All @@ -12,6 +10,14 @@ mutable struct rocfft_execution_info_t end

const rocfft_execution_info = Ptr{rocfft_execution_info_t}

mutable struct rocfft_field_t end

const rocfft_field = Ptr{rocfft_field_t}

mutable struct rocfft_brick_t end

const rocfft_brick = Ptr{rocfft_brick_t}

@cenum rocfft_status_e::UInt32 begin
rocfft_status_success = 0
rocfft_status_failure = 1
Expand All @@ -38,6 +44,7 @@ const rocfft_transform_type = rocfft_transform_type_e
@cenum rocfft_precision_e::UInt32 begin
rocfft_precision_single = 0
rocfft_precision_double = 1
rocfft_precision_half = 2
end

const rocfft_precision = rocfft_precision_e
Expand All @@ -60,100 +67,133 @@ end

const rocfft_array_type = rocfft_array_type_e

# no prototype is found for this function at rocfft.h:124:29, please use with caution
function rocfft_setup()
AMDGPU.prepare_state()
ccall((:rocfft_setup, librocfft), rocfft_status, ()) |> check
@check ccall((:rocfft_setup, librocfft), rocfft_status, ())
end

# no prototype is found for this function at rocfft.h:128:29, please use with caution
function rocfft_cleanup()
AMDGPU.prepare_state()
ccall((:rocfft_cleanup, librocfft), rocfft_status, ()) |> check
@check ccall((:rocfft_cleanup, librocfft), rocfft_status, ())
end

function rocfft_plan_create(plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_create, librocfft), rocfft_status, (Ptr{rocfft_plan}, rocfft_result_placement, rocfft_transform_type, rocfft_precision, Cint, Ptr{Cint}, Cint, rocfft_plan_description), plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description) |> check
@check ccall((:rocfft_plan_create, librocfft), rocfft_status, (Ptr{rocfft_plan}, rocfft_result_placement, rocfft_transform_type, rocfft_precision, Csize_t, Ptr{Csize_t}, Csize_t, rocfft_plan_description), plan, placement, transform_type, precision, dimensions, lengths, number_of_transforms, description)
end

function rocfft_execute(plan, in_buffer, out_buffer, info)
AMDGPU.prepare_state()
ccall((:rocfft_execute, librocfft), rocfft_status, (rocfft_plan, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, rocfft_execution_info), plan, in_buffer, out_buffer, info) |> check
@check ccall((:rocfft_execute, librocfft), rocfft_status, (rocfft_plan, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, rocfft_execution_info), plan, in_buffer, out_buffer, info)
end

function rocfft_plan_destroy(plan)
AMDGPU.prepare_state()
ccall((:rocfft_plan_destroy, librocfft), rocfft_status, (rocfft_plan,), plan) |> check
@check ccall((:rocfft_plan_destroy, librocfft), rocfft_status, (rocfft_plan,), plan)
end

function rocfft_plan_description_set_scale_factor(description, scale_factor)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_set_scale_factor, librocfft), rocfft_status, (rocfft_plan_description, Cdouble), description, scale_factor) |> check
@check ccall((:rocfft_plan_description_set_scale_factor, librocfft), rocfft_status, (rocfft_plan_description, Cdouble), description, scale_factor)
end

function rocfft_plan_description_set_data_layout(description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_set_data_layout, librocfft), rocfft_status, (rocfft_plan_description, rocfft_array_type, rocfft_array_type, Ptr{Cint}, Ptr{Cint}, Cint, Ptr{Cint}, Cint, Cint, Ptr{Cint}, Cint), description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance) |> check
@check ccall((:rocfft_plan_description_set_data_layout, librocfft), rocfft_status, (rocfft_plan_description, rocfft_array_type, rocfft_array_type, Ptr{Csize_t}, Ptr{Csize_t}, Csize_t, Ptr{Csize_t}, Csize_t, Csize_t, Ptr{Csize_t}, Csize_t), description, in_array_type, out_array_type, in_offsets, out_offsets, in_strides_size, in_strides, in_distance, out_strides_size, out_strides, out_distance)
end

function rocfft_field_create(field)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_create, librocfft), rocfft_status, (Ptr{rocfft_field},), field)
end

function rocfft_field_destroy(field)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_destroy, librocfft), rocfft_status, (rocfft_field,), field)
end

function rocfft_get_version_string(buf, len)
AMDGPU.prepare_state()
ccall((:rocfft_get_version_string, librocfft), rocfft_status, (Ptr{Cchar}, Cint), buf, len) |> check
@check ccall((:rocfft_get_version_string, librocfft), rocfft_status, (Ptr{Cchar}, Csize_t), buf, len)
end

function rocfft_brick_create(brick, field_lower, field_upper, brick_stride, dim, deviceID)
AMDGPU.prepare_state()
@check ccall((:rocfft_brick_create, librocfft), rocfft_status, (Ptr{rocfft_brick}, Ptr{Csize_t}, Ptr{Csize_t}, Ptr{Csize_t}, Csize_t, Cint), brick, field_lower, field_upper, brick_stride, dim, deviceID)
end

function rocfft_brick_destroy(brick)
AMDGPU.prepare_state()
@check ccall((:rocfft_brick_destroy, librocfft), rocfft_status, (rocfft_brick,), brick)
end

function rocfft_field_add_brick(field, brick)
AMDGPU.prepare_state()
@check ccall((:rocfft_field_add_brick, librocfft), rocfft_status, (rocfft_field, rocfft_brick), field, brick)
end

function rocfft_plan_description_add_infield(description, field)
AMDGPU.prepare_state()
@check ccall((:rocfft_plan_description_add_infield, librocfft), rocfft_status, (rocfft_plan_description, rocfft_field), description, field)
end

function rocfft_plan_description_add_outfield(description, field)
AMDGPU.prepare_state()
@check ccall((:rocfft_plan_description_add_outfield, librocfft), rocfft_status, (rocfft_plan_description, rocfft_field), description, field)
end

function rocfft_plan_get_work_buffer_size(plan, size_in_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_plan_get_work_buffer_size, librocfft), rocfft_status, (rocfft_plan, Ptr{Cint}), plan, size_in_bytes) |> check
@check ccall((:rocfft_plan_get_work_buffer_size, librocfft), rocfft_status, (rocfft_plan, Ptr{Csize_t}), plan, size_in_bytes)
end

function rocfft_plan_get_print(plan)
AMDGPU.prepare_state()
ccall((:rocfft_plan_get_print, librocfft), rocfft_status, (rocfft_plan,), plan) |> check
@check ccall((:rocfft_plan_get_print, librocfft), rocfft_status, (rocfft_plan,), plan)
end

function rocfft_plan_description_create(description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_create, librocfft), rocfft_status, (Ptr{rocfft_plan_description},), description) |> check
@check ccall((:rocfft_plan_description_create, librocfft), rocfft_status, (Ptr{rocfft_plan_description},), description)
end

function rocfft_plan_description_destroy(description)
AMDGPU.prepare_state()
ccall((:rocfft_plan_description_destroy, librocfft), rocfft_status, (rocfft_plan_description,), description) |> check
@check ccall((:rocfft_plan_description_destroy, librocfft), rocfft_status, (rocfft_plan_description,), description)
end

function rocfft_execution_info_create(info)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_create, librocfft), rocfft_status, (Ptr{rocfft_execution_info},), info) |> check
@check ccall((:rocfft_execution_info_create, librocfft), rocfft_status, (Ptr{rocfft_execution_info},), info)
end

function rocfft_execution_info_destroy(info)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_destroy, librocfft), rocfft_status, (rocfft_execution_info,), info) |> check
@check ccall((:rocfft_execution_info_destroy, librocfft), rocfft_status, (rocfft_execution_info,), info)
end

function rocfft_execution_info_set_work_buffer(info, work_buffer, size_in_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_work_buffer, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}, Cint), info, work_buffer, size_in_bytes) |> check
@check ccall((:rocfft_execution_info_set_work_buffer, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}, Csize_t), info, work_buffer, size_in_bytes)
end

function rocfft_execution_info_set_stream(info, stream)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_stream, librocfft), rocfft_status, (rocfft_execution_info, hipStream_t), info, stream) |> check
@check ccall((:rocfft_execution_info_set_stream, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Cvoid}), info, stream)
end

function rocfft_execution_info_set_load_callback(info, cb_functions, cb_data, shared_mem_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_load_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Cint), info, cb_functions, cb_data, shared_mem_bytes) |> check
@check ccall((:rocfft_execution_info_set_load_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Csize_t), info, cb_functions, cb_data, shared_mem_bytes)
end

function rocfft_execution_info_set_store_callback(info, cb_functions, cb_data, shared_mem_bytes)
AMDGPU.prepare_state()
ccall((:rocfft_execution_info_set_store_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Cint), info, cb_functions, cb_data, shared_mem_bytes) |> check
@check ccall((:rocfft_execution_info_set_store_callback, librocfft), rocfft_status, (rocfft_execution_info, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Csize_t), info, cb_functions, cb_data, shared_mem_bytes)
end

const rocfft_version_major = 1

const rocfft_version_minor = 0

const rocfft_version_patch = 21
const rocfft_version_patch = 27
16 changes: 8 additions & 8 deletions src/fft/rocFFT.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
module rocFFT
export ROCFFTError

import AbstractFFTs: complexfloat, realfloat
import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!
import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization
import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan
using CEnum
using LinearAlgebra

# TODO
# @reexport using AbstractFFTs

using LinearAlgebra
import AbstractFFTs: complexfloat, realfloat
import AbstractFFTs: plan_fft, plan_fft!, plan_bfft, plan_bfft!
import AbstractFFTs: plan_rfft, plan_brfft, plan_inv, normalization
import AbstractFFTs: fft, bfft, ifft, rfft, Plan, ScaledPlan

import ..AMDGPU
import .AMDGPU: ROCArray, ROCVector, HandleCache, HIP, unsafe_free!, check
import .AMDGPU: ROCArray, ROCVector, HandleCache, HIP, unsafe_free!, check, @check
import AMDGPU: librocfft
import .HIP: hipStream_t, HIPContext, HIPStream

using CEnum

include("librocfft.jl")
include("error.jl")
include("util.jl")
include("wrappers.jl")
include("fft.jl")

version() = VersionNumber(
Expand Down
Loading

0 comments on commit f8ca0d6

Please sign in to comment.