import logging
import types
from functools import partial
from torch import Tensor
from .anomaly import IFAggregation, LOFAggregation
from .basics import (
avg_topk_aggregation,
depth_weighted_sum,
layer_idx,
max_aggregation,
mean_aggregation,
median_aggregation,
min_aggregation,
none_aggregation,
topk_aggregation,
)
from .cosine import CosineAggregation
from .innerprod import InnerProductAggregation
from .mahalanobis import MahalanobisAggregation
from .quantile import QuantileAggregation
_logger = logging.getLogger(__name__)
aggregations_registry = {
"none": none_aggregation,
"mean": mean_aggregation,
"max": max_aggregation,
"min": min_aggregation,
"median": median_aggregation,
"dws": depth_weighted_sum,
"avg_topk": avg_topk_aggregation,
"topk": topk_aggregation,
"lof": LOFAggregation,
"if": IFAggregation,
"layer_idx": layer_idx,
"mahalanobis": MahalanobisAggregation,
"innerprod": InnerProductAggregation,
"cosine": CosineAggregation,
"quantile": QuantileAggregation,
}
[docs]class Aggregation:
"""Aggregation wrapper class."""
def __init__(self, aggregation_method, *args, **kwargs) -> None:
self.aggregation_method = aggregation_method
[docs] def fit(self, stack: Tensor, *args, **kwargs):
if not hasattr(self.aggregation_method, "fit"):
_logger.debug("Aggregation method does not have a `fit` method.")
return
self.aggregation_method.fit(stack)
def __call__(self, stack: Tensor, *args, **kwargs):
return self.aggregation_method(stack, *args, **kwargs)
[docs]def register_aggregation(name: str):
"""Decorator to register a new aggregation method."""
def decorator(f):
aggregations_registry[name] = f
return f
return decorator
[docs]def create_aggregation(aggregation_name: str, **kwargs) -> Aggregation:
if aggregation_name not in aggregations_registry:
raise ValueError(f"Unknown aggregation method: {aggregation_name}")
if not isinstance(aggregations_registry[aggregation_name], types.FunctionType):
return Aggregation(aggregations_registry[aggregation_name](**kwargs))
return Aggregation(partial(aggregations_registry[aggregation_name], **kwargs))
[docs]def list_aggregations() -> list:
return list(aggregations_registry.keys())