Source code for detectors.methods.react

import logging
from functools import partial
from typing import Callable, List, Optional

import torch
import torch.fx as fx
import torch.utils.data
from torch import Tensor
from torchvision.models.feature_extraction import create_feature_extractor

_logger = logging.getLogger(__name__)

HYPERPARAMETERS = dict(p=dict(low=0.1, high=1.0, type=float, default=0.9, step=0.05))


[docs]def reactify(m: torch.nn.Module, condition_fn: Callable, insert_fn: Callable) -> torch.nn.Module: graph: fx.Graph = fx.Tracer().trace(m) # Transformation logic here for node in graph.nodes: if condition_fn(node): insert_fn(node, graph) # Return new Module return fx.GraphModule(m, graph)
[docs]def condition_fn(node, equals_to: str): if node.name == equals_to: return True return False
[docs]def insert_fn(node, graph: fx.Graph, thr: float = 1.0): with graph.inserting_after(node): new_node = graph.call_function(torch.clip, args=(node,), kwargs={"max": thr}) # change inputs of the next node and keep the input from the new node node.replace_all_uses_with(new_node) new_node.replace_input_with(new_node, node)
[docs]class ReAct: """ReAct detector. Args: model (torch.nn.Module): Model to be used to extract features features_nodes (Optional[List[str]]): List of strings that represent the feature nodes. Defaults to None. graph_nodes_names (Optional[List[str]]): List of strings that represent the graph nodes. Defaults to None. insert_node_fn (Callable): Function to be used to insert the node. Defaults to insert_fn. p (float, optional): Threshold to be used to clip the features. Defaults to 0.9. References: [1] https://arxiv.org/abs/2111.12797 """ LIMIT = 2_560_000 def __init__( self, model: torch.nn.Module, features_nodes: Optional[List[str]] = None, graph_nodes_names: Optional[List[str]] = None, insert_node_fn: Callable = insert_fn, p=0.9, **kwargs, ) -> None: self.model = model self.device = next(self.model.parameters()).device self.model.eval() self.features_nodes = features_nodes self.graph_nodes_names = graph_nodes_names if self.features_nodes is None: self.features_nodes = [list(self.model._modules.keys())[-2]] self.feature_extractor = create_feature_extractor(self.model, self.features_nodes) self.last_layer = list(self.model._modules.values())[-1] self.insert_node_fn = insert_node_fn self.p = p self.thr = None self.training_features = {}
[docs] def start(self, *args, **kwargs): self.training_features = {}
[docs] def update(self, x: Tensor, y: Tensor) -> None: if len(self.training_features.keys()) > 0: k = list(self.training_features.keys())[0] if self.training_features[k].view(-1).shape[0] > self.LIMIT: return with torch.no_grad(): features = self.feature_extractor(x) # accumulate training features if len(self.training_features) == 0: for k in features: self.training_features[k] = features[k].cpu() else: for k in features: self.training_features[k] = torch.cat((self.training_features[k], features[k].cpu()), dim=0)
[docs] def end(self, *args, **kwargs): self.thrs = list( { k: torch.quantile(self.training_features[k].view(-1)[: self.LIMIT].to(self.device), self.p).item() for k in self.training_features.keys() }.values() ) if self.graph_nodes_names is not None: for i, node_name in enumerate(self.graph_nodes_names): # add clipping node to every feature node in the graph passed in the constructor self.model = reactify( self.model, condition_fn=partial(condition_fn, equals_to=node_name), insert_fn=partial(insert_fn, thr=self.thrs[i]), ) _logger.info("ReAct thresholds = %s", dict(zip(self.features_nodes, self.thrs))) del self.training_features
@torch.no_grad() def __call__(self, x: Tensor) -> Tensor: if self.graph_nodes_names is not None: logits = self.model(x) else: features = torch.clip(list(self.feature_extractor(x).values())[-1], max=self.thrs[-1]) logits = self.last_layer(features) # type: ignore return torch.logsumexp(logits, dim=-1)