-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Description
We're now using the shrinkage technique from this paper in the concept-erasure repo; it makes covariance estimation robust to small sample sizes. Might make CRC-TPC, VINC, etc. work better
def gaussian_shrinkage(S_hat: Tensor, n: int) -> Tensor:
"""Applies Rao-Blackwell LW shrinkage to a sample covariance matrix."""
p = S_hat.shape[-1]
assert n > 1 and S_hat.shape == (p, p)
trace_S = torch.trace(S_hat)
trace_S_sq = torch.trace(S_hat ** 2)
trace_sq_S = trace_S ** 2
numer = (n - 2) / n * trace_S_sq + trace_sq_S
denom = (n + 2) * (trace_S_sq - trace_sq_S / p)
rho = torch.clamp(numer / denom, 0, 1)
eye = torch.eye(p, dtype=S_hat.dtype, device=S_hat.device)
F_hat = eye * trace_S / p
return (1 - rho) * S_hat + rho * F_hat
Metadata
Metadata
Assignees
Labels
No labels