Source code for detectors.methods.max_logits

import torch
from torch import Tensor


[docs]@torch.no_grad() def max_logits(input: Tensor, model: torch.nn.Module, **kwargs) -> Tensor: """Max Logits OOD detector. Args: logits (Tensor): input tensor. Returns: Tensor: OOD scores for each input. References: [1] https://arxiv.org/abs/1911.11132 """ model.eval() logits = model(input) return torch.max(logits, dim=1)[0]