Source code for detectors.aggregations.cosine
import logging
import torch
from torch import Tensor
_logger = logging.getLogger(__name__)
[docs]class CosineAggregation:
def __init__(self, *args, **kwargs) -> None:
self.mu = None
def __call__(self, scores: Tensor, *args, **kwargs):
return torch.nn.functional.cosine_similarity(scores, self.mu.to(scores.device), dim=1)