Source code for detectors.pipelines.ood

"""
OOD Pipelines.
"""
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Literal, Tuple, Union

import accelerate
import numpy as np
import optuna
import pandas as pd
import torch
import torch.utils.data
import torchvision
from optuna.trial import TrialState
from torch import Tensor
from tqdm import tqdm

from detectors.data import create_dataset
from detectors.eval import get_ood_results
from detectors.methods import DetectorWrapper
from detectors.pipelines import register_pipeline
from detectors.pipelines.base import Pipeline
from detectors.utils import ConcatDatasetsDim1, sync_tensor_across_gpus

_logger = logging.getLogger(__name__)


METRICS_NAMES_PRETTY = {
    "fpr_at_0.95_tpr": "FPR at 95% TPR",
    "tnr_at_0.95_tpr": "TNR at 95% TPR",
    "detection_error": "Detection error",
    "auroc": "AUROC",
    "aupr_in": "AUPR in",
    "aupr_out": "AUPR out",
    "thr": "Threshold",
    "time": "Time",
}


[docs]class OODBenchmarkPipeline(Pipeline, ABC): """OOD Benchmark pipeline. Args: in_dataset_name (str): Name of the in-distribution dataset. out_datasets_names_splits (Dict[str, Any]): Dictionary mapping out-distribution dataset names to their splits. transform (Callable): Transform to apply to the datasets. batch_size (int): Batch size. num_workers (int, optional): Number of workers. Defaults to 4. pin_memory (bool, optional): Pin memory. Defaults to True. prefetch_factor (int, optional): Prefetch factor. Defaults to 2. limit_fit (float, optional): Fraction of the training set to use for fitting. Defaults to 1.0. limit_run (float, optional): Fraction of the testing set to use for running. Defaults to 1.0. seed (int, optional): Random seed. Defaults to 42. accelerator (Any, optional): Accelerator. Defaults to None. """ def __init__( self, in_dataset_name: str, out_datasets_names_splits: Dict[str, Any], transform: Callable, batch_size: int, num_workers: int = 4, pin_memory: bool = True, prefetch_factor: int = 2, limit_fit: float = 1.0, limit_run: float = 1.0, seed: int = 42, accelerator=None, ) -> None: self.in_dataset_name = in_dataset_name self.out_datasets_names_splits = out_datasets_names_splits self.out_datasets_names = list(out_datasets_names_splits.keys()) self.limit_fit = limit_fit self.limit_run = limit_run self.transform = transform self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.prefetch_factor = prefetch_factor self.seed = seed self.accelerator = accelerator self.fit_dataset = None self.in_dataset = None self.out_dataset = None self.out_datasets = None accelerate.utils.set_seed(seed) print("Setting up datasets...") self.setup() @abstractmethod def _setup_datasets(self): """Setup `in_dataset`, `out_dataset`, `fit_dataset` and `out_datasets`.""" ... def _setup_dataloaders(self): if self.fit_dataset is None or self.in_dataset is None or self.out_datasets is None or self.out_dataset is None: raise ValueError("Datasets are not set.") if self.limit_fit is None: self.limit_fit = 1.0 self.limit_fit = min(int(self.limit_fit * len(self.fit_dataset)), len(self.fit_dataset)) # random indices subset = np.random.choice(np.arange(len(self.fit_dataset)), self.limit_fit, replace=False).tolist() self.fit_dataset = torch.utils.data.Subset(self.fit_dataset, subset) self.fit_dataloader = torch.utils.data.DataLoader( self.fit_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, ) self.test_dataset = torch.utils.data.ConcatDataset([self.in_dataset, self.out_dataset]) test_labels = torch.utils.data.TensorDataset( torch.cat( [torch.zeros(len(self.in_dataset))] # type: ignore + [torch.ones(len(d)) * (i + 1) for i, d in enumerate(self.out_datasets.values())] # type: ignore ).long() ) self.test_dataset = ConcatDatasetsDim1([self.test_dataset, test_labels]) # shuffle and subsample test_dataset subset = np.random.choice( np.arange(len(self.test_dataset)), int(self.limit_run * len(self.test_dataset)), replace=False ).tolist() self.test_dataset = torch.utils.data.Subset(self.test_dataset, subset) self.test_dataloader = torch.utils.data.DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor, ) if self.accelerator is not None: self.fit_dataloader = self.accelerator.prepare(self.fit_dataloader) self.test_dataloader = self.accelerator.prepare(self.test_dataloader) _logger.info(f"Using {len(self.fit_dataset)} samples for fitting.") _logger.info(f"Using {len(self.test_dataset)} samples for testing.")
[docs] def setup(self): self._setup_datasets() self._setup_dataloaders()
[docs] def preprocess(self, method: DetectorWrapper) -> DetectorWrapper: if self.fit_dataset is None: _logger.warning("Fit dataset is not set or not supported. Returning.") return method if not hasattr(method.detector, "update"): _logger.warning("Detector does not support fitting. Returning.") return method disable = False if self.accelerator is not None: disable = not self.accelerator.is_main_process progress_bar = tqdm(range(len(self.fit_dataloader)), desc="Fitting", disable=disable) fit_length = len(self.fit_dataloader.dataset) example = next(iter(self.fit_dataloader))[0] method.start(example=example, fit_length=fit_length) for x, y in self.fit_dataloader: method.update(x, y) progress_bar.update(1) progress_bar.close() method.end() return method
[docs] def run(self, method: DetectorWrapper) -> Dict[str, Any]: self.method = method _logger.info("Running pipeline...") self.method = self.preprocess(self.method) # initialize based on dataset size dataset_size = len(self.test_dataloader.dataset) test_labels = torch.empty(dataset_size, dtype=torch.int64) test_scores = torch.empty(dataset_size, dtype=torch.float32) _logger.debug("test_labels shape: %s", test_labels.shape) _logger.debug("test_scores shape: %s", test_scores.shape) self.infer_times = [] idx = 0 disable = False if self.accelerator is not None: disable = not self.accelerator.is_main_process progress_bar = tqdm(range(len(self.test_dataloader)), desc="Inference", disable=disable) for x, y, labels in self.test_dataloader: t1 = time.time() score = self.method(x) t2 = time.time() if self.accelerator is not None: score = self.accelerator.gather_for_metrics(score) labels = self.accelerator.gather_for_metrics(labels) # score = sync_tensor_across_gpus(score.detach()) # labels = sync_tensor_across_gpus(labels.to(score.device)) self.infer_times.append(t2 - t1) test_labels[idx : idx + labels.shape[0]] = labels.cpu() test_scores[idx : idx + score.shape[0]] = score.cpu() idx += labels.shape[0] progress_bar.update(1) progress_bar.close() self.infer_times = np.mean(self.infer_times) test_scores = test_scores[:idx] test_labels = test_labels[:idx] res_obj = self.postprocess(test_scores, test_labels) return {"results": res_obj, "scores": test_scores, "labels": test_labels}
[docs] def postprocess(self, test_scores: Tensor, test_labels: Tensor): _logger.info("Computing metrics...") in_scores = test_scores[test_labels == 0] results = {} for i, ood_dataset_name in enumerate(self.out_datasets_names): ood_scores = test_scores[test_labels == (i + 1)] results[ood_dataset_name] = get_ood_results(in_scores, ood_scores) results[ood_dataset_name]["time"] = self.infer_times results["average"] = { k: np.mean([results[ds][k] for ds in self.out_datasets_names]) for k in results[self.out_datasets_names[0]].keys() } results["average"]["time"] = self.infer_times ood_scores = test_scores[test_labels > 0] return results
[docs] def report(self, results: Dict[str, Dict[str, Any]]) -> str: # log results in a table if "results" in results: results = results["results"] df = pd.DataFrame() for ood_dataset, res in results.items(): df = pd.concat([df, pd.DataFrame(res, index=[ood_dataset])]) df.columns = [METRICS_NAMES_PRETTY[k] for k in df.columns] return df.to_string(index=True, float_format="{:.4f}".format)
[docs]@register_pipeline("ood_benchmark_cifar10") class OODCifar10BenchmarkPipeline(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=128, seed=42, **kwargs) -> None: super().__init__( "cifar10", { "cifar100": "test", "svhn": "test", "isun": None, "lsun_c": None, "lsun_r": None, "tiny_imagenet_c": None, "tiny_imagenet_r": None, "textures": None, "places365": None, "english_chars": None, "uniform": None, "gaussian": None, }, transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform, download=True) self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform, download=True) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]@register_pipeline("ood_benchmark_cifar100") class OODCifar100BenchmarkPipeline(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=128, seed=42, **kwargs) -> None: super().__init__( "cifar100", { "cifar10": "test", "svhn": "test", "isun": None, "lsun_c": None, "lsun_r": None, "tiny_imagenet_c": None, "tiny_imagenet_r": None, "textures": None, "places365": None, "english_chars": None, "uniform": None, "gaussian": None, }, transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform, download=True) self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform, download=True) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]@register_pipeline("ood_benchmark_imagenet") class OODImageNetBenchmarkPipeline(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1.0, limit_run=1.0, batch_size=64, seed=42, **kwargs) -> None: super().__init__( "ilsvrc2012", { "mos_inaturalist": None, "mos_sun": None, "mos_places365": None, "textures": None, "imagenet_o": None, "openimage_o": None, "imagenet_a": None, "imagenet_r": None, "uniform": None, "gaussian": None, }, limit_fit=limit_fit, limit_run=limit_run, transform=transform, batch_size=batch_size, seed=seed, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform) self.in_dataset = create_dataset(self.in_dataset_name, split="val", transform=self.transform) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]@register_pipeline("ood_mnist_benchmark") class OODMNISTBenchmarkPipeline(OODBenchmarkPipeline): def __init__(self, transform: Callable, limit_fit=1, batch_size=64) -> None: super().__init__( "mnist", { "fashion_mnist": "test", "svhn": "test", "cifar10": "test", "textures": None, "english_chars": None, }, limit_fit=limit_fit, transform=transform, batch_size=batch_size, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.transform.transforms.append(torchvision.transforms.Grayscale(num_output_channels=1)) self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform) self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]class OODValidationPipeline(OODBenchmarkPipeline, ABC): """Pipeline for OOD validation. This pipeline is used to validate the performance of a model on OOD datasets. Args: method (DetectorWrapper): The OOD detection method to use. hyperparameters (Dict[str, Union[List[Any], Tuple[Any], Dict[str, Any]]]): The hyperparameters to use for the method. objective_metric (Literal["fpr_at_0.95_tpr", "auroc"], optional): The metric to optimize. Defaults to "auroc". n_trials (int, optional): The number of trials to run. Defaults to 20. """ # TODO: include prevent refit flag.
[docs] def run( self, method: DetectorWrapper, hyperparameters: Dict[str, Union[List[Any], Tuple[Any], Dict[str, Any]]], objective_metric: Literal["fpr_at_0.95_tpr", "auroc"] = "auroc", objective_dataset: str = "average", n_trials=20, ) -> Dict[str, Any]: self.method = method self.hyperparameters = hyperparameters self.objective_metric = objective_metric self.objective_dataset = objective_dataset direction = "maximize" if objective_metric == "auroc" else "minimize" sampler = None if all(isinstance(v, (list, tuple)) for v in hyperparameters.values()): sampler = optuna.samplers.GridSampler(search_space=hyperparameters) lengths = np.array([len(v) for v in hyperparameters.values()]) n_trials = min(int(np.prod(lengths)), n_trials) study = optuna.create_study(study_name="ood-val", sampler=sampler, direction=direction) study.optimize(self.objective, n_trials=n_trials, show_progress_bar=True) self.method.set_hyperparameters(**study.best_params) return { "method": self.method, "study": study, "best_params": study.best_params, "best_value": study.best_trial.value, }
[docs] def objective(self, trial: optuna.trial.Trial) -> float: # build detector from trial params new_params = {} for k in self.hyperparameters: if isinstance(self.hyperparameters[k], (list, tuple)): new_params[k] = trial.suggest_categorical(k, self.hyperparameters[k]) elif isinstance(self.hyperparameters[k], dict): step = self.hyperparameters[k]["step"] low = self.hyperparameters[k]["low"] high = self.hyperparameters[k]["high"] param_type = type(step) if param_type == float: new_params[k] = trial.suggest_float(k, low=low, high=high, step=step) elif param_type == int: new_params[k] = trial.suggest_int(k, low=low, high=high, step=step) self.method.set_hyperparameters(**new_params) # print methods params run_obj = super().run(self.method) results = run_obj["results"] return results[self.objective_dataset][self.objective_metric]
[docs] def report(self, results: Dict[str, Any]): if "study" not in results: raise ValueError("The results dict must contain a 'study' key.") study = results["study"] pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) best_trial = study.best_trial best_value = best_trial.value report_str = f""" Study statistics: Number of finished trials: {len(study.trials)} Number of pruned trials: {len(pruned_trials)} Number of complete trials: {len(complete_trials)} Best trial: Value: {best_value} Params: {best_trial.params} """ return report_str
[docs]@register_pipeline("ood_validation_cifar10") class OODCifar10ValidationPipeline(OODCifar10BenchmarkPipeline, OODValidationPipeline): def __init__(self, transform: Callable, limit_fit=1, limit_run=0.1, batch_size=128, seed=42, **kwargs) -> None: super().__init__( transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed )
[docs]@register_pipeline("ood_validation_noise_cifar10") class OODCifar10NoiseValidationPipeline(OODValidationPipeline): def __init__(self, transform: Callable, limit_fit=1, limit_run=0.1, batch_size=128, seed=42, **kwargs) -> None: super().__init__( "cifar10", { "uniform": None, "gaussian": None, }, transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform) self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]@register_pipeline("ood_validation_cifar100") class OODCifar100ValidationPipeline(OODCifar100BenchmarkPipeline, OODValidationPipeline): def __init__(self, transform: Callable, limit_fit=1, limit_run=0.1, batch_size=128, seed=42, **kwargs) -> None: super().__init__( transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed )
[docs]@register_pipeline("ood_validation_noise_cifar100") class OODCifar100NoiseValidationPipeline(OODValidationPipeline): def __init__(self, transform: Callable, limit_fit=1, limit_run=0.1, batch_size=128, seed=42, **kwargs) -> None: super().__init__( "cifar100", { "uniform": None, "gaussian": None, }, transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed, ) def _setup_datasets(self): _logger.info("Loading In-distribution dataset...") self.fit_dataset = create_dataset(self.in_dataset_name, split="train", transform=self.transform) self.in_dataset = create_dataset(self.in_dataset_name, split="test", transform=self.transform) _logger.info("Loading OOD datasets...") self.out_datasets = { ds: create_dataset(ds, split=split, transform=self.transform, download=True) for ds, split in self.out_datasets_names_splits.items() } self.out_dataset = torch.utils.data.ConcatDataset(list(self.out_datasets.values()))
[docs]@register_pipeline("ood_validation_imagenet") class OODImageNetValidationPipeline(OODImageNetBenchmarkPipeline, OODValidationPipeline): def __init__(self, transform: Callable, limit_fit=1, limit_run=0.1, batch_size=64, seed=42, **kwargs) -> None: super().__init__( transform=transform, batch_size=batch_size, limit_fit=limit_fit, limit_run=limit_run, seed=seed )