Source code for detectors.methods.igeood_logits

import logging
from functools import partial
from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from detectors.methods.utils import input_pre_processing

_logger = logging.getLogger(__name__)


HYPERPARAMETERS = dict(temperature=dict(low=0.1, high=1000, step=0.1), eps=dict(low=0.0, high=0.005, step=0.0001))


[docs]def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12): logits = torch.sqrt(F.softmax(logits / temperature, dim=1)) centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1)) mult = logits @ centroids.T stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon)) return stack
def _score_fn(x: Tensor, model: torch.nn.Module, centroids: Tensor, temperature: float = 1.0, **kwargs) -> Tensor: logits = model(x) return igeoodlogits_vec(logits, temperature, centroids).mean(dim=1)
[docs]class IgeoodLogits: """IGEOOD detector. Args: model (nn.Module): classifier. temperature (float, optional): softmax temperature parameter. Defaults to 1.0. eps (float, optional): input preprocessing noise value. Defaults to 0.0 (no input preprocessing). References: [1] https://arxiv.org/abs/2203.07798 """ def __init__(self, model: torch.nn.Module, temperature: float = 1.0, eps: float = 0.0, **kwargs): self.model = model self.temperature = temperature self.eps = eps self.model.eval()
[docs] @torch.no_grad() def start(self, example: Optional[Tensor] = None, fit_length: Optional[int] = None, *args, **kwargs): self.train_features = [] self.train_targets = [] self.mus = [] self.idx = 0 if example is not None and fit_length is not None: logits = self.model(example) self.train_features = torch.zeros((fit_length,) + logits.shape[1:], dtype=logits.dtype) self.train_targets = torch.ones((fit_length,), dtype=torch.long) * -1
[docs] @torch.no_grad() def update(self, x: Tensor, y: Tensor, *args, **kwargs): self.batch_size = x.shape[0] logits = self.model(x) if isinstance(self.train_features, list): self.train_features.append(logits) else: self.train_features[self.idx : self.idx + logits.shape[0]] = logits if isinstance(self.train_targets, list): self.train_targets.append(y) else: self.train_targets[self.idx : self.idx + y.shape[0]] = y self.idx += y.shape[0]
[docs] def end(self, *args, **kwargs): if isinstance(self.train_features, list): self.train_features = torch.cat(self.train_features, dim=0) else: self.train_features = self.train_features[: self.idx] if isinstance(self.train_targets, list): self.train_targets = torch.cat(self.train_targets, dim=0) else: self.train_targets = self.train_targets[: self.idx] assert torch.all(self.train_targets > -1), "Not all targets were updated" self._fit_params() del self.train_features del self.train_targets
def _fit_params(self) -> None: self.mus = [] unique_classes = torch.unique(self.train_targets).detach().cpu().numpy().tolist() for c in unique_classes: filt = self.train_targets == c if filt.sum() == 0: continue self.mus.append(self.train_features[filt].mean(0, keepdim=True)) self.mus = torch.cat(self.mus, dim=0) def __call__(self, x: Tensor) -> Tensor: self.mus = self.mus.to(x.device) if self.eps > 0: x = input_pre_processing( partial(_score_fn, model=self.model, temperature=self.temperature, centroids=self.mus), x, self.eps ) with torch.no_grad(): return _score_fn(x, self.model, self.mus, temperature=self.temperature)