MWE: ```python In [1]: import exponax as ex In [2]: import jax In [3]: a = jax.random.normal(jax.random.key(0), (2, 1, 100)) In [4]: ex.metrics.sRMSE(a[0], a[1]) Out[4]: Array(1.5277674, dtype=float32) ```