Source code for detectors.methods.mahalanobis

import logging
from typing import List, Literal, Optional

import torch
from torch import Tensor, nn

from detectors.methods.templates import DetectorWithFeatureExtraction

from .utils import sklearn_cov_matrix_estimarion

_logger = logging.getLogger(__name__)


[docs]def mahalanobis_dist_forward_substitution(x: Tensor, y: Tensor, L: Tensor): return torch.sqrt(torch.sum(torch.square(torch.mm(x, L).unsqueeze(1) - torch.mm(y, L).unsqueeze(0)), dim=-1)).min( dim=1, keepdim=True )[0]
[docs]def mahalanobis_distance_inv_fast(x: Tensor, y: Tensor, precision: Tensor): """Mahalanobis distance betwee x and y with an accelerated implementation. Args: x (Tensor): first point. y (Tensor): second point. precision (Tensor): inverse of the covariance matrix. """ d_squared = torch.mm(torch.mm(x - y, precision), (x - y).T).diag() return torch.sqrt(d_squared)
[docs]def mahalanobis_distance_inv(x: Tensor, y: Tensor, precision: Tensor): """Mahalanobis distance betwee x and y. Args: x (Tensor): first point. y (Tensor): second point. precision (Tensor): inverse of the covariance matrix. """ d_squared = torch.sum((x - y).T * torch.mm(precision, (x - y).T), dim=0) return torch.sqrt(d_squared)
[docs]def mahalanobis_inv_layer_score(x: Tensor, mus: Tensor, inv: Tensor) -> Tensor: stack = torch.zeros((x.shape[0], mus.shape[0]), device=x.device, dtype=x.dtype) for i, mu in enumerate(mus): stack[:, i] = mahalanobis_distance_inv(x, mu.reshape(1, -1), inv).reshape(-1) return -torch.nan_to_num(stack.min(1, keepdim=True)[0], nan=1e6)
[docs]def mahalanobis_inv_layer_score_fast(x: Tensor, mus: Tensor, inv: Tensor) -> Tensor: stack = torch.zeros((x.shape[0], mus.shape[0]), device=x.device, dtype=x.dtype) for i, mu in enumerate(mus): stack[:, i] = mahalanobis_distance_inv_fast(x, mu.reshape(1, -1), inv).reshape(-1) return -torch.nan_to_num(stack.min(1, keepdim=True)[0], nan=1e6)
[docs]def class_cond_mus_cov_inv_matrix( x: Tensor, targets: Tensor, cov_method: str = "EmpiricalCovariance", device=torch.device("cpu") ): class_cond_mean = {} centered_data_per_class = {} unique_classes = sorted(torch.unique(targets.detach().cpu()).numpy().tolist()) for c in unique_classes: filt = targets == c temp = x[filt].to(device) class_cond_mean[c] = temp.mean(0, keepdim=True) centered_data_per_class[c] = temp - class_cond_mean[c] class_cond_mean[c] = class_cond_mean[c].detach().cpu() centered_data_per_class[c] = centered_data_per_class[c].detach().cpu() centered_data_per_class = torch.vstack(list(centered_data_per_class.values())) mus = torch.vstack(list(class_cond_mean.values())) mu, cov_mat, inv_mat = sklearn_cov_matrix_estimarion(centered_data_per_class.numpy(), method=cov_method) cov_mat = torch.from_numpy(cov_mat).float() inv_mat = torch.from_numpy(inv_mat).float() return mus, cov_mat, inv_mat
HYPERPARAMETERS = dict( cov_mat_method=[ "EmpiricalCovariance", "GraphicalLasso", "GraphicalLassoCV", "LedoitWolf", "ShrunkCovariance", "OAS", ] )
[docs]class Mahalanobis(DetectorWithFeatureExtraction): """Mahalanobis OOD detector. Args: model (nn.Module): Model to be used to extract features features_nodes (Optional[List[str]]): List of strings that represent the feature nodes. Defaults to None. all_blocks (bool, optional): If True, use all blocks of the model. Defaults to False. last_layer (bool, optional): If True, use also the last layer of the model. Defaults to False. pooling_op_name (str, optional): Pooling operation to be applied to the features. Can be one of `max`, `avg`, `flatten`, `getitem`, `avg_or_getitem`, `max_or_getitem`, `none`. Defaults to `avg`. aggregation_method_name (str, optional): Aggregation method to be applied to the features. Defaults to None. cov_mat_method (str, optional): Covariance matrix estimation method. Can be one of: `EmpiricalCovariance`, `GraphicalLasso`, `GraphicalLassoCV`, `LedoitWolf`, `MinCovDet`, `ShrunkCovariance`, `OAS`. Defaults to `EmpiricalCovariance`. mu_cov_inv_est_fn (function, optional): Function to be used to estimate the means, covariance and inverse matrix. Defaults to `class_cond_mus_cov_inv_matrix`. cov_reg (float, optional): Covariance regularization. Defaults to 1e-6. References: [1] https://arxiv.org/abs/1807.03888 """ def __init__( self, model: nn.Module, features_nodes: Optional[List[str]] = None, all_blocks: bool = False, last_layer: bool = False, pooling_op_name: str = "avg_or_getitem", aggregation_method_name: Optional[str] = "mean", cov_mat_method: Literal[ "EmpiricalCovariance", "GraphicalLasso", "GraphicalLassoCV", "LedoitWolf", "MinCovDet", "ShrunkCovariance", "OAS", ] = "EmpiricalCovariance", mu_cov_inv_est_fn=class_cond_mus_cov_inv_matrix, cov_reg: float = 1e-6, **kwargs, ): super().__init__( model, features_nodes, all_blocks, last_layer, pooling_op_name, aggregation_method_name, **kwargs ) self.cov_mat_method = cov_mat_method self.mu_cov_inv_est_fn = mu_cov_inv_est_fn self.cov_reg = cov_reg def _layer_score(self, x: Tensor, layer_name: Optional[str] = None, index: Optional[int] = None): return mahalanobis_inv_layer_score( x, self.mus[layer_name].to(x.device), self.precision_chols[layer_name].to(x.device) ) def _fit_params(self) -> None: self.mus = {} self.invs = {} self.precision_chols = {} device = next(self.model.parameters()).device for layer_name, layer_features in self.train_features.items(): self.mus[layer_name], cov, self.invs[layer_name] = self.mu_cov_inv_est_fn( layer_features, self.train_targets, self.cov_mat_method, device=device ) cov_chol = torch.linalg.cholesky(cov.to(device) + self.cov_reg * torch.eye(cov.shape[1], device=device)) self.precision_chols[layer_name] = torch.linalg.solve_triangular( cov_chol, torch.eye(cov_chol.shape[1], device=device), upper=False ).T