Source code for detectors.aggregations.quantile

import torch
from sklearn.preprocessing import QuantileTransformer
from torch import Tensor


[docs]class QuantileAggregation:
[docs] def fit(self, stack: Tensor, *args, **kwargs): stack = stack.detach().cpu().numpy() self.method = QuantileTransformer(n_quantiles=100, output_distribution="uniform") self.method.fit(stack)
def __call__(self, scores: Tensor, *args, **kwargs): device = scores.device scores = scores.detach().cpu().numpy() return torch.from_numpy(self.method.transform(scores).mean(1)).to(device).view(-1)