Source code for detectors.data.tiny_imagenet

"""Simple Tiny ImageNet dataset utility class for pytorch."""

import os
import shutil
from typing import Callable, Optional

from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg


[docs]def normalize_tin_val_folder_structure(path, images_folder="images", annotations_file="val_annotations.txt"): # Check if files/annotations are still there to see # if we already run reorganize the folder structure. images_folder = os.path.join(path, images_folder) annotations_file = os.path.join(path, annotations_file) # Exists if not os.path.exists(images_folder) and not os.path.exists(annotations_file): if not os.listdir(path): raise RuntimeError("Validation folder is empty.") return # Parse the annotations with open(annotations_file) as f: for line in f: values = line.split() img = values[0] label = values[1] img_file = os.path.join(images_folder, values[0]) label_folder = os.path.join(path, label) os.makedirs(label_folder, exist_ok=True) try: shutil.move(img_file, os.path.join(label_folder, img)) except FileNotFoundError: continue os.sync() assert not os.listdir(images_folder) shutil.rmtree(images_folder) os.remove(annotations_file) os.sync()
[docs]class TinyImageNet(ImageFolder): """Dataset for TinyImageNet-200""" base_folder = "tiny-imagenet-200" zip_md5 = "90528d7ca1a48142e341f4ef8d21d0de" splits = ("train", "val", "test") filename = "tiny-imagenet-200.zip" url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" def __init__(self, root, split="train", transform: Optional[Callable] = None, download=False, **kwargs): self.data_root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", self.splits) if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found." + " You can use download=True to download it") super().__init__(self.split_folder, transform=transform, **kwargs) @property def dataset_folder(self): return os.path.join(self.data_root, self.base_folder) @property def split_folder(self): return os.path.join(self.dataset_folder, self.split) def _check_exists(self): return os.path.exists(self.split_folder)
[docs] def extra_repr(self): return "Split: {split}".format(**self.__dict__)
[docs] def download(self): if self._check_exists(): return download_and_extract_archive(self.url, self.data_root, filename=self.filename, md5=self.zip_md5) assert "val" in self.splits normalize_tin_val_folder_structure(os.path.join(self.dataset_folder, "val"))