Source code for detectors.data.imagenet

import logging
import os
from typing import Callable, Optional

import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg
from tqdm import tqdm

_logger = logging.getLogger(__name__)


[docs]class ImageNetA(ImageFolder): """ImageNetA dataset. - Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174). """ base_folder = "imagenet-a" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar" filename = "imagenet-a.tar" tgz_md5 = "c3e55429088dc681f30d81f4726b6595" def __init__(self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs): self.root = root if download: self.download() if not self._check_integrity(): raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") super().__init__(root=os.path.join(root, self.base_folder), transform=transform, **kwargs) def _check_exists(self) -> bool: return os.path.exists(os.path.join(self.root, self.base_folder)) def _check_integrity(self) -> bool: return check_integrity(os.path.join(self.root, self.filename), self.tgz_md5)
[docs] def download(self) -> None: if self._check_integrity() and self._check_exists(): _logger.debug("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
[docs]class ImageNetO(ImageNetA): """ImageNetO datasets. Contains unknown classes to ImageNet-1k. - Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174) """ base_folder = "imagenet-o" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar" filename = "imagenet-o.tar" tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b"
[docs]class ImageNetR(ImageNetA): """ImageNet-R(endition) dataset. Contains art, cartoons, deviantart, graffiti, embroidery, graphics, origami, paintings, patterns, plastic objects,plush objects, sculptures, sketches, tattoos, toys, and video game renditions of ImageNet-1k classes. - Paper: [https://arxiv.org/abs/2006.16241](https://arxiv.org/abs/2006.16241) """ base_folder = "imagenet-r" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar" filename = "imagenet-r.tar" tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337"
CORRUPTIONS = [ "brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise", "snow", "spatter", "speckle_noise", "zoom_blur", ]
[docs]class ImageNetC(ImageNetA): """Corrupted version of the ImageNet-1k dataset. It contains the following subsets: - `noise` (21GB): gaussian_noise, shot_noise, and impulse_noise. - `blur` (7GB): defocus_blur, glass_blur, motion_blur, and zoom_blur. - `weather` (12GB): frost, snow, fog, and brightness. - `digital` (7GB): contrast, elastic_transform, pixelate, and jpeg_compression. - `extra` (15GB): speckle_noise, spatter, gaussian_blur, and saturate. - Paper: [https://arxiv.org/abs/1903.12261v1](https://arxiv.org/abs/1903.12261v1) """ split_list = ["blur", "digital", "extra", "noise", "weather"] base_folder_name = "ImageNetC" url_base = "https://zenodo.org/record/2235448/files/" tgz_md5_list = [ "2d8e81fdd8e07fef67b9334fa635e45c", "89157860d7b10d5797849337ca2e5c03", "d492dfba5fc162d8ec2c3cd8ee672984", "e80562d7f6c3f8834afb1ecf27252745", "33ffea4db4d93fe4a428c40a6ce0c25d", ] corruptions = CORRUPTIONS def __init__( self, root: str, split: str, intensity: int, transform: Optional[Callable] = None, download: bool = False, **kwargs, ) -> None: self.root = os.path.expanduser(root) self.corruption = verify_str_arg(split, "split", self.corruptions) split_group = self._get_corruption_group(self.corruption) self._base_folder = os.path.join(root, self.base_folder_name, split_group) self.filename = split_group + ".tar" self.url = self.url_base + self.filename self.tgz_md5 = self.tgz_md5_list[self.split_list.index(split_group)] self.base_folder = os.path.join(self._base_folder, split, str(intensity)) super().__init__(root, transform=transform, download=download, **kwargs)
[docs] def download(self) -> None: if self._check_integrity() and self._check_exists(): _logger.debug("Files already downloaded and verified") return download_and_extract_archive( self.url, self.root, extract_root=self._base_folder, filename=self.filename, md5=self.tgz_md5 )
@staticmethod def _get_corruption_group(corruption: str): split_group = "" if corruption in ["defocus_blur", "glass_blur", "motion_blur", "zoom_blur"]: split_group = "blur" elif corruption in ["contrast", "elastic_transform", "pixelate", "jpeg_compression"]: split_group = "digital" elif corruption in ["speckle_noise", "spatter", "gaussian_blur", "saturate"]: split_group = "extra" elif corruption in ["gaussian_noise", "shot_noise", "impulse_noise"]: split_group = "noise" elif corruption in ["frost", "snow", "fog", "brightness"]: split_group = "weather" return split_group
def _imagenet_c_to_npz(root: str, split: str, intensity: int, dest_folder: str = "ImageNetCnpz") -> None: dataset = ImageNetC(root, split, intensity, download=True) assert len(dataset) == 50_000, "ImageNetC should have 50,000 images. Please check the dataset." image_example = dataset[0][0] width, height = image_example.size _logger.info("Image size: %d x %d", width, height) x = np.ndarray(shape=(len(dataset), height, width, 3), dtype=np.uint8) y = np.ndarray(shape=(len(dataset)), dtype=np.int32) for i in tqdm(range(len(dataset))): image, label = dataset[i] x[i] = image y[i] = label os.makedirs(os.path.join(root, dest_folder), exist_ok=True) np.savez(os.path.join(root, dest_folder, f"{split}-{intensity}.npz"), x=x, y=y)
[docs]class ImageNetCnpz(Dataset): """Corrupted version of the ImageNet-1k dataset saved in npz format.""" corruptions = CORRUPTIONS base_folder_name = "ImageNetCnpz" def __init__( self, root: str, split: str, intensity: int, transform: Optional[Callable] = None, download: bool = False, **kwargs, ) -> None: super().__init__() self.root = os.path.expanduser(root) self.corruption = verify_str_arg(split, "split", self.corruptions) self.intensity = int(intensity) self.path = os.path.join(self.root, self.base_folder_name, f"{split}-{intensity}.npz") self.transform = transform if download: self.download() data = np.load(self.path, mmap_mode="r") self.images = data["x"] self.labels = data["y"] def __getitem__(self, index): x = self.images[index] x = Image.fromarray(x) if self.transform: x = self.transform(x) y = self.labels[index] return x, y def __len__(self): return len(self.images) def _check_exists(self) -> bool: return os.path.exists(self.path)
[docs] def download(self) -> None: if self._check_exists(): return _imagenet_c_to_npz(self.root, self.corruption, self.intensity, self.base_folder_name)