Skip to content

Commit

Permalink
Plotting for ZXDiagrams and ZXGraphs, Equality for ZXDiagrams and con…
Browse files Browse the repository at this point in the history
…verting ZX to ZXW (#102)

* Vega and Compose Plotting and Tutorial

* fix ploting when there are no gates

replace index based access with get and a default value

* test ploting for all existing ZXDiagrams

* Add Multigraph ZXDiagram to tutorial.jl

* add research notebook

* convert plot_vega into an extension

* fix src and dst bugs

* add plots for all ZX test

* integrate equality

* fix id match and rewrite rule

* cvd accesible colors for plot

* better plot ux

* ZXW from IR import

* fix ir tests by shadowing instrinsic operations

* cleanup repo

* update notebooks

* set default indentation and margin

* apply default indent of 4

* use upstream packages

* indent gates_to_circ function

* test plot for ZXGraph

* use upstream packages in notebooks

* creat test for BlockIR -> ZXWDiagram -> Matrix

* correct names of Pkgs

* readd phase test with CodeIR

* improve testcoverage

* add more docstrings

* use builtin pluto env

* remove formater

* update of Pkg

* correct us of isnothing

* use pluto env

* fix if expression for id rule

* Minor Notebook improvements

* remove Documentation.yml

* fix indentation again

---------

Co-authored-by: liam <[email protected]>
Co-authored-by: Liam <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent 146e39b commit a324b54
Show file tree
Hide file tree
Showing 30 changed files with 4,166 additions and 296 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
/docs/src/assets/indigo.css
/docs/Manifest.toml
.vscode
.JuliaFormatter.toml
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
YaoHIR = "6769671a-fce8-4286-b3f7-6099e1b1298a"
YaoLocations = "66df03fb-d475-48f7-b449-3d9064bf085b"

[weakdeps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Vega = "239c3e63-733f-47ad-beb7-a12fde22c578"

[extensions]
ZXCalculusExt = ["Vega", "DataFrames"]

[compat]
Expronicon = "0.10.3"
Graphs = "1"
Expand All @@ -22,11 +29,12 @@ Multigraphs = "0.3"
OMEinsum = "0.7, 0.8"
YaoHIR = "0.2"
YaoLocations = "0.1"
julia = "1.9"
julia = ">= 1.9"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Vega = "239c3e63-733f-47ad-beb7-a12fde22c578"

[targets]
test = ["Test", "Documenter"]
test = ["Test", "Documenter", "Vega", "DataFrames"]
266 changes: 266 additions & 0 deletions ext/ZXCalculusExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
module ZXCalculusExt

using Graphs
import Graphs: AbstractEdge, src, dst
using Vega, DataFrames
using ZXCalculus, ZXCalculus.ZX
using ZXCalculus: ZX

function spider_type_string(st1)
st1 == SpiderType.X && return "X"
st1 == SpiderType.Z && return "Z"
st1 == SpiderType.H && return "H"
st1 == SpiderType.Out && return "Out"
st1 == SpiderType.In && return "In"
end

function generate_d_spiders(vs, st, ps, x_locs_normal, y_locs_normal)
return DataFrame(
id = [v for v in vs],
x = [get(x_locs_normal, v, nothing) for v in vs],
y = [get(y_locs_normal, v, nothing) for v in vs],
spider_type = [spider_type_string(st[v]) for v in vs],
phase = [iszero(ps[v]) ? "" : "$(ps[v])" for v in vs],
)
end

function generate_d_edges(zxd::ZXDiagram)
s = Int[]
d = Int[]
isH = Bool[]
for e in ZXCalculus.ZX.edges(zxd.mg)
push!(s, src(e))
push!(d, dst(e))
push!(isH, false)
end
return DataFrame(src = s, dst = d, isHadamard = isH)
end
function generate_d_edges(zxd::ZXGraph)
s = Int[]
d = Int[]
isH = Bool[]
for e in ZXCalculus.ZX.edges(zxd.mg)
push!(s, src(e))
push!(d, dst(e))
push!(isH, ZXCalculus.ZX.is_hadamard(zxd, src(e), dst(e)))
end
return DataFrame(src = s, dst = d, isHadamard = isH)
end

function ZXCalculus.ZX.plot(zxd::Union{ZXDiagram,ZXGraph}; kwargs...)
scale = 2
lattice_unit = 50 * scale
zxd = copy(zxd)
ZXCalculus.ZX.generate_layout!(zxd)
vs = spiders(zxd)
x_locs = zxd.layout.spider_col
x_min = minimum(values(x_locs), init = 0)
x_max = maximum(values(x_locs), init = 1)
x_range = (x_max - x_min) * lattice_unit
y_locs = zxd.layout.spider_q
y_min = minimum(values(y_locs), init = 0)
y_max = maximum(values(y_locs), init = 1)
y_range = (y_max - y_min) * lattice_unit
x_locs_normal = copy(x_locs)
for (k, v) in x_locs_normal
x_locs_normal[k] = v * lattice_unit
end
y_locs_normal = copy(y_locs)
for (k, v) in y_locs_normal
y_locs_normal[k] = v * lattice_unit
end

st = zxd.st
ps = zxd.ps

d_spiders = generate_d_spiders(vs, st, ps, x_locs_normal, y_locs_normal)
d_edges = generate_d_edges(zxd)

spec = @vgplot(
$schema = "https://vega.github.io/schema/vega/v5.json",
height = y_range,
width = x_range,
padding = 0.5 * lattice_unit,
marks = [
{
encode = {
update = {strokeWidth = {signal = "edgeWidth"}, path = {field = "path"}},
enter = {stroke = {field = "color"}},
},
from = {data = "edges"},
type = "path",
},
{
encode = {
update = {
stroke = {value = "black"},
x = {field = "x"},
strokeWidth = {signal = "strokeWidth"},
size = {signal = "spiderSize"},
y = {field = "y"},
},
enter = {shape = {field = "shape"}, fill = {field = "color"}},
},
from = {data = "spiders"},
type = "symbol",
},
{
encode = {
update = {
align = {value = "center"},
x = {field = "x"},
ne = {value = "top"},
opacity = {signal = "showIds"},
y = {field = "y"},
fontSize = {value = 6 * lattice_unit / 50},
dy = {value = 18 * lattice_unit / 50},
},
enter = {fill = {value = "lightgray"}, text = {field = "id"}},
},
from = {data = "spiders"},
type = "text",
},
{
encode = {
update = {
align = {value = "center"},
x = {field = "x"},
dy = {value = lattice_unit / 50},
baseline = {value = "middle"},
opacity = {signal = "showPhases"},
fontSize = {value = 6 * lattice_unit / 50},
y = {field = "y"},
},
enter = {fill = {value = "black"}, text = {field = "phase"}},
},
from = {data = "spiders"},
type = "text",
},
],
data = [
{
name = "spiders",
values = d_spiders,
on = [{
modify = "whichSymbol",
values = "newLoc && {x: newLoc.x, y: newLoc.y}",
trigger = "newLoc",
}],
transform = [
{
as = "shape",
expr = "datum.spider_type === 'Z' ? 'circle' : (datum.spider_type === 'X' ? 'circle' : (datum.spider_type === 'H' ? 'square' : 'circle'))",
type = "formula",
},
{
as = "color",
expr = "datum.spider_type === 'Z' ? '#D8F8D8' : (datum.spider_type === 'X' ? '#E8A5A5' : (datum.spider_type === 'H' ? 'yellow' : '#9558B2'))",
type = "formula",
},
],
},
{
name = "edges",
values = d_edges,
transform = [
{
key = "id",
fields = ["src", "dst"],
as = ["source", "target"],
from = "spiders",
type = "lookup",
},
{
targetX = "target.x",
shape = {signal = "shape"},
sourceX = "source.x",
targetY = "target.y",
type = "linkpath",
sourceY = "source.y",
orient = {signal = "orient"},
},
{
as = "color",
expr = "datum.isHadamard ? '#4063D8' : 'black'",
type = "formula",
},
],
},
],
signals = [
{name = "showIds", bind = {input = "checkbox"}, value = true},
{name = "showPhases", bind = {input = "checkbox"}, value = true},
{
name = "spiderSize",
bind = {
step = lattice_unit / 5,
max = 40 * lattice_unit,
min = 2 * lattice_unit,
input = "range",
},
value = 20 * lattice_unit,
},
{
name = "strokeWidth",
bind = {
step = 0.001 * lattice_unit,
max = 3 * lattice_unit / 50,
min = 0,
input = "range",
},
value = 1.5 * lattice_unit / 50,
},
{
name = "edgeWidth",
bind = {
step = 0.001 * lattice_unit,
max = 3 * lattice_unit / 50,
min = 0.002 * lattice_unit,
input = "range",
},
value = 1.5 * lattice_unit / 50,
},
{
name = "orient",
bind = {options = ["horizontal", "vertical"], input = "select"},
value = "horizontal",
},
{
name = "shape",
bind = {
options = ["line", "arc", "curve", "diagonal", "orthogonal"],
input = "select",
},
value = "diagonal",
},
{
name = "whichSymbol",
on = [
{events = "symbol:mousedown", update = "datum"},
{events = "*:mouseup", update = "{}"},
],
value = {},
},
{
name = "newLoc",
on = [
{events = "symbol:mouseout[!event.buttons], window:mouseup", update = "false"},
{events = "symbol:mouseover", update = "{x: x(), y: y()}"},
{
events = "[symbol:mousedown, window:mouseup] > window:mousemove!",
update = "{x: x(), y: y()}",
},
],
value = false,
},
]
)
return spec
end






end
Loading

0 comments on commit a324b54

Please sign in to comment.