import logging
from typing import Callable, Optional
import accelerate
import numpy as np
import torch
import torch.utils.data
import detectors
from detectors.pipelines import register_pipeline
from detectors.pipelines.base import Pipeline
from detectors.pipelines.ood import OODBenchmarkPipeline
from detectors.utils import ConcatDatasetsDim1
_logger = logging.getLogger(__name__)
[docs]@register_pipeline("osr_cifar10")
class OSRCifar10(OODBenchmarkPipeline):
def __init__(
self, transform: Callable, batch_size: int, limit_fit: Optional[int] = None, seed: int = 42, **kwargs
) -> None:
super().__init__(
"cifar10", {"cifar100": "test"}, transform, batch_size, limit_fit=limit_fit, seed=seed, **kwargs
)
[docs]@register_pipeline("osr_cifar100")
class OSRCifar100(OODBenchmarkPipeline):
def __init__(
self, transform: Callable, batch_size: int, limit_fit: Optional[int] = None, seed: int = 42, **kwargs
) -> None:
super().__init__(
"cifar100", {"cifar10": "test"}, transform, batch_size, limit_fit=limit_fit, seed=seed, **kwargs
)
[docs]@register_pipeline("osr_imagenet")
class OSRImagenet(OODBenchmarkPipeline):
def __init__(
self, transform: Callable, batch_size: int, limit_fit: Optional[int] = None, seed: int = 42, **kwargs
) -> None:
super().__init__(
"imagenet",
{
"imagenet_o": None,
},
transform,
batch_size,
limit_fit=limit_fit,
seed=seed,
**kwargs,
)
[docs]@register_pipeline("one_class_versus_others_cifar10")
class SingleClassCifar10(Pipeline):
# TODO
def __init__(
self,
in_dataset_name: str,
in_dataset_split: str,
transform: Callable,
batch_size: int,
num_workers: int = 4,
pin_memory: bool = True,
prefetch_factor: int = 2,
limit_fit: float = 1.0,
limit_run: float = 1.0,
seed: int = 42,
**kwargs,
) -> None:
self.transform = transform
self.batch_size = batch_size
self.limit_fit = limit_fit
self.limit_run = limit_run
self.seed = seed
self.kwargs = kwargs
self.num_workers = num_workers
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
self.in_dataset_name = in_dataset_name
self.in_dataset_split = in_dataset_split
self.in_dataset_name = "cifar10"
accelerate.utils.set_seed(seed)
self.setup()
def _setup_datasets(self):
"""Setup `in_dataset`."""
self.in_dataset = detectors.create_dataset(
self.in_dataset_name, transform=self.transform, split=self.in_dataset_split
)
def _setup_dataloaders(self):
if self.limit_fit is None:
self.limit_fit = 1.0
self.limit_fit = min(int(self.limit_fit * len(self.fit_dataset)), len(self.fit_dataset))
# random indices
subset = np.random.choice(np.arange(len(self.fit_dataset)), self.limit_fit, replace=False).tolist()
self.fit_dataset = torch.utils.data.Subset(self.fit_dataset, subset)
self.fit_dataloader = torch.utils.data.DataLoader(
self.fit_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
self.test_dataset = torch.utils.data.ConcatDataset([self.in_dataset])
test_labels = torch.utils.data.TensorDataset(
torch.cat(
[torch.zeros(len(self.in_dataset))] # type: ignore
+ [torch.ones(len(d)) * (i + 1) for i, d in enumerate(self.out_datasets.values())] # type: ignore
).long()
)
self.test_dataset = ConcatDatasetsDim1([self.test_dataset, test_labels])
# shuffle and subsample test_dataset
subset = np.random.choice(
np.arange(len(self.test_dataset)), int(self.limit_run * len(self.test_dataset)), replace=False
).tolist()
self.test_dataset = torch.utils.data.Subset(self.test_dataset, subset)
self.test_dataloader = torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
)
_logger.info(f"Using {len(self.fit_dataset)} samples for fitting.")
_logger.info(f"Using {len(self.test_dataset)} samples for testing.")
[docs] def setup(self):
self._setup_datasets()
self._setup_dataloaders()