Source code for detectors.data.mnist_wrapped

from typing import Callable, Optional

from torchvision.datasets import MNIST, FashionMNIST
from torchvision.datasets.utils import verify_str_arg


[docs]class MNISTWrapped(MNIST): splits = ("train", "test") def __init__( self, root: str, split: str = "test", transform: Optional[Callable] = None, download: bool = False, **kwargs ) -> None: self.split = verify_str_arg(split, "split", self.splits) super().__init__(root, train=split == "train", transform=transform, download=download, **kwargs)
[docs]class FashionMNISTWrapped(FashionMNIST): splits = ("train", "test") def __init__( self, root: str, split: str = "test", transform: Optional[Callable] = None, download: bool = False, **kwargs ) -> None: self.split = verify_str_arg(split, "split", self.splits) super().__init__(root, train=split == "train", transform=transform, download=download, **kwargs)