import torch
from torch import Tensor, nn
from detectors.utils import sync_tensor_across_gpus
[docs]def kl_divergence(p: Tensor, q: Tensor, eps=1e-6):
return (p * torch.log(p / (q + eps))).sum(1)
[docs]def js_divergence(p: Tensor, q: Tensor):
m = 0.5 * (p + q)
return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
[docs]class KLMatching:
"""KL-Matching detector.
Args:
model (nn.Module): classifier.
References:
[1] https://arxiv.org/abs/1911.11132
"""
def __init__(self, model: nn.Module, **kwargs):
self.model = model
self.model.eval()
self.centroids = {}
self.train_probs = []
self.predictions = []
[docs] def start(self, *args, **kwargs):
self.centroids = {}
self.train_probs = []
self.predictions = []
[docs] @torch.no_grad()
def update(self, x: Tensor, *args, **kwargs):
logits = self.model(x)
logits = sync_tensor_across_gpus(logits).cpu()
y_hat = torch.argmax(logits, dim=1).cpu()
probs = torch.softmax(logits, dim=1).cpu()
self.train_probs.append(probs)
self.predictions.append(y_hat)
[docs] def end(self):
self.train_probs = torch.cat(self.train_probs, dim=0)
self.predictions = torch.cat(self.predictions, dim=0)
for c in torch.unique(self.predictions).detach().cpu().numpy().tolist():
self.centroids[c] = torch.mean(self.train_probs[self.predictions == c], dim=0, keepdim=True)
@torch.no_grad()
def __call__(self, x: torch.Tensor) -> torch.Tensor:
probs = torch.softmax(self.model(x), dim=1)
predictions = probs.argmax(dim=1)
scores = torch.empty_like(predictions, dtype=torch.float32)
for label in torch.unique(predictions).detach().cpu().numpy().tolist():
if label not in self.centroids:
raise ValueError(f"Label {label} not found in training set.")
centroid = self.centroids[label].to(x.device)
scores[predictions == label] = kl_divergence(probs[predictions == label], centroid)
return -scores