Source code for detectors.pipelines.covariate_drift

import logging
from typing import Any, Dict, List

import accelerate
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import sklearn.metrics
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import Tensor
from tqdm import tqdm

import detectors
from detectors.data import create_dataset
from detectors.methods import DetectorWrapper
from detectors.pipelines import register_pipeline
from detectors.pipelines.base import Pipeline
from detectors.utils import ConcatDatasetsDim1

_logger = logging.getLogger(__name__)


[docs]class CovariateDriftPipeline(Pipeline): def __init__( self, dataset_name: str, dataset_splits: List[str], transform, corruptions: List[str], intensities: List[int], batch_size: int = 128, limit_fit: float = 1.0, warmup_size=2000, seed=42, **kwargs, ) -> None: """Covariate Drift Pipeline. Covariate drift event: when moving accuracy is below threshold compared to training accuracy Args: dataset_name (str): Name of the dataset. dataset_splits (List[str]): List of dataset splits to use in fit and in dataset, respectively. transform (Callable): Transform to apply to the dataset. corruptions (List[str]): List of corruptions to apply. intensities (List[int]): List of intensities to apply. batch_size (int, optional): Batch size. Defaults to 128. """ self.accelerator = accelerate.Accelerator() accelerate.utils.set_seed(seed) _logger.info("Creating datasets...") fit_dataset = create_dataset(dataset_name, split=dataset_splits[0], transform=transform) # shuffle fit dataset limit_fit = limit_fit or 1 limit_fit = min(int(limit_fit * len(fit_dataset)), len(fit_dataset)) indices = range(len(fit_dataset))[:limit_fit] indices = torch.randperm(len(fit_dataset)).numpy()[:limit_fit] fit_dataset = torch.utils.data.Subset(fit_dataset, indices) in_dataset = create_dataset(dataset_name, split=dataset_splits[1], transform=transform) # shuffle in dataset indices = torch.randperm(len(in_dataset)).numpy() max_dataset_size = len(in_dataset) // (len(corruptions) * len(intensities) + 1) self.splits = torch.arange(0, len(in_dataset), max_dataset_size) in_dataset = torch.utils.data.Subset(in_dataset, indices[np.arange(0, self.splits[1].item())]) warmup_dataset = torch.utils.data.Subset(fit_dataset, range(len(fit_dataset))[:warmup_size]) in_dataset = torch.utils.data.ConcatDataset([warmup_dataset, in_dataset]) out_datasets = {} for i, corruption in enumerate(corruptions): out_datasets[corruption] = [] for j, intensity in enumerate(intensities): _indices = indices[torch.arange(self.splits[i + j + 1].item(), self.splits[i + j + 2].item())] out_datasets[corruption].append( torch.utils.data.Subset( create_dataset(dataset_name + "_c", split=corruption, intensity=intensity, transform=transform), _indices, ) ) self.splits = self.splits.numpy() + warmup_size self.splits = [0] + self.splits.tolist() _logger.info("Data splits are: %s", self.splits) # increasing intensity concat dataset out_dataset = [] for i, intensity in enumerate(intensities): # shuffle corruptions? out_dataset.append( torch.utils.data.ConcatDataset([out_datasets[corruption][i] for corruption in corruptions]) ) out_dataset = torch.utils.data.ConcatDataset(out_dataset) _logger.debug("Fit dataset size: %s", {len(fit_dataset)}) _logger.debug("In dataset size: %s", {len(in_dataset)}) _logger.debug("Out dataset size: %s", {len(out_dataset)}) self.in_dataset = in_dataset self.out_dataset = out_dataset self.fit_dataset = fit_dataset self.batch_size = batch_size self.warmup_size = warmup_size self.limit_fit = limit_fit self.setup()
[docs] def setup(self): test_dataset = torch.utils.data.ConcatDataset([self.in_dataset, self.out_dataset]) test_labels = torch.utils.data.TensorDataset( torch.cat([torch.zeros(len(self.in_dataset))] + [torch.ones(len(self.out_dataset))]).long() # type: ignore ) self.test_dataset = ConcatDatasetsDim1([test_dataset, test_labels]) self.test_dataloader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) self.fit_dataloader = torch.utils.data.DataLoader(self.fit_dataset, batch_size=self.batch_size, shuffle=True) self.fit_dataloader = self.accelerator.prepare(self.fit_dataloader) # careful with this with multiple gpus self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
[docs] def preprocess(self, method: DetectorWrapper) -> DetectorWrapper: if not hasattr(method.detector, "update"): return method if self.fit_dataset is None: _logger.warning("Fit is not set or not supported. Returning.") return method if method.model is not None: method.detector.model = self.accelerator.prepare(method.detector.model) progress_bar = tqdm( range(len(self.fit_dataloader)), desc="Fitting", disable=not self.accelerator.is_local_main_process ) method.start() for x, y in self.fit_dataloader: method.update(x, y) progress_bar.update(1) progress_bar.close() self.accelerator.wait_for_everyone() method.end() return method
[docs] def run(self, method, model, **kwargs): method = method test_labels = torch.empty(len(self.test_dataset), dtype=torch.long) test_targets = torch.empty(len(self.test_dataset), dtype=torch.long) test_scores = torch.empty(len(self.test_dataset), dtype=torch.float) test_preds = torch.empty(len(self.test_dataset), dtype=torch.long) idx = 0 progress_bar = tqdm( range(len(self.test_dataloader)), desc="Inference", disable=not self.accelerator.is_local_main_process ) for x, y, labels in self.test_dataloader: scores = method(x) with torch.no_grad(): logits = model(x) labels, y, scores, logits = self.accelerator.gather_for_metrics((labels, y, scores, logits)) test_labels[idx : idx + len(x)] = labels.detach().cpu() test_targets[idx : idx + len(x)] = y.detach().cpu() test_scores[idx : idx + len(x)] = scores.detach().cpu() test_preds[idx : idx + len(x)] = logits.detach().cpu().argmax(1) idx += len(x) progress_bar.update(1) progress_bar.close() self.accelerator.wait_for_everyone() assert len(test_labels) == len(test_targets) == len(test_scores) == len(test_preds) == idx _logger.info("Computing metrics...") acc_threshold = kwargs.get("acc_threshold", 0.90) stride = kwargs.get("stride", 1) metrics = self.postprocess( test_scores, test_preds, test_targets, test_labels, acc_threshold=acc_threshold, stride=stride ) return { "method": method, "scores": test_scores, "preds": test_preds, "targets": test_targets, "labels": test_labels, **metrics, }
[docs] def postprocess( self, test_scores: Tensor, test_preds: Tensor, test_targets: Tensor, test_labels: Tensor, stride=1, alpha=0.99, acc_threshold=0.90, **kwargs, ) -> Dict[str, Any]: win_size = self.batch_size avg_warmup = test_scores[: self.warmup_size].mean().item() data_padded = F.pad(test_scores.unsqueeze(0), (win_size - 1, 0), "constant", avg_warmup).squeeze(0) moving_average = data_padded.unfold(0, win_size, stride).mean(dim=1) ema = test_scores.clone() ema[0] = avg_warmup for i in range(1, len(test_scores)): ema[i] = alpha * ema[i - 1] + (1 - alpha) * test_scores[i] mistakes = (test_preds != test_targets).float() mistakes_padded = F.pad(mistakes.unsqueeze(0), (win_size - 1, 0), "constant", 0).squeeze(0) moving_accuracy = 1 - mistakes_padded.unfold(0, win_size, stride).mean(dim=1) # define real drift event: when moving accuracy is below threshold compared to training accuracy acc = moving_accuracy[self.splits[1] : self.splits[1] + (self.splits[2] - self.splits[1]) // 2].mean().item() ref = acc_threshold * acc _logger.info("Original accuracy: %s", acc) _logger.info("Reference accuracy to detect drift: %s", ref) drift_labels = (moving_accuracy < ref).float() corr_drift = np.corrcoef(-moving_average.numpy(), drift_labels.numpy())[0, 1] corr_acc = np.corrcoef(moving_average.numpy(), moving_accuracy.numpy())[0, 1] # check error if theres is only one label on drift_labels if len(np.unique(drift_labels.numpy())) == 1: auroc_drift = 1.0 else: auroc_drift = float(sklearn.metrics.roc_auc_score(drift_labels, -test_scores)) if len(np.unique(mistakes.numpy())) == 1: auroc_mistakes = 1.0 else: auroc_mistakes = float(sklearn.metrics.roc_auc_score(mistakes, -test_scores)) fprs, tprs, thresholds = sklearn.metrics.roc_curve(drift_labels, -test_scores) fpr_drift, _, _ = detectors.eval.fpr_at_fixed_tpr(fprs, tprs, thresholds, 0.95) fprs, tprs, thresholds = sklearn.metrics.roc_curve(mistakes, -test_scores) fpr_mistakes, _, _ = detectors.eval.fpr_at_fixed_tpr(fprs, tprs, thresholds, 0.95) return dict( drift_labels=drift_labels, first_drift=torch.argmax(drift_labels).item(), ref_accuracy=ref, splits=self.splits, corr_acc=corr_acc, corr_drift=corr_drift, auroc_drift=auroc_drift, auroc_mistakes=auroc_mistakes, fpr_drift=fpr_drift, fpr_mistakes=fpr_mistakes, moving_average=moving_average, ema=ema, moving_accuracy=moving_accuracy, mistakes=mistakes, )
[docs] def report(self, results: Dict[str, Any], subsample=1): print("Results:") print("\tCorr. Acc:", results["corr_acc"]) print("\tCorr. Drift:", results["corr_drift"]) print("\tAUC Drift:", results["auroc_drift"]) print("\tAUC Mistakes:", results["auroc_mistakes"]) print("\tFPR Drift:", results["fpr_drift"]) print("\tFPR Mistakes:", results["fpr_mistakes"]) print("\tFirst Drift:", results["first_drift"]) print("\tSplits", results["splits"]) # plot results mpl_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] fig, ax1 = plt.subplots(1, 1, figsize=(9, 4)) ax2 = ax1.twinx() # subsample values test_scores = results["scores"][::subsample] moving_average = results["moving_average"][::subsample] ema = results["ema"][::subsample] # test_labels = results["test_labels"][::subsample] drift_labels = results["drift_labels"][::subsample] mistakes = results["mistakes"][::subsample] moving_accuracy = results["moving_accuracy"][::subsample] ax1.plot(test_scores.numpy(), alpha=0.4, c=mpl_colors[1], label="score") ax1.plot(moving_average.numpy(), alpha=0.8, c=mpl_colors[0], label="moving avg") ax1.plot(ema.numpy(), alpha=0.8, c=mpl_colors[4], linewidth=2, label="ema") ax1.grid() ax2.vlines( self.warmup_size // subsample, 0, 1, linestyle="--", color="gray", alpha=0.5, linewidth=3, label="begin test set", ) ax2.plot(drift_labels.numpy(), alpha=0.5, c=mpl_colors[2], linestyle="--", label="drift", linewidth=3) ax2.scatter(range(len(mistakes)), mistakes.numpy(), alpha=0.5, marker="*", c=mpl_colors[3], label="mistakes") ax2.plot(moving_accuracy.numpy(), alpha=0.5, c=mpl_colors[3], label="moving accuracy", linewidth=2) # plot reference accuracy ax2.axhline( results["ref_accuracy"], linestyle=":", color="black", alpha=0.5, linewidth=3, label="drift accuracy ref" ) ax1.set_xlabel("Sample index") ax1.set_ylabel("Scores") ax2.set_ylabel("Drift") ax2.legend(loc="upper right") ax1.legend(loc="lower left") plt.suptitle( f"Corr. Acc. {results['corr_acc']:.2f}\nFPR Mistakes {results['fpr_mistakes']:.2f}, AUC Mistakes {results['auroc_mistakes']:.2f}, AUC Drift {results['auroc_drift']:.2f}" )
[docs]@register_pipeline("covariate_drift_cifar10") class OneCorruptionCovariateDriftCifar10Pipeline(CovariateDriftPipeline): def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: super().__init__("cifar10", ["train", "test"], transform, [corruption], intensities, batch_size, **kwargs)
[docs]@register_pipeline("covariate_drift_cifar100") class OneCorruptionCovariateDriftCifar100Pipeline(CovariateDriftPipeline): def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: super().__init__("cifar100", ["train", "test"], transform, [corruption], intensities, batch_size, **kwargs)
[docs]@register_pipeline("covariate_drift_imagenet") class OneCorruptionCovariateDriftImagenetPipeline(CovariateDriftPipeline): def __init__(self, transform, corruption: str, intensities: List[int], batch_size: int = 128, **kwargs) -> None: super().__init__("imagenet", ["train", "val"], transform, [corruption], intensities, batch_size, **kwargs)