import logging
import torch
from torch import Tensor
_logger = logging.getLogger(__name__)
[docs]def mahalanobis_distance_inv_fast(x: Tensor, y: Tensor, precision: Tensor):
"""Mahalanobis distance betwee x and y with an accelerated implementation.
Args:
x (Tensor): first point.
y (Tensor): second point.
precision (Tensor): inverse of the covariance matrix.
"""
d_squared = torch.mm(torch.mm(x - y, precision), (x - y).T).diag()
return torch.sqrt(d_squared)
[docs]class MahalanobisAggregation:
def __init__(self, *args, **kwargs) -> None:
self.mu = None
self.pinv = None
[docs] def fit(self, stack: Tensor, *args, **kwargs):
self.mu = stack.mean(dim=0, keepdim=True)
self.pinv = torch.linalg.pinv(torch.cov(stack.T))
def __call__(self, scores: Tensor, *args, **kwargs):
return -mahalanobis_distance_inv_fast(scores, self.mu.to(scores.device), self.pinv.to(scores.device))