"""ResNet models for CIFAR10, CIFAR100, and SVHN datasets."""
import logging
import timm
import timm.models
import torch
import torch.nn as nn
from timm.models import register_model as timm_register_model
from detectors.data import CIFAR10_DEFAULT_MEAN, CIFAR10_DEFAULT_STD
from detectors.data.constants import CIFAR100_DEFAULT_MEAN, CIFAR100_DEFAULT_STD, SVHN_DEFAULT_MEAN, SVHN_DEFAULT_STD
from detectors.models.utils import ModelDefaultConfig, hf_hub_url_template
_logger = logging.getLogger(__name__)
def _cfg(url="", **kwargs):
num_classes = kwargs.pop("num_classes", 10)
mean = kwargs.pop("mean", CIFAR10_DEFAULT_MEAN)
std = kwargs.pop("std", CIFAR10_DEFAULT_STD)
return ModelDefaultConfig(
url=url,
num_classes=num_classes,
input_size=(3, 32, 32),
pool_size=(4, 4),
crop_pct=1,
interpolation="bilinear",
mean=mean,
std=std,
first_conv="conv1",
classifier="fc",
**kwargs,
)
default_cfgs = {
# CIFAR-10
"resnet18_cifar10": _cfg(url=hf_hub_url_template("resnet18_cifar10"), architecture="resnet18"),
"resnet34_cifar10": _cfg(url=hf_hub_url_template("resnet34_cifar10"), architecture="resnet34"),
"resnet50_cifar10": _cfg(url=hf_hub_url_template("resnet50_cifar10"), architecture="resnet50"),
"resnet34_simclr_cifar10": _cfg(
url=hf_hub_url_template("resnet34_simclr_cifar10"), architecture="resnet34", num_classes=512
),
"resnet50_simclr_cifar10": _cfg(
url=hf_hub_url_template("resnet50_simclr_cifar10"), architecture="resnet50", num_classes=2048
),
"resnet34_supcon_cifar10": _cfg(
url=hf_hub_url_template("resnet34_supcon_cifar10"), architecture="resnet34", num_classes=512
),
"resnet50_supcon_cifar10": _cfg(
url=hf_hub_url_template("resnet50_supcon_cifar10"), architecture="resnet50", num_classes=2048
),
# CIFAR-100
"resnet18_cifar100": _cfg(
url=hf_hub_url_template("resnet18_cifar100"),
num_classes=100,
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
architecture="resnet18",
),
"resnet34_cifar100": _cfg(
url=hf_hub_url_template("resnet34_cifar100"),
num_classes=100,
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
architecture="resnet34",
),
"resnet50_cifar100": _cfg(
url=hf_hub_url_template("resnet50_cifar100"),
num_classes=100,
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
architecture="resnet50",
),
"resnet34_simclr_cifar100": _cfg(
url=hf_hub_url_template("resnet34_simclr_cifar100"),
architecture="resnet34",
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
num_classes=512,
),
"resnet50_simclr_cifar100": _cfg(
url=hf_hub_url_template("resnet50_simclr_cifar100"),
architecture="resnet50",
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
num_classes=2048,
),
"resnet34_supcon_cifar100": _cfg(
url=hf_hub_url_template("resnet34_supcon_cifar100"),
architecture="resnet34",
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
num_classes=512,
),
"resnet50_supcon_cifar100": _cfg(
url=hf_hub_url_template("resnet50_supcon_cifar100"),
architecture="resnet50",
mean=CIFAR100_DEFAULT_MEAN,
std=CIFAR100_DEFAULT_STD,
num_classes=2048,
),
# SVHN
"resnet18_svhn": _cfg(
url=hf_hub_url_template("resnet18_svhn"), mean=SVHN_DEFAULT_MEAN, std=SVHN_DEFAULT_STD, architecture="resnet18"
),
"resnet34_svhn": _cfg(
url=hf_hub_url_template("resnet34_svhn"), mean=SVHN_DEFAULT_MEAN, std=SVHN_DEFAULT_STD, architecture="resnet34"
),
"resnet50_svhn": _cfg(
url=hf_hub_url_template("resnet50_svhn"), mean=SVHN_DEFAULT_MEAN, std=SVHN_DEFAULT_STD, architecture="resnet50"
),
}
def _create_resnet_small(variant, features_dim=512, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
# load timm model
architecture = default_cfg.architecture or variant.split("_")[0]
model = timm.create_model(architecture, pretrained=False)
# override timm config
model.default_cfg = default_cfg
model.pretrained_cfg = default_cfg
# override model
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity() # type: ignore
model.fc = nn.Linear(features_dim, model.default_cfg.num_classes)
# load weights
if pretrained:
model.load_state_dict(
torch.hub.load_state_dict_from_url(model.default_cfg.url, map_location="cpu", file_name=f"{variant}.pth")
)
return model
def _create_resnet_small_ssl(variant, features_dim=512, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant]
# load timm model
architecture = default_cfg.architecture or variant.split("_")[0]
model = timm.create_model(architecture, pretrained=False, num_classes=0)
# override timm config
model.default_cfg = default_cfg
model.pretrained_cfg = default_cfg
# override model
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity() # type: ignore
if pretrained:
model.load_state_dict(
torch.hub.load_state_dict_from_url(model.default_cfg.url, map_location="cpu", file_name=f"{variant}.pth")
)
return model
[docs]@timm_register_model
def resnet18_cifar10(pretrained=False, **kwargs):
return _create_resnet_small("resnet18_cifar10", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_cifar10(pretrained=False, **kwargs):
return _create_resnet_small("resnet34_cifar10", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_cifar10(pretrained=False, **kwargs):
return _create_resnet_small("resnet50_cifar10", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_simclr_cifar10(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet34_simclr_cifar10", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_simclr_cifar10(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet50_simclr_cifar10", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_supcon_cifar10(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet34_supcon_cifar10", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_supcon_cifar10(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet50_supcon_cifar10", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet18_cifar100(pretrained=False, **kwargs):
return _create_resnet_small("resnet18_cifar100", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_cifar100(pretrained=False, **kwargs):
return _create_resnet_small("resnet34_cifar100", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_cifar100(pretrained=False, **kwargs):
return _create_resnet_small("resnet50_cifar100", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_simclr_cifar100(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet34_simclr_cifar100", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_simclr_cifar100(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet50_simclr_cifar100", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_supcon_cifar100(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet34_supcon_cifar100", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_supcon_cifar100(pretrained=False, **kwargs):
return _create_resnet_small_ssl("resnet50_supcon_cifar100", features_dim=2048, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet18_svhn(pretrained=False, **kwargs):
return _create_resnet_small("resnet18_svhn", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet34_svhn(pretrained=False, **kwargs):
return _create_resnet_small("resnet34_svhn", features_dim=512, pretrained=pretrained, **kwargs)
[docs]@timm_register_model
def resnet50_svhn(pretrained=False, **kwargs):
return _create_resnet_small("resnet50_svhn", features_dim=2048, pretrained=pretrained, **kwargs)