Source code for detectors.models.vgg

"""VGG models for CIFAR10, CIFAR100 and SVHN datasets."""
import timm
import timm.models
import torch
import torch.nn as nn
from timm.models import register_model as timm_register_model

from detectors.data import CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD
from detectors.data.constants import CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD, SVHN_DEFAULT_MEAN, SVHN_DEFAULT_STD
from detectors.models.utils import ModelDefaultConfig, hf_hub_url_template


def _cfg(url="", **kwargs):
    num_classes = kwargs.pop("num_classes", 10)
    mean = kwargs.pop("mean", CIFAR10_DEFAULT_MEAN)
    std = kwargs.pop("std", CIFAR10_DEFAULT_STD)
    return ModelDefaultConfig(
        url=url,
        num_classes=num_classes,
        input_size=(3, 32, 32),
        pool_size=(4, 4),
        crop_pct=1,
        interpolation="bilinear",
        mean=mean,
        std=std,
        first_conv="features.0",
        classifier="head.fc",
        **kwargs,
    )


default_cfgs = {
    "vgg16_bn_cifar10": _cfg(
        url=hf_hub_url_template("vgg16_bn_cifar10"),
        architecture="vgg16_bn",
    ),
    "vgg16_bn_cifar100": _cfg(
        url=hf_hub_url_template("vgg16_bn_cifar100"),
        num_classes=100,
        mean=CIFAR100_DEFAULT_MEAN,
        std=CIFAR100_DEFAULT_STD,
        architecture="vgg16_bn",
    ),
    "vgg16_bn_svhn": _cfg(
        url=hf_hub_url_template("vgg16_bn_svhn"),
        mean=SVHN_DEFAULT_MEAN,
        std=SVHN_DEFAULT_STD,
        architecture="vgg16_bn",
    ),
}


def _create_vgg_small(variant, features_dim=512, pretrained=False, **kwargs):
    default_cfg = default_cfgs[variant]

    # load timm model
    architecture = default_cfg.architecture
    model = timm.create_model(architecture, pretrained=False)

    # override timm config
    model.default_cfg = default_cfg
    model.pretrained_cfg = default_cfg

    # override model
    model.pre_logits = nn.Identity()  # type: ignore
    model.head.fc = nn.Linear(features_dim, model.default_cfg.num_classes)

    # load weights
    if pretrained:
        checkpoint = model.default_cfg.url
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                checkpoint, progress=True, map_location="cpu", file_name=f"{variant}.pth"
            )
        )

    return model


[docs]@timm_register_model def vgg16_bn_cifar10(pretrained=False, **kwargs): return _create_vgg_small("vgg16_bn_cifar10", pretrained=pretrained, **kwargs)
[docs]@timm_register_model def vgg16_bn_cifar100(pretrained=False, **kwargs): return _create_vgg_small("vgg16_bn_cifar100", pretrained=pretrained, **kwargs)
[docs]@timm_register_model def vgg16_bn_svhn(pretrained=False, **kwargs): return _create_vgg_small("vgg16_bn_svhn", pretrained=pretrained, **kwargs)