Skip to content

Tensor unfolding + matrix product @ for apply gate #1710

@BrunoLiegiBastonLiegi

Description

@BrunoLiegiBastonLiegi

While profiling qiboml backends, I noticed that for large number of qubits (e.g. 15) the einsum performed in backend.apply_gate takes the majority of the time in circuit execution. Therefore, I've been playing around with tensor unfolding in order to replace the einsum contraction we use right now. In detail, if you move the contracted axes to the front of the state tensor and then unfold it you're able to obtain the same contraction with just a simple matrix product @ (naturally you also have to move back the axes and re-fold the tensor afterwards).
I've done some tests with random gate matrices and random states of 15 qubits, and this approach seems to provide a substantial speedup compared to einsum. These are the average results for 50 random matrices/states with pytorch (I tried numpy as well whose einsum appears to be even slower, probably due to the fact that it does not make use of opt_einsum by default):

####### 1-qubits matrix ########
 > einsum: 0.019747539659874747
 > unfold: 0.006024982599919895
####### 2-qubits matrix ########
 > einsum: 0.027757639880073837
 > unfold: 0.006263335839939828
####### 3-qubits matrix ########
 > einsum: 0.02750167398011399
 > unfold: 0.0069057601801250715

Thus 3~4 times faster depending on the case, you can try out your self with the script:

import torch, random

from timeit import timeit
from functools import cache

from qibo.backends import einsum_utils
from qiboml.backends import PyTorchBackend


def einsum(backend, matrix, state, qubits, nqubits):
    matrix = backend.np.reshape(matrix, 2 * len(qubits) * (2,))
    opstring = einsum_utils.apply_gate_string(qubits, nqubits)
    state = backend.np.einsum(opstring, state, matrix)
    return state

@cache
def permutations(qubits, nqubits):
    fwd_perm = list(qubits) + [q for q in range(nqubits) if q not in qubits]
    inv_perm = zip(list(range(nqubits)), fwd_perm)
    inv_perm, _ = list(zip(*sorted(inv_perm, key=lambda x: x[1])))
    return fwd_perm, inv_perm

def unfolded_matrix_product(backend, matrix, state, qubits, nqubits):
    shape = state.shape
    fwd_perm, inv_perm = permutations(qubits, nqubits)
    state = backend.np.transpose(state, fwd_perm)
    state = state.reshape(2 ** len(qubits), -1)
    state = matrix @ state
    state = state.reshape(shape)
    state = backend.np.transpose(state, inv_perm)
    return state


def bench(backend, state, matrix, gate_qubits, nqubits):
    
    qubits = tuple(random.sample(range(nqubits), k=gate_qubits))
    einsum_res = einsum(backend, matrix, state, qubits, nqubits)
    unfolded_res = unfolded_matrix_product(backend, matrix, state, qubits, nqubits)
    diff = backend.np.abs((einsum_res - unfolded_res).ravel())
    var = locals()
    var["einsum"] = einsum
    var["unfolded_matrix_product"] = unfolded_matrix_product
    try:
        assert max(diff) < 1e-6
    except:
        print(qubits)
        print(diff)
        print(max(diff))
    return (
        timeit("einsum(backend, matrix, state, qubits, nqubits)", globals=var, number=100),
        timeit("unfolded_matrix_product(backend, matrix, state, qubits, nqubits)", globals=var, number=100)
    )

if __name__ == "__main__":

    backend = PyTorchBackend()
    nqubits = 15
    dtype = torch.complex128
    state = torch.randn(2**nqubits).reshape(nqubits * (2,)).type(dtype)
    matrix_1q = torch.randn(2, 2).type(dtype)
    matrix_2q = torch.randn(4, 4).type(dtype)
    matrix_3q = torch.randn(8, 8).type(dtype)

    for n, m in enumerate([matrix_1q, matrix_2q, matrix_3q], start=1):
        print(f"####### {n}-qubits matrix ########")
        einsum_times, unfold_times = [], []
        N = 50
        for _ in range(N):
            res = bench(backend, state, m, n, nqubits)
            einsum_times.append(res[0])
            unfold_times.append(res[1])
        print(f" > einsum: {sum(einsum_times) / N}")
        print(f" > unfold: {sum(unfold_times) / N}")

I would therefore advocate for:

  • replacing the einsum with unfolding + @
  • avoiding to reshape the state every time a gate is applied and always work with a state of form (2, 2, 2, ...), and do the reshape just at the beginning and end of the circuit execution

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions