Description
One of the main uses of QR decomposition A=QR is in solving the linear least squares problem Ax=b. In order to solve this problem, the only access we need of Q is the ability to apply it (or its transpose, or adjoint) to vectors, ie we need only blackbox access to Q to its multiplication operators. The QR decomposition algorithm initially forms Q indirectly, as a product of householder transformations, see jax._src.lax.geqrf and jax._src.lax.geqp3 for the pivoted version. This form is efficient for functional access, through the lapack functions ormqr and equivalents, but this functionality is not exposed by jax. Instead, jax exposes only the function householder_product, which wraps the lapack function orgqr and equivalents. Therefore, for the application to linear least squares, we are forced to wastefully materialize the matrix Q explicitly and use the @ operator. This is done, eg in https://github.com/patrick-kidger/lineax/blob/main/lineax/_solver/qr.py .
I believe the best solution is to expose ormqr in the same way as we expose orgqr. This will allow, along with the existing triangular_solve, efficient implementation of the algorithm implemented in lapack function sgels for linear least squares for full rank matrices.
Another remark: It is tempting to expose the functionality of sgels directly (and other lapack drivers) as scipy.linalg.lstsq does. Perhaps this should also be done, but in the context of wanting to do jvp through linear least squares solves, we solve not only Ax=b but systems A^t y = c. Both of these can be done efficiently factoring A only once (this is the approach taken by lineax), but this possibility is not exposed by the scipy.linalg.lstsq interface.
Adding ormgr would also allow an almost completely efficient implementation of sgelsy, which is faster than singular value decomposition for minimum norm least squares for rank deficient problems (allowing one factorization to do the jvp also). Here, to do the maximum, one also wants stzrzf exposed to do the complete orthogonal transformation and make use of upper triangularity. However, due to jax wanting static shapes for intermediate terms (the cannot be rank dependent, eg), maybe what is actually needed is an stzrzf which can be instructed to ignore some zero rows padding the bottom. (I am just starting out in jax, maybe the last sentence is unnecessary for some reason?)