Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] symbolic algebra functionality #14

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,17 @@ uuid = "c7a6d0f7-daa6-4368-ba67-89ed64127c3b"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.1"

[deps]
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
LightSumTypes = "f56206fc-af4c-5561-a72a-43fe2ca5a923"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[compat]
Dictionaries = "0.4.4"
LightSumTypes = "5.0.1"
SymbolicUtils = "3.23.0"
TermInterface = "2.0.0"
VectorInterface = "0.5.0"
julia = "1.10"
18 changes: 13 additions & 5 deletions src/QuantumOperatorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
module QuantumOperatorAlgebra

include("LazyApply/LazyApply.jl")
# Make these available as `QuantumOperatorAlgebra.f`.
using .LazyApply: coefficient, terms
export Op, LocalOp

include("op.jl")
include("trotter.jl")
using LightSumTypes
using VectorInterface
import VectorInterface: scalartype
using TermInterface
using Dictionaries

import Base: +, *, -, /, \
import Base: one, zero, isone, iszero
import Base: show, show_unquoted

include("symbolicalgebra/abstractalgebra.jl")
include("symbolicalgebra/localalgebra.jl")

end
223 changes: 223 additions & 0 deletions src/symbolicalgebra/abstractalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""
SymbolicAlgebra

Abstract supertype for working with symbolic (operator) algebras.
"""
abstract type SymbolicAlgebra{T<:Number} end

# Building blocks
# ---------------
struct Op{T<:Number,A}
id::A
end

struct Scaled{O,T<:Number}
op::O
scalar::T
end

struct Sum{O,T<:Number}
terms::Dictionary{O,T}
Sum{O,T}() where {O,T} = new{O,T}(Dictionary{O,T}())
Sum{O,T}(terms::Dictionary{O,T}) where {O,T} = new{O,T}(terms)
end

struct Prod{O}
factors::Vector{O}
Prod{O}() where {O} = new{O}(O[])
Prod{O}(factors::Vector{O}) where {O} = new{O}(factors)
end

struct Kron{O}
factors::Vector{O}
Kron{O}() where {O} = new{O}(O[])
Kron{O}(factors::Vector{O}) where {O} = new{O}(factors)
end

struct Fun{O}
f::Any
args::Vector{O}
Fun{O}(f) where {O} = new{O}(f, O[])
Fun{O}(f, args::Vector{O}) where {O} = new{O}(f, args)
end


# Properties
# ----------
VectorInterface.scalartype(::Type{<:SymbolicAlgebra{T}}) where {T} = scalartype(T)
algebratype(a::SymbolicAlgebra) = algebratype(typeof(a))

# Linear algebra
# --------------

# functionality to rewrite basic operations in terms of a more limited set
(O::SymbolicAlgebra * λ::Number) = scale(O, λ)
(λ::Number * O::SymbolicAlgebra) = scale(O, λ)
(O::SymbolicAlgebra / λ::Number) = scale(O, inv(λ))
(λ::Number \ O::SymbolicAlgebra) = scale(O, inv(λ))

+(O::SymbolicAlgebra) = scale(O, one(scalartype(O)))
-(O::SymbolicAlgebra) = scale(O, -one(scalartype(O)))
(O₁::SymbolicAlgebra + O₂::SymbolicAlgebra) = add(O₁, O₂)
(O₁::SymbolicAlgebra - O₂::SymbolicAlgebra) = add(O₁, O₂, -one(scalartype(O₁)))
# (O1::SymbolicAlgebra - O2::SymbolicAlgebra) = -(promote(O1, O2)...)

# Show utility
# ------------
# functionality to display symbolic expressions
# -> expressions show have two variants: show and show_unquoted to determine whether
# the expressions should be using parentheses

"""
show_scaled(io::IO, operator, scalar)

Utility function to display a scaled operator as `scalar * operator`.
"""
function show_scaled(io::IO, operator, scalar)
if isone(scalar)
show(io, operator)
return nothing
end

if isreal(scalar) && isone(abs(scalar))
print(io, '-')
show(io, operator)
return nothing
end

show_unquoted(io, scalar, 0, Base.operator_precedence(:*))
print(io, " * ")
show_unquoted(io, operator, 0, Base.operator_precedence(:*))

return nothing
end

"""
show_scaled_unquoted(io::IO, operator, scalar, indent::Int, precedence::Int)

Utility function to display a scaled operator as `scalar * operator` within the context of
a larger expression. This function will parenthesize the scaled operator if necessary, based
on the relative precedence of `*` over `precedence`.

See also `Base.show_unquoted` and `Base.operator_precedence`.
"""
function show_scaled_unquoted(io::IO, operator, scalar, indent::Int, precedence::Int)
should_parenthesize =
!isone(scalar) &&
(!isreal(scalar) || !isone(abs(scalar))) &&
Base.operator_precedence(:*) ≤ precedence

if should_parenthesize
print(io, "(")
show_scaled(io, operator, scalar)
print(io, ")")
else
show_scaled(io, operator, scalar)
end

return nothing
end

"""
show_summed(io::IO, operators, [scalars])

Utility function to display a sum of operators as `operators[1] + operators[2] + ...`.
"""
function show_summed(io::IO, operators)
precedence = Base.operator_precedence(:+)
for (i, operator) in enumerate(operators)
if i == 1
show_unquoted(io, operator, 0, precedence)
else
print(io, " + ")
show_unquoted(io, operator, 0, precedence)
end
end
return nothing
end

function show_summed(io::IO, operators, scalars)
precedence = Base.operator_precedence(:+)

for (i, (operator, scalar)) in enumerate(zip(operators, scalars))
if i == 1
show_scaled_unquoted(io, operator, scalars[i], 0, precedence)
continue
end

# attempt to absorb the sign of the scalar
if isreal(scalar) && scalar < 0
print(io, " - ")
scalar = abs(scalar)
else
print(io, " + ")
end

show_scaled_unquoted(io, operator, scalar, 0, precedence)
end

return nothing
end

function show_summed_unquoted(io::IO, operators, indent::Int, precedence::Int)
if length(operators) == 1
show_unquoted(io, operators[1], indent, precedence)
return nothing
end

if Base.operator_precedence(:+) ≤ precedence
print(io, "(")
show_summed(io, operators)
print(io, ")")
else
show_summed(io, operators)
end

return nothing
end
function show_summed_unquoted(io::IO, operators, scalars, indent::Int, precedence::Int)
if length(operators) == 1
show_scaled_unquoted(io, only(operators), only(scalars)indent, precedence)
return nothing
end

if Base.operator_precedence(:+) ≤ precedence
print(io, "(")
show_summed(io, operators, scalars)
print(io, ")")
else
show_summed(io, operators, scalars)
end

return nothing
end

function show_product(io::IO, factors)
precedence = Base.operator_precedence(:*)

for (i, factor) in enumerate(factors)
if i == 1
show_unquoted(io, factor, 0, precedence)
else
print(io, " * ")
show_unquoted(io, factor, 0, precedence)
end
end
end

function show_product_unquoted(io::IO, factors, indent::Int, precedence::Int)
if length(operators) == 1
show_unquoted(io, only(factors), indent, precedence)
return nothing
end

if Base.operator_precedence(:*) ≤ precedence
print(io, "(")
show_prod(io, factors)
print(io, ")")
else
show_prod(io, factors)
end

return nothing
end
Loading
Loading