Source code for detectors.aggregations.innerprod

import logging

import torch
from torch import Tensor

_logger = logging.getLogger(__name__)


[docs]class InnerProductAggregation: def __init__(self, *args, **kwargs) -> None: self.max_trajectory = None self.ref_trajectory = None self.scale = None
[docs] def fit(self, stack: Tensor, *args, **kwargs): self.max_trajectory = stack.max(dim=0, keepdim=True)[0] self.ref_trajectory = stack.mean(dim=0, keepdim=True) / self.max_trajectory self.scale = torch.sum(self.ref_trajectory**2) _logger.debug("InnerProductAggregation parameters") _logger.debug(f"max_trajectory: {self.max_trajectory}") _logger.debug(f"ref_trajectory: {self.ref_trajectory}") _logger.debug(f"scale: {self.scale}")
def __call__(self, scores: Tensor, *args, **kwargs): scores = scores / self.max_trajectory.to(scores.device) scores = torch.sum(scores * self.ref_trajectory.to(scores.device), dim=1) / self.scale.to(scores.device) return scores