Source code for detectors.methods.react_projection

from functools import partial
from typing import Callable, Dict, List

import torch
import torch.fx as fx
from torchvision.models.feature_extraction import get_graph_node_names

from detectors.methods.projection import Projection
from detectors.methods.react import condition_fn, insert_fn, reactify


[docs]class ReActProjection(Projection): def __init__( self, model: torch.nn.Module, features_nodes: List[str], pooling_name: str = "max", graph_nodes_names_thr: Dict[str, float] = {"flatten": 1.0}, insert_node_fn: Callable = insert_fn, aggregation_method=None, *args, **kwargs, ): self.graph_nodes_names_thr = graph_nodes_names_thr self.insert_node_fn = insert_node_fn for node_name, thr in self.graph_nodes_names_thr.items(): model = reactify( model, condition_fn=partial(condition_fn, equals_to=node_name), insert_fn=partial(self.insert_node_fn, thr=thr), ) super().__init__(model, features_nodes, pooling_name, aggregation_method, *args, **kwargs)
[docs]def test(): import torchvision.models as models model = models.densenet121() print(get_graph_node_names(model)[0]) graph: fx.Graph = fx.Tracer().trace(model) model = fx.GraphModule(model, graph) print(model.code) graph_nodes_names_thr = { "features_transition1_pool": 1.0, "features_transition2_pool": 1.0, "features_transition3_pool": 1, "features_norm5": 1, "flatten": 1, } for node_name, thr in graph_nodes_names_thr.items(): model = reactify(model, partial(condition_fn, equals_to=node_name), partial(insert_fn, thr=thr)) print(model.graph) print(model.code) model.classifier = torch.nn.Linear(1024, 3) model.eval() x = torch.rand(32, 3, 224, 224) y = torch.randint(0, 3, (32,)) projection = ReActProjection( model, ["clip", "clip_1", "clip_2", "clip_3", "clip_4", "classifier"], "max", graph_nodes_names_thr ) projection.fit(x, y) projection.end() print(projection(x)) assert projection(x).shape == (32,)
if __name__ == "__main__": test()