Skip to content

Commit

Permalink
Move away from Luxor for visualization (#34)
Browse files Browse the repository at this point in the history
* Refactor visualization code

* Set compat bounds for `Cobweb`

* Make block fill transparent

Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>

* Rename `DRAWING_STYLE` to `DEFAULT_STYLE`

* Rename `draw` functions

* Refactor code

---------

Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>
  • Loading branch information
mofeing and jofrevalles authored Jun 29, 2023
1 parent e9a8626 commit 23c1814
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 127 deletions.
10 changes: 2 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@ authors = ["Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>"]
version = "0.2.1"

[deps]
Cobweb = "ec354790-cf28-43e8-bb59-b484409b7bad"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Luxor = "ae8d54c2-7ccd-5906-9d76-62fc9837b5bc"
MathTeXEngine = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"

[compat]
Cobweb = "0.5"
Combinatorics = "1.0.0"
LaTeXStrings = "1.3"
Luxor = "3.5"
MathTeXEngine = "0.5"
Requires = "1.0"
SimpleWeightedGraphs = "1.2"
julia = "1.8"
8 changes: 1 addition & 7 deletions src/Quac.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
module Quac

using Requires: @require

include("Gate.jl")
include("Array.jl")
include("Circuit.jl")
include("Algorithms.jl")

function __init__()
@require Luxor = "ae8d54c2-7ccd-5906-9d76-62fc9837b5bc" include("Visualization.jl")
@require VSCodeServer = "9f5989ce-84fe-42d4-91ec-6a7a8d53ed0f" using Luxor
end
include("Visualization.jl")

end
245 changes: 133 additions & 112 deletions src/Visualization.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,74 @@
using Luxor
using MathTeXEngine
using LaTeXStrings
using Cobweb: h

texname(::Type{Op}) where {Op<:Operator} = LaTeXString(String(nameof(Op)))
function svg end

texname(::Type{Sd}) = L"S^\dagger"
texname(::Type{Td}) = L"T^\dagger"
Base.show(io::IO, ::MIME"image/svg+xml", circuit::Circuit) = print(io, svg(circuit))

texname(::Type{Rx}) = L"R_X"
texname(::Type{Ry}) = L"R_Y"
texname(::Type{Rz}) = L"R_Z"
texname(::Type{Op}) where {Op<:Operator} = String(nameof(Op))

texname(::Type{Hz}) = L"H_Z"
texname(::Type{Sd}) = """S<tspan font-size="60%" baseline-shift="super">†</tspan>"""
texname(::Type{Td}) = """T<tspan font-size="60%" baseline-shift="super">†</tspan>"""

texname(::Type{FSim}) = L"F_S"
texname(::Type{Rx}) = """R<tspan font-size="60%" baseline-shift="sub">X</tspan>"""
texname(::Type{Ry}) = """R<tspan font-size="60%" baseline-shift="sub">Y</tspan>"""
texname(::Type{Rz}) = """R<tspan font-size="60%" baseline-shift="sub">Z</tspan>"""

function draw end
export draw
texname(::Type{Hz}) = """H<tspan font-size="60%" baseline-shift="sub">Z</tspan>"""

function draw(circuit::Circuit; kwargs...)
texname(::Type{FSim}) = """F<tspan font-size="60%" baseline-shift="sub">S</tspan>"""

const DEFAULT_STYLE = h.style(
"""
.wire {
stroke: black;
stroke-width: 2px;
}
.lane {}
.virtual {}
.block {
stroke: black;
fill: transparent;
}
text {
text-anchor: middle;
dominant-baseline: central;
}
""",
type = "text/css",
)

function __svg_vcat_blocks(blocks...)
container = h.svg(width = maximum(x -> x.width, blocks), height = sum(x -> parse(Int, x.height), blocks))

for (block, y) in zip(blocks, Iterators.flatten([0, cumsum(Iterators.map(x -> parse(Int, x.height), blocks))]))
block.y = y
push!(container, block)
end

return container
end

function __svg_hcat_blocks(blocks...)
container = h.svg(height = maximum(x -> x.height, blocks), width = sum(x -> parse(Int, x.width), blocks))

for (block, x) in zip(blocks, Iterators.flatten([0, Iterators.map(x -> parse(Int, x.width), blocks)]))
block.x = x
push!(container, block)
end

return container
end

function svg(circuit::Circuit; kwargs...)
n = lanes(circuit)

if isempty(circuit)
return vcat([draw(Gate{I}(lane); kwargs...) for lane in 1:n]...)
svg = __svg_vcat_blocks([svg(Gate{I}(lane); kwargs...) for lane in 1:n]...)
push!(svg, DEFAULT_STYLE)
return svg
end

# split moments if gates overlap in 1D
Expand All @@ -48,148 +94,123 @@ function draw(circuit::Circuit; kwargs...)
return ms
end |> Iterators.flatten |> collect

return mapreduce(hcat, _moments) do moment
svg = mapreduce(__svg_hcat_blocks, _moments) do moment
(min, max) = extrema(mapreduce(lanes, , moment))
moment = [map(I, filter(<(min), 1:n))..., moment..., map(I, filter(>(max), 1:n))...]

mapreduce(x -> draw(x; kwargs...), vcat, moment)
mapreduce(x -> svg(x; kwargs...), __svg_vcat_blocks, moment)
end
end

function Base.show(io::IO, ::MIME"image/svg+xml", circuit::Circuit)
_ = draw(circuit)
print(io, svgstring())
end
push!(svg, DEFAULT_STYLE)

function draw(gate::Gate{Op,1,P}; kwargs...) where {Op,P}
draw_block(; top = false, bottom = false)
return svg
end

function draw(gate::Gate{Op,N,P}; kwargs...) where {Op,N,P}
svg(gate::Gate{Op,1,P}) where {Op,P} = __svg_block(; top = false, bottom = false)

function svg(gate::Gate{Op,N,P}) where {Op,N,P}
a, b = extrema(lanes(gate))
n = b - a + 1
vcat(
draw_block(; top = true, bottom = false, kwargs...),
fill(draw_multiblock_mid(; kwargs...), (n - 2))...,
draw_block(; top = false, bottom = true, kwargs...),
__svg_vcat_blocks(
__svg_block(; top = true, bottom = false),
fill(__svg_multiblock_mid(), (n - 2))...,
__svg_block(; top = false, bottom = true),
)
end

function draw(::Gate{I,1,NamedTuple{(),Tuple{}}}; background = nothing)
@drawsvg begin
(background !== nothing) && Luxor.background(background)
origin()
line(Point(-25, 0), Point(25, 0), action = :stroke)
end 50 50
end
svg(::Gate{I,1,NamedTuple{(),Tuple{}}}) =
h.svg(h.line."wire lane"(; x1 = -25, y1 = 0, x2 = 25, y2 = 0), viewBox = "-25 -25 50 50", width = 50, height = 50)

for Op in [X, Y, Z, H, S, Sd, T, Td, Rx, Ry, Rz, Hz, FSim]
@eval draw(::Gate{$Op,1,P}; kwargs...) where {P} = draw_block(texname($Op); kwargs...)
@eval svg(::Gate{$Op,1,P}; kwargs...) where {P} = __svg_block(texname($Op); kwargs...)
end

function draw(gate::Gate{<:Control}; kwargs...)
function svg(gate::Gate{<:Control}; kwargs...)
c = control(gate)
t = target(gate)
r = range(extrema(lanes(gate))...)

# TODO Control{Swap}
@assert length(t) == 1

vcat(
__svg_vcat_blocks(
[
if lane == first(r)
draw_copy(:top; kwargs...)
__svg_copy(:top)
elseif lane c
draw_copy(:mid; kwargs...)
__svg_copy(:mid)
else
draw_cross(; kwargs...)
__svg_cross()
end for lane in r if lane < only(t)
]...,
draw(
svg(
Gate{targettype(operator(gate))}(target(gate)...; parameters(gate)...);
top = !any(<(only(t)), c),
bottom = !any(>(only(t)), c),
kwargs...,
),
[
if lane == last(r)
draw_copy(:bottom; kwargs...)
__svg_copy(:bottom)
elseif lane c
draw_copy(:mid; kwargs...)
__svg_copy(:mid)
else
draw_cross(; kwargs...)
__svg_cross()
end for lane in r if lane > only(t)
]...,
)
end

function draw_block(label = ""; top::Bool = false, bottom::Bool = false, background = nothing)
@drawsvg begin
(background !== nothing) && Luxor.background(background)
origin()

# lane wire
line(Point(-25, 0), Point(-15, 0), action = :stroke)
line(Point(25, 0), Point(15, 0), action = :stroke)

# control connectors
if top
line(Point(0, 25), Point(0, 15), action = :stroke)
end
if bottom
line(Point(0, -25), Point(0, -15), action = :stroke)
end

rect(-15, -15, 30, 30, action = :stroke)

# label
fontsize(16)
text(label, Point(0, 0), valign = :middle, halign = :center)
end 50 50
end

function draw_multiblock_mid(; background = nothing)
@drawsvg begin
(background !== nothing) && Luxor.background(background)
origin()

# lane wire
line(Point(-25, 0), Point(-15, 0), action = :stroke)
line(Point(25, 0), Point(15, 0), action = :stroke)

# vertical lines
line(Point(-25, 0), Point(25, 0), action = :stroke)
line(Point(0, -25), Point(0, 25), action = :stroke)
end 50 50
function __svg_block(label = ""; top::Bool = false, bottom::Bool = false)
drawing = h.svg(
h.line."wire lane"(x1 = -25, y1 = 0, x2 = -15, y2 = 0),
h.line."wire lane"(x1 = 25, y1 = 0, x2 = 15, y2 = 0),
h.rect."block"(x = -15, y = -15, width = 30, height = 30),
h.text."label"(label, x = 0, y = 0),
viewBox = "-25 -25 50 50",
width = 50,
height = 50,
)
top && push!(drawing, h.line."wire virtual"(x1 = 0, y1 = 25, x2 = 0, y2 = 15))
bottom && push!(drawing, h.line."wire virtual"(x1 = 0, y1 = -25, x2 = 0, y2 = -15))
return drawing
end

function draw_cross(; background = nothing)
@drawsvg begin
(background !== nothing) && Luxor.background(background)
origin()

line(Point(-25, 0), Point(25, 0), action = :stroke)
line(Point(0, -25), Point(0, 25), action = :stroke)
end 50 50
end
__svg_multiblock_mid() = h.svg(
h.line."wire lane"(x1 = -25, y1 = 0, x2 = -15, y2 = 0),
h.line."wire lane"(x1 = 25, y1 = 0, x2 = 15, y2 = 0),
h.line(x1 = -25, y1 = 0, x2 = 25, y2 = 0), # TODO assign class. fill?
h.line(x1 = 0, y1 = -25, x2 = 0, y2 = 25), # TODO assign class. fill?
viewBox = "-25 -25 50 50",
width = 50,
height = 50,
)

__svg_cross() = h.svg(
h.line."wire lane"(x1 = -25, y1 = 0, x2 = 25, y2 = 0),
h.line."wire virtual"(x1 = 0, y1 = -25, x2 = 0, y2 = 25),
viewBox = "-25 -25 50 50",
width = 50,
height = 50,
)

function __svg_copy(dir::Symbol)
(a, b) = if dir === :top
0, 25
elseif dir === :bottom
0, -25
elseif dir === :mid
25, -25
else
throw(ArgumentError("`dir`=$dir is invalid"))
end

function draw_copy(dir::Symbol; background = nothing)
@drawsvg begin
(background !== nothing) && Luxor.background(background)
origin()
line(Point(-25, 0), Point(25, 0), action = :stroke)

circle(0, 0, 5, action = :fill)

(a, b) = if dir == :top
Point(0, 0), Point(0, 25)
elseif dir == :bottom
Point(0, 0), Point(0, -25)
elseif dir == :mid
Point(0, 25), Point(0, -25)
else
throw(ArgumentError("`dir`=$dir is invalid"))
end
line(a, b, action = :stroke)
end 50 50
h.svg(
h.line."wire lane"(x1 = -25, y1 = 0, x2 = 25, y2 = 0),
h.circle."copy"(cx = 0, cy = 0, r = 5),
h.line."wire virtual"(x1 = 0, y1 = a, x2 = 0, y2 = b),
viewBox = "-25 -25 50 50",
width = 50,
height = 50,
)
end

0 comments on commit 23c1814

Please sign in to comment.