Source code for detectors.methods.templates
"""
Generalized detection methods templates.
"""
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import torch
import torch.distributed
from torch import Tensor, nn
from torchvision.models.feature_extraction import create_feature_extractor
from tqdm import tqdm
from detectors.aggregations import create_aggregation
from detectors.utils import sync_tensor_across_gpus
from .utils import add_output_op, create_reduction
_logger = logging.getLogger(__name__)
[docs]class Detector(ABC):
"""Detector base class."""
def __init__(self, **kwargs):
pass
[docs] def start(self, example: Optional[Tensor] = None, fit_length: Optional[int] = None, *args, **kwargs):
"""Setup detector for fitting parameters.
Args:
example (Optional[Tensor], optional): Input example. Useful for pre-allocating memory.
Defaults to None.
fit_length (Optional[int], optional): Length of the fitting dataset. Useful for pre-allocating memory.
Defaults to None.
This is called before the first call to `update` and is optional.
"""
pass
[docs] def update(self, x: Tensor, y: Tensor, *args, **kwargs):
"""Accumulate features for detector.
Args:
x (Tensor): input tensor.
y (Tensor): target tensor.
This is called for each batch of the fitting dataset and is optional.
"""
pass
[docs] def end(self, *args, **kwargs):
"""Finalize detector fitting process.
This is called after the last call to `update` and is optional.
"""
pass
[docs] def fit(self, dataloader, **kwargs):
"""Fit detector to a dataset.
Args:
dataloader (Dataloader): Dataloader for the fitting dataset.
"""
fit_length = len(dataloader.dataset)
# get example
x, y = next(iter(dataloader))
self.start(example=x, fit_length=fit_length, **kwargs)
for x, y in dataloader:
self.update(x, y, **kwargs)
self.end(**kwargs)
return self
[docs] @abstractmethod
def __call__(self, x: Tensor) -> Tensor:
"""Compute scores for each input at test time.
Args:
x (Tensor): input tensor.
Returns:
Tensor: scores for each input.
"""
raise NotImplementedError
[docs]class DetectorWrapper(Detector):
"""Detector interface."""
def __init__(self, detector, **kwargs):
self.detector = detector
if hasattr(self.detector, "model"):
self.model = self.detector.model
self.detector.model.eval()
elif hasattr(self.detector, "keywords") and "model" in self.detector.keywords:
self.model = self.detector.keywords["model"]
else:
self.model = None
self.keywords = kwargs
if self.model is not None:
self.device = next(self.model.parameters()).device
else:
self.device = torch.device("cpu")
[docs] def start(self, example: Optional[Tensor] = None, fit_length: Optional[int] = None, *args, **kwargs):
if not hasattr(self.detector, "start"):
_logger.warning("Detector does not have a start method.")
return
if example is not None:
example = example.to(self.device)
self.detector.start(example, fit_length, *args, **kwargs)
[docs] def update(self, x: Tensor, y: Tensor, *args, **kwargs):
if not hasattr(self.detector, "update"):
_logger.warning("Detector does not have an update method.")
return
x = x.to(self.device)
y = y.to(self.device)
self.detector.update(x, y, *args, **kwargs)
[docs] def end(self, *args, **kwargs):
if not hasattr(self.detector, "end"):
_logger.warning("Detector does not have an end method.")
return
self.detector.end(*args, **kwargs)
[docs] def fit(self, dataloader, **kwargs):
# get fit length # CHECK BUG
fit_length = len(dataloader.dataset)
# get example
x, y = next(iter(dataloader))
self.start(example=x, fit_length=fit_length, **kwargs)
for x, y in dataloader:
self.update(x, y, **kwargs)
self.end(**kwargs)
return self
def __call__(self, x: Tensor) -> Tensor:
x = x.to(self.device)
return self.detector(x)
[docs] def set_hyperparameters(self, **params):
"""Set the parameters of the detector."""
model = params.pop("model", self.model)
self.keywords.update(params)
if hasattr(self.detector, "keywords"):
self.detector.keywords.update(**params)
else:
self.detector = self.detector.__class__(model=model, **self.keywords)
return self
[docs] def save_params(self, path):
"""Save the parameters of the detector."""
raise NotImplementedError
[docs] def load_params(self, path):
"""Load the parameters of the detector."""
raise NotImplementedError
[docs] def __repr__(self):
"""Return the string representation of the detector."""
return f"{self.__class__.__name__}()"
[docs]class DetectorWithFeatureExtraction(Detector):
"""Base class for OOD detectors with feature extraction.
Args:
model (nn.Module): Model to be used to extract features
features_nodes (Optional[List[str]]): List of strings that represent the feature nodes.
Defaults to None.
all_blocks (bool, optional): If True, use all blocks of the model. Defaults to False.
last_layer (bool, optional): If True, use also the last layer of the model. Defaults to False.
pooling_op_name (str, optional): Pooling operation to be applied to the features.
Can be one of:
`max`, `avg`, `none`, `flatten`, `getitem`, `avg_or_getitem`, `max_or_getitem`.
Defaults to "avg".
aggregation_method_name (str, optional): Aggregation method to be applied to the features. Defaults to None.
**kwargs
"""
def __init__(
self,
model: nn.Module,
features_nodes: Optional[List[str]] = None,
all_blocks: bool = False,
last_layer: bool = False,
pooling_op_name: str = "avg_or_getitem",
aggregation_method_name: Optional[str] = "mean",
**kwargs,
):
self.model = model
self.model.eval()
self.features_nodes = features_nodes
self.all_blocks = all_blocks
self.pooling_op_name = pooling_op_name
self.aggregation_method_name = aggregation_method_name or "none"
# feature feature reduction operation
self.reduction_op = create_reduction(self.pooling_op_name)
if self.features_nodes is not None:
# if features nodes were explicitly specified, use them
pass
elif hasattr(self.model, "feature_info") and self.all_blocks:
# if all_blocks is True, use all blocks of the model
self.features_nodes = [fi["module"] for fi in self.model.feature_info][1:] # type: ignore
else:
# extract from the penultimate layer only
self.features_nodes = [list(self.model._modules.keys())[-2]]
if last_layer:
# if last_layer is True, use the last layer of the model
self.last_layer_name = list(self.model._modules.keys())[-1]
if self.features_nodes is None:
self.features_nodes = [self.last_layer_name]
else:
self.features_nodes.append(self.last_layer_name)
# remove duplicates
self.features_nodes = list(set(self.features_nodes))
_logger.info("Using features nodes: %s", self.features_nodes)
self.feature_extractor = create_feature_extractor(self.model, self.features_nodes)
self.feature_extractor.eval()
# insert reduction operation after each node
def output_reduce(x: Dict[str, Tensor]):
return {k: self.reduction_op(v) for k, v in x.items()}
self.feature_extractor = add_output_op(self.feature_extractor, output_reduce)
self.aggregation_method = create_aggregation(self.aggregation_method_name, **kwargs)
self.train_features = {}
self.train_targets = []
self.idx = 0
[docs] @torch.no_grad()
def start(self, example: Optional[Tensor] = None, fit_length: Optional[int] = None, *args, **kwargs):
self.train_features = {}
self.train_targets = []
self.idx = 0
if example is not None and fit_length is not None:
self.feature_extractor.to(example.device)
example_output = self.feature_extractor(example)
for node_name, v in example_output.items():
_logger.debug((fit_length,) + v.shape[1:])
self.train_features[node_name] = torch.empty((fit_length,) + v.shape[1:], dtype=v.dtype)
self.train_targets = torch.empty((fit_length,), dtype=torch.long)
[docs] @torch.no_grad()
def update(self, x: Tensor, y: Tensor, *args, **kwargs):
self.batch_size = x.shape[0]
# self.feature_extractor.to(x.device)
features: Dict[str, Tensor] = self.feature_extractor(x)
for node_name, v in features.items():
v = sync_tensor_across_gpus(v).cpu()
if node_name not in self.train_features:
self.train_features[node_name] = [v]
elif isinstance(self.train_features[node_name], list):
self.train_features[node_name].append(v)
else:
self.train_features[node_name][self.idx : self.idx + v.shape[0]] = v
y = sync_tensor_across_gpus(y).cpu()
if isinstance(self.train_targets, list):
self.train_targets.append(y)
else:
self.train_targets[self.idx : self.idx + y.shape[0]] = y
self.idx += y.shape[0]
[docs] def end(self, *args, **kwargs):
for node_name, v in self.train_features.items():
if isinstance(v, list):
self.train_features[node_name] = torch.cat(v, dim=0)
else:
self.train_features[node_name] = v[: self.idx]
if isinstance(self.train_targets, list):
self.train_targets = torch.cat(self.train_targets, dim=0)
else:
self.train_targets = self.train_targets[: self.idx]
self._fit_params()
_logger.debug("Fitting aggregator %s...", self.aggregation_method_name)
self.batch_size = self.train_targets.shape[0] # type: ignore
all_scores = torch.zeros(self.train_targets.shape[0], len(self.train_features))
for i, (k, v) in tqdm(enumerate(self.train_features.items())):
idx = 0
for idx in range(0, v.shape[0], self.batch_size):
all_scores[:, i] = self._layer_score(v[idx : idx + self.batch_size], k, i).view(-1)
self.aggregation_method.fit(all_scores, self.train_targets)
# TODO: compile graph with _layer_score
del self.train_features
del self.train_targets
@abstractmethod
def _fit_params(self) -> None:
"""Fit the data to the parameters of the detector."""
pass
@abstractmethod
def _layer_score(self, features: Tensor, layer_name: Optional[str] = None, index: Optional[int] = None, **kwargs):
"""Compute the anomaly score for a single layer.
Args:
features (Tensor): features input tensor.
layer_name (str, optional): name of the layer. Defaults to None.
index (int, optional): index of the layer in the feature extractor. Defaults to None.
"""
raise NotImplementedError
@torch.no_grad()
def __call__(self, x: Tensor) -> Tensor:
# self.feature_extractor.to(x.device)
features = self.feature_extractor(x)
all_scores = torch.zeros(x.shape[0], len(features), device=x.device)
for i, (k, v) in enumerate(features.items()):
all_scores[:, i] = self._layer_score(v, k, i).view(-1)
all_scores = self.aggregation_method(all_scores)
return all_scores.view(-1)
[docs]class SimpleFeatureExtractor:
def __init__(self, model):
self.model = model
def __call__(self, x):
return {"features": self.model(x)}
[docs]class DetectorWithSimpleFE(DetectorWithFeatureExtraction, ABC):
"""Detector that uses the forward pass of the model."""
def __init__(self, model: nn.Module, pooling_op_name: str = "avg_or_getitem", **kwargs):
self.model = model
self.feature_extractor = SimpleFeatureExtractor(self.model)
self.pooling_op_name = pooling_op_name
self.reduction_op = create_reduction(self.pooling_op_name)
self.aggregation_method = create_aggregation("none")
self.train_features = {}
self.train_targets = []
self.idx = 0
@abstractmethod
def _fit_params(self) -> None:
"""Fit the data to the parameters of the detector."""
pass
@abstractmethod
def _layer_score(self, features: Tensor, layer_name: Optional[str] = None, index: Optional[int] = None, **kwargs):
"""Compute the anomaly score for a single layer.
Args:
features (Tensor): features input tensor.
layer_name (str, optional): name of the layer. Defaults to None.
index (int, optional): index of the layer in the feature extractor. Defaults to None.
"""
raise NotImplementedError