-
Notifications
You must be signed in to change notification settings - Fork 89
Description
While profiling qiboml backends, I noticed that for large number of qubits (e.g. 15) the einsum performed in backend.apply_gate takes the majority of the time in circuit execution. Therefore, I've been playing around with tensor unfolding in order to replace the einsum contraction we use right now. In detail, if you move the contracted axes to the front of the state tensor and then unfold it you're able to obtain the same contraction with just a simple matrix product @ (naturally you also have to move back the axes and re-fold the tensor afterwards).
I've done some tests with random gate matrices and random states of 15 qubits, and this approach seems to provide a substantial speedup compared to einsum. These are the average results for 50 random matrices/states with pytorch (I tried numpy as well whose einsum appears to be even slower, probably due to the fact that it does not make use of opt_einsum by default):
####### 1-qubits matrix ########
> einsum: 0.019747539659874747
> unfold: 0.006024982599919895
####### 2-qubits matrix ########
> einsum: 0.027757639880073837
> unfold: 0.006263335839939828
####### 3-qubits matrix ########
> einsum: 0.02750167398011399
> unfold: 0.0069057601801250715
Thus 3~4 times faster depending on the case, you can try out your self with the script:
import torch, random
from timeit import timeit
from functools import cache
from qibo.backends import einsum_utils
from qiboml.backends import PyTorchBackend
def einsum(backend, matrix, state, qubits, nqubits):
matrix = backend.np.reshape(matrix, 2 * len(qubits) * (2,))
opstring = einsum_utils.apply_gate_string(qubits, nqubits)
state = backend.np.einsum(opstring, state, matrix)
return state
@cache
def permutations(qubits, nqubits):
fwd_perm = list(qubits) + [q for q in range(nqubits) if q not in qubits]
inv_perm = zip(list(range(nqubits)), fwd_perm)
inv_perm, _ = list(zip(*sorted(inv_perm, key=lambda x: x[1])))
return fwd_perm, inv_perm
def unfolded_matrix_product(backend, matrix, state, qubits, nqubits):
shape = state.shape
fwd_perm, inv_perm = permutations(qubits, nqubits)
state = backend.np.transpose(state, fwd_perm)
state = state.reshape(2 ** len(qubits), -1)
state = matrix @ state
state = state.reshape(shape)
state = backend.np.transpose(state, inv_perm)
return state
def bench(backend, state, matrix, gate_qubits, nqubits):
qubits = tuple(random.sample(range(nqubits), k=gate_qubits))
einsum_res = einsum(backend, matrix, state, qubits, nqubits)
unfolded_res = unfolded_matrix_product(backend, matrix, state, qubits, nqubits)
diff = backend.np.abs((einsum_res - unfolded_res).ravel())
var = locals()
var["einsum"] = einsum
var["unfolded_matrix_product"] = unfolded_matrix_product
try:
assert max(diff) < 1e-6
except:
print(qubits)
print(diff)
print(max(diff))
return (
timeit("einsum(backend, matrix, state, qubits, nqubits)", globals=var, number=100),
timeit("unfolded_matrix_product(backend, matrix, state, qubits, nqubits)", globals=var, number=100)
)
if __name__ == "__main__":
backend = PyTorchBackend()
nqubits = 15
dtype = torch.complex128
state = torch.randn(2**nqubits).reshape(nqubits * (2,)).type(dtype)
matrix_1q = torch.randn(2, 2).type(dtype)
matrix_2q = torch.randn(4, 4).type(dtype)
matrix_3q = torch.randn(8, 8).type(dtype)
for n, m in enumerate([matrix_1q, matrix_2q, matrix_3q], start=1):
print(f"####### {n}-qubits matrix ########")
einsum_times, unfold_times = [], []
N = 50
for _ in range(N):
res = bench(backend, state, m, n, nqubits)
einsum_times.append(res[0])
unfold_times.append(res[1])
print(f" > einsum: {sum(einsum_times) / N}")
print(f" > unfold: {sum(unfold_times) / N}")I would therefore advocate for:
- replacing the
einsumwith unfolding +@ - avoiding to reshape the state every time a gate is applied and always work with a state of form
(2, 2, 2, ...), and do the reshape just at the beginning and end of the circuit execution