Skip to content

Commit

Permalink
TensorOperations v5 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Jul 15, 2024
1 parent 50f0ef2 commit ad54f62
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.8'
- '1' # automatically expands to the latest stable 1.x release of Julia
os:
- ubuntu-latest
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperationsTBLIS"
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
authors = ["lkdvos <lukas.devos@ugent.be>"]
version = "0.1.1"
version = "0.2.0"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -11,9 +11,9 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"

[compat]
TensorOperations = "4"
TensorOperations = "5"
TupleTools = "1"
julia = "1.6"
julia = "1.8"
tblis_jll = "1.2"

[extras]
Expand Down
47 changes: 25 additions & 22 deletions src/TensorOperationsTBLIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,42 @@ include("LibTblis.jl")
using .LibTblis

export tblis_set_num_threads, tblis_get_num_threads
export tblisBackend

# TensorOperations
#------------------

const tblisBackend = TensorOperations.Backend{:tblis}
struct tblisBackend <: TensorOperations.AbstractBackend end

function TensorOperations.tensoradd!(C::StridedArray{T}, pC::Index2Tuple,
A::StridedArray{T}, conjA::Symbol,
function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T},
pA::Index2Tuple, conjA::Bool,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensoradd(C, pC, A)
TensorOperations.dimcheck_tensoradd(C, pC, A)
TensorOperations.argcheck_tensoradd(C, A, pA)
TensorOperations.dimcheck_tensoradd(C, A, pA)

szC = collect(size(C))
strC = collect(strides(C))
C_tblis = tblis_tensor(C, szC, strC, β)

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

einA, einC = TensorOperations.add_labels(pC)
einA, einC = TensorOperations.add_labels(pA)
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))

return C
end

function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
function TensorOperations.tensorcontract!(C::StridedArray{T},
A::StridedArray{T}, pA::Index2Tuple,
conjA::Symbol, B::StridedArray{T},
pB::Index2Tuple, conjB::Symbol, α::Number,
β::Number, ::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensorcontract(C, pC, A, pA, B, pB)
TensorOperations.dimcheck_tensorcontract(C, pC, A, pA, B, pB)
conjA::Bool, B::StridedArray{T},
pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB)
TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

rmul!(C, β) # TODO: is it possible to use tblis scaling here?
szC = ndims(C) == 0 ? Int[] : collect(size(C))
Expand All @@ -51,25 +53,26 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

szB = collect(size(B))
strB = collect(strides(B))
B_tblis = tblis_tensor(conjB == :C ? conj(B) : B, szB, strB, 1)
B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1)

einA, einB, einC = TensorOperations.contract_labels(pA, pB, pC)
einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB)
tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis,
string(einC...))

return C
end

function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,
A::StridedArray{T}, pA::Index2Tuple, conjA::Symbol,
function TensorOperations.tensortrace!(C::StridedArray{T},
A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple,
conjA::Bool,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensortrace(C, pC, A, pA)
TensorOperations.dimcheck_tensortrace(C, pC, A, pA)
TensorOperations.argcheck_tensortrace(C, A, p, q)
TensorOperations.dimcheck_tensortrace(C, A, p, q)

rmul!(C, β) # TODO: is it possible to use tblis scaling here?
szC = ndims(C) == 0 ? Int[] : collect(size(C))
Expand All @@ -78,9 +81,9 @@ function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

einA, einC = TensorOperations.trace_labels(pC, pA...)
einA, einC = TensorOperations.trace_labels(p, q)

tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))

Expand Down
34 changes: 19 additions & 15 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ using TensorOperationsTBLIS
using Test
using LinearAlgebra: norm

const tblisbackend = tblisBackend()
@testset "elementary operations" verbose = true begin
@testset "tensorcopy" begin
A = randn(Float32, (3, 5, 4, 6))
@tensor C1[4, 1, 3, 2] := A[1, 2, 3, 4]
@tensor backend = tblis C2[4, 1, 3, 2] := A[1, 2, 3, 4]
@tensor backend = tblisbackend C2[4, 1, 3, 2] := A[1, 2, 3, 4]
@test C2 C1
end

