Source code for detectors.data.cifarc

import os
from typing import Callable, Optional

import numpy as np
import torch.utils.data.dataset as dataset
from PIL import Image
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg

CORRUPTIONS = [
    "brightness",
    "contrast",
    "defocus_blur",
    "elastic_transform",
    "fog",
    "frost",
    "gaussian_blur",
    "gaussian_noise",
    "glass_blur",
    "impulse_noise",
    "jpeg_compression",
    "motion_blur",
    "pixelate",
    "saturate",
    "shot_noise",
    "snow",
    "spatter",
    "speckle_noise",
    "zoom_blur",
]


[docs]class CIFAR10_C(dataset.Dataset): base_folder = "CIFAR-10-C" filename = "CIFAR-10-C.tar" file_md5 = "56bf5dcef84df0e2308c6dcbcbbd8499" url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar" corruptions = CORRUPTIONS def __init__( self, root: str, split: str, intensity: int, transform: Optional[Callable] = None, download: bool = False ) -> None: self.root = os.path.expanduser(root) self.corruption = verify_str_arg(split, "split", self.corruptions) self.intensity = int(intensity) self.transform = transform if download: self.download() if not self._check_integrity(): raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") indices = self._indices images = np.load(self._perturbation_array)[indices] labels = np.load(self._labels_array)[indices] self.arrays = [images, labels] def __getitem__(self, index): x = self.arrays[0][index] # to PIL image x = Image.fromarray(x) if self.transform: x = self.transform(x) y = self.arrays[1][index] return x, y def __len__(self): return len(self.arrays[0]) @property def _indices(self): N = 10000 return slice((self.intensity - 1) * N, self.intensity * N) @property def _dataset_folder(self): return os.path.join(self.root, self.base_folder) @property def _perturbation_array(self): return os.path.join(self._dataset_folder, self.corruption + ".npy") @property def _labels_array(self): return os.path.join(self._dataset_folder, "labels.npy") def _check_integrity(self) -> bool: root = self.root md5 = self.file_md5 fpath = os.path.join(root, self.filename) return check_integrity(fpath, md5) def _check_exists(self) -> bool: return os.path.exists(self._perturbation_array)
[docs] def download(self) -> None: if self._check_integrity() and self._check_exists(): return download_and_extract_archive(self.url, download_root=self.root, md5=self.file_md5)
[docs]class CIFAR100_C(CIFAR10_C): base_folder = "CIFAR-100-C" filename = "CIFAR-100-C.tar" file_md5 = "11f0ed0f1191edbf9fa23466ae6021d3" url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar" corruptions = CORRUPTIONS