Source code for detectors.methods.mcdropout

import torch
from torch import Tensor, nn


[docs]@torch.no_grad() def mcdropout(x: Tensor, model: nn.Module, k: int = 5, **kwargs) -> Tensor: """MC Dropout Forward-propagates the input through the model several times with activated dropout and averages the results. Args: x (Tensor): input tensor. model (nn.Module): classifier. k (int, optional): number of forward passes. Defaults to 5. References: [1] http://proceedings.mlr.press/v48/gal16.pdf """ model.eval() has_dropout = False for m in model.modules(): if isinstance(m, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)): m.train() has_dropout = True if not has_dropout: return torch.softmax(model(x), dim=1).max(dim=1)[0] results = None for i in range(k): probs = torch.softmax(model(x), dim=1) if results is None: results = probs.clone() results += probs results = results / k # type: ignore return results.max(dim=1)[0]