from typing import Callable, Optional
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.datasets.utils import verify_str_arg
[docs]class CIFAR10Wrapped(CIFAR10):
splits = ("train", "test")
def __init__(
self, root: str, split: str = "test", transform: Optional[Callable] = None, download: bool = False, **kwargs
) -> None:
self.split = verify_str_arg(split, "split", self.splits)
super().__init__(root, train=split == "train", transform=transform, download=download, **kwargs)
[docs]class CIFAR100Wrapped(CIFAR100):
splits = ("train", "test")
def __init__(
self, root: str, split: str = "test", transform: Optional[Callable] = None, download: bool = False, **kwargs
) -> None:
self.split = verify_str_arg(split, "split", self.splits)
super().__init__(root, train=split == "train", transform=transform, download=download, **kwargs)
CIFAR10_LABELS = (
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
CIFAR100_CATS = {
"aquatic mammals": ["beaver", "dolphin", "otter", "seal", "whale"],
"fish": ["aquarium fish", "flatfish", "ray", "shark", "trout"],
"flowers": ["orchids", "poppies", "roses", "sunflowers", "tulips"],
"food containers": ["bottles", "bowls", "cans", "cups", "plates"],
"fruit and vegetables": ["apples", "mushrooms", "oranges", "pears", "sweet", "peppers"],
"household electrical devices": ["clock", "computer", "keyboard", "lamp", "telephone", "television"],
"household furniture": ["bed", "chair", "couch", "table", "wardrobe"],
"insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
"large carnivores": ["bear", "leopard", "lion", "tiger", "wolf"],
"large man-made outdoor things": ["bridge", "castle", "house", "road", "skyscraper"],
"large natural outdoor scenes": ["cloud", "forest", "mountain", "plain", "sea"],
"large omnivores and herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
"medium-sized mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"],
"non-insect invertebrates": ["crab", "lobster", "snail", "spider", "worm"],
"people": ["baby", "boy", "girl", "man", "woman"],
"reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
"small mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
"trees": ["maple", "oak", "palm", "pine", "willow"],
"vehicles 1": ["bicycle", "bus", "motorcycle", "pickup" "truck", "train"],
"vehicles 2": ["lawn-mower", "rocket", "streetcar", "tank", "tractor"],
}
CIFAR100_LABELS = (
"apple",
"aquarium_fish",
"baby",
"bear",
"beaver",
"bed",
"bee",
"beetle",
"bicycle",
"bottle",
"bowl",
"boy",
"bridge",
"bus",
"butterfly",
"camel",
"can",
"castle",
"caterpillar",
"cattle",
"chair",
"chimpanzee",
"clock",
"cloud",
"cockroach",
"couch",
"cra",
"crocodile",
"cup",
"dinosaur",
"dolphin",
"elephant",
"flatfish",
"forest",
"fox",
"girl",
"hamster",
"house",
"kangaroo",
"keyboard",
"lamp",
"lawn_mower",
"leopard",
"lion",
"lizard",
"lobster",
"man",
"maple_tree",
"motorcycle",
"mountain",
"mouse",
"mushroom",
"oak_tree",
"orange",
"orchid",
"otter",
"palm_tree",
"pear",
"pickup_truck",
"pine_tree",
"plain",
"plate",
"poppy",
"porcupine",
"possum",
"rabbit",
"raccoon",
"ray",
"road",
"rocket",
"rose",
"sea",
"seal",
"shark",
"shrew",
"skunk",
"skyscraper",
"snail",
"snake",
"spider",
"squirrel",
"streetcar",
"sunflower",
"sweet_pepper",
"table",
"tank",
"telephone",
"television",
"tiger",
"tractor",
"train",
"trout",
"tulip",
"turtle",
"wardrobe",
"whale",
"willow_tree",
"wolf",
"woman",
"worm",
)