Expand All @@ -16,49 +17,50 @@ using LinearAlgebra: norm
B = randn(Float32, (5, 6, 3, 4))
α = randn(Float32)
@tensor C1[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
@tensor backend = tblis C2[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
@tensor backend = tblisbackend C2[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
@test collect(C2) C1

C = randn(ComplexF32, (5, 6, 3, 4))
D = randn(ComplexF32, (5, 3, 4, 6))
β = randn(ComplexF32)
@tensor E1[a, b, c, d] := C[a, b, c, d] + β * conj(D[a, c, d, b])
@tensor backend = tblis E2[a, b, c, d] := C[a, b, c, d] + β * conj(D[a, c, d, b])
@tensor backend = tblisbackend E2[a, b, c, d] := C[a, b, c, d] +
β * conj(D[a, c, d, b])
@test collect(E2) E1
end

@testset "tensortrace" begin
A = randn(Float32, (5, 10, 10))
@tensor B1[a] := A[a, b′, b′]
@tensor backend = tblis B2[a] := A[a, b′, b′]
@tensor backend = tblisbackend B2[a] := A[a, b′, b′]
@test B2 B1

C = randn(ComplexF32, (3, 20, 5, 3, 20, 4, 5))
@tensor D1[e, a, d] := C[a, b, c, d, b, e, c]
@tensor backend = tblis D2[e, a, d] := C[a, b, c, d, b, e, c]
@tensor backend = tblisbackend D2[e, a, d] := C[a, b, c, d, b, e, c]
@test D2 D1

@tensor D3[a, e, d] := conj(C[a, b, c, d, b, e, c])
@tensor backend = tblis D4[a, e, d] := conj(C[a, b, c, d, b, e, c])
@tensor backend = tblisbackend D4[a, e, d] := conj(C[a, b, c, d, b, e, c])
@test D4 D3

α = randn(ComplexF32)
@tensor D5[d, e, a] := α * C[a, b, c, d, b, e, c]
@tensor backend = tblis D6[d, e, a] := α * C[a, b, c, d, b, e, c]
@tensor backend = tblisbackend D6[d, e, a] := α * C[a, b, c, d, b, e, c]
@test D6 D5
end

@testset "tensorcontract" begin
A = randn(Float32, (3, 20, 5, 3, 4))
B = randn(Float32, (5, 6, 20, 3))
@tensor C1[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
@tensor backend = tblis C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
@tensor backend = tblisbackend C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
@test C2 C1

D = randn(ComplexF64, (3, 3, 3))
E = rand(ComplexF64, (3, 3, 3))
@tensor F1[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
@tensor backend = tblis F2[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
@tensor backend = tblisbackend F2[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
@test F2 F1 atol = 1e-12
end
end
Expand All @@ -72,12 +74,14 @@ end
# α = 1

@tensor D1[d, f, h] := A[c, a, f, a, e, b, b, g] * B[c, h, g, e, d] + α * C[d, h, f]
@tensor backend = tblis D2[d, f, h] := A[c, a, f, a, e, b, b, g] * B[c, h, g, e, d] +
α * C[d, h, f]
@tensor backend = tblisbackend D2[d, f, h] := A[c, a, f, a, e, b, b, g] *
B[c, h, g, e, d] +
α * C[d, h, f]
@test D2 D1 rtol = 1e-8

@test norm(vec(D1)) sqrt(abs(@tensor D1[d, f, h] * conj(D1[d, f, h])))
@test norm(D2) sqrt(abs(@tensor backend = tblis D2[d, f, h] * conj(D2[d, f, h])))
@test norm(D2)
sqrt(abs(@tensor backend = tblisbackend D2[d, f, h] * conj(D2[d, f, h])))

@testset "readme example" begin
α = randn()
Expand All @@ -90,7 +94,7 @@ end
D[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
E[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
end
@tensor backend = tblis begin
@tensor backend = tblisbackend begin
D2[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
E2[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
end
Expand All @@ -113,7 +117,7 @@ end
HrA12[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] *
H[s1, s2, t1, t2]
end
@tensor backend = tblis begin
@tensor backend = tblisbackend begin
HrA12′[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] *
H[s1, s2, t1, t2]
end
Expand All @@ -123,7 +127,7 @@ end
E1 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] *
conj(A1[a', t, b']) * conj(A2[b', t', c'])
end
@tensor backend = tblis begin
@tensor backend = tblisbackend begin
E2 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] *
conj(A1[a', t, b']) * conj(A2[b', t', c'])
end
Expand Down

0 comments on commit ad54f62

Please sign in to comment.