Source code for detectors.models.vit

"""Finetuned ViT models for CIFAR10, CIFAR100, and SVHN datasets."""
import timm
import timm.models
import torch
import torch.nn as nn
from timm.models import register_model as timm_register_model

from detectors.models.utils import ModelDefaultConfig, hf_hub_url_template


def _cfg(url, architecture: str, **kwargs):
    model = timm.create_model(architecture, pretrained=False)
    base_cfg = model.default_cfg
    num_classes = kwargs.pop("num_classes", 10)
    # replace base cfg with new arguments
    base_cfg.update(kwargs)
    base_cfg["architecture"] = architecture
    base_cfg["url"] = url
    base_cfg["num_classes"] = num_classes
    return ModelDefaultConfig(**base_cfg)


default_cfgs = {
    "vit_base_patch16_224_in21k_ft_cifar10": _cfg(
        url=hf_hub_url_template("vit_base_patch16_224_in21k_ft_cifar10"),
        architecture="timm/vit_base_patch16_224.orig_in21k_ft_in1k",
    ),
    "vit_base_patch16_224_in21k_ft_cifar100": _cfg(
        url=hf_hub_url_template("vit_base_patch16_224_in21k_ft_cifar100"),
        num_classes=100,
        architecture="timm/vit_base_patch16_224.orig_in21k_ft_in1k",
    ),
    "vit_base_patch16_224_in21k_ft_svhn": _cfg(
        url=hf_hub_url_template("vit_base_patch16_224_in21k_ft_svhn"),
        architecture="timm/vit_base_patch16_224.orig_in21k_ft_in1k",
    ),
}


def _create_vit_ft(variant, pretrained=False, **kwargs):
    default_cfg = default_cfgs[variant]

    # load timm model
    model = timm.create_model(default_cfg["architecture"], pretrained=not pretrained)
    # override timm config
    model.default_cfg = default_cfg

    # override model
    model.head = nn.Linear(model.head.in_features, default_cfg["num_classes"])

    if pretrained:
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(default_cfg["url"], map_location="cpu", file_name=f"{variant}.pth")
        )

    return model


[docs]@timm_register_model def vit_base_patch16_224_in21k_ft_cifar10(pretrained=False, **kwargs): return _create_vit_ft("vit_base_patch16_224_in21k_ft_cifar10", pretrained=pretrained, **kwargs)
[docs]@timm_register_model def vit_base_patch16_224_in21k_ft_cifar100(pretrained=False, **kwargs): return _create_vit_ft("vit_base_patch16_224_in21k_ft_cifar100", pretrained=pretrained, **kwargs)
[docs]@timm_register_model def vit_base_patch16_224_in21k_ft_svhn(pretrained=False, **kwargs): return _create_vit_ft("vit_base_patch16_224_in21k_ft_svhn", pretrained=pretrained, **kwargs